Compare commits

...

255 Commits

Author SHA1 Message Date
Jonathan
7a6760acad Update presets.py (#8846) 2026-02-06 08:52:24 -05:00
Jonathan
91c1e64f0b Add dype area option (#8844)
* Add DyPE area option

* Added tests and fixed frontend build

* Made more pythonic
2026-02-06 08:50:24 -05:00
Lincoln Stein
cbe528eef7 chore(release): prep for 6.11.1 bugfix release 2026-02-06 08:10:22 -05:00
Alexander Eichhorn
4081f8701e fix(flux2): Fix FLUX.2 Klein image generation quality (#8838)
* fix(flux2): Fix image quality degradation at resolutions > 1024x1024

This commit addresses severe quality degradation and artifacts when
generating images larger than 1024x1024 with FLUX.2 Klein models.

Root causes fixed:

1. Dynamic max_image_seq_len in scheduler (flux2_denoise.py)
   - Previously hardcoded to 4096 (1024x1024 only)
   - Now dynamically calculated based on actual resolution
   - Allows proper schedule shifting at all resolutions

2. Smoothed mu calculation discontinuity (sampling_utils.py)
   - Eliminated 40-50% mu value drop at seq_len 4300 threshold
   - Implemented smooth cosine interpolation (4096-4500 transition zone)
   - Gradual blend between low-res and high-res formulas

Impact:
- FLUX.2 Klein 9B: Major quality improvement at high resolutions
- FLUX.2 Klein 4B: Improved quality at high resolutions
- Baseline 1024x1024: Unchanged (no regression)
- All generation modes: T2I and Kontext (reference images)

Fixes: Community-reported quality degradation issue
See: Discord discussions in #garbage-bin and #devchat

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

* fix(flux2): Fix high-resolution quality degradation for FLUX.2 Klein

  Fixes grid/diamond artifacts and color loss at resolutions > 1024x1024.

  Root causes identified and fixed:
  - BN normalization was incorrectly applied to random noise input
    (diffusers only normalizes image latents from VAE.encode)
  - BN denormalization must be applied to output before VAE decode
  - mu parameter was resolution-dependent causing over-shifted schedules
    at high resolutions (now fixed to 2.02, matching ComfyUI)

  Changes:
  - Remove BN normalization on noise input (not needed for N(0,1) noise)
  - Preserve BN denormalization on denoised output (required for VAE)
  - Fix mu to constant 2.02 for all resolutions (matches ComfyUI)

  Tested at 2048x2048 with FLUX.2 Klein 4B

* Chore Ruff

---------

Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com>
2026-02-06 08:04:03 -05:00
Lincoln Stein
5649b60672 chore(release): bump version to 6.11.1
This is a bugfix release that contains fixes for bugs in 6.11.0, as
well as updates to the Russian language translation.

No new user-facing features are included in this release.
2026-02-04 15:29:26 -05:00
Weblate (bot)
714eeed74d translationBot(ui): update translation (Russian) (#8830)
Currently translated at 59.7% (1344 of 2249 strings)


Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ru/
Translation: InvokeAI/Web UI

Co-authored-by: DustyShoe <warukeichi@gmail.com>
2026-02-04 15:28:27 -05:00
Alexander Eichhorn
656b50e6ad fix(ui): remove duplicate DyPE preset dropdown in generation settings (#8831)
The ParamFluxDypePreset component was rendered twice in the FLUX
generation settings accordion, causing the DyPE dropdown to appear
both after the scheduler and after the guidance slider.

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-02-04 15:27:24 -05:00
Alexander Eichhorn
0263f4032c fix(ui): reset seed variance toggle when recalling images without that metadata (#8829)
When recalling an image that lacks `z_image_seed_variance_enabled` metadata
   (e.g. older images), the toggle now defaults to off instead of retaining the
   previous state.
2026-02-04 15:27:03 -05:00
Alexander Eichhorn
dd87e0a946 The FLUX.2 Klein PR (b92c6ae63) replaced the user's denoising strength (#8828)
setting with hardcoded full denoising (start=0, end=1) in addOutpaint.
   This caused denoising strength to be completely ignored whenever the
   canvas bbox extended beyond the raster layer content, triggering outpaint
   mode. The issue affected all model types (SDXL, SD1.5, FLUX, etc.).

   Restore the original behavior by reading denoising_start/end from the
   user's img2imgStrength setting via getDenoisingStartAndEnd().

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-02-04 15:26:39 -05:00
Alexander Eichhorn
438eea1159 fix(flux2): support Heun scheduler for FLUX.2 Klein models (#8794)
* fix(flux2): support Heun scheduler for FLUX.2 Klein models

FlowMatchHeunDiscreteScheduler does not support dynamic shifting parameters
(use_dynamic_shifting, base_shift, max_shift, etc.) or sigmas/mu in set_timesteps.
This caused FLUX.2 Klein to fail when using Heun scheduler.

- Create Heun scheduler with only num_train_timesteps and shift parameters
- Use num_inference_steps instead of sigmas for Heun's set_timesteps call
- Euler and LCM schedulers continue to use full dynamic shifting support

* fix(flux2): fix Heun scheduler detection using inspect.signature

The previous hasattr check for state_in_first_order failed because
the attribute doesn't exist before set_timesteps() is called. Now
using inspect.signature to check for sigmas parameter support,
matching the FLUX1 implementation.

---------

Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com>
2026-02-04 15:26:12 -05:00
Alexander Eichhorn
d93e451831 fix(ui): only show FLUX.1 VAEs when a FLUX.1 main model is selected (#8821)
Use useFlux1VAEModels() instead of useFluxVAEModels() in the FLUX VAE
selector, which was incorrectly returning both FLUX.1 and FLUX.2 VAEs.
Remove the now-unused useFluxVAEModels hook.

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-02-04 15:25:48 -05:00
Lincoln Stein
efc7a262b7 chore(release): bump version to 6.11.0 2026-01-31 17:27:56 -05:00
Weblate (bot)
a873ce0175 ui: translations update from weblate (#8816)
* translationBot(ui): update translation (Italian)

Currently translated at 95.0% (2124 of 2235 strings)

translationBot(ui): update translation (Italian)

Currently translated at 94.5% (2114 of 2235 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation files

Updated by "Remove blank strings" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Italian)

Currently translated at 98.2% (2195 of 2235 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Italian)

Currently translated at 98.2% (2197 of 2235 strings)

Translation: InvokeAI/Web UI
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/

* translationBot(ui): update translation (Russian)

Currently translated at 60.0% (1341 of 2235 strings)

Translation: InvokeAI/Web UI
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ru/

---------

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Co-authored-by: DustyShoe <warukeichi@gmail.com>
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-31 22:18:09 +00:00
Alexander Eichhorn
9ee7baaba5 fix(ui): convert reference image configs when switching main model base (#8811)
When switching between FLUX.2 (model-less reference images) and other
models that require IP adapter/Redux models, the reference image configs
were not being converted, leaving stale config types that hid or showed
the wrong UI controls.

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-31 22:04:23 +00:00
Weblate (bot)
fb5c43a905 ui: translations update from weblate (#8814)
* translationBot(ui): update translation (Italian)

Currently translated at 95.0% (2124 of 2235 strings)

translationBot(ui): update translation (Italian)

Currently translated at 94.5% (2114 of 2235 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation files

Updated by "Remove blank strings" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Italian)

Currently translated at 98.2% (2195 of 2235 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Italian)

Currently translated at 98.2% (2197 of 2235 strings)

Translation: InvokeAI/Web UI
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/

* translationBot(ui): update translation (Russian)

Currently translated at 60.0% (1341 of 2235 strings)

Translation: InvokeAI/Web UI
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ru/

---------

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Co-authored-by: DustyShoe <warukeichi@gmail.com>
2026-01-31 17:03:47 -05:00
Weblate (bot)
0f69f4bb9a ui: translations update from weblate (#8813)
* translationBot(ui): update translation (Italian)

Currently translated at 95.0% (2124 of 2235 strings)

translationBot(ui): update translation (Italian)

Currently translated at 94.5% (2114 of 2235 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation files

Updated by "Remove blank strings" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Italian)

Currently translated at 98.2% (2195 of 2235 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Italian)

Currently translated at 98.2% (2197 of 2235 strings)

Translation: InvokeAI/Web UI
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/

* translationBot(ui): update translation (Russian)

Currently translated at 60.0% (1341 of 2235 strings)

Translation: InvokeAI/Web UI
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ru/

---------

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Co-authored-by: DustyShoe <warukeichi@gmail.com>
2026-01-31 16:41:12 -05:00
Weblate (bot)
8a355e66fa ui: translations update from weblate (#8812)
* translationBot(ui): update translation (Italian)

Currently translated at 95.0% (2124 of 2235 strings)

translationBot(ui): update translation (Italian)

Currently translated at 94.5% (2114 of 2235 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation files

Updated by "Remove blank strings" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Italian)

Currently translated at 98.2% (2195 of 2235 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Italian)

Currently translated at 98.2% (2197 of 2235 strings)

Translation: InvokeAI/Web UI
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/

* translationBot(ui): update translation (Russian)

Currently translated at 60.0% (1341 of 2235 strings)

Translation: InvokeAI/Web UI
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ru/

---------

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Co-authored-by: DustyShoe <warukeichi@gmail.com>
2026-01-31 08:52:27 -05:00
blessedcoolant
b811602b38 fix(ui): Flux 2 Model Manager default settings not showing Guidance (#8810) 2026-01-31 13:41:05 +00:00
DustyShoe
0716b2fa75 Fix blur filter clipping by expanding padded bounds (#8773)
Co-authored-by: Alexander Eichhorn <alex@eichhorn.dev>
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-30 20:56:51 +00:00
Alexander Eichhorn
4d71609115 fix(ui): remove scheduler selection for FLUX.2 Klein (#8808)
The scheduler dropdown is no longer shown for FLUX.2 Klein models.
The backend default (Euler) is used instead.

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-30 02:16:12 +00:00
blessedcoolant
0ecb903ae2 fix: Klein 2 Inpainting breaking when there is a reference image (#8803)
Co-authored-by: Alexander Eichhorn <alex@eichhorn.dev>
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-30 02:12:41 +00:00
Alexander Eichhorn
736f4ffeb1 fix(ui): improve DyPE field ordering and add 'On' preset option (#8793)
* fix(ui): improve DyPE field ordering and add 'On' preset option

- Add ui_order to DyPE fields (100, 101, 102) to group them at bottom of node
- Change DyPEPreset from Enum to Literal type for proper frontend dropdown support
- Add ui_choice_labels for human-readable dropdown options
- Add new 'On' preset to enable DyPE regardless of resolution
- Fix frontend input field sorting to respect ui_order (unordered first, then ordered)
- Bump flux_denoise node version to 4.4.0

* Chore Ruff check fix

* fix(flux): remove .value from dype_preset logging

DyPEPreset is now a Literal type (string) instead of an Enum,
so .value is no longer needed.

* fix(tests): update DyPE tests for Literal type change

Update test imports and assertions to use string constants
instead of Enum attributes since DyPEPreset is now a Literal type.

* feat(flux): add DyPE scale and exponent controls to Linear UI

- Add dype_scale (λs) and dype_exponent (λt) sliders to generation settings
- Add Zod schemas and parameter types for DyPE scale/exponent
- Pass custom values from Linear UI to flux_denoise node
- Fix bug where DyPE was enabled even when preset was "off"
- Add enhanced logging showing all DyPE parameters when enabled

* fix(flux): apply DyPE scale/exponent and add metadata recall

- Fix DyPE scale and exponent parameters not being applied in frequency
  computation (compute_vision_yarn_freqs, compute_yarn_freqs now call
  get_timestep_mscale)
- Add metadata handlers for dype_scale and dype_exponent to enable
  recall from generated images
- Add i18n translations referencing existing parameter labels

* fix(flux): apply DyPE scale/exponent and add metadata recall

- Fix DyPE scale and exponent parameters not being applied in frequency
  computation (compute_vision_yarn_freqs, compute_yarn_freqs now call
  get_timestep_mscale)
- Add metadata handlers for dype_scale and dype_exponent to enable
  recall from generated images
- Add i18n translations referencing existing parameter labels

* feat(ui): show DyPE scale/exponent only when preset is "on"

- Hide scale/exponent controls in UI when preset is not "on"
- Only parse/recall scale/exponent from metadata when preset is "on"
- Prevents confusion where custom values override preset behavior

* fix(dype): only allow custom scale/exponent with 'on' preset

Presets (auto, 4k) now use their predefined values and ignore
any custom_scale/custom_exponent parameters. Only the 'on' preset
allows manual override of these values.

This matches the frontend UI behavior where the scale/exponent
fields are only shown when 'On' is selected.

* refactor(dype): rename 'on' preset to 'manual'

Rename the 'on' DyPE preset to 'manual' to better reflect its purpose:
allowing users to manually configure scale and exponent values.

Updated in:
- Backend presets (DYPE_PRESET_ON -> DYPE_PRESET_MANUAL)
- Frontend UI labels and options
- Redux slice type definitions
- Zod schema validation
- Tests

* refactor(dype): rename 'on' preset to 'manual'

Rename the 'on' DyPE preset to 'manual' to better reflect its purpose:
allowing users to manually configure scale and exponent values.

Updated in:
- Backend presets (DYPE_PRESET_ON -> DYPE_PRESET_MANUAL)
- Frontend UI labels and options
- Redux slice type definitions
- Zod schema validation
- Tests

* fix(dype): update remaining 'on' references to 'manual'

- Update docstrings, comments, and error messages to use 'manual' preset name
- Simplify FLUX graph builder to always send dype_scale/dype_exponent
- Fix UI condition to show DyPE controls for 'manual' preset

---------

Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com>
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-30 01:28:28 +00:00
Weblate (bot)
2102b43edc ui: translations update from weblate (#8807)
* translationBot(ui): update translation (Italian)

Currently translated at 95.0% (2124 of 2235 strings)

translationBot(ui): update translation (Italian)

Currently translated at 94.5% (2114 of 2235 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation files

Updated by "Remove blank strings" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Italian)

Currently translated at 98.2% (2195 of 2235 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Italian)

Currently translated at 98.2% (2197 of 2235 strings)

Translation: InvokeAI/Web UI
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/

---------

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-30 01:22:50 +00:00
Lincoln Stein
5801e59e2b Documentation: InvokeAI PR review and merge policy (#8795)
* docs: Add a PR review and merge policy

* doc(release): add policy on release candidates

* docs(CD/CI): add best practice for external components
2026-01-30 01:03:43 +00:00
Lincoln Stein
5fc950b745 Release Workflow: Fix workflow edge case (#8792)
* release(docker): fix workflow edge case that prevented CUDA build from completing

* bugfix(release): fix yaml syntax error

* bugfix(CI/CD): fix similar problem in typegen check
2026-01-30 01:02:24 +00:00
Weblate (bot)
63dec985cd ui: translations update from weblate (#8806)
* translationBot(ui): update translation (Italian)

Currently translated at 95.0% (2124 of 2235 strings)

translationBot(ui): update translation (Italian)

Currently translated at 94.5% (2114 of 2235 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation files

Updated by "Remove blank strings" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Italian)

Currently translated at 98.2% (2195 of 2235 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Italian)

Currently translated at 98.2% (2197 of 2235 strings)

Translation: InvokeAI/Web UI
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/

---------

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
2026-01-29 20:52:21 +00:00
Weblate (bot)
03cdd6df2e ui: translations update from weblate (#8804)
* translationBot(ui): update translation (Italian)

Currently translated at 95.0% (2124 of 2235 strings)

translationBot(ui): update translation (Italian)

Currently translated at 94.5% (2114 of 2235 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation files

Updated by "Remove blank strings" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Italian)

Currently translated at 98.2% (2195 of 2235 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Italian)

Currently translated at 98.2% (2197 of 2235 strings)

Translation: InvokeAI/Web UI
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/

---------

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-29 15:42:10 -05:00
Lincoln Stein
99f4070ce7 translationBot(ui): update translation (Italian) (#8805)
Currently translated at 98.2% (2197 of 2235 strings)

Translation: InvokeAI/Web UI
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
2026-01-29 15:36:44 -05:00
Alexander Eichhorn
cf07f8be14 Add new model type integration guide (#8779)
* Add new model type integration guide

Comprehensive documentation covering all steps required to integrate
a new model type into InvokeAI, including:

- Backend: Model manager, configs, loaders, invocations, sampling
- Frontend: Graph building, state management, parameter recall
- Metadata, starter models, and optional features (ControlNet, LoRA, IP-Adapter)

Uses FLUX.1, FLUX.2 Klein, SD3, SDXL, and Z-Image as reference implementations.

* docs: improve new model integration guide

- Move document to docs/contributing/ directory
- Fix broken TOC links by replacing '&' with 'and' in headings
- Add code example for text encoder config (section 2.4)
- Add text encoder loader example (new section 3.3)
- Expand text encoder invocation to show full conditioning flow (section 4.2)

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-29 13:45:29 +00:00
Alexander Eichhorn
1f0d92defc fix(ui): allow guidance slider to reach 1 for FLUX.2 Klein
FLUX.2 Klein models require guidance=1 (no CFG), but the slider minimum
was set to 2. Changed sliderMin from 2 to 1 to allow proper configuration.
2026-01-29 07:27:18 +05:30
blessedcoolant
68089ca688 fix(ui): use proper FLUX2 latent RGB factors for preview images (#8802)
## Summary

Replace placeholder zeros with actual 32-channel factors from ComfyUI
and add latent_rgb_bias support for improved FLUX2 denoising previews.

## Related Issues / Discussions

https://github.com/Comfy-Org/ComfyUI/blob/main/comfy/latent_formats.py

https://github.com/user-attachments/assets/dfbc3d81-b855-46b8-8217-50b140f13520

## QA Instructions

1. Generate an image with a FLUX2 model (e.g. FLUX.2 Kontext)
2. Observe the denoising preview during generation
3. Preview should now show more accurate colors instead of
washed-out/incorrect colors from the previous placeholder factors
2026-01-29 07:12:39 +05:30
blessedcoolant
32e2132948 Merge branch 'main' into fix/flux2-latent-preview-factors 2026-01-29 07:07:50 +05:30
Alexander Eichhorn
bec3586930 fix(ui): use proper FLUX2 latent RGB factors for preview images
Replace placeholder zeros with actual 32-channel factors from ComfyUI
and add latent_rgb_bias support for improved FLUX2 denoising previews.
2026-01-29 02:22:17 +01:00
Weblate (bot)
8bf4d1ea59 ui: translations update from weblate (#8797)
* translationBot(ui): update translation (Italian)

Currently translated at 95.0% (2124 of 2235 strings)

translationBot(ui): update translation (Italian)

Currently translated at 94.5% (2114 of 2235 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation files

Updated by "Remove blank strings" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Italian)

Currently translated at 98.2% (2195 of 2235 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

---------

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-28 17:21:42 -05:00
Jonathan
fd7a3aebd2 Add input connectors to the FLUX model loader (#8785)
* Update flux_model_loader.py

Added nodal points for inputs to the model loader since we should be able to use a model selection node and pass in for Flux models.

* typegen

* Fixed existing ruff error

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-28 16:49:16 -05:00
Alexander Eichhorn
72491e2153 Fix ref_images metadata format for FLUX Kontext recall (#8791)
Remove extra array wrapper when saving ref_images metadata for FLUX.2 Klein

and FLUX.1 Kontext reference images. The double-nested array [[...]] was

preventing recall from parsing the metadata correctly.
2026-01-27 08:44:44 -05:00
Lincoln Stein
3d0725072d Prep for 6.11.0.rc1 (#8771)
* chore(release): add flux.2-klein to whats new items & bump version

* doc(release): update the WhatsNew text

* chore(frontend): run lint:prettier and frontend-typegen
2026-01-27 05:40:09 +00:00
Alexander Eichhorn
0ae7392c81 fix(model_manager): detect Flux1/2 VAE by latent space dimensions instead of filename (#8790)
* fix(model_manager): detect Flux VAE by latent space dimensions instead of filename

VAE detection previously relied solely on filename pattern matching, which failed
for Flux VAE files with generic names like "ae.safetensors". Now probes the model's
decoder.conv_in weight shape to determine the latent space dimensions:
- 16 channels -> Flux VAE
- 4 channels -> SD/SDXL VAE (with filename fallback for SD1/SD2/SDXL distinction)

* fix(model_manager): add latent space probing for Flux2 VAE detection

Extend Flux2 VAE detection to also check for 32-dimensional latent space
(decoder.conv_in with 32 input channels) in addition to BatchNorm layers.
This provides more robust detection for Flux2 VAE files regardless of filename.

* Chore Ruff format

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-27 05:20:50 +00:00
Alexander Eichhorn
cff20b45f3 Feature: Add DyPE (Dynamic Position Extrapolation) support to FLUX models for improved high-resolution image generation (#8763)
* docs: add DyPE implementation plan for FLUX high-resolution generation

Add detailed plan for porting ComfyUI-DyPE (Dynamic Position Extrapolation)
to InvokeAI, enabling 4K+ image generation with FLUX models without
training. Estimated effort: 5-7 developer days.

* docs: update DyPE plan with design decisions

- Integrate DyPE directly into FluxDenoise (no separate node)
- Add 4K preset and "auto" mode for automatic activation
- Confirm FLUX Schnell support (same base resolution as Dev)

* docs: add activation threshold for DyPE auto mode

FLUX can handle resolutions up to ~1.5x natively without artifacts.
Set activation_threshold=1536 so DyPE only kicks in above that.

* feat(flux): implement DyPE for high-resolution generation

Add Dynamic Position Extrapolation (DyPE) support to FLUX models,
enabling artifact-free generation at 4K+ resolutions.

New files:
- invokeai/backend/flux/dype/base.py: DyPEConfig and scaling calculations
- invokeai/backend/flux/dype/rope.py: DyPE-enhanced RoPE functions
- invokeai/backend/flux/dype/embed.py: DyPEEmbedND position embedder
- invokeai/backend/flux/dype/presets.py: Presets (off, auto, 4k)
- invokeai/backend/flux/extensions/dype_extension.py: Pipeline integration

Modified files:
- invokeai/backend/flux/denoise.py: Add dype_extension parameter
- invokeai/app/invocations/flux_denoise.py: Add UI parameters

UI parameters:
- dype_preset: off | auto | 4k
- dype_scale: Custom magnitude override (0-8)
- dype_exponent: Custom decay speed override (0-1000)

Auto mode activates DyPE for resolutions > 1536px.

Based on: https://github.com/wildminder/ComfyUI-DyPE

* feat(flux): add DyPE preset selector to Linear UI

Add Linear UI integration for FLUX DyPE (Dynamic Position Extrapolation):

- Add ParamFluxDypePreset component with Off/Auto/4K options
- Integrate preset selector in GenerationSettingsAccordion for FLUX models
- Add state management (paramsSlice, types) for fluxDypePreset
- Add dype_preset to FLUX denoise graph builder and metadata
- Add translations for DyPE preset label and popover
- Add zFluxDypePresetField schema definition

Fix DyPE frequency computation:
- Remove incorrect mscale multiplication on frequencies
- Use only NTK-aware theta scaling for position extrapolation

* feat(flux): add DyPE preset to metadata recall

- Add FluxDypePreset handler to ImageMetadataHandlers
- Parse dype_preset from metadata and dispatch setFluxDypePreset on recall
- Add translation key metadata.dypePreset

* chore: remove dype-implementation-plan.md

Remove internal planning document from the branch.

* chore(flux): bump flux_denoise version to 4.3.0

Version bump for dype_preset field addition.

* chore: ruff check fix

* chore: ruff format

* Fix truncated DyPE label in advanced options UI

Shorten the label from "DyPE (High-Res)" to "DyPE" to prevent text truncation in the sidebar. The high-resolution context is preserved in the informational popover tooltip.

* Add DyPE preset to recall parameters in image viewer

The dype_preset metadata was being saved but not displayed in the Recall Parameters tab. Add FluxDypePreset handler to ImageMetadataActions so users can see and recall this parameter.

---------

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com>
2026-01-26 23:54:44 -05:00
Alexander Eichhorn
b92c6ae633 feat(flux2): add FLUX.2 klein model support (#8768)
* WIP: feat(flux2): add FLUX 2 Kontext model support

- Add new invocation nodes for FLUX 2:
  - flux2_denoise: Denoising invocation for FLUX 2
  - flux2_klein_model_loader: Model loader for Klein architecture
  - flux2_klein_text_encoder: Text encoder for Qwen3-based encoding
  - flux2_vae_decode: VAE decoder for FLUX 2

- Add backend support:
  - New flux2 module with denoise and sampling utilities
  - Extended model manager configs for FLUX 2 models
  - Updated model loaders for Klein architecture

- Update frontend:
  - Extended graph builder for FLUX 2 support
  - Added FLUX 2 model types and configurations
  - Updated readiness checks and UI components

* fix(flux2): correct VAE decode with proper BN denormalization

FLUX.2 VAE uses Batch Normalization in the patchified latent space
(128 channels). The decode must:
1. Patchify latents from (B, 32, H, W) to (B, 128, H/2, W/2)
2. Apply BN denormalization using running_mean/running_var
3. Unpatchify back to (B, 32, H, W) for VAE decode

Also fixed image normalization from [-1, 1] to [0, 255].

This fixes washed-out colors in generated FLUX.2 Klein images.

* feat(flux2): add FLUX.2 Klein model support with ComfyUI checkpoint compatibility

- Add FLUX.2 transformer loader with BFL-to-diffusers weight conversion
- Fix AdaLayerNorm scale-shift swap for final_layer.adaLN_modulation weights
- Add VAE batch normalization handling for FLUX.2 latent normalization
- Add Qwen3 text encoder loader with ComfyUI FP8 quantization support
- Add frontend components for FLUX.2 Klein model selection
- Update configs and schema for FLUX.2 model types

* Chore Ruff

* Fix Flux1 vae probing

* Fix Windows Paths schema.ts

* Add 4B und 9B klein to Starter Models.

* feat(flux2): add non-commercial license indicator for FLUX.2 Klein 9B

- Add isFlux2Klein9BMainModelConfig and isNonCommercialMainModelConfig functions
- Update MainModelPicker and InitialStateMainModelPicker to show license icon
- Update license tooltip text to include FLUX.2 Klein 9B

* feat(flux2): add Klein/Qwen3 variant support and encoder filtering

Backend:
- Add klein_4b/klein_9b variants for FLUX.2 Klein models
- Add qwen3_4b/qwen3_8b variants for Qwen3 encoder models
- Validate encoder variant matches Klein model (4B↔4B, 9B↔8B)
- Auto-detect Qwen3 variant from hidden_size during probing

Frontend:
- Show variant field for all model types in ModelView
- Filter Qwen3 encoder dropdown to only show compatible variants
- Update variant type definitions (zFlux2VariantType, zQwen3VariantType)
- Remove unused exports (isFluxDevMainModelConfig, isFlux2Klein9BMainModelConfig)

* Chore Ruff

* feat(flux2): add Klein 9B Base (undistilled) variant support

Distinguish between FLUX.2 Klein 9B (distilled) and Klein 9B Base (undistilled)
models by checking guidance_embeds in diffusers config or guidance_in keys in
safetensors. Klein 9B Base requires more steps but offers higher quality.

* feat(flux2): improve diffusers compatibility and distilled model support

Backend changes:
- Update text encoder layers from [9,18,27] to (10,20,30) matching diffusers
- Use apply_chat_template with system message instead of manual formatting
- Change position IDs from ones to zeros to match diffusers implementation
- Add get_schedule_flux2() with empirical mu computation for proper schedule shifting
- Add txt_embed_scale parameter for Qwen3 embedding magnitude control
- Add shift_schedule toggle for base (28+ steps) vs distilled (4 steps) models
- Zero out guidance_embedder weights for Klein models without guidance_embeds

UI changes:
- Clear Klein VAE and Qwen3 encoder when switching away from flux2 base
- Clear Qwen3 encoder when switching between different Klein model variants
- Add toast notification informing user to select compatible encoder

* feat(flux2): fix distilled model scheduling with proper dynamic shifting

- Configure scheduler with FLUX.2 Klein parameters from scheduler_config.json
  (use_dynamic_shifting=True, shift=3.0, time_shift_type="exponential")
- Pass mu parameter to scheduler.set_timesteps() for resolution-aware shifting
- Remove manual shift_schedule parameter (scheduler handles this automatically)
- Simplify get_schedule_flux2() to return linear sigmas only
- Remove txt_embed_scale parameter (no longer needed)

This matches the diffusers Flux2KleinPipeline behavior where the
FlowMatchEulerDiscreteScheduler applies dynamic timestep shifting
based on image resolution via the mu parameter.

Fixes 4-step distilled Klein 9B model quality issues.

* fix(ui): fix FLUX.1 graph building with posCondCollect node lookup

The posCondCollect node was created with getPrefixedId() which generates
a random suffix (e.g., 'pos_cond_collect:abc123'), but g.getNode() was
called with the plain string 'pos_cond_collect', causing a node lookup
failure.

Fix by declaring posCondCollect as a module-scoped variable and
referencing it directly instead of using g.getNode().

* Remove Flux2 Klein Base from Starter Models

* Remove Logging

* Add Default Values for Flux2 Klein and add variant as additional info to from_base

* Add migrations for the z-image qwen3 encoder without a variant value

* Add img2img, inpainting and outpainting support for FLUX.2 Klein

- Add flux2_vae_encode invocation for encoding images to FLUX.2 latents
- Integrate inpaint_extension into FLUX.2 denoise loop for proper mask handling
- Apply BN normalization to init_latents and noise for consistency in inpainting
- Use manual Euler stepping for img2img/inpaint to preserve exact timestep schedule
- Add flux2_img2img, flux2_inpaint, flux2_outpaint generation modes
- Expand starter models with FP8 variants, standalone transformers, and separate VAE/encoders
- Fix outpainting to always use full denoising (0-1) since strength doesn't apply
- Improve error messages in model loader with clear guidance for standalone models

* Add GGUF quantized model support and Diffusers VAE loader for FLUX.2 Klein

- Add Main_GGUF_Flux2_Config for GGUF-quantized FLUX.2 transformer models
- Add VAE_Diffusers_Flux2_Config for FLUX.2 VAE in diffusers format
- Add Flux2GGUFCheckpointModel loader with BFL-to-diffusers conversion
- Add Flux2VAEDiffusersLoader for AutoencoderKLFlux2
- Add FLUX.2 Klein 4B/9B hardware requirements to documentation
- Update starter model descriptions to clarify dependencies install together
- Update frontend schema for new model configs

* Fix FLUX.2 model detection and add FP8 weight dequantization support

- Improve FLUX.2 variant detection for GGUF/checkpoint models (BFL format keys)
- Fix guidance_embeds logic: distilled=False, undistilled=True
- Add FP8 weight dequantization for ComfyUI-style quantized models
- Prevent FLUX.2 models from being misidentified as FLUX.1
- Preserve user-editable fields (name, description, etc.) on model reidentify
- Improve Qwen3Encoder detection by variant in starter models
- Add defensive checks for tensor operations

* Chore ruff format

* Chore Typegen

* Fix FLUX.2 Klein 9B model loading by detecting hidden_size from weights

Previously num_attention_heads was hardcoded to 24, which is correct for
Klein 4B but causes size mismatches when loading Klein 9B checkpoints.

Now dynamically calculates num_attention_heads from the hidden_size
dimension of context_embedder weights:
- Klein 4B: hidden_size=3072 → num_attention_heads=24
- Klein 9B: hidden_size=4096 → num_attention_heads=32

Fixes both Checkpoint and GGUF loaders for FLUX.2 models.

* Only clear Qwen3 encoder when FLUX.2 Klein variant changes

Previously the encoder was cleared whenever switching between any Klein
models, even if they had the same variant. Now compares the variant of
the old and new model and only clears the encoder when switching between
different variants (e.g., klein_4b to klein_9b).

This allows users to switch between different Klein 9B models without
having to re-select the Qwen3 encoder each time.

* Add metadata recall support for FLUX.2 Klein parameters

The scheduler, VAE model, and Qwen3 encoder model were not being
recalled correctly for FLUX.2 Klein images. This adds dedicated
metadata handlers for the Klein-specific parameters.

* Fix FLUX.2 Klein denoising scaling and Z-Image VAE compatibility

- Apply exponential denoising scaling (exponent 0.2) to FLUX.2 Klein,
  matching FLUX.1 behavior for more intuitive inpainting strength
- Add isFlux1VAEModelConfig type guard to filter FLUX 1.0 VAEs only
- Restrict Z-Image VAE selection to FLUX 1.0 VAEs, excluding FLUX.2
  Klein 32-channel VAEs which are incompatible

* chore pnpm fix

* Add FLUX.2 Klein to starter bundles and documentation

- Add FLUX.2 Klein hardware requirements to quick start guide
- Create flux2_klein_bundle with GGUF Q4 model, VAE, and Qwen3 encoder
- Add "What's New" entry announcing FLUX.2 Klein support

* Add FLUX.2 Klein built-in reference image editing support

FLUX.2 Klein has native multi-reference image editing without requiring
a separate model (unlike FLUX.1 which needs a Kontext model).

Backend changes:
- Add Flux2RefImageExtension for encoding reference images with FLUX.2 VAE
- Apply BN normalization to reference image latents for correct scaling
- Use T-coordinate offset scale=10 like diffusers (T=10, 20, 30...)
- Concatenate reference latents with generated image during denoising
- Extract only generated portion in step callback for correct preview

Frontend changes:
- Add flux2_reference_image config type without model field
- Hide model selector for FLUX.2 reference images (built-in support)
- Add type guards to handle configs without model property
- Update validators to skip model validation for FLUX.2
- Add 'flux2' to SUPPORTS_REF_IMAGES_BASE_MODELS

* Chore windows path fix

* Add reference image resizing for FLUX.2 Klein

Resize large reference images to match BFL FLUX.2 sampling.py limits:
- Single reference: max 2024² pixels (~4.1M)
- Multiple references: max 1024² pixels (~1M)

Uses same scaling approach as BFL's cap_pixels() function.
2026-01-26 23:21:37 -05:00
DustyShoe
729bae19a5 Feat(UI): Search bar in image info code tabs and add vertical margins for improved UX in Recall Parameters tab. (#8786)
* Adjusted Search bar position and added padding in image info viewer.

* Minor bug fix with spaces being highlighted.
2026-01-25 22:38:43 +01:00
Copilot
fcc81f17a5 Limit automated issue closure to bug issues only (#8776)
* Initial plan

* Add only-labels parameter to limit automated issue closure to bugs only

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
2026-01-21 02:43:59 +05:30
Weblate (bot)
27ae70a428 translationBot(ui): update translation files (#8767)
Updated by "Cleanup translation files" hook in Weblate.


Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2026-01-20 15:50:21 -05:00
Lincoln Stein
82819cdadc Add user survey section to README (#8766)
* Add user survey section to README

Added a section for new and returning users to take a survey.

* docs: add user survey link to WhatsNew

* Fix formatting issues in WhatsNew.tsx

---------

Co-authored-by: Alexander Eichhorn <alex@eichhorn.dev>
2026-01-16 03:32:16 +01:00
Alexander Eichhorn
b2b8820519 fix(model_manager): prevent Z-Image LoRAs from being misclassified as main models (#8754)
* fix(model_manager): prevent Z-Image LoRAs from being misclassified as main models

Z-Image LoRAs containing keys like `diffusion_model.context_refiner.*` were being
incorrectly classified as main checkpoint models instead of LoRAs. This happened
because the `_has_z_image_keys()` function checked for Z-Image specific keys
(like `context_refiner`) without verifying if the file was actually a LoRA.

Since main models have higher priority than LoRAs in the classification sort order,
the incorrect main model classification would win.

The fix adds detection of LoRA-specific weight suffixes (`.lora_down.weight`,
`.lora_up.weight`, `.lora_A.weight`, `.lora_B.weight`, `.dora_scale`) and returns
False if any are found, ensuring LoRAs are correctly classified.

* refactor(mm): simplify _has_z_image_keys with early return

Return True directly when a Z-Image key is found instead of using an
intermediate variable.
2026-01-14 22:35:17 -05:00
Alexander Eichhorn
bb6c544603 feat(z-image): add Seed Variance Enhancer node and Linear UI integration (#8753)
* feat(z-image): add Seed Variance Enhancer node and Linear UI integration

Add a new conditioning node for Z-Image models that injects seed-based
noise into text embeddings to increase visual variation between seeds.

Backend:
- New invocation: z_image_seed_variance_enhancer.py
- Parameters: strength (0-2), randomize_percent (1-100%), seed

Frontend:
- State management in paramsSlice with selectors and reducers
- UI components in SeedVariance/ folder with toggle and sliders
- Integration in GenerationSettingsAccordion (Advanced Options)
- Graph builder integration in buildZImageGraph.ts
- Metadata recall handlers for remix functionality
- Translations and tooltip descriptions

Based on: github.com/Pfannkuchensack/invokeai-z-image-seed-variance-enhancer

* chore: ruff and typegen fix

* chore: ruff and typegen fix

* Revise seedVarianceStrength explanation

Updated description for seedVarianceStrength.

* Update description for seedVarianceStrength

* fix(z-image): correct noise range comment from [-1, 1] to [-1, 1)

torch.rand() generates [0, 1), so the scaled range excludes 1.
2026-01-12 20:36:21 +01:00
blessedcoolant
8a18914637 chore(CI/CD): Remove codeowners from /docs directory (#8737)
## Summary

This PR removes codeowners from the `/docs` directory, allowing any team
member with repo write permissions to review and approve PRs involving
documentation.

## Related Issues / Discussions

Documentation review is a shared responsibility.

## QA Instructions

None needed.

## Merge Plan

Simple merge.

## Checklist

- [X] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Changes to a redux slice have a corresponding migration_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2026-01-12 15:19:22 +05:30
blessedcoolant
d66df9a0d0 Merge branch 'main' into lstein/chore/codeowners 2026-01-12 15:18:19 +05:30
DustyShoe
5c00684701 Feat(UI): Canvas high level transform smoothing (#8756)
* WIP transform smoothing controls

* Fix transform smoothing control typings

* High level resize algo for transformation

* ESLint fix

* format with prettier

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-11 15:48:27 -05:00
DustyShoe
d93ce6ac42 Fix(UI): Canvas numeric brush size (#8761)
* Fix for brush/eraser size not updating on up/down arrow click

* Made further improvements on brush size selection behavior

---------

Co-authored-by: Alexander Eichhorn <alex@eichhorn.dev>
2026-01-11 15:23:06 -05:00
blessedcoolant
13bf5feb4d Fix(UI): Error message for extract region (#8759)
## Summary

This PR fixes misleading popup message "Canvas is empty" when attempting
to extract region with empty mask layer.
Replaced with correct message "Mask layer is empty". Also redirected few
other popups to use translation file.


## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Changes to a redux slice have a corresponding migration_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2026-01-11 21:53:48 +05:30
DustyShoe
53ab178edd Merge branch 'invoke-ai:main' into Fix(UI)--Error-messsage-for-extract-region 2026-01-11 02:13:35 +02:00
DustyShoe
2d8317f1aa Corrected error message and redirected popup messages to use translation file 2026-01-11 02:08:47 +02:00
Lincoln Stein
04f815638c chore(invocation stats): remove old dangling debug statement 2026-01-10 11:32:37 -05:00
Lincoln Stein
d6ad6a2dcb fix(invocation stats): Report delta VRAM for each invocation and fix reporting of RAM cache size 2026-01-10 11:32:37 -05:00
Hosted Weblate
784503e484 translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

translationBot(ui): update translation files

Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2026-01-08 16:48:16 -05:00
RyoKoba
da2809b000 translationBot(ui): update translation (Japanese)
Currently translated at 99.6% (2155 of 2163 strings)

Co-authored-by: RyoKoba <kobayashi_ryo@cyberagent.co.jp>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ja/
Translation: InvokeAI/Web UI
2026-01-08 16:48:16 -05:00
Weblate (bot)
53c34eb95e ui: translations update from weblate (#8748)
* translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2099 of 2132 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2130 of 2163 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2130 of 2163 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Japanese)

Currently translated at 99.6% (2155 of 2163 strings)

Co-authored-by: RyoKoba <kobayashi_ryo@cyberagent.co.jp>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ja/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation files

Updated by "Cleanup translation files" hook in Weblate.

translationBot(ui): update translation files

Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2103 of 2136 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): added translation (English (United Kingdom))

* translationBot(ui): update translation files

Updated by "Cleanup translation files" hook in Weblate.

Translation: InvokeAI/Web UI
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/

---------

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Co-authored-by: RyoKoba <kobayashi_ryo@cyberagent.co.jp>
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-08 15:56:33 -05:00
Weblate (bot)
18fc822d37 ui: translations update from weblate (#8747)
* translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2099 of 2132 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2130 of 2163 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2130 of 2163 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Japanese)

Currently translated at 99.6% (2155 of 2163 strings)

Co-authored-by: RyoKoba <kobayashi_ryo@cyberagent.co.jp>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ja/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation files

Updated by "Cleanup translation files" hook in Weblate.

translationBot(ui): update translation files

Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2103 of 2136 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): added translation (English (United Kingdom))

* translationBot(ui): update translation files

Updated by "Cleanup translation files" hook in Weblate.

Translation: InvokeAI/Web UI
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/

---------

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Co-authored-by: RyoKoba <kobayashi_ryo@cyberagent.co.jp>
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-08 20:31:33 +00:00
Lincoln Stein
89dc50bd7c Chore: Fix weblate merge conflicts (#8744)
* translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2099 of 2132 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2130 of 2163 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2130 of 2163 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Japanese)

Currently translated at 99.6% (2155 of 2163 strings)

Co-authored-by: RyoKoba <kobayashi_ryo@cyberagent.co.jp>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ja/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation files

Updated by "Cleanup translation files" hook in Weblate.

translationBot(ui): update translation files

Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2103 of 2136 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* chore: add weblate.ini file to gitignore

* Fix duplicate entry in ja.json

Removed duplicate 'jump' entry in Japanese locale.

---------

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Co-authored-by: RyoKoba <kobayashi_ryo@cyberagent.co.jp>
Co-authored-by: Hosted Weblate <hosted@weblate.org>
2026-01-08 15:25:11 -05:00
Lincoln Stein
d34655fd58 Fix(model manager): Improve calculation of Z-Image VAE working memory needs (#8740)
* Fix Z-Image VAE encode/decode to request working memory

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>

* fix: remove check for non-flux vae

* fix: remove check for non-flux vae: latents_to_image

* Remove conditional estimation tests

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
2026-01-08 17:48:09 +00:00
Lincoln Stein
c1a8300e96 chore(release): bump development version to 6.10.0.post1 (#8745)
* chore(release): bump development version 6.10.0post1

* chore: fix version syntax
2026-01-08 12:42:11 -05:00
Lincoln Stein
9c5b2f6498 (chore) Bump to version 6.10.0 (#8742)
* (chore) Prep for v6.10.0rc2

* (chore) bump to version v6.10.0
2026-01-05 23:47:57 -05:00
Alexander Eichhorn
dbb4a07a8f feat(z-image): add add_noise option to Z-Image Denoise (#8739)
* feat(z-image): add `add_noise` option to Z-Image Denoise

Add the same `add_noise` option that exists in FLUX Denoise to Z-Image Denoise.
When set to false, no noise is added to the input latents during image-to-image,
allowing for more controlled transformations.
2026-01-05 21:24:44 -05:00
Lincoln Stein
f66a1a38c8 Merge branch 'main' into lstein/chore/codeowners 2026-01-05 15:16:33 -05:00
Alexander Eichhorn
be2635161c Feature: z-image + metadata node (#8733)
## Summary

Add a new "Denoise - Z-Image + Metadata" node
(`ZImageDenoiseMetaInvocation`) that extends the Z-Image denoise node
with metadata output for image recall functionality.

This follows the same pattern as existing `denoise_latents_meta`
(SD1.5/SDXL) and `flux_denoise_meta` (FLUX) nodes.

**Captured metadata:**
- `width` / `height`
- `steps`
- `guidance` (guidance_scale)
- `denoising_start` / `denoising_end`
- `scheduler`
- `model` (transformer)
- `seed`
- `loras` (if applied)

## Related Issues / Discussions

Enables metadata recall for Z-Image generated images, similar to
existing support for SD1.5, SDXL, and FLUX models.

## QA Instructions

1. Create a workflow using the new "Denoise - Z-Image + Metadata" node
2. Connect the metadata output to a "Save Image" node
3. Generate an image
4. Check that metadata is saved with the image (visible in image info
panel)
5. Verify all generation parameters are captured correctly

## Merge Plan

Requires `feature/zimage-scheduler-support` #8705 branch to be merged
first (base branch).

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Changes to a redux slice have a corresponding migration_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2026-01-05 01:56:22 +01:00
Alexander Eichhorn
384a1a689d Merge branch 'main' into z-image_metadata_node 2026-01-05 01:50:28 +01:00
Lincoln Stein
0021404639 chore: remove dangling debug statements (#8738) 2026-01-05 00:47:46 +00:00
Alexander Eichhorn
a05a626644 Fix typegen 2026-01-05 01:42:49 +01:00
Alexander Eichhorn
97b82d752e Add configurable model cache timeout for automatic memory management (#8693)
## Summary

Adds `model_cache_keep_alive_min` config option (minutes, default 5) to
automatically clear model cache after inactivity. Addresses memory
contention when running InvokeAI alongside other GPU applications like
Ollama.

**Implementation:**
- **Config**: New `model_cache_keep_alive_min` field in
`InvokeAIAppConfig` with 5-minute default
- **ModelCache**: Activity tracking on get/lock/unlock/put operations,
threading.Timer for scheduled clearing
- **Thread safety**: Double-check pattern handles race conditions,
daemon threads for clean shutdown
- **Integration**: ModelManagerService passes config to cache, calls
shutdown() on stop
- **Logging**: Smart timeout logging that only shows messages when
unlocked models are actually cleared
- **Tests**: Comprehensive unit tests with properly configured mock
logger

**Usage:**
```yaml
# invokeai.yaml
model_cache_keep_alive_min: 10  # Clear after 10 minutes idle
model_cache_keep_alive_min: 0   # Set to 0 for indefinite caching (old behavior)
```

**Key Behavior:**
- **Default timeout**: 5 minutes - models are automatically cleared
after 5 minutes of inactivity
- Clearing uses same logic as "Clear Model Cache" button (make_room with
1000GB)
- Only clears **unlocked** models (respects models actively in use
during generation)
- Timeout message only appears when models are actually cleared
- Debug logging available for timeout events when no action is taken
- Prevents misleading log entries during active generation
- Users can set to 0 to restore indefinite caching behavior

## Related Issues / Discussions

Addresses enhancement request for automatic model unloading from memory
after inactivity period.

## QA Instructions

1. **Test default behavior (5-minute timeout)**:
   - Start InvokeAI without explicit config
   - Run a generation
   - Wait 6 minutes with no activity
   - Check logs for "Clearing X unlocked model(s) from cache" message
   - Verify cache is empty

2. **Test custom timeout**:
   - Set `model_cache_keep_alive_min: 0.1` (6 seconds) in config
   - Load a model (run generation)
   - Wait 7+ seconds with no activity
   - Check logs for "Clearing X unlocked model(s) from cache" message
   - Verify cache is empty

3. **Test no timeout (old behavior)**:
   - Set `model_cache_keep_alive_min: 0` in config
   - Run generations and wait extended periods
   - Verify models remain cached indefinitely

4. **Test during active use**:
   - Run continuous generations with any timeout setting
- Verify no timeout messages appear during active use (models are
locked)
- After generation completes, wait for timeout and verify unlocked
models are cleared

## Merge Plan

N/A - Additive change with sensible defaults. The 5-minute default
enables automatic memory management while remaining practical for
typical workflows.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [ ] _Changes to a redux slice have a corresponding migration_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_

<!-- START COPILOT ORIGINAL PROMPT -->



<details>

<summary>Original prompt</summary>

> 
> ----
> 
> *This section details on the original issue you should resolve*
> 
> <issue_title>[enhancement]: option to unload from memory
</issue_title>
> <issue_description>### Is there an existing issue for this?
> 
> - [X] I have searched the existing issues
> 
> ### Contact Details
> 
> ### What should this feature add?
> 
> a command line option to unload model from RAM after a defined period
of time
> 
> ### Alternatives
> 
> running as a container and using Sablier to shutdown the container
after some time, this has the downside of if traffic isn't see through
the web interface it will be shut even if jobs are running.
> 
> ### Additional Content
> 
> _No response_</issue_description>
> 
> ## Comments on the Issue (you are @copilot in this section)
> 
> <comments>
> <comment_new><author>@lstein</author><body>
> I am reopening this issue. I'm running ollama and invoke on the same
server and I find their memory requirements are frequently clashing. It
would be helpful to offer users the option to have the model cache
automatically cleared after a fixed amount of inactivity. I would
suggest the following:
> 
> 1. Introduce a new config file option `model_cache_keep_alive` which
specifies, in minutes, how long to keep a model in cache between
generations. The default is 0, which means to keep the model in cache
indefinitely, as is currently the case.
> 2. If no model generations occur within the timeout period, the model
cache is cleared using the same backend code as the "Clear Model Cache"
button in the queue tab.
> 
> I'm going to assign this to GitHub copilot, partly to test how well it
can manage the Invoke code base. </body></comment_new>
> </comments>
> 


</details>



<!-- START COPILOT CODING AGENT SUFFIX -->

- Fixes invoke-ai/InvokeAI#6856

<!-- START COPILOT CODING AGENT TIPS -->
---

 Let Copilot coding agent [set things up for
you](https://github.com/invoke-ai/InvokeAI/issues/new?title=+Set+up+Copilot+instructions&body=Configure%20instructions%20for%20this%20repository%20as%20documented%20in%20%5BBest%20practices%20for%20Copilot%20coding%20agent%20in%20your%20repository%5D%28https://gh.io/copilot-coding-agent-tips%29%2E%0A%0A%3COnboard%20this%20repo%3E&assignees=copilot)
— coding agent works faster and does higher quality work when set up for
your repo.
2026-01-05 01:41:40 +01:00
Alexander Eichhorn
f29820a7ba feat(ui): improve Z-Image model selector UX with auto-clearing conflicts
Instead of disabling mutually exclusive model selectors, automatically
clear conflicting models when a new selection is made. This applies to
VAE, Qwen3 Encoder, and Qwen3 Source selectors - selecting one now
clears the others. Also applies same logic during metadata recall.
2026-01-05 00:57:45 +01:00
Lincoln Stein
47a634d8fb fix(naming style) change name of model_cache_keep_alive to model_cache_keep_alive_min 2026-01-04 17:36:55 -05:00
Lincoln Stein
768f3dbde0 chore: remove codeowners from /docs directory 2026-01-04 17:08:45 -05:00
Alexander Eichhorn
1ca589ea10 Merge branch 'main' into z-image_metadata_node 2026-01-04 23:07:06 +01:00
Jonathan
3a21e7699f Merge branch 'main' into copilot/add-unload-model-option 2026-01-04 10:22:44 -05:00
Lincoln Stein
56fd7bc7c4 docs(z-image) add Z-Image requirements and starter bundle (#8734)
* docs(z-image) add minimum requirements for Z-Image and create Z-Image starter bundle

* fix(model manager) add flux VAE to Z-Image bundle

* docs(model manager) remove out-of-date model info link

* chore: fix frontendchecks

* chore: lint:prettier

* docs(model manager): clarify minimum hardware for z-image turbo

* (fix) add flux VAE to ZIT starter dependencies & tweak UI docs
2026-01-04 10:17:26 -05:00
Lincoln Stein
2425005aad chore: typegen update 2026-01-04 09:28:43 -05:00
Lincoln Stein
2ccadd1834 Merge branch 'main' into z-image_metadata_node 2026-01-04 07:03:25 -05:00
Lincoln Stein
5cef8bd364 (fix) default timeout to 0 min, to disable timeout feature and restore previous default behavior 2026-01-04 07:01:01 -05:00
Lincoln Stein
8a6d593fe8 Merge branch 'main' into copilot/add-unload-model-option 2026-01-03 22:48:36 -05:00
Lincoln Stein
14309562b8 chore: typegen 2026-01-03 22:48:19 -05:00
Alexander Eichhorn
9f8f9965f9 fix(model-loaders): add local_files_only=True to prevent network requests (#8735) 2026-01-03 22:21:42 -05:00
Jonathan
44a21a348d Merge branch 'main' into copilot/add-unload-model-option 2026-01-03 22:00:11 -05:00
Alexander Eichhorn
81d83d5aab Merge branch 'main' into z-image_metadata_node 2026-01-03 23:06:42 +01:00
Alexander Eichhorn
d99707fdcb fix(ui): fix z-image scheduler recall by reordering metadata handlers
Move Scheduler handler after MainModel in ImageMetadataHandlers so that
base-dependent recall logic (z-image scheduler) works correctly. The
Scheduler handler checks `base === 'z-image'` before dispatching the
z-image scheduler action, but this check failed when Scheduler ran
before MainModel was recalled.
2026-01-03 22:33:18 +01:00
dunkeroni
252dd5b426 Add @dunkeroni as code owner for some paths (#8732)
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-03 20:56:21 +00:00
Alexander Eichhorn
f922f6c634 Update CODEOWNERS (#8731)
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-03 20:53:16 +00:00
Alexander Eichhorn
be0cbe046c feat(flux): add scheduler selection for Flux models (#8704)
* feat(flux): add scheduler selection for Flux models

Add support for alternative diffusers Flow Matching schedulers:
- Euler (default, 1st order)
- Heun (2nd order, better quality, 2x slower)
- LCM (optimized for few steps)

Backend:
- Add schedulers.py with scheduler type definitions and class mapping
- Modify denoise.py to accept optional scheduler parameter
- Add scheduler InputField to flux_denoise invocation (v4.2.0)

Frontend:
- Add fluxScheduler to Redux state and paramsSlice
- Create ParamFluxScheduler component for Linear UI
- Add scheduler to buildFLUXGraph for generation

* fix(flux): prevent progress percentage overflow with LCM scheduler

LCM scheduler may have more internal timesteps than user-facing steps,
causing user_step to exceed total_steps. This resulted in progress
percentage > 1.0, which caused a pydantic validation error.

Fix: Only call step_callback when user_step <= total_steps.

* Ruff format

* fix(flux): remove initial step-0 callback for consistent step count

Remove the initial step_callback at step=0 to match SD/SDXL behavior.
Previously Flux showed N+1 steps (step 0 + N denoising steps), while
SD/SDXL showed only N steps. Now all models display N steps consistently.

* feat(flux): add scheduler support with metadata recall

- Handle LCM scheduler by using num_inference_steps instead of custom sigmas
- Fix progress bar to show user-facing steps instead of internal scheduler steps
- Pass scheduler parameter to Flux denoise node in graph builder
- Add model-aware metadata recall for Flux scheduler

---------

Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com>
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-03 15:52:00 -05:00
Jonathan
e39b880f6d Merge branch 'main' into copilot/add-unload-model-option 2026-01-03 15:41:59 -05:00
Jonathan
4f8ec07d2f Update CODEOWNERS (#8728)
Adding @JPPhoto to CODEOWNERS

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-03 20:40:27 +00:00
Alexander Eichhorn
689953e3cf Feature/zimage scheduler support (#8705)
* feat(flux): add scheduler selection for Flux models

Add support for alternative diffusers Flow Matching schedulers:
- Euler (default, 1st order)
- Heun (2nd order, better quality, 2x slower)
- LCM (optimized for few steps)

Backend:
- Add schedulers.py with scheduler type definitions and class mapping
- Modify denoise.py to accept optional scheduler parameter
- Add scheduler InputField to flux_denoise invocation (v4.2.0)

Frontend:
- Add fluxScheduler to Redux state and paramsSlice
- Create ParamFluxScheduler component for Linear UI
- Add scheduler to buildFLUXGraph for generation

* feat(z-image): add scheduler selection for Z-Image models

Add support for alternative diffusers Flow Matching schedulers for Z-Image:
- Euler (default) - 1st order, optimized for Z-Image-Turbo (8 steps)
- Heun (2nd order) - Better quality, 2x slower
- LCM - Optimized for few-step generation

Backend:
- Extend schedulers.py with Z-Image scheduler types and mapping
- Add scheduler InputField to z_image_denoise invocation (v1.3.0)
- Refactor denoising loop to support diffusers schedulers

Frontend:
- Add zImageScheduler to Redux state in paramsSlice
- Create ParamZImageScheduler component for Linear UI
- Add scheduler to buildZImageGraph for generation

* fix ruff check

* fix(schedulers): prevent progress percentage overflow with LCM scheduler

LCM scheduler may have more internal timesteps than user-facing steps,
causing user_step to exceed total_steps. This resulted in progress
percentage > 1.0, which caused a pydantic validation error.

Fix: Only call step_callback when user_step <= total_steps.

* Ruff format

* fix(schedulers): remove initial step-0 callback for consistent step count

Remove the initial step_callback at step=0 to match SD/SDXL behavior.
Previously Flux/Z-Image showed N+1 steps (step 0 + N denoising steps),
while SD/SDXL showed only N steps. Now all models display N steps
consistently in the server log.

* feat(z-image): add scheduler support with metadata recall

- Handle LCM scheduler by using num_inference_steps instead of custom sigmas
- Fix progress bar to show user-facing steps instead of internal scheduler steps
- Pass scheduler parameter to Z-Image denoise node in graph builder
- Add model-aware metadata recall for Flux and Z-Image schedulers

---------

Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com>
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-03 20:37:04 +00:00
Lincoln Stein
61c2589e39 (chore) update WhatsNew translation text (#8727) 2026-01-03 20:31:50 +00:00
Lincoln Stein
8cf4c6944a (style) ruff fix 2026-01-03 14:54:15 -05:00
Lincoln Stein
db228ddc4f (style) add @record_activity and @synchronized to locked methods 2026-01-03 14:52:31 -05:00
Lincoln Stein
858c94b575 Merge remote-tracking branch 'refs/remotes/origin/copilot/add-unload-model-option' into copilot/add-unload-model-option 2026-01-03 14:26:20 -05:00
Alexander Eichhorn
252794d717 ruff fix 2026-01-03 19:50:08 +01:00
Alexander Eichhorn
7847ccea13 fix typegen 2026-01-03 19:48:11 +01:00
Alexander Eichhorn
1bcf589d19 feat(z-image): add Z-Image Denoise + Metadata node
Add ZImageDenoiseMetaInvocation that extends ZImageDenoiseInvocation
with metadata output for image recall. Captures generation parameters
including steps, guidance, scheduler, seed, model, and LoRAs.
2026-01-03 18:28:17 +01:00
Alexander Eichhorn
132a48497b feat(z-image): add scheduler support with metadata recall
- Handle LCM scheduler by using num_inference_steps instead of custom sigmas
- Fix progress bar to show user-facing steps instead of internal scheduler steps
- Pass scheduler parameter to Z-Image denoise node in graph builder
- Add model-aware metadata recall for Flux and Z-Image schedulers
2026-01-03 17:11:05 +01:00
Jonathan
f49e1b8dae Merge branch 'main' into copilot/add-unload-model-option 2026-01-01 21:31:08 -05:00
Jonathan
e7233efb79 Merge branch 'main' into feature/zimage-scheduler-support 2026-01-01 21:30:44 -05:00
Alexander Eichhorn
3b2d2ef10a fix(gguf): ensure dequantized tensors are on correct device for MPS (#8713)
When using GGUF-quantized models on MPS (Apple Silicon), the
dequantized tensors could end up on a different device than the
other operands in math operations, causing "Expected all tensors
to be on the same device" errors.

This fix ensures that after dequantization, tensors are moved to
the same device as the other tensors in the operation.

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-02 00:45:50 +00:00
Alexander Eichhorn
66974841f1 fix(model-manager): support offline Qwen3 tokenizer loading for Z-Image (#8719)
Add local_files_only fallback for Qwen3 tokenizer loading in both
Checkpoint and GGUF loaders. This ensures Z-Image models can generate
images offline after the initial tokenizer download.

The tokenizer is now loaded with local_files_only=True first, falling
back to network download only if files aren't cached yet.

Fixes #8716

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-02 00:40:08 +00:00
Lincoln Stein
87608ade45 (chore) update config docstrings 2026-01-01 19:35:15 -05:00
Weblate (bot)
1e83aeeb79 ui: translations update from weblate (#8725)
* translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2099 of 2132 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2130 of 2163 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2130 of 2163 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Japanese)

Currently translated at 99.6% (2155 of 2163 strings)

Co-authored-by: RyoKoba <kobayashi_ryo@cyberagent.co.jp>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ja/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation files

Updated by "Cleanup translation files" hook in Weblate.

translationBot(ui): update translation files

Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2103 of 2136 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): added translation (English (United Kingdom))

---------

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Co-authored-by: RyoKoba <kobayashi_ryo@cyberagent.co.jp>
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2026-01-02 00:35:09 +00:00
Alex Yankov
1c76d295a2 fix(docs) Bump versions in mkdocs github actions (#8722) 2026-01-01 19:31:33 -05:00
Lincoln Stein
384250ff8c Merge branch 'main' into copilot/add-unload-model-option 2026-01-01 19:28:45 -05:00
Lincoln Stein
6c3ce8e7e9 Merge branch 'main' into feature/zimage-scheduler-support 2026-01-01 19:08:56 -05:00
Weblate (bot)
d658ef4322 ui: translations update from weblate (#8724)
* translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2099 of 2132 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2130 of 2163 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2130 of 2163 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Japanese)

Currently translated at 99.6% (2155 of 2163 strings)

Co-authored-by: RyoKoba <kobayashi_ryo@cyberagent.co.jp>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ja/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation files

Updated by "Cleanup translation files" hook in Weblate.

translationBot(ui): update translation files

Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2103 of 2136 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

---------

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Co-authored-by: RyoKoba <kobayashi_ryo@cyberagent.co.jp>
2025-12-30 13:21:36 -05:00
Alexander Eichhorn
8d880ef5a0 fix(schedulers): remove initial step-0 callback for consistent step count
Remove the initial step_callback at step=0 to match SD/SDXL behavior.
Previously Flux/Z-Image showed N+1 steps (step 0 + N denoising steps),
while SD/SDXL showed only N steps. Now all models display N steps
consistently in the server log.
2025-12-29 12:39:39 +01:00
Lincoln Stein
c6775cc999 (style) ruff and typegen updates 2025-12-28 22:40:36 -05:00
Lincoln Stein
d44b99ae0a Merge branch 'main' into copilot/add-unload-model-option 2025-12-28 22:39:45 -05:00
blessedcoolant
1675712094 Implement PBR Maps Node (#8700)
* feat: Implement PBR Maps Generation Node

* feat(ui): Add PBR Maps Generation to UI

* chore: fix typegen checks

* chore: possible fix for nvidia 5000 series cards

* fix: Use safetensor models for PBR maps instead of pickles.

* fix: incorrect naming of upconv_block for PBR network

* fix: incorrect naming of displacement map variable

* chore: add relevant docs to the PBR generate function

* fix: clear cuda cache after loading state_dict for PBR maps

* fix: load torch_device only once as multiple models are loaded

* chore(ui): update the filter icon for PBR to CubeBold

More relevant

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2025-12-29 02:11:46 +00:00
Kyle H
2924d052c5 Fix an issue with regional guidance and multiple quick-queued generations after moving bbox (#8613)
* Fix an issue with multiple quick-queued generations after moving bbox

After moving the canvas bbox we still handed out the previous regional-guidance mask because only two parts of the system knew anything had changed. The adapter’s
cache key doesn’t include the bbox, so the next few graph builds reused the stale mask from before the move; if the user queued several runs back‑to‑back, every
background enqueue except the last skipped rerasterizing altogether because another raster job was still in flight. The fix makes the canvas manager invalidate each
region adapter’s cached mask whenever the bbox (or a related setting) changes, and—if a reraster is already running—queues up and waits instead of bailing. Now the
first run after a bbox edit forces a new mask, and rapid-fire enqueues just wait their turn, so every queued generation gets the correct regional prompt.

* (fix) Update invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts

Fixes race condition identified during copilot review.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-29 02:01:21 +00:00
Lincoln Stein
f1624a6215 Merge branch 'main' into copilot/add-unload-model-option 2025-12-28 20:38:42 -05:00
Alexander Eichhorn
b7e28e4fa6 fix(ui): make Z-Image model selects mutually exclusive (#8717)
* fix(ui): make Z-Image model selects mutually exclusive

VAE and Qwen3 Encoder selects are disabled when Qwen3 Source is selected,
and vice versa. This prevents invalid model combinations.

* feat(ui): auto-select Z-Image component models on model change

When switching to a Z-Image model, automatically set valid defaults
if no configuration exists:
- Prefers Qwen3 Source (Diffusers model) if available
- Falls back to Qwen3 Encoder + FLUX VAE combination

This ensures the generate button is enabled immediately after selecting
a Z-Image model, without requiring manual configuration.

* fix(ui): save and restore Qwen3 Source model in metadata

Qwen3 Source (Diffusers Z-Image) model was not being saved to image
metadata or restored during Remix. This adds:
- Saving qwen3_source to metadata in buildZImageGraph
- ZImageQwen3SourceModel metadata handler for parsing and recall
- i18n translation for qwen3Source
2025-12-28 20:25:35 -05:00
Alexander Eichhorn
d7d051200f fix(z_image): use unrestricted image self-attention for regional prompting (#8718)
Changes image self-attention from restricted (region-isolated) to unrestricted
(all image tokens can attend to each other), similar to the FLUX approach.

This fixes the issue where ZImage-Turbo with multiple regional guidance layers
would generate two separate/disconnected images instead of compositing them
into a single unified image.

The regional text-image attention remains restricted so that each region still
responds to its corresponding prompt.

Fixes #8715
2025-12-28 11:32:50 -05:00
Alexander Eichhorn
0f830ddd00 Ruff format 2025-12-28 12:37:21 +01:00
Alexander Eichhorn
9617140b7f Merge branch 'feature/zimage-scheduler-support' of https://github.com/Pfannkuchensack/InvokeAI into feature/zimage-scheduler-support 2025-12-28 12:29:19 +01:00
Alexander Eichhorn
bc4783028f Merge branch 'main' into feature/zimage-scheduler-support 2025-12-28 12:29:14 +01:00
Alexander Eichhorn
16fedfb538 fix(schedulers): prevent progress percentage overflow with LCM scheduler
LCM scheduler may have more internal timesteps than user-facing steps,
causing user_step to exceed total_steps. This resulted in progress
percentage > 1.0, which caused a pydantic validation error.

Fix: Only call step_callback when user_step <= total_steps.
2025-12-28 12:22:28 +01:00
Lincoln Stein
d781a3b8a2 Merge branch 'main' into copilot/add-unload-model-option 2025-12-27 23:27:19 -05:00
blessedcoolant
7182ff26dc fix(ui): misaligned Color Compensation Option (#8714) 2025-12-27 23:11:48 -05:00
Lincoln Stein
95ee27d5c0 Merge branch 'main' into copilot/add-unload-model-option 2025-12-27 21:56:55 -05:00
Lincoln Stein
b4f05d3fe7 Merge branch 'main' into feature/zimage-scheduler-support 2025-12-27 21:50:05 -05:00
Josh Corbett
8deafabe6b feat(prompts): 💄 increase prompt font size (#8712)
* feat(prompts): 💄 increase prompt font size

* style(prompts): 🚨 satisfy linter
2025-12-27 21:18:23 -05:00
copilot-swe-agent[bot]
1bd1c76a2c Change default model_cache_keep_alive to 5 minutes
Changed the default value of model_cache_keep_alive from 0 (indefinite)
to 5 minutes as requested. This means models will now be automatically
cleared from cache after 5 minutes of inactivity by default, unless
users explicitly configure a different value.

Users can still set it to 0 in their config to get the old behavior
of keeping models indefinitely.

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
2025-12-28 02:11:20 +00:00
Lincoln Stein
56fd1da888 Merge branch 'main' into copilot/add-unload-model-option 2025-12-27 21:08:17 -05:00
Alexander Eichhorn
0956ce0cd3 Merge branch 'main' into feature/zimage-scheduler-support 2025-12-28 00:44:10 +01:00
blessedcoolant
d42bf9c941 fix(model-manager): add Z-Image LoRA/DoRA detection support (#8709)
## Summary

Fix Z-Image LoRA/DoRA model detection failing during installation.

Z-Image LoRAs use different key patterns than SD/SDXL LoRAs. The base
`LoRA_LyCORIS_Config_Base` class only checked for key suffixes like
`lora_A.weight` and `lora_B.weight`, but Z-Image LoRAs (especially those
in DoRA format) use:
- `lora_down.weight` / `lora_up.weight` (standard LoRA format)
- `dora_scale` (DoRA weight decomposition)

This PR overrides `_validate_looks_like_lora` in
`LoRA_LyCORIS_ZImage_Config` to recognize Z-Image specific patterns:
- Keys starting with `diffusion_model.layers.` (Z-Image S3-DiT
architecture)
- Keys ending with `lora_down.weight`, `lora_up.weight`,
`lora_A.weight`, `lora_B.weight`, or `dora_scale`

## Related Issues / Discussions

Fixes installation of Z-Image LoRAs trained with DoRA (Weight-Decomposed
Low-Rank Adaptation).

## QA Instructions

1. Download a Z-Image LoRA in DoRA format (e.g., from CivitAI with keys
like `diffusion_model.layers.X.attention.to_k.lora_down.weight`)
2. Try to install the LoRA via Model Manager
3. Verify the model is recognized as a Z-Image LoRA and installs
successfully
4. Verify the LoRA can be applied when generating with Z-Image

## Merge Plan

Standard merge, no special considerations.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Changes to a redux slice have a corresponding migration_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-12-27 23:10:06 +05:30
Alexander Eichhorn
d403587c7f Merge branch 'fix/z-image-lora-dora-detection' of https://github.com/Pfannkuchensack/InvokeAI into fix/z-image-lora-dora-detection 2025-12-27 09:17:33 +01:00
Alexander Eichhorn
355c985cc3 fix(model-manager): add Z-Image LoRA/DoRA detection and loading support
Two fixes for Z-Image LoRA support:

1. Override _validate_looks_like_lora in LoRA_LyCORIS_ZImage_Config to
   recognize Z-Image specific LoRA formats that use different key patterns
   than SD/SDXL LoRAs. Z-Image LoRAs use lora_down.weight/lora_up.weight
   and dora_scale suffixes instead of lora_A.weight/lora_B.weight.

2. Fix _group_by_layer in z_image_lora_conversion_utils.py to correctly
   group LoRA keys by layer name. The previous logic used rsplit with
   maxsplit=2 which incorrectly grouped keys like:
   - "to_k.alpha" -> layer "diffusion_model.layers.17.attention"
   - "lora_down.weight" -> layer "diffusion_model.layers.17.attention.to_k"

   Now uses suffix matching to ensure all keys for a layer are grouped
   together (alpha, dora_scale, lora_down.weight, lora_up.weight).
2025-12-27 09:17:29 +01:00
Alexander Eichhorn
41742146e2 fix(model-manager): add Z-Image LoRA/DoRA detection support
Override _validate_looks_like_lora in LoRA_LyCORIS_ZImage_Config to
recognize Z-Image specific LoRA formats that use different key patterns
than SD/SDXL LoRAs.

Z-Image LoRAs (including DoRA format) use keys like:
- diffusion_model.layers.X.attention.to_k.lora_down.weight
- diffusion_model.layers.X.attention.to_k.dora_scale

The base LyCORIS config only checked for lora_A.weight/lora_B.weight
suffixes, missing the lora_down.weight/lora_up.weight and dora_scale
patterns used by Z-Image LoRAs.
2025-12-27 07:06:12 +01:00
Jonathan
eb516e1998 Merge branch 'main' into feature/zimage-scheduler-support 2025-12-26 22:06:49 -05:00
Lincoln Stein
0b1befa9ab (chore) Prep for v6.10.0rc2 (#8701) 2025-12-26 18:26:04 -05:00
Alexander Eichhorn
bd678b1c95 fix ruff check 2025-12-26 21:22:46 +01:00
Alexander Eichhorn
56bef0b089 feat(z-image): add scheduler selection for Z-Image models
Add support for alternative diffusers Flow Matching schedulers for Z-Image:
- Euler (default) - 1st order, optimized for Z-Image-Turbo (8 steps)
- Heun (2nd order) - Better quality, 2x slower
- LCM - Optimized for few-step generation

Backend:
- Extend schedulers.py with Z-Image scheduler types and mapping
- Add scheduler InputField to z_image_denoise invocation (v1.3.0)
- Refactor denoising loop to support diffusers schedulers

Frontend:
- Add zImageScheduler to Redux state in paramsSlice
- Create ParamZImageScheduler component for Linear UI
- Add scheduler to buildZImageGraph for generation
2025-12-26 21:15:26 +01:00
Alexander Eichhorn
99fc1243cb feat(flux): add scheduler selection for Flux models
Add support for alternative diffusers Flow Matching schedulers:
- Euler (default, 1st order)
- Heun (2nd order, better quality, 2x slower)
- LCM (optimized for few steps)

Backend:
- Add schedulers.py with scheduler type definitions and class mapping
- Modify denoise.py to accept optional scheduler parameter
- Add scheduler InputField to flux_denoise invocation (v4.2.0)

Frontend:
- Add fluxScheduler to Redux state and paramsSlice
- Create ParamFluxScheduler component for Linear UI
- Add scheduler to buildFLUXGraph for generation
2025-12-26 20:53:59 +01:00
Lincoln Stein
a7205e4e36 Merge branch 'main' into copilot/add-unload-model-option 2025-12-25 21:33:59 -05:00
Alexander Eichhorn
65efc3db7d Feature: Add Z-Image-Turbo regional guidance (#8672)
* feat: Add Regional Guidance support for Z-Image model

Implements regional prompting for Z-Image (S3-DiT Transformer) allowing
different prompts to affect different image regions using attention masks.

Backend changes:
- Add ZImageRegionalPromptingExtension for mask preparation
- Add ZImageTextConditioning and ZImageRegionalTextConditioning data classes
- Patch transformer forward to inject 4D regional attention masks
- Use additive float mask (0.0 attend, -inf block) in bfloat16 for compatibility
- Alternate regional/full attention layers for global coherence

Frontend changes:
- Update buildZImageGraph to support regional conditioning collectors
- Update addRegions to create z_image_text_encoder nodes for regions
- Update addZImageLoRAs to handle optional negCond when guidance_scale=0
- Add Z-Image validation (no IP adapters, no autoNegative)

* @Pfannkuchensack
Fix windows path again

* ruff check fix

* ruff formating

* fix(ui): Z-Image CFG guidance_scale check uses > 1 instead of > 0

Changed the guidance_scale check from > 0 to > 1 for Z-Image models.
Since Z-Image uses guidance_scale=1.0 as "no CFG" (matching FLUX convention),
negative conditioning should only be created when guidance_scale > 1.

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2025-12-26 02:25:38 +00:00
Lincoln Stein
de1aa557b8 chore: bump version to v6.10.0rc1 (#8695)
* chore: bump version to v6.10.0rc1

* docs: fix names of code owners in release doc
2025-12-26 02:08:14 +00:00
Lincoln Stein
b9493ddce7 Workaround for Windows being unable to remove tmp directories when installing GGUF files (#8699)
* (bugfix)(mm) work around Windows being unable to rmtree tmp directories after GGUF install

* (style) fix ruff error

* (fix) add workaround for Windows Permission Denied on GGUF file move() call

* (fix) perform torch copy() in GGUF reader to avoid deletion failures on Windows

* (style) fix ruff formatting issues
2025-12-26 02:02:39 +00:00
Lincoln Stein
ca14c5c9e1 Merge branch 'main' into copilot/add-unload-model-option 2025-12-25 00:08:28 -05:00
Josh Corbett
ddb85ca669 fix(prompts): 🐛 prompt attention behaviors, add tests (#8683)
* fix(prompts): 🐛 prompt attention adjust elevation edge cases, added tests

* refactor(prompts): ♻️ create attention edit helper for prompt boxes

* feat(prompts):  apply attention keybinds to negative prompt

* feat(prompts): 🚀 reconsider behaviors, simplify code

* fix(prompts): 🐛 keybind attention update not tracked by undo/redo

* feat(prompts):  overhaul prompt attention behavior

* fix(prompts): 🩹 remove unused type

* fix(prompts): 🩹 remove unused `Token` type

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2025-12-24 17:38:24 -05:00
Lincoln Stein
5b69403ba8 Merge branch 'main' into copilot/add-unload-model-option 2025-12-24 15:39:46 -05:00
Alexander Eichhorn
ac245cbf6c feat(backend): add support for xlabs Flux LoRA format (#8686)
Add support for loading Flux LoRA models in the xlabs format, which uses
keys like `double_blocks.X.processor.{qkv|proj}_lora{1|2}.{down|up}.weight`.

The xlabs format maps:
- lora1 -> img_attn (image attention stream)
- lora2 -> txt_attn (text attention stream)
- qkv -> query/key/value projection
- proj -> output projection

Changes:
- Add FluxLoRAFormat.XLabs enum value
- Add flux_xlabs_lora_conversion_utils.py with detection and conversion
- Update formats.py to detect xlabs format
- Update lora.py loader to handle xlabs format
- Update model probe to accept recognized Flux LoRA formats
- Add unit tests for xlabs format detection and conversion

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2025-12-24 20:18:11 +00:00
Alexander Eichhorn
5be1e03d73 Feature/user workflow tags (#8698)
* Feature: Add Tag System for user made Workflows

* feat(ui): display tags on workflow library tiles

Show workflow tags at the bottom of each tile in the workflow browser,
making it easier to identify workflow categories at a glance.

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2025-12-24 14:54:22 -05:00
Josh Corbett
87314142b5 feat(hotkeys modal): loading state + performance improvements (#8694)
* feat(hotkeys modal):  loading state + performance improvements

* feat(hotkeys modal): add tooltip to edit button and adjust layout spacing

* style(hotkeys modal): 🚨 satisfy the linter

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2025-12-24 14:39:14 -05:00
Alexander Eichhorn
4cb9b8d97d Feature: add prompt template node (#8680)
* feat(nodes): add Prompt Template node

Add a new node that applies Style Preset templates to prompts in workflows.
The node takes a style preset ID and positive/negative prompts as inputs,
then replaces {prompt} placeholders in the template with the provided prompts.

This makes Style Preset templates accessible in Workflow mode, enabling
users to apply consistent styling across their workflow-based generations.

* feat(nodes): add StylePresetField for database-driven preset selection

Adds a new StylePresetField type that enables dropdown selection of
style presets from the database in the workflow editor.

Changes:
- Add StylePresetField to backend (fields.py)
- Update Prompt Template node to use StylePresetField instead of string ID
- Add frontend field type definitions (zod schemas, type guards)
- Create StylePresetFieldInputComponent with Combobox
- Register field in InputFieldRenderer and nodesSlice
- Add translations for preset selection

* fix schema.ts on windows.

* chore(api): regenerate schema.ts after merge

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-12-24 14:33:16 -05:00
Lincoln Stein
83deb0233e Merge remote-tracking branch 'refs/remotes/origin/copilot/add-unload-model-option' into copilot/add-unload-model-option 2025-12-24 00:44:32 -05:00
Lincoln Stein
8ebb6dd3d9 (chore) regenerate typescript schema 2025-12-24 00:43:06 -05:00
copilot-swe-agent[bot]
b7afd9b5b3 Fix test failures caused by MagicMock TypeError
Configure mock logger to return a valid log level for getEffectiveLevel()
to prevent TypeError when comparing with logging.DEBUG constant.

The issue was that ModelCache._log_cache_state() checks
self._logger.getEffectiveLevel() > logging.DEBUG, and when the logger
is a MagicMock without configuration, getEffectiveLevel() returns another
MagicMock, causing a TypeError when compared with an int.

Fixes all 4 test failures in test_model_cache_timeout.py

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
2025-12-24 05:42:45 +00:00
copilot-swe-agent[bot]
4987b4da1c Fix timeout message appearing during active generation
Only log "Clearing model cache" message when there are actually unlocked
models to clear. This prevents the misleading message from appearing during
active generation when all models are locked.

Changes:
- Check for unlocked models before logging clear message
- Add count of unlocked models in log message
- Add debug log when all models are locked
- Improves user experience by avoiding confusing messages

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
2025-12-24 05:31:11 +00:00
Lincoln Stein
a21b7792d8 (chore) regenerate config docstrings 2025-12-24 00:29:48 -05:00
Lincoln Stein
8819cc30be (chore) regenerate schema.ts 2025-12-24 00:28:55 -05:00
Lincoln Stein
9d1de81fe2 (style) correct ruff formatting error 2025-12-24 00:19:25 -05:00
Lincoln Stein
1e15b8c106 Merge branch 'main' into copilot/add-unload-model-option 2025-12-24 00:14:45 -05:00
Alexander Eichhorn
21138e5d52 fix support multi-subfolder downloads for Z-Image Qwen3 encoder (#8692)
* fix(model-install): support multi-subfolder downloads for Z-Image Qwen3 encoder

The Z-Image Qwen3 text encoder requires both text_encoder and tokenizer
subfolders from the HuggingFace repo, but the previous implementation
only downloaded the text_encoder subfolder, causing model identification
to fail.

Changes:
- Add subfolders property to HFModelSource supporting '+' separated paths
- Extend filter_files() and download_urls() to handle multiple subfolders
- Update _multifile_download() to preserve subfolder structure
- Make Qwen3Encoder probe check both nested and direct config.json paths
- Update Qwen3EncoderLoader to handle both directory structures
- Change starter model source to text_encoder+tokenizer

* ruff format

* fix schema description

* fix schema description

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2025-12-23 23:39:43 -05:00
copilot-swe-agent[bot]
8d76b4e4d4 Fix ruff whitespace errors and improve timeout logging
- Remove all trailing whitespace (W293 errors)
- Add debug logging when timeout fires but activity detected
- Add debug logging when timeout fires but cache is empty
- Only log "Clearing model cache" message when actually clearing
- Prevents misleading timeout messages during active generation

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
2025-12-24 04:05:57 +00:00
Lincoln Stein
9662d1fdb6 Merge branch 'main' into copilot/add-unload-model-option 2025-12-23 22:48:11 -05:00
Alexander Eichhorn
39114b0ad0 Feature (UI): add model path update for external models (#8675)
* feat(ui): add model path update for external models

Add ability to update file paths for externally managed models (models with
absolute paths). Invoke-controlled models (with relative paths in the models
directory) are excluded from this feature to prevent breaking internal
model management.

- Add ModelUpdatePathButton component with modal dialog
- Only show button for external models (absolute path check)
- Add translations for path update UI elements

* Added support for Windows UNC paths in ModelView.tsx:38-41. The isExternalModel function now detects:
Unix absolute paths: /home/user/models/...
Windows drive paths: C:\Models\... or D:/Models/...
Windows UNC paths: \\ServerName\ShareName\... or //ServerName/ShareName/...

* fix(ui): validate path format in Update Path modal to prevent invalid paths

When updating an external model's path, the new path is now validated to ensure
it follows an absolute path format (Unix, Windows drive, or UNC). This prevents
users from accidentally entering invalid paths that would cause the Update Path
button to disappear, leaving them unable to correct the mistake.

* fix(ui): extract isExternalModel to separate file to fix circular dependency

Moves the isExternalModel utility function to its own file to break the
circular dependency between ModelView.tsx and ModelUpdatePathButton.tsx.

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2025-12-23 22:46:50 -05:00
Josh Corbett
3fe5f62c48 feat(hotkeys): Overhaul hotkeys modal UI (#8682)
* feat(hotkeys):  overhaul hotkeys modal UI

* fix(model manager): 🩹 improved check for hotkey search clear button

* fix(model manager): 🩹 remove unused exports

* feat(starter-models): add Z-Image Turbo starter models

Add Z-Image Turbo and related models to the starter models list:
- Z-Image Turbo (full precision, ~13GB)
- Z-Image Turbo quantized (GGUF Q4_K, ~4GB)
- Z-Image Qwen3 Text Encoder (full precision, ~8GB)
- Z-Image Qwen3 Text Encoder quantized (GGUF Q6_K, ~3.3GB)
- Z-Image ControlNet Union (Canny, HED, Depth, Pose, MLSD, Inpainting)

The quantized Turbo model includes the quantized Qwen3 encoder as a
dependency for automatic installation.

* feat(starter-models): add Z-Image Q8 quant and ControlNet Tile

Add higher quality Q8_0 quantization option for Z-Image Turbo (~6.6GB)
to complement existing Q4_K variant, providing better quality for users
with more VRAM.

Add dedicated Z-Image ControlNet Tile model (~6.7GB) for upscaling and
detail enhancement workflows.

* feat(hotkeys):  overhaul hotkeys modal UI

* feat(hotkeys modal): 💄 shrink add hotkey button

* fix(hotkeys): normalization and detection issues

* style: 🚨 satisfy the linter

* fix(hotkeys modal): 🩹 remove unused exports

---------

Co-authored-by: Alexander Eichhorn <alex@eichhorn.dev>
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2025-12-23 22:24:00 -05:00
Josh Corbett
73c6b31011 feat(model manager): 💄 refactor model manager bulk actions UI (#8684)
* feat(model manager): 💄 refactor model manager bulk actions UI

* feat(model manager): 💄 tweak model list item ui for checkbox selects

* style(model manager): 🚨 satisfy the linter

* feat(model manager): 💄 tweak search and actions dropdown placement

* refactor(model manager): 🔥 remove unused `ModelListHeader` component

* fix(model manager): 🐛 list items overlapping sticky headers

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2025-12-23 22:17:07 -05:00
copilot-swe-agent[bot]
b16717bbf8 Explicitly pass all ModelCache constructor parameters
- Add explicit storage_device parameter (cpu)
- Add explicit log_memory_usage parameter from config
- Improves code clarity and configuration transparency

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
2025-12-24 00:30:51 +00:00
copilot-swe-agent[bot]
c3217d8a08 Address code review feedback
- Remove unused variable in test
- Add clarifying comment for daemon thread setting
- Add detailed comment explaining cache clearing with 1000 GB value
- Improve code documentation

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
2025-12-24 00:27:39 +00:00
blessedcoolant
f82bcd40fc fix: CFG Scale min value reset to zero (#8691)
No longer needed coz Z Image works at 1.0
2025-12-23 19:27:30 -05:00
copilot-swe-agent[bot]
2500153ed8 Fix race condition in timeout mechanism
- Added clarifying comment that _record_activity is called with lock held
- Enhanced double-check in _on_timeout for thread safety
- Added lock protection to shutdown method
- Improved handling of edge cases where timer fires during activity

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
2025-12-24 00:26:01 +00:00
copilot-swe-agent[bot]
75a14e2a4b Add unit tests for model cache timeout functionality
- Created test_model_cache_timeout.py with comprehensive tests
- Tests timeout clearing behavior
- Tests activity resetting timeout
- Tests no-timeout default behavior
- Tests shutdown canceling timers

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
2025-12-24 00:24:31 +00:00
copilot-swe-agent[bot]
9bbd2b3f11 Add model_cache_keep_alive config option and timeout mechanism
- Added model_cache_keep_alive config field (minutes, default 0 = infinite)
- Implemented timeout tracking in ModelCache class
- Added _record_activity() to track model usage
- Added _on_timeout() to auto-clear cache when timeout expires
- Added shutdown() method to clean up timers
- Integrated timeout with get(), lock(), unlock(), and put() operations
- Updated ModelManagerService to pass keep_alive parameter
- Added cleanup in stop() method

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
2025-12-24 00:22:59 +00:00
copilot-swe-agent[bot]
c26445253c Initial plan 2025-12-24 00:16:17 +00:00
blessedcoolant
5a0b227256 feat(starter-models): add Z-Image Turbo starter models (#8689)
## Summary

Add Z-Image Turbo and related models to the starter models list for easy
installation via the Model Manager:

- **Z-Image Turbo** - Full precision Diffusers format (~13GB)
- **Z-Image Turbo (quantized)** - GGUF Q4_K format (~4GB)
- **Z-Image Qwen3 Text Encoder** - Full precision (~8GB)
- **Z-Image Qwen3 Text Encoder (quantized)** - GGUF Q6_K format (~3.3GB)
- **Z-Image ControlNet Union** - Unified ControlNet supporting Canny,
HED, Depth, Pose, MLSD, and Inpainting modes

The quantized Turbo model includes the quantized Qwen3 encoder as a
dependency for automatic installation.

## Related Issues / Discussions

Builds on the Z-Image Turbo support added in main.

## QA Instructions

1. Open Model Manager → Starter Models
2. Search for "Z-Image"
3. Verify all 5 models appear with correct descriptions
4. Install the quantized version and confirm the Qwen3 encoder
dependency is also installed

## Merge Plan

Standard merge, no special considerations.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Changes to a redux slice have a corresponding migration_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-12-23 08:31:34 +05:30
blessedcoolant
1b5d91d1cf Merge branch 'main' into feat/z-image-starter-models 2025-12-23 08:27:25 +05:30
Alexander Eichhorn
a748519e92 feat(starter-models): add Z-Image Q8 quant and ControlNet Tile
Add higher quality Q8_0 quantization option for Z-Image Turbo (~6.6GB)
to complement existing Q4_K variant, providing better quality for users
with more VRAM.

Add dedicated Z-Image ControlNet Tile model (~6.7GB) for upscaling and
detail enhancement workflows.
2025-12-23 03:27:09 +01:00
blessedcoolant
90e34002f0 fix(z-image): Fix padding token shape mismatch for GGUF models (#8690)
## Summary

Fix shape mismatch when loading GGUF-quantized Z-Image transformer
models.

GGUF Z-Image models store `x_pad_token` and `cap_pad_token` with shape
`[3840]`, but diffusers `ZImageTransformer2DModel` expects `[1, 3840]`
(with batch dimension). This caused a `RuntimeError` on Linux systems
when loading models like `z_image_turbo-Q4_K.gguf`.

The fix:
- Dequantizes GGMLTensors first (since they don't support `unsqueeze`)
- Reshapes the tensors to add the missing batch dimension

## Related Issues / Discussions

Reported by Linux user using:
-
https://huggingface.co/leejet/Z-Image-Turbo-GGUF/resolve/main/z_image_turbo-Q4_K.gguf
-
https://huggingface.co/worstplayer/Z-Image_Qwen_3_4b_text_encoder_GGUF/resolve/main/Qwen_3_4b-Q6_K.gguf

## QA Instructions

1. Install a GGUF-quantized Z-Image model (e.g.,
`z_image_turbo-Q4_K.gguf`)
2. Install a Qwen3 GGUF encoder
3. Run a Z-Image generation
4. Verify no `RuntimeError: size mismatch for x_pad_token` error occurs

## Merge Plan

None, straightforward fix.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Changes to a redux slice have a corresponding migration_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-12-23 06:04:40 +05:30
blessedcoolant
7068cf956a Merge branch 'main' into pr/8690 2025-12-23 05:59:49 +05:30
blessedcoolant
aa764f8bf4 Feature: z-image Turbo Control Net (#8679)
## Summary

Add support for Z-Image ControlNet V2.0 alongside the existing V1
support.

**Key changes:**
- Auto-detect `control_in_dim` from adapter weights (16 for V1, 33 for
V2.0)
- Auto-detect `n_refiner_layers` from state dict
- Add zero-padding for V2.0's additional control channels (diffusers
approach)
- Use `accelerate.init_empty_weights()` for more efficient model
creation
- Add `ControlNet_Checkpoint_ZImage_Config` to frontend schema

## Related Issues / Discussions

Part of Z-Image feature implementation.

## QA Instructions

1. Load a Z-Image ControlNet V1 model (control_in_dim=16) and verify it
works
2. Load a Z-Image ControlNet V2.0 model (control_in_dim=33) and verify
it works
3. Test with different control types: Canny, Depth, Pose
4. Recommended `control_context_scale`: 0.65-0.80

## Merge Plan

Can be merged after review. No special considerations needed.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Changes to a redux slice have a corresponding migration_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-12-23 05:58:58 +05:30
Alexander Eichhorn
73be5e5d35 Merge branch 'main' into feature/z-image-control 2025-12-22 22:56:30 +01:00
DustyShoe
259304bac5 Feature(UI): add extract masked area from raster layers (#8667)
* chore: localize extraction errors

* chore: rename extract masked area menu item

* chore: rename inpaint mask extract component

* fix: use mask bounds for extraction region

* Prettier format applied to InpaintMaskMenuItemsExtractMaskedArea.tsx

* Fix base64 image import bug in extracted area in InpaintMaskMenuItemsExtractMaskedArea.tsx and removed unused locales entries in en.json

* Fix formatting issue in InpaintMaskMenuItemsExtractMaskedArea.tsx

* Minor comment fix

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2025-12-22 15:57:27 -05:00
Alexander Eichhorn
2be701cfe3 Feature: Add Tag System for user made Workflows (#8673)
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2025-12-22 15:41:48 -05:00
blessedcoolant
874b547598 chore: format code for ruff checks 2025-12-23 01:04:22 +05:30
blessedcoolant
7b9ce35806 Merge branch 'main' into pr/8679 2025-12-23 01:03:43 +05:30
Alexander Eichhorn
84f3e44a5d Merge branch 'main' into feat/z-image-starter-models 2025-12-22 20:16:05 +01:00
Alexander Eichhorn
5264b7511c Merge branch 'main' into fix/z-image-gguf-padding-token-shape 2025-12-22 20:15:18 +01:00
Alexander Eichhorn
f8b1f42f6d fix(z-image): Fix padding token shape mismatch for GGUF models
GGUF Z-Image models store x_pad_token and cap_pad_token with shape [dim],
but diffusers ZImageTransformer2DModel expects [1, dim]. This caused a
RuntimeError when loading GGUF-quantized Z-Image models.

The fix dequantizes GGMLTensors first (since they don't support unsqueeze),
then reshapes to add the batch dimension.
2025-12-22 18:31:57 +01:00
Josh Corbett
e1acb636d8 fix(ui): 🐛 HotkeysModal and SettingsModal initial focus (#8687)
* fix(ui): 🐛 `HotkeysModal` and `SettingsModal` initial focus

instead of using the `initialFocusRef` prop, the `Modal` component was focusing on the last available Button. This is a workaround that uses `tabIndex` instead which seems to be working.

Closes #8685

* style: 🚨 satisfy linter

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2025-12-22 11:20:44 -05:00
Alexander Eichhorn
b08accd4be feat(starter-models): add Z-Image Turbo starter models
Add Z-Image Turbo and related models to the starter models list:
- Z-Image Turbo (full precision, ~13GB)
- Z-Image Turbo quantized (GGUF Q4_K, ~4GB)
- Z-Image Qwen3 Text Encoder (full precision, ~8GB)
- Z-Image Qwen3 Text Encoder quantized (GGUF Q6_K, ~3.3GB)
- Z-Image ControlNet Union (Canny, HED, Depth, Pose, MLSD, Inpainting)

The quantized Turbo model includes the quantized Qwen3 encoder as a
dependency for automatic installation.
2025-12-22 15:04:27 +01:00
Alexander Eichhorn
3668d5b83b feat(z-image): add Extension-based Z-Image ControlNet support
Implement Z-Image ControlNet as an Extension pattern (similar to FLUX ControlNet)
instead of merging control weights into the base transformer. This provides:
- Lower memory usage (no weight duplication)
- Flexibility to enable/disable control per step
- Cleaner architecture with separate control adapter

Key implementation details:
- ZImageControlNetExtension: computes control hints per denoising step
- z_image_forward_with_control: custom forward pass with hint injection
- patchify_control_context: utility for control image patchification
- ZImageControlAdapter: standalone adapter with control_layers and noise_refiner

Architecture matches original VideoX-Fun implementation:
- Hints computed ONCE using INITIAL unified state (before main layers)
- Hints injected at every other main transformer layer (15 control blocks)
- Control signal added after each designated layer's forward pass

V2.0 ControlNet support (control_in_dim=33):
- Channels 0-15: control image latents
- Channels 16-31: reference image (zeros for pure control)
- Channel 32: inpaint mask (1.0 = don't inpaint, use control signal)
2025-12-21 22:30:28 +01:00
Alexander Eichhorn
1c13ca8159 style: apply ruff formatting 2025-12-21 18:52:12 +01:00
Alexander Eichhorn
3ed0e55d9d fix: resolve linting errors in Z-Image ControlNet support
- Add missing ControlNet_Checkpoint_ZImage_Config import
- Remove unused imports (Any, Dict, ADALN_EMBED_DIM, is_torch_version)
- Add strict=True to zip() calls
- Replace mutable list defaults with immutable tuples
- Replace dict() calls with literal syntax
- Sort imports in z_image_denoise.py
2025-12-21 18:50:43 +01:00
Alexander Eichhorn
8db8aa8594 Add Z-Image ControlNet V2.0 support
VRAM usage is high.

- Auto-detect control_in_dim from adapter weights (16 for V1, 33 for V2.0)
- Auto-detect n_refiner_layers from state dict
- Add zero-padding for V2.0's additional channels
- Use accelerate.init_empty_weights() for efficient model creation
- Add ControlNet_Checkpoint_ZImage_Config to frontend schema
2025-12-21 18:43:02 +01:00
Alexander Eichhorn
456d578f20 WIP not working.
feat: Add Z-Image ControlNet support with spatial conditioning

Add comprehensive ControlNet support for Z-Image models including:

Backend:
- New ControlNet_Checkpoint_ZImage_Config for Z-Image control adapter models
- Z-Image control key detection (_has_z_image_control_keys) to identify control layers
- ZImageControlAdapter loader for standalone control models
- ZImageControlTransformer2DModel combining base transformer with control layers
- Memory-efficient model loading by building combined state dict
2025-12-21 18:43:02 +01:00
blessedcoolant
ab6b6721dc Feature: Add Z-Image-Turbo model support (#8671)
Add comprehensive support for Z-Image-Turbo (S3-DiT) models including:

Backend:
- New BaseModelType.ZImage in taxonomy
- Z-Image model config classes (ZImageTransformerConfig,
Qwen3TextEncoderConfig)
- Model loader for Z-Image transformer and Qwen3 text encoder
- Z-Image conditioning data structures
- Step callback support for Z-Image with FLUX latent RGB factors

Invocations:
- z_image_model_loader: Load Z-Image transformer and Qwen3 encoder
- z_image_text_encoder: Encode prompts using Qwen3 with chat template
- z_image_denoise: Flow matching denoising with time-shifted sigmas
- z_image_image_to_latents: Encode images to 16-channel latents
- z_image_latents_to_image: Decode latents using FLUX VAE

Frontend:
- Z-Image graph builder for text-to-image generation
- Model picker and validation updates for z-image base type
- CFG scale now allows 0 (required for Z-Image-Turbo)
- Clip skip disabled for Z-Image (uses Qwen3, not CLIP)
- Optimal dimension settings for Z-Image (1024x1024)

Technical details:
- Uses Qwen3 text encoder (not CLIP/T5)
- 16 latent channels with FLUX-compatible VAE
- Flow matching scheduler with dynamic time shift
- 8 inference steps recommended for Turbo variant
- bfloat16 inference dtype

## Summary

<!--A description of the changes in this PR. Include the kind of change
(fix, feature, docs, etc), the "why" and the "how". Screenshots or
videos are useful for frontend changes.-->

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

- Install a Z-Image-Turbo model (e.g., from HuggingFace)
- Select the model in the Model Picker
- Generate a text-to-image with:
- CFG Scale: 0
- Steps: 8
- Resolution: 1024x1024
- Verify the generated image is coherent (not noise)

## Merge Plan

Standard merge, no special considerations needed.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Changes to a redux slice have a corresponding migration_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-12-21 22:11:37 +05:30
blessedcoolant
93a587da90 Merge branch 'main' into feat/z-image-turbo-support 2025-12-21 21:58:22 +05:30
blessedcoolant
87bebf9c28 chore: upgrade diffusers to 0.36.0 to support z image 2025-12-21 21:54:47 +05:30
Alexander Eichhorn
f417c269d1 fix(vae): Fix dtype mismatch in FP32 VAE decode mode
The previous mixed-precision optimization for FP32 mode only converted
some VAE decoder layers (post_quant_conv, conv_in, mid_block) to the
latents dtype while leaving others (up_blocks, conv_norm_out) in float32.
This caused "expected scalar type Half but found Float" errors after
recent diffusers updates.

Simplify FP32 mode to consistently use float32 for both VAE and latents,
removing the incomplete mixed-precision logic. This trades some VRAM
usage for stability and correctness.

Also removes now-unused attention processor imports.
2025-12-16 15:58:48 +01:00
Alexander Eichhorn
4ce0ef5260 stupid windows file path again. 2025-12-16 10:31:52 +01:00
Alexander Eichhorn
39cdcdc9e8 fix(z-image): remove unused WithMetadata and WithBoard mixins from denoise node
The Z-Image denoise node outputs latents, not images, so these mixins
were unnecessary. Metadata and board handling is correctly done in the
L2I (latents-to-image) node. This aligns with how FLUX denoise works.
2025-12-16 09:41:26 +01:00
Josh Corbett
926923bb2b feat(prompts): hotkey controlled prompt weighting (#8647)
* feat(prompts): add abstract syntax tree (AST) builder for prompts

* fix(prompts): add escaped parens to AST

* test(prompts): add AST tests

* fix(prompts): appease the linter

* perf(prompts): break up tokenize function into subroutines

* feat(prompts): add hotkey controlled prompt attention adjust

* fix(hotkeys): 🩹 add translations for hotkey dialog

* fix: 🏷️ remove unused exports

* fix(keybinds): 🐛 use `arrowup`/`arrowdown` over `up`/`down`

* refactor(prompts): ♻️ use better language for attention direction

* style: 🚨 appease the linter

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2025-12-15 21:53:58 -05:00
blessedcoolant
8785d9a3a9 chore: fix ruff checks 2025-12-14 19:51:22 +05:30
Alexander Eichhorn
1e72feb744 Remove unneeded Loggging 2025-12-14 06:44:29 +01:00
Alexander Eichhorn
3ee24cbdde Remove the ParamScheduler for z-images
Fixed the DEFAULT_TOKENIZER_SOURCE to Qwen/Qwen3-4B
2025-12-13 04:23:34 +01:00
Alexander Eichhorn
f9605e18a0 z-image-turbo-fp8-e5m2 works. the z-image-turbo_fp8_scaled_e4m3fn_KJ dont. 2025-12-10 17:15:54 +01:00
Alexander Eichhorn
8551ff8569 fix typegen 2025-12-10 15:04:39 +01:00
Alexander Eichhorn
fb1a99b650 feat(cache): add partial loading support for Z-Image RMSNorm and LayerNorm
- Add CustomDiffusersRMSNorm for diffusers.models.normalization.RMSNorm
- Add CustomLayerNorm for torch.nn.LayerNorm
- Register both in AUTOCAST_MODULE_TYPE_MAPPING

Enables partial loading (enable_partial_loading: true) for Z-Image models
by wrapping their normalization layers with device autocast support
2025-12-10 03:45:42 +01:00
Alexander Eichhorn
3b5d9c26d3 feat(z-image): add Qwen3 GGUF text encoder support and default parameters
- Add Qwen3EncoderGGUFLoader for llama.cpp GGUF quantized text encoders
- Convert llama.cpp key format (blk.X., token_embd) to PyTorch format
- Handle tied embeddings (lm_head.weight ↔ embed_tokens.weight)
- Dequantize embed_tokens for embedding lookups (GGMLTensor limitation)
- Add QK normalization key mappings (q_norm, k_norm) for Qwen3
- Set Z-Image defaults: steps=9, cfg_scale=0.0, width/height=1024
- Allow cfg_scale >= 0 (was >= 1) for Z-Image Turbo compatibility
- Add GGUF format detection for Qwen3 model probing
2025-12-10 03:07:07 +01:00
Alexander Eichhorn
0a986c2720 fix(ui): replace misused isCheckpointMainModelConfig with isFluxDevMainModelConfig
The FLUX Dev license warning in model pickers used isCheckpointMainModelConfig
incorrectly:
```
isCheckpointMainModelConfig(config) && config.variant === 'dev'
```

This caused a TypeScript error because CheckpointModelConfig type doesn't
include the 'variant' property (it's extracted as `{ type: 'main'; format:
'checkpoint' }` which doesn't narrow to include variant).

Changes:
- Add isFluxDevMainModelConfig type guard that properly checks
  base='flux' AND variant='dev', returning MainModelConfig
- Update MainModelPicker and InitialStateMainModelPicker to use new guard
- Remove isCheckpointMainModelConfig as it had no other usages

The function was removed because:
1. It was only used for detecting FLUX Dev models (incorrect use case)
2. No other code needs a generic "is checkpoint format" check
3. The pattern in this codebase is specific type guards per model variant
   (isFluxFillMainModelModelConfig, isRefinerMainModelModelConfig, etc.)
2025-12-09 08:18:17 +01:00
Alexander Eichhorn
3e862ced25 fix typegen wrong 2025-12-09 07:46:12 +01:00
Alexander Eichhorn
ba2475c3f0 fix(z-image): improve device/dtype compatibility and error handling
Add robust device capability detection for bfloat16, replacing hardcoded
dtype with runtime checks that fallback to float16/float32 on unsupported
hardware. This prevents runtime failures on GPUs and CPUs without bfloat16.

Key changes:
- Add TorchDevice.choose_bfloat16_safe_dtype() helper for safe dtype selection
- Fix LoRA device mismatch in layer_patcher.py (add device= to .to() call)
- Replace all assert statements with descriptive exceptions (TypeError/ValueError)
- Add hidden_states bounds check and apply_chat_template fallback in text encoder
- Add GGUF QKV tensor validation (divisible by 3 check)
- Fix CPU noise generation to use float32 for compatibility
- Remove verbose debug logging from LoRA conversion utils
2025-12-09 07:37:06 +01:00
Alexander Eichhorn
841372944f feat(z-image): add metadata recall for VAE and Qwen3 encoder
Add support for saving and recalling Z-Image component models (VAE and
Qwen3 Encoder) in image metadata.

Backend:
- Add qwen3_encoder field to CoreMetadataInvocation (version 2.1.0)

Frontend:
- Add vae and qwen3_encoder to Z-Image graph metadata
- Add Qwen3EncoderModel metadata handler for recall
- Add ZImageVAEModel metadata handler (uses zImageVaeModelSelected
  instead of vaeSelected to set Z-Image-specific VAE state)
- Add qwen3Encoder translation key

This enables "Recall Parameters" / "Remix Image" to restore the VAE
and Qwen3 Encoder settings used for Z-Image generations.
2025-12-09 07:12:36 +01:00
Alexander Eichhorn
e9d52734d1 feat(z-image): add single-file checkpoint support for Z-Image models
Add support for loading Z-Image transformer and Qwen3 encoder models
from single-file safetensors format (in addition to existing diffusers
directory format).

Changes:
- Add Main_Checkpoint_ZImage_Config and Main_GGUF_ZImage_Config for
  single-file Z-Image transformer models
- Add Qwen3Encoder_Checkpoint_Config for single-file Qwen3 text encoder
- Add ZImageCheckpointModel and ZImageGGUFCheckpointModel loaders with
  automatic key conversion from original to diffusers format
- Add Qwen3EncoderCheckpointLoader using Qwen3ForCausalLM with fast
  loading via init_empty_weights and proper weight tying for lm_head
- Update z_image_denoise to accept Checkpoint format models
2025-12-09 06:32:51 +01:00
Alexander Eichhorn
2e0cd4d68c Patch from @lstein for the update of diffusers 2025-12-06 03:12:50 +01:00
Alexander Eichhorn
b28d58b8ce Merge branch 'feat/z-image-turbo-support' of https://github.com/Pfannkuchensack/InvokeAI into feat/z-image-turbo-support 2025-12-05 01:12:34 +01:00
Alexander Eichhorn
4a1710b795 fix for the typegen-checks 2025-12-05 01:12:19 +01:00
Alexander Eichhorn
9f6d04c690 Merge branch 'main' into feat/z-image-turbo-support 2025-12-05 00:45:02 +01:00
Alexander Eichhorn
66729ea9eb Fix windows path again again again... 2025-12-03 03:28:43 +01:00
Alexander Eichhorn
280202908a feat: Add GGUF quantized Z-Image support and improve VAE/encoder flexibility
Add comprehensive support for GGUF quantized Z-Image models and improve component flexibility:

Backend:
- New Main_GGUF_ZImage_Config for GGUF quantized Z-Image transformers
- Z-Image key detection (_has_z_image_keys) to identify S3-DiT models
- GGUF quantization detection and sidecar LoRA patching for quantized models
- Qwen3Encoder_Qwen3Encoder_Config for standalone Qwen3 encoder models

Model Loader:
- Split Z-Image model
2025-12-02 20:31:11 +01:00
Alexander Eichhorn
2b062b21cd fix: Improve Flux AI Toolkit LoRA detection to prevent Z-Image misidentification
Move Flux layer structure check before metadata check to prevent misidentifying Z-Image LoRAs (which use `diffusion_model.layers.X`) as Flux AI Toolkit format. Flux models use `double_blocks` and `single_blocks` patterns which are now checked first regardless of metadata presence.
2025-12-02 15:50:01 +01:00
Alexander Eichhorn
6f9f8e57ac Feature(UI): bulk remove models loras (#8659)
* feat: Add bulk delete functionality for models, LoRAs, and embeddings

Implements a comprehensive bulk deletion feature for the model manager that allows users to select and delete multiple models, LoRAs, and embeddings at once.

Key changes:

Frontend:
- Add multi-selection state management to modelManagerV2 slice
- Update ModelListItem to support Ctrl/Cmd+Click multi-selection with checkboxes
- Create ModelListHeader component showing selection count and bulk actions
- Create BulkDeleteModelsModal for confirming bulk deletions
- Integrate bulk delete UI into ModelList with proper error handling
- Add API mutation for bulk delete operations

Backend:
- Add POST /api/v2/models/i/bulk_delete endpoint
- Implement BulkDeleteModelsRequest and BulkDeleteModelsResponse schemas
- Handle partial failures with detailed error reporting
- Return lists of successfully deleted and failed models

This feature significantly improves user experience when managing large model libraries, especially when restructuring model storage locations.

Fixes issue where users had to delete models individually after moving model files to new storage locations.

* fix: prevent model list header from scrolling with content

* fix: improve error handling in bulk model deletion

- Added proper error serialization using serialize-error for better error logging
- Explicitly defined BulkDeleteModelsResponse type instead of relying on generated schema reference

* refactor: improve code organization in ModelList components

- Reordered imports to follow conventional grouping (external, internal, then third-party utilities)
- Added type assertion for error serialization to satisfy TypeScript
- Extracted inline event handler into named callback function for better readability

* refactor: consolidate Button component props to single line

* feat(ui): enhance model manager bulk selection with select-all and actions menu

- Added select-all checkbox in navigation header with indeterminate state support
- Replaced single delete button with actions dropdown menu for future extensibility
- Made checkboxes always visible instead of conditionally showing on selection
- Moved model filtering logic to ModelListNavigation for select-all functionality
- Improved UX by showing selection state for filtered models only

* fix the wrong path seperater from my windows system

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-12-01 20:09:27 -05:00
Alexander Eichhorn
eaf4742799 Fix windows path again again 2025-12-01 22:28:39 +01:00
Alexander Eichhorn
f05ea28cbd feat: Add Z-Image LoRA support
Add comprehensive LoRA support for Z-Image models including:

Backend:
- New Z-Image LoRA config classes (LoRA_LyCORIS_ZImage_Config, LoRA_Diffusers_ZImage_Config)
- Z-Image LoRA conversion utilities with key mapping for transformer and Qwen3 encoder
- LoRA prefix constants (Z_IMAGE_LORA_TRANSFORMER_PREFIX, Z_IMAGE_LORA_QWEN3_PREFIX)
- LoRA detection logic to distinguish Z-Image from Flux models
- Layer patcher improvements for proper dtype conversion and parameter
2025-12-01 22:23:30 +01:00
Alexander Eichhorn
13ac16e2c0 fix windows path again. 2025-12-01 00:30:53 +01:00
Alexander Eichhorn
eb3f1c9a61 feat: Add Z-Image-Turbo model support
Add comprehensive support for Z-Image-Turbo (S3-DiT) models including:

Backend:
- New BaseModelType.ZImage in taxonomy
- Z-Image model config classes (ZImageTransformerConfig, Qwen3TextEncoderConfig)
- Model loader for Z-Image transformer and Qwen3 text encoder
- Z-Image conditioning data structures
- Step callback support for Z-Image with FLUX latent RGB factors

Invocations:
- z_image_model_loader: Load Z-Image transformer and Qwen3 encoder
- z_image_text_encoder: Encode prompts using Qwen3 with chat template
- z_image_denoise: Flow matching denoising with time-shifted sigmas
- z_image_image_to_latents: Encode images to 16-channel latents
- z_image_latents_to_image: Decode latents using FLUX VAE

Frontend:
- Z-Image graph builder for text-to-image generation
- Model picker and validation updates for z-image base type
- CFG scale now allows 0 (required for Z-Image-Turbo)
- Clip skip disabled for Z-Image (uses Qwen3, not CLIP)
- Optimal dimension settings for Z-Image (1024x1024)

Technical details:
- Uses Qwen3 text encoder (not CLIP/T5)
- 16 latent channels with FLUX-compatible VAE
- Flow matching scheduler with dynamic time shift
- 8 inference steps recommended for Turbo variant
- bfloat16 inference dtype
2025-12-01 00:22:32 +01:00
Kent Keirsey
c6a9847bbd feat(ui): Color Picker V2 (#8585)
* pinned colorpicker

* hex options

* remove unused consts

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2025-11-16 09:49:55 -05:00
Alexander Eichhorn
a2e109b3c2 feat(ui): improve hotkey customization UX with interactive controls and validation (#8649)
* feat: remove the ModelFooter in the ModelView and add the Delete Model Button from the Footer into the View

* forget to run pnpm fix

* chore(ui): reorder the model view buttons

* Initial plan

* Add customizable hotkeys infrastructure with UI

Co-authored-by: dunkeroni <3298737+dunkeroni@users.noreply.github.com>

* Fix ESLint issues in HotkeyEditor component

Co-authored-by: dunkeroni <3298737+dunkeroni@users.noreply.github.com>

* Fix knip unused export warning

Co-authored-by: dunkeroni <3298737+dunkeroni@users.noreply.github.com>

* Add tests for hotkeys slice

Co-authored-by: dunkeroni <3298737+dunkeroni@users.noreply.github.com>

* Fix tests to actually call reducer and add documentation

Co-authored-by: dunkeroni <3298737+dunkeroni@users.noreply.github.com>

* docs: add comprehensive hotkeys system documentation

- Created new HOTKEYS.md technical documentation for developers explaining architecture, data flow, and implementation details
- Added user-facing hotkeys.md guide with features overview and usage instructions
- Removed old CUSTOMIZABLE_HOTKEYS.md in favor of new split documentation
- Expanded documentation with detailed sections on:
  - State management and persistence
  - Component architecture and responsibilities
  - Developer integration

* Behavior changed to hotkey press instead of input + checking for allready used hotkeys

---------

Co-authored-by: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: dunkeroni <3298737+dunkeroni@users.noreply.github.com>
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2025-11-16 14:35:37 +00:00
dunkeroni
5642099a40 Feat: SDXL Color Compensation (#8637)
* feat(nodes/UI): add SDXL color compensation option

* adjust value

* Better warnings on wrong VAE base model

* Restrict XL compensation to XL models

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>

* fix: BaseModelType missing import

* (chore): appease the ruff

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2025-11-16 14:32:12 +00:00
gogurtenjoyer
382d85ee23 Fix memory issues when installing models on Windows (#8652)
* Wrap GGUF loader for context managed close()

Wrap gguf.GGUFReader and then use a context manager to load memory-mapped GGUF files, so that they will automatically close properly when no longer needed. Should prevent the 'file in use in another process' errors on Windows.

* Additional check for cached state_dict

Additional check for cached state_dict as path is now optional - should solve model manager 'missing' this and the resultant memory errors.

* Appease ruff

* Further ruff appeasement

* ruff

* loaders.py fix for linux

No longer attempting to delete internal object.

* loaders.py - one more _mmap ref removed

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2025-11-16 09:25:52 -05:00
Jonathan
abcc987f6f Rework graph.py (#8642)
* Rework graph, add documentation

* Minor fixes to README.md

* Updated schema

* Fixed test to match behavior - all nodes executed, parents before children

* Update invokeai/app/services/shared/graph.py

Cleaned up code

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>

* Change silent corrections to enforcing invariants

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2025-11-16 09:10:47 -05:00
Lincoln Stein
36e400dd5d (chore) Update requirements to python 3.11-12 (#8657)
* (chore) update requirements to python 3.11-12

* update uv.lock
2025-11-08 21:29:43 -05:00
Weblate (bot)
0113931956 ui: translations update from weblate (#8599)
* translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2099 of 2132 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2130 of 2163 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2130 of 2163 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Japanese)

Currently translated at 99.6% (2155 of 2163 strings)

Co-authored-by: RyoKoba <kobayashi_ryo@cyberagent.co.jp>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ja/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation files

Updated by "Cleanup translation files" hook in Weblate.

translationBot(ui): update translation files

Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI

* translationBot(ui): update translation (Italian)

Currently translated at 98.4% (2103 of 2136 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI

---------

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Co-authored-by: RyoKoba <kobayashi_ryo@cyberagent.co.jp>
2025-11-04 02:29:05 +00:00
DustyShoe
8d6e00533e Fix to enable loading fp16 repo variant ControlNets (#8643)
* Fix ControlNet repo variant detection for fp16 weights

* Remove ControlNet diffusers fp16 regression test

* Update invokeai/backend/model_manager/configs/controlnet.py

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>

* style: ruff format controlnet.py

---------

Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
2025-11-03 21:23:35 -05:00
Lincoln Stein
10eebb6c0c remove jazzhaiku as well 2025-11-02 15:22:13 -05:00
Lincoln Stein
68bcf2ebe0 chore(codeowners): remove commercial dev codeowners 2025-11-02 15:22:13 -05:00
blessedcoolant
ad0b09c738 chore(ui): reorder the model view buttons 2025-10-28 00:16:20 +05:30
Alexander Eichhorn
737cf795e8 forget to run pnpm fix 2025-10-28 00:16:20 +05:30
Alexander Eichhorn
6192ff5abb feat: remove the ModelFooter in the ModelView and add the Delete Model Button from the Footer into the View 2025-10-28 00:16:20 +05:30
blessedcoolant
066ba5fb19 fix(mm): directory path leakage on scan folder error (#8641)
## Summary

This fixes a bug in which private directory paths on the host could be
leaked to the user interface. The error occurs during the `scan_folders`
operation when a subdirectory is not accessible. The UI shows a
permission denied error message, followed by the path of the offending
directory. This patch limits the error message to the error type only
and does not give further details.

## Related Issues / Discussions

This bug was reported in a private DM on the Discord server.

## QA Instructions

Before applying this PR, go to ***Model Manager -> Add Model -> Scan
Folder*** and enter the path of a directory that has subdirectories that
the backend should not have access to, for example `/etc`. Press the
***Scan Folder*** button. You will see a Permission Denied error message
that gives away the path of the first inaccesislbe subdirectory.

After applying this PR, you will see just the Permission Denied error
without details.

## Merge Plan

Merge when approved.

## Checklist

- [X] _The PR has a short but descriptive title, suitable for a
changelog_
- [X] _Tests added / updated (if applicable)_
- [X] _Changes to a redux slice have a corresponding migration_
- [X] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-10-28 00:02:59 +05:30
Lincoln Stein
2fb4c92310 fix(mm): directory path leakage on scan folder error 2025-10-27 08:54:57 -04:00
psychedelicious
3fdceba5fc chore: bump version to v6.9.0 2025-10-17 12:13:01 +11:00
psychedelicious
ae4bcc08f2 chore(ui): point ui lib dep at gh repo 2025-10-17 07:22:39 +11:00
psychedelicious
e1d88f93ca fix(ui): generator nodes
Closes #8617
2025-10-16 10:37:14 +11:00
psychedelicious
4ad2574835 feat(ui): add button to reidentify model to mm 2025-10-16 10:33:02 +11:00
psychedelicious
0e3d4beb48 chore(ui): typegen 2025-10-16 10:33:02 +11:00
psychedelicious
dcfd4ea756 feat(mm): reidentify models
Add route and model record service method to reidentify a model. This
re-probes the model files and replaces the model's config with the new
one if it does not error.
2025-10-16 10:33:02 +11:00
psychedelicious
093f8d6720 fix(mm): ignore files in hidden directories when identifying models 2025-10-16 10:33:02 +11:00
psychedelicious
22fdfab764 chore: bump version to v6.9.0rc3 2025-10-16 08:08:44 +11:00
psychedelicious
7a0b157fb8 feat(mm): more exports in invocation api 2025-10-16 08:08:44 +11:00
psychedelicious
563da9ee8e feat(mm): write warning README file to models dir 2025-10-16 08:08:44 +11:00
psychedelicious
c8d9cdc22e docs(mm): add readme for updating or adding new model support 2025-10-16 08:08:44 +11:00
psychedelicious
e9c2411da9 chore: bump version to v6.9.0rc2 2025-10-16 08:08:44 +11:00
psychedelicious
90989291ed fix(ui): wait for nav api to be ready before loading main app component 2025-10-16 08:08:44 +11:00
psychedelicious
d04fc343f0 feat(ui): add flag for connected status 2025-10-16 08:08:44 +11:00
psychedelicious
437594915a feat(mm): add model taxonomy and other classes to public API exports 2025-10-16 08:08:44 +11:00
psychedelicious
875aba8979 tidy(mm): remove unused class 2025-10-16 08:08:44 +11:00
psychedelicious
61d13f20ea chore: bump version to v6.9.0rc1 2025-10-16 08:08:44 +11:00
psychedelicious
3b0dd5768b chore(ui): update whatsnew 2025-10-16 08:08:44 +11:00
285 changed files with 87055 additions and 2063 deletions

39
.github/CODEOWNERS vendored
View File

@@ -1,31 +1,32 @@
# continuous integration
/.github/workflows/ @lstein @blessedcoolant @hipsterusername @ebr @jazzhaiku @psychedelicious
/.github/workflows/ @lstein @blessedcoolant
# documentation
/docs/ @lstein @blessedcoolant @hipsterusername @psychedelicious
/mkdocs.yml @lstein @blessedcoolant @hipsterusername @psychedelicious
# documentation - anyone with write privileges can review
/docs/
/mkdocs.yml
# nodes
/invokeai/app/ @blessedcoolant @psychedelicious @hipsterusername @jazzhaiku
/invokeai/app/ @blessedcoolant @lstein @dunkeroni @JPPhoto
# installation and configuration
/pyproject.toml @lstein @blessedcoolant @psychedelicious @hipsterusername
/docker/ @lstein @blessedcoolant @psychedelicious @hipsterusername @ebr
/scripts/ @ebr @lstein @psychedelicious @hipsterusername
/installer/ @lstein @ebr @psychedelicious @hipsterusername
/invokeai/assets @lstein @ebr @psychedelicious @hipsterusername
/invokeai/configs @lstein @psychedelicious @hipsterusername
/invokeai/version @lstein @blessedcoolant @psychedelicious @hipsterusername
/pyproject.toml @lstein @blessedcoolant
/docker/ @lstein @blessedcoolant
/scripts/ @lstein
/installer/ @lstein
/invokeai/assets @lstein
/invokeai/configs @lstein
/invokeai/version @lstein @blessedcoolant
# web ui
/invokeai/frontend @blessedcoolant @psychedelicious @lstein @maryhipp @hipsterusername
/invokeai/frontend @blessedcoolant @lstein @dunkeroni
# generation, model management, postprocessing
/invokeai/backend @lstein @blessedcoolant @hipsterusername @jazzhaiku @psychedelicious @maryhipp
/invokeai/backend @lstein @blessedcoolant @dunkeroni @JPPhoto
# front ends
/invokeai/frontend/CLI @lstein @psychedelicious @hipsterusername
/invokeai/frontend/install @lstein @ebr @psychedelicious @hipsterusername
/invokeai/frontend/merge @lstein @blessedcoolant @psychedelicious @hipsterusername
/invokeai/frontend/training @lstein @blessedcoolant @psychedelicious @hipsterusername
/invokeai/frontend/web @psychedelicious @blessedcoolant @maryhipp @hipsterusername
/invokeai/frontend/CLI @lstein
/invokeai/frontend/install @lstein
/invokeai/frontend/merge @lstein @blessedcoolant
/invokeai/frontend/training @lstein @blessedcoolant
/invokeai/frontend/web @blessedcoolant @lstein @dunkeroni @Pfannkuchensack

View File

@@ -53,8 +53,10 @@ jobs:
df -h
sudo rm -rf /usr/share/dotnet
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
sudo swapoff /mnt/swapfile
sudo rm -rf /mnt/swapfile
if [ -f /mnt/swapfile ]; then
sudo swapoff /mnt/swapfile
sudo rm -rf /mnt/swapfile
fi
if [ -d /mnt ]; then
sudo chmod -R 777 /mnt
echo '{"data-root": "/mnt/docker-root"}' | sudo tee /etc/docker/daemon.json

View File

@@ -23,6 +23,7 @@ jobs:
close-issue-message: "Due to inactivity, this issue was automatically closed. If you are still experiencing the issue, please recreate the issue."
days-before-pr-stale: -1
days-before-pr-close: -1
only-labels: "bug"
exempt-issue-labels: "Active Issue"
repo-token: ${{ secrets.GITHUB_TOKEN }}
operations-per-run: 500

View File

@@ -22,12 +22,12 @@ jobs:
steps:
- name: checkout
uses: actions/checkout@v4
uses: actions/checkout@v5
- name: setup python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.10'
python-version: '3.12'
cache: pip
cache-dependency-path: pyproject.toml

View File

@@ -46,8 +46,10 @@ jobs:
df -h
sudo rm -rf /usr/share/dotnet
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
sudo swapoff /mnt/swapfile
sudo rm -rf /mnt/swapfile
if [ -f /mnt/swapfile ]; then
sudo swapoff /mnt/swapfile
sudo rm -rf /mnt/swapfile
fi
echo "----- Free space after cleanup"
df -h

3
.gitignore vendored
View File

@@ -192,3 +192,6 @@ installer/InvokeAI-Installer/
.aider*
.claude/
# Weblate configuration file
weblate.ini

View File

@@ -16,6 +16,12 @@ Invoke is a leading creative engine built to empower professionals and enthusias
![Highlighted Features - Canvas and Workflows](https://github.com/invoke-ai/InvokeAI/assets/31807370/708f7a82-084f-4860-bfbe-e2588c53548d)
---
> ## 📣 Are you a new or returning InvokeAI user?
> Take our first annual [User's Survey](https://forms.gle/rCE5KuQ7Wfrd1UnS7)
---
# Documentation
| **Quick Links** |

View File

@@ -16,7 +16,9 @@ The launcher uses GitHub as the source of truth for available releases.
## General Prep
Make a developer call-out for PRs to merge. Merge and test things out. Bump the version by editing `invokeai/version/invokeai_version.py`.
Make a developer call-out for PRs to merge. Merge and test things
out. Create a branch with a name like user/chore/vX.X.X-prep and bump the version by editing
`invokeai/version/invokeai_version.py` and commit locally.
## Release Workflow
@@ -26,14 +28,14 @@ It is triggered on **tag push**, when the tag matches `v*`.
### Triggering the Workflow
Ensure all commits that should be in the release are merged, and you have pulled them locally.
Double-check that you have checked out the commit that will represent the release (typically the latest commit on `main`).
Ensure all commits that should be in the release are merged into this branch, and that you have pulled them locally.
Run `make tag-release` to tag the current commit and kick off the workflow. You will be prompted to provide a message - use the version specifier.
If this version's tag already exists for some reason (maybe you had to make a last minute change), the script will overwrite it.
Push the commit to trigger the workflow.
> In case you cannot use the Make target, the release may also be dispatched [manually] via GH.
### Workflow Jobs and Process
@@ -89,7 +91,7 @@ The publish jobs will not run if any of the previous jobs fail.
They use [GitHub environments], which are configured as [trusted publishers] on PyPI.
Both jobs require a @hipsterusername or @psychedelicious to approve them from the workflow's **Summary** tab.
Both jobs require a @lstein or @blessedcoolant to approve them from the workflow's **Summary** tab.
- Click the **Review deployments** button
- Select the environment (either `testpypi` or `pypi` - typically you select both)
@@ -101,7 +103,7 @@ Both jobs require a @hipsterusername or @psychedelicious to approve them from th
Check the [python infrastructure status page] for incidents.
If there are no incidents, contact @hipsterusername or @lstein, who have owner access to GH and PyPI, to see if access has expired or something like that.
If there are no incidents, contact @lstein or @blessedcoolant, who have owner access to GH and PyPI, to see if access has expired or something like that.
#### `publish-testpypi` Job

View File

@@ -0,0 +1,295 @@
# Hotkeys System
This document describes the technical implementation of the customizable hotkeys system in InvokeAI.
> **Note:** For user-facing documentation on how to use customizable hotkeys, see [Hotkeys Feature Documentation](../features/hotkeys.md).
## Overview
The hotkeys system allows users to customize keyboard shortcuts throughout the application. All hotkeys are:
- Centrally defined and managed
- Customizable by users
- Persisted across sessions
- Type-safe and validated
## Architecture
The customizable hotkeys feature is built on top of the existing hotkey system with the following components:
### 1. Hotkeys State Slice (`hotkeysSlice.ts`)
Location: `invokeai/frontend/web/src/features/system/store/hotkeysSlice.ts`
**Responsibilities:**
- Stores custom hotkey mappings in Redux state
- Persisted to IndexedDB using `redux-remember`
- Provides actions to change, reset individual, or reset all hotkeys
**State Shape:**
```typescript
{
_version: 1,
customHotkeys: {
'app.invoke': ['mod+enter'],
'canvas.undo': ['mod+z'],
// ...
}
}
```
**Actions:**
- `hotkeyChanged(id, hotkeys)` - Update a single hotkey
- `hotkeyReset(id)` - Reset a single hotkey to default
- `allHotkeysReset()` - Reset all hotkeys to defaults
### 2. useHotkeyData Hook (`useHotkeyData.ts`)
Location: `invokeai/frontend/web/src/features/system/components/HotkeysModal/useHotkeyData.ts`
**Responsibilities:**
- Defines all default hotkeys
- Merges default hotkeys with custom hotkeys from the store
- Returns the effective hotkeys that should be used throughout the app
- Provides platform-specific key translations (Ctrl/Cmd, Alt/Option)
**Key Functions:**
- `useHotkeyData()` - Returns all hotkeys organized by category
- `useRegisteredHotkeys()` - Hook to register a hotkey in a component
### 3. HotkeyEditor Component (`HotkeyEditor.tsx`)
Location: `invokeai/frontend/web/src/features/system/components/HotkeysModal/HotkeyEditor.tsx`
**Features:**
- Inline editor with input field
- Modifier buttons (Mod, Ctrl, Shift, Alt) for quick insertion
- Live preview of hotkey combinations
- Validation with visual feedback
- Help tooltip with syntax examples
- Save/cancel/reset buttons
**Smart Features:**
- Automatic `+` insertion between modifiers
- Cursor position preservation
- Validation prevents invalid combinations (e.g., modifier-only keys)
### 4. HotkeysModal Component (`HotkeysModal.tsx`)
Location: `invokeai/frontend/web/src/features/system/components/HotkeysModal/HotkeysModal.tsx`
**Features:**
- View Mode / Edit Mode toggle
- Search functionality
- Category-based organization
- Shows HotkeyEditor components when in edit mode
- "Reset All to Default" button in edit mode
## Data Flow
```
┌─────────────────────────────────────────────────────────────┐
│ 1. User opens Hotkeys Modal │
│ 2. User clicks "Edit Mode" button │
│ 3. User clicks edit icon next to a hotkey │
│ 4. User enters new hotkey(s) using editor │
│ 5. User clicks save or presses Enter │
│ 6. Custom hotkey stored via hotkeyChanged() action │
│ 7. Redux state persisted to IndexedDB (redux-remember) │
│ 8. useHotkeyData() hook picks up the change │
│ 9. All components using useRegisteredHotkeys() get update │
└─────────────────────────────────────────────────────────────┘
```
## Hotkey Format
Hotkeys use the format from `react-hotkeys-hook` library:
- **Modifiers:** `mod`, `ctrl`, `shift`, `alt`, `meta`
- **Keys:** Letters, numbers, function keys, special keys
- **Separator:** `+` between keys in a combination
- **Multiple hotkeys:** Comma-separated (e.g., `mod+a, ctrl+b`)
**Examples:**
- `mod+enter` - Mod key + Enter
- `shift+x` - Shift + X
- `ctrl+shift+a` - Control + Shift + A
- `f1, f2` - F1 or F2 (alternatives)
## Developer Guide
### Using Hotkeys in Components
To use a hotkey in a component:
```tsx
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
const MyComponent = () => {
const handleAction = useCallback(() => {
// Your action here
}, []);
// This automatically uses custom hotkeys if configured
useRegisteredHotkeys({
id: 'myAction',
category: 'app', // or 'canvas', 'viewer', 'gallery', 'workflows'
callback: handleAction,
options: { enabled: true, preventDefault: true },
dependencies: [handleAction]
});
// ...
};
```
**Options:**
- `enabled` - Whether the hotkey is active
- `preventDefault` - Prevent default browser behavior
- `enableOnFormTags` - Allow hotkey in form elements (default: false)
### Adding New Hotkeys
To add a new hotkey to the system:
#### 1. Add Translation Strings
In `invokeai/frontend/web/public/locales/en.json`:
```json
{
"hotkeys": {
"app": {
"myAction": {
"title": "My Action",
"desc": "Description of what this hotkey does"
}
}
}
}
```
#### 2. Register the Hotkey
In `invokeai/frontend/web/src/features/system/components/HotkeysModal/useHotkeyData.ts`:
```typescript
// Inside the appropriate category builder function
addHotkey('app', 'myAction', ['mod+k']); // Default binding
```
#### 3. Use the Hotkey
In your component:
```typescript
useRegisteredHotkeys({
id: 'myAction',
category: 'app',
callback: handleMyAction,
options: { enabled: true },
dependencies: [handleMyAction]
});
```
### Hotkey Categories
Current categories:
- **app** - Global application hotkeys
- **canvas** - Canvas/drawing operations
- **viewer** - Image viewer operations
- **gallery** - Gallery/image grid operations
- **workflows** - Node workflow editor
To add a new category, update `useHotkeyData.ts` and add translations.
## Testing
Tests are located in `invokeai/frontend/web/src/features/system/store/hotkeysSlice.test.ts`.
**Test Coverage:**
- Adding custom hotkeys
- Updating existing custom hotkeys
- Resetting individual hotkeys
- Resetting all hotkeys
- State persistence and migration
Run tests with:
```bash
cd invokeai/frontend/web
pnpm test:no-watch
```
## Persistence
Custom hotkeys are persisted using the same mechanism as other app settings:
- Stored in Redux state under the `hotkeys` slice
- Persisted to IndexedDB via `redux-remember`
- Automatically loaded when the app starts
- Survives page refreshes and browser restarts
- Includes migration support for state schema changes
**State Location:**
- IndexedDB database: `invoke`
- Store key: `hotkeys`
## Dependencies
- **react-hotkeys-hook** (v4.5.0) - Core hotkey handling
- **@reduxjs/toolkit** - State management
- **redux-remember** - Persistence
- **zod** - State validation
## Best Practices
1. **Use `mod` instead of `ctrl`** - Automatically maps to Cmd on Mac, Ctrl elsewhere
2. **Provide descriptive translations** - Help users understand what each hotkey does
3. **Avoid conflicts** - Check existing hotkeys before adding new ones
4. **Use preventDefault** - Prevent browser default behavior when appropriate
5. **Check enabled state** - Only activate hotkeys when the action is available
6. **Use dependencies correctly** - Ensure callbacks are stable with useCallback
## Common Patterns
### Conditional Hotkeys
```typescript
useRegisteredHotkeys({
id: 'save',
category: 'app',
callback: handleSave,
options: {
enabled: hasUnsavedChanges && !isLoading, // Only when valid
preventDefault: true
},
dependencies: [hasUnsavedChanges, isLoading, handleSave]
});
```
### Multiple Hotkeys for Same Action
```typescript
// In useHotkeyData.ts
addHotkey('canvas', 'redo', ['mod+shift+z', 'mod+y']); // Two alternatives
```
### Focus-Scoped Hotkeys
```typescript
import { useFocusRegion } from 'common/hooks/focus';
const MyComponent = () => {
const focusRegionRef = useFocusRegion('myRegion');
// Hotkey only works when this region has focus
useRegisteredHotkeys({
id: 'myAction',
category: 'app',
callback: handleAction,
options: { enabled: true }
});
return <div ref={focusRegionRef}>...</div>;
};
```

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,64 @@
# Pull Request Merge Policy
This document outlines the process for reviewing and merging pull requests (PRs) into the InvokeAI repository.
## Review Process
### 1. Assignment
One of the repository maintainers will assign collaborators to review a pull request. The assigned reviewer(s) will be responsible for conducting the code review.
### 2. Review and Iteration
The assignee is responsible for:
- Reviewing the PR thoroughly
- Providing constructive feedback
- Iterating with the PR author until the assignee is satisfied that the PR is fit to merge
- Ensuring the PR meets code quality standards, follows project conventions, and doesn't introduce bugs or regressions
### 3. Approval and Notification
Once the assignee is satisfied with the PR:
- The assignee approves the PR
- The assignee alerts one of the maintainers that the PR is ready for merge using the **#request-reviews Discord channel**
### 4. Final Merge
One of the maintainers is responsible for:
- Performing a final check of the PR
- Merging the PR into the appropriate branch
**Important:** Collaborators are strongly discouraged from merging PRs on their own, except in case of emergency (e.g., critical bug fix and no maintainer is available).
### 5. Release Policy
Once a feature release candidate is published, no feature PRs are to
be merged into main. Only bugfixes are allowed until the final
release.
## Best Practices
### Clean Commit History
To encourage a clean development log, PR authors are encouraged to use `git rebase -i` to suppress trivial commit messages (e.g., `ruff` and `prettier` formatting fixes) after the PR is accepted but before it is merged.
### Merge Strategy
The maintainer will perform either a **3-way merge** or **squash merge** when merging a PR into the `main` branch. This approach helps avoid rebase conflict hell and maintains a cleaner project history.
### Attribution
The PR author should reference any papers, source code or
documentation that they used while creating the code both in the PR
and as comments in the code itself. If there are any licensing
restrictions, these should be linked to and/or reproduced in the repo
root.
## Summary
This policy ensures that:
- All PRs receive proper review from assigned collaborators
- Maintainers have final oversight before code enters the main branch
- The commit history remains clean and meaningful
- Merge conflicts are minimized through appropriate merge strategies

80
docs/features/hotkeys.md Normal file
View File

@@ -0,0 +1,80 @@
# Customizable Hotkeys
InvokeAI allows you to customize all keyboard shortcuts (hotkeys) to match your workflow preferences.
## Features
- **View All Hotkeys**: See all available keyboard shortcuts in one place
- **Customize Any Hotkey**: Change any shortcut to your preference
- **Multiple Bindings**: Assign multiple key combinations to the same action
- **Smart Validation**: Built-in validation prevents invalid combinations
- **Persistent Settings**: Your custom hotkeys are saved and restored across sessions
- **Easy Reset**: Reset individual hotkeys or all hotkeys back to defaults
## How to Use
### Opening the Hotkeys Modal
Press `Shift+?` or click the keyboard icon in the application to open the Hotkeys Modal.
### Viewing Hotkeys
In **View Mode** (default), you can:
- Browse all available hotkeys organized by category (App, Canvas, Gallery, Workflows, etc.)
- Search for specific hotkeys using the search bar
- See the current key combination for each action
### Customizing Hotkeys
1. Click the **Edit Mode** button at the bottom of the Hotkeys Modal
2. Find the hotkey you want to change
3. Click the **pencil icon** next to it
4. The editor will appear with:
- **Input field**: Enter your new hotkey combination
- **Modifier buttons**: Quick-insert Mod, Ctrl, Shift, Alt keys
- **Help icon** (?): Shows syntax examples and valid keys
- **Live preview**: See how your hotkey will look
5. Enter your new hotkey using the format:
- `mod+a` - Mod key + A (Mod = Ctrl on Windows/Linux, Cmd on Mac)
- `ctrl+shift+k` - Multiple modifiers
- `f1` - Function keys
- `mod+enter, ctrl+enter` - Multiple alternatives (separated by comma)
6. Click the **checkmark** or press Enter to save
7. Click the **X** or press Escape to cancel
### Resetting Hotkeys
**Reset a single hotkey:**
- Click the counter-clockwise arrow icon that appears next to customized hotkeys
**Reset all hotkeys:**
- In Edit Mode, click the **Reset All to Default** button at the bottom
### Hotkey Format Reference
**Valid Modifiers:**
- `mod` - Context-aware: Ctrl (Windows/Linux) or Cmd (Mac)
- `ctrl` - Control key
- `shift` - Shift key
- `alt` - Alt key (Option on Mac)
**Valid Keys:**
- Letters: `a-z`
- Numbers: `0-9`
- Function keys: `f1-f12`
- Special keys: `enter`, `space`, `tab`, `backspace`, `delete`, `escape`
- Arrow keys: `up`, `down`, `left`, `right`
- And more...
**Examples:**
-`mod+s` - Save action
-`ctrl+shift+p` - Command palette
-`f5, mod+r` - Two alternatives for refresh
-`mod+` - Invalid (no key after modifier)
-`shift+ctrl+` - Invalid (ends with modifier)
## For Developers
For technical implementation details, architecture, and how to add new hotkeys to the system, see the [Hotkeys Developer Documentation](../contributing/HOTKEYS.md).

View File

@@ -70,7 +70,7 @@ Prior to installing PyPatchMatch, you need to take the following steps:
`from patchmatch import patch_match`: It should look like the following:
```py
Python 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] on linux
Python 3.12.3 (main, Aug 14 2025, 17:47:21) [GCC 13.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from patchmatch import patch_match
Compiling and loading c extensions from "/home/lstein/Projects/InvokeAI/.invokeai-env/src/pypatchmatch/patchmatch".

View File

@@ -25,12 +25,24 @@ Hardware requirements vary significantly depending on model and image output siz
- Memory: At least 16GB RAM.
- Disk: 10GB for base installation plus 100GB for models.
=== "FLUX - 1024×1024"
=== "FLUX.1 - 1024×1024"
- GPU: Nvidia 20xx series or later, 10GB+ VRAM.
- Memory: At least 32GB RAM.
- Disk: 10GB for base installation plus 200GB for models.
=== "FLUX.2 Klein - 1024×1024"
- GPU: Nvidia 20xx series or later, 6GB+ VRAM for GGUF Q4 quantized models, 12GB+ for full precision.
- Memory: At least 16GB RAM.
- Disk: 10GB for base installation plus 20GB for models.
=== "Z-Image Turbo - 1024x1024"
- GPU: Nvidia 20xx series or later, 8GB+ VRAM for the Q4_K quantized model. 16GB+ needed for the Q8 or BF16 models.
- Memory: At least 16GB RAM.
- Disk: 10GB for base installation plus 35GB for models.
More detail on system requirements can be found [here](./requirements.md).
## Step 2: Download and Set Up the Launcher

View File

@@ -25,12 +25,29 @@ The requirements below are rough guidelines for best performance. GPUs with less
- Memory: At least 16GB RAM.
- Disk: 10GB for base installation plus 100GB for models.
=== "FLUX - 1024×1024"
=== "FLUX.1 - 1024×1024"
- GPU: Nvidia 20xx series or later, 10GB+ VRAM.
- Memory: At least 32GB RAM.
- Disk: 10GB for base installation plus 200GB for models.
=== "FLUX.2 Klein 4B - 1024×1024"
- GPU: Nvidia 30xx series or later, 12GB+ VRAM (e.g. RTX 3090, RTX 4070). FP8 version works with 8GB+ VRAM.
- Memory: At least 16GB RAM.
- Disk: 10GB for base installation plus 20GB for models (Diffusers format with encoder).
=== "FLUX.2 Klein 9B - 1024×1024"
- GPU: Nvidia 40xx series, 24GB+ VRAM (e.g. RTX 4090). FP8 version works with 12GB+ VRAM.
- Memory: At least 32GB RAM.
- Disk: 10GB for base installation plus 40GB for models (Diffusers format with encoder).
=== "Z-Image Turbo - 1024x1024"
- GPU: Nvidia 20xx series or later, 8GB+ VRAM for the Q4_K quantized model. 16GB+ needed for the Q8 or BF16 models.
- Memory: At least 16GB RAM.
- Disk: 10GB for base installation plus 35GB for models.
!!! info "`tmpfs` on Linux"
If your temporary directory is mounted as a `tmpfs`, ensure it has sufficient space.
@@ -41,7 +58,7 @@ The requirements below are rough guidelines for best performance. GPUs with less
You don't need to do this if you are installing with the [Invoke Launcher](./quick_start.md).
Invoke requires python 3.10 through 3.12. If you don't already have one of these versions installed, we suggest installing 3.12, as it will be supported for longer.
Invoke requires python 3.11 through 3.12. If you don't already have one of these versions installed, we suggest installing 3.12, as it will be supported for longer.
Check that your system has an up-to-date Python installed by running `python3 --version` in the terminal (Linux, macOS) or cmd/powershell (Windows).
@@ -56,7 +73,7 @@ Check that your system has an up-to-date Python installed by running `python3 --
=== "macOS"
- Install python with [an official installer].
- If model installs fail with a certificate error, you may need to run this command (changing the python version to match what you have installed): `/Applications/Python\ 3.10/Install\ Certificates.command`
- If model installs fail with a certificate error, you may need to run this command (changing the python version to match what you have installed): `/Applications/Python\ 3.11/Install\ Certificates.command`
- If you haven't already, you will need to install the XCode CLI Tools by running `xcode-select --install` in a terminal.
=== "Linux"

View File

@@ -49,6 +49,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
FLUXConditioningInfo,
SD3ConditioningInfo,
SDXLConditioningInfo,
ZImageConditioningInfo,
)
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@@ -129,6 +130,7 @@ class ApiDependencies:
FLUXConditioningInfo,
SD3ConditioningInfo,
CogView4ConditioningInfo,
ZImageConditioningInfo,
],
ephemeral=True,
),

View File

@@ -28,7 +28,7 @@ from invokeai.app.services.model_records import (
UnknownModelException,
)
from invokeai.app.util.suppress_output import SuppressOutput
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
from invokeai.backend.model_manager.configs.main import (
Main_Checkpoint_SD1_Config,
Main_Checkpoint_SD2_Config,
@@ -38,6 +38,7 @@ from invokeai.backend.model_manager.configs.main import (
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.model_manager.starter_models import (
STARTER_BUNDLES,
@@ -191,6 +192,49 @@ async def get_model_record(
raise HTTPException(status_code=404, detail=str(e))
@model_manager_router.post(
"/i/{key}/reidentify",
operation_id="reidentify_model",
responses={
200: {
"description": "The model configuration was retrieved successfully",
"content": {"application/json": {"example": example_model_config}},
},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
},
)
async def reidentify_model(
key: Annotated[str, Path(description="Key of the model to reidentify.")],
) -> AnyModelConfig:
"""Attempt to reidentify a model by re-probing its weights file."""
try:
config = ApiDependencies.invoker.services.model_manager.store.get_model(key)
models_path = ApiDependencies.invoker.services.configuration.models_path
if pathlib.Path(config.path).is_relative_to(models_path):
model_path = pathlib.Path(config.path)
else:
model_path = models_path / config.path
mod = ModelOnDisk(model_path)
result = ModelConfigFactory.from_model_on_disk(mod)
if result.config is None:
raise InvalidModelException("Unable to identify model format")
# Retain user-editable fields from the original config
result.config.key = config.key
result.config.name = config.name
result.config.description = config.description
result.config.cover_image = config.cover_image
result.config.trigger_phrases = config.trigger_phrases
result.config.source = config.source
result.config.source_type = config.source_type
new_config = ApiDependencies.invoker.services.model_manager.store.replace_model(config.key, result.config)
return new_config
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
class FoundModel(BaseModel):
path: str = Field(description="Path to the model")
is_installed: bool = Field(description="Whether or not the model is already installed")
@@ -238,9 +282,10 @@ async def scan_for_models(
found_model = FoundModel(path=path, is_installed=is_installed)
scan_results.append(found_model)
except Exception as e:
error_type = type(e).__name__
raise HTTPException(
status_code=500,
detail=f"An error occurred while searching the directory: {e}",
detail=f"An error occurred while searching the directory: {error_type}",
)
return scan_results
@@ -411,6 +456,59 @@ async def delete_model(
raise HTTPException(status_code=404, detail=str(e))
class BulkDeleteModelsRequest(BaseModel):
"""Request body for bulk model deletion."""
keys: List[str] = Field(description="List of model keys to delete")
class BulkDeleteModelsResponse(BaseModel):
"""Response body for bulk model deletion."""
deleted: List[str] = Field(description="List of successfully deleted model keys")
failed: List[dict] = Field(description="List of failed deletions with error messages")
@model_manager_router.post(
"/i/bulk_delete",
operation_id="bulk_delete_models",
responses={
200: {"description": "Models deleted (possibly with some failures)"},
},
status_code=200,
)
async def bulk_delete_models(
request: BulkDeleteModelsRequest = Body(description="List of model keys to delete"),
) -> BulkDeleteModelsResponse:
"""
Delete multiple model records from database.
The configuration records will be removed. The corresponding weights files will be
deleted as well if they reside within the InvokeAI "models" directory.
Returns a list of successfully deleted keys and failed deletions with error messages.
"""
logger = ApiDependencies.invoker.services.logger
installer = ApiDependencies.invoker.services.model_manager.install
deleted = []
failed = []
for key in request.keys:
try:
installer.delete(key)
deleted.append(key)
logger.info(f"Deleted model: {key}")
except UnknownModelException as e:
logger.error(f"Failed to delete model {key}: {str(e)}")
failed.append({"key": key, "error": str(e)})
except Exception as e:
logger.error(f"Failed to delete model {key}: {str(e)}")
failed.append({"key": key, "error": str(e)})
logger.info(f"Bulk delete completed: {len(deleted)} deleted, {len(failed)} failed")
return BulkDeleteModelsResponse(deleted=deleted, failed=failed)
@model_manager_router.delete(
"/i/{key}/image",
operation_id="delete_model_image",
@@ -816,15 +914,48 @@ class StarterModelResponse(BaseModel):
def get_is_installed(
starter_model: StarterModel | StarterModelWithoutDependencies, installed_models: list[AnyModelConfig]
) -> bool:
from invokeai.backend.model_manager.taxonomy import ModelType
for model in installed_models:
# Check if source matches exactly
if model.source == starter_model.source:
return True
# Check if name (or previous names), base and type match
if (
(model.name == starter_model.name or model.name in starter_model.previous_names)
and model.base == starter_model.base
and model.type == starter_model.type
):
return True
# Special handling for Qwen3Encoder models - check by type and variant
# This allows renamed models to still be detected as installed
if starter_model.type == ModelType.Qwen3Encoder:
from invokeai.backend.model_manager.taxonomy import Qwen3VariantType
# Determine expected variant from source pattern
expected_variant: Qwen3VariantType | None = None
if "klein-9B" in starter_model.source or "qwen3_8b" in starter_model.source.lower():
expected_variant = Qwen3VariantType.Qwen3_8B
elif (
"klein-4B" in starter_model.source
or "qwen3_4b" in starter_model.source.lower()
or "Z-Image" in starter_model.source
):
expected_variant = Qwen3VariantType.Qwen3_4B
if expected_variant is not None:
for model in installed_models:
if model.type == ModelType.Qwen3Encoder and hasattr(model, "variant"):
model_variant = model.variant
# Handle both enum and string values
if isinstance(model_variant, Qwen3VariantType):
if model_variant == expected_variant:
return True
elif isinstance(model_variant, str):
if model_variant == expected_variant.value:
return True
return False

View File

@@ -223,6 +223,15 @@ async def get_workflow_thumbnail(
raise HTTPException(status_code=404)
@workflows_router.get("/tags", operation_id="get_all_tags")
async def get_all_tags(
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
) -> list[str]:
"""Gets all unique tags from workflows"""
return ApiDependencies.invoker.services.workflow_records.get_all_tags(categories=categories)
@workflows_router.get("/counts_by_tag", operation_id="get_counts_by_tag")
async def get_counts_by_tag(
tags: list[str] = Query(description="The tags to get counts for"),

View File

@@ -154,6 +154,7 @@ class FieldDescriptions:
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
t5_encoder = "T5 tokenizer and text encoder"
glm_encoder = "GLM (THUDM) tokenizer and text encoder"
qwen3_encoder = "Qwen3 tokenizer and text encoder"
clip_embed_model = "CLIP Embed loader"
clip_g_model = "CLIP-G Embed loader"
unet = "UNet (scheduler, LoRAs)"
@@ -169,6 +170,7 @@ class FieldDescriptions:
flux_model = "Flux model (Transformer) to load"
sd3_model = "SD3 model (MMDiTX) to load"
cogview4_model = "CogView4 model (Transformer) to load"
z_image_model = "Z-Image model (Transformer) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
@@ -241,6 +243,12 @@ class BoardField(BaseModel):
board_id: str = Field(description="The id of the board")
class StylePresetField(BaseModel):
"""A style preset primitive field"""
style_preset_id: str = Field(description="The id of the style preset")
class DenoiseMaskField(BaseModel):
"""An inpaint mask field"""
@@ -321,6 +329,17 @@ class CogView4ConditioningField(BaseModel):
conditioning_name: str = Field(description="The name of conditioning tensor")
class ZImageConditioningField(BaseModel):
"""A Z-Image conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
mask: Optional[TensorField] = Field(
default=None,
description="The mask associated with this conditioning tensor for regional prompting. "
"Excluded regions should be set to False, included regions should be set to True.",
)
class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
@@ -513,7 +532,7 @@ def migrate_model_ui_type(ui_type: UIType | str, json_schema_extra: dict[str, An
case UIType.VAEModel:
ui_model_type = [ModelType.VAE]
case UIType.FluxVAEModel:
ui_model_base = [BaseModelType.Flux]
ui_model_base = [BaseModelType.Flux, BaseModelType.Flux2]
ui_model_type = [ModelType.VAE]
case UIType.LoRAModel:
ui_model_type = [ModelType.LoRA]

View File

@@ -0,0 +1,505 @@
"""Flux2 Klein Denoise Invocation.
Run denoising process with a FLUX.2 Klein transformer model.
Uses Qwen3 conditioning instead of CLIP+T5.
"""
from contextlib import ExitStack
from typing import Callable, Iterator, Optional, Tuple
import torch
import torchvision.transforms as tv_transforms
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
FluxConditioningField,
FluxKontextConditioningField,
Input,
InputField,
LatentsField,
)
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
from invokeai.backend.flux.schedulers import FLUX_SCHEDULER_LABELS, FLUX_SCHEDULER_MAP, FLUX_SCHEDULER_NAME_VALUES
from invokeai.backend.flux2.denoise import denoise
from invokeai.backend.flux2.ref_image_extension import Flux2RefImageExtension
from invokeai.backend.flux2.sampling_utils import (
compute_empirical_mu,
generate_img_ids_flux2,
get_noise_flux2,
get_schedule_flux2,
pack_flux2,
unpack_flux2,
)
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
"flux2_denoise",
title="FLUX2 Denoise",
tags=["image", "flux", "flux2", "klein", "denoise"],
category="image",
version="1.3.0",
classification=Classification.Prototype,
)
class Flux2DenoiseInvocation(BaseInvocation):
"""Run denoising process with a FLUX.2 Klein transformer model.
This node is designed for FLUX.2 Klein models which use Qwen3 as the text encoder.
It does not support ControlNet, IP-Adapters, or regional prompting.
"""
latents: Optional[LatentsField] = InputField(
default=None,
description=FieldDescriptions.latents,
input=Input.Connection,
)
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None,
description=FieldDescriptions.denoise_mask,
input=Input.Connection,
)
denoising_start: float = InputField(
default=0.0,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(
default=1.0,
ge=0,
le=1,
description=FieldDescriptions.denoising_end,
)
add_noise: bool = InputField(default=True, description="Add noise based on denoising start.")
transformer: TransformerField = InputField(
description=FieldDescriptions.flux_model,
input=Input.Connection,
title="Transformer",
)
positive_text_conditioning: FluxConditioningField = InputField(
description=FieldDescriptions.positive_cond,
input=Input.Connection,
)
negative_text_conditioning: Optional[FluxConditioningField] = InputField(
default=None,
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
input=Input.Connection,
)
cfg_scale: float = InputField(
default=1.0,
description=FieldDescriptions.cfg_scale,
title="CFG Scale",
)
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
num_steps: int = InputField(
default=4,
description="Number of diffusion steps. Use 4 for distilled models, 28+ for base models.",
)
scheduler: FLUX_SCHEDULER_NAME_VALUES = InputField(
default="euler",
description="Scheduler (sampler) for the denoising process. 'euler' is fast and standard. "
"'heun' is 2nd-order (better quality, 2x slower). 'lcm' is optimized for few steps.",
ui_choice_labels=FLUX_SCHEDULER_LABELS,
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
vae: VAEField = InputField(
description="FLUX.2 VAE model (required for BN statistics).",
input=Input.Connection,
)
kontext_conditioning: FluxKontextConditioningField | list[FluxKontextConditioningField] | None = InputField(
default=None,
description="FLUX Kontext conditioning (reference images for multi-reference image editing).",
input=Input.Connection,
title="Reference Images",
)
def _get_bn_stats(self, context: InvocationContext) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
"""Extract BN statistics from the FLUX.2 VAE.
The FLUX.2 VAE uses batch normalization on the patchified 128-channel representation.
IMPORTANT: BFL FLUX.2 VAE uses affine=False, so there are NO learnable weight/bias.
BN formula (affine=False): y = (x - mean) / std
Inverse: x = y * std + mean
Returns:
Tuple of (bn_mean, bn_std) tensors of shape (128,), or None if BN layer not found.
"""
with context.models.load(self.vae.vae).model_on_device() as (_, vae):
# Ensure VAE is in eval mode to prevent BN stats from being updated
vae.eval()
# Try to find the BN layer - it may be at different locations depending on model format
bn_layer = None
if hasattr(vae, "bn"):
bn_layer = vae.bn
elif hasattr(vae, "batch_norm"):
bn_layer = vae.batch_norm
elif hasattr(vae, "encoder") and hasattr(vae.encoder, "bn"):
bn_layer = vae.encoder.bn
if bn_layer is None:
return None
# Verify running statistics are initialized
if bn_layer.running_mean is None or bn_layer.running_var is None:
return None
# Get BN running statistics from VAE
bn_mean = bn_layer.running_mean.clone() # Shape: (128,)
bn_var = bn_layer.running_var.clone() # Shape: (128,)
bn_eps = bn_layer.eps if hasattr(bn_layer, "eps") else 1e-4 # BFL uses 1e-4
bn_std = torch.sqrt(bn_var + bn_eps)
return bn_mean, bn_std
def _bn_normalize(
self,
x: torch.Tensor,
bn_mean: torch.Tensor,
bn_std: torch.Tensor,
) -> torch.Tensor:
"""Apply BN normalization to packed latents.
BN formula (affine=False): y = (x - mean) / std
Args:
x: Packed latents of shape (B, seq, 128).
bn_mean: BN running mean of shape (128,).
bn_std: BN running std of shape (128,).
Returns:
Normalized latents of same shape.
"""
# x: (B, seq, 128), params: (128,) -> broadcast over batch and sequence dims
bn_mean = bn_mean.to(x.device, x.dtype)
bn_std = bn_std.to(x.device, x.dtype)
return (x - bn_mean) / bn_std
def _bn_denormalize(
self,
x: torch.Tensor,
bn_mean: torch.Tensor,
bn_std: torch.Tensor,
) -> torch.Tensor:
"""Apply BN denormalization to packed latents (inverse of normalization).
Inverse BN (affine=False): x = y * std + mean
Args:
x: Packed latents of shape (B, seq, 128).
bn_mean: BN running mean of shape (128,).
bn_std: BN running std of shape (128,).
Returns:
Denormalized latents of same shape.
"""
# x: (B, seq, 128), params: (128,) -> broadcast over batch and sequence dims
bn_mean = bn_mean.to(x.device, x.dtype)
bn_std = bn_std.to(x.device, x.dtype)
return x * bn_std + bn_mean
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
latents = latents.detach().to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
inference_dtype = torch.bfloat16
device = TorchDevice.choose_torch_device()
# Get BN statistics from VAE for latent denormalization (optional)
# BFL FLUX.2 VAE uses affine=False, so only mean/std are needed
# Some VAE formats (e.g. diffusers) may not expose BN stats directly
bn_stats = self._get_bn_stats(context)
bn_mean, bn_std = bn_stats if bn_stats is not None else (None, None)
# Load the input latents, if provided
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
if init_latents is not None:
init_latents = init_latents.to(device=device, dtype=inference_dtype)
# Prepare input noise (FLUX.2 uses 32 channels)
noise = get_noise_flux2(
num_samples=1,
height=self.height,
width=self.width,
device=device,
dtype=inference_dtype,
seed=self.seed,
)
b, _c, latent_h, latent_w = noise.shape
packed_h = latent_h // 2
packed_w = latent_w // 2
# Load the conditioning data
pos_cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(pos_cond_data.conditionings) == 1
pos_flux_conditioning = pos_cond_data.conditionings[0]
assert isinstance(pos_flux_conditioning, FLUXConditioningInfo)
pos_flux_conditioning = pos_flux_conditioning.to(dtype=inference_dtype, device=device)
# Qwen3 stacked embeddings (stored in t5_embeds field for compatibility)
txt = pos_flux_conditioning.t5_embeds
# Generate text position IDs (4D format for FLUX.2: T, H, W, L)
# FLUX.2 uses 4D position coordinates for its rotary position embeddings
# IMPORTANT: Position IDs must be int64 (long) dtype
# Diffusers uses: T=0, H=0, W=0, L=0..seq_len-1
seq_len = txt.shape[1]
txt_ids = torch.zeros(1, seq_len, 4, device=device, dtype=torch.long)
txt_ids[..., 3] = torch.arange(seq_len, device=device, dtype=torch.long) # L coordinate varies
# Load negative conditioning if provided
neg_txt = None
neg_txt_ids = None
if self.negative_text_conditioning is not None:
neg_cond_data = context.conditioning.load(self.negative_text_conditioning.conditioning_name)
assert len(neg_cond_data.conditionings) == 1
neg_flux_conditioning = neg_cond_data.conditionings[0]
assert isinstance(neg_flux_conditioning, FLUXConditioningInfo)
neg_flux_conditioning = neg_flux_conditioning.to(dtype=inference_dtype, device=device)
neg_txt = neg_flux_conditioning.t5_embeds
# For text tokens: T=0, H=0, W=0, L=0..seq_len-1 (only L varies per token)
neg_seq_len = neg_txt.shape[1]
neg_txt_ids = torch.zeros(1, neg_seq_len, 4, device=device, dtype=torch.long)
neg_txt_ids[..., 3] = torch.arange(neg_seq_len, device=device, dtype=torch.long)
# Validate transformer config
transformer_config = context.models.get_config(self.transformer.transformer)
assert transformer_config.base == BaseModelType.Flux2 and transformer_config.type == ModelType.Main
# Calculate the timestep schedule using FLUX.2 specific schedule
# This matches diffusers' Flux2Pipeline implementation
# Note: Schedule shifting is handled by the scheduler via mu parameter
image_seq_len = packed_h * packed_w
timesteps = get_schedule_flux2(
num_steps=self.num_steps,
image_seq_len=image_seq_len,
)
# Compute mu for dynamic schedule shifting (used by FlowMatchEulerDiscreteScheduler)
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=self.num_steps)
# Clip the timesteps schedule based on denoising_start and denoising_end
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
# Prepare input latent image
if init_latents is not None:
if self.add_noise:
t_0 = timesteps[0]
x = t_0 * noise + (1.0 - t_0) * init_latents
else:
x = init_latents
else:
if self.denoising_start > 1e-5:
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
x = noise
# If len(timesteps) == 1, then short-circuit
if len(timesteps) <= 1:
return x
# Generate image position IDs (FLUX.2 uses 4D coordinates)
# Position IDs use int64 dtype like diffusers
img_ids = generate_img_ids_flux2(h=latent_h, w=latent_w, batch_size=b, device=device)
# Prepare inpaint mask
inpaint_mask = self._prep_inpaint_mask(context, x)
# Pack all latent tensors
init_latents_packed = pack_flux2(init_latents) if init_latents is not None else None
inpaint_mask_packed = pack_flux2(inpaint_mask) if inpaint_mask is not None else None
noise_packed = pack_flux2(noise)
x = pack_flux2(x)
# BN normalization for txt2img:
# - DO NOT normalize random noise (it's already N(0,1) distributed)
# - Diffusers only normalizes image latents from VAE (for img2img/kontext)
# - Output MUST be denormalized after denoising before VAE decode
#
# For img2img with init_latents, we should normalize init_latents on unpacked
# shape (B, 128, H/16, W/16) - this is handled by _bn_normalize_unpacked below
# Verify packed dimensions
assert packed_h * packed_w == x.shape[1]
# Prepare inpaint extension
inpaint_extension: Optional[RectifiedFlowInpaintExtension] = None
if inpaint_mask_packed is not None:
assert init_latents_packed is not None
inpaint_extension = RectifiedFlowInpaintExtension(
init_latents=init_latents_packed,
inpaint_mask=inpaint_mask_packed,
noise=noise_packed,
)
# Prepare CFG scale list
num_steps = len(timesteps) - 1
cfg_scale_list = [self.cfg_scale] * num_steps
# Check if we're doing inpainting (have a mask or a clipped schedule)
is_inpainting = self.denoise_mask is not None or self.denoising_start > 1e-5
# Create scheduler with FLUX.2 Klein configuration
# For inpainting/img2img, use manual Euler stepping to preserve the exact timestep schedule
# For txt2img, use the scheduler with dynamic shifting for optimal results
scheduler = None
if self.scheduler in FLUX_SCHEDULER_MAP and not is_inpainting:
# Only use scheduler for txt2img - use manual Euler for inpainting to preserve exact timesteps
scheduler_class = FLUX_SCHEDULER_MAP[self.scheduler]
# FlowMatchHeunDiscreteScheduler only supports num_train_timesteps and shift parameters
# FlowMatchEulerDiscreteScheduler and FlowMatchLCMScheduler support dynamic shifting
if self.scheduler == "heun":
scheduler = scheduler_class(
num_train_timesteps=1000,
shift=3.0,
)
else:
scheduler = scheduler_class(
num_train_timesteps=1000,
shift=3.0,
use_dynamic_shifting=True,
base_shift=0.5,
max_shift=1.15,
base_image_seq_len=256,
max_image_seq_len=4096,
time_shift_type="exponential",
)
# Prepare reference image extension for FLUX.2 Klein built-in editing
ref_image_extension = None
if self.kontext_conditioning:
ref_image_extension = Flux2RefImageExtension(
context=context,
ref_image_conditioning=self.kontext_conditioning
if isinstance(self.kontext_conditioning, list)
else [self.kontext_conditioning],
vae_field=self.vae,
device=device,
dtype=inference_dtype,
bn_mean=bn_mean,
bn_std=bn_std,
)
with ExitStack() as exit_stack:
# Load the transformer model
(cached_weights, transformer) = exit_stack.enter_context(
context.models.load(self.transformer.transformer).model_on_device()
)
config = transformer_config
# Determine if the model is quantized
if config.format in [ModelFormat.Diffusers]:
model_is_quantized = False
elif config.format in [
ModelFormat.BnbQuantizedLlmInt8b,
ModelFormat.BnbQuantizednf4b,
ModelFormat.GGUFQuantized,
]:
model_is_quantized = True
else:
model_is_quantized = False
# Apply LoRA models to the transformer
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
dtype=inference_dtype,
cached_weights=cached_weights,
force_sidecar_patching=model_is_quantized,
)
)
# Prepare reference image conditioning if provided
img_cond_seq = None
img_cond_seq_ids = None
if ref_image_extension is not None:
# Ensure batch sizes match
ref_image_extension.ensure_batch_size(x.shape[0])
img_cond_seq, img_cond_seq_ids = (
ref_image_extension.ref_image_latents,
ref_image_extension.ref_image_ids,
)
x = denoise(
model=transformer,
img=x,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
timesteps=timesteps,
step_callback=self._build_step_callback(context),
cfg_scale=cfg_scale_list,
neg_txt=neg_txt,
neg_txt_ids=neg_txt_ids,
scheduler=scheduler,
mu=mu,
inpaint_extension=inpaint_extension,
img_cond_seq=img_cond_seq,
img_cond_seq_ids=img_cond_seq_ids,
)
# Apply BN denormalization if BN stats are available
# The diffusers Flux2KleinPipeline applies: latents = latents * bn_std + bn_mean
# This transforms latents from normalized space to VAE's expected input space
if bn_mean is not None and bn_std is not None:
x = self._bn_denormalize(x, bn_mean, bn_std)
x = unpack_flux2(x.float(), self.height, self.width)
return x
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> Optional[torch.Tensor]:
"""Prepare the inpaint mask."""
if self.denoise_mask is None:
return None
mask = context.tensors.load(self.denoise_mask.mask_name)
mask = 1.0 - mask
_, _, latent_height, latent_width = latents.shape
mask = tv_resize(
img=mask,
size=[latent_height, latent_width],
interpolation=tv_transforms.InterpolationMode.BILINEAR,
antialias=False,
)
mask = mask.to(device=latents.device, dtype=latents.dtype)
return mask.expand_as(latents)
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
"""Iterate over LoRA models to apply."""
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
"""Build a callback for step progress updates."""
def step_callback(state: PipelineIntermediateState) -> None:
latents = state.latents.float()
state.latents = unpack_flux2(latents, self.height, self.width).squeeze()
context.util.flux2_step_callback(state)
return step_callback

View File

@@ -0,0 +1,222 @@
"""Flux2 Klein Model Loader Invocation.
Loads a Flux2 Klein model with its Qwen3 text encoder and VAE.
Unlike standard FLUX which uses CLIP+T5, Klein uses only Qwen3.
"""
from typing import Literal, Optional
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import (
ModelIdentifierField,
Qwen3EncoderField,
TransformerField,
VAEField,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
Flux2VariantType,
ModelFormat,
ModelType,
Qwen3VariantType,
SubModelType,
)
@invocation_output("flux2_klein_model_loader_output")
class Flux2KleinModelLoaderOutput(BaseInvocationOutput):
"""Flux2 Klein model loader output."""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
qwen3_encoder: Qwen3EncoderField = OutputField(description=FieldDescriptions.qwen3_encoder, title="Qwen3 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
max_seq_len: Literal[256, 512] = OutputField(
description="The max sequence length for the Qwen3 encoder.",
title="Max Seq Length",
)
@invocation(
"flux2_klein_model_loader",
title="Main Model - Flux2 Klein",
tags=["model", "flux", "klein", "qwen3"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class Flux2KleinModelLoaderInvocation(BaseInvocation):
"""Loads a Flux2 Klein model, outputting its submodels.
Flux2 Klein uses Qwen3 as the text encoder instead of CLIP+T5.
It uses a 32-channel VAE (AutoencoderKLFlux2) instead of the 16-channel FLUX.1 VAE.
When using a Diffusers format model, both VAE and Qwen3 encoder are extracted
automatically from the main model. You can override with standalone models:
- Transformer: Always from Flux2 Klein main model
- VAE: From main model (Diffusers) or standalone VAE
- Qwen3 Encoder: From main model (Diffusers) or standalone Qwen3 model
"""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux_model,
input=Input.Direct,
ui_model_base=BaseModelType.Flux2,
ui_model_type=ModelType.Main,
title="Transformer",
)
vae_model: Optional[ModelIdentifierField] = InputField(
default=None,
description="Standalone VAE model. Flux2 Klein uses the same VAE as FLUX (16-channel). "
"If not provided, VAE will be loaded from the Qwen3 Source model.",
input=Input.Direct,
ui_model_base=[BaseModelType.Flux, BaseModelType.Flux2],
ui_model_type=ModelType.VAE,
title="VAE",
)
qwen3_encoder_model: Optional[ModelIdentifierField] = InputField(
default=None,
description="Standalone Qwen3 Encoder model. "
"If not provided, encoder will be loaded from the Qwen3 Source model.",
input=Input.Direct,
ui_model_type=ModelType.Qwen3Encoder,
title="Qwen3 Encoder",
)
qwen3_source_model: Optional[ModelIdentifierField] = InputField(
default=None,
description="Diffusers Flux2 Klein model to extract VAE and/or Qwen3 encoder from. "
"Use this if you don't have separate VAE/Qwen3 models. "
"Ignored if both VAE and Qwen3 Encoder are provided separately.",
input=Input.Direct,
ui_model_base=BaseModelType.Flux2,
ui_model_type=ModelType.Main,
ui_model_format=ModelFormat.Diffusers,
title="Qwen3 Source (Diffusers)",
)
max_seq_len: Literal[256, 512] = InputField(
default=512,
description="Max sequence length for the Qwen3 encoder.",
title="Max Seq Length",
)
def invoke(self, context: InvocationContext) -> Flux2KleinModelLoaderOutput:
# Transformer always comes from the main model
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
# Check if main model is Diffusers format (can extract VAE directly)
main_config = context.models.get_config(self.model)
main_is_diffusers = main_config.format == ModelFormat.Diffusers
# Determine VAE source
# IMPORTANT: FLUX.2 Klein uses a 32-channel VAE (AutoencoderKLFlux2), not the 16-channel FLUX.1 VAE.
# The VAE should come from the FLUX.2 Klein Diffusers model, not a separate FLUX VAE.
if self.vae_model is not None:
# Use standalone VAE (user explicitly selected one)
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
elif main_is_diffusers:
# Extract VAE from main model (recommended for FLUX.2)
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
elif self.qwen3_source_model is not None:
# Extract from Qwen3 source Diffusers model
self._validate_diffusers_format(context, self.qwen3_source_model, "Qwen3 Source")
vae = self.qwen3_source_model.model_copy(update={"submodel_type": SubModelType.VAE})
else:
raise ValueError(
"No VAE source provided. Standalone safetensors/GGUF models require a separate VAE. "
"Options:\n"
" 1. Set 'VAE' to a standalone FLUX VAE model\n"
" 2. Set 'Qwen3 Source' to a Diffusers Flux2 Klein model to extract the VAE from"
)
# Determine Qwen3 Encoder source
if self.qwen3_encoder_model is not None:
# Use standalone Qwen3 Encoder - validate it matches the FLUX.2 Klein variant
self._validate_qwen3_encoder_variant(context, main_config)
qwen3_tokenizer = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
qwen3_encoder = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
elif main_is_diffusers:
# Extract from main model (recommended for FLUX.2 Klein)
qwen3_tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
qwen3_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
elif self.qwen3_source_model is not None:
# Extract from separate Diffusers model
self._validate_diffusers_format(context, self.qwen3_source_model, "Qwen3 Source")
qwen3_tokenizer = self.qwen3_source_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
qwen3_encoder = self.qwen3_source_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
else:
raise ValueError(
"No Qwen3 Encoder source provided. Standalone safetensors/GGUF models require a separate text encoder. "
"Options:\n"
" 1. Set 'Qwen3 Encoder' to a standalone Qwen3 text encoder model "
"(Klein 4B needs Qwen3 4B, Klein 9B needs Qwen3 8B)\n"
" 2. Set 'Qwen3 Source' to a Diffusers Flux2 Klein model to extract the encoder from"
)
return Flux2KleinModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
qwen3_encoder=Qwen3EncoderField(tokenizer=qwen3_tokenizer, text_encoder=qwen3_encoder),
vae=VAEField(vae=vae),
max_seq_len=self.max_seq_len,
)
def _validate_diffusers_format(
self, context: InvocationContext, model: ModelIdentifierField, model_name: str
) -> None:
"""Validate that a model is in Diffusers format."""
config = context.models.get_config(model)
if config.format != ModelFormat.Diffusers:
raise ValueError(
f"The {model_name} model must be a Diffusers format model. "
f"The selected model '{config.name}' is in {config.format.value} format."
)
def _validate_qwen3_encoder_variant(self, context: InvocationContext, main_config) -> None:
"""Validate that the standalone Qwen3 encoder variant matches the FLUX.2 Klein variant.
- FLUX.2 Klein 4B requires Qwen3 4B encoder
- FLUX.2 Klein 9B requires Qwen3 8B encoder
"""
if self.qwen3_encoder_model is None:
return
# Get the Qwen3 encoder config
qwen3_config = context.models.get_config(self.qwen3_encoder_model)
# Check if the config has a variant field
if not hasattr(qwen3_config, "variant"):
# Can't validate, skip
return
qwen3_variant = qwen3_config.variant
# Get the FLUX.2 Klein variant from the main model config
if not hasattr(main_config, "variant"):
return
flux2_variant = main_config.variant
# Validate the variants match
# Klein4B requires Qwen3_4B, Klein9B/Klein9BBase requires Qwen3_8B
expected_qwen3_variant = None
if flux2_variant == Flux2VariantType.Klein4B:
expected_qwen3_variant = Qwen3VariantType.Qwen3_4B
elif flux2_variant in (Flux2VariantType.Klein9B, Flux2VariantType.Klein9BBase):
expected_qwen3_variant = Qwen3VariantType.Qwen3_8B
if expected_qwen3_variant is not None and qwen3_variant != expected_qwen3_variant:
raise ValueError(
f"Qwen3 encoder variant mismatch: FLUX.2 Klein {flux2_variant.value} requires "
f"{expected_qwen3_variant.value} encoder, but {qwen3_variant.value} was selected. "
f"Please select a matching Qwen3 encoder or use a Diffusers format model which includes the correct encoder."
)

View File

@@ -0,0 +1,222 @@
"""Flux2 Klein Text Encoder Invocation.
Flux2 Klein uses Qwen3 as the text encoder instead of CLIP+T5.
The key difference is that it extracts hidden states from layers (9, 18, 27)
and stacks them together for richer text representations.
This implementation matches the diffusers Flux2KleinPipeline exactly.
"""
from contextlib import ExitStack
from typing import Iterator, Literal, Optional, Tuple
import torch
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxConditioningField,
Input,
InputField,
TensorField,
UIComponent,
)
from invokeai.app.invocations.model import Qwen3EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_T5_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
# FLUX.2 Klein extracts hidden states from these specific layers
# Matching diffusers Flux2KleinPipeline: (9, 18, 27)
# hidden_states[0] is embedding layer, so layer N is at index N
KLEIN_EXTRACTION_LAYERS = (9, 18, 27)
# Default max sequence length for Klein models
KLEIN_MAX_SEQ_LEN = 512
@invocation(
"flux2_klein_text_encoder",
title="Prompt - Flux2 Klein",
tags=["prompt", "conditioning", "flux", "klein", "qwen3"],
category="conditioning",
version="1.1.0",
classification=Classification.Prototype,
)
class Flux2KleinTextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for Flux2 Klein image generation.
Flux2 Klein uses Qwen3 as the text encoder, extracting hidden states from
layers (9, 18, 27) and stacking them for richer text representations.
This matches the diffusers Flux2KleinPipeline implementation exactly.
"""
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
qwen3_encoder: Qwen3EncoderField = InputField(
title="Qwen3 Encoder",
description=FieldDescriptions.qwen3_encoder,
input=Input.Connection,
)
max_seq_len: Literal[256, 512] = InputField(
default=512,
description="Max sequence length for the Qwen3 encoder.",
)
mask: Optional[TensorField] = InputField(
default=None,
description="A mask defining the region that this conditioning prompt applies to.",
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
qwen3_embeds, pooled_embeds = self._encode_prompt(context)
# Use FLUXConditioningInfo for compatibility with existing Flux denoiser
# t5_embeds -> qwen3 stacked embeddings
# clip_embeds -> pooled qwen3 embedding
conditioning_data = ConditioningFieldData(
conditionings=[FLUXConditioningInfo(clip_embeds=pooled_embeds, t5_embeds=qwen3_embeds)]
)
conditioning_name = context.conditioning.save(conditioning_data)
return FluxConditioningOutput(
conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask)
)
def _encode_prompt(self, context: InvocationContext) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode prompt using Qwen3 text encoder with Klein-style layer extraction.
This matches the diffusers Flux2KleinPipeline._get_qwen3_prompt_embeds() exactly.
Returns:
Tuple of (stacked_embeddings, pooled_embedding):
- stacked_embeddings: Hidden states from layers (9, 18, 27) stacked together.
Shape: (1, seq_len, hidden_size * 3)
- pooled_embedding: Pooled representation for global conditioning.
Shape: (1, hidden_size)
"""
prompt = self.prompt
device = TorchDevice.choose_torch_device()
text_encoder_info = context.models.load(self.qwen3_encoder.text_encoder)
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
with ExitStack() as exit_stack:
(cached_weights, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
# Apply LoRA models to the text encoder
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
model=text_encoder,
patches=self._lora_iterator(context),
prefix=FLUX_LORA_T5_PREFIX, # Reuse T5 prefix for Qwen3 LoRAs
dtype=lora_dtype,
cached_weights=cached_weights,
)
)
context.util.signal_progress("Running Qwen3 text encoder (Klein)")
if not isinstance(text_encoder, PreTrainedModel):
raise TypeError(
f"Expected PreTrainedModel for text encoder, got {type(text_encoder).__name__}. "
"The Qwen3 encoder model may be corrupted or incompatible."
)
if not isinstance(tokenizer, PreTrainedTokenizerBase):
raise TypeError(
f"Expected PreTrainedTokenizerBase for tokenizer, got {type(tokenizer).__name__}. "
"The Qwen3 tokenizer may be corrupted or incompatible."
)
# Format messages exactly like diffusers Flux2KleinPipeline:
# - Only user message, NO system message
# - add_generation_prompt=True (adds assistant prefix)
# - enable_thinking=False
messages = [{"role": "user", "content": prompt}]
# Step 1: Apply chat template to get formatted text (tokenize=False)
text: str = tokenizer.apply_chat_template( # type: ignore[assignment]
messages,
tokenize=False,
add_generation_prompt=True, # Adds assistant prefix like diffusers
enable_thinking=False, # Disable thinking mode
)
# Step 2: Tokenize the formatted text
inputs = tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_seq_len,
)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# Move to device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
# Forward pass through the model - matching diffusers exactly
outputs = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Validate hidden_states output
if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
raise RuntimeError(
"Text encoder did not return hidden_states. "
"Ensure output_hidden_states=True is supported by this model."
)
num_hidden_layers = len(outputs.hidden_states)
# Extract and stack hidden states - EXACTLY like diffusers:
# out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
# prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
hidden_states_list = []
for layer_idx in KLEIN_EXTRACTION_LAYERS:
if layer_idx >= num_hidden_layers:
layer_idx = num_hidden_layers - 1
hidden_states_list.append(outputs.hidden_states[layer_idx])
# Stack along dim=1, then permute and reshape - exactly like diffusers
out = torch.stack(hidden_states_list, dim=1)
out = out.to(dtype=text_encoder.dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
# Create pooled embedding for global conditioning
# Use mean pooling over the sequence (excluding padding)
# This serves a similar role to CLIP's pooled output in standard FLUX
last_hidden_state = outputs.hidden_states[-1] # Use last layer for pooling
# Expand mask to match hidden state dimensions
expanded_mask = attention_mask.unsqueeze(-1).expand_as(last_hidden_state).float()
sum_embeds = (last_hidden_state * expanded_mask).sum(dim=1)
num_tokens = expanded_mask.sum(dim=1).clamp(min=1)
pooled_embeds = sum_embeds / num_tokens
return prompt_embeds, pooled_embeds
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
"""Iterate over LoRA models to apply to the Qwen3 text encoder."""
for lora in self.qwen3_encoder.loras:
lora_info = context.models.load(lora.lora)
if not isinstance(lora_info.model, ModelPatchRaw):
raise TypeError(
f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. "
"The LoRA model may be corrupted or incompatible."
)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -0,0 +1,92 @@
"""Flux2 Klein VAE Decode Invocation.
Decodes latents to images using the FLUX.2 32-channel VAE (AutoencoderKLFlux2).
"""
import torch
from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.util.devices import TorchDevice
@invocation(
"flux2_vae_decode",
title="Latents to Image - FLUX2",
tags=["latents", "image", "vae", "l2i", "flux2", "klein"],
category="latents",
version="1.0.0",
classification=Classification.Prototype,
)
class Flux2VaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents using FLUX.2 Klein's 32-channel VAE."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
"""Decode latents to image using FLUX.2 VAE.
Input latents should already be in the correct space after BN denormalization
was applied in the denoiser. The VAE expects (B, 32, H, W) format.
"""
with vae_info.model_on_device() as (_, vae):
vae_dtype = next(iter(vae.parameters())).dtype
device = TorchDevice.choose_torch_device()
latents = latents.to(device=device, dtype=vae_dtype)
# Decode using diffusers API
decoded = vae.decode(latents, return_dict=False)[0]
# Convert from [-1, 1] to [0, 1] then to [0, 255] PIL image
img = (decoded / 2 + 0.5).clamp(0, 1)
img = rearrange(img[0], "c h w -> h w c")
img_np = (img * 255).byte().cpu().numpy()
# Explicitly create RGB image (not grayscale)
img_pil = Image.fromarray(img_np, mode="RGB")
return img_pil
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
# Log latent statistics for debugging black image issues
context.logger.debug(
f"FLUX.2 VAE decode input: shape={latents.shape}, "
f"min={latents.min().item():.4f}, max={latents.max().item():.4f}, "
f"mean={latents.mean().item():.4f}"
)
# Warn if input latents are all zeros or very small (would cause black images)
if latents.abs().max() < 1e-6:
context.logger.warning(
"FLUX.2 VAE decode received near-zero latents! This will cause black images. "
"The latent cache may be corrupted - try clearing the cache."
)
vae_info = context.models.load(self.vae.vae)
context.util.signal_progress("Running VAE")
image = self._vae_decode(vae_info=vae_info, latents=latents)
TorchDevice.empty_cache()
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,88 @@
"""Flux2 Klein VAE Encode Invocation.
Encodes images to latents using the FLUX.2 32-channel VAE (AutoencoderKLFlux2).
"""
import einops
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
@invocation(
"flux2_vae_encode",
title="Image to Latents - FLUX2",
tags=["latents", "image", "vae", "i2l", "flux2", "klein"],
category="latents",
version="1.0.0",
classification=Classification.Prototype,
)
class Flux2VaeEncodeInvocation(BaseInvocation):
"""Encodes an image into latents using FLUX.2 Klein's 32-channel VAE."""
image: ImageField = InputField(
description="The image to encode.",
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
def _vae_encode(self, vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
"""Encode image to latents using FLUX.2 VAE.
The VAE encodes to 32-channel latent space.
Output latents shape: (B, 32, H/8, W/8).
"""
with vae_info.model_on_device() as (_, vae):
vae_dtype = next(iter(vae.parameters())).dtype
device = TorchDevice.choose_torch_device()
image_tensor = image_tensor.to(device=device, dtype=vae_dtype)
# Encode using diffusers API
# The VAE.encode() returns a DiagonalGaussianDistribution-like object
latent_dist = vae.encode(image_tensor, return_dict=False)[0]
# Sample from the distribution (or use mode for deterministic output)
# Using mode() for deterministic encoding
if hasattr(latent_dist, "mode"):
latents = latent_dist.mode()
elif hasattr(latent_dist, "sample"):
# Fall back to sampling if mode is not available
generator = torch.Generator(device=device).manual_seed(0)
latents = latent_dist.sample(generator=generator)
else:
# Direct tensor output (some VAE implementations)
latents = latent_dist
return latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
# Convert image to tensor (HWC -> CHW, normalize to [-1, 1])
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
context.util.signal_progress("Running VAE Encode")
latents = self._vae_encode(vae_info=vae_info, image_tensor=image_tensor)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)

View File

@@ -32,6 +32,13 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.dype.presets import (
DYPE_PRESET_LABELS,
DYPE_PRESET_OFF,
DyPEPreset,
get_dype_config_from_preset,
)
from invokeai.backend.flux.extensions.dype_extension import DyPEExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.kontext_extension import KontextExtension
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
@@ -47,6 +54,7 @@ from invokeai.backend.flux.sampling_utils import (
pack,
unpack,
)
from invokeai.backend.flux.schedulers import FLUX_SCHEDULER_LABELS, FLUX_SCHEDULER_MAP, FLUX_SCHEDULER_NAME_VALUES
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
from invokeai.backend.model_manager.taxonomy import BaseModelType, FluxVariantType, ModelFormat, ModelType
from invokeai.backend.patches.layer_patcher import LayerPatcher
@@ -63,7 +71,7 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="4.1.0",
version="4.5.1",
)
class FluxDenoiseInvocation(BaseInvocation):
"""Run denoising process with a FLUX transformer model."""
@@ -132,6 +140,12 @@ class FluxDenoiseInvocation(BaseInvocation):
num_steps: int = InputField(
default=4, description="Number of diffusion steps. Recommended values are schnell: 4, dev: 50."
)
scheduler: FLUX_SCHEDULER_NAME_VALUES = InputField(
default="euler",
description="Scheduler (sampler) for the denoising process. 'euler' is fast and standard. "
"'heun' is 2nd-order (better quality, 2x slower). 'lcm' is optimized for few steps.",
ui_choice_labels=FLUX_SCHEDULER_LABELS,
)
guidance: float = InputField(
default=4.0,
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
@@ -159,6 +173,31 @@ class FluxDenoiseInvocation(BaseInvocation):
input=Input.Connection,
)
# DyPE (Dynamic Position Extrapolation) for high-resolution generation
dype_preset: DyPEPreset = InputField(
default=DYPE_PRESET_OFF,
description=(
"DyPE preset for high-resolution generation. 'auto' enables automatically for resolutions > 1536px. "
"'area' enables automatically based on image area. '4k' uses optimized settings for 4K output."
),
ui_order=100,
ui_choice_labels=DYPE_PRESET_LABELS,
)
dype_scale: Optional[float] = InputField(
default=None,
ge=0.0,
le=8.0,
description="DyPE magnitude (λs). Higher values = stronger extrapolation. Only used when dype_preset is not 'off'.",
ui_order=101,
)
dype_exponent: Optional[float] = InputField(
default=None,
ge=0.0,
le=1000.0,
description="DyPE decay speed (λt). Controls transition from low to high frequency detail. Only used when dype_preset is not 'off'.",
ui_order=102,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
@@ -232,8 +271,14 @@ class FluxDenoiseInvocation(BaseInvocation):
)
transformer_config = context.models.get_config(self.transformer.transformer)
assert transformer_config.base is BaseModelType.Flux and transformer_config.type is ModelType.Main
is_schnell = transformer_config.variant is FluxVariantType.Schnell
assert (
transformer_config.base in (BaseModelType.Flux, BaseModelType.Flux2)
and transformer_config.type is ModelType.Main
)
# Schnell is only for FLUX.1, FLUX.2 Klein behaves like Dev (with guidance)
is_schnell = (
transformer_config.base is BaseModelType.Flux and transformer_config.variant is FluxVariantType.Schnell
)
# Calculate the timestep schedule.
timesteps = get_schedule(
@@ -242,6 +287,12 @@ class FluxDenoiseInvocation(BaseInvocation):
shift=not is_schnell,
)
# Create scheduler if not using default euler
scheduler = None
if self.scheduler in FLUX_SCHEDULER_MAP:
scheduler_class = FLUX_SCHEDULER_MAP[self.scheduler]
scheduler = scheduler_class(num_train_timesteps=1000)
# Clip the timesteps schedule based on denoising_start and denoising_end.
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
@@ -409,6 +460,30 @@ class FluxDenoiseInvocation(BaseInvocation):
kontext_extension.ensure_batch_size(x.shape[0])
img_cond_seq, img_cond_seq_ids = kontext_extension.kontext_latents, kontext_extension.kontext_ids
# Prepare DyPE extension for high-resolution generation
dype_extension: DyPEExtension | None = None
dype_config = get_dype_config_from_preset(
preset=self.dype_preset,
width=self.width,
height=self.height,
custom_scale=self.dype_scale,
custom_exponent=self.dype_exponent,
)
if dype_config is not None:
dype_extension = DyPEExtension(
config=dype_config,
target_height=self.height,
target_width=self.width,
)
context.logger.info(
f"DyPE enabled: resolution={self.width}x{self.height}, preset={self.dype_preset}, "
f"method={dype_config.method}, scale={dype_config.dype_scale:.2f}, "
f"exponent={dype_config.dype_exponent:.2f}, start_sigma={dype_config.dype_start_sigma:.2f}, "
f"base_resolution={dype_config.base_resolution}"
)
else:
context.logger.debug(f"DyPE disabled: resolution={self.width}x{self.height}, preset={self.dype_preset}")
x = denoise(
model=transformer,
img=x,
@@ -426,6 +501,8 @@ class FluxDenoiseInvocation(BaseInvocation):
img_cond=img_cond,
img_cond_seq=img_cond_seq,
img_cond_seq_ids=img_cond_seq_ids,
dype_extension=dype_extension,
scheduler=scheduler,
)
x = unpack(x.float(), self.height, self.width)

View File

@@ -162,7 +162,7 @@ class FLUXLoRACollectionLoader(BaseInvocation):
if not context.models.exists(lora.lora.key):
raise Exception(f"Unknown lora: {lora.lora.key}!")
assert lora.lora.base is BaseModelType.Flux
assert lora.lora.base in (BaseModelType.Flux, BaseModelType.Flux2)
added_loras.append(lora.lora.key)

View File

@@ -6,7 +6,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.t5_model_identifier import (
@@ -37,28 +37,25 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
title="Main Model - FLUX",
tags=["model", "flux"],
category="model",
version="1.0.6",
version="1.0.7",
)
class FluxModelLoaderInvocation(BaseInvocation):
"""Loads a flux base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux_model,
input=Input.Direct,
ui_model_base=BaseModelType.Flux,
ui_model_type=ModelType.Main,
)
t5_encoder_model: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder,
input=Input.Direct,
title="T5 Encoder",
ui_model_type=ModelType.T5Encoder,
)
clip_embed_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_embed_model,
input=Input.Direct,
title="CLIP Embed",
ui_model_type=ModelType.CLIPEmbed,
)

View File

@@ -46,7 +46,12 @@ class IdealSizeInvocation(BaseInvocation):
dimension = 512
elif unet_config.base == BaseModelType.StableDiffusion2:
dimension = 768
elif unet_config.base in (BaseModelType.StableDiffusionXL, BaseModelType.Flux, BaseModelType.StableDiffusion3):
elif unet_config.base in (
BaseModelType.StableDiffusionXL,
BaseModelType.Flux,
BaseModelType.Flux2,
BaseModelType.StableDiffusion3,
):
dimension = 1024
else:
raise ValueError(f"Unsupported model type: {unet_config.base}")

View File

@@ -1,5 +1,6 @@
from contextlib import nullcontext
from functools import singledispatchmethod
from typing import Literal
import einops
import torch
@@ -20,7 +21,7 @@ from invokeai.app.invocations.fields import (
Input,
InputField,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.model import BaseModelType, VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
@@ -29,13 +30,21 @@ from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd15_sdxl
"""
SDXL VAE color compensation values determined experimentally to reduce color drift.
If more reliable values are found in the future (e.g. individual color channels), they can be updated.
SD1.5, TAESD, TAESDXL VAEs distort in less predictable ways, so no compensation is offered at this time.
"""
COMPENSATION_OPTIONS = Literal["None", "SDXL"]
COLOR_COMPENSATION_MAP = {"None": [1, 0], "SDXL": [1.015, -0.002]}
@invocation(
"i2l",
title="Image to Latents - SD1.5, SDXL",
tags=["latents", "image", "vae", "i2l"],
category="latents",
version="1.1.1",
version="1.2.0",
)
class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents."""
@@ -52,6 +61,10 @@ class ImageToLatentsInvocation(BaseInvocation):
# offer a way to directly set None values.
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
color_compensation: COMPENSATION_OPTIONS = InputField(
default="None",
description="Apply VAE scaling compensation when encoding images (reduces color drift).",
)
@classmethod
def vae_encode(
@@ -62,7 +75,7 @@ class ImageToLatentsInvocation(BaseInvocation):
image_tensor: torch.Tensor,
tile_size: int = 0,
) -> torch.Tensor:
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny)), "VAE must be of type SD-1.5 or SDXL"
estimated_working_memory = estimate_vae_working_memory_sd15_sdxl(
operation="encode",
image_tensor=image_tensor,
@@ -71,7 +84,7 @@ class ImageToLatentsInvocation(BaseInvocation):
fp32=upcast,
)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny)), "VAE must be of type SD-1.5 or SDXL"
orig_dtype = vae.dtype
if upcast:
vae.to(dtype=torch.float32)
@@ -127,9 +140,14 @@ class ImageToLatentsInvocation(BaseInvocation):
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny)), "VAE must be of type SD-1.5 or SDXL"
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if self.color_compensation != "None" and vae_info.config.base == BaseModelType.StableDiffusionXL:
scale, bias = COLOR_COMPENSATION_MAP[self.color_compensation]
image_tensor = image_tensor * scale + bias
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")

View File

@@ -2,12 +2,6 @@ from contextlib import nullcontext
import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
@@ -77,26 +71,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(TorchDevice.choose_torch_device())
if self.fp32:
# FP32 mode: convert everything to float32 for maximum precision
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
vae.post_quant_conv.to(latents.dtype)
vae.decoder.conv_in.to(latents.dtype)
vae.decoder.mid_block.to(latents.dtype)
else:
latents = latents.float()
latents = latents.float()
else:
vae.to(dtype=torch.float16)
latents = latents.half()

View File

@@ -150,6 +150,10 @@ GENERATION_MODES = Literal[
"flux_img2img",
"flux_inpaint",
"flux_outpaint",
"flux2_txt2img",
"flux2_img2img",
"flux2_inpaint",
"flux2_outpaint",
"sd3_txt2img",
"sd3_img2img",
"sd3_inpaint",
@@ -158,6 +162,10 @@ GENERATION_MODES = Literal[
"cogview4_img2img",
"cogview4_inpaint",
"cogview4_outpaint",
"z_image_txt2img",
"z_image_img2img",
"z_image_inpaint",
"z_image_outpaint",
]
@@ -166,7 +174,7 @@ GENERATION_MODES = Literal[
title="Core Metadata",
tags=["metadata"],
category="metadata",
version="2.0.0",
version="2.1.0",
classification=Classification.Internal,
)
class CoreMetadataInvocation(BaseInvocation):
@@ -217,6 +225,10 @@ class CoreMetadataInvocation(BaseInvocation):
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
qwen3_encoder: Optional[ModelIdentifierField] = InputField(
default=None,
description="The Qwen3 text encoder model used for Z-Image inference",
)
# High resolution fix metadata.
hrf_enabled: Optional[bool] = InputField(

View File

@@ -52,6 +52,7 @@ from invokeai.app.invocations.primitives import (
)
from invokeai.app.invocations.scheduler import SchedulerOutput
from invokeai.app.invocations.t2i_adapter import T2IAdapterField, T2IAdapterInvocation
from invokeai.app.invocations.z_image_denoise import ZImageDenoiseInvocation
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@@ -729,6 +730,52 @@ class FluxDenoiseLatentsMetaInvocation(FluxDenoiseInvocation, WithMetadata):
return LatentsMetaOutput(**params, metadata=MetadataField.model_validate(md))
@invocation(
"z_image_denoise_meta",
title=f"{ZImageDenoiseInvocation.UIConfig.title} + Metadata",
tags=["z-image", "latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
version="1.0.0",
)
class ZImageDenoiseMetaInvocation(ZImageDenoiseInvocation, WithMetadata):
"""Run denoising process with a Z-Image transformer model + metadata."""
def invoke(self, context: InvocationContext) -> LatentsMetaOutput:
def _loras_to_json(obj: Union[Any, list[Any]]):
if not isinstance(obj, list):
obj = [obj]
output: list[dict[str, Any]] = []
for item in obj:
output.append(
LoRAMetadataField(
model=item.lora,
weight=item.weight,
).model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
)
return output
obj = super().invoke(context)
md: Dict[str, Any] = {} if self.metadata is None else self.metadata.root
md.update({"width": obj.width})
md.update({"height": obj.height})
md.update({"steps": self.steps})
md.update({"guidance": self.guidance_scale})
md.update({"denoising_start": self.denoising_start})
md.update({"denoising_end": self.denoising_end})
md.update({"scheduler": self.scheduler})
md.update({"model": self.transformer.transformer})
md.update({"seed": self.seed})
if len(self.transformer.loras) > 0:
md.update({"loras": _loras_to_json(self.transformer.loras)})
params = obj.__dict__.copy()
del params["type"]
return LatentsMetaOutput(**params, metadata=MetadataField.model_validate(md))
@invocation(
"metadata_to_vae",
title="Metadata To VAE",

View File

@@ -72,6 +72,14 @@ class GlmEncoderField(BaseModel):
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
class Qwen3EncoderField(BaseModel):
"""Field for Qwen3 text encoder used by Z-Image models."""
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
loras: List[LoRAField] = Field(default_factory=list, description="LoRAs to apply on model loading")
class VAEField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
@@ -502,6 +510,7 @@ class VAELoaderInvocation(BaseInvocation):
BaseModelType.StableDiffusionXL,
BaseModelType.StableDiffusion3,
BaseModelType.Flux,
BaseModelType.Flux2,
],
ui_model_type=ModelType.VAE,
)

View File

@@ -0,0 +1,59 @@
import pathlib
from typing import Literal
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import ImageField, InputField, OutputField, WithBoard, WithMetadata
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.pbr_maps.architecture.pbr_rrdb_net import PBR_RRDB_Net
from invokeai.backend.image_util.pbr_maps.pbr_maps import NORMAL_MAP_MODEL, OTHER_MAP_MODEL, PBRMapsGenerator
from invokeai.backend.util.devices import TorchDevice
@invocation_output("pbr_maps-output")
class PBRMapsOutput(BaseInvocationOutput):
normal_map: ImageField = OutputField(default=None, description="The generated normal map")
roughness_map: ImageField = OutputField(default=None, description="The generated roughness map")
displacement_map: ImageField = OutputField(default=None, description="The generated displacement map")
@invocation("pbr_maps", title="PBR Maps", tags=["image", "material"], category="image", version="1.0.0")
class PBRMapsInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generate Normal, Displacement and Roughness Map from a given image"""
image: ImageField = InputField(description="Input image")
tile_size: int = InputField(default=512, description="Tile size")
border_mode: Literal["none", "seamless", "mirror", "replicate"] = InputField(
default="none", description="Border mode to apply to eliminate any artifacts or seams"
)
def invoke(self, context: InvocationContext) -> PBRMapsOutput:
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
def loader(model_path: pathlib.Path):
return PBRMapsGenerator.load_model(model_path, TorchDevice.choose_torch_device())
torch_device = TorchDevice.choose_torch_device()
with (
context.models.load_remote_model(NORMAL_MAP_MODEL, loader) as normal_map_model,
context.models.load_remote_model(OTHER_MAP_MODEL, loader) as other_map_model,
):
assert isinstance(normal_map_model, PBR_RRDB_Net)
assert isinstance(other_map_model, PBR_RRDB_Net)
pbr_pipeline = PBRMapsGenerator(normal_map_model, other_map_model, torch_device)
normal_map, roughness_map, displacement_map = pbr_pipeline.generate_maps(
image_pil, self.tile_size, self.border_mode
)
normal_map = context.images.save(normal_map)
normal_map_field = ImageField(image_name=normal_map.image_name)
roughness_map = context.images.save(roughness_map)
roughness_map_field = ImageField(image_name=roughness_map.image_name)
displacement_map = context.images.save(displacement_map)
displacement_map_field = ImageField(image_name=displacement_map.image_name)
return PBRMapsOutput(
normal_map=normal_map_field, roughness_map=roughness_map_field, displacement_map=displacement_map_field
)

View File

@@ -27,6 +27,7 @@ from invokeai.app.invocations.fields import (
SD3ConditioningField,
TensorField,
UIComponent,
ZImageConditioningField,
)
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -461,6 +462,17 @@ class CogView4ConditioningOutput(BaseInvocationOutput):
return cls(conditioning=CogView4ConditioningField(conditioning_name=conditioning_name))
@invocation_output("z_image_conditioning_output")
class ZImageConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a Z-Image text conditioning tensor."""
conditioning: ZImageConditioningField = OutputField(description=FieldDescriptions.cond)
@classmethod
def build(cls, conditioning_name: str) -> "ZImageConditioningOutput":
return cls(conditioning=ZImageConditioningField(conditioning_name=conditioning_name))
@invocation_output("conditioning_output")
class ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""

View File

@@ -0,0 +1,57 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import InputField, OutputField, StylePresetField, UIComponent
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("prompt_template_output")
class PromptTemplateOutput(BaseInvocationOutput):
"""Output for the Prompt Template node"""
positive_prompt: str = OutputField(description="The positive prompt with the template applied")
negative_prompt: str = OutputField(description="The negative prompt with the template applied")
@invocation(
"prompt_template",
title="Prompt Template",
tags=["prompt", "template", "style", "preset"],
category="prompt",
version="1.0.0",
)
class PromptTemplateInvocation(BaseInvocation):
"""Applies a Style Preset template to positive and negative prompts.
Select a Style Preset and provide positive/negative prompts. The node replaces
{prompt} placeholders in the template with your input prompts.
"""
style_preset: StylePresetField = InputField(
description="The Style Preset to use as a template",
)
positive_prompt: str = InputField(
default="",
description="The positive prompt to insert into the template's {prompt} placeholder",
ui_component=UIComponent.Textarea,
)
negative_prompt: str = InputField(
default="",
description="The negative prompt to insert into the template's {prompt} placeholder",
ui_component=UIComponent.Textarea,
)
def invoke(self, context: InvocationContext) -> PromptTemplateOutput:
# Fetch the style preset from the database
style_preset = context._services.style_preset_records.get(self.style_preset.style_preset_id)
# Get the template prompts
positive_template = style_preset.preset_data.positive_prompt
negative_template = style_preset.preset_data.negative_prompt
# Replace {prompt} placeholder with the input prompts
rendered_positive = positive_template.replace("{prompt}", self.positive_prompt)
rendered_negative = negative_template.replace("{prompt}", self.negative_prompt)
return PromptTemplateOutput(
positive_prompt=rendered_positive,
negative_prompt=rendered_negative,
)

View File

@@ -0,0 +1,112 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Z-Image Control invocation for spatial conditioning."""
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
InputField,
OutputField,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
class ZImageControlField(BaseModel):
"""A Z-Image control conditioning field for spatial control (Canny, HED, Depth, Pose, MLSD)."""
image_name: str = Field(description="The name of the preprocessed control image")
control_model: ModelIdentifierField = Field(description="The Z-Image ControlNet adapter model")
control_context_scale: float = Field(
default=0.75,
ge=0.0,
le=2.0,
description="The strength of the control signal. Recommended range: 0.65-0.80.",
)
begin_step_percent: float = Field(
default=0.0,
ge=0.0,
le=1.0,
description="When the control is first applied (% of total steps)",
)
end_step_percent: float = Field(
default=1.0,
ge=0.0,
le=1.0,
description="When the control is last applied (% of total steps)",
)
@invocation_output("z_image_control_output")
class ZImageControlOutput(BaseInvocationOutput):
"""Z-Image Control output containing control configuration."""
control: ZImageControlField = OutputField(description="Z-Image control conditioning")
@invocation(
"z_image_control",
title="Z-Image ControlNet",
tags=["image", "z-image", "control", "controlnet"],
category="control",
version="1.1.0",
classification=Classification.Prototype,
)
class ZImageControlInvocation(BaseInvocation):
"""Configure Z-Image ControlNet for spatial conditioning.
Takes a preprocessed control image (e.g., Canny edges, depth map, pose)
and a Z-Image ControlNet adapter model to enable spatial control.
Supports 5 control modes: Canny, HED, Depth, Pose, MLSD.
Recommended control_context_scale: 0.65-0.80.
"""
image: ImageField = InputField(
description="The preprocessed control image (Canny, HED, Depth, Pose, or MLSD)",
)
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model,
title="Control Model",
ui_model_base=BaseModelType.ZImage,
ui_model_type=ModelType.ControlNet,
)
control_context_scale: float = InputField(
default=0.75,
ge=0.0,
le=2.0,
description="Strength of the control signal. Recommended range: 0.65-0.80.",
title="Control Scale",
)
begin_step_percent: float = InputField(
default=0.0,
ge=0.0,
le=1.0,
description="When the control is first applied (% of total steps)",
)
end_step_percent: float = InputField(
default=1.0,
ge=0.0,
le=1.0,
description="When the control is last applied (% of total steps)",
)
def invoke(self, context: InvocationContext) -> ZImageControlOutput:
return ZImageControlOutput(
control=ZImageControlField(
image_name=self.image.image_name,
control_model=self.control_model,
control_context_scale=self.control_context_scale,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
)
)

View File

@@ -0,0 +1,770 @@
import inspect
import math
from contextlib import ExitStack
from typing import Callable, Iterator, Optional, Tuple
import einops
import torch
import torchvision.transforms as tv_transforms
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from PIL import Image
from torchvision.transforms.functional import resize as tv_resize
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
Input,
InputField,
LatentsField,
ZImageConditioningField,
)
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.invocations.z_image_control import ZImageControlField
from invokeai.app.invocations.z_image_image_to_latents import ZImageImageToLatentsInvocation
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.schedulers import ZIMAGE_SCHEDULER_LABELS, ZIMAGE_SCHEDULER_MAP, ZIMAGE_SCHEDULER_NAME_VALUES
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.z_image_lora_constants import Z_IMAGE_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ZImageConditioningInfo
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.z_image.extensions.regional_prompting_extension import ZImageRegionalPromptingExtension
from invokeai.backend.z_image.text_conditioning import ZImageTextConditioning
from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
from invokeai.backend.z_image.z_image_controlnet_extension import (
ZImageControlNetExtension,
z_image_forward_with_control,
)
from invokeai.backend.z_image.z_image_transformer_patch import patch_transformer_for_regional_prompting
@invocation(
"z_image_denoise",
title="Denoise - Z-Image",
tags=["image", "z-image"],
category="image",
version="1.4.0",
classification=Classification.Prototype,
)
class ZImageDenoiseInvocation(BaseInvocation):
"""Run the denoising process with a Z-Image model.
Supports regional prompting by connecting multiple conditioning inputs with masks.
"""
# If latents is provided, this means we are doing image-to-image.
latents: Optional[LatentsField] = InputField(
default=None, description=FieldDescriptions.latents, input=Input.Connection
)
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
)
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
add_noise: bool = InputField(default=True, description="Add noise based on denoising start.")
transformer: TransformerField = InputField(
description=FieldDescriptions.z_image_model, input=Input.Connection, title="Transformer"
)
positive_conditioning: ZImageConditioningField | list[ZImageConditioningField] = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: ZImageConditioningField | list[ZImageConditioningField] | None = InputField(
default=None, description=FieldDescriptions.negative_cond, input=Input.Connection
)
# Z-Image-Turbo works best without CFG (guidance_scale=1.0)
guidance_scale: float = InputField(
default=1.0,
ge=1.0,
description="Guidance scale for classifier-free guidance. 1.0 = no CFG (recommended for Z-Image-Turbo). "
"Values > 1.0 amplify guidance.",
title="Guidance Scale",
)
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
# Z-Image-Turbo uses 8 steps by default
steps: int = InputField(default=8, gt=0, description="Number of denoising steps. 8 recommended for Z-Image-Turbo.")
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
# Z-Image Control support
control: Optional[ZImageControlField] = InputField(
default=None,
description="Z-Image control conditioning for spatial control (Canny, HED, Depth, Pose, MLSD).",
input=Input.Connection,
)
# VAE for encoding control images (required when using control)
vae: Optional[VAEField] = InputField(
default=None,
description=FieldDescriptions.vae + " Required for control conditioning.",
input=Input.Connection,
)
# Scheduler selection for the denoising process
scheduler: ZIMAGE_SCHEDULER_NAME_VALUES = InputField(
default="euler",
description="Scheduler (sampler) for the denoising process. Euler is the default and recommended for "
"Z-Image-Turbo. Heun is 2nd-order (better quality, 2x slower). LCM is optimized for few steps.",
ui_choice_labels=ZIMAGE_SCHEDULER_LABELS,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
latents = latents.detach().to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
"""Prepare the inpaint mask."""
if self.denoise_mask is None:
return None
mask = context.tensors.load(self.denoise_mask.mask_name)
# Invert mask: 0.0 = regions to denoise, 1.0 = regions to preserve
mask = 1.0 - mask
_, _, latent_height, latent_width = latents.shape
mask = tv_resize(
img=mask,
size=[latent_height, latent_width],
interpolation=tv_transforms.InterpolationMode.BILINEAR,
antialias=False,
)
mask = mask.to(device=latents.device, dtype=latents.dtype)
return mask
def _load_text_conditioning(
self,
context: InvocationContext,
cond_field: ZImageConditioningField | list[ZImageConditioningField],
img_height: int,
img_width: int,
dtype: torch.dtype,
device: torch.device,
) -> list[ZImageTextConditioning]:
"""Load Z-Image text conditioning with optional regional masks.
Args:
context: The invocation context.
cond_field: Single conditioning field or list of fields.
img_height: Height of the image token grid (H // patch_size).
img_width: Width of the image token grid (W // patch_size).
dtype: Target dtype.
device: Target device.
Returns:
List of ZImageTextConditioning objects with embeddings and masks.
"""
# Normalize to a list
cond_list = [cond_field] if isinstance(cond_field, ZImageConditioningField) else cond_field
text_conditionings: list[ZImageTextConditioning] = []
for cond in cond_list:
# Load the text embeddings
cond_data = context.conditioning.load(cond.conditioning_name)
assert len(cond_data.conditionings) == 1
z_image_conditioning = cond_data.conditionings[0]
assert isinstance(z_image_conditioning, ZImageConditioningInfo)
z_image_conditioning = z_image_conditioning.to(dtype=dtype, device=device)
prompt_embeds = z_image_conditioning.prompt_embeds
# Load the mask, if provided
mask: torch.Tensor | None = None
if cond.mask is not None:
mask = context.tensors.load(cond.mask.tensor_name)
mask = mask.to(device=device)
mask = ZImageRegionalPromptingExtension.preprocess_regional_prompt_mask(
mask, img_height, img_width, dtype, device
)
text_conditionings.append(ZImageTextConditioning(prompt_embeds=prompt_embeds, mask=mask))
return text_conditionings
def _get_noise(
self,
batch_size: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
seed: int,
) -> torch.Tensor:
"""Generate initial noise tensor."""
# Generate noise as float32 on CPU for maximum compatibility,
# then cast to target dtype/device
rand_device = "cpu"
rand_dtype = torch.float32
return torch.randn(
batch_size,
num_channels_latents,
int(height) // LATENT_SCALE_FACTOR,
int(width) // LATENT_SCALE_FACTOR,
device=rand_device,
dtype=rand_dtype,
generator=torch.Generator(device=rand_device).manual_seed(seed),
).to(device=device, dtype=dtype)
def _calculate_shift(
self,
image_seq_len: int,
base_image_seq_len: int = 256,
max_image_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
) -> float:
"""Calculate timestep shift based on image sequence length.
Based on diffusers ZImagePipeline.calculate_shift method.
"""
m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len)
b = base_shift - m * base_image_seq_len
mu = image_seq_len * m + b
return mu
def _get_sigmas(self, mu: float, num_steps: int) -> list[float]:
"""Generate sigma schedule with time shift.
Based on FlowMatchEulerDiscreteScheduler with shift.
Generates num_steps + 1 sigma values (including terminal 0.0).
"""
import math
def time_shift(mu: float, sigma: float, t: float) -> float:
"""Apply time shift to a single timestep value."""
if t <= 0:
return 0.0
if t >= 1:
return 1.0
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
# Generate linearly spaced values from 1 to 0 (excluding endpoints for safety)
# then apply time shift
sigmas = []
for i in range(num_steps + 1):
t = 1.0 - i / num_steps # Goes from 1.0 to 0.0
sigma = time_shift(mu, 1.0, t)
sigmas.append(sigma)
return sigmas
def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
device = TorchDevice.choose_torch_device()
inference_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
transformer_info = context.models.load(self.transformer.transformer)
# Calculate image token grid dimensions
patch_size = 2 # Z-Image uses patch_size=2
latent_height = self.height // LATENT_SCALE_FACTOR
latent_width = self.width // LATENT_SCALE_FACTOR
img_token_height = latent_height // patch_size
img_token_width = latent_width // patch_size
img_seq_len = img_token_height * img_token_width
# Load positive conditioning with regional masks
pos_text_conditionings = self._load_text_conditioning(
context=context,
cond_field=self.positive_conditioning,
img_height=img_token_height,
img_width=img_token_width,
dtype=inference_dtype,
device=device,
)
# Create regional prompting extension
regional_extension = ZImageRegionalPromptingExtension.from_text_conditionings(
text_conditionings=pos_text_conditionings,
img_seq_len=img_seq_len,
)
# Get the concatenated prompt embeddings for the transformer
pos_prompt_embeds = regional_extension.regional_text_conditioning.prompt_embeds
# Load negative conditioning if provided and guidance_scale != 1.0
# CFG formula: pred = pred_uncond + cfg_scale * (pred_cond - pred_uncond)
# At cfg_scale=1.0: pred = pred_cond (no effect, skip uncond computation)
# This matches FLUX's convention where 1.0 means "no CFG"
neg_prompt_embeds: torch.Tensor | None = None
do_classifier_free_guidance = (
not math.isclose(self.guidance_scale, 1.0) and self.negative_conditioning is not None
)
if do_classifier_free_guidance:
assert self.negative_conditioning is not None
# Load all negative conditionings and concatenate embeddings
# Note: We ignore masks for negative conditioning as regional negative prompting is not fully supported
neg_text_conditionings = self._load_text_conditioning(
context=context,
cond_field=self.negative_conditioning,
img_height=img_token_height,
img_width=img_token_width,
dtype=inference_dtype,
device=device,
)
# Concatenate all negative embeddings
neg_prompt_embeds = torch.cat([tc.prompt_embeds for tc in neg_text_conditionings], dim=0)
# Calculate shift based on image sequence length
mu = self._calculate_shift(img_seq_len)
# Generate sigma schedule with time shift
sigmas = self._get_sigmas(mu, self.steps)
# Apply denoising_start and denoising_end clipping
if self.denoising_start > 0 or self.denoising_end < 1:
# Calculate start and end indices based on denoising range
total_sigmas = len(sigmas)
start_idx = int(self.denoising_start * (total_sigmas - 1))
end_idx = int(self.denoising_end * (total_sigmas - 1)) + 1
sigmas = sigmas[start_idx:end_idx]
total_steps = len(sigmas) - 1
# Load input latents if provided (image-to-image)
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
if init_latents is not None:
init_latents = init_latents.to(device=device, dtype=inference_dtype)
# Generate initial noise
num_channels_latents = 16 # Z-Image uses 16 latent channels
noise = self._get_noise(
batch_size=1,
num_channels_latents=num_channels_latents,
height=self.height,
width=self.width,
dtype=inference_dtype,
device=device,
seed=self.seed,
)
# Prepare input latent image
if init_latents is not None:
if self.add_noise:
# Noise the init_latents by the appropriate amount for the first timestep.
s_0 = sigmas[0]
latents = s_0 * noise + (1.0 - s_0) * init_latents
else:
latents = init_latents
else:
if self.denoising_start > 1e-5:
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
latents = noise
# Short-circuit if no denoising steps
if total_steps <= 0:
return latents
# Prepare inpaint extension
inpaint_mask = self._prep_inpaint_mask(context, latents)
inpaint_extension: RectifiedFlowInpaintExtension | None = None
if inpaint_mask is not None:
if init_latents is None:
raise ValueError("Initial latents are required when using an inpaint mask (image-to-image inpainting)")
inpaint_extension = RectifiedFlowInpaintExtension(
init_latents=init_latents,
inpaint_mask=inpaint_mask,
noise=noise,
)
step_callback = self._build_step_callback(context)
# Initialize the diffusers scheduler if not using built-in Euler
scheduler: SchedulerMixin | None = None
use_scheduler = self.scheduler != "euler"
if use_scheduler:
scheduler_class = ZIMAGE_SCHEDULER_MAP[self.scheduler]
scheduler = scheduler_class(
num_train_timesteps=1000,
shift=1.0,
)
# Set timesteps - LCM should use num_inference_steps (it has its own sigma schedule),
# while other schedulers can use custom sigmas if supported
is_lcm = self.scheduler == "lcm"
set_timesteps_sig = inspect.signature(scheduler.set_timesteps)
if not is_lcm and "sigmas" in set_timesteps_sig.parameters:
# Convert sigmas list to tensor for scheduler
scheduler.set_timesteps(sigmas=sigmas, device=device)
else:
# LCM or scheduler doesn't support custom sigmas - use num_inference_steps
scheduler.set_timesteps(num_inference_steps=total_steps, device=device)
# For Heun scheduler, the number of actual steps may differ
num_scheduler_steps = len(scheduler.timesteps)
else:
num_scheduler_steps = total_steps
with ExitStack() as exit_stack:
# Get transformer config to determine if it's quantized
transformer_config = context.models.get_config(self.transformer.transformer)
# Determine if the model is quantized.
# If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
# slower inference than direct patching, but is agnostic to the quantization format.
if transformer_config.format in [ModelFormat.Diffusers, ModelFormat.Checkpoint]:
model_is_quantized = False
elif transformer_config.format in [ModelFormat.GGUFQuantized]:
model_is_quantized = True
else:
raise ValueError(f"Unsupported Z-Image model format: {transformer_config.format}")
# Load transformer - always use base transformer, control is handled via extension
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
# Prepare control extension if control is provided
control_extension: ZImageControlNetExtension | None = None
if self.control is not None:
# Load control adapter using context manager (proper GPU memory management)
control_model_info = context.models.load(self.control.control_model)
(_, control_adapter) = exit_stack.enter_context(control_model_info.model_on_device())
assert isinstance(control_adapter, ZImageControlAdapter)
# Get control_in_dim from adapter config (16 for V1, 33 for V2.0)
adapter_config = control_adapter.config
control_in_dim = adapter_config.get("control_in_dim", 16)
num_control_blocks = adapter_config.get("num_control_blocks", 6)
# Log control configuration for debugging
version = "V2.0" if control_in_dim > 16 else "V1"
context.util.signal_progress(
f"Using Z-Image ControlNet {version} (Extension): control_in_dim={control_in_dim}, "
f"num_blocks={num_control_blocks}, scale={self.control.control_context_scale}"
)
# Load and prepare control image - must be VAE-encoded!
if self.vae is None:
raise ValueError("VAE is required when using Z-Image Control. Connect a VAE to the 'vae' input.")
control_image = context.images.get_pil(self.control.image_name)
# Resize control image to match output dimensions
control_image = control_image.convert("RGB")
control_image = control_image.resize((self.width, self.height), Image.Resampling.LANCZOS)
# Convert to tensor format for VAE encoding
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
control_image_tensor = image_resized_to_grid_as_tensor(control_image)
if control_image_tensor.dim() == 3:
control_image_tensor = einops.rearrange(control_image_tensor, "c h w -> 1 c h w")
# Encode control image through VAE to get latents
vae_info = context.models.load(self.vae.vae)
control_latents = ZImageImageToLatentsInvocation.vae_encode(
vae_info=vae_info,
image_tensor=control_image_tensor,
)
# Move to inference device/dtype
control_latents = control_latents.to(device=device, dtype=inference_dtype)
# Add frame dimension: [B, C, H, W] -> [C, 1, H, W] (single image)
control_latents = control_latents.squeeze(0).unsqueeze(1)
# Prepare control_cond based on control_in_dim
# V1: 16 channels (just control latents)
# V2.0: 33 channels = 16 control + 16 reference + 1 mask
# - Channels 0-15: control image latents (from VAE encoding)
# - Channels 16-31: reference/inpaint image latents (zeros for pure control)
# - Channel 32: inpaint mask (1.0 = don't inpaint, 0.0 = inpaint region)
# For pure control (no inpainting), we set mask=1 to tell model "use control, don't inpaint"
c, f, h, w = control_latents.shape
if c < control_in_dim:
padding_channels = control_in_dim - c
if padding_channels == 17:
# V2.0: 16 reference channels (zeros) + 1 mask channel (ones)
ref_padding = torch.zeros(
(16, f, h, w),
device=device,
dtype=inference_dtype,
)
# Mask channel = 1.0 means "don't inpaint this region, use control signal"
mask_channel = torch.ones(
(1, f, h, w),
device=device,
dtype=inference_dtype,
)
control_latents = torch.cat([control_latents, ref_padding, mask_channel], dim=0)
else:
# Generic padding with zeros for other cases
zero_padding = torch.zeros(
(padding_channels, f, h, w),
device=device,
dtype=inference_dtype,
)
control_latents = torch.cat([control_latents, zero_padding], dim=0)
# Create control extension (adapter is already on device from model_on_device)
control_extension = ZImageControlNetExtension(
control_adapter=control_adapter,
control_cond=control_latents,
weight=self.control.control_context_scale,
begin_step_percent=self.control.begin_step_percent,
end_step_percent=self.control.end_step_percent,
)
# Apply LoRA models to the transformer.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=Z_IMAGE_LORA_TRANSFORMER_PREFIX,
dtype=inference_dtype,
cached_weights=cached_weights,
force_sidecar_patching=model_is_quantized,
)
)
# Apply regional prompting patch if we have regional masks
exit_stack.enter_context(
patch_transformer_for_regional_prompting(
transformer=transformer,
regional_attn_mask=regional_extension.regional_attn_mask,
img_seq_len=img_seq_len,
)
)
# Denoising loop - supports both built-in Euler and diffusers schedulers
# Track user-facing step for progress (accounts for Heun's double steps)
user_step = 0
if use_scheduler and scheduler is not None:
# Use diffusers scheduler for stepping
# Use tqdm with total_steps (user-facing steps) not num_scheduler_steps (internal steps)
# This ensures progress bar shows 1/8, 2/8, etc. even when scheduler uses more internal steps
pbar = tqdm(total=total_steps, desc="Denoising")
for step_index in range(num_scheduler_steps):
sched_timestep = scheduler.timesteps[step_index]
# Convert scheduler timestep (0-1000) to normalized sigma (0-1)
sigma_curr = sched_timestep.item() / scheduler.config.num_train_timesteps
# For Heun scheduler, track if we're in first or second order step
is_heun = hasattr(scheduler, "state_in_first_order")
in_first_order = scheduler.state_in_first_order if is_heun else True
# Timestep tensor for Z-Image model
# The model expects t=0 at start (noise) and t=1 at end (clean)
model_t = 1.0 - sigma_curr
timestep = torch.tensor([model_t], device=device, dtype=inference_dtype).expand(latents.shape[0])
# Run transformer for positive prediction
latent_model_input = latents.to(transformer.dtype)
latent_model_input = latent_model_input.unsqueeze(2) # Add frame dimension
latent_model_input_list = list(latent_model_input.unbind(dim=0))
# Determine if control should be applied at this step
apply_control = control_extension is not None and control_extension.should_apply(
user_step, total_steps
)
# Run forward pass
if apply_control:
model_out_list, _ = z_image_forward_with_control(
transformer=transformer,
x=latent_model_input_list,
t=timestep,
cap_feats=[pos_prompt_embeds],
control_extension=control_extension,
)
else:
model_output = transformer(
x=latent_model_input_list,
t=timestep,
cap_feats=[pos_prompt_embeds],
)
model_out_list = model_output[0]
noise_pred_cond = torch.stack([t.float() for t in model_out_list], dim=0)
noise_pred_cond = noise_pred_cond.squeeze(2)
noise_pred_cond = -noise_pred_cond # Z-Image uses v-prediction with negation
# Apply CFG if enabled
if do_classifier_free_guidance and neg_prompt_embeds is not None:
if apply_control:
model_out_list_uncond, _ = z_image_forward_with_control(
transformer=transformer,
x=latent_model_input_list,
t=timestep,
cap_feats=[neg_prompt_embeds],
control_extension=control_extension,
)
else:
model_output_uncond = transformer(
x=latent_model_input_list,
t=timestep,
cap_feats=[neg_prompt_embeds],
)
model_out_list_uncond = model_output_uncond[0]
noise_pred_uncond = torch.stack([t.float() for t in model_out_list_uncond], dim=0)
noise_pred_uncond = noise_pred_uncond.squeeze(2)
noise_pred_uncond = -noise_pred_uncond
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond
# Use scheduler.step() for the update
step_output = scheduler.step(model_output=noise_pred, timestep=sched_timestep, sample=latents)
latents = step_output.prev_sample
# Get sigma_prev for inpainting (next sigma value)
if step_index + 1 < len(scheduler.sigmas):
sigma_prev = scheduler.sigmas[step_index + 1].item()
else:
sigma_prev = 0.0
if inpaint_extension is not None:
latents = inpaint_extension.merge_intermediate_latents_with_init_latents(latents, sigma_prev)
# For Heun, only increment user step after second-order step completes
if is_heun:
if not in_first_order:
user_step += 1
# Only call step_callback if we haven't exceeded total_steps
if user_step <= total_steps:
pbar.update(1)
step_callback(
PipelineIntermediateState(
step=user_step,
order=2,
total_steps=total_steps,
timestep=int(sigma_curr * 1000),
latents=latents,
),
)
else:
# For LCM and other first-order schedulers
user_step += 1
# Only call step_callback if we haven't exceeded total_steps
# (LCM scheduler may have more internal steps than user-facing steps)
if user_step <= total_steps:
pbar.update(1)
step_callback(
PipelineIntermediateState(
step=user_step,
order=1,
total_steps=total_steps,
timestep=int(sigma_curr * 1000),
latents=latents,
),
)
pbar.close()
else:
# Original Euler implementation (default, optimized for Z-Image)
for step_idx in tqdm(range(total_steps)):
sigma_curr = sigmas[step_idx]
sigma_prev = sigmas[step_idx + 1]
# Timestep tensor for Z-Image model
# The model expects t=0 at start (noise) and t=1 at end (clean)
# Sigma goes from 1 (noise) to 0 (clean), so model_t = 1 - sigma
model_t = 1.0 - sigma_curr
timestep = torch.tensor([model_t], device=device, dtype=inference_dtype).expand(latents.shape[0])
# Run transformer for positive prediction
# Z-Image transformer expects: x as list of [C, 1, H, W] tensors, t, cap_feats as list
# Prepare latent input: [B, C, H, W] -> [B, C, 1, H, W] -> list of [C, 1, H, W]
latent_model_input = latents.to(transformer.dtype)
latent_model_input = latent_model_input.unsqueeze(2) # Add frame dimension
latent_model_input_list = list(latent_model_input.unbind(dim=0))
# Determine if control should be applied at this step
apply_control = control_extension is not None and control_extension.should_apply(
step_idx, total_steps
)
# Run forward pass - use custom forward with control if extension is active
if apply_control:
model_out_list, _ = z_image_forward_with_control(
transformer=transformer,
x=latent_model_input_list,
t=timestep,
cap_feats=[pos_prompt_embeds],
control_extension=control_extension,
)
else:
model_output = transformer(
x=latent_model_input_list,
t=timestep,
cap_feats=[pos_prompt_embeds],
)
model_out_list = model_output[0] # Extract list of tensors from tuple
noise_pred_cond = torch.stack([t.float() for t in model_out_list], dim=0)
noise_pred_cond = noise_pred_cond.squeeze(2) # Remove frame dimension
noise_pred_cond = -noise_pred_cond # Z-Image uses v-prediction with negation
# Apply CFG if enabled
if do_classifier_free_guidance and neg_prompt_embeds is not None:
if apply_control:
model_out_list_uncond, _ = z_image_forward_with_control(
transformer=transformer,
x=latent_model_input_list,
t=timestep,
cap_feats=[neg_prompt_embeds],
control_extension=control_extension,
)
else:
model_output_uncond = transformer(
x=latent_model_input_list,
t=timestep,
cap_feats=[neg_prompt_embeds],
)
model_out_list_uncond = model_output_uncond[0] # Extract list of tensors from tuple
noise_pred_uncond = torch.stack([t.float() for t in model_out_list_uncond], dim=0)
noise_pred_uncond = noise_pred_uncond.squeeze(2)
noise_pred_uncond = -noise_pred_uncond
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond
# Euler step
latents_dtype = latents.dtype
latents = latents.to(dtype=torch.float32)
latents = latents + (sigma_prev - sigma_curr) * noise_pred
latents = latents.to(dtype=latents_dtype)
if inpaint_extension is not None:
latents = inpaint_extension.merge_intermediate_latents_with_init_latents(latents, sigma_prev)
step_callback(
PipelineIntermediateState(
step=step_idx + 1,
order=1,
total_steps=total_steps,
timestep=int(sigma_curr * 1000),
latents=latents,
),
)
return latents
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, BaseModelType.ZImage)
return step_callback
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
"""Iterate over LoRA models to apply to the transformer."""
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)
if not isinstance(lora_info.model, ModelPatchRaw):
raise TypeError(
f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. "
"The LoRA model may be corrupted or incompatible."
)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -0,0 +1,110 @@
from typing import Union
import einops
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.autoencoder import AutoEncoder as FluxAutoEncoder
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
# Z-Image can use either the Diffusers AutoencoderKL or the FLUX AutoEncoder
ZImageVAE = Union[AutoencoderKL, FluxAutoEncoder]
@invocation(
"z_image_i2l",
title="Image to Latents - Z-Image",
tags=["image", "latents", "vae", "i2l", "z-image"],
category="image",
version="1.1.0",
classification=Classification.Prototype,
)
class ZImageImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates latents from an image using Z-Image VAE (supports both Diffusers and FLUX VAE)."""
image: ImageField = InputField(description="The image to encode.")
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
if not isinstance(vae_info.model, (AutoencoderKL, FluxAutoEncoder)):
raise TypeError(
f"Expected AutoencoderKL or FluxAutoEncoder for Z-Image VAE, got {type(vae_info.model).__name__}. "
"Ensure you are using a compatible VAE model."
)
# Estimate working memory needed for VAE encode
estimated_working_memory = estimate_vae_working_memory_flux(
operation="encode",
image_tensor=image_tensor,
vae=vae_info.model,
)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
if not isinstance(vae, (AutoencoderKL, FluxAutoEncoder)):
raise TypeError(
f"Expected AutoencoderKL or FluxAutoEncoder, got {type(vae).__name__}. "
"VAE model type changed unexpectedly after loading."
)
vae_dtype = next(iter(vae.parameters())).dtype
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
with torch.inference_mode():
if isinstance(vae, FluxAutoEncoder):
# FLUX VAE handles scaling internally
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
latents = vae.encode(image_tensor, sample=True, generator=generator)
else:
# AutoencoderKL - needs manual scaling
vae.disable_tiling()
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents: torch.Tensor = image_tensor_dist.sample().to(dtype=vae.dtype)
# Apply scaling_factor and shift_factor from VAE config
# Z-Image uses: latents = (latents - shift_factor) * scaling_factor
scaling_factor = vae.config.scaling_factor
shift_factor = getattr(vae.config, "shift_factor", None)
if shift_factor is not None:
latents = latents - shift_factor
latents = latents * scaling_factor
return latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
vae_info = context.models.load(self.vae.vae)
if not isinstance(vae_info.model, (AutoencoderKL, FluxAutoEncoder)):
raise TypeError(
f"Expected AutoencoderKL or FluxAutoEncoder for Z-Image VAE, got {type(vae_info.model).__name__}. "
"Ensure you are using a compatible VAE model."
)
context.util.signal_progress("Running VAE")
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)

View File

@@ -0,0 +1,111 @@
from contextlib import nullcontext
from typing import Union
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.autoencoder import AutoEncoder as FluxAutoEncoder
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
# Z-Image can use either the Diffusers AutoencoderKL or the FLUX AutoEncoder
ZImageVAE = Union[AutoencoderKL, FluxAutoEncoder]
@invocation(
"z_image_l2i",
title="Latents to Image - Z-Image",
tags=["latents", "image", "vae", "l2i", "z-image"],
category="latents",
version="1.1.0",
classification=Classification.Prototype,
)
class ZImageLatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents using Z-Image VAE (supports both Diffusers and FLUX VAE)."""
latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
if not isinstance(vae_info.model, (AutoencoderKL, FluxAutoEncoder)):
raise TypeError(
f"Expected AutoencoderKL or FluxAutoEncoder for Z-Image VAE, got {type(vae_info.model).__name__}. "
"Ensure you are using a compatible VAE model."
)
is_flux_vae = isinstance(vae_info.model, FluxAutoEncoder)
# Estimate working memory needed for VAE decode
estimated_working_memory = estimate_vae_working_memory_flux(
operation="decode",
image_tensor=latents,
vae=vae_info.model,
)
# FLUX VAE doesn't support seamless, so only apply for AutoencoderKL
seamless_context = (
nullcontext() if is_flux_vae else SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes)
)
with seamless_context, vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
context.util.signal_progress("Running VAE")
if not isinstance(vae, (AutoencoderKL, FluxAutoEncoder)):
raise TypeError(
f"Expected AutoencoderKL or FluxAutoEncoder, got {type(vae).__name__}. "
"VAE model type changed unexpectedly after loading."
)
vae_dtype = next(iter(vae.parameters())).dtype
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
# Disable tiling for AutoencoderKL
if isinstance(vae, AutoencoderKL):
vae.disable_tiling()
# Clear memory as VAE decode can request a lot
TorchDevice.empty_cache()
with torch.inference_mode():
if isinstance(vae, FluxAutoEncoder):
# FLUX VAE handles scaling internally
img = vae.decode(latents)
else:
# AutoencoderKL - Apply scaling_factor and shift_factor from VAE config
# Z-Image uses: latents = latents / scaling_factor + shift_factor
scaling_factor = vae.config.scaling_factor
shift_factor = getattr(vae.config, "shift_factor", None)
latents = latents / scaling_factor
if shift_factor is not None:
latents = latents + shift_factor
img = vae.decode(latents, return_dict=False)[0]
img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c")
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
TorchDevice.empty_cache()
image_dto = context.images.save(image=img_pil)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,153 @@
from typing import Optional
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, Qwen3EncoderField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
@invocation_output("z_image_lora_loader_output")
class ZImageLoRALoaderOutput(BaseInvocationOutput):
"""Z-Image LoRA Loader Output"""
transformer: Optional[TransformerField] = OutputField(
default=None, description=FieldDescriptions.transformer, title="Z-Image Transformer"
)
qwen3_encoder: Optional[Qwen3EncoderField] = OutputField(
default=None, description=FieldDescriptions.qwen3_encoder, title="Qwen3 Encoder"
)
@invocation(
"z_image_lora_loader",
title="Apply LoRA - Z-Image",
tags=["lora", "model", "z-image"],
category="model",
version="1.0.0",
)
class ZImageLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a Z-Image transformer and/or Qwen3 text encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model,
title="LoRA",
ui_model_base=BaseModelType.ZImage,
ui_model_type=ModelType.LoRA,
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
transformer: TransformerField | None = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="Z-Image Transformer",
)
qwen3_encoder: Qwen3EncoderField | None = InputField(
default=None,
title="Qwen3 Encoder",
description=FieldDescriptions.qwen3_encoder,
input=Input.Connection,
)
def invoke(self, context: InvocationContext) -> ZImageLoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")
# Check for existing LoRAs with the same key.
if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
if self.qwen3_encoder and any(lora.lora.key == lora_key for lora in self.qwen3_encoder.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to Qwen3 encoder.')
output = ZImageLoRALoaderOutput()
# Attach LoRA layers to the models.
if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
output.transformer.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
if self.qwen3_encoder is not None:
output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True)
output.qwen3_encoder.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
return output
@invocation(
"z_image_lora_collection_loader",
title="Apply LoRA Collection - Z-Image",
tags=["lora", "model", "z-image"],
category="model",
version="1.0.0",
)
class ZImageLoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to a Z-Image transformer."""
loras: Optional[LoRAField | list[LoRAField]] = InputField(
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)
transformer: Optional[TransformerField] = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="Transformer",
)
qwen3_encoder: Qwen3EncoderField | None = InputField(
default=None,
title="Qwen3 Encoder",
description=FieldDescriptions.qwen3_encoder,
input=Input.Connection,
)
def invoke(self, context: InvocationContext) -> ZImageLoRALoaderOutput:
output = ZImageLoRALoaderOutput()
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []
if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
if self.qwen3_encoder is not None:
output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True)
for lora in loras:
if lora is None:
continue
if lora.lora.key in added_loras:
continue
if not context.models.exists(lora.lora.key):
raise Exception(f"Unknown lora: {lora.lora.key}!")
if lora.lora.base is not BaseModelType.ZImage:
raise ValueError(
f"LoRA '{lora.lora.key}' is for {lora.lora.base.value if lora.lora.base else 'unknown'} models, "
"not Z-Image models. Ensure you are using a Z-Image compatible LoRA."
)
added_loras.append(lora.lora.key)
if self.transformer is not None and output.transformer is not None:
output.transformer.loras.append(lora)
if self.qwen3_encoder is not None and output.qwen3_encoder is not None:
output.qwen3_encoder.loras.append(lora)
return output

View File

@@ -0,0 +1,135 @@
from typing import Optional
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import (
ModelIdentifierField,
Qwen3EncoderField,
TransformerField,
VAEField,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, SubModelType
@invocation_output("z_image_model_loader_output")
class ZImageModelLoaderOutput(BaseInvocationOutput):
"""Z-Image base model loader output."""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
qwen3_encoder: Qwen3EncoderField = OutputField(description=FieldDescriptions.qwen3_encoder, title="Qwen3 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation(
"z_image_model_loader",
title="Main Model - Z-Image",
tags=["model", "z-image"],
category="model",
version="3.0.0",
classification=Classification.Prototype,
)
class ZImageModelLoaderInvocation(BaseInvocation):
"""Loads a Z-Image model, outputting its submodels.
Similar to FLUX, you can mix and match components:
- Transformer: From Z-Image main model (GGUF quantized or Diffusers format)
- VAE: Separate FLUX VAE (shared with FLUX models) or from a Diffusers Z-Image model
- Qwen3 Encoder: Separate Qwen3Encoder model or from a Diffusers Z-Image model
"""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.z_image_model,
input=Input.Direct,
ui_model_base=BaseModelType.ZImage,
ui_model_type=ModelType.Main,
title="Transformer",
)
vae_model: Optional[ModelIdentifierField] = InputField(
default=None,
description="Standalone VAE model. Z-Image uses the same VAE as FLUX (16-channel). "
"If not provided, VAE will be loaded from the Qwen3 Source model.",
input=Input.Direct,
ui_model_base=BaseModelType.Flux,
ui_model_type=ModelType.VAE,
title="VAE",
)
qwen3_encoder_model: Optional[ModelIdentifierField] = InputField(
default=None,
description="Standalone Qwen3 Encoder model. "
"If not provided, encoder will be loaded from the Qwen3 Source model.",
input=Input.Direct,
ui_model_type=ModelType.Qwen3Encoder,
title="Qwen3 Encoder",
)
qwen3_source_model: Optional[ModelIdentifierField] = InputField(
default=None,
description="Diffusers Z-Image model to extract VAE and/or Qwen3 encoder from. "
"Use this if you don't have separate VAE/Qwen3 models. "
"Ignored if both VAE and Qwen3 Encoder are provided separately.",
input=Input.Direct,
ui_model_base=BaseModelType.ZImage,
ui_model_type=ModelType.Main,
ui_model_format=ModelFormat.Diffusers,
title="Qwen3 Source (Diffusers)",
)
def invoke(self, context: InvocationContext) -> ZImageModelLoaderOutput:
# Transformer always comes from the main model
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
# Determine VAE source
if self.vae_model is not None:
# Use standalone FLUX VAE
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
elif self.qwen3_source_model is not None:
# Extract from Diffusers Z-Image model
self._validate_diffusers_format(context, self.qwen3_source_model, "Qwen3 Source")
vae = self.qwen3_source_model.model_copy(update={"submodel_type": SubModelType.VAE})
else:
raise ValueError(
"No VAE source provided. Either set 'VAE' to a FLUX VAE model, "
"or set 'Qwen3 Source' to a Diffusers Z-Image model."
)
# Determine Qwen3 Encoder source
if self.qwen3_encoder_model is not None:
# Use standalone Qwen3 Encoder
qwen3_tokenizer = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
qwen3_encoder = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
elif self.qwen3_source_model is not None:
# Extract from Diffusers Z-Image model
self._validate_diffusers_format(context, self.qwen3_source_model, "Qwen3 Source")
qwen3_tokenizer = self.qwen3_source_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
qwen3_encoder = self.qwen3_source_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
else:
raise ValueError(
"No Qwen3 Encoder source provided. Either set 'Qwen3 Encoder' to a standalone model, "
"or set 'Qwen3 Source' to a Diffusers Z-Image model."
)
return ZImageModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
qwen3_encoder=Qwen3EncoderField(tokenizer=qwen3_tokenizer, text_encoder=qwen3_encoder),
vae=VAEField(vae=vae),
)
def _validate_diffusers_format(
self, context: InvocationContext, model: ModelIdentifierField, model_name: str
) -> None:
"""Validate that a model is in Diffusers format."""
config = context.models.get_config(model)
if config.format != ModelFormat.Diffusers:
raise ValueError(
f"The {model_name} model must be a Diffusers format Z-Image model. "
f"The selected model '{config.name}' is in {config.format.value} format."
)

View File

@@ -0,0 +1,110 @@
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
ZImageConditioningField,
)
from invokeai.app.invocations.primitives import ZImageConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningFieldData,
ZImageConditioningInfo,
)
@invocation(
"z_image_seed_variance_enhancer",
title="Seed Variance Enhancer - Z-Image",
tags=["conditioning", "z-image", "variance", "seed"],
category="conditioning",
version="1.0.0",
classification=Classification.Prototype,
)
class ZImageSeedVarianceEnhancerInvocation(BaseInvocation):
"""Adds seed-based noise to Z-Image conditioning to increase variance between seeds.
Z-Image-Turbo can produce relatively similar images with different seeds,
making it harder to explore variations of a prompt. This node implements
reproducible, seed-based noise injection into text embeddings to increase
visual variation while maintaining reproducibility.
The noise strength is auto-calibrated relative to the embedding's standard
deviation, ensuring consistent results across different prompts.
"""
conditioning: ZImageConditioningField = InputField(
description=FieldDescriptions.cond,
input=Input.Connection,
title="Conditioning",
)
seed: int = InputField(
default=0,
ge=0,
description="Seed for reproducible noise generation. Different seeds produce different noise patterns.",
)
strength: float = InputField(
default=0.1,
ge=0.0,
le=2.0,
description="Noise strength as multiplier of embedding std. 0=off, 0.1=subtle, 0.5=strong.",
)
randomize_percent: float = InputField(
default=50.0,
ge=1.0,
le=100.0,
description="Percentage of embedding values to add noise to (1-100). Lower values create more selective noise patterns.",
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ZImageConditioningOutput:
# Load conditioning data
cond_data = context.conditioning.load(self.conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1, "Expected exactly one conditioning tensor"
z_image_conditioning = cond_data.conditionings[0]
assert isinstance(z_image_conditioning, ZImageConditioningInfo), "Expected ZImageConditioningInfo"
# Early return if strength is zero (no modification needed)
if self.strength == 0:
return ZImageConditioningOutput(conditioning=self.conditioning)
# Clone embeddings to avoid modifying the original
prompt_embeds = z_image_conditioning.prompt_embeds.clone()
# Calculate actual noise strength based on embedding statistics
# This auto-calibration ensures consistent results across different prompts
embed_std = torch.std(prompt_embeds).item()
actual_strength = self.strength * embed_std
# Generate deterministic noise using the seed
generator = torch.Generator(device=prompt_embeds.device)
generator.manual_seed(self.seed)
noise = torch.rand(
prompt_embeds.shape, generator=generator, device=prompt_embeds.device, dtype=prompt_embeds.dtype
)
noise = noise * 2 - 1 # Scale to [-1, 1)
noise = noise * actual_strength
# Create selective mask for noise application
generator.manual_seed(self.seed + 1)
noise_mask = torch.bernoulli(
torch.ones_like(prompt_embeds) * (self.randomize_percent / 100.0),
generator=generator,
).bool()
# Apply noise only to masked positions
prompt_embeds = prompt_embeds + (noise * noise_mask)
# Save modified conditioning
new_conditioning = ZImageConditioningInfo(prompt_embeds=prompt_embeds)
conditioning_data = ConditioningFieldData(conditionings=[new_conditioning])
conditioning_name = context.conditioning.save(conditioning_data)
return ZImageConditioningOutput(
conditioning=ZImageConditioningField(
conditioning_name=conditioning_name,
mask=self.conditioning.mask,
)
)

View File

@@ -0,0 +1,197 @@
from contextlib import ExitStack
from typing import Iterator, Optional, Tuple
import torch
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
TensorField,
UIComponent,
ZImageConditioningField,
)
from invokeai.app.invocations.model import Qwen3EncoderField
from invokeai.app.invocations.primitives import ZImageConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.z_image_lora_constants import Z_IMAGE_LORA_QWEN3_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningFieldData,
ZImageConditioningInfo,
)
from invokeai.backend.util.devices import TorchDevice
# Z-Image max sequence length based on diffusers default
Z_IMAGE_MAX_SEQ_LEN = 512
@invocation(
"z_image_text_encoder",
title="Prompt - Z-Image",
tags=["prompt", "conditioning", "z-image"],
category="conditioning",
version="1.1.0",
classification=Classification.Prototype,
)
class ZImageTextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for a Z-Image image.
Supports regional prompting by connecting a mask input.
"""
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
qwen3_encoder: Qwen3EncoderField = InputField(
title="Qwen3 Encoder",
description=FieldDescriptions.qwen3_encoder,
input=Input.Connection,
)
mask: Optional[TensorField] = InputField(
default=None,
description="A mask defining the region that this conditioning prompt applies to.",
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ZImageConditioningOutput:
prompt_embeds = self._encode_prompt(context, max_seq_len=Z_IMAGE_MAX_SEQ_LEN)
conditioning_data = ConditioningFieldData(conditionings=[ZImageConditioningInfo(prompt_embeds=prompt_embeds)])
conditioning_name = context.conditioning.save(conditioning_data)
return ZImageConditioningOutput(
conditioning=ZImageConditioningField(conditioning_name=conditioning_name, mask=self.mask)
)
def _encode_prompt(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
"""Encode prompt using Qwen3 text encoder.
Based on the ZImagePipeline._encode_prompt method from diffusers.
"""
prompt = self.prompt
device = TorchDevice.choose_torch_device()
text_encoder_info = context.models.load(self.qwen3_encoder.text_encoder)
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
with ExitStack() as exit_stack:
(_, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
# Apply LoRA models to the text encoder
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
model=text_encoder,
patches=self._lora_iterator(context),
prefix=Z_IMAGE_LORA_QWEN3_PREFIX,
dtype=lora_dtype,
)
)
context.util.signal_progress("Running Qwen3 text encoder")
if not isinstance(text_encoder, PreTrainedModel):
raise TypeError(
f"Expected PreTrainedModel for text encoder, got {type(text_encoder).__name__}. "
"The Qwen3 encoder model may be corrupted or incompatible."
)
if not isinstance(tokenizer, PreTrainedTokenizerBase):
raise TypeError(
f"Expected PreTrainedTokenizerBase for tokenizer, got {type(tokenizer).__name__}. "
"The Qwen3 tokenizer may be corrupted or incompatible."
)
# Apply chat template similar to diffusers ZImagePipeline
# The chat template formats the prompt for the Qwen3 model
try:
prompt_formatted = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
except (AttributeError, TypeError) as e:
# Fallback if tokenizer doesn't support apply_chat_template or enable_thinking
context.logger.warning(f"Chat template failed ({e}), using raw prompt.")
prompt_formatted = prompt
# Tokenize the formatted prompt
text_inputs = tokenizer(
prompt_formatted,
padding="max_length",
max_length=max_seq_len,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
if not isinstance(text_input_ids, torch.Tensor):
raise TypeError(
f"Expected torch.Tensor for input_ids, got {type(text_input_ids).__name__}. "
"Tokenizer returned unexpected type."
)
if not isinstance(attention_mask, torch.Tensor):
raise TypeError(
f"Expected torch.Tensor for attention_mask, got {type(attention_mask).__name__}. "
"Tokenizer returned unexpected type."
)
# Check for truncation
untruncated_ids = tokenizer(prompt_formatted, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_seq_len - 1 : -1])
context.logger.warning(
f"The following part of your input was truncated because `max_sequence_length` is set to "
f"{max_seq_len} tokens: {removed_text}"
)
# Get hidden states from the text encoder
# Use the second-to-last hidden state like diffusers does
prompt_mask = attention_mask.to(device).bool()
outputs = text_encoder(
text_input_ids.to(device),
attention_mask=prompt_mask,
output_hidden_states=True,
)
# Validate hidden_states output
if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
raise RuntimeError(
"Text encoder did not return hidden_states. "
"Ensure output_hidden_states=True is supported by this model."
)
if len(outputs.hidden_states) < 2:
raise RuntimeError(
f"Expected at least 2 hidden states from text encoder, got {len(outputs.hidden_states)}. "
"This may indicate an incompatible model or configuration."
)
prompt_embeds = outputs.hidden_states[-2]
# Z-Image expects a 2D tensor [seq_len, hidden_dim] with only valid tokens
# Based on diffusers ZImagePipeline implementation:
# embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
# Since batch_size=1, we take the first item and filter by mask
prompt_embeds = prompt_embeds[0][prompt_mask[0]]
if not isinstance(prompt_embeds, torch.Tensor):
raise TypeError(
f"Expected torch.Tensor for prompt embeddings, got {type(prompt_embeds).__name__}. "
"Text encoder returned unexpected type."
)
return prompt_embeds
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
"""Iterate over LoRA models to apply to the Qwen3 text encoder."""
for lora in self.qwen3_encoder.loras:
lora_info = context.models.load(lora.lora)
if not isinstance(lora_info.model, ModelPatchRaw):
raise TypeError(
f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. "
"The LoRA model may be corrupted or incompatible."
)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -85,6 +85,7 @@ class InvokeAIAppConfig(BaseSettings):
max_cache_ram_gb: The maximum amount of CPU RAM to use for model caching in GB. If unset, the limit will be configured based on the available RAM. In most cases, it is recommended to leave this unset.
max_cache_vram_gb: The amount of VRAM to use for model caching in GB. If unset, the limit will be configured based on the available VRAM and the device_working_mem_gb. In most cases, it is recommended to leave this unset.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
model_cache_keep_alive_min: How long to keep models in cache after last use, in minutes. A value of 0 (the default) means models are kept in cache indefinitely. If no model generations occur within the timeout period, the model cache is cleared using the same logic as the 'Clear Model Cache' button.
device_working_mem_gb: The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.
enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.
keep_ram_copy_of_weights: Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.
@@ -165,9 +166,10 @@ class InvokeAIAppConfig(BaseSettings):
max_cache_ram_gb: Optional[float] = Field(default=None, gt=0, description="The maximum amount of CPU RAM to use for model caching in GB. If unset, the limit will be configured based on the available RAM. In most cases, it is recommended to leave this unset.")
max_cache_vram_gb: Optional[float] = Field(default=None, ge=0, description="The amount of VRAM to use for model caching in GB. If unset, the limit will be configured based on the available VRAM and the device_working_mem_gb. In most cases, it is recommended to leave this unset.")
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
model_cache_keep_alive_min: float = Field(default=0, ge=0, description="How long to keep models in cache after last use, in minutes. A value of 0 (the default) means models are kept in cache indefinitely. If no model generations occur within the timeout period, the model cache is cleared using the same logic as the 'Clear Model Cache' button.")
device_working_mem_gb: float = Field(default=3, description="The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.")
enable_partial_loading: bool = Field(default=False, description="Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.")
keep_ram_copy_of_weights: bool = Field(default=True, description="Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.")
keep_ram_copy_of_weights: bool = Field(default=True, description="Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.")
# Deprecated CACHE configs
ram: Optional[float] = Field(default=None, gt=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
vram: Optional[float] = Field(default=None, ge=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")

View File

@@ -14,7 +14,7 @@ class NodeExecutionStatsSummary:
node_type: str
num_calls: int
time_used_seconds: float
peak_vram_gb: float
delta_vram_gb: float
@dataclass
@@ -58,10 +58,10 @@ class InvocationStatsSummary:
def __str__(self) -> str:
_str = ""
_str = f"Graph stats: {self.graph_stats.graph_execution_state_id}\n"
_str += f"{'Node':>30} {'Calls':>7} {'Seconds':>9} {'VRAM Used':>10}\n"
_str += f"{'Node':>30} {'Calls':>7} {'Seconds':>9} {'VRAM Change':+>10}\n"
for summary in self.node_stats:
_str += f"{summary.node_type:>30} {summary.num_calls:>7} {summary.time_used_seconds:>8.3f}s {summary.peak_vram_gb:>9.3f}G\n"
_str += f"{summary.node_type:>30} {summary.num_calls:>7} {summary.time_used_seconds:>8.3f}s {summary.delta_vram_gb:+10.3f}G\n"
_str += f"TOTAL GRAPH EXECUTION TIME: {self.graph_stats.execution_time_seconds:7.3f}s\n"
@@ -100,7 +100,7 @@ class NodeExecutionStats:
start_ram_gb: float # GB
end_ram_gb: float # GB
peak_vram_gb: float # GB
delta_vram_gb: float # GB
def total_time(self) -> float:
return self.end_time - self.start_time
@@ -174,9 +174,9 @@ class GraphExecutionStats:
for node_type, node_type_stats_list in node_stats_by_type.items():
num_calls = len(node_type_stats_list)
time_used = sum([n.total_time() for n in node_type_stats_list])
peak_vram = max([n.peak_vram_gb for n in node_type_stats_list])
delta_vram = max([n.delta_vram_gb for n in node_type_stats_list])
summary = NodeExecutionStatsSummary(
node_type=node_type, num_calls=num_calls, time_used_seconds=time_used, peak_vram_gb=peak_vram
node_type=node_type, num_calls=num_calls, time_used_seconds=time_used, delta_vram_gb=delta_vram
)
summaries.append(summary)

View File

@@ -52,8 +52,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
# Record state before the invocation.
start_time = time.time()
start_ram = psutil.Process().memory_info().rss
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
# Remember current VRAM usage
vram_in_use = torch.cuda.memory_allocated() if torch.cuda.is_available() else 0.0
assert services.model_manager.load is not None
services.model_manager.load.ram_cache.stats = self._cache_stats[graph_execution_state_id]
@@ -62,14 +63,16 @@ class InvocationStatsService(InvocationStatsServiceBase):
# Let the invocation run.
yield None
finally:
# Record state after the invocation.
# Record delta VRAM
delta_vram_gb = ((torch.cuda.memory_allocated() - vram_in_use) / GB) if torch.cuda.is_available() else 0.0
node_stats = NodeExecutionStats(
invocation_type=invocation.get_type(),
start_time=start_time,
end_time=time.time(),
start_ram_gb=start_ram / GB,
end_ram_gb=psutil.Process().memory_info().rss / GB,
peak_vram_gb=torch.cuda.max_memory_allocated() / GB if torch.cuda.is_available() else 0.0,
delta_vram_gb=delta_vram_gb,
)
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
@@ -81,6 +84,8 @@ class InvocationStatsService(InvocationStatsServiceBase):
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)
node_stats_summaries = self._get_node_summaries(graph_execution_state_id)
model_cache_stats_summary = self._get_model_cache_summary(graph_execution_state_id)
# Note: We use memory_allocated() here (not memory_reserved()) because we want to show
# the current actively-used VRAM, not the total reserved memory including PyTorch's cache.
vram_usage_gb = torch.cuda.memory_allocated() / GB if torch.cuda.is_available() else None
return InvocationStatsSummary(

View File

@@ -85,9 +85,12 @@ class LocalModelSource(StringLikeSource):
class HFModelSource(StringLikeSource):
"""
A HuggingFace repo_id with optional variant, sub-folder and access token.
A HuggingFace repo_id with optional variant, sub-folder(s) and access token.
Note that the variant option, if not provided to the constructor, will default to fp16, which is
what people (almost) always want.
The subfolder can be a single path or multiple paths joined by '+' (e.g., "text_encoder+tokenizer").
When multiple subfolders are specified, all of them will be downloaded and combined into the model directory.
"""
repo_id: str
@@ -103,6 +106,16 @@ class HFModelSource(StringLikeSource):
raise ValueError(f"{v}: invalid repo_id format")
return v
@property
def subfolders(self) -> list[Path]:
"""Return list of subfolders (supports '+' separated multiple subfolders)."""
if self.subfolder is None:
return []
subfolder_str = self.subfolder.as_posix()
if "+" in subfolder_str:
return [Path(s.strip()) for s in subfolder_str.split("+")]
return [self.subfolder]
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = self.repo_id

View File

@@ -1,8 +1,10 @@
"""Model installation class."""
import gc
import locale
import os
import re
import sys
import threading
import time
from copy import deepcopy
@@ -135,6 +137,8 @@ class ModelInstallService(ModelInstallServiceBase):
for model in self._scan_for_missing_models():
self._logger.warning(f"Missing model file: {model.name} at {model.path}")
self._write_invoke_managed_models_dir_readme()
def stop(self, invoker: Optional[Invoker] = None) -> None:
"""Stop the installer thread; after this the object can be deleted and garbage collected."""
if not self._running:
@@ -147,6 +151,14 @@ class ModelInstallService(ModelInstallServiceBase):
self._install_thread.join()
self._running = False
def _write_invoke_managed_models_dir_readme(self) -> None:
"""Write a README file to the Invoke-managed models directory warning users to not fiddle with it."""
readme_path = self.app_config.models_path / "README.txt"
with open(readme_path, "wt", encoding=locale.getpreferredencoding()) as f:
f.write(
"This directory is managed by Invoke. Do not add, delete or move files in this directory.\n\nTo manage models, use the web interface.\n"
)
def _clear_pending_jobs(self) -> None:
for job in self.list_jobs():
if not job.in_terminal_state:
@@ -177,6 +189,22 @@ class ModelInstallService(ModelInstallServiceBase):
config.source_type = ModelSourceType.Path
return self._register(model_path, config)
# TODO: Replace this with a proper fix for underlying problem of Windows holding open
# the file when it needs to be moved.
@staticmethod
def _move_with_retries(src: Path, dst: Path, attempts: int = 5, delay: float = 0.5) -> None:
"""Workaround for Windows file-handle issues when moving files."""
for tries_left in range(attempts, 0, -1):
try:
move(src, dst)
return
except PermissionError:
gc.collect()
if tries_left == 1:
raise
time.sleep(delay)
delay *= 2 # Exponential backoff
def install_path(
self,
model_path: Union[Path, str],
@@ -195,7 +223,7 @@ class ModelInstallService(ModelInstallServiceBase):
dest_dir.mkdir(parents=True)
dest_path = dest_dir / model_path.name if model_path.is_file() else dest_dir
if model_path.is_file():
move(model_path, dest_path)
self._move_with_retries(model_path, dest_path) # Windows workaround TODO: fix root cause
elif model_path.is_dir():
# Move the contents of the directory, not the directory itself
for item in model_path.iterdir():
@@ -407,10 +435,15 @@ class ModelInstallService(ModelInstallServiceBase):
model_path.mkdir(parents=True, exist_ok=True)
model_source = self._guess_source(str(source))
remote_files, _ = self._remote_files_from_source(model_source)
# Handle multiple subfolders for HFModelSource
subfolders = model_source.subfolders if isinstance(model_source, HFModelSource) else []
job = self._multifile_download(
dest=model_path,
remote_files=remote_files,
subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
subfolder=model_source.subfolder
if isinstance(model_source, HFModelSource) and len(subfolders) <= 1
else None,
subfolders=subfolders if len(subfolders) > 1 else None,
)
files_string = "file" if len(remote_files) == 1 else "files"
self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
@@ -428,10 +461,13 @@ class ModelInstallService(ModelInstallServiceBase):
if isinstance(source, HFModelSource):
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
assert isinstance(metadata, ModelMetadataWithFiles)
# Use subfolders property which handles '+' separated multiple subfolders
subfolders = source.subfolders
return (
metadata.download_urls(
variant=source.variant or self._guess_variant(),
subfolder=source.subfolder,
subfolder=source.subfolder if len(subfolders) <= 1 else None,
subfolders=subfolders if len(subfolders) > 1 else None,
session=self._session,
),
metadata,
@@ -482,6 +518,39 @@ class ModelInstallService(ModelInstallServiceBase):
self._install_thread.start()
self._running = True
@staticmethod
def _safe_rmtree(path: Path, logger: Any) -> None:
"""Remove a directory tree with retry logic for Windows file locking issues.
On Windows, memory-mapped files may not be immediately released even after
the file handle is closed. This function retries the removal with garbage
collection to help release any lingering references.
"""
max_retries = 3
retry_delay = 0.5 # seconds
for attempt in range(max_retries):
try:
# Force garbage collection to release any lingering file references
gc.collect()
rmtree(path)
return
except PermissionError as e:
if attempt < max_retries - 1 and sys.platform == "win32":
logger.warning(
f"Failed to remove {path} (attempt {attempt + 1}/{max_retries}): {e}. "
f"Retrying in {retry_delay}s..."
)
time.sleep(retry_delay)
retry_delay *= 2 # Exponential backoff
else:
logger.error(f"Failed to remove temporary directory {path}: {e}")
# On final failure, don't raise - the temp dir will be cleaned up on next startup
return
except Exception as e:
logger.error(f"Unexpected error removing {path}: {e}")
return
def _install_next_item(self) -> None:
self._logger.debug(f"Installer thread {threading.get_ident()} starting")
while True:
@@ -511,7 +580,7 @@ class ModelInstallService(ModelInstallServiceBase):
finally:
# if this is an install of a remote file, then clean up the temporary directory
if job._install_tmpdir is not None:
rmtree(job._install_tmpdir)
self._safe_rmtree(job._install_tmpdir, self._logger)
self._install_completed_event.set()
self._install_queue.task_done()
self._logger.info(f"Installer thread {threading.get_ident()} exiting")
@@ -556,7 +625,7 @@ class ModelInstallService(ModelInstallServiceBase):
path = self._app_config.models_path
for tmpdir in path.glob(f"{TMPDIR_PREFIX}*"):
self._logger.info(f"Removing dangling temporary directory {tmpdir}")
rmtree(tmpdir)
self._safe_rmtree(tmpdir, self._logger)
def _scan_for_missing_models(self) -> list[AnyModelConfig]:
"""Scan the models directory for missing models and return a list of them."""
@@ -731,10 +800,13 @@ class ModelInstallService(ModelInstallServiceBase):
install_job._install_tmpdir = destdir
install_job.total_bytes = sum((x.size or 0) for x in remote_files)
# Handle multiple subfolders for HFModelSource
subfolders = source.subfolders if isinstance(source, HFModelSource) else []
multifile_job = self._multifile_download(
remote_files=remote_files,
dest=destdir,
subfolder=source.subfolder if isinstance(source, HFModelSource) else None,
subfolder=source.subfolder if isinstance(source, HFModelSource) and len(subfolders) <= 1 else None,
subfolders=subfolders if len(subfolders) > 1 else None,
access_token=source.access_token,
submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict
)
@@ -761,6 +833,7 @@ class ModelInstallService(ModelInstallServiceBase):
remote_files: List[RemoteModelFile],
dest: Path,
subfolder: Optional[Path] = None,
subfolders: Optional[List[Path]] = None,
access_token: Optional[str] = None,
submit_job: bool = True,
) -> MultiFileDownloadJob:
@@ -768,24 +841,61 @@ class ModelInstallService(ModelInstallServiceBase):
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
if subfolder:
#
# For multiple subfolders (e.g., text_encoder+tokenizer), we create a combined folder name
# (e.g., sdxl-turbo_text_encoder_tokenizer) and keep each subfolder's contents in its own
# subdirectory within the model folder.
if subfolders and len(subfolders) > 1:
# Multiple subfolders: create combined name and keep subfolder structure
top = Path(remote_files[0].path.parts[0]) # e.g. "Z-Image-Turbo/"
subfolder_names = [sf.name.replace("/", "_").replace("\\", "_") for sf in subfolders]
combined_name = "_".join(subfolder_names)
path_to_add = Path(f"{top}_{combined_name}")
parts: List[RemoteModelFile] = []
for model_file in remote_files:
assert model_file.size is not None
# Determine which subfolder this file belongs to
file_path = model_file.path
new_path: Optional[Path] = None
for sf in subfolders:
try:
# Try to get relative path from this subfolder
relative = file_path.relative_to(top / sf)
# Keep the subfolder name as a subdirectory
new_path = path_to_add / sf.name / relative
break
except ValueError:
continue
if new_path is None:
# File doesn't match any subfolder, keep original path structure
new_path = path_to_add / file_path.relative_to(top)
parts.append(RemoteModelFile(url=model_file.url, path=new_path))
elif subfolder:
# Single subfolder: flatten into renamed folder
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
path_to_remove = top / subfolder # sdxl-turbo/vae/
subfolder_rename = subfolder.name.replace("/", "_").replace("\\", "_")
path_to_add = Path(f"{top}_{subfolder_rename}")
else:
path_to_remove = Path(".")
path_to_add = Path(".")
parts: List[RemoteModelFile] = []
for model_file in remote_files:
assert model_file.size is not None
parts.append(
RemoteModelFile(
url=model_file.url, # if a subfolder, then sdxl-turbo_vae/config.json
path=path_to_add / model_file.path.relative_to(path_to_remove),
parts = []
for model_file in remote_files:
assert model_file.size is not None
parts.append(
RemoteModelFile(
url=model_file.url,
path=path_to_add / model_file.path.relative_to(path_to_remove),
)
)
)
else:
# No subfolder specified - pass through unchanged
parts = []
for model_file in remote_files:
assert model_file.size is not None
parts.append(RemoteModelFile(url=model_file.url, path=model_file.path))
return self._download_queue.multifile_download(
parts=parts,

View File

@@ -60,6 +60,10 @@ class ModelManagerService(ModelManagerServiceBase):
service.start(invoker)
def stop(self, invoker: Invoker) -> None:
# Shutdown the model cache to cancel any pending timers
if hasattr(self._load, "ram_cache"):
self._load.ram_cache.shutdown()
for service in [self._store, self._install, self._load]:
if hasattr(service, "stop"):
service.stop(invoker)
@@ -88,7 +92,10 @@ class ModelManagerService(ModelManagerServiceBase):
max_ram_cache_size_gb=app_config.max_cache_ram_gb,
max_vram_cache_size_gb=app_config.max_cache_vram_gb,
execution_device=execution_device or TorchDevice.choose_torch_device(),
storage_device="cpu",
log_memory_usage=app_config.log_memory_usage,
logger=logger,
keep_alive_minutes=app_config.model_cache_keep_alive_min,
)
loader = ModelLoadService(
app_config=app_config,

View File

@@ -19,11 +19,13 @@ from invokeai.backend.model_manager.configs.main import MainModelDefaultSettings
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ClipVariantType,
Flux2VariantType,
FluxVariantType,
ModelFormat,
ModelSourceType,
ModelType,
ModelVariantType,
Qwen3VariantType,
SchedulerPredictionType,
)
@@ -89,8 +91,8 @@ class ModelRecordChanges(BaseModelExcludeNull):
# Checkpoint-specific changes
# TODO(MM2): Should we expose these? Feels footgun-y...
variant: Optional[ModelVariantType | ClipVariantType | FluxVariantType] = Field(
description="The variant of the model.", default=None
variant: Optional[ModelVariantType | ClipVariantType | FluxVariantType | Flux2VariantType | Qwen3VariantType] = (
Field(description="The variant of the model.", default=None)
)
prediction_type: Optional[SchedulerPredictionType] = Field(
description="The prediction type of the model.", default=None
@@ -138,6 +140,18 @@ class ModelRecordServiceBase(ABC):
"""
pass
@abstractmethod
def replace_model(self, key: str, new_config: AnyModelConfig) -> AnyModelConfig:
"""
Replace the model record entirely, returning the new record.
This is used when we re-identify a model and have a new config object.
:param key: Unique key for the model to be updated.
:param new_config: The new model config to write.
"""
pass
@abstractmethod
def get_model(self, key: str) -> AnyModelConfig:
"""

View File

@@ -179,6 +179,23 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
return self.get_model(key)
def replace_model(self, key: str, new_config: AnyModelConfig) -> AnyModelConfig:
if key != new_config.key:
raise ValueError("key does not match new_config.key")
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
UPDATE models
SET
config=?
WHERE id=?;
""",
(new_config.model_dump_json(), key),
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
return self.get_model(key)
def get_model(self, key: str) -> AnyModelConfig:
"""
Retrieve the ModelConfigBase instance for the indicated model.

View File

@@ -0,0 +1,194 @@
# InvokeAI Graph - Design Overview
High-level design for the graph module. Focuses on responsibilities, data flow, and how traversal works.
## 1) Purpose
Provide a typed, acyclic workflow model (**Graph**) plus a runtime scheduler (**GraphExecutionState**) that expands
iterator patterns, tracks readiness via indegree (the number of incoming edges to a node in the directed graph), and
executes nodes in class-grouped batches. Source graphs remain immutable during a run; runtime expansion happens in a
separate execution graph.
## 2) Major Data Types
### EdgeConnection
* Fields: `node_id: str`, `field: str`.
* Hashable; printed as `node.field` for readable diagnostics.
### Edge
* Fields: `source: EdgeConnection`, `destination: EdgeConnection`.
* One directed connection from a specific output port to a specific input port.
### AnyInvocation / AnyInvocationOutput
* Pydantic wrappers that carry concrete invocation models and outputs.
* No registry logic in this file; they are permissive containers for heterogeneous nodes.
### IterateInvocation / CollectInvocation
* Control nodes used by validation and execution:
* **IterateInvocation**: input `collection`, outputs include `item` (and index/total).
* **CollectInvocation**: many `item` inputs aggregated to one `collection` output.
## 3) Graph (author-time model)
A container for declared nodes and edges. Does **not** perform iteration expansion.
### 3.1 Data
* `nodes: dict[str, AnyInvocation]` - key must equal `node.id`.
* `edges: list[Edge]` - zero or more.
* Utility: `_get_input_edges(node_id, field?)`, `_get_output_edges(node_id, field?)`
These scan `self.edges` (no adjacency indices in the current code).
### 3.2 Validation (`validate_self`)
Runs a sequence of checks:
1. **Node ID uniqueness**
No duplicate IDs; map key equals `node.id`.
2. **Endpoint existence**
Source and destination node IDs must exist.
3. **Port existence**
Input ports must exist on the node class; output ports on the node's output model.
4. **Type compatibility**
`get_output_field_type` vs `get_input_field_type` and `are_connection_types_compatible`.
5. **DAG constraint**
Build a *flat* `DiGraph` (no runtime expansion) and assert acyclicity.
6. **Iterator / collector structure**
Enforce special rules:
* Iterator's input must be `collection`; its outgoing edges use `item`.
* Collector accepts many `item` inputs; outputs a single `collection`.
* Edge fan-in to a non-collector input is rejected.
### 3.3 Edge admission (`_validate_edge`)
Checks a single prospective edge before insertion:
* Endpoints/ports exist.
* Destination port is not already occupied unless it's a collector `item`.
* Adding the edge to the flat DAG must keep it acyclic.
* Iterator/collector constraints re-checked when the edge creates relevant patterns.
### 3.4 Topology utilities
* `nx_graph()` - DiGraph of declared nodes and edges.
* `nx_graph_with_data()` - includes node/edge attributes.
* `nx_graph_flat()` - "flattened" DAG (still author-time; no runtime copies).
Used in validation and in `_prepare()` during execution planning.
### 3.5 Mutation helpers
* `add_node`, `update_node` (preserve edges, rewrite endpoints if id changes), `delete_node`.
* `add_edge`, `delete_edge` (with validation).
## 4) GraphExecutionState (runtime)
Holds the state for a single run. Keeps the source graph intact; materializes a separate execution graph.
### 4.1 Data
* `graph: Graph` - immutable source during a run.
* `execution_graph: Graph` - materialized runtime nodes/edges.
* `executed: set[str]`, `executed_history: list[str]`.
* `results: dict[str, AnyInvocationOutput]`, `errors: dict[str, str]`.
* `prepared_source_mapping: dict[str, str]` - exec id → source id.
* `source_prepared_mapping: dict[str, set[str]]` - source id → exec ids.
* `indegree: dict[str, int]` - unmet inputs per exec node.
* **Ready queues grouped by class** (private attrs):
`_ready_queues: dict[class_name, deque[str]]`, `_active_class: Optional[str]`. Optional `ready_order: list[str]` to
prioritize classes.
### 4.2 Core methods
* `next()`
Returns the next ready exec node. If none, calls `_prepare()` to materialize more, then retries. Before returning a
node, `_prepare_inputs()` deep-copies inbound values into the node fields.
* `complete(node_id, output)`
Record result; mark exec node executed; if all exec copies of the same **source** are done, mark the source executed.
For each outgoing exec edge, decrement child indegree and enqueue when it reaches zero.
### 4.3 Preparation (`_prepare()`)
* Build a flat DAG from the **source** graph.
* Choose the **next source node** in topological order that:
1. has not been prepared,
2. if it is an iterator, *its inputs are already executed*,
3. it has *no unexecuted iterator ancestors*.
* If the node is a **CollectInvocation**: collapse all prepared parents into one mapping and create **one** exec node.
* Otherwise: compute all combinations of prepared iterator ancestors. For each combination, pick the matching prepared parent per upstream and create **one** exec node.
* For each new exec node:
* Deep-copy the source node; assign a fresh ID (and `index` for iterators).
* Wire edges from chosen prepared parents.
* Set `indegree = number of unmet inputs` (i.e., parents not yet executed).
* If `indegree == 0`, enqueue into its class queue.
### 4.4 Readiness and batching
* `_enqueue_if_ready(nid)` enqueues by class name only when `indegree == 0` and not executed.
* `_get_next_node()` drains the `_active_class` queue FIFO; when empty, selects the next nonempty class queue (by `ready_order` if set, else alphabetical), and continues. Optional fairness knobs can limit batch size per class; default is drain fully.
#### 4.4.1 Indegree (what it is and how it's used)
**Indegree** is the number of incoming edges to a node in the execution graph that are still unmet. In this engine:
* For every materialized exec node, `indegree[node]` equals the count of its prerequisite parents that have **not** finished yet.
* A node is "ready" exactly when `indegree[node] == 0`; only then is it enqueued.
* When a node completes, the scheduler decrements `indegree[child]` for each outgoing edge. Any child that reaches 0 is enqueued.
Example: edges `A→C`, `B→C`, `C→D`. Start: `A:0, B:0, C:2, D:1`. Run `A``C:1`. Run `B``C:0` → enqueue `C`. Run `C`
`D:0` → enqueue `D`. Run `D` → done.
### 4.5 Input hydration (`_prepare_inputs()`)
* For **CollectInvocation**: gather all incoming `item` values into `collection`.
* For all others: deep-copy each incoming edge's value into the destination field.
This prevents cross-node mutation through shared references.
## 5) Traversal Summary
1. Author builds a valid **Graph**.
2. Create **GraphExecutionState** with that graph.
3. Loop:
* `node = state.next()` → may trigger `_prepare()` expansion.
* Execute node externally → `output`.
* `state.complete(node.id, output)` → updates indegrees and queues.
4. Finish when `next()` returns `None`.
The source graph is never mutated; all expansion occurs in `execution_graph` with traceability back to source nodes.
## 6) Invariants
* Source **Graph** remains a DAG and type-consistent.
* `execution_graph` remains a DAG.
* Nodes are enqueued only when `indegree == 0`.
* `results` and `errors` are keyed by **exec node id**.
* Collectors only aggregate `item` inputs; other inputs behave one-to-one.
## 7) Extensibility
* **New node types**: implement as Pydantic models with typed fields and outputs. Register per your invocation system; this file accepts them as `AnyInvocation`.
* **Scheduling policy**: adjust `ready_order` to batch by class; add a batch cap for fairness without changing complexity.
* **Dynamic behaviors** (future): can be added in `GraphExecutionState` by creating exec nodes and edges at `complete()` time, as long as the DAG invariant holds.
## 8) Error Model (selected)
* `DuplicateNodeIdError`, `NodeAlreadyInGraphError`
* `NodeNotFoundError`, `NodeFieldNotFoundError`
* `InvalidEdgeError`, `CyclicalGraphError`
* `NodeInputError` (raised when preparing inputs for execution)
Messages favor short, precise diagnostics (node id, field, and failing condition).
## 9) Rationale
* **Two-graph approach** isolates authoring from execution expansion and keeps validation simple.
* **Indegree + queues** gives O(1) scheduling decisions with clear batching semantics.
* **Iterator/collector separation** keeps fan-out/fan-in explicit and testable.
* **Deep-copy hydration** avoids incidental aliasing bugs between nodes.

View File

@@ -2,7 +2,8 @@
import copy
import itertools
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
from collections import deque
from typing import Any, Deque, Iterable, Optional, Type, TypeVar, Union, get_args, get_origin
import networkx as nx
from pydantic import (
@@ -10,6 +11,7 @@ from pydantic import (
ConfigDict,
GetCoreSchemaHandler,
GetJsonSchemaHandler,
PrivateAttr,
ValidationError,
field_validator,
)
@@ -33,6 +35,10 @@ from invokeai.app.util.misc import uuid_string
# in 3.10 this would be "from types import NoneType"
NoneType = type(None)
# Port name constants
ITEM_FIELD = "item"
COLLECTION_FIELD = "collection"
class EdgeConnection(BaseModel):
node_id: str = Field(description="The id of the node for this edge connection")
@@ -395,7 +401,7 @@ class Graph(BaseModel):
try:
self.edges.remove(edge)
except KeyError:
except ValueError:
pass
def validate_self(self) -> None:
@@ -414,7 +420,8 @@ class Graph(BaseModel):
# Validate that all node ids are unique
node_ids = [n.id for n in self.nodes.values()]
duplicate_node_ids = {node_id for node_id in node_ids if node_ids.count(node_id) >= 2}
seen = set()
duplicate_node_ids = {nid for nid in node_ids if (nid in seen) or seen.add(nid)}
if duplicate_node_ids:
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
@@ -529,19 +536,19 @@ class Graph(BaseModel):
raise InvalidEdgeError(f"Field types are incompatible ({edge})")
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
if isinstance(to_node, IterateInvocation) and edge.destination.field == COLLECTION_FIELD:
err = self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source)
if err is not None:
raise InvalidEdgeError(f"Iterator input type does not match iterator output type ({edge}): {err}")
# Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
if isinstance(from_node, IterateInvocation) and edge.source.field == ITEM_FIELD:
err = self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination)
if err is not None:
raise InvalidEdgeError(f"Iterator output type does not match iterator input type ({edge}): {err}")
# Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
if isinstance(to_node, CollectInvocation) and edge.destination.field == ITEM_FIELD:
err = self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source)
if err is not None:
raise InvalidEdgeError(f"Collector output type does not match collector input type ({edge}): {err}")
@@ -549,7 +556,7 @@ class Graph(BaseModel):
# Validate if collector output type matches input type (if this edge results in both being set) - skip if the destination field is not Any or list[Any]
if (
isinstance(from_node, CollectInvocation)
and edge.source.field == "collection"
and edge.source.field == COLLECTION_FIELD
and not self._is_destination_field_list_of_Any(edge)
and not self._is_destination_field_Any(edge)
):
@@ -639,8 +646,8 @@ class Graph(BaseModel):
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> str | None:
inputs = [e.source for e in self._get_input_edges(node_id, "collection")]
outputs = [e.destination for e in self._get_output_edges(node_id, "item")]
inputs = [e.source for e in self._get_input_edges(node_id, COLLECTION_FIELD)]
outputs = [e.destination for e in self._get_output_edges(node_id, ITEM_FIELD)]
if new_input is not None:
inputs.append(new_input)
@@ -670,7 +677,7 @@ class Graph(BaseModel):
if isinstance(input_node, CollectInvocation):
# Traverse the graph to find the first collector input edge. Collectors validate that their collection
# inputs are all of the same type, so we can use the first input edge to determine the collector's type
first_collector_input_edge = self._get_input_edges(input_node.id, "item")[0]
first_collector_input_edge = self._get_input_edges(input_node.id, ITEM_FIELD)[0]
first_collector_input_type = get_output_field_type(
self.get_node(first_collector_input_edge.source.node_id), first_collector_input_edge.source.field
)
@@ -690,8 +697,8 @@ class Graph(BaseModel):
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> str | None:
inputs = [e.source for e in self._get_input_edges(node_id, "item")]
outputs = [e.destination for e in self._get_output_edges(node_id, "collection")]
inputs = [e.source for e in self._get_input_edges(node_id, ITEM_FIELD)]
outputs = [e.destination for e in self._get_output_edges(node_id, COLLECTION_FIELD)]
if new_input is not None:
inputs.append(new_input)
@@ -761,7 +768,7 @@ class Graph(BaseModel):
# TODO: figure out if iteration nodes need to be expanded
unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges}
g.add_edges_from([(e[0], e[1]) for e in unique_edges])
g.add_edges_from(unique_edges)
return g
@@ -802,6 +809,41 @@ class GraphExecutionState(BaseModel):
description="The map of original graph nodes to prepared nodes",
default_factory=dict,
)
# Ready queues grouped by node class name (internal only)
_ready_queues: dict[str, Deque[str]] = PrivateAttr(default_factory=dict)
# Current class being drained; stays until its queue empties
_active_class: Optional[str] = PrivateAttr(default=None)
# Optional priority; others follow in name order
ready_order: list[str] = Field(default_factory=list)
indegree: dict[str, int] = Field(default_factory=dict, description="Remaining unmet input count for exec nodes")
def _type_key(self, node_obj: BaseInvocation) -> str:
return node_obj.__class__.__name__
def _queue_for(self, cls_name: str) -> Deque[str]:
q = self._ready_queues.get(cls_name)
if q is None:
q = deque()
self._ready_queues[cls_name] = q
return q
def set_ready_order(self, order: Iterable[Type[BaseInvocation] | str]) -> None:
names: list[str] = []
for x in order:
names.append(x.__name__ if hasattr(x, "__name__") else str(x))
self.ready_order = names
def _enqueue_if_ready(self, nid: str) -> None:
"""Push nid to its class queue if unmet inputs == 0."""
# Invariants: exec node exists and has an indegree entry
if nid not in self.execution_graph.nodes:
raise KeyError(f"exec node {nid} missing from execution_graph")
if nid not in self.indegree:
raise KeyError(f"indegree missing for exec node {nid}")
if self.indegree[nid] != 0 or nid in self.executed:
return
node_obj = self.execution_graph.nodes[nid]
self._queue_for(self._type_key(node_obj)).append(nid)
model_config = ConfigDict(
json_schema_extra={
@@ -834,12 +876,14 @@ class GraphExecutionState(BaseModel):
# If there are no prepared nodes, prepare some nodes
next_node = self._get_next_node()
if next_node is None:
prepared_id = self._prepare()
base_g = self.graph.nx_graph_flat()
prepared_id = self._prepare(base_g)
# Prepare as many nodes as we can
while prepared_id is not None:
prepared_id = self._prepare()
next_node = self._get_next_node()
prepared_id = self._prepare(base_g)
if next_node is None:
next_node = self._get_next_node()
# Get values from edges
if next_node is not None:
@@ -869,6 +913,18 @@ class GraphExecutionState(BaseModel):
self.executed.add(source_node)
self.executed_history.append(source_node)
# Decrement children indegree and enqueue when ready
for e in self.execution_graph._get_output_edges(node_id):
child = e.destination.node_id
if child not in self.indegree:
raise KeyError(f"indegree missing for exec node {child}")
# Only decrement if there's something to satisfy
if self.indegree[child] == 0:
raise RuntimeError(f"indegree underflow for {child} from parent {node_id}")
self.indegree[child] -= 1
if self.indegree[child] == 0:
self._enqueue_if_ready(child)
def set_node_error(self, node_id: str, error: str):
"""Marks a node as errored"""
self.errors[node_id] = error
@@ -892,7 +948,7 @@ class GraphExecutionState(BaseModel):
# If this is an iterator node, we must create a copy for each iteration
if isinstance(node, IterateInvocation):
# Get input collection edge (should error if there are no inputs)
input_collection_edge = next(iter(self.graph._get_input_edges(node_id, "collection")))
input_collection_edge = next(iter(self.graph._get_input_edges(node_id, COLLECTION_FIELD)))
input_collection_prepared_node_id = next(
n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id
)
@@ -922,7 +978,7 @@ class GraphExecutionState(BaseModel):
# Create a new node (or one for each iteration of this iterator)
for i in range(self_iteration_count) if self_iteration_count > 0 else [-1]:
# Create a new node
new_node = copy.deepcopy(node)
new_node = node.model_copy(deep=True)
# Create the node id (use a random uuid)
new_node.id = uuid_string()
@@ -946,53 +1002,55 @@ class GraphExecutionState(BaseModel):
)
self.execution_graph.add_edge(new_edge)
# Initialize indegree as unmet inputs only and enqueue if ready
inputs = self.execution_graph._get_input_edges(new_node.id)
unmet = sum(1 for e in inputs if e.source.node_id not in self.executed)
self.indegree[new_node.id] = unmet
self._enqueue_if_ready(new_node.id)
new_nodes.append(new_node.id)
return new_nodes
def _iterator_graph(self) -> nx.DiGraph:
def _iterator_graph(self, base: Optional[nx.DiGraph] = None) -> nx.DiGraph:
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
g = self.graph.nx_graph_flat()
g = base.copy() if base is not None else self.graph.nx_graph_flat()
collectors = (n for n in self.graph.nodes if isinstance(self.graph.get_node(n), CollectInvocation))
for c in collectors:
g.remove_edges_from(list(g.in_edges(c)))
return g
def _get_node_iterators(self, node_id: str) -> list[str]:
def _get_node_iterators(self, node_id: str, it_graph: Optional[nx.DiGraph] = None) -> list[str]:
"""Gets iterators for a node"""
g = self._iterator_graph()
iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.get_node(n), IterateInvocation)]
return iterators
g = it_graph or self._iterator_graph()
return [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.get_node(n), IterateInvocation)]
def _prepare(self) -> Optional[str]:
def _prepare(self, base_g: Optional[nx.DiGraph] = None) -> Optional[str]:
# Get flattened source graph
g = self.graph.nx_graph_flat()
g = base_g or self.graph.nx_graph_flat()
# Find next node that:
# - was not already prepared
# - is not an iterate node whose inputs have not been executed
# - does not have an unexecuted iterate ancestor
sorted_nodes = nx.topological_sort(g)
def unprepared(n: str) -> bool:
return n not in self.source_prepared_mapping
def iter_inputs_ready(n: str) -> bool:
if not isinstance(self.graph.get_node(n), IterateInvocation):
return True
return all(u in self.executed for u, _ in g.in_edges(n))
def no_unexecuted_iter_ancestors(n: str) -> bool:
return not any(
isinstance(self.graph.get_node(a), IterateInvocation) and a not in self.executed
for a in nx.ancestors(g, n)
)
next_node_id = next(
(
n
for n in sorted_nodes
# exclude nodes that have already been prepared
if n not in self.source_prepared_mapping
# exclude iterate nodes whose inputs have not been executed
and not (
isinstance(self.graph.get_node(n), IterateInvocation) # `n` is an iterate node...
and not all((e[0] in self.executed for e in g.in_edges(n))) # ...that has unexecuted inputs
)
# exclude nodes who have unexecuted iterate ancestors
and not any(
(
isinstance(self.graph.get_node(a), IterateInvocation) # `a` is an iterate ancestor of `n`...
and a not in self.executed # ...that is not executed
for a in nx.ancestors(g, n) # for all ancestors `a` of node `n`
)
)
),
(n for n in sorted_nodes if unprepared(n) and iter_inputs_ready(n) and no_unexecuted_iter_ancestors(n)),
None,
)
@@ -1000,7 +1058,7 @@ class GraphExecutionState(BaseModel):
return None
# Get all parents of the next node
next_node_parents = [e[0] for e in g.in_edges(next_node_id)]
next_node_parents = [u for u, _ in g.in_edges(next_node_id)]
# Create execution nodes
next_node = self.graph.get_node(next_node_id)
@@ -1018,7 +1076,8 @@ class GraphExecutionState(BaseModel):
else: # Iterators or normal nodes
# Get all iterator combinations for this node
# Will produce a list of lists of prepared iterator nodes, from which results can be iterated
iterator_nodes = self._get_node_iterators(next_node_id)
it_g = self._iterator_graph(g)
iterator_nodes = self._get_node_iterators(next_node_id, it_g)
iterator_nodes_prepared = [list(self.source_prepared_mapping[n]) for n in iterator_nodes]
iterator_node_prepared_combinations = list(itertools.product(*iterator_nodes_prepared))
@@ -1066,45 +1125,41 @@ class GraphExecutionState(BaseModel):
)
def _get_next_node(self) -> Optional[BaseInvocation]:
"""Gets the deepest node that is ready to be executed"""
g = self.execution_graph.nx_graph()
"""Gets the next ready node: FIFO within class, drain class before switching."""
# 1) Continue draining the active class
if self._active_class:
q = self._ready_queues.get(self._active_class)
while q:
nid = q.popleft()
if nid not in self.executed:
return self.execution_graph.nodes[nid]
# emptied: release active class
self._active_class = None
# Perform a topological sort using depth-first search
topo_order = list(nx.dfs_postorder_nodes(g))
# Get all IterateInvocation nodes
iterate_nodes = [n for n in topo_order if isinstance(self.execution_graph.nodes[n], IterateInvocation)]
# Sort the IterateInvocation nodes based on their index attribute
iterate_nodes.sort(key=lambda x: self.execution_graph.nodes[x].index)
# Prioritize IterateInvocation nodes and their children
for iterate_node in iterate_nodes:
if iterate_node not in self.executed and all((e[0] in self.executed for e in g.in_edges(iterate_node))):
return self.execution_graph.nodes[iterate_node]
# Check the children of the IterateInvocation node
for child_node in nx.dfs_postorder_nodes(g, iterate_node):
if child_node not in self.executed and all((e[0] in self.executed for e in g.in_edges(child_node))):
return self.execution_graph.nodes[child_node]
# If no IterateInvocation node or its children are ready, return the first ready node in the topological order
for node in topo_order:
if node not in self.executed and all((e[0] in self.executed for e in g.in_edges(node))):
return self.execution_graph.nodes[node]
# If no node is found, return None
# 2) Pick next class by priority, then by class name
seen = set(self.ready_order)
for cls_name in self.ready_order:
q = self._ready_queues.get(cls_name)
if q:
self._active_class = cls_name
# recurse to drain newly set active class
return self._get_next_node()
for cls_name in sorted(k for k in self._ready_queues.keys() if k not in seen):
q = self._ready_queues[cls_name]
if q:
self._active_class = cls_name
return self._get_next_node()
return None
def _prepare_inputs(self, node: BaseInvocation):
input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id]
input_edges = self.execution_graph._get_input_edges(node.id)
# Inputs must be deep-copied, else if a node mutates the object, other nodes that get the same input
# will see the mutation.
if isinstance(node, CollectInvocation):
output_collection = [
copydeep(getattr(self.results[edge.source.node_id], edge.source.field))
for edge in input_edges
if edge.destination.field == "item"
if edge.destination.field == ITEM_FIELD
]
node.collection = output_collection
else:

View File

@@ -630,6 +630,21 @@ class UtilInterface(InvocationContextInterface):
is_canceled=self.is_canceled,
)
def flux2_step_callback(self, intermediate_state: PipelineIntermediateState) -> None:
"""
The step callback for FLUX.2 Klein models (32-channel VAE).
Args:
intermediate_state: The intermediate state of the diffusion pipeline.
"""
diffusion_step_callback(
signal_progress=self.signal_progress,
intermediate_state=intermediate_state,
base_model=BaseModelType.Flux2,
is_canceled=self.is_canceled,
)
def signal_progress(
self,
message: str,

View File

@@ -27,6 +27,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_21 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_22 import build_migration_22
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_23 import build_migration_23
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_24 import build_migration_24
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_25 import build_migration_25
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -71,6 +72,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_22(app_config=config, logger=logger))
migrator.register_migration(build_migration_23(app_config=config, logger=logger))
migrator.register_migration(build_migration_24(app_config=config, logger=logger))
migrator.register_migration(build_migration_25(app_config=config, logger=logger))
migrator.run_migrations()
return db

View File

@@ -0,0 +1,61 @@
import json
import sqlite3
from logging import Logger
from typing import Any
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
from invokeai.backend.model_manager.taxonomy import ModelType, Qwen3VariantType
class Migration25Callback:
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
self._app_config = app_config
self._logger = logger
def __call__(self, cursor: sqlite3.Cursor) -> None:
cursor.execute("SELECT id, config FROM models;")
rows = cursor.fetchall()
migrated_count = 0
for model_id, config_json in rows:
try:
config_dict: dict[str, Any] = json.loads(config_json)
if config_dict.get("type") != ModelType.Qwen3Encoder.value:
continue
if "variant" in config_dict:
continue
config_dict["variant"] = Qwen3VariantType.Qwen3_4B.value
cursor.execute(
"UPDATE models SET config = ? WHERE id = ?;",
(json.dumps(config_dict), model_id),
)
migrated_count += 1
except json.JSONDecodeError as e:
self._logger.error("Invalid config JSON for model %s: %s", model_id, e)
raise
if migrated_count > 0:
self._logger.info(f"Migration complete: {migrated_count} Qwen3 encoder configs updated with variant field")
else:
self._logger.info("Migration complete: no Qwen3 encoder configs needed migration")
def build_migration_25(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
"""Builds the migration object for migrating from version 24 to version 25.
This migration adds the variant field to existing Qwen3 encoder models.
Models installed before the variant field was added will default to Qwen3_4B (for Z-Image compatibility).
"""
return Migration(
from_version=24,
to_version=25,
callback=Migration25Callback(app_config=app_config, logger=logger),
)

View File

@@ -74,3 +74,11 @@ class WorkflowRecordsStorageBase(ABC):
def update_opened_at(self, workflow_id: str) -> None:
"""Open a workflow."""
pass
@abstractmethod
def get_all_tags(
self,
categories: Optional[list[WorkflowCategory]] = None,
) -> list[str]:
"""Gets all unique tags from workflows."""
pass

View File

@@ -332,6 +332,48 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
(workflow_id,),
)
def get_all_tags(
self,
categories: Optional[list[WorkflowCategory]] = None,
) -> list[str]:
with self._db.transaction() as cursor:
conditions: list[str] = []
params: list[str] = []
# Only get workflows that have tags
conditions.append("tags IS NOT NULL AND tags != ''")
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
conditions.append(f"category IN ({placeholders})")
params.extend([category.value for category in categories])
stmt = """--sql
SELECT DISTINCT tags
FROM workflow_library
"""
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
cursor.execute(stmt, params)
rows = cursor.fetchall()
# Parse comma-separated tags and collect unique tags
all_tags: set[str] = set()
for row in rows:
tags_value = row[0]
if tags_value and isinstance(tags_value, str):
# Tags are stored as comma-separated string
for tag in tags_value.split(","):
tag_stripped = tag.strip()
if tag_stripped:
all_tags.add(tag_stripped)
return sorted(all_tags)
def _sync_default_workflows(self) -> None:
"""Syncs default workflows to the database. Internal use only."""

View File

@@ -93,14 +93,60 @@ COGVIEW4_LATENT_RGB_FACTORS = [
[-0.00955853, -0.00980067, -0.00977842],
]
# FLUX.2 uses 32 latent channels.
# Factors from ComfyUI: https://github.com/Comfy-Org/ComfyUI/blob/main/comfy/latent_formats.py
FLUX2_LATENT_RGB_FACTORS = [
# R G B
[0.0058, 0.0113, 0.0073],
[0.0495, 0.0443, 0.0836],
[-0.0099, 0.0096, 0.0644],
[0.2144, 0.3009, 0.3652],
[0.0166, -0.0039, -0.0054],
[0.0157, 0.0103, -0.0160],
[-0.0398, 0.0902, -0.0235],
[-0.0052, 0.0095, 0.0109],
[-0.3527, -0.2712, -0.1666],
[-0.0301, -0.0356, -0.0180],
[-0.0107, 0.0078, 0.0013],
[0.0746, 0.0090, -0.0941],
[0.0156, 0.0169, 0.0070],
[-0.0034, -0.0040, -0.0114],
[0.0032, 0.0181, 0.0080],
[-0.0939, -0.0008, 0.0186],
[0.0018, 0.0043, 0.0104],
[0.0284, 0.0056, -0.0127],
[-0.0024, -0.0022, -0.0030],
[0.1207, -0.0026, 0.0065],
[0.0128, 0.0101, 0.0142],
[0.0137, -0.0072, -0.0007],
[0.0095, 0.0092, -0.0059],
[0.0000, -0.0077, -0.0049],
[-0.0465, -0.0204, -0.0312],
[0.0095, 0.0012, -0.0066],
[0.0290, -0.0034, 0.0025],
[0.0220, 0.0169, -0.0048],
[-0.0332, -0.0457, -0.0468],
[-0.0085, 0.0389, 0.0609],
[-0.0076, 0.0003, -0.0043],
[-0.0111, -0.0460, -0.0614],
]
FLUX2_LATENT_RGB_BIAS = [-0.0329, -0.0718, -0.0851]
def sample_to_lowres_estimated_image(
samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None
samples: torch.Tensor,
latent_rgb_factors: torch.Tensor,
smooth_matrix: Optional[torch.Tensor] = None,
latent_rgb_bias: Optional[torch.Tensor] = None,
):
if samples.dim() == 4:
samples = samples[0]
latent_image = samples.permute(1, 2, 0) @ latent_rgb_factors
if latent_rgb_bias is not None:
latent_image = latent_image + latent_rgb_bias
if smooth_matrix is not None:
latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2)
latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1, 1, 3, 3)), padding=1)
@@ -153,6 +199,7 @@ def diffusion_step_callback(
sample = intermediate_state.latents
smooth_matrix: list[list[float]] | None = None
latent_rgb_bias: list[float] | None = None
if base_model in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
latent_rgb_factors = SD1_5_LATENT_RGB_FACTORS
elif base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
@@ -164,6 +211,12 @@ def diffusion_step_callback(
latent_rgb_factors = COGVIEW4_LATENT_RGB_FACTORS
elif base_model == BaseModelType.Flux:
latent_rgb_factors = FLUX_LATENT_RGB_FACTORS
elif base_model == BaseModelType.Flux2:
latent_rgb_factors = FLUX2_LATENT_RGB_FACTORS
latent_rgb_bias = FLUX2_LATENT_RGB_BIAS
elif base_model == BaseModelType.ZImage:
# Z-Image uses FLUX-compatible VAE with 16 latent channels
latent_rgb_factors = FLUX_LATENT_RGB_FACTORS
else:
raise ValueError(f"Unsupported base model: {base_model}")
@@ -171,8 +224,14 @@ def diffusion_step_callback(
smooth_matrix_torch = (
torch.tensor(smooth_matrix, dtype=sample.dtype, device=sample.device) if smooth_matrix else None
)
latent_rgb_bias_torch = (
torch.tensor(latent_rgb_bias, dtype=sample.dtype, device=sample.device) if latent_rgb_bias else None
)
image = sample_to_lowres_estimated_image(
samples=sample, latent_rgb_factors=latent_rgb_factors_torch, smooth_matrix=smooth_matrix_torch
samples=sample,
latent_rgb_factors=latent_rgb_factors_torch,
smooth_matrix=smooth_matrix_torch,
latent_rgb_bias=latent_rgb_bias_torch,
)
width = image.width * 8

View File

@@ -1,10 +1,13 @@
import inspect
import math
from typing import Callable
import torch
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from tqdm import tqdm
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
from invokeai.backend.flux.extensions.dype_extension import DyPEExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
@@ -35,149 +38,366 @@ def denoise(
# extra img tokens (sequence-wise) - for Kontext conditioning
img_cond_seq: torch.Tensor | None = None,
img_cond_seq_ids: torch.Tensor | None = None,
# DyPE extension for high-resolution generation
dype_extension: DyPEExtension | None = None,
# Optional scheduler for alternative sampling methods
scheduler: SchedulerMixin | None = None,
):
# step 0 is the initial state
total_steps = len(timesteps) - 1
step_callback(
PipelineIntermediateState(
step=0,
order=1,
total_steps=total_steps,
timestep=int(timesteps[0]),
latents=img,
),
)
# Determine if we're using a diffusers scheduler or the built-in Euler method
use_scheduler = scheduler is not None
if use_scheduler:
# Initialize scheduler with timesteps
# The timesteps list contains values in [0, 1] range (sigmas)
# LCM should use num_inference_steps (it has its own sigma schedule),
# while other schedulers can use custom sigmas if supported
is_lcm = scheduler.__class__.__name__ == "FlowMatchLCMScheduler"
set_timesteps_sig = inspect.signature(scheduler.set_timesteps)
if not is_lcm and "sigmas" in set_timesteps_sig.parameters:
# Scheduler supports custom sigmas - use InvokeAI's time-shifted schedule
scheduler.set_timesteps(sigmas=timesteps, device=img.device)
else:
# LCM or scheduler doesn't support custom sigmas - use num_inference_steps
# The schedule will be computed by the scheduler itself
num_inference_steps = len(timesteps) - 1
scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=img.device)
# For schedulers like Heun, the number of actual steps may differ
# (Heun doubles timesteps internally)
num_scheduler_steps = len(scheduler.timesteps)
# For user-facing step count, use the original number of denoising steps
total_steps = len(timesteps) - 1
else:
total_steps = len(timesteps) - 1
num_scheduler_steps = total_steps
# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
# Store original sequence length for slicing predictions
original_seq_len = img.shape[1]
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
# DyPE: Patch model with DyPE-aware position embedder
dype_embedder = None
original_pe_embedder = None
if dype_extension is not None:
dype_embedder, original_pe_embedder = dype_extension.patch_model(model)
# Run ControlNet models.
controlnet_residuals: list[ControlNetFluxOutput] = []
for controlnet_extension in controlnet_extensions:
controlnet_residuals.append(
controlnet_extension.run_controlnet(
timestep_index=step_index,
total_num_timesteps=total_steps,
img=img,
img_ids=img_ids,
try:
# Track the actual step for user-facing progress (accounts for Heun's double steps)
user_step = 0
if use_scheduler:
# Use diffusers scheduler for stepping
# Use tqdm with total_steps (user-facing steps) not num_scheduler_steps (internal steps)
# This ensures progress bar shows 1/8, 2/8, etc. even when scheduler uses more internal steps
pbar = tqdm(total=total_steps, desc="Denoising")
for step_index in range(num_scheduler_steps):
timestep = scheduler.timesteps[step_index]
# Convert scheduler timestep (0-1000) to normalized (0-1) for the model
t_curr = timestep.item() / scheduler.config.num_train_timesteps
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
# DyPE: Update step state for timestep-dependent scaling
if dype_extension is not None and dype_embedder is not None:
dype_extension.update_step_state(
embedder=dype_embedder,
timestep=t_curr,
timestep_index=user_step,
total_steps=total_steps,
)
# For Heun scheduler, track if we're in first or second order step
is_heun = hasattr(scheduler, "state_in_first_order")
in_first_order = scheduler.state_in_first_order if is_heun else True
# Run ControlNet models
controlnet_residuals: list[ControlNetFluxOutput] = []
for controlnet_extension in controlnet_extensions:
controlnet_residuals.append(
controlnet_extension.run_controlnet(
timestep_index=user_step,
total_num_timesteps=total_steps,
img=img,
img_ids=img_ids,
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
)
)
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
# Prepare input for model
img_input = img
img_input_ids = img_ids
if img_cond is not None:
img_input = torch.cat((img_input, img_cond), dim=-1)
if img_cond_seq is not None:
assert img_cond_seq_ids is not None
img_input = torch.cat((img_input, img_cond_seq), dim=1)
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
pred = model(
img=img_input,
img_ids=img_input_ids,
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
timestep_index=user_step,
total_num_timesteps=total_steps,
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
ip_adapter_extensions=pos_ip_adapter_extensions,
regional_prompting_extension=pos_regional_prompting_extension,
)
)
# Merge the ControlNet residuals from multiple ControlNets.
# TODO(ryand): We may want to calculate the sum just-in-time to keep peak memory low. Keep in mind, that the
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
# tensors. Calculating the sum materializes each tensor into its own instance.
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
if img_cond_seq is not None:
pred = pred[:, :original_seq_len]
# Prepare input for model - concatenate fresh each step
img_input = img
img_input_ids = img_ids
# Get CFG scale for current user step
step_cfg_scale = cfg_scale[min(user_step, len(cfg_scale) - 1)]
# Add channel-wise conditioning (for ControlNet, FLUX Fill, etc.)
if img_cond is not None:
img_input = torch.cat((img_input, img_cond), dim=-1)
if not math.isclose(step_cfg_scale, 1.0):
if neg_regional_prompting_extension is None:
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
# Add sequence-wise conditioning (for Kontext)
if img_cond_seq is not None:
assert img_cond_seq_ids is not None, (
"You need to provide either both or neither of the sequence conditioning"
)
img_input = torch.cat((img_input, img_cond_seq), dim=1)
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
neg_img_input = img
neg_img_input_ids = img_ids
pred = model(
img=img_input,
img_ids=img_input_ids,
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
timestep_index=step_index,
total_num_timesteps=total_steps,
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
ip_adapter_extensions=pos_ip_adapter_extensions,
regional_prompting_extension=pos_regional_prompting_extension,
)
if img_cond is not None:
neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1)
# Slice prediction to only include the main image tokens
if img_cond_seq is not None:
pred = pred[:, :original_seq_len]
if img_cond_seq is not None:
neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1)
neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1)
step_cfg_scale = cfg_scale[step_index]
neg_pred = model(
img=neg_img_input,
img_ids=neg_img_input_ids,
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
timestep_index=user_step,
total_num_timesteps=total_steps,
controlnet_double_block_residuals=None,
controlnet_single_block_residuals=None,
ip_adapter_extensions=neg_ip_adapter_extensions,
regional_prompting_extension=neg_regional_prompting_extension,
)
# If step_cfg_scale, is 1.0, then we don't need to run the negative prediction.
if not math.isclose(step_cfg_scale, 1.0):
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance
# on systems with sufficient VRAM.
if img_cond_seq is not None:
neg_pred = neg_pred[:, :original_seq_len]
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
if neg_regional_prompting_extension is None:
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
# Use scheduler.step() for the update
step_output = scheduler.step(model_output=pred, timestep=timestep, sample=img)
img = step_output.prev_sample
# For negative prediction with Kontext, we need to include the reference images
# to maintain consistency between positive and negative passes. Without this,
# CFG would create artifacts as the attention mechanism would see different
# spatial structures in each pass
neg_img_input = img
neg_img_input_ids = img_ids
# Get t_prev for inpainting (next sigma value)
if step_index + 1 < len(scheduler.sigmas):
t_prev = scheduler.sigmas[step_index + 1].item()
else:
t_prev = 0.0
# Add channel-wise conditioning for negative pass if present
if inpaint_extension is not None:
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
# For Heun, only increment user step after second-order step completes
if is_heun:
if not in_first_order:
# Second order step completed
user_step += 1
# Only call step_callback if we haven't exceeded total_steps
if user_step <= total_steps:
pbar.update(1)
preview_img = img - t_curr * pred
if inpaint_extension is not None:
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(
preview_img, 0.0
)
step_callback(
PipelineIntermediateState(
step=user_step,
order=2,
total_steps=total_steps,
timestep=int(t_curr * 1000),
latents=preview_img,
),
)
else:
# For LCM and other first-order schedulers
user_step += 1
# Only call step_callback if we haven't exceeded total_steps
# (LCM scheduler may have more internal steps than user-facing steps)
if user_step <= total_steps:
pbar.update(1)
preview_img = img - t_curr * pred
if inpaint_extension is not None:
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(
preview_img, 0.0
)
step_callback(
PipelineIntermediateState(
step=user_step,
order=1,
total_steps=total_steps,
timestep=int(t_curr * 1000),
latents=preview_img,
),
)
pbar.close()
return img
# Original Euler implementation (when scheduler is None)
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
# DyPE: Update step state for timestep-dependent scaling
if dype_extension is not None and dype_embedder is not None:
dype_extension.update_step_state(
embedder=dype_embedder,
timestep=t_curr,
timestep_index=step_index,
total_steps=total_steps,
)
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
# Run ControlNet models.
controlnet_residuals: list[ControlNetFluxOutput] = []
for controlnet_extension in controlnet_extensions:
controlnet_residuals.append(
controlnet_extension.run_controlnet(
timestep_index=step_index,
total_num_timesteps=total_steps,
img=img,
img_ids=img_ids,
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
)
)
# Merge the ControlNet residuals from multiple ControlNets.
# TODO(ryand): We may want to calculate the sum just-in-time to keep peak memory low. Keep in mind, that the
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
# tensors. Calculating the sum materializes each tensor into its own instance.
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
# Prepare input for model - concatenate fresh each step
img_input = img
img_input_ids = img_ids
# Add channel-wise conditioning (for ControlNet, FLUX Fill, etc.)
if img_cond is not None:
neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1)
img_input = torch.cat((img_input, img_cond), dim=-1)
# Add sequence-wise conditioning (Kontext) for negative pass
# This ensures reference images are processed consistently
# Add sequence-wise conditioning (for Kontext)
if img_cond_seq is not None:
neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1)
neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1)
assert img_cond_seq_ids is not None, (
"You need to provide either both or neither of the sequence conditioning"
)
img_input = torch.cat((img_input, img_cond_seq), dim=1)
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
neg_pred = model(
img=neg_img_input,
img_ids=neg_img_input_ids,
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
pred = model(
img=img_input,
img_ids=img_input_ids,
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
timestep_index=step_index,
total_num_timesteps=total_steps,
controlnet_double_block_residuals=None,
controlnet_single_block_residuals=None,
ip_adapter_extensions=neg_ip_adapter_extensions,
regional_prompting_extension=neg_regional_prompting_extension,
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
ip_adapter_extensions=pos_ip_adapter_extensions,
regional_prompting_extension=pos_regional_prompting_extension,
)
# Slice negative prediction to match main image tokens
# Slice prediction to only include the main image tokens
if img_cond_seq is not None:
neg_pred = neg_pred[:, :original_seq_len]
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
pred = pred[:, :original_seq_len]
preview_img = img - t_curr * pred
img = img + (t_prev - t_curr) * pred
step_cfg_scale = cfg_scale[step_index]
if inpaint_extension is not None:
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
# If step_cfg_scale, is 1.0, then we don't need to run the negative prediction.
if not math.isclose(step_cfg_scale, 1.0):
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance
# on systems with sufficient VRAM.
step_callback(
PipelineIntermediateState(
step=step_index + 1,
order=1,
total_steps=total_steps,
timestep=int(t_curr),
latents=preview_img,
),
)
if neg_regional_prompting_extension is None:
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
return img
# For negative prediction with Kontext, we need to include the reference images
# to maintain consistency between positive and negative passes. Without this,
# CFG would create artifacts as the attention mechanism would see different
# spatial structures in each pass
neg_img_input = img
neg_img_input_ids = img_ids
# Add channel-wise conditioning for negative pass if present
if img_cond is not None:
neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1)
# Add sequence-wise conditioning (Kontext) for negative pass
# This ensures reference images are processed consistently
if img_cond_seq is not None:
neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1)
neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1)
neg_pred = model(
img=neg_img_input,
img_ids=neg_img_input_ids,
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
timestep_index=step_index,
total_num_timesteps=total_steps,
controlnet_double_block_residuals=None,
controlnet_single_block_residuals=None,
ip_adapter_extensions=neg_ip_adapter_extensions,
regional_prompting_extension=neg_regional_prompting_extension,
)
# Slice negative prediction to match main image tokens
if img_cond_seq is not None:
neg_pred = neg_pred[:, :original_seq_len]
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
preview_img = img - t_curr * pred
img = img + (t_prev - t_curr) * pred
if inpaint_extension is not None:
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
step_callback(
PipelineIntermediateState(
step=step_index + 1,
order=1,
total_steps=total_steps,
timestep=int(t_curr),
latents=preview_img,
),
)
return img
finally:
# DyPE: Restore original position embedder
if original_pe_embedder is not None:
DyPEExtension.restore_model(model, original_pe_embedder)

View File

@@ -0,0 +1,35 @@
"""Dynamic Position Extrapolation (DyPE) for FLUX models.
DyPE enables high-resolution image generation (4K+) with pretrained FLUX models
by dynamically scaling RoPE position embeddings during the denoising process.
Based on: https://github.com/wildminder/ComfyUI-DyPE
"""
from invokeai.backend.flux.dype.base import DyPEConfig
from invokeai.backend.flux.dype.embed import DyPEEmbedND
from invokeai.backend.flux.dype.presets import (
DYPE_PRESET_4K,
DYPE_PRESET_AREA,
DYPE_PRESET_AUTO,
DYPE_PRESET_LABELS,
DYPE_PRESET_MANUAL,
DYPE_PRESET_OFF,
DyPEPreset,
get_dype_config_for_area,
get_dype_config_for_resolution,
)
__all__ = [
"DyPEConfig",
"DyPEEmbedND",
"DyPEPreset",
"DYPE_PRESET_OFF",
"DYPE_PRESET_MANUAL",
"DYPE_PRESET_AUTO",
"DYPE_PRESET_AREA",
"DYPE_PRESET_4K",
"DYPE_PRESET_LABELS",
"get_dype_config_for_area",
"get_dype_config_for_resolution",
]

View File

@@ -0,0 +1,260 @@
"""DyPE base configuration and utilities."""
import math
from dataclasses import dataclass
from typing import Literal
import torch
from torch import Tensor
@dataclass
class DyPEConfig:
"""Configuration for Dynamic Position Extrapolation."""
enable_dype: bool = True
base_resolution: int = 1024 # Native training resolution
method: Literal["vision_yarn", "yarn", "ntk", "base"] = "vision_yarn"
dype_scale: float = 2.0 # Magnitude λs (0.0-8.0)
dype_exponent: float = 2.0 # Decay speed λt (0.0-1000.0)
dype_start_sigma: float = 1.0 # When DyPE decay starts
def get_mscale(scale: float, mscale_factor: float = 1.0) -> float:
"""Calculate magnitude scaling factor.
Args:
scale: The resolution scaling factor
mscale_factor: Adjustment factor for the scaling
Returns:
The magnitude scaling factor
"""
if scale <= 1.0:
return 1.0
return mscale_factor * math.log(scale) + 1.0
def get_timestep_mscale(
scale: float,
current_sigma: float,
dype_scale: float,
dype_exponent: float,
dype_start_sigma: float,
) -> float:
"""Calculate timestep-dependent magnitude scaling.
The key insight of DyPE: early steps focus on low frequencies (global structure),
late steps on high frequencies (details). This function modulates the scaling
based on the current timestep/sigma.
Args:
scale: Resolution scaling factor
current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
dype_scale: DyPE magnitude (λs)
dype_exponent: DyPE decay speed (λt)
dype_start_sigma: Sigma threshold to start decay
Returns:
Timestep-modulated scaling factor
"""
if scale <= 1.0:
return 1.0
# Normalize sigma to [0, 1] range relative to start_sigma
if current_sigma >= dype_start_sigma:
t_normalized = 1.0
else:
t_normalized = current_sigma / dype_start_sigma
# Apply exponential decay: stronger extrapolation early, weaker late
# decay = exp(-λt * (1 - t)) where t=1 is early (high sigma), t=0 is late
decay = math.exp(-dype_exponent * (1.0 - t_normalized))
# Base mscale from resolution
base_mscale = get_mscale(scale)
# Interpolate between base_mscale and 1.0 based on decay and dype_scale
# When decay=1 (early): use scaled value
# When decay=0 (late): use base value
scaled_mscale = 1.0 + (base_mscale - 1.0) * dype_scale * decay
return scaled_mscale
def compute_vision_yarn_freqs(
pos: Tensor,
dim: int,
theta: int,
scale_h: float,
scale_w: float,
current_sigma: float,
dype_config: DyPEConfig,
) -> tuple[Tensor, Tensor]:
"""Compute RoPE frequencies using NTK-aware scaling for high-resolution.
This method extends FLUX's position encoding to handle resolutions beyond
the 1024px training resolution by scaling the base frequency (theta).
The NTK-aware approach smoothly interpolates frequencies to cover larger
position ranges without breaking the attention patterns.
DyPE (Dynamic Position Extrapolation) modulates the NTK scaling based on
the current timestep - stronger extrapolation in early steps (global structure),
weaker in late steps (fine details).
Args:
pos: Position tensor
dim: Embedding dimension
theta: RoPE base frequency
scale_h: Height scaling factor
scale_w: Width scaling factor
current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
dype_config: DyPE configuration
Returns:
Tuple of (cos, sin) frequency tensors
"""
assert dim % 2 == 0
# Use the larger scale for NTK calculation
scale = max(scale_h, scale_w)
device = pos.device
dtype = torch.float64 if device.type != "mps" else torch.float32
# NTK-aware theta scaling: extends position coverage for high-res
# Formula: theta_scaled = theta * scale^(dim/(dim-2))
# This increases the wavelength of position encodings proportionally
if scale > 1.0:
ntk_alpha = scale ** (dim / (dim - 2))
# Apply timestep-dependent DyPE modulation
# mscale controls how strongly we apply the NTK extrapolation
# Early steps (high sigma): stronger extrapolation for global structure
# Late steps (low sigma): weaker extrapolation for fine details
mscale = get_timestep_mscale(
scale=scale,
current_sigma=current_sigma,
dype_scale=dype_config.dype_scale,
dype_exponent=dype_config.dype_exponent,
dype_start_sigma=dype_config.dype_start_sigma,
)
# Modulate NTK alpha by mscale
# When mscale > 1: interpolate towards stronger extrapolation
# When mscale = 1: use base NTK alpha
modulated_alpha = 1.0 + (ntk_alpha - 1.0) * mscale
scaled_theta = theta * modulated_alpha
else:
scaled_theta = theta
# Standard RoPE frequency computation
freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
freqs = 1.0 / (scaled_theta**freq_seq)
# Compute angles = position * frequency
angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)
cos = torch.cos(angles)
sin = torch.sin(angles)
return cos.to(pos.dtype), sin.to(pos.dtype)
def compute_yarn_freqs(
pos: Tensor,
dim: int,
theta: int,
scale: float,
current_sigma: float,
dype_config: DyPEConfig,
) -> tuple[Tensor, Tensor]:
"""Compute RoPE frequencies using YARN/NTK method.
Uses NTK-aware theta scaling for high-resolution support with
timestep-dependent DyPE modulation.
Args:
pos: Position tensor
dim: Embedding dimension
theta: RoPE base frequency
scale: Uniform scaling factor
current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
dype_config: DyPE configuration
Returns:
Tuple of (cos, sin) frequency tensors
"""
assert dim % 2 == 0
device = pos.device
dtype = torch.float64 if device.type != "mps" else torch.float32
# NTK-aware theta scaling with DyPE modulation
if scale > 1.0:
ntk_alpha = scale ** (dim / (dim - 2))
# Apply timestep-dependent DyPE modulation
mscale = get_timestep_mscale(
scale=scale,
current_sigma=current_sigma,
dype_scale=dype_config.dype_scale,
dype_exponent=dype_config.dype_exponent,
dype_start_sigma=dype_config.dype_start_sigma,
)
# Modulate NTK alpha by mscale
modulated_alpha = 1.0 + (ntk_alpha - 1.0) * mscale
scaled_theta = theta * modulated_alpha
else:
scaled_theta = theta
freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
freqs = 1.0 / (scaled_theta**freq_seq)
angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)
cos = torch.cos(angles)
sin = torch.sin(angles)
return cos.to(pos.dtype), sin.to(pos.dtype)
def compute_ntk_freqs(
pos: Tensor,
dim: int,
theta: int,
scale: float,
) -> tuple[Tensor, Tensor]:
"""Compute RoPE frequencies using NTK method.
Neural Tangent Kernel approach - continuous frequency scaling without
timestep dependency.
Args:
pos: Position tensor
dim: Embedding dimension
theta: RoPE base frequency
scale: Scaling factor
Returns:
Tuple of (cos, sin) frequency tensors
"""
assert dim % 2 == 0
device = pos.device
dtype = torch.float64 if device.type != "mps" else torch.float32
# NTK scaling
scaled_theta = theta * (scale ** (dim / (dim - 2)))
freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
freqs = 1.0 / (scaled_theta**freq_seq)
angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)
cos = torch.cos(angles)
sin = torch.sin(angles)
return cos.to(pos.dtype), sin.to(pos.dtype)

View File

@@ -0,0 +1,116 @@
"""DyPE-enhanced position embedding module."""
import torch
from torch import Tensor, nn
from invokeai.backend.flux.dype.base import DyPEConfig
from invokeai.backend.flux.dype.rope import rope_dype
class DyPEEmbedND(nn.Module):
"""N-dimensional position embedding with DyPE support.
This class replaces the standard EmbedND from FLUX with a DyPE-aware version
that dynamically scales position embeddings based on resolution and timestep.
The key difference from EmbedND:
- Maintains step state (current_sigma, target dimensions)
- Uses rope_dype() instead of rope() for frequency computation
- Applies timestep-dependent scaling for better high-resolution generation
"""
def __init__(
self,
dim: int,
theta: int,
axes_dim: list[int],
dype_config: DyPEConfig,
):
"""Initialize DyPE position embedder.
Args:
dim: Total embedding dimension (sum of axes_dim)
theta: RoPE base frequency
axes_dim: Dimension allocation per axis (e.g., [16, 56, 56] for FLUX)
dype_config: DyPE configuration
"""
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
self.dype_config = dype_config
# Step state - updated before each denoising step
self._current_sigma: float = 1.0
self._target_height: int = 1024
self._target_width: int = 1024
def set_step_state(self, sigma: float, height: int, width: int) -> None:
"""Update the step state before each denoising step.
This method should be called by the DyPE extension before each step
to update the current noise level and target dimensions.
Args:
sigma: Current noise level (timestep value, 1.0 = full noise)
height: Target image height in pixels
width: Target image width in pixels
"""
self._current_sigma = sigma
self._target_height = height
self._target_width = width
def forward(self, ids: Tensor) -> Tensor:
"""Compute position embeddings with DyPE scaling.
Args:
ids: Position indices tensor with shape (batch, seq_len, n_axes)
For FLUX: n_axes=3 (time/channel, height, width)
Returns:
Position embedding tensor with shape (batch, 1, seq_len, dim)
"""
n_axes = ids.shape[-1]
# Compute RoPE for each axis with DyPE scaling
embeddings = []
for i in range(n_axes):
axis_emb = rope_dype(
pos=ids[..., i],
dim=self.axes_dim[i],
theta=self.theta,
current_sigma=self._current_sigma,
target_height=self._target_height,
target_width=self._target_width,
dype_config=self.dype_config,
)
embeddings.append(axis_emb)
# Concatenate embeddings from all axes
emb = torch.cat(embeddings, dim=-3)
return emb.unsqueeze(1)
@classmethod
def from_embednd(
cls,
embed_nd: nn.Module,
dype_config: DyPEConfig,
) -> "DyPEEmbedND":
"""Create a DyPEEmbedND from an existing EmbedND.
This is a convenience method for patching an existing FLUX model.
Args:
embed_nd: Original EmbedND module from FLUX
dype_config: DyPE configuration
Returns:
New DyPEEmbedND with same parameters
"""
return cls(
dim=embed_nd.dim,
theta=embed_nd.theta,
axes_dim=embed_nd.axes_dim,
dype_config=dype_config,
)

View File

@@ -0,0 +1,203 @@
"""DyPE presets and automatic configuration."""
import math
from dataclasses import dataclass
from typing import Literal
from invokeai.backend.flux.dype.base import DyPEConfig
# DyPE preset type - using Literal for proper frontend dropdown support
DyPEPreset = Literal["off", "manual", "auto", "area", "4k"]
# Constants for preset values
DYPE_PRESET_OFF: DyPEPreset = "off"
DYPE_PRESET_MANUAL: DyPEPreset = "manual"
DYPE_PRESET_AUTO: DyPEPreset = "auto"
DYPE_PRESET_AREA: DyPEPreset = "area"
DYPE_PRESET_4K: DyPEPreset = "4k"
# Human-readable labels for the UI
DYPE_PRESET_LABELS: dict[str, str] = {
"off": "Off",
"manual": "Manual",
"auto": "Auto (>1536px)",
"area": "Area (auto)",
"4k": "4K Optimized",
}
@dataclass
class DyPEPresetConfig:
"""Preset configuration values."""
base_resolution: int
method: str
dype_scale: float
dype_exponent: float
dype_start_sigma: float
# Predefined preset configurations
DYPE_PRESETS: dict[DyPEPreset, DyPEPresetConfig] = {
DYPE_PRESET_4K: DyPEPresetConfig(
base_resolution=1024,
method="vision_yarn",
dype_scale=2.0,
dype_exponent=2.0,
dype_start_sigma=1.0,
),
}
def get_dype_config_for_resolution(
width: int,
height: int,
base_resolution: int = 1024,
activation_threshold: int = 1536,
) -> DyPEConfig | None:
"""Automatically determine DyPE config based on target resolution.
FLUX can handle resolutions up to ~1.5x natively without significant artifacts.
DyPE is only activated when the resolution exceeds the activation threshold.
Args:
width: Target image width in pixels
height: Target image height in pixels
base_resolution: Native training resolution of the model (for scale calculation)
activation_threshold: Resolution threshold above which DyPE is activated
Returns:
DyPEConfig if DyPE should be enabled, None otherwise
"""
max_dim = max(width, height)
if max_dim <= activation_threshold:
return None # FLUX can handle this natively
# Calculate scaling factor based on base_resolution
scale = max_dim / base_resolution
# Dynamic parameters based on scaling
# Higher resolution = higher dype_scale, capped at 8.0
dynamic_dype_scale = min(2.0 * scale, 8.0)
return DyPEConfig(
enable_dype=True,
base_resolution=base_resolution,
method="vision_yarn",
dype_scale=dynamic_dype_scale,
dype_exponent=2.0,
dype_start_sigma=1.0,
)
def get_dype_config_for_area(
width: int,
height: int,
base_resolution: int = 1024,
) -> DyPEConfig | None:
"""Automatically determine DyPE config based on target area.
Uses sqrt(area/base_area) as an effective side-length ratio.
DyPE is enabled only when target area exceeds base area.
Returns:
DyPEConfig if DyPE should be enabled, None otherwise
"""
area = width * height
base_area = base_resolution**2
if area <= base_area:
return None
area_ratio = area / base_area
effective_side_ratio = math.sqrt(area_ratio) # 1.0 at base, 2.0 at 2K (if base is 1K)
# Strength: 0 at base area, 8 at sat_area, clamped thereafter.
sat_area = 2027520 # Determined by experimentation where a vertical line appears
sat_side_ratio = math.sqrt(sat_area / base_area)
dynamic_dype_scale = 8.0 * (effective_side_ratio - 1.0) / (sat_side_ratio - 1.0)
dynamic_dype_scale = max(0.0, min(dynamic_dype_scale, 8.0))
# Continuous exponent schedule:
# r=1 -> 0.5, r=2 -> 1.0, r=4 -> 2.0 (exact), smoothly varying in between.
x = math.log2(effective_side_ratio)
dype_exponent = 0.25 * (x**2) + 0.25 * x + 0.5
dype_exponent = max(0.5, min(dype_exponent, 2.0))
return DyPEConfig(
enable_dype=True,
base_resolution=base_resolution,
method="vision_yarn",
dype_scale=dynamic_dype_scale,
dype_exponent=dype_exponent,
dype_start_sigma=1.0,
)
def get_dype_config_from_preset(
preset: DyPEPreset,
width: int,
height: int,
custom_scale: float | None = None,
custom_exponent: float | None = None,
) -> DyPEConfig | None:
"""Get DyPE configuration from a preset or custom values.
Args:
preset: The DyPE preset to use
width: Target image width
height: Target image height
custom_scale: Optional custom dype_scale (only used with 'manual' preset)
custom_exponent: Optional custom dype_exponent (only used with 'manual' preset)
Returns:
DyPEConfig if DyPE should be enabled, None otherwise
"""
if preset == DYPE_PRESET_OFF:
return None
if preset == DYPE_PRESET_MANUAL:
# Manual mode - custom values can override defaults
max_dim = max(width, height)
scale = max_dim / 1024
dynamic_dype_scale = min(2.0 * scale, 8.0)
return DyPEConfig(
enable_dype=True,
base_resolution=1024,
method="vision_yarn",
dype_scale=custom_scale if custom_scale is not None else dynamic_dype_scale,
dype_exponent=custom_exponent if custom_exponent is not None else 2.0,
dype_start_sigma=1.0,
)
if preset == DYPE_PRESET_AUTO:
# Auto preset - custom values are ignored
return get_dype_config_for_resolution(
width=width,
height=height,
base_resolution=1024,
activation_threshold=1536,
)
if preset == DYPE_PRESET_AREA:
# Area-based preset - custom values are ignored
return get_dype_config_for_area(
width=width,
height=height,
base_resolution=1024,
)
# Use preset configuration (4K etc.) - custom values are ignored
preset_config = DYPE_PRESETS.get(preset)
if preset_config is None:
return None
return DyPEConfig(
enable_dype=True,
base_resolution=preset_config.base_resolution,
method=preset_config.method,
dype_scale=preset_config.dype_scale,
dype_exponent=preset_config.dype_exponent,
dype_start_sigma=preset_config.dype_start_sigma,
)

View File

@@ -0,0 +1,110 @@
"""DyPE-enhanced RoPE (Rotary Position Embedding) functions."""
import torch
from einops import rearrange
from torch import Tensor
from invokeai.backend.flux.dype.base import (
DyPEConfig,
compute_ntk_freqs,
compute_vision_yarn_freqs,
compute_yarn_freqs,
)
def rope_dype(
pos: Tensor,
dim: int,
theta: int,
current_sigma: float,
target_height: int,
target_width: int,
dype_config: DyPEConfig,
) -> Tensor:
"""Compute RoPE with Dynamic Position Extrapolation.
This is the core DyPE function that replaces the standard rope() function.
It applies resolution-aware and timestep-aware scaling to position embeddings.
Args:
pos: Position indices tensor
dim: Embedding dimension per axis
theta: RoPE base frequency (typically 10000)
current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
target_height: Target image height in pixels
target_width: Target image width in pixels
dype_config: DyPE configuration
Returns:
Rotary position embedding tensor with shape suitable for FLUX attention
"""
assert dim % 2 == 0
# Calculate scaling factors
base_res = dype_config.base_resolution
scale_h = target_height / base_res
scale_w = target_width / base_res
scale = max(scale_h, scale_w)
# If no scaling needed and DyPE disabled, use base method
if not dype_config.enable_dype or scale <= 1.0:
return _rope_base(pos, dim, theta)
# Select method and compute frequencies
method = dype_config.method
if method == "vision_yarn":
cos, sin = compute_vision_yarn_freqs(
pos=pos,
dim=dim,
theta=theta,
scale_h=scale_h,
scale_w=scale_w,
current_sigma=current_sigma,
dype_config=dype_config,
)
elif method == "yarn":
cos, sin = compute_yarn_freqs(
pos=pos,
dim=dim,
theta=theta,
scale=scale,
current_sigma=current_sigma,
dype_config=dype_config,
)
elif method == "ntk":
cos, sin = compute_ntk_freqs(
pos=pos,
dim=dim,
theta=theta,
scale=scale,
)
else: # "base"
return _rope_base(pos, dim, theta)
# Construct rotation matrix from cos/sin
# Output shape: (batch, seq_len, dim/2, 2, 2)
out = torch.stack([cos, -sin, sin, cos], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.to(dtype=pos.dtype, device=pos.device)
def _rope_base(pos: Tensor, dim: int, theta: int) -> Tensor:
"""Standard RoPE without DyPE scaling.
This matches the original rope() function from invokeai.backend.flux.math.
"""
assert dim % 2 == 0
device = pos.device
dtype = torch.float64 if device.type != "mps" else torch.float32
scale = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos.to(dtype), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.to(dtype=pos.dtype, device=pos.device)

View File

@@ -0,0 +1,91 @@
"""DyPE extension for FLUX denoising pipeline."""
from dataclasses import dataclass
from typing import TYPE_CHECKING
from invokeai.backend.flux.dype.base import DyPEConfig
from invokeai.backend.flux.dype.embed import DyPEEmbedND
if TYPE_CHECKING:
from invokeai.backend.flux.model import Flux
@dataclass
class DyPEExtension:
"""Extension for Dynamic Position Extrapolation in FLUX models.
This extension manages the patching of the FLUX model's position embedder
and updates the step state during denoising.
Usage:
1. Create extension with config and target dimensions
2. Call patch_model() to replace pe_embedder with DyPE version
3. Call update_step_state() before each denoising step
4. Call restore_model() after denoising to restore original embedder
"""
config: DyPEConfig
target_height: int
target_width: int
def patch_model(self, model: "Flux") -> tuple[DyPEEmbedND, object]:
"""Patch the model's position embedder with DyPE version.
Args:
model: The FLUX model to patch
Returns:
Tuple of (new DyPE embedder, original embedder for restoration)
"""
original_embedder = model.pe_embedder
dype_embedder = DyPEEmbedND.from_embednd(
embed_nd=original_embedder,
dype_config=self.config,
)
# Set initial state
dype_embedder.set_step_state(
sigma=1.0,
height=self.target_height,
width=self.target_width,
)
# Replace the embedder
model.pe_embedder = dype_embedder
return dype_embedder, original_embedder
def update_step_state(
self,
embedder: DyPEEmbedND,
timestep: float,
timestep_index: int,
total_steps: int,
) -> None:
"""Update the step state in the DyPE embedder.
This should be called before each denoising step to update the
current noise level for timestep-dependent scaling.
Args:
embedder: The DyPE embedder to update
timestep: Current timestep value (sigma/noise level)
timestep_index: Current step index (0-based)
total_steps: Total number of denoising steps
"""
embedder.set_step_state(
sigma=timestep,
height=self.target_height,
width=self.target_width,
)
@staticmethod
def restore_model(model: "Flux", original_embedder: object) -> None:
"""Restore the original position embedder.
Args:
model: The FLUX model to restore
original_embedder: The original embedder saved from patch_model()
"""
model.pe_embedder = original_embedder

View File

@@ -0,0 +1,62 @@
"""Flow Matching scheduler definitions and mapping.
This module provides the scheduler types and mapping for Flow Matching models
(Flux and Z-Image), supporting multiple schedulers from the diffusers library.
"""
from typing import Literal, Type
from diffusers import (
FlowMatchEulerDiscreteScheduler,
FlowMatchHeunDiscreteScheduler,
)
from diffusers.schedulers.scheduling_utils import SchedulerMixin
# Note: FlowMatchLCMScheduler may not be available in all diffusers versions
try:
from diffusers import FlowMatchLCMScheduler
_HAS_LCM = True
except ImportError:
_HAS_LCM = False
# Scheduler name literal type for type checking
FLUX_SCHEDULER_NAME_VALUES = Literal["euler", "heun", "lcm"]
# Human-readable labels for the UI
FLUX_SCHEDULER_LABELS: dict[str, str] = {
"euler": "Euler",
"heun": "Heun (2nd order)",
"lcm": "LCM",
}
# Mapping from scheduler names to scheduler classes
FLUX_SCHEDULER_MAP: dict[str, Type[SchedulerMixin]] = {
"euler": FlowMatchEulerDiscreteScheduler,
"heun": FlowMatchHeunDiscreteScheduler,
}
if _HAS_LCM:
FLUX_SCHEDULER_MAP["lcm"] = FlowMatchLCMScheduler
# Z-Image scheduler types (same schedulers as Flux, both use Flow Matching)
# Note: Z-Image-Turbo is optimized for ~8 steps with Euler, but other schedulers
# can be used for experimentation.
ZIMAGE_SCHEDULER_NAME_VALUES = Literal["euler", "heun", "lcm"]
# Human-readable labels for the UI
ZIMAGE_SCHEDULER_LABELS: dict[str, str] = {
"euler": "Euler",
"heun": "Heun (2nd order)",
"lcm": "LCM",
}
# Mapping from scheduler names to scheduler classes (same as Flux)
ZIMAGE_SCHEDULER_MAP: dict[str, Type[SchedulerMixin]] = {
"euler": FlowMatchEulerDiscreteScheduler,
"heun": FlowMatchHeunDiscreteScheduler,
}
if _HAS_LCM:
ZIMAGE_SCHEDULER_MAP["lcm"] = FlowMatchLCMScheduler

View File

@@ -5,7 +5,7 @@ from typing import Literal
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
from invokeai.backend.model_manager.taxonomy import AnyVariant, FluxVariantType
from invokeai.backend.model_manager.taxonomy import AnyVariant, Flux2VariantType, FluxVariantType
@dataclass
@@ -46,6 +46,8 @@ _flux_max_seq_lengths: dict[AnyVariant, Literal[256, 512]] = {
FluxVariantType.Dev: 512,
FluxVariantType.DevFill: 512,
FluxVariantType.Schnell: 256,
Flux2VariantType.Klein4B: 512,
Flux2VariantType.Klein9B: 512,
}
@@ -117,6 +119,38 @@ _flux_transformer_params: dict[AnyVariant, FluxParams] = {
qkv_bias=True,
guidance_embed=True,
),
# Flux2 Klein 4B uses Qwen3 4B text encoder with stacked embeddings from layers [9, 18, 27]
# The context_in_dim is 3 * hidden_size of Qwen3 (3 * 2560 = 7680)
Flux2VariantType.Klein4B: FluxParams(
in_channels=64,
vec_in_dim=2560, # Qwen3-4B hidden size (used for pooled output)
context_in_dim=7680, # 3 layers * 2560 = 7680 for Qwen3-4B
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
# Flux2 Klein 9B uses Qwen3 8B text encoder with stacked embeddings from layers [9, 18, 27]
# The context_in_dim is 3 * hidden_size of Qwen3 (3 * 4096 = 12288)
Flux2VariantType.Klein9B: FluxParams(
in_channels=64,
vec_in_dim=4096, # Qwen3-8B hidden size (used for pooled output)
context_in_dim=12288, # 3 layers * 4096 = 12288 for Qwen3-8B
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
}

View File

@@ -0,0 +1,4 @@
"""FLUX.2 backend modules.
This package contains modules specific to FLUX.2 models (e.g., Klein).
"""

View File

@@ -0,0 +1,288 @@
"""Flux2 Klein Denoising Function.
This module provides the denoising function for FLUX.2 Klein models,
which use Qwen3 as the text encoder instead of CLIP+T5.
"""
import inspect
import math
from typing import Any, Callable
import numpy as np
import torch
from tqdm import tqdm
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
def denoise(
model: torch.nn.Module,
# model input
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
# sampling parameters
timesteps: list[float],
step_callback: Callable[[PipelineIntermediateState], None],
cfg_scale: list[float],
# Negative conditioning for CFG
neg_txt: torch.Tensor | None = None,
neg_txt_ids: torch.Tensor | None = None,
# Scheduler for stepping (e.g., FlowMatchEulerDiscreteScheduler, FlowMatchHeunDiscreteScheduler)
scheduler: Any = None,
# Dynamic shifting parameter for FLUX.2 Klein (computed from image resolution)
mu: float | None = None,
# Inpainting extension for merging latents during denoising
inpaint_extension: RectifiedFlowInpaintExtension | None = None,
# Reference image conditioning (multi-reference image editing)
img_cond_seq: torch.Tensor | None = None,
img_cond_seq_ids: torch.Tensor | None = None,
) -> torch.Tensor:
"""Denoise latents using a FLUX.2 Klein transformer model.
This is a simplified denoise function for FLUX.2 Klein models that uses
the diffusers Flux2Transformer2DModel interface.
Note: FLUX.2 Klein has guidance_embeds=False, so no guidance parameter is used.
CFG is applied externally using negative conditioning when cfg_scale != 1.0.
Args:
model: The Flux2Transformer2DModel from diffusers.
img: Packed latent image tensor of shape (B, seq_len, channels).
img_ids: Image position IDs tensor.
txt: Text encoder hidden states (Qwen3 embeddings).
txt_ids: Text position IDs tensor.
timesteps: List of timesteps for denoising schedule (linear sigmas from 1.0 to 1/n).
step_callback: Callback function for progress updates.
cfg_scale: List of CFG scale values per step.
neg_txt: Negative text embeddings for CFG (optional).
neg_txt_ids: Negative text position IDs (optional).
scheduler: Optional diffusers scheduler (Euler, Heun, LCM). If None, uses manual Euler.
mu: Dynamic shifting parameter computed from image resolution. Required when scheduler
has use_dynamic_shifting=True.
Returns:
Denoised latent tensor.
"""
total_steps = len(timesteps) - 1
# Store original sequence length for extracting output later (before concatenating reference images)
original_seq_len = img.shape[1]
# Concatenate reference image conditioning if provided (multi-reference image editing)
if img_cond_seq is not None and img_cond_seq_ids is not None:
img = torch.cat([img, img_cond_seq], dim=1)
img_ids = torch.cat([img_ids, img_cond_seq_ids], dim=1)
# Klein has guidance_embeds=False, but the transformer forward() still requires a guidance tensor
# We pass a dummy value (1.0) since it won't affect the output when guidance_embeds=False
guidance = torch.full((img.shape[0],), 1.0, device=img.device, dtype=img.dtype)
# Use scheduler if provided
use_scheduler = scheduler is not None
if use_scheduler:
# Set up scheduler with sigmas and mu for dynamic shifting
# Convert timesteps (0-1 range) to sigmas for the scheduler
# The scheduler will apply dynamic shifting internally using mu (if enabled in scheduler config)
sigmas = np.array(timesteps[:-1], dtype=np.float32) # Exclude final 0.0
# Check if scheduler supports sigmas parameter using inspect.signature
# FlowMatchHeunDiscreteScheduler and FlowMatchLCMScheduler don't support sigmas
set_timesteps_sig = inspect.signature(scheduler.set_timesteps)
supports_sigmas = "sigmas" in set_timesteps_sig.parameters
if supports_sigmas and mu is not None:
# Pass mu if provided - it will only be used if scheduler has use_dynamic_shifting=True
scheduler.set_timesteps(sigmas=sigmas.tolist(), mu=mu, device=img.device)
elif supports_sigmas:
scheduler.set_timesteps(sigmas=sigmas.tolist(), device=img.device)
else:
# Scheduler doesn't support sigmas (e.g., Heun, LCM) - use num_inference_steps
scheduler.set_timesteps(num_inference_steps=len(sigmas), device=img.device)
num_scheduler_steps = len(scheduler.timesteps)
is_heun = hasattr(scheduler, "state_in_first_order")
user_step = 0
pbar = tqdm(total=total_steps, desc="Denoising")
for step_index in range(num_scheduler_steps):
timestep = scheduler.timesteps[step_index]
# Convert scheduler timestep (0-1000) to normalized (0-1) for the model
t_curr = timestep.item() / scheduler.config.num_train_timesteps
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
# Track if we're in first or second order step (for Heun)
in_first_order = scheduler.state_in_first_order if is_heun else True
# Run the transformer model (matching diffusers: guidance=guidance, return_dict=False)
output = model(
hidden_states=img,
encoder_hidden_states=txt,
timestep=t_vec,
img_ids=img_ids,
txt_ids=txt_ids,
guidance=guidance,
return_dict=False,
)
# Extract the sample from the output (return_dict=False returns tuple)
pred = output[0] if isinstance(output, tuple) else output
step_cfg_scale = cfg_scale[min(user_step, len(cfg_scale) - 1)]
# Apply CFG if scale is not 1.0
if not math.isclose(step_cfg_scale, 1.0):
if neg_txt is None:
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
neg_output = model(
hidden_states=img,
encoder_hidden_states=neg_txt,
timestep=t_vec,
img_ids=img_ids,
txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids,
guidance=guidance,
return_dict=False,
)
neg_pred = neg_output[0] if isinstance(neg_output, tuple) else neg_output
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
# Use scheduler.step() for the update
step_output = scheduler.step(model_output=pred, timestep=timestep, sample=img)
img = step_output.prev_sample
# Get t_prev for inpainting (next sigma value)
if step_index + 1 < len(scheduler.sigmas):
t_prev = scheduler.sigmas[step_index + 1].item()
else:
t_prev = 0.0
# Apply inpainting merge at each step
if inpaint_extension is not None:
# Separate the generated latents from the reference conditioning
gen_img = img[:, :original_seq_len, :]
ref_img = img[:, original_seq_len:, :]
# Merge only the generated part
gen_img = inpaint_extension.merge_intermediate_latents_with_init_latents(gen_img, t_prev)
# Concatenate back together
img = torch.cat([gen_img, ref_img], dim=1)
# For Heun, only increment user step after second-order step completes
if is_heun:
if not in_first_order:
user_step += 1
if user_step <= total_steps:
pbar.update(1)
preview_img = img - t_curr * pred
if inpaint_extension is not None:
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(
preview_img, 0.0
)
step_callback(
PipelineIntermediateState(
step=user_step,
order=2,
total_steps=total_steps,
timestep=int(t_curr * 1000),
latents=preview_img,
),
)
else:
user_step += 1
if user_step <= total_steps:
pbar.update(1)
preview_img = img - t_curr * pred
if inpaint_extension is not None:
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
# Extract only the generated image portion for preview (exclude reference images)
callback_latents = preview_img[:, :original_seq_len, :] if img_cond_seq is not None else preview_img
step_callback(
PipelineIntermediateState(
step=user_step,
order=1,
total_steps=total_steps,
timestep=int(t_curr * 1000),
latents=callback_latents,
),
)
pbar.close()
else:
# Manual Euler stepping (original behavior)
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
# Run the transformer model (matching diffusers: guidance=guidance, return_dict=False)
output = model(
hidden_states=img,
encoder_hidden_states=txt,
timestep=t_vec,
img_ids=img_ids,
txt_ids=txt_ids,
guidance=guidance,
return_dict=False,
)
# Extract the sample from the output (return_dict=False returns tuple)
pred = output[0] if isinstance(output, tuple) else output
step_cfg_scale = cfg_scale[step_index]
# Apply CFG if scale is not 1.0
if not math.isclose(step_cfg_scale, 1.0):
if neg_txt is None:
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
neg_output = model(
hidden_states=img,
encoder_hidden_states=neg_txt,
timestep=t_vec,
img_ids=img_ids,
txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids,
guidance=guidance,
return_dict=False,
)
neg_pred = neg_output[0] if isinstance(neg_output, tuple) else neg_output
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
# Euler step
preview_img = img - t_curr * pred
img = img + (t_prev - t_curr) * pred
# Apply inpainting merge at each step
if inpaint_extension is not None:
# Separate the generated latents from the reference conditioning
gen_img = img[:, :original_seq_len, :]
ref_img = img[:, original_seq_len:, :]
# Merge only the generated part
gen_img = inpaint_extension.merge_intermediate_latents_with_init_latents(gen_img, t_prev)
# Concatenate back together
img = torch.cat([gen_img, ref_img], dim=1)
# Handling preview images
preview_gen = preview_img[:, :original_seq_len, :]
preview_gen = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_gen, 0.0)
# Extract only the generated image portion for preview (exclude reference images)
callback_latents = preview_img[:, :original_seq_len, :] if img_cond_seq is not None else preview_img
step_callback(
PipelineIntermediateState(
step=step_index + 1,
order=1,
total_steps=total_steps,
timestep=int(t_curr),
latents=callback_latents,
),
)
# Extract only the generated image portion (exclude concatenated reference images)
if img_cond_seq is not None:
img = img[:, :original_seq_len, :]
return img

View File

@@ -0,0 +1,294 @@
"""FLUX.2 Klein Reference Image Extension for multi-reference image editing.
This module provides the Flux2RefImageExtension for FLUX.2 Klein models,
which handles encoding reference images using the FLUX.2 VAE and
generating the appropriate position IDs for multi-reference image editing.
FLUX.2 Klein has built-in support for reference image editing (unlike FLUX.1
which requires a separate Kontext model).
"""
import math
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from einops import repeat
from PIL import Image
from invokeai.app.invocations.fields import FluxKontextConditioningField
from invokeai.app.invocations.model import VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux2.sampling_utils import pack_flux2
from invokeai.backend.util.devices import TorchDevice
# Maximum pixel counts for reference images (matches BFL FLUX.2 sampling.py)
# Single reference image: 2024² pixels, Multiple: 1024² pixels
MAX_PIXELS_SINGLE_REF = 2024**2 # ~4.1M pixels
MAX_PIXELS_MULTI_REF = 1024**2 # ~1M pixels
def resize_image_to_max_pixels(image: Image.Image, max_pixels: int) -> Image.Image:
"""Resize image to fit within max_pixels while preserving aspect ratio.
This matches the BFL FLUX.2 sampling.py cap_pixels() behavior.
Args:
image: PIL Image to resize.
max_pixels: Maximum total pixel count (width * height).
Returns:
Resized PIL Image (or original if already within bounds).
"""
width, height = image.size
pixel_count = width * height
if pixel_count <= max_pixels:
return image
# Calculate scale factor to fit within max_pixels (BFL approach)
scale = math.sqrt(max_pixels / pixel_count)
new_width = int(width * scale)
new_height = int(height * scale)
# Ensure dimensions are at least 1
new_width = max(1, new_width)
new_height = max(1, new_height)
return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
def generate_img_ids_flux2_with_offset(
latent_height: int,
latent_width: int,
batch_size: int,
device: torch.device,
idx_offset: int = 0,
h_offset: int = 0,
w_offset: int = 0,
) -> torch.Tensor:
"""Generate tensor of image position ids with optional offsets for FLUX.2.
FLUX.2 uses 4D position coordinates (T, H, W, L) for its rotary position embeddings.
Position IDs use int64 (long) dtype.
Args:
latent_height: Height of image in latent space (before packing).
latent_width: Width of image in latent space (before packing).
batch_size: Number of images in the batch.
device: Device to create tensors on.
idx_offset: Offset for T (time/index) coordinate - use 1 for reference images.
h_offset: Spatial offset for H coordinate in latent space.
w_offset: Spatial offset for W coordinate in latent space.
Returns:
Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 4].
"""
# After packing, the spatial dimensions are halved due to the 2x2 patch structure
packed_height = latent_height // 2
packed_width = latent_width // 2
# Convert spatial offsets from latent space to packed space
packed_h_offset = h_offset // 2
packed_w_offset = w_offset // 2
# Create base tensor for position IDs with shape [packed_height, packed_width, 4]
# The 4 channels represent: [T, H, W, L]
img_ids = torch.zeros(packed_height, packed_width, 4, device=device, dtype=torch.long)
# Set T (time/index offset) for all positions - use 1 for reference images
img_ids[..., 0] = idx_offset
# Set H (height/y) coordinates with offset
h_coords = torch.arange(packed_height, device=device, dtype=torch.long) + packed_h_offset
img_ids[..., 1] = h_coords[:, None]
# Set W (width/x) coordinates with offset
w_coords = torch.arange(packed_width, device=device, dtype=torch.long) + packed_w_offset
img_ids[..., 2] = w_coords[None, :]
# L (layer) coordinate stays 0
# Expand to include batch dimension: [batch_size, (packed_height * packed_width), 4]
img_ids = img_ids.reshape(1, packed_height * packed_width, 4)
img_ids = repeat(img_ids, "1 s c -> b s c", b=batch_size)
return img_ids
class Flux2RefImageExtension:
"""Applies FLUX.2 Klein reference image conditioning.
This extension handles encoding reference images using the FLUX.2 VAE
and generating the appropriate 4D position IDs for multi-reference image editing.
FLUX.2 Klein has built-in support for reference image editing, unlike FLUX.1
which requires a separate Kontext model.
"""
def __init__(
self,
ref_image_conditioning: list[FluxKontextConditioningField],
context: InvocationContext,
vae_field: VAEField,
device: torch.device,
dtype: torch.dtype,
bn_mean: torch.Tensor | None = None,
bn_std: torch.Tensor | None = None,
):
"""Initialize the Flux2RefImageExtension.
Args:
ref_image_conditioning: List of reference image conditioning fields.
context: The invocation context for loading models and images.
vae_field: The FLUX.2 VAE field for encoding images.
device: Target device for tensors.
dtype: Target dtype for tensors.
bn_mean: BN running mean for normalizing latents (shape: 128).
bn_std: BN running std for normalizing latents (shape: 128).
"""
self._context = context
self._device = device
self._dtype = dtype
self._vae_field = vae_field
self._bn_mean = bn_mean
self._bn_std = bn_std
self.ref_image_conditioning = ref_image_conditioning
# Pre-process and cache the reference image latents and ids upon initialization
self.ref_image_latents, self.ref_image_ids = self._prepare_ref_images()
def _bn_normalize(self, x: torch.Tensor) -> torch.Tensor:
"""Apply BN normalization to packed latents.
BN formula (affine=False): y = (x - mean) / std
Args:
x: Packed latents of shape (B, seq, 128).
Returns:
Normalized latents of same shape.
"""
assert self._bn_mean is not None and self._bn_std is not None
bn_mean = self._bn_mean.to(x.device, x.dtype)
bn_std = self._bn_std.to(x.device, x.dtype)
return (x - bn_mean) / bn_std
def _prepare_ref_images(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Encode reference images and prepare their concatenated latents and IDs with spatial tiling."""
all_latents = []
all_ids = []
# Track cumulative dimensions for spatial tiling
canvas_h = 0
canvas_w = 0
vae_info = self._context.models.load(self._vae_field.vae)
# Determine max pixels based on number of reference images (BFL FLUX.2 approach)
num_refs = len(self.ref_image_conditioning)
max_pixels = MAX_PIXELS_SINGLE_REF if num_refs == 1 else MAX_PIXELS_MULTI_REF
for idx, ref_image_field in enumerate(self.ref_image_conditioning):
image = self._context.images.get_pil(ref_image_field.image.image_name)
image = image.convert("RGB")
# Resize large images to max pixel count (matches BFL FLUX.2 sampling.py)
image = resize_image_to_max_pixels(image, max_pixels)
# Convert to tensor using torchvision transforms
transformation = T.Compose([T.ToTensor()])
image_tensor = transformation(image)
# Convert from [0, 1] to [-1, 1] range expected by VAE
image_tensor = image_tensor * 2.0 - 1.0
image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
# Encode using FLUX.2 VAE
with vae_info.model_on_device() as (_, vae):
vae_dtype = next(iter(vae.parameters())).dtype
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
# FLUX.2 VAE uses diffusers API
latent_dist = vae.encode(image_tensor, return_dict=False)[0]
# Use mode() for deterministic encoding (no sampling)
if hasattr(latent_dist, "mode"):
ref_image_latents_unpacked = latent_dist.mode()
elif hasattr(latent_dist, "sample"):
ref_image_latents_unpacked = latent_dist.sample()
else:
ref_image_latents_unpacked = latent_dist
TorchDevice.empty_cache()
# Extract tensor dimensions (B, 32, H, W for FLUX.2)
batch_size, _, latent_height, latent_width = ref_image_latents_unpacked.shape
# Pad latents to be compatible with patch_size=2
pad_h = (2 - latent_height % 2) % 2
pad_w = (2 - latent_width % 2) % 2
if pad_h > 0 or pad_w > 0:
ref_image_latents_unpacked = F.pad(ref_image_latents_unpacked, (0, pad_w, 0, pad_h), mode="circular")
_, _, latent_height, latent_width = ref_image_latents_unpacked.shape
# Pack the latents using FLUX.2 pack function (32 channels -> 128)
ref_image_latents_packed = pack_flux2(ref_image_latents_unpacked).to(self._device, self._dtype)
# Apply BN normalization to match the input latents scale
# This is critical - the transformer expects normalized latents
if self._bn_mean is not None and self._bn_std is not None:
ref_image_latents_packed = self._bn_normalize(ref_image_latents_packed)
# Determine spatial offsets for this reference image
h_offset = 0
w_offset = 0
if idx > 0: # First image starts at (0, 0)
# Calculate potential canvas dimensions for each tiling option
potential_h_vertical = canvas_h + latent_height
potential_w_horizontal = canvas_w + latent_width
# Choose arrangement that minimizes the maximum dimension
if potential_h_vertical > potential_w_horizontal:
# Tile horizontally (to the right)
w_offset = canvas_w
canvas_w = canvas_w + latent_width
canvas_h = max(canvas_h, latent_height)
else:
# Tile vertically (below)
h_offset = canvas_h
canvas_h = canvas_h + latent_height
canvas_w = max(canvas_w, latent_width)
else:
canvas_h = latent_height
canvas_w = latent_width
# Generate position IDs with 4D format (T, H, W, L)
# Use T-coordinate offset with scale=10 like diffusers Flux2Pipeline:
# T = scale + scale * idx (so first ref image is T=10, second is T=20, etc.)
# The generated image uses T=0, so this clearly separates reference images
t_offset = 10 + 10 * idx # scale=10 matches diffusers
ref_image_ids = generate_img_ids_flux2_with_offset(
latent_height=latent_height,
latent_width=latent_width,
batch_size=batch_size,
device=self._device,
idx_offset=t_offset, # Reference images use T=10, 20, 30...
h_offset=h_offset,
w_offset=w_offset,
)
all_latents.append(ref_image_latents_packed)
all_ids.append(ref_image_ids)
# Concatenate all latents and IDs along the sequence dimension
concatenated_latents = torch.cat(all_latents, dim=1)
concatenated_ids = torch.cat(all_ids, dim=1)
return concatenated_latents, concatenated_ids
def ensure_batch_size(self, target_batch_size: int) -> None:
"""Ensure the reference image latents and IDs match the target batch size."""
if self.ref_image_latents.shape[0] != target_batch_size:
self.ref_image_latents = self.ref_image_latents.repeat(target_batch_size, 1, 1)
self.ref_image_ids = self.ref_image_ids.repeat(target_batch_size, 1, 1)

View File

@@ -0,0 +1,206 @@
"""FLUX.2 Klein Sampling Utilities.
FLUX.2 Klein uses a 32-channel VAE (AutoencoderKLFlux2) instead of the 16-channel VAE
used by FLUX.1. This module provides sampling utilities adapted for FLUX.2.
"""
import math
import torch
from einops import rearrange
def get_noise_flux2(
num_samples: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
) -> torch.Tensor:
"""Generate noise for FLUX.2 Klein (32 channels).
FLUX.2 uses a 32-channel VAE, so noise must have 32 channels.
The spatial dimensions are calculated to allow for packing.
Args:
num_samples: Batch size.
height: Target image height in pixels.
width: Target image width in pixels.
device: Target device.
dtype: Target dtype.
seed: Random seed.
Returns:
Noise tensor of shape (num_samples, 32, latent_h, latent_w).
"""
# We always generate noise on the same device and dtype then cast to ensure consistency.
rand_device = "cpu"
rand_dtype = torch.float16
# FLUX.2 uses 32 latent channels
# Latent dimensions: height/8, width/8 (from VAE downsampling)
# Must be divisible by 2 for packing (patchify step)
latent_h = 2 * math.ceil(height / 16)
latent_w = 2 * math.ceil(width / 16)
return torch.randn(
num_samples,
32, # FLUX.2 uses 32 latent channels (vs 16 for FLUX.1)
latent_h,
latent_w,
device=rand_device,
dtype=rand_dtype,
generator=torch.Generator(device=rand_device).manual_seed(seed),
).to(device=device, dtype=dtype)
def pack_flux2(x: torch.Tensor) -> torch.Tensor:
"""Pack latent image to flattened array of patch embeddings for FLUX.2.
This performs the patchify + pack operation in one step:
1. Patchify: Group 2x2 spatial patches into channels (C*4)
2. Pack: Flatten spatial dimensions to sequence
For 32-channel input: (B, 32, H, W) -> (B, H/2*W/2, 128)
Args:
x: Latent tensor of shape (B, 32, H, W).
Returns:
Packed tensor of shape (B, H/2*W/2, 128).
"""
# Same operation as FLUX.1 pack, but input has 32 channels -> output has 128
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
def unpack_flux2(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""Unpack flat array of patch embeddings back to latent image for FLUX.2.
This reverses the pack_flux2 operation:
1. Unpack: Restore spatial dimensions from sequence
2. Unpatchify: Restore 32 channels from 128
Args:
x: Packed tensor of shape (B, H/2*W/2, 128).
height: Target image height in pixels.
width: Target image width in pixels.
Returns:
Latent tensor of shape (B, 32, H, W).
"""
# Calculate latent dimensions
latent_h = 2 * math.ceil(height / 16)
latent_w = 2 * math.ceil(width / 16)
# Packed dimensions (after patchify)
packed_h = latent_h // 2
packed_w = latent_w // 2
return rearrange(
x,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=packed_h,
w=packed_w,
ph=2,
pw=2,
)
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
"""Compute mu for FLUX.2 schedule shifting.
Uses a fixed mu value of 2.02, matching ComfyUI's proven FLUX.2 configuration.
The previous implementation (from diffusers' FLUX.1 pipeline) computed mu as a
linear function of image_seq_len, which produced excessively high values at
high resolutions (e.g., mu=3.23 at 2048x2048). This over-shifted the sigma
schedule, compressing almost all values above 0.9 and forcing the model to
denoise everything in the final 1-2 steps, causing severe grid/diamond artifacts.
ComfyUI uses a fixed shift=2.02 for FLUX.2 Klein at all resolutions and produces
artifact-free images even at 2048x2048.
Args:
image_seq_len: Number of image tokens (packed_h * packed_w). Currently unused.
num_steps: Number of denoising steps. Currently unused.
Returns:
The mu value (fixed at 2.02).
"""
return 2.02
def get_schedule_flux2(
num_steps: int,
image_seq_len: int,
) -> list[float]:
"""Get linear timestep schedule for FLUX.2.
Returns a linear sigma schedule from 1.0 to 1/num_steps.
The actual schedule shifting is handled by the FlowMatchEulerDiscreteScheduler
using the mu parameter and use_dynamic_shifting=True.
Args:
num_steps: Number of denoising steps.
image_seq_len: Number of image tokens (packed_h * packed_w). Currently unused,
but kept for API compatibility. The scheduler computes shifting internally.
Returns:
List of linear sigmas from 1.0 to 1/num_steps, plus final 0.0.
"""
import numpy as np
# Create linear sigmas from 1.0 to 1/num_steps
# The scheduler will apply dynamic shifting using mu parameter
sigmas = np.linspace(1.0, 1 / num_steps, num_steps)
sigmas_list = [float(s) for s in sigmas]
# Add final 0.0 for the last step (scheduler needs n+1 timesteps for n steps)
sigmas_list.append(0.0)
return sigmas_list
def generate_img_ids_flux2(h: int, w: int, batch_size: int, device: torch.device) -> torch.Tensor:
"""Generate tensor of image position ids for FLUX.2 with RoPE scaling.
FLUX.2 uses 4D position coordinates (T, H, W, L) for its rotary position embeddings.
This is different from FLUX.1 which uses 3D coordinates.
RoPE Scaling: For resolutions >1536x1536, position IDs are scaled down using
Position Interpolation to prevent RoPE degradation and diamond/grid artifacts.
IMPORTANT: Position IDs must use int64 (long) dtype like diffusers, not bfloat16.
Using floating point dtype for position IDs can cause NaN in rotary embeddings.
Args:
h: Height of image in latent space.
w: Width of image in latent space.
batch_size: Batch size.
device: Device.
Returns:
Image position ids tensor of shape (batch_size, h/2*w/2, 4) with int64 dtype.
"""
# After packing, spatial dims are h/2 x w/2
packed_h = h // 2
packed_w = w // 2
# Create coordinate grids - 4D: (T, H, W, L)
# T = time/batch index, H = height, W = width, L = layer/channel
# Use int64 (long) dtype like diffusers
img_ids = torch.zeros(packed_h, packed_w, 4, device=device, dtype=torch.long)
# T (time/batch) coordinate - set to 0 (already initialized)
# H coordinates
img_ids[..., 1] = torch.arange(packed_h, device=device, dtype=torch.long)[:, None]
# W coordinates
img_ids[..., 2] = torch.arange(packed_w, device=device, dtype=torch.long)[None, :]
# L (layer) coordinate - set to 0 (already initialized)
# Flatten and expand for batch
img_ids = img_ids.reshape(1, packed_h * packed_w, 4)
img_ids = img_ids.expand(batch_size, -1, -1)
return img_ids

View File

@@ -0,0 +1,367 @@
# Original: https://github.com/joeyballentine/Material-Map-Generator
# Adopted and optimized for Invoke AI
from collections import OrderedDict
from typing import Any, List, Literal, Optional
import torch
import torch.nn as nn
ACTIVATION_LAYER_TYPE = Literal["relu", "leakyrelu", "prelu"]
NORMALIZATION_LAYER_TYPE = Literal["batch", "instance"]
PADDING_LAYER_TYPE = Literal["zero", "reflect", "replicate"]
BLOCK_MODE = Literal["CNA", "NAC", "CNAC"]
UPCONV_BLOCK_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear"]
def act(act_type: ACTIVATION_LAYER_TYPE, inplace: bool = True, neg_slope: float = 0.2, n_prelu: int = 1):
"""Helper to select Activation Layer"""
if act_type == "relu":
layer = nn.ReLU(inplace)
elif act_type == "leakyrelu":
layer = nn.LeakyReLU(neg_slope, inplace)
elif act_type == "prelu":
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
return layer
def norm(norm_type: NORMALIZATION_LAYER_TYPE, nc: int):
"""Helper to select Normalization Layer"""
if norm_type == "batch":
layer = nn.BatchNorm2d(nc, affine=True)
elif norm_type == "instance":
layer = nn.InstanceNorm2d(nc, affine=False)
return layer
def pad(pad_type: PADDING_LAYER_TYPE, padding: int):
"""Helper to select Padding Layer"""
if padding == 0 or pad_type == "zero":
return None
if pad_type == "reflect":
layer = nn.ReflectionPad2d(padding)
elif pad_type == "replicate":
layer = nn.ReplicationPad2d(padding)
return layer
def get_valid_padding(kernel_size: int, dilation: int):
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
padding = (kernel_size - 1) // 2
return padding
def sequential(*args: Any):
# Flatten Sequential. It unwraps nn.Sequential.
if len(args) == 1:
if isinstance(args[0], OrderedDict):
raise NotImplementedError("sequential does not support OrderedDict input.")
return args[0] # No sequential is needed.
modules: List[nn.Module] = []
for module in args:
if isinstance(module, nn.Sequential):
for submodule in module.children():
modules.append(submodule)
elif isinstance(module, nn.Module):
modules.append(module)
return nn.Sequential(*modules)
def conv_block(
in_nc: int,
out_nc: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
pad_type: Optional[PADDING_LAYER_TYPE] = "zero",
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
act_type: Optional[ACTIVATION_LAYER_TYPE] = "relu",
mode: BLOCK_MODE = "CNA",
):
"""
Conv layer with padding, normalization, activation
mode: CNA --> Conv -> Norm -> Act
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
"""
assert mode in ["CNA", "NAC", "CNAC"], f"Wrong conv mode [{mode}]"
padding = get_valid_padding(kernel_size, dilation)
p = pad(pad_type, padding) if pad_type else None
padding = padding if pad_type == "zero" else 0
c = nn.Conv2d(
in_nc,
out_nc,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
groups=groups,
)
a = act(act_type) if act_type else None
match mode:
case "CNA":
n = norm(norm_type, out_nc) if norm_type else None
return sequential(p, c, n, a)
case "NAC":
if norm_type is None and act_type is not None:
a = act(act_type, inplace=False)
n = norm(norm_type, in_nc) if norm_type else None
return sequential(n, a, p, c)
case "CNAC":
n = norm(norm_type, in_nc) if norm_type else None
return sequential(n, a, p, c)
class ConcatBlock(nn.Module):
# Concat the output of a submodule to its input
def __init__(self, submodule: nn.Module):
super(ConcatBlock, self).__init__()
self.sub = submodule
def forward(self, x: torch.Tensor):
output = torch.cat((x, self.sub(x)), dim=1)
return output
def __repr__(self):
tmpstr = "Identity .. \n|"
modstr = self.sub.__repr__().replace("\n", "\n|")
tmpstr = tmpstr + modstr
return tmpstr
class ShortcutBlock(nn.Module):
# Elementwise sum the output of a submodule to its input
def __init__(self, submodule: nn.Module):
super(ShortcutBlock, self).__init__()
self.sub = submodule
def forward(self, x: torch.Tensor):
output = x + self.sub(x)
return output
def __repr__(self):
tmpstr = "Identity + \n|"
modstr = self.sub.__repr__().replace("\n", "\n|")
tmpstr = tmpstr + modstr
return tmpstr
class ShortcutBlockSPSR(nn.Module):
# Elementwise sum the output of a submodule to its input
def __init__(self, submodule: nn.Module):
super(ShortcutBlockSPSR, self).__init__()
self.sub = submodule
def forward(self, x: torch.Tensor):
return x, self.sub
def __repr__(self):
tmpstr = "Identity + \n|"
modstr = self.sub.__repr__().replace("\n", "\n|")
tmpstr = tmpstr + modstr
return tmpstr
class ResNetBlock(nn.Module):
"""
ResNet Block, 3-3 style
with extra residual scaling used in EDSR
(Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
"""
def __init__(
self,
in_nc: int,
mid_nc: int,
out_nc: int,
kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
pad_type: PADDING_LAYER_TYPE = "zero",
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
act_type: Optional[ACTIVATION_LAYER_TYPE] = "relu",
mode: BLOCK_MODE = "CNA",
res_scale: int = 1,
):
super(ResNetBlock, self).__init__()
conv0 = conv_block(
in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, norm_type, act_type, mode
)
if mode == "CNA":
act_type = None
if mode == "CNAC": # Residual path: |-CNAC-|
act_type = None
norm_type = None
conv1 = conv_block(
mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, norm_type, act_type, mode
)
self.res = sequential(conv0, conv1)
self.res_scale = res_scale
def forward(self, x: torch.Tensor):
res = self.res(x).mul(self.res_scale)
return x + res
class ResidualDenseBlock_5C(nn.Module):
"""
Residual Dense Block
style: 5 convs
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
"""
def __init__(
self,
nc: int,
kernel_size: int = 3,
gc: int = 32,
stride: int = 1,
bias: bool = True,
pad_type: PADDING_LAYER_TYPE = "zero",
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
act_type: ACTIVATION_LAYER_TYPE = "leakyrelu",
mode: BLOCK_MODE = "CNA",
):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = conv_block(
nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode
)
self.conv2 = conv_block(
nc + gc,
gc,
kernel_size,
stride,
bias=bias,
pad_type=pad_type,
norm_type=norm_type,
act_type=act_type,
mode=mode,
)
self.conv3 = conv_block(
nc + 2 * gc,
gc,
kernel_size,
stride,
bias=bias,
pad_type=pad_type,
norm_type=norm_type,
act_type=act_type,
mode=mode,
)
self.conv4 = conv_block(
nc + 3 * gc,
gc,
kernel_size,
stride,
bias=bias,
pad_type=pad_type,
norm_type=norm_type,
act_type=act_type,
mode=mode,
)
if mode == "CNA":
last_act = None
else:
last_act = act_type
self.conv5 = conv_block(
nc + 4 * gc, nc, 3, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=last_act, mode=mode
)
def forward(self, x: torch.Tensor):
x1 = self.conv1(x)
x2 = self.conv2(torch.cat((x, x1), 1))
x3 = self.conv3(torch.cat((x, x1, x2), 1))
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5.mul(0.2) + x
class RRDB(nn.Module):
"""
Residual in Residual Dense Block
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
"""
def __init__(
self,
nc: int,
kernel_size: int = 3,
gc: int = 32,
stride: int = 1,
bias: bool = True,
pad_type: PADDING_LAYER_TYPE = "zero",
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
act_type: ACTIVATION_LAYER_TYPE = "leakyrelu",
mode: BLOCK_MODE = "CNA",
):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode)
self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode)
self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode)
def forward(self, x: torch.Tensor):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out.mul(0.2) + x
# Upsampler
def pixelshuffle_block(
in_nc: int,
out_nc: int,
upscale_factor: int = 2,
kernel_size: int = 3,
stride: int = 1,
bias: bool = True,
pad_type: PADDING_LAYER_TYPE = "zero",
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
act_type: ACTIVATION_LAYER_TYPE = "relu",
):
"""
Pixel shuffle layer
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
Neural Network, CVPR17)
"""
conv = conv_block(
in_nc,
out_nc * (upscale_factor**2),
kernel_size,
stride,
bias=bias,
pad_type=pad_type,
norm_type=None,
act_type=None,
)
pixel_shuffle = nn.PixelShuffle(upscale_factor)
n = norm(norm_type, out_nc) if norm_type else None
a = act(act_type) if act_type else None
return sequential(conv, pixel_shuffle, n, a)
def upconv_block(
in_nc: int,
out_nc: int,
upscale_factor: int = 2,
kernel_size: int = 3,
stride: int = 1,
bias: bool = True,
pad_type: PADDING_LAYER_TYPE = "zero",
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
act_type: ACTIVATION_LAYER_TYPE = "relu",
mode: UPCONV_BLOCK_MODE = "nearest",
):
# Adopted from https://distill.pub/2016/deconv-checkerboard/
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
conv = conv_block(
in_nc, out_nc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type
)
return sequential(upsample, conv)

View File

@@ -0,0 +1,70 @@
# Original: https://github.com/joeyballentine/Material-Map-Generator
# Adopted and optimized for Invoke AI
import math
from typing import Literal, Optional
import torch
import torch.nn as nn
import invokeai.backend.image_util.pbr_maps.architecture.block as B
UPSCALE_MODE = Literal["upconv", "pixelshuffle"]
class PBR_RRDB_Net(nn.Module):
def __init__(
self,
in_nc: int,
out_nc: int,
nf: int,
nb: int,
gc: int = 32,
upscale: int = 4,
norm_type: Optional[B.NORMALIZATION_LAYER_TYPE] = None,
act_type: B.ACTIVATION_LAYER_TYPE = "leakyrelu",
mode: B.BLOCK_MODE = "CNA",
res_scale: int = 1,
upsample_mode: UPSCALE_MODE = "upconv",
):
super(PBR_RRDB_Net, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
rb_blocks = [
B.RRDB(
nf,
kernel_size=3,
gc=32,
stride=1,
bias=True,
pad_type="zero",
norm_type=norm_type,
act_type=act_type,
mode="CNA",
)
for _ in range(nb)
]
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
if upsample_mode == "upconv":
upsample_block = B.upconv_block
elif upsample_mode == "pixelshuffle":
upsample_block = B.pixelshuffle_block
if upscale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
else:
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
self.model = B.sequential(
fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)), *upsampler, HR_conv0, HR_conv1
)
def forward(self, x: torch.Tensor):
return self.model(x)

View File

@@ -0,0 +1,141 @@
# Original: https://github.com/joeyballentine/Material-Map-Generator
# Adopted and optimized for Invoke AI
import pathlib
from typing import Any, Literal
import cv2
import numpy as np
import numpy.typing as npt
import torch
from PIL import Image
from safetensors.torch import load_file
from invokeai.backend.image_util.pbr_maps.architecture.pbr_rrdb_net import PBR_RRDB_Net
from invokeai.backend.image_util.pbr_maps.utils.image_ops import crop_seamless, esrgan_launcher_split_merge
NORMAL_MAP_MODEL = (
"https://huggingface.co/InvokeAI/pbr-material-maps/resolve/main/normal_map_generator.safetensors?download=true"
)
OTHER_MAP_MODEL = (
"https://huggingface.co/InvokeAI/pbr-material-maps/resolve/main/franken_map_generator.safetensors?download=true"
)
class PBRMapsGenerator:
def __init__(self, normal_map_model: PBR_RRDB_Net, other_map_model: PBR_RRDB_Net, device: torch.device) -> None:
self.normal_map_model = normal_map_model
self.other_map_model = other_map_model
self.device = device
@staticmethod
def load_model(model_path: pathlib.Path, device: torch.device) -> PBR_RRDB_Net:
state_dict = load_file(model_path.as_posix(), device=device.type)
model = PBR_RRDB_Net(
3,
3,
32,
12,
gc=32,
upscale=1,
norm_type=None,
act_type="leakyrelu",
mode="CNA",
res_scale=1,
upsample_mode="upconv",
)
model.load_state_dict(state_dict, strict=False)
del state_dict
if torch.cuda.is_available() and device.type == "cuda":
torch.cuda.empty_cache()
model.eval()
for _, v in model.named_parameters():
v.requires_grad = False
return model.to(device)
def process(self, img: npt.NDArray[Any], model: PBR_RRDB_Net):
img = img.astype(np.float32) / np.iinfo(img.dtype).max
img = img[..., ::-1].copy()
tensor_img = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).to(self.device)
with torch.no_grad():
output = model(tensor_img).data.squeeze(0).float().cpu().clamp_(0, 1).numpy()
output = output[[2, 1, 0], :, :]
output = np.transpose(output, (1, 2, 0))
output = (output * 255.0).round()
return output
def _cv2_to_pil(self, image: npt.NDArray[Any]):
return Image.fromarray(cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_RGB2BGR))
def generate_maps(
self,
image: Image.Image,
tile_size: int = 512,
border_mode: Literal["none", "seamless", "mirror", "replicate"] = "none",
):
"""
Generate PBR texture maps (normal, roughness, and displacement) from an input image.
The image can optionally be padded before inference to control how borders are treated,
which can help create seamless or edgeconsistent textures.
Args:
image: Source image used to generate the PBR maps.
tile_size: Maximum tile size used for tiled inference. If the image is larger than
this size in either dimension, it will be split into tiles for processing and
then merged.
border_mode: Strategy for padding the image before inference:
- "none": No padding is applied; the image is processed asis.
- "seamless": Pads the image using wraparound tiling
(`cv2.BORDER_WRAP`) to help produce seamless textures.
- "mirror": Pads the image by mirroring border pixels
(`cv2.BORDER_REFLECT_101`) to reduce edge artifacts.
- "replicate": Pads the image by replicating the edge pixels outward
(`cv2.BORDER_REPLICATE`).
Returns:
A tuple of three PIL Images:
- normal_map: RGB normal map generated from the input.
- roughness: Singlechannel roughness map extracted from the second model output.
- displacement: Singlechannel displacement (height) map extracted from the
second model output.
"""
models = [self.normal_map_model, self.other_map_model]
np_image = np.array(image).astype(np.uint8)
match border_mode:
case "seamless":
np_image = cv2.copyMakeBorder(np_image, 16, 16, 16, 16, cv2.BORDER_WRAP)
case "mirror":
np_image = cv2.copyMakeBorder(np_image, 16, 16, 16, 16, cv2.BORDER_REFLECT_101)
case "replicate":
np_image = cv2.copyMakeBorder(np_image, 16, 16, 16, 16, cv2.BORDER_REPLICATE)
case "none":
pass
img_height, img_width = np_image.shape[:2]
# Checking whether to perform tiled inference
do_split = img_height > tile_size or img_width > tile_size
if do_split:
rlts = esrgan_launcher_split_merge(np_image, self.process, models, scale_factor=1, tile_size=tile_size)
else:
rlts = [self.process(np_image, model) for model in models]
if border_mode != "none":
rlts = [crop_seamless(rlt) for rlt in rlts]
normal_map = self._cv2_to_pil(rlts[0])
roughness = self._cv2_to_pil(rlts[1][:, :, 1])
displacement = self._cv2_to_pil(rlts[1][:, :, 0])
return normal_map, roughness, displacement

View File

@@ -0,0 +1,93 @@
# Original: https://github.com/joeyballentine/Material-Map-Generator
# Adopted and optimized for Invoke AI
import math
from typing import Any, Callable, List
import numpy as np
import numpy.typing as npt
from invokeai.backend.image_util.pbr_maps.architecture.pbr_rrdb_net import PBR_RRDB_Net
def crop_seamless(img: npt.NDArray[Any]):
img_height, img_width = img.shape[:2]
y, x = 16, 16
h, w = img_height - 32, img_width - 32
img = img[y : y + h, x : x + w]
return img
# from https://github.com/ata4/esrgan-launcher/blob/master/upscale.py
def esrgan_launcher_split_merge(
input_image: npt.NDArray[Any],
upscale_function: Callable[[npt.NDArray[Any], PBR_RRDB_Net], npt.NDArray[Any]],
models: List[PBR_RRDB_Net],
scale_factor: int = 4,
tile_size: int = 512,
tile_padding: float = 0.125,
):
width, height, depth = input_image.shape
output_width = width * scale_factor
output_height = height * scale_factor
output_shape = (output_width, output_height, depth)
# start with black image
output_images = [np.zeros(output_shape, np.uint8) for _ in range(len(models))]
tile_padding = math.ceil(tile_size * tile_padding)
tile_size = math.ceil(tile_size / scale_factor)
tiles_x = math.ceil(width / tile_size)
tiles_y = math.ceil(height / tile_size)
for y in range(tiles_y):
for x in range(tiles_x):
# extract tile from input image
ofs_x = x * tile_size
ofs_y = y * tile_size
# input tile area on total image
input_start_x = ofs_x
input_end_x = min(ofs_x + tile_size, width)
input_start_y = ofs_y
input_end_y = min(ofs_y + tile_size, height)
# input tile area on total image with padding
input_start_x_pad = max(input_start_x - tile_padding, 0)
input_end_x_pad = min(input_end_x + tile_padding, width)
input_start_y_pad = max(input_start_y - tile_padding, 0)
input_end_y_pad = min(input_end_y + tile_padding, height)
# input tile dimensions
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
input_tile = input_image[input_start_x_pad:input_end_x_pad, input_start_y_pad:input_end_y_pad]
for idx, model in enumerate(models):
# upscale tile
output_tile = upscale_function(input_tile, model)
# output tile area on total image
output_start_x = input_start_x * scale_factor
output_end_x = input_end_x * scale_factor
output_start_y = input_start_y * scale_factor
output_end_y = input_end_y * scale_factor
# output tile area without padding
output_start_x_tile = (input_start_x - input_start_x_pad) * scale_factor
output_end_x_tile = output_start_x_tile + input_tile_width * scale_factor
output_start_y_tile = (input_start_y - input_start_y_pad) * scale_factor
output_end_y_tile = output_start_y_tile + input_tile_height * scale_factor
# put tile into output image
output_images[idx][output_start_x:output_end_x, output_start_y:output_end_y] = output_tile[
output_start_x_tile:output_end_x_tile, output_start_y_tile:output_end_y_tile
]
return output_images

View File

@@ -0,0 +1,212 @@
# Model Management System
This document describes Invoke's model management system and common tasks for extending model support.
## Overview
The model management system handles the full lifecycle of models: identification, loading, and running. The system is extensible and supports multiple model architectures, formats, and quantization schemes.
### Three Major Subsystems
1. **Model Identification** (`configs/`): Determines model type, architecture, format, and metadata when users install models.
2. **Model Loading** (`load/`): Loads models from disk into memory for inference.
3. **Model Running**: Executes inference on loaded models. Implementation is scattered across the codebase, typically in architecture-specific inference code adjacent to `model_manager/`. The inference code is run in nodes in the graph execution system.
## Core Concepts
### Model Taxonomy
The `taxonomy.py` module defines the type system for models:
- `ModelType`: The kind of model (e.g., `Main`, `LoRA`, `ControlNet`, `VAE`).
- `ModelFormat`: Storage format - may imply a quantization or some other quality (e.g., `Diffusers`, `Checkpoint`, `LyCORIS`, `BnbQuantizednf4b`).
- `BaseModelType`: Associated pipeline architecture (e.g., `StableDiffusion1`, `StableDiffusionXL`, `Flux`). Models without an associated base use `Any` (e.g., `CLIPVision` is its own thing).
- `ModelVariantType`, `FluxVariantType`, `ClipVariantType`: Architecture-specific variants.
These enums form a discriminated union that uniquely identifies each model configuration class.
### Model "Configs"
Model configs are Pydantic models that describe a model on disk. They include the model taxonomy, path, and any metadata needed for loading or running the model.
Model configs are stored in the database.
### Model Identification
When a user installs a model, the system attempts to identify it by trying each registered config class until one matches.
**Config Classes** (`configs/`):
- All config classes inherit from `Config_Base`, either directly or indirectly via some intermediary class (e.g., `Diffusers_Config_Base`, `Checkpoint_Config_Base`, or something narrower).
- Each config class represents a specific, unique combination of `type`, `format`, `base`, and optional `variant`.
- Config classes must implement `from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict) -> Self`. This method inspects the model on disk and raises `NotAMatchError` if the model doesn't match the config class, or returns an instance of the config class if it does.
- `ModelOnDisk` is a helper class that abstracts the model weights. It should be the entrypoint for inspecting the model (e.g., loading state dicts).
- Override fields allow users to provide hints (e.g., when differentiating between SD1/SD2/SDXL VAEs with identical structures).
**Identification Process**:
1. `ModelConfigFactory.from_model_on_disk()` is called with a path to the model.
2. The factory iterates through all registered config classes, calling `from_model_on_disk()` on each.
3. Each config class inspects the model (state dict keys, tensor shapes, config files, etc.).
4. If a match is found, the config instance is returned. If multiple matches are found, they are prioritized (e.g., main models over LoRAs).
5. If no match is found, an `Unknown_Config` is returned as a fallback.
**Utilities** (`identification_utils.py`):
- `NotAMatchError`: Exception raised when a model doesn't match a config class.
- `get_config_dict_or_raise()`: Load JSON config files from diffusers/transformers models.
- `raise_for_class_name()`: Validate class names in config files.
- `raise_for_override_fields()`: Validate user-provided override fields against the config schema.
- `state_dict_has_any_keys_*()`: Helpers for inspecting state dict keys.
### Model Loading
Model loaders handle instantiating models from disk into memory.
**Loader Classes** (`load/model_loaders/`):
- Loaders register themselves with a decorator `@ModelLoaderRegistry.register(base=..., type=..., format=...)`. The `type`, `format` and `base` indicate which configs classes the loader can handle.
- Each loader implements `_load_model(self, config: AnyModelConfig, submodel_type: Optional[SubModelType]) -> AnyModel`.
- Loaders are responsible for:
- Loading model weights from the config's path.
- Instantiating the correct model class (often using diffusers, transformers, or custom implementations).
- Returning the in-memory model representation.
**Model Cache** (`load/model_cache/`):
> This system typically does not require changes to support new model types, but it is important to understand how it works.
- Manages models in memory with RAM and VRAM limits.
- Handles moving models between CPU (storage device) and GPU (execution device).
- Implements LRU eviction for RAM and smallest-first offload for VRAM.
- Supports partial loading for large models on CUDA.
- Thread-safe with locks on all public methods.
**Loading Process**:
1. The appropriate loader is selected based on the model config's `base`, `type`, and `format` attributes.
2. The loader's `_load_model()` method is called with the model config.
3. The loaded model is added to the model cache via `ModelCache.put()`.
4. When needed, the model is moved into VRAM via `ModelCache.get()` and `ModelCache.lock()`.
### Model Running
Model running is architecture-specific and typically implemented in folders adjacent to `model_manager/`.
Inference code doesn't necessarily follow any specific pattern, and doesn't interact directly with the model management system except to receive model configs and loaded models.
At a high level, when a node needs to run a model, it will:
- Receive a model identifier as an input or constant. This is typically the model's database ID (aka the `key`).
- The node will use the `InvocationContext` API to load the model. The request is dispatched to the model manager which will load the model and return the a model loader with a context manager that yields the in-memory model, mediating VRAM/RAM management as needed.
- The node will run inference using the loaded model using whatever patterns or libraries it needs.
## Common Tasks
### Task 1: Improving Identification for a Supported Model Type
When identification fails or produces incorrect results for a model that should be supported, you may need to refine the identification logic.
**Steps**:
1. Obtain the failing model file or directory.
2. Create a test case for it, following the instructions in `tests/model_identification/README.md`.
3. Review the relevant config class in `configs/` (e.g., `configs/lora.py` for LoRA models).
4. Examine the `from_model_on_disk()` method for some existing models to understand the patterns for identification logic.
5. Inspect the failing model's files and structure:
- For checkpoint files: Load the state dict and examine keys and tensor shapes.
- For diffusers models: Examine the config files and directory structure.
6. Update the identification logic to handle the new model variant. Common approaches:
- Check for specific state dict keys or key patterns.
- Inspect tensor shapes (e.g., `state_dict[key].shape`).
- Parse config files for class names or configuration values.
- Use helper functions from `identification_utils.py`.
7. Run the test suite to verify the new logic works and doesn't break existing tests: `pytest tests/model_identification/test_identification.py`.
- Make sure you have installed the test dependencies (e.g. `uv pip install -e ".[dev,test]"`).
- If the model type is complex or has multiple variants, consider adding more test cases to cover edge cases.
8. If, after successfully adding identification support for the model, it still doesn't work, you may need to update loading and/or inference code as well.
**Key Files**:
- Config class: `configs/<model_type>.py`
- Identification utilities: `configs/identification_utils.py`
- Taxonomy: `taxonomy.py`
- Test README: `tests/model_identification/README.md`
### Task 2: Adding Support for a New Model Type
Adding a new model type requires implementing identification and loading logic. Inference and new nodes ("invocations") may be required if the model type doesn't fit into existing architectures or nodes.
**Steps**:
#### 1. Define Taxonomy
- Add a new `ModelType` enum value in `taxonomy.py` if needed.
- Determine the appropriate `BaseModelType` (or use `Any` if not architecture-specific).
- Add a new `ModelFormat` if the model uses a unique storage format.
You may need to add other attributes, depending on the model.
#### 2. Implement Config Class
- Create a new config file in `configs/` (e.g., `configs/new_model.py`).
- Define a config class inheriting from `Config_Base` and appropriate format base class:
- `Diffusers_Config_Base` for diffusers-style models.
- `Checkpoint_Config_Base` for single-file checkpoint models.
- Define `type`, `format`, and `base` as `Literal` fields with defaults. Remember, these must uniquely identify the config class.
- Implement `from_model_on_disk()`:
- Validate the model is the correct format (file vs directory).
- Inspect state dict keys, tensor shapes, or config files.
- Raise `NotAMatchError` if the model doesn't match.
- Extract any additional metadata needed (e.g., variant, prediction type).
- Return an instance of the config class.
- Register the config in `configs/factory.py`:
- Add the config class to the `AnyModelConfig` union.
- Add an `Annotated[YourConfig, YourConfig.get_tag()]` entry.
#### 3. Implement Loader Class
- Create a new loader file in `load/model_loaders/` (e.g., `load/model_loaders/new_model.py`).
- Define a loader class inheriting from `ModelLoader`.
- Decorate with `@ModelLoaderRegistry.register(base=..., type=..., format=...)`.
- Implement `_load_model()`:
- Load model weights from `config.path`.
- Instantiate the model using the appropriate library (diffusers, transformers, or custom).
- Handle `submodel_type` if the model has submodels (e.g., text encoders, VAE).
- Return the in-memory model representation.
#### 4. Add Tests
Follow the instructions in `tests/model_identification/README.md`.
#### 5. Implement Inference and Nodes (if needed)
- If the model type requires new inference logic, implement it in an appropriate location.
- Create nodes for the model if it doesn't fit into existing nodes. Search for subclasses of `BaseInvocation` for many examples.
### 6. Frontend Support
#### Workflows tab
Typically, you will not need to do anything for the model to work in the Workflow Editor. When you define the node's model field, you can provide constraints for what type of models are selectable. The UI will automatically filter the list of models based on the model taxonomy.
For example, this field definition in a node will allow users to select only "main" (pipeline) Stable Diffusion 1.x or 2.x models:
```py
model: ModelIdentifierField = InputField(
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2],
ui_model_type=ModelType.Main,
)
```
This same pattern works for any combination of `type`, `base`, `format`, and `variant`.
#### Canvas / Generate tabs
The Canvas and Generate tabs use graphs internally, but they don't expose the full graph editor UI. Instead, they provide a simplified interface for common tasks.
They use "graph builder" functions, which take the user's selected settings and build a graph behind the scenes. We have one graph builder for each model architecture.
Updating or adding a graph builder can be a bit complex, and you'd likely need to update other UI components and state management to support the new model type.
The SDXL graph builder is a good example: `invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts`

View File

@@ -28,17 +28,6 @@ if TYPE_CHECKING:
pass
class URLModelSource(BaseModel):
type: Literal[ModelSourceType.Url] = Field(default=ModelSourceType.Url)
url: str = Field(
description="The URL from which the model was installed.",
)
api_response: str | None = Field(
default=None,
description="The original API response from the source, as stringified JSON.",
)
class Config_Base(ABC, BaseModel):
"""
Abstract base class for model configurations. A model config describes a specific combination of model base, type and

View File

@@ -88,7 +88,9 @@ class ControlNet_Diffusers_Config_Base(Diffusers_Config_Base):
cls._validate_base(mod)
return cls(**override_fields)
repo_variant = {"repo_variant": override_fields.get("repo_variant", cls._get_repo_variant_or_raise(mod))}
args = override_fields | repo_variant
return cls(**args)
@classmethod
def _validate_base(cls, mod: ModelOnDisk) -> None:
@@ -228,3 +230,47 @@ class ControlNet_Checkpoint_SDXL_Config(ControlNet_Checkpoint_Config_Base, Confi
class ControlNet_Checkpoint_FLUX_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
def _has_z_image_control_keys(state_dict: dict) -> bool:
"""Check if state dict contains Z-Image Control specific keys."""
z_image_control_keys = {"control_layers", "control_all_x_embedder", "control_noise_refiner"}
for key in state_dict.keys():
if isinstance(key, str):
prefix = key.split(".")[0]
if prefix in z_image_control_keys:
return True
return False
class ControlNet_Checkpoint_ZImage_Config(Checkpoint_Config_Base, Config_Base):
"""Model config for Z-Image Control adapter models (Safetensors checkpoint).
Z-Image Control models are standalone adapters containing only the control layers
(control_layers, control_all_x_embedder, control_noise_refiner) that extend
the base Z-Image transformer with spatial conditioning capabilities.
Supports: Canny, HED, Depth, Pose, MLSD.
Recommended control_context_scale: 0.65-0.80.
"""
type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)
default_settings: ControlAdapterDefaultSettings | None = Field(None)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_z_image_control(mod)
return cls(**override_fields)
@classmethod
def _validate_looks_like_z_image_control(cls, mod: ModelOnDisk) -> None:
state_dict = mod.load_state_dict()
if not _has_z_image_control_keys(state_dict):
raise NotAMatchError("state dict does not look like a Z-Image Control model")

View File

@@ -20,6 +20,7 @@ from invokeai.backend.model_manager.configs.controlnet import (
ControlNet_Checkpoint_SD1_Config,
ControlNet_Checkpoint_SD2_Config,
ControlNet_Checkpoint_SDXL_Config,
ControlNet_Checkpoint_ZImage_Config,
ControlNet_Diffusers_FLUX_Config,
ControlNet_Diffusers_SD1_Config,
ControlNet_Diffusers_SD2_Config,
@@ -43,30 +44,44 @@ from invokeai.backend.model_manager.configs.lora import (
LoRA_Diffusers_SD1_Config,
LoRA_Diffusers_SD2_Config,
LoRA_Diffusers_SDXL_Config,
LoRA_Diffusers_ZImage_Config,
LoRA_LyCORIS_FLUX_Config,
LoRA_LyCORIS_SD1_Config,
LoRA_LyCORIS_SD2_Config,
LoRA_LyCORIS_SDXL_Config,
LoRA_LyCORIS_ZImage_Config,
LoRA_OMI_FLUX_Config,
LoRA_OMI_SDXL_Config,
LoraModelDefaultSettings,
)
from invokeai.backend.model_manager.configs.main import (
Main_BnBNF4_FLUX_Config,
Main_Checkpoint_Flux2_Config,
Main_Checkpoint_FLUX_Config,
Main_Checkpoint_SD1_Config,
Main_Checkpoint_SD2_Config,
Main_Checkpoint_SDXL_Config,
Main_Checkpoint_SDXLRefiner_Config,
Main_Checkpoint_ZImage_Config,
Main_Diffusers_CogView4_Config,
Main_Diffusers_Flux2_Config,
Main_Diffusers_FLUX_Config,
Main_Diffusers_SD1_Config,
Main_Diffusers_SD2_Config,
Main_Diffusers_SD3_Config,
Main_Diffusers_SDXL_Config,
Main_Diffusers_SDXLRefiner_Config,
Main_Diffusers_ZImage_Config,
Main_GGUF_Flux2_Config,
Main_GGUF_FLUX_Config,
Main_GGUF_ZImage_Config,
MainModelDefaultSettings,
)
from invokeai.backend.model_manager.configs.qwen3_encoder import (
Qwen3Encoder_Checkpoint_Config,
Qwen3Encoder_GGUF_Config,
Qwen3Encoder_Qwen3Encoder_Config,
)
from invokeai.backend.model_manager.configs.siglip import SigLIP_Diffusers_Config
from invokeai.backend.model_manager.configs.spandrel import Spandrel_Checkpoint_Config
from invokeai.backend.model_manager.configs.t2i_adapter import (
@@ -84,10 +99,12 @@ from invokeai.backend.model_manager.configs.textual_inversion import (
)
from invokeai.backend.model_manager.configs.unknown import Unknown_Config
from invokeai.backend.model_manager.configs.vae import (
VAE_Checkpoint_Flux2_Config,
VAE_Checkpoint_FLUX_Config,
VAE_Checkpoint_SD1_Config,
VAE_Checkpoint_SD2_Config,
VAE_Checkpoint_SDXL_Config,
VAE_Diffusers_Flux2_Config,
VAE_Diffusers_SD1_Config,
VAE_Diffusers_SDXL_Config,
)
@@ -137,29 +154,43 @@ AnyModelConfig = Annotated[
Annotated[Main_Diffusers_SDXL_Config, Main_Diffusers_SDXL_Config.get_tag()],
Annotated[Main_Diffusers_SDXLRefiner_Config, Main_Diffusers_SDXLRefiner_Config.get_tag()],
Annotated[Main_Diffusers_SD3_Config, Main_Diffusers_SD3_Config.get_tag()],
Annotated[Main_Diffusers_FLUX_Config, Main_Diffusers_FLUX_Config.get_tag()],
Annotated[Main_Diffusers_Flux2_Config, Main_Diffusers_Flux2_Config.get_tag()],
Annotated[Main_Diffusers_CogView4_Config, Main_Diffusers_CogView4_Config.get_tag()],
Annotated[Main_Diffusers_ZImage_Config, Main_Diffusers_ZImage_Config.get_tag()],
# Main (Pipeline) - checkpoint format
# IMPORTANT: FLUX.2 must be checked BEFORE FLUX.1 because FLUX.2 has specific validation
# that will reject FLUX.1 models, but FLUX.1 validation may incorrectly match FLUX.2 models
Annotated[Main_Checkpoint_SD1_Config, Main_Checkpoint_SD1_Config.get_tag()],
Annotated[Main_Checkpoint_SD2_Config, Main_Checkpoint_SD2_Config.get_tag()],
Annotated[Main_Checkpoint_SDXL_Config, Main_Checkpoint_SDXL_Config.get_tag()],
Annotated[Main_Checkpoint_SDXLRefiner_Config, Main_Checkpoint_SDXLRefiner_Config.get_tag()],
Annotated[Main_Checkpoint_Flux2_Config, Main_Checkpoint_Flux2_Config.get_tag()],
Annotated[Main_Checkpoint_FLUX_Config, Main_Checkpoint_FLUX_Config.get_tag()],
Annotated[Main_Checkpoint_ZImage_Config, Main_Checkpoint_ZImage_Config.get_tag()],
# Main (Pipeline) - quantized formats
# IMPORTANT: FLUX.2 must be checked BEFORE FLUX.1 because FLUX.2 has specific validation
# that will reject FLUX.1 models, but FLUX.1 validation may incorrectly match FLUX.2 models
Annotated[Main_BnBNF4_FLUX_Config, Main_BnBNF4_FLUX_Config.get_tag()],
Annotated[Main_GGUF_Flux2_Config, Main_GGUF_Flux2_Config.get_tag()],
Annotated[Main_GGUF_FLUX_Config, Main_GGUF_FLUX_Config.get_tag()],
Annotated[Main_GGUF_ZImage_Config, Main_GGUF_ZImage_Config.get_tag()],
# VAE - checkpoint format
Annotated[VAE_Checkpoint_SD1_Config, VAE_Checkpoint_SD1_Config.get_tag()],
Annotated[VAE_Checkpoint_SD2_Config, VAE_Checkpoint_SD2_Config.get_tag()],
Annotated[VAE_Checkpoint_SDXL_Config, VAE_Checkpoint_SDXL_Config.get_tag()],
Annotated[VAE_Checkpoint_FLUX_Config, VAE_Checkpoint_FLUX_Config.get_tag()],
Annotated[VAE_Checkpoint_Flux2_Config, VAE_Checkpoint_Flux2_Config.get_tag()],
# VAE - diffusers format
Annotated[VAE_Diffusers_SD1_Config, VAE_Diffusers_SD1_Config.get_tag()],
Annotated[VAE_Diffusers_SDXL_Config, VAE_Diffusers_SDXL_Config.get_tag()],
Annotated[VAE_Diffusers_Flux2_Config, VAE_Diffusers_Flux2_Config.get_tag()],
# ControlNet - checkpoint format
Annotated[ControlNet_Checkpoint_SD1_Config, ControlNet_Checkpoint_SD1_Config.get_tag()],
Annotated[ControlNet_Checkpoint_SD2_Config, ControlNet_Checkpoint_SD2_Config.get_tag()],
Annotated[ControlNet_Checkpoint_SDXL_Config, ControlNet_Checkpoint_SDXL_Config.get_tag()],
Annotated[ControlNet_Checkpoint_FLUX_Config, ControlNet_Checkpoint_FLUX_Config.get_tag()],
Annotated[ControlNet_Checkpoint_ZImage_Config, ControlNet_Checkpoint_ZImage_Config.get_tag()],
# ControlNet - diffusers format
Annotated[ControlNet_Diffusers_SD1_Config, ControlNet_Diffusers_SD1_Config.get_tag()],
Annotated[ControlNet_Diffusers_SD2_Config, ControlNet_Diffusers_SD2_Config.get_tag()],
@@ -170,6 +201,7 @@ AnyModelConfig = Annotated[
Annotated[LoRA_LyCORIS_SD2_Config, LoRA_LyCORIS_SD2_Config.get_tag()],
Annotated[LoRA_LyCORIS_SDXL_Config, LoRA_LyCORIS_SDXL_Config.get_tag()],
Annotated[LoRA_LyCORIS_FLUX_Config, LoRA_LyCORIS_FLUX_Config.get_tag()],
Annotated[LoRA_LyCORIS_ZImage_Config, LoRA_LyCORIS_ZImage_Config.get_tag()],
# LoRA - OMI format
Annotated[LoRA_OMI_SDXL_Config, LoRA_OMI_SDXL_Config.get_tag()],
Annotated[LoRA_OMI_FLUX_Config, LoRA_OMI_FLUX_Config.get_tag()],
@@ -178,11 +210,16 @@ AnyModelConfig = Annotated[
Annotated[LoRA_Diffusers_SD2_Config, LoRA_Diffusers_SD2_Config.get_tag()],
Annotated[LoRA_Diffusers_SDXL_Config, LoRA_Diffusers_SDXL_Config.get_tag()],
Annotated[LoRA_Diffusers_FLUX_Config, LoRA_Diffusers_FLUX_Config.get_tag()],
Annotated[LoRA_Diffusers_ZImage_Config, LoRA_Diffusers_ZImage_Config.get_tag()],
# ControlLoRA - diffusers format
Annotated[ControlLoRA_LyCORIS_FLUX_Config, ControlLoRA_LyCORIS_FLUX_Config.get_tag()],
# T5 Encoder - all formats
Annotated[T5Encoder_T5Encoder_Config, T5Encoder_T5Encoder_Config.get_tag()],
Annotated[T5Encoder_BnBLLMint8_Config, T5Encoder_BnBLLMint8_Config.get_tag()],
# Qwen3 Encoder
Annotated[Qwen3Encoder_Qwen3Encoder_Config, Qwen3Encoder_Qwen3Encoder_Config.get_tag()],
Annotated[Qwen3Encoder_Checkpoint_Config, Qwen3Encoder_Checkpoint_Config.get_tag()],
Annotated[Qwen3Encoder_GGUF_Config, Qwen3Encoder_GGUF_Config.get_tag()],
# TI - file format
Annotated[TI_File_SD1_Config, TI_File_SD1_Config.get_tag()],
Annotated[TI_File_SD2_Config, TI_File_SD2_Config.get_tag()],
@@ -333,7 +370,11 @@ class ModelConfigFactory:
# For directories, do a quick file count check with early exit
total_files = 0
# Ignore hidden files and directories
paths_to_check = (p for p in path.rglob("*") if not p.name.startswith("."))
paths_to_check = (
p
for p in path.rglob("*")
if not p.name.startswith(".") and not any(part.startswith(".") for part in p.parts)
)
for item in paths_to_check:
if item.is_file():
total_files += 1
@@ -473,7 +514,9 @@ class ModelConfigFactory:
# Now do any post-processing needed for specific model types/bases/etc.
match config.type:
case ModelType.Main:
config.default_settings = MainModelDefaultSettings.from_base(config.base)
# Pass variant if available (e.g., for Flux2 models)
variant = getattr(config, "variant", None)
config.default_settings = MainModelDefaultSettings.from_base(config.base, variant)
case ModelType.ControlNet | ModelType.T2IAdapter | ModelType.ControlLoRa:
config.default_settings = ControlAdapterDefaultSettings.from_model_name(config.name)
case ModelType.LoRA:

View File

@@ -150,11 +150,16 @@ class LoRA_LyCORIS_Config_Base(LoRA_Config_Base):
@classmethod
def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
# First rule out ControlLoRA and Diffusers LoRA
# First rule out ControlLoRA
flux_format = _get_flux_lora_format(mod)
if flux_format in [FluxLoRAFormat.Control]:
raise NotAMatchError("model looks like Control LoRA")
# If it's a recognized Flux LoRA format (Kohya, Diffusers, OneTrainer, AIToolkit, XLabs, etc.),
# it's valid and we skip the heuristic check
if flux_format is not None:
return
# Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
# Some main models have these keys, likely due to the creator merging in a LoRA.
has_key_with_lora_prefix = state_dict_has_any_keys_starting_with(
@@ -217,6 +222,73 @@ class LoRA_LyCORIS_FLUX_Config(LoRA_LyCORIS_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base):
"""Model config for Z-Image LoRA models in LyCORIS format."""
base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)
@classmethod
def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
"""Z-Image LoRAs have different key patterns than SD/SDXL LoRAs.
Z-Image LoRAs use keys like:
- diffusion_model.layers.X.attention.to_k.lora_down.weight (DoRA format)
- diffusion_model.layers.X.attention.to_k.lora_A.weight (PEFT format)
- diffusion_model.layers.X.attention.to_k.dora_scale (DoRA scale)
"""
state_dict = mod.load_state_dict()
# Check for Z-Image specific LoRA patterns
has_z_image_lora_keys = state_dict_has_any_keys_starting_with(
state_dict,
{
"diffusion_model.layers.", # Z-Image S3-DiT layer pattern
},
)
# Also check for LoRA weight suffixes (various formats)
has_lora_suffix = state_dict_has_any_keys_ending_with(
state_dict,
{
"lora_A.weight",
"lora_B.weight",
"lora_down.weight",
"lora_up.weight",
"dora_scale",
},
)
if has_z_image_lora_keys and has_lora_suffix:
return
raise NotAMatchError("model does not match Z-Image LoRA heuristics")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
"""Z-Image LoRAs are identified by their diffusion_model.layers structure.
Z-Image uses S3-DiT architecture with layer names like:
- diffusion_model.layers.0.attention.to_k.lora_A.weight
- diffusion_model.layers.0.feed_forward.w1.lora_A.weight
"""
state_dict = mod.load_state_dict()
# Check for Z-Image transformer layer patterns
# Z-Image uses diffusion_model.layers.X structure (unlike Flux which uses double_blocks/single_blocks)
has_z_image_keys = state_dict_has_any_keys_starting_with(
state_dict,
{
"diffusion_model.layers.", # Z-Image S3-DiT layer pattern
},
)
# If it looks like a Z-Image LoRA, return ZImage base
if has_z_image_keys:
return BaseModelType.ZImage
raise NotAMatchError("model does not look like a Z-Image LoRA")
class ControlAdapter_Config_Base(ABC, BaseModel):
default_settings: ControlAdapterDefaultSettings | None = Field(None)
@@ -320,3 +392,9 @@ class LoRA_Diffusers_SDXL_Config(LoRA_Diffusers_Config_Base, Config_Base):
class LoRA_Diffusers_FLUX_Config(LoRA_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
class LoRA_Diffusers_ZImage_Config(LoRA_Diffusers_Config_Base, Config_Base):
"""Model config for Z-Image LoRA models in Diffusers format."""
base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)

View File

@@ -23,6 +23,7 @@ from invokeai.backend.model_manager.configs.identification_utils import (
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
Flux2VariantType,
FluxVariantType,
ModelFormat,
ModelType,
@@ -52,7 +53,11 @@ class MainModelDefaultSettings(BaseModel):
model_config = ConfigDict(extra="forbid")
@classmethod
def from_base(cls, base: BaseModelType) -> Self | None:
def from_base(
cls,
base: BaseModelType,
variant: Flux2VariantType | FluxVariantType | ModelVariantType | None = None,
) -> Self | None:
match base:
case BaseModelType.StableDiffusion1:
return cls(width=512, height=512)
@@ -60,6 +65,16 @@ class MainModelDefaultSettings(BaseModel):
return cls(width=768, height=768)
case BaseModelType.StableDiffusionXL:
return cls(width=1024, height=1024)
case BaseModelType.ZImage:
return cls(steps=9, cfg_scale=1.0, width=1024, height=1024)
case BaseModelType.Flux2:
# Different defaults based on variant
if variant == Flux2VariantType.Klein9BBase:
# Undistilled base model needs more steps
return cls(steps=28, cfg_scale=1.0, width=1024, height=1024)
else:
# Distilled models (Klein 4B, Klein 9B) use fewer steps
return cls(steps=4, cfg_scale=1.0, width=1024, height=1024)
case _:
# TODO(psyche): Do we want defaults for other base types?
return None
@@ -111,6 +126,47 @@ def _has_main_keys(state_dict: dict[str | int, Any]) -> bool:
return False
def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool:
"""Check if state dict contains Z-Image S3-DiT transformer keys.
This function returns True only for Z-Image main models, not LoRAs.
LoRAs are excluded by checking for LoRA-specific weight suffixes.
"""
# Z-Image specific keys that distinguish it from other models
z_image_specific_keys = {
"cap_embedder", # Caption embedder - unique to Z-Image
"context_refiner", # Context refiner blocks
"cap_pad_token", # Caption padding token
}
# LoRA-specific suffixes - if present, this is a LoRA not a main model
lora_suffixes = (
".lora_down.weight",
".lora_up.weight",
".lora_A.weight",
".lora_B.weight",
".dora_scale",
)
for key in state_dict.keys():
if isinstance(key, int):
continue
# If we find any LoRA-specific keys, this is not a main model
if key.endswith(lora_suffixes):
return False
# Check for Z-Image specific key prefixes
# Handle both direct keys (cap_embedder.0.weight) and
# ComfyUI-style keys (model.diffusion_model.cap_embedder.0.weight)
key_parts = key.split(".")
for part in key_parts:
if part in z_image_specific_keys:
return True
return False
class Main_SD_Checkpoint_Config_Base(Checkpoint_Config_Base, Main_Config_Base):
"""Model config for main checkpoint models."""
@@ -225,6 +281,108 @@ class Main_Checkpoint_SDXLRefiner_Config(Main_SD_Checkpoint_Config_Base, Config_
base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(default=BaseModelType.StableDiffusionXLRefiner)
def _is_flux2_model(state_dict: dict[str | int, Any]) -> bool:
"""Check if state dict is a FLUX.2 model by examining context_embedder dimensions.
FLUX.2 Klein uses Qwen3 encoder with larger context dimension:
- FLUX.1: context_in_dim = 4096 (T5)
- FLUX.2 Klein 4B: context_in_dim = 7680 (3×Qwen3-4B hidden size)
- FLUX.2 Klein 8B: context_in_dim = 12288 (3×Qwen3-8B hidden size)
Also checks for FLUX.2-specific 32-channel latent space (in_channels=128 after packing).
"""
# Check context_embedder input dimension (most reliable)
# Weight shape: [hidden_size, context_in_dim]
for key in {"context_embedder.weight", "model.diffusion_model.context_embedder.weight"}:
if key in state_dict:
weight = state_dict[key]
if hasattr(weight, "shape") and len(weight.shape) >= 2:
context_in_dim = weight.shape[1]
# FLUX.2 has context_in_dim > 4096 (Qwen3 vs T5)
if context_in_dim > 4096:
return True
# Also check in_channels - FLUX.2 uses 128 (32 latent channels × 4 packing)
for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}:
if key in state_dict:
in_channels = state_dict[key].shape[1]
# FLUX.2 uses 128 in_channels (32 latent channels × 4)
# FLUX.1 uses 64 in_channels (16 latent channels × 4)
if in_channels == 128:
return True
return False
def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | None:
"""Determine FLUX.2 variant from state dict.
Distinguishes between Klein 4B and Klein 9B based on context embedding dimension:
- Klein 4B: context_in_dim = 7680 (3 × Qwen3-4B hidden_size 2560)
- Klein 9B: context_in_dim = 12288 (3 × Qwen3-8B hidden_size 4096)
Note: Klein 9B Base (undistilled) also has context_in_dim = 12288 but is rare.
We default to Klein9B (distilled) for all 9B models since GGUF models may not
include guidance embedding keys needed to distinguish them.
Supports both BFL format (checkpoint) and diffusers format keys:
- BFL format: txt_in.weight (context embedder)
- Diffusers format: context_embedder.weight
"""
# Context dimensions for each variant
KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560
KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096
# Check context_embedder to determine variant
# Support both BFL format (txt_in.weight) and diffusers format (context_embedder.weight)
context_keys = {
# Diffusers format
"context_embedder.weight",
"model.diffusion_model.context_embedder.weight",
# BFL format (used by checkpoint/GGUF models)
"txt_in.weight",
"model.diffusion_model.txt_in.weight",
}
for key in context_keys:
if key in state_dict:
weight = state_dict[key]
# Handle GGUF quantized tensors which use tensor_shape instead of shape
if hasattr(weight, "tensor_shape"):
shape = weight.tensor_shape
elif hasattr(weight, "shape"):
shape = weight.shape
else:
continue
if len(shape) >= 2:
context_in_dim = shape[1]
# Determine variant based on context dimension
if context_in_dim == KLEIN_9B_CONTEXT_DIM:
# Default to Klein9B (distilled) - the official/common 9B model
return Flux2VariantType.Klein9B
elif context_in_dim == KLEIN_4B_CONTEXT_DIM:
return Flux2VariantType.Klein4B
elif context_in_dim > 4096:
# Unknown FLUX.2 variant, default to 4B
return Flux2VariantType.Klein4B
# Check in_channels as backup - can only confirm it's FLUX.2, not which variant
for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}:
if key in state_dict:
weight = state_dict[key]
# Handle GGUF quantized tensors
if hasattr(weight, "tensor_shape"):
in_channels = weight.tensor_shape[1]
elif hasattr(weight, "shape"):
in_channels = weight.shape[1]
else:
continue
if in_channels == 128:
# It's FLUX.2 but we can't determine which Klein variant, default to 4B
return Flux2VariantType.Klein4B
return None
def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | None:
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
@@ -298,8 +456,9 @@ class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Conf
@classmethod
def _validate_is_flux(cls, mod: ModelOnDisk) -> None:
state_dict = mod.load_state_dict()
if not state_dict_has_any_keys_exact(
mod.load_state_dict(),
state_dict,
{
"double_blocks.0.img_attn.norm.key_norm.scale",
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
@@ -307,6 +466,10 @@ class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Conf
):
raise NotAMatchError("state dict does not look like a FLUX checkpoint")
# Exclude FLUX.2 models - they have their own config class
if _is_flux2_model(state_dict):
raise NotAMatchError("model is a FLUX.2 model, not FLUX.1")
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
@@ -340,6 +503,68 @@ class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Conf
raise NotAMatchError("state dict looks like GGUF quantized")
class Main_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
"""Model config for FLUX.2 checkpoint models (e.g. Klein)."""
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
variant: Flux2VariantType = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_main_model(mod)
cls._validate_is_flux2(mod)
cls._validate_does_not_look_like_bnb_quantized(mod)
cls._validate_does_not_look_like_gguf_quantized(mod)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
return cls(**override_fields, variant=variant)
@classmethod
def _validate_is_flux2(cls, mod: ModelOnDisk) -> None:
"""Validate that this is a FLUX.2 model, not FLUX.1."""
state_dict = mod.load_state_dict()
if not _is_flux2_model(state_dict):
raise NotAMatchError("state dict does not look like a FLUX.2 model")
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
state_dict = mod.load_state_dict()
variant = _get_flux2_variant(state_dict)
if variant is None:
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
return variant
@classmethod
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
has_main_model_keys = _has_main_keys(mod.load_state_dict())
if not has_main_model_keys:
raise NotAMatchError("state dict does not look like a main model")
@classmethod
def _validate_does_not_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict())
if has_bnb_nf4_keys:
raise NotAMatchError("state dict looks like bnb quantized nf4")
@classmethod
def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk):
has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
if has_ggml_tensors:
raise NotAMatchError("state dict looks like GGUF quantized")
class Main_BnBNF4_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
"""Model config for main checkpoint models."""
@@ -407,6 +632,8 @@ class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Bas
cls._validate_looks_like_gguf_quantized(mod)
cls._validate_is_not_flux2(mod)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
return cls(**override_fields, variant=variant)
@@ -437,6 +664,195 @@ class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Bas
if not has_ggml_tensors:
raise NotAMatchError("state dict does not look like GGUF quantized")
@classmethod
def _validate_is_not_flux2(cls, mod: ModelOnDisk) -> None:
"""Validate that this is NOT a FLUX.2 model."""
state_dict = mod.load_state_dict()
if _is_flux2_model(state_dict):
raise NotAMatchError("model is a FLUX.2 model, not FLUX.1")
class Main_GGUF_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
"""Model config for GGUF-quantized FLUX.2 checkpoint models (e.g. Klein)."""
base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
variant: Flux2VariantType = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_main_model(mod)
cls._validate_looks_like_gguf_quantized(mod)
cls._validate_is_flux2(mod)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
return cls(**override_fields, variant=variant)
@classmethod
def _validate_is_flux2(cls, mod: ModelOnDisk) -> None:
"""Validate that this is a FLUX.2 model, not FLUX.1."""
state_dict = mod.load_state_dict()
if not _is_flux2_model(state_dict):
raise NotAMatchError("state dict does not look like a FLUX.2 model")
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
state_dict = mod.load_state_dict()
variant = _get_flux2_variant(state_dict)
if variant is None:
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
return variant
@classmethod
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
has_main_model_keys = _has_main_keys(mod.load_state_dict())
if not has_main_model_keys:
raise NotAMatchError("state dict does not look like a main model")
@classmethod
def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
if not has_ggml_tensors:
raise NotAMatchError("state dict does not look like GGUF quantized")
class Main_Diffusers_FLUX_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
"""Model config for FLUX.1 models in diffusers format."""
base: Literal[BaseModelType.Flux] = Field(BaseModelType.Flux)
variant: FluxVariantType = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
# Check for FLUX-specific pipeline or transformer class names
raise_for_class_name(
common_config_paths(mod.path),
{
"FluxPipeline",
"FluxFillPipeline",
"FluxTransformer2DModel",
},
)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
return cls(
**override_fields,
variant=variant,
repo_variant=repo_variant,
)
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
"""Determine the FLUX variant from the transformer config.
FLUX variants are distinguished by:
- in_channels: 64 for Dev/Schnell, 384 for DevFill
- guidance_embeds: True for Dev, False for Schnell
"""
transformer_config = get_config_dict_or_raise(mod.path / "transformer" / "config.json")
in_channels = transformer_config.get("in_channels", 64)
guidance_embeds = transformer_config.get("guidance_embeds", False)
# DevFill has 384 input channels
if in_channels == 384:
return FluxVariantType.DevFill
# Dev has guidance_embeds=True, Schnell has guidance_embeds=False
if guidance_embeds:
return FluxVariantType.Dev
else:
return FluxVariantType.Schnell
class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
"""Model config for FLUX.2 models in diffusers format (e.g. FLUX.2 Klein)."""
base: Literal[BaseModelType.Flux2] = Field(BaseModelType.Flux2)
variant: Flux2VariantType = Field()
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
# Check for FLUX.2-specific pipeline class names
raise_for_class_name(
common_config_paths(mod.path),
{
"Flux2KleinPipeline",
},
)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
return cls(
**override_fields,
variant=variant,
repo_variant=repo_variant,
)
@classmethod
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
"""Determine the FLUX.2 variant from the transformer config.
FLUX.2 Klein uses Qwen3 text encoder with larger joint_attention_dim:
- Klein 4B: joint_attention_dim = 7680 (3×Qwen3-4B hidden size)
- Klein 9B/9B Base: joint_attention_dim = 12288 (3×Qwen3-8B hidden size)
To distinguish Klein 9B (distilled) from Klein 9B Base (undistilled),
we check guidance_embeds:
- Klein 9B (distilled): guidance_embeds = False (guidance is "baked in" during distillation)
- Klein 9B Base (undistilled): guidance_embeds = True (needs guidance at inference)
Note: The official BFL Klein 9B model is the distilled version with guidance_embeds=False.
"""
KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560
KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096
transformer_config = get_config_dict_or_raise(mod.path / "transformer" / "config.json")
joint_attention_dim = transformer_config.get("joint_attention_dim", 4096)
guidance_embeds = transformer_config.get("guidance_embeds", False)
# Determine variant based on joint_attention_dim
if joint_attention_dim == KLEIN_9B_CONTEXT_DIM:
# Check guidance_embeds to distinguish distilled from undistilled
# Klein 9B (distilled): guidance_embeds = False (guidance is baked in)
# Klein 9B Base (undistilled): guidance_embeds = True (needs guidance)
if guidance_embeds:
return Flux2VariantType.Klein9BBase
else:
return Flux2VariantType.Klein9B
elif joint_attention_dim == KLEIN_4B_CONTEXT_DIM:
return Flux2VariantType.Klein4B
elif joint_attention_dim > 4096:
# Unknown FLUX.2 variant, default to 4B
return Flux2VariantType.Klein4B
# Default to 4B
return Flux2VariantType.Klein4B
class Main_SD_Diffusers_Config_Base(Diffusers_Config_Base, Main_Config_Base):
prediction_type: SchedulerPredictionType = Field()
@@ -657,3 +1073,92 @@ class Main_Diffusers_CogView4_Config(Diffusers_Config_Base, Main_Config_Base, Co
**override_fields,
repo_variant=repo_variant,
)
class Main_Diffusers_ZImage_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
"""Model config for Z-Image diffusers models (Z-Image-Turbo, Z-Image-Base, Z-Image-Edit)."""
base: Literal[BaseModelType.ZImage] = Field(BaseModelType.ZImage)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
# This check implies the base type - no further validation needed.
raise_for_class_name(
common_config_paths(mod.path),
{
"ZImagePipeline",
},
)
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
return cls(
**override_fields,
repo_variant=repo_variant,
)
class Main_Checkpoint_ZImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
"""Model config for Z-Image single-file checkpoint models (safetensors, etc)."""
base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_z_image_model(mod)
cls._validate_does_not_look_like_gguf_quantized(mod)
return cls(**override_fields)
@classmethod
def _validate_looks_like_z_image_model(cls, mod: ModelOnDisk) -> None:
has_z_image_keys = _has_z_image_keys(mod.load_state_dict())
if not has_z_image_keys:
raise NotAMatchError("state dict does not look like a Z-Image model")
@classmethod
def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
if has_ggml_tensors:
raise NotAMatchError("state dict looks like GGUF quantized")
class Main_GGUF_ZImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
"""Model config for GGUF-quantized Z-Image transformer models."""
base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)
format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_z_image_model(mod)
cls._validate_looks_like_gguf_quantized(mod)
return cls(**override_fields)
@classmethod
def _validate_looks_like_z_image_model(cls, mod: ModelOnDisk) -> None:
has_z_image_keys = _has_z_image_keys(mod.load_state_dict())
if not has_z_image_keys:
raise NotAMatchError("state dict does not look like a Z-Image model")
@classmethod
def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
if not has_ggml_tensors:
raise NotAMatchError("state dict does not look like GGUF quantized")

View File

@@ -0,0 +1,265 @@
import json
from typing import Any, Literal, Optional, Self
from pydantic import Field
from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Config_Base
from invokeai.backend.model_manager.configs.identification_utils import (
NotAMatchError,
raise_for_class_name,
raise_for_override_fields,
raise_if_not_dir,
raise_if_not_file,
)
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, Qwen3VariantType
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
def _has_qwen3_keys(state_dict: dict[str | int, Any]) -> bool:
"""Check if state dict contains Qwen3 model keys.
Supports both:
- PyTorch/diffusers format: model.layers.0., model.embed_tokens.weight
- GGUF/llama.cpp format: blk.0., token_embd.weight
"""
# PyTorch/diffusers format indicators
pytorch_indicators = ["model.layers.0.", "model.embed_tokens.weight"]
# GGUF/llama.cpp format indicators
gguf_indicators = ["blk.0.", "token_embd.weight"]
for key in state_dict.keys():
if isinstance(key, str):
# Check PyTorch format
for indicator in pytorch_indicators:
if key.startswith(indicator) or key == indicator:
return True
# Check GGUF format
for indicator in gguf_indicators:
if key.startswith(indicator) or key == indicator:
return True
return False
def _has_ggml_tensors(state_dict: dict[str | int, Any]) -> bool:
"""Check if state dict contains GGML tensors (GGUF quantized)."""
return any(isinstance(v, GGMLTensor) for v in state_dict.values())
def _get_qwen3_variant_from_state_dict(state_dict: dict[str | int, Any]) -> Optional[Qwen3VariantType]:
"""Determine Qwen3 variant (4B vs 8B) from state dict based on hidden_size.
The hidden_size can be determined from the embed_tokens.weight tensor shape:
- Qwen3 4B: hidden_size = 2560
- Qwen3 8B: hidden_size = 4096
For GGUF format, the key is 'token_embd.weight'.
For PyTorch format, the key is 'model.embed_tokens.weight'.
"""
# Hidden size thresholds
QWEN3_4B_HIDDEN_SIZE = 2560
QWEN3_8B_HIDDEN_SIZE = 4096
# Try to find embed_tokens weight
embed_key = None
for key in state_dict.keys():
if isinstance(key, str):
if key == "model.embed_tokens.weight" or key == "token_embd.weight":
embed_key = key
break
if embed_key is None:
return None
tensor = state_dict[embed_key]
# Get hidden_size from tensor shape
# Shape is [vocab_size, hidden_size]
if isinstance(tensor, GGMLTensor):
# GGUF tensor
if hasattr(tensor, "shape") and len(tensor.shape) >= 2:
hidden_size = tensor.shape[1]
else:
return None
elif hasattr(tensor, "shape"):
# PyTorch tensor
if len(tensor.shape) >= 2:
hidden_size = tensor.shape[1]
else:
return None
else:
return None
# Determine variant based on hidden_size
if hidden_size == QWEN3_4B_HIDDEN_SIZE:
return Qwen3VariantType.Qwen3_4B
elif hidden_size == QWEN3_8B_HIDDEN_SIZE:
return Qwen3VariantType.Qwen3_8B
else:
# Unknown size, default to 4B (more common)
return Qwen3VariantType.Qwen3_4B
class Qwen3Encoder_Checkpoint_Config(Checkpoint_Config_Base, Config_Base):
"""Configuration for single-file Qwen3 Encoder models (safetensors)."""
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
variant: Qwen3VariantType = Field(description="Qwen3 model size variant (4B or 8B)")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_qwen3_model(mod)
cls._validate_does_not_look_like_gguf_quantized(mod)
# Determine variant from state dict
variant = cls._get_variant_or_default(mod)
return cls(variant=variant, **override_fields)
@classmethod
def _get_variant_or_default(cls, mod: ModelOnDisk) -> Qwen3VariantType:
"""Get variant from state dict, defaulting to 4B if unknown."""
state_dict = mod.load_state_dict()
variant = _get_qwen3_variant_from_state_dict(state_dict)
return variant if variant is not None else Qwen3VariantType.Qwen3_4B
@classmethod
def _validate_looks_like_qwen3_model(cls, mod: ModelOnDisk) -> None:
has_qwen3_keys = _has_qwen3_keys(mod.load_state_dict())
if not has_qwen3_keys:
raise NotAMatchError("state dict does not look like a Qwen3 model")
@classmethod
def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
has_ggml = _has_ggml_tensors(mod.load_state_dict())
if has_ggml:
raise NotAMatchError("state dict looks like GGUF quantized")
class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
"""Configuration for Qwen3 Encoder models in a diffusers-like format.
The model weights are expected to be in a folder called text_encoder inside the model directory,
compatible with Qwen2VLForConditionalGeneration or similar architectures used by Z-Image.
"""
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
format: Literal[ModelFormat.Qwen3Encoder] = Field(default=ModelFormat.Qwen3Encoder)
variant: Qwen3VariantType = Field(description="Qwen3 model size variant (4B or 8B)")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
# Exclude full pipeline models - these should be matched as main models, not just Qwen3 encoders.
# Full pipelines have model_index.json at root (diffusers format) or a transformer subfolder.
model_index_path = mod.path / "model_index.json"
transformer_path = mod.path / "transformer"
if model_index_path.exists() or transformer_path.exists():
raise NotAMatchError(
"directory looks like a full diffusers pipeline (has model_index.json or transformer folder), "
"not a standalone Qwen3 encoder"
)
# Check for text_encoder config - support both:
# 1. Full model structure: model_root/text_encoder/config.json
# 2. Standalone text_encoder download: model_root/config.json (when text_encoder subfolder is downloaded separately)
config_path_nested = mod.path / "text_encoder" / "config.json"
config_path_direct = mod.path / "config.json"
if config_path_nested.exists():
expected_config_path = config_path_nested
elif config_path_direct.exists():
expected_config_path = config_path_direct
else:
raise NotAMatchError(
f"unable to load config file(s): {{PosixPath('{config_path_nested}'): 'file does not exist'}}"
)
# Qwen3 uses Qwen2VLForConditionalGeneration or similar
raise_for_class_name(
expected_config_path,
{
"Qwen2VLForConditionalGeneration",
"Qwen2ForCausalLM",
"Qwen3ForCausalLM",
},
)
# Determine variant from config.json hidden_size
variant = cls._get_variant_from_config(expected_config_path)
return cls(variant=variant, **override_fields)
@classmethod
def _get_variant_from_config(cls, config_path) -> Qwen3VariantType:
"""Get variant from config.json based on hidden_size."""
QWEN3_4B_HIDDEN_SIZE = 2560
QWEN3_8B_HIDDEN_SIZE = 4096
try:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
hidden_size = config.get("hidden_size")
if hidden_size == QWEN3_8B_HIDDEN_SIZE:
return Qwen3VariantType.Qwen3_8B
elif hidden_size == QWEN3_4B_HIDDEN_SIZE:
return Qwen3VariantType.Qwen3_4B
else:
# Default to 4B for unknown sizes
return Qwen3VariantType.Qwen3_4B
except (json.JSONDecodeError, OSError):
return Qwen3VariantType.Qwen3_4B
class Qwen3Encoder_GGUF_Config(Checkpoint_Config_Base, Config_Base):
"""Configuration for GGUF-quantized Qwen3 Encoder models."""
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
variant: Qwen3VariantType = Field(description="Qwen3 model size variant (4B or 8B)")
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_qwen3_model(mod)
cls._validate_looks_like_gguf_quantized(mod)
# Determine variant from state dict
variant = cls._get_variant_or_default(mod)
return cls(variant=variant, **override_fields)
@classmethod
def _get_variant_or_default(cls, mod: ModelOnDisk) -> Qwen3VariantType:
"""Get variant from state dict, defaulting to 4B if unknown."""
state_dict = mod.load_state_dict()
variant = _get_qwen3_variant_from_state_dict(state_dict)
return variant if variant is not None else Qwen3VariantType.Qwen3_4B
@classmethod
def _validate_looks_like_qwen3_model(cls, mod: ModelOnDisk) -> None:
has_qwen3_keys = _has_qwen3_keys(mod.load_state_dict())
if not has_qwen3_keys:
raise NotAMatchError("state dict does not look like a Qwen3 model")
@classmethod
def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
has_ggml = _has_ggml_tensors(mod.load_state_dict())
if not has_ggml:
raise NotAMatchError("state dict does not look like GGUF quantized")

View File

@@ -33,6 +33,25 @@ REGEX_TO_BASE: dict[str, BaseModelType] = {
}
def _is_flux2_vae(state_dict: dict[str | int, Any]) -> bool:
"""Check if state dict is a FLUX.2 VAE (AutoencoderKLFlux2).
FLUX.2 VAE can be identified by:
1. Batch Normalization layers (bn.running_mean, bn.running_var) - unique to FLUX.2
2. 32-dimensional latent space (decoder.conv_in has 32 input channels)
FLUX.1 VAE has 16-dimensional latent space and no BatchNorm layers.
"""
# Check for BN layer which is unique to FLUX.2 VAE
has_bn = "bn.running_mean" in state_dict or "bn.running_var" in state_dict
# Check for 32-channel latent space (FLUX.2 has 32, FLUX.1 has 16)
decoder_conv_in_key = "decoder.conv_in.weight"
has_32_latent_channels = decoder_conv_in_key in state_dict and state_dict[decoder_conv_in_key].shape[1] == 32
return has_bn or has_32_latent_channels
class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base):
"""Model config for standalone VAE models."""
@@ -61,8 +80,9 @@ class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base):
@classmethod
def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None:
state_dict = mod.load_state_dict()
if not state_dict_has_any_keys_starting_with(
mod.load_state_dict(),
state_dict,
{
"encoder.conv_in",
"decoder.conv_in",
@@ -70,9 +90,30 @@ class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base):
):
raise NotAMatchError("model does not match Checkpoint VAE heuristics")
# Exclude FLUX.2 VAEs - they have their own config class
if _is_flux2_vae(state_dict):
raise NotAMatchError("model is a FLUX.2 VAE, not a standard VAE")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
# Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name
# First, try to identify by latent space dimensions (most reliable)
state_dict = mod.load_state_dict()
decoder_conv_in_key = "decoder.conv_in.weight"
if decoder_conv_in_key in state_dict:
latent_channels = state_dict[decoder_conv_in_key].shape[1]
if latent_channels == 16:
# Flux1 VAE has 16-dimensional latent space
return BaseModelType.Flux
elif latent_channels == 4:
# SD/SDXL VAE has 4-dimensional latent space
# Try to distinguish SD1/SD2/SDXL by name, fallback to SD1
for regexp, base in REGEX_TO_BASE.items():
if re.search(regexp, mod.path.name, re.IGNORECASE):
return base
# Default to SD1 if we can't determine from name
return BaseModelType.StableDiffusion1
# Fallback: guess based on name
for regexp, base in REGEX_TO_BASE.items():
if re.search(regexp, mod.path.name, re.IGNORECASE):
return base
@@ -96,6 +137,44 @@ class VAE_Checkpoint_FLUX_Config(VAE_Checkpoint_Config_Base, Config_Base):
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
class VAE_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Config_Base):
"""Model config for FLUX.2 VAE checkpoint models (AutoencoderKLFlux2)."""
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
cls._validate_looks_like_vae(mod)
cls._validate_is_flux2_vae(mod)
return cls(**override_fields)
@classmethod
def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None:
if not state_dict_has_any_keys_starting_with(
mod.load_state_dict(),
{
"encoder.conv_in",
"decoder.conv_in",
},
):
raise NotAMatchError("model does not match Checkpoint VAE heuristics")
@classmethod
def _validate_is_flux2_vae(cls, mod: ModelOnDisk) -> None:
"""Validate that this is a FLUX.2 VAE, not FLUX.1."""
state_dict = mod.load_state_dict()
if not _is_flux2_vae(state_dict):
raise NotAMatchError("state dict does not look like a FLUX.2 VAE")
class VAE_Diffusers_Config_Base(Diffusers_Config_Base):
"""Model config for standalone VAE models (diffusers version)."""
@@ -161,3 +240,26 @@ class VAE_Diffusers_SD1_Config(VAE_Diffusers_Config_Base, Config_Base):
class VAE_Diffusers_SDXL_Config(VAE_Diffusers_Config_Base, Config_Base):
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
class VAE_Diffusers_Flux2_Config(Diffusers_Config_Base, Config_Base):
"""Model config for FLUX.2 VAE models in diffusers format (AutoencoderKLFlux2)."""
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
raise_for_class_name(
common_config_paths(mod.path),
{
"AutoencoderKLFlux2",
},
)
return cls(**override_fields)

View File

@@ -55,6 +55,21 @@ def synchronized(method: Callable[..., Any]) -> Callable[..., Any]:
return wrapper
def record_activity(method: Callable[..., Any]) -> Callable[..., Any]:
"""A decorator that records activity after a method completes successfully.
Note: This decorator should be applied to methods that already hold self._lock.
"""
@wraps(method)
def wrapper(self, *args, **kwargs):
result = method(self, *args, **kwargs)
self._record_activity()
return result
return wrapper
@dataclass
class CacheEntrySnapshot:
cache_key: str
@@ -132,6 +147,7 @@ class ModelCache:
storage_device: torch.device | str = "cpu",
log_memory_usage: bool = False,
logger: Optional[Logger] = None,
keep_alive_minutes: float = 0,
):
"""Initialize the model RAM cache.
@@ -151,6 +167,7 @@ class ModelCache:
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
:param logger: InvokeAILogger to use (otherwise creates one)
:param keep_alive_minutes: How long to keep models in cache after last use (in minutes). 0 means keep indefinitely.
"""
self._enable_partial_loading = enable_partial_loading
self._keep_ram_copy_of_weights = keep_ram_copy_of_weights
@@ -182,6 +199,12 @@ class ModelCache:
self._on_cache_miss_callbacks: set[CacheMissCallback] = set()
self._on_cache_models_cleared_callbacks: set[CacheModelsClearedCallback] = set()
# Keep-alive timeout support
self._keep_alive_minutes = keep_alive_minutes
self._last_activity_time: Optional[float] = None
self._timeout_timer: Optional[threading.Timer] = None
self._shutdown_event = threading.Event()
def on_cache_hit(self, cb: CacheHitCallback) -> Callable[[], None]:
self._on_cache_hit_callbacks.add(cb)
@@ -190,7 +213,7 @@ class ModelCache:
return unsubscribe
def on_cache_miss(self, cb: CacheHitCallback) -> Callable[[], None]:
def on_cache_miss(self, cb: CacheMissCallback) -> Callable[[], None]:
self._on_cache_miss_callbacks.add(cb)
def unsubscribe() -> None:
@@ -217,8 +240,82 @@ class ModelCache:
def stats(self, stats: CacheStats) -> None:
"""Set the CacheStats object for collecting cache statistics."""
self._stats = stats
# Populate the cache size in the stats object when it's set
if self._stats is not None:
self._stats.cache_size = self._ram_cache_size_bytes
def _record_activity(self) -> None:
"""Record model activity and reset the timeout timer if configured.
Note: This method should only be called when self._lock is already held.
"""
if self._keep_alive_minutes <= 0:
return
self._last_activity_time = time.time()
# Cancel any existing timer
if self._timeout_timer is not None:
self._timeout_timer.cancel()
# Start a new timer
timeout_seconds = self._keep_alive_minutes * 60
self._timeout_timer = threading.Timer(timeout_seconds, self._on_timeout)
# Set as daemon so it doesn't prevent application shutdown
self._timeout_timer.daemon = True
self._timeout_timer.start()
self._logger.debug(f"Model cache activity recorded. Timeout set to {self._keep_alive_minutes} minutes.")
@synchronized
@record_activity
def _on_timeout(self) -> None:
"""Called when the keep-alive timeout expires. Clears the model cache."""
if self._shutdown_event.is_set():
return
# Double-check if there has been activity since the timer was set
# This handles the race condition where activity occurred just before the timer fired
if self._last_activity_time is not None and self._keep_alive_minutes > 0:
elapsed_minutes = (time.time() - self._last_activity_time) / 60
if elapsed_minutes < self._keep_alive_minutes:
# Activity occurred, don't clear cache
self._logger.debug(
f"Model cache timeout fired but activity detected {elapsed_minutes:.2f} minutes ago. "
f"Skipping cache clear."
)
return
# Check if there are any unlocked models that can be cleared
unlocked_models = [key for key, entry in self._cached_models.items() if not entry.is_locked]
if len(unlocked_models) > 0:
self._logger.info(
f"Model cache keep-alive timeout of {self._keep_alive_minutes} minutes expired. "
f"Clearing {len(unlocked_models)} unlocked model(s) from cache."
)
# Clear the cache by requesting a very large amount of space.
# This is the same logic used by the "Clear Model Cache" button.
# Using 1000 GB ensures all unlocked models are removed.
self._make_room_internal(1000 * GB)
elif len(self._cached_models) > 0:
# All models are locked, don't log at info level
self._logger.debug(
f"Model cache timeout fired but all {len(self._cached_models)} model(s) are locked. "
f"Skipping cache clear."
)
else:
self._logger.debug("Model cache timeout fired but cache is already empty.")
@synchronized
def shutdown(self) -> None:
"""Shutdown the model cache, cancelling any pending timers."""
self._shutdown_event.set()
if self._timeout_timer is not None:
self._timeout_timer.cancel()
self._timeout_timer = None
@synchronized
@record_activity
def put(self, key: str, model: AnyModel) -> None:
"""Add a model to the cache."""
if key in self._cached_models:
@@ -228,7 +325,7 @@ class ModelCache:
return
size = calc_model_size_by_data(self._logger, model)
self.make_room(size)
self._make_room_internal(size)
# Inject custom modules into the model.
if isinstance(model, torch.nn.Module):
@@ -272,6 +369,7 @@ class ModelCache:
return overview
@synchronized
@record_activity
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
"""Retrieve a model from the cache.
@@ -309,9 +407,11 @@ class ModelCache:
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
for cb in self._on_cache_hit_callbacks:
cb(model_key=key, cache_snapshot=self._get_cache_snapshot())
return cache_entry
@synchronized
@record_activity
def lock(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int]) -> None:
"""Lock a model for use and move it into VRAM."""
if cache_entry.key not in self._cached_models:
@@ -348,6 +448,7 @@ class ModelCache:
self._log_cache_state()
@synchronized
@record_activity
def unlock(self, cache_entry: CacheRecord) -> None:
"""Unlock a model."""
if cache_entry.key not in self._cached_models:
@@ -691,6 +792,10 @@ class ModelCache:
external references to the model, there's nothing that the cache can do about it, and those models will not be
garbage-collected.
"""
self._make_room_internal(bytes_needed)
def _make_room_internal(self, bytes_needed: int) -> None:
"""Internal implementation of make_room(). Assumes the lock is already held."""
self._logger.debug(f"Making room for {bytes_needed / MB:.2f}MB of RAM.")
self._log_cache_state(title="Before dropping models:")

View File

@@ -0,0 +1,40 @@
import torch
from diffusers.models.normalization import RMSNorm as DiffusersRMSNorm
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
class CustomDiffusersRMSNorm(DiffusersRMSNorm, CustomModuleMixin):
"""Custom wrapper for diffusers RMSNorm that supports device autocasting for partial model loading."""
def _autocast_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, hidden_states.device) if self.weight is not None else None
bias = cast_to_device(self.bias, hidden_states.device) if self.bias is not None else None
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if weight is not None:
# convert into half-precision if necessary
if weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(weight.dtype)
hidden_states = hidden_states * weight
if bias is not None:
hidden_states = hidden_states + bias
else:
hidden_states = hidden_states.to(input_dtype)
return hidden_states
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if len(self._patches_and_weights) > 0:
raise RuntimeError("DiffusersRMSNorm layers do not support patches")
if self._device_autocasting_enabled:
return self._autocast_forward(hidden_states)
else:
return super().forward(hidden_states)

View File

@@ -0,0 +1,25 @@
import torch
import torch.nn.functional as F
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)
class CustomLayerNorm(torch.nn.LayerNorm, CustomModuleMixin):
"""Custom wrapper for torch.nn.LayerNorm that supports device autocasting for partial model loading."""
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device) if self.weight is not None else None
bias = cast_to_device(self.bias, input.device) if self.bias is not None else None
return F.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if len(self._patches_and_weights) > 0:
raise RuntimeError("LayerNorm layers do not support patches")
if self._device_autocasting_enabled:
return self._autocast_forward(input)
else:
return super().forward(input)

View File

@@ -1,14 +1,18 @@
from typing import TypeVar
import torch
from diffusers.models.normalization import RMSNorm as DiffusersRMSNorm
from invokeai.backend.flux.modules.layers import RMSNorm
from invokeai.backend.flux.modules.layers import RMSNorm as FluxRMSNorm
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv1d import (
CustomConv1d,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv2d import (
CustomConv2d,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_diffusers_rms_norm import (
CustomDiffusersRMSNorm,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_embedding import (
CustomEmbedding,
)
@@ -18,6 +22,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custo
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_group_norm import (
CustomGroupNorm,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_layer_norm import (
CustomLayerNorm,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import (
CustomLinear,
)
@@ -31,7 +38,9 @@ AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]]
torch.nn.Conv2d: CustomConv2d,
torch.nn.GroupNorm: CustomGroupNorm,
torch.nn.Embedding: CustomEmbedding,
RMSNorm: CustomFluxRMSNorm,
torch.nn.LayerNorm: CustomLayerNorm,
FluxRMSNorm: CustomFluxRMSNorm,
DiffusersRMSNorm: CustomDiffusersRMSNorm,
}
try:

View File

@@ -45,12 +45,13 @@ class CogView4DiffusersModel(GenericDiffusersLoader):
model_path,
torch_dtype=dtype,
variant=variant,
local_files_only=True,
)
except OSError as e:
if variant and "no file named" in str(
e
): # try without the variant, just in case user's preferences changed
result = load_class.from_pretrained(model_path, torch_dtype=dtype)
result = load_class.from_pretrained(model_path, torch_dtype=dtype, local_files_only=True)
else:
raise e

File diff suppressed because it is too large Load Diff

View File

@@ -37,12 +37,14 @@ class GenericDiffusersLoader(ModelLoader):
repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None
variant = repo_variant.value if repo_variant else None
try:
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant)
result: AnyModel = model_class.from_pretrained(
model_path, torch_dtype=self._torch_dtype, variant=variant, local_files_only=True
)
except OSError as e:
if variant and "no file named" in str(
e
): # try without the variant, just in case user's preferences changed
result = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype)
result = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, local_files_only=True)
else:
raise e
return result

View File

@@ -41,8 +41,13 @@ from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_u
is_state_dict_likely_in_flux_onetrainer_format,
lora_model_from_flux_onetrainer_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_xlabs_lora_conversion_utils import (
is_state_dict_likely_in_flux_xlabs_format,
lora_model_from_flux_xlabs_state_dict,
)
from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import lora_model_from_z_image_state_dict
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.LoRA, format=ModelFormat.OMI)
@@ -117,6 +122,8 @@ class LoRALoader(ModelLoader):
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict=state_dict):
model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict)
elif is_state_dict_likely_in_flux_xlabs_format(state_dict=state_dict):
model = lora_model_from_flux_xlabs_state_dict(state_dict=state_dict)
else:
raise ValueError("LoRA model is in unsupported FLUX format")
else:
@@ -124,6 +131,10 @@ class LoRALoader(ModelLoader):
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
# Currently, we don't apply any conversions for SD1 and SD2 LoRA models.
model = lora_model_from_sd_state_dict(state_dict=state_dict)
elif self._model_base == BaseModelType.ZImage:
# Z-Image LoRAs use diffusers PEFT format with transformer and/or Qwen3 encoder layers.
# We set alpha=None to use rank as alpha (common default).
model = lora_model_from_z_image_state_dict(state_dict=state_dict, alpha=None)
else:
raise ValueError(f"Unsupported LoRA base model: {self._model_base}")

View File

@@ -38,5 +38,6 @@ class OnnyxDiffusersModel(GenericDiffusersLoader):
model_path,
torch_dtype=self._torch_dtype,
variant=variant,
local_files_only=True,
)
return result

View File

@@ -80,12 +80,13 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
model_path,
torch_dtype=self._torch_dtype,
variant=variant,
local_files_only=True,
)
except OSError as e:
if variant and "no file named" in str(
e
): # try without the variant, just in case user's preferences changed
result = load_class.from_pretrained(model_path, torch_dtype=self._torch_dtype)
result = load_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, local_files_only=True)
else:
raise e

File diff suppressed because it is too large Load Diff

View File

@@ -10,7 +10,7 @@ import onnxruntime as ort
import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast
from transformers import CLIPTokenizer, PreTrainedTokenizerBase, T5Tokenizer, T5TokenizerFast
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
@@ -73,6 +73,10 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
# relative to the text encoder that it's used with, so shouldn't matter too much, but we should fix this at some
# point.
return len(model)
elif isinstance(model, PreTrainedTokenizerBase):
# Catch-all for other tokenizer types (e.g., Qwen2Tokenizer, Qwen3Tokenizer).
# Tokenizers are small relative to models, so returning 0 is acceptable.
return 0
else:
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
# supported model types.
@@ -156,6 +160,7 @@ def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, var
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
(".gguf",), # gguf quantized
]
for file_format in formats:

View File

@@ -95,13 +95,15 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
self,
variant: Optional[ModelRepoVariant] = None,
subfolder: Optional[Path] = None,
subfolders: Optional[List[Path]] = None,
session: Optional[Session] = None,
) -> List[RemoteModelFile]:
"""
Return list of downloadable files, filtering by variant and subfolder, if any.
Return list of downloadable files, filtering by variant and subfolder(s), if any.
:param variant: Return model files needed to reconstruct the indicated variant
:param subfolder: Return model files from the designated subfolder only
:param subfolder: Return model files from the designated subfolder only (deprecated, use subfolders)
:param subfolders: Return model files from the designated subfolders
:param session: A request.Session object used for internet-free testing
Note that there is special variant-filtering behavior here:
@@ -111,10 +113,15 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
session = session or Session()
configure_http_backend(backend_factory=lambda: session) # used in testing
paths = filter_files([x.path for x in self.files], variant, subfolder) # all files in the model
prefix = f"{subfolder}/" if subfolder else ""
paths = filter_files([x.path for x in self.files], variant, subfolder, subfolders) # all files in the model
# Determine prefix for model_index.json check - only applies for single subfolder
prefix = ""
if subfolder and not subfolders:
prefix = f"{subfolder}/"
# the next step reads model_index.json to determine which subdirectories belong
# to the model
# to the model (only for single subfolder case)
if Path(f"{prefix}model_index.json") in paths:
url = hf_hub_url(self.id, filename="model_index.json", subfolder=str(subfolder) if subfolder else None)
resp = session.get(url)

View File

@@ -84,6 +84,9 @@ class ModelOnDisk:
path = self.resolve_weight_file(path)
if path in self._state_dict_cache:
return self._state_dict_cache[path]
with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
scan_result = scan_file_path(path)

View File

@@ -690,6 +690,178 @@ flux_fill = StarterModel(
)
# endregion
# region FLUX.2 Klein
flux2_vae = StarterModel(
name="FLUX.2 VAE",
base=BaseModelType.Flux2,
source="black-forest-labs/FLUX.2-klein-4B::vae",
description="FLUX.2 VAE (16-channel, same architecture as FLUX.1 VAE). ~335MB",
type=ModelType.VAE,
)
flux2_klein_qwen3_4b_encoder = StarterModel(
name="FLUX.2 Klein Qwen3 4B Encoder",
base=BaseModelType.Any,
source="black-forest-labs/FLUX.2-klein-4B::text_encoder+tokenizer",
description="Qwen3 4B text encoder for FLUX.2 Klein 4B (also compatible with Z-Image). ~8GB",
type=ModelType.Qwen3Encoder,
)
flux2_klein_qwen3_8b_encoder = StarterModel(
name="FLUX.2 Klein Qwen3 8B Encoder",
base=BaseModelType.Any,
source="black-forest-labs/FLUX.2-klein-9B::text_encoder+tokenizer",
description="Qwen3 8B text encoder for FLUX.2 Klein 9B models. ~16GB",
type=ModelType.Qwen3Encoder,
)
flux2_klein_4b = StarterModel(
name="FLUX.2 Klein 4B (Diffusers)",
base=BaseModelType.Flux2,
source="black-forest-labs/FLUX.2-klein-4B",
description="FLUX.2 Klein 4B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~10GB",
type=ModelType.Main,
)
flux2_klein_4b_single = StarterModel(
name="FLUX.2 Klein 4B",
base=BaseModelType.Flux2,
source="https://huggingface.co/black-forest-labs/FLUX.2-klein-4B/resolve/main/flux-2-klein-4b.safetensors",
description="FLUX.2 Klein 4B standalone transformer. Installs with VAE and Qwen3 4B encoder. ~8GB",
type=ModelType.Main,
dependencies=[flux2_vae, flux2_klein_qwen3_4b_encoder],
)
flux2_klein_4b_fp8 = StarterModel(
name="FLUX.2 Klein 4B (FP8)",
base=BaseModelType.Flux2,
source="https://huggingface.co/black-forest-labs/FLUX.2-klein-4b-fp8/resolve/main/flux-2-klein-4b-fp8.safetensors",
description="FLUX.2 Klein 4B FP8 quantized - smaller and faster. Installs with VAE and Qwen3 4B encoder. ~4GB",
type=ModelType.Main,
dependencies=[flux2_vae, flux2_klein_qwen3_4b_encoder],
)
flux2_klein_9b = StarterModel(
name="FLUX.2 Klein 9B (Diffusers)",
base=BaseModelType.Flux2,
source="black-forest-labs/FLUX.2-klein-9B",
description="FLUX.2 Klein 9B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~20GB",
type=ModelType.Main,
)
flux2_klein_9b_fp8 = StarterModel(
name="FLUX.2 Klein 9B (FP8)",
base=BaseModelType.Flux2,
source="https://huggingface.co/black-forest-labs/FLUX.2-klein-9b-fp8/resolve/main/flux-2-klein-9b-fp8.safetensors",
description="FLUX.2 Klein 9B FP8 quantized - more efficient than full precision. Installs with VAE and Qwen3 8B encoder. ~9.5GB",
type=ModelType.Main,
dependencies=[flux2_vae, flux2_klein_qwen3_8b_encoder],
)
flux2_klein_4b_gguf_q4 = StarterModel(
name="FLUX.2 Klein 4B (GGUF Q4)",
base=BaseModelType.Flux2,
source="https://huggingface.co/unsloth/FLUX.2-klein-4B-GGUF/resolve/main/flux-2-klein-4b-Q4_K_M.gguf",
description="FLUX.2 Klein 4B GGUF Q4_K_M quantized - runs on 6-8GB VRAM. Installs with VAE and Qwen3 4B encoder. ~2.6GB",
type=ModelType.Main,
format=ModelFormat.GGUFQuantized,
dependencies=[flux2_vae, flux2_klein_qwen3_4b_encoder],
)
flux2_klein_4b_gguf_q8 = StarterModel(
name="FLUX.2 Klein 4B (GGUF Q8)",
base=BaseModelType.Flux2,
source="https://huggingface.co/unsloth/FLUX.2-klein-4B-GGUF/resolve/main/flux-2-klein-4b-Q8_0.gguf",
description="FLUX.2 Klein 4B GGUF Q8_0 quantized - higher quality than Q4. Installs with VAE and Qwen3 4B encoder. ~4.3GB",
type=ModelType.Main,
format=ModelFormat.GGUFQuantized,
dependencies=[flux2_vae, flux2_klein_qwen3_4b_encoder],
)
flux2_klein_9b_gguf_q4 = StarterModel(
name="FLUX.2 Klein 9B (GGUF Q4)",
base=BaseModelType.Flux2,
source="https://huggingface.co/unsloth/FLUX.2-klein-9B-GGUF/resolve/main/flux-2-klein-9b-Q4_K_M.gguf",
description="FLUX.2 Klein 9B GGUF Q4_K_M quantized - runs on 12GB+ VRAM. Installs with VAE and Qwen3 8B encoder. ~5.8GB",
type=ModelType.Main,
format=ModelFormat.GGUFQuantized,
dependencies=[flux2_vae, flux2_klein_qwen3_8b_encoder],
)
flux2_klein_9b_gguf_q8 = StarterModel(
name="FLUX.2 Klein 9B (GGUF Q8)",
base=BaseModelType.Flux2,
source="https://huggingface.co/unsloth/FLUX.2-klein-9B-GGUF/resolve/main/flux-2-klein-9b-Q8_0.gguf",
description="FLUX.2 Klein 9B GGUF Q8_0 quantized - higher quality than Q4. Installs with VAE and Qwen3 8B encoder. ~10GB",
type=ModelType.Main,
format=ModelFormat.GGUFQuantized,
dependencies=[flux2_vae, flux2_klein_qwen3_8b_encoder],
)
# endregion
# region Z-Image
z_image_qwen3_encoder = StarterModel(
name="Z-Image Qwen3 Text Encoder",
base=BaseModelType.Any,
source="Tongyi-MAI/Z-Image-Turbo::text_encoder+tokenizer",
description="Qwen3 4B text encoder with tokenizer for Z-Image (full precision). ~8GB",
type=ModelType.Qwen3Encoder,
)
z_image_qwen3_encoder_quantized = StarterModel(
name="Z-Image Qwen3 Text Encoder (quantized)",
base=BaseModelType.Any,
source="https://huggingface.co/worstplayer/Z-Image_Qwen_3_4b_text_encoder_GGUF/resolve/main/Qwen_3_4b-Q6_K.gguf",
description="Qwen3 4B text encoder for Z-Image quantized to GGUF Q6_K format. ~3.3GB",
type=ModelType.Qwen3Encoder,
format=ModelFormat.GGUFQuantized,
)
z_image_turbo = StarterModel(
name="Z-Image Turbo",
base=BaseModelType.ZImage,
source="Tongyi-MAI/Z-Image-Turbo",
description="Z-Image Turbo - fast 6B parameter text-to-image model with 8 inference steps. Supports bilingual prompts (English & Chinese). ~13GB",
type=ModelType.Main,
)
z_image_turbo_quantized = StarterModel(
name="Z-Image Turbo (quantized)",
base=BaseModelType.ZImage,
source="https://huggingface.co/leejet/Z-Image-Turbo-GGUF/resolve/main/z_image_turbo-Q4_K.gguf",
description="Z-Image Turbo quantized to GGUF Q4_K format. Requires standalone Qwen3 text encoder and Flux VAE. ~4GB",
type=ModelType.Main,
format=ModelFormat.GGUFQuantized,
dependencies=[z_image_qwen3_encoder_quantized, flux_vae],
)
z_image_turbo_q8 = StarterModel(
name="Z-Image Turbo (Q8)",
base=BaseModelType.ZImage,
source="https://huggingface.co/leejet/Z-Image-Turbo-GGUF/resolve/main/z_image_turbo-Q8_0.gguf",
description="Z-Image Turbo quantized to GGUF Q8_0 format. Higher quality, larger size. Requires standalone Qwen3 text encoder and Flux VAE. ~6.6GB",
type=ModelType.Main,
format=ModelFormat.GGUFQuantized,
dependencies=[z_image_qwen3_encoder_quantized, flux_vae],
)
z_image_controlnet_union = StarterModel(
name="Z-Image ControlNet Union",
base=BaseModelType.ZImage,
source="https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.1/resolve/main/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors",
description="Unified ControlNet for Z-Image Turbo supporting Canny, HED, Depth, Pose, MLSD, and Inpainting modes.",
type=ModelType.ControlNet,
)
z_image_controlnet_tile = StarterModel(
name="Z-Image ControlNet Tile",
base=BaseModelType.ZImage,
source="https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.1/resolve/main/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors",
description="Dedicated Tile ControlNet for Z-Image Turbo. Useful for upscaling and adding detail. ~6.7GB",
type=ModelType.ControlNet,
)
# endregion
# List of starter models, displayed on the frontend.
# The order/sort of this list is not changed by the frontend - set it how you want it here.
STARTER_MODELS: list[StarterModel] = [
@@ -763,9 +935,28 @@ STARTER_MODELS: list[StarterModel] = [
flux_redux,
llava_onevision,
flux_fill,
flux2_vae,
flux2_klein_4b,
flux2_klein_4b_single,
flux2_klein_4b_fp8,
flux2_klein_9b,
flux2_klein_9b_fp8,
flux2_klein_4b_gguf_q4,
flux2_klein_4b_gguf_q8,
flux2_klein_9b_gguf_q4,
flux2_klein_9b_gguf_q8,
flux2_klein_qwen3_4b_encoder,
flux2_klein_qwen3_8b_encoder,
cogview4,
flux_krea,
flux_krea_quantized,
z_image_turbo,
z_image_turbo_quantized,
z_image_turbo_q8,
z_image_qwen3_encoder,
z_image_qwen3_encoder_quantized,
z_image_controlnet_union,
z_image_controlnet_tile,
]
sd1_bundle: list[StarterModel] = [
@@ -820,10 +1011,26 @@ flux_bundle: list[StarterModel] = [
flux_krea_quantized,
]
zimage_bundle: list[StarterModel] = [
z_image_turbo_quantized,
z_image_qwen3_encoder_quantized,
z_image_controlnet_union,
z_image_controlnet_tile,
flux_vae,
]
flux2_klein_bundle: list[StarterModel] = [
flux2_klein_4b_gguf_q4,
flux2_vae,
flux2_klein_qwen3_4b_encoder,
]
STARTER_BUNDLES: dict[str, StarterModelBundle] = {
BaseModelType.StableDiffusion1: StarterModelBundle(name="Stable Diffusion 1.5", models=sd1_bundle),
BaseModelType.StableDiffusionXL: StarterModelBundle(name="SDXL", models=sdxl_bundle),
BaseModelType.Flux: StarterModelBundle(name="FLUX.1 dev", models=flux_bundle),
BaseModelType.Flux2: StarterModelBundle(name="FLUX.2 Klein", models=flux2_klein_bundle),
BaseModelType.ZImage: StarterModelBundle(name="Z-Image Turbo", models=zimage_bundle),
}
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"

Some files were not shown because too many files have changed in this diff Show More