mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Compare commits
255 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a6760acad | ||
|
|
91c1e64f0b | ||
|
|
cbe528eef7 | ||
|
|
4081f8701e | ||
|
|
5649b60672 | ||
|
|
714eeed74d | ||
|
|
656b50e6ad | ||
|
|
0263f4032c | ||
|
|
dd87e0a946 | ||
|
|
438eea1159 | ||
|
|
d93e451831 | ||
|
|
efc7a262b7 | ||
|
|
a873ce0175 | ||
|
|
9ee7baaba5 | ||
|
|
fb5c43a905 | ||
|
|
0f69f4bb9a | ||
|
|
8a355e66fa | ||
|
|
b811602b38 | ||
|
|
0716b2fa75 | ||
|
|
4d71609115 | ||
|
|
0ecb903ae2 | ||
|
|
736f4ffeb1 | ||
|
|
2102b43edc | ||
|
|
5801e59e2b | ||
|
|
5fc950b745 | ||
|
|
63dec985cd | ||
|
|
03cdd6df2e | ||
|
|
99f4070ce7 | ||
|
|
cf07f8be14 | ||
|
|
1f0d92defc | ||
|
|
68089ca688 | ||
|
|
32e2132948 | ||
|
|
bec3586930 | ||
|
|
8bf4d1ea59 | ||
|
|
fd7a3aebd2 | ||
|
|
72491e2153 | ||
|
|
3d0725072d | ||
|
|
0ae7392c81 | ||
|
|
cff20b45f3 | ||
|
|
b92c6ae633 | ||
|
|
729bae19a5 | ||
|
|
fcc81f17a5 | ||
|
|
27ae70a428 | ||
|
|
82819cdadc | ||
|
|
b2b8820519 | ||
|
|
bb6c544603 | ||
|
|
8a18914637 | ||
|
|
d66df9a0d0 | ||
|
|
5c00684701 | ||
|
|
d93ce6ac42 | ||
|
|
13bf5feb4d | ||
|
|
53ab178edd | ||
|
|
2d8317f1aa | ||
|
|
04f815638c | ||
|
|
d6ad6a2dcb | ||
|
|
784503e484 | ||
|
|
da2809b000 | ||
|
|
53c34eb95e | ||
|
|
18fc822d37 | ||
|
|
89dc50bd7c | ||
|
|
d34655fd58 | ||
|
|
c1a8300e96 | ||
|
|
9c5b2f6498 | ||
|
|
dbb4a07a8f | ||
|
|
f66a1a38c8 | ||
|
|
be2635161c | ||
|
|
384a1a689d | ||
|
|
0021404639 | ||
|
|
a05a626644 | ||
|
|
97b82d752e | ||
|
|
f29820a7ba | ||
|
|
47a634d8fb | ||
|
|
768f3dbde0 | ||
|
|
1ca589ea10 | ||
|
|
3a21e7699f | ||
|
|
56fd7bc7c4 | ||
|
|
2425005aad | ||
|
|
2ccadd1834 | ||
|
|
5cef8bd364 | ||
|
|
8a6d593fe8 | ||
|
|
14309562b8 | ||
|
|
9f8f9965f9 | ||
|
|
44a21a348d | ||
|
|
81d83d5aab | ||
|
|
d99707fdcb | ||
|
|
252dd5b426 | ||
|
|
f922f6c634 | ||
|
|
be0cbe046c | ||
|
|
e39b880f6d | ||
|
|
4f8ec07d2f | ||
|
|
689953e3cf | ||
|
|
61c2589e39 | ||
|
|
8cf4c6944a | ||
|
|
db228ddc4f | ||
|
|
858c94b575 | ||
|
|
252794d717 | ||
|
|
7847ccea13 | ||
|
|
1bcf589d19 | ||
|
|
132a48497b | ||
|
|
f49e1b8dae | ||
|
|
e7233efb79 | ||
|
|
3b2d2ef10a | ||
|
|
66974841f1 | ||
|
|
87608ade45 | ||
|
|
1e83aeeb79 | ||
|
|
1c76d295a2 | ||
|
|
384250ff8c | ||
|
|
6c3ce8e7e9 | ||
|
|
d658ef4322 | ||
|
|
8d880ef5a0 | ||
|
|
c6775cc999 | ||
|
|
d44b99ae0a | ||
|
|
1675712094 | ||
|
|
2924d052c5 | ||
|
|
f1624a6215 | ||
|
|
b7e28e4fa6 | ||
|
|
d7d051200f | ||
|
|
0f830ddd00 | ||
|
|
9617140b7f | ||
|
|
bc4783028f | ||
|
|
16fedfb538 | ||
|
|
d781a3b8a2 | ||
|
|
7182ff26dc | ||
|
|
95ee27d5c0 | ||
|
|
b4f05d3fe7 | ||
|
|
8deafabe6b | ||
|
|
1bd1c76a2c | ||
|
|
56fd1da888 | ||
|
|
0956ce0cd3 | ||
|
|
d42bf9c941 | ||
|
|
d403587c7f | ||
|
|
355c985cc3 | ||
|
|
41742146e2 | ||
|
|
eb516e1998 | ||
|
|
0b1befa9ab | ||
|
|
bd678b1c95 | ||
|
|
56bef0b089 | ||
|
|
99fc1243cb | ||
|
|
a7205e4e36 | ||
|
|
65efc3db7d | ||
|
|
de1aa557b8 | ||
|
|
b9493ddce7 | ||
|
|
ca14c5c9e1 | ||
|
|
ddb85ca669 | ||
|
|
5b69403ba8 | ||
|
|
ac245cbf6c | ||
|
|
5be1e03d73 | ||
|
|
87314142b5 | ||
|
|
4cb9b8d97d | ||
|
|
83deb0233e | ||
|
|
8ebb6dd3d9 | ||
|
|
b7afd9b5b3 | ||
|
|
4987b4da1c | ||
|
|
a21b7792d8 | ||
|
|
8819cc30be | ||
|
|
9d1de81fe2 | ||
|
|
1e15b8c106 | ||
|
|
21138e5d52 | ||
|
|
8d76b4e4d4 | ||
|
|
9662d1fdb6 | ||
|
|
39114b0ad0 | ||
|
|
3fe5f62c48 | ||
|
|
73c6b31011 | ||
|
|
b16717bbf8 | ||
|
|
c3217d8a08 | ||
|
|
f82bcd40fc | ||
|
|
2500153ed8 | ||
|
|
75a14e2a4b | ||
|
|
9bbd2b3f11 | ||
|
|
c26445253c | ||
|
|
5a0b227256 | ||
|
|
1b5d91d1cf | ||
|
|
a748519e92 | ||
|
|
90e34002f0 | ||
|
|
7068cf956a | ||
|
|
aa764f8bf4 | ||
|
|
73be5e5d35 | ||
|
|
259304bac5 | ||
|
|
2be701cfe3 | ||
|
|
874b547598 | ||
|
|
7b9ce35806 | ||
|
|
84f3e44a5d | ||
|
|
5264b7511c | ||
|
|
f8b1f42f6d | ||
|
|
e1acb636d8 | ||
|
|
b08accd4be | ||
|
|
3668d5b83b | ||
|
|
1c13ca8159 | ||
|
|
3ed0e55d9d | ||
|
|
8db8aa8594 | ||
|
|
456d578f20 | ||
|
|
ab6b6721dc | ||
|
|
93a587da90 | ||
|
|
87bebf9c28 | ||
|
|
f417c269d1 | ||
|
|
4ce0ef5260 | ||
|
|
39cdcdc9e8 | ||
|
|
926923bb2b | ||
|
|
8785d9a3a9 | ||
|
|
1e72feb744 | ||
|
|
3ee24cbdde | ||
|
|
f9605e18a0 | ||
|
|
8551ff8569 | ||
|
|
fb1a99b650 | ||
|
|
3b5d9c26d3 | ||
|
|
0a986c2720 | ||
|
|
3e862ced25 | ||
|
|
ba2475c3f0 | ||
|
|
841372944f | ||
|
|
e9d52734d1 | ||
|
|
2e0cd4d68c | ||
|
|
b28d58b8ce | ||
|
|
4a1710b795 | ||
|
|
9f6d04c690 | ||
|
|
66729ea9eb | ||
|
|
280202908a | ||
|
|
2b062b21cd | ||
|
|
6f9f8e57ac | ||
|
|
eaf4742799 | ||
|
|
f05ea28cbd | ||
|
|
13ac16e2c0 | ||
|
|
eb3f1c9a61 | ||
|
|
c6a9847bbd | ||
|
|
a2e109b3c2 | ||
|
|
5642099a40 | ||
|
|
382d85ee23 | ||
|
|
abcc987f6f | ||
|
|
36e400dd5d | ||
|
|
0113931956 | ||
|
|
8d6e00533e | ||
|
|
10eebb6c0c | ||
|
|
68bcf2ebe0 | ||
|
|
ad0b09c738 | ||
|
|
737cf795e8 | ||
|
|
6192ff5abb | ||
|
|
066ba5fb19 | ||
|
|
2fb4c92310 | ||
|
|
3fdceba5fc | ||
|
|
ae4bcc08f2 | ||
|
|
e1d88f93ca | ||
|
|
4ad2574835 | ||
|
|
0e3d4beb48 | ||
|
|
dcfd4ea756 | ||
|
|
093f8d6720 | ||
|
|
22fdfab764 | ||
|
|
7a0b157fb8 | ||
|
|
563da9ee8e | ||
|
|
c8d9cdc22e | ||
|
|
e9c2411da9 | ||
|
|
90989291ed | ||
|
|
d04fc343f0 | ||
|
|
437594915a | ||
|
|
875aba8979 | ||
|
|
61d13f20ea | ||
|
|
3b0dd5768b |
39
.github/CODEOWNERS
vendored
39
.github/CODEOWNERS
vendored
@@ -1,31 +1,32 @@
|
||||
# continuous integration
|
||||
/.github/workflows/ @lstein @blessedcoolant @hipsterusername @ebr @jazzhaiku @psychedelicious
|
||||
/.github/workflows/ @lstein @blessedcoolant
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @blessedcoolant @hipsterusername @psychedelicious
|
||||
/mkdocs.yml @lstein @blessedcoolant @hipsterusername @psychedelicious
|
||||
# documentation - anyone with write privileges can review
|
||||
/docs/
|
||||
/mkdocs.yml
|
||||
|
||||
# nodes
|
||||
/invokeai/app/ @blessedcoolant @psychedelicious @hipsterusername @jazzhaiku
|
||||
/invokeai/app/ @blessedcoolant @lstein @dunkeroni @JPPhoto
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @lstein @blessedcoolant @psychedelicious @hipsterusername
|
||||
/docker/ @lstein @blessedcoolant @psychedelicious @hipsterusername @ebr
|
||||
/scripts/ @ebr @lstein @psychedelicious @hipsterusername
|
||||
/installer/ @lstein @ebr @psychedelicious @hipsterusername
|
||||
/invokeai/assets @lstein @ebr @psychedelicious @hipsterusername
|
||||
/invokeai/configs @lstein @psychedelicious @hipsterusername
|
||||
/invokeai/version @lstein @blessedcoolant @psychedelicious @hipsterusername
|
||||
/pyproject.toml @lstein @blessedcoolant
|
||||
/docker/ @lstein @blessedcoolant
|
||||
/scripts/ @lstein
|
||||
/installer/ @lstein
|
||||
/invokeai/assets @lstein
|
||||
/invokeai/configs @lstein
|
||||
/invokeai/version @lstein @blessedcoolant
|
||||
|
||||
# web ui
|
||||
/invokeai/frontend @blessedcoolant @psychedelicious @lstein @maryhipp @hipsterusername
|
||||
/invokeai/frontend @blessedcoolant @lstein @dunkeroni
|
||||
|
||||
# generation, model management, postprocessing
|
||||
/invokeai/backend @lstein @blessedcoolant @hipsterusername @jazzhaiku @psychedelicious @maryhipp
|
||||
/invokeai/backend @lstein @blessedcoolant @dunkeroni @JPPhoto
|
||||
|
||||
# front ends
|
||||
/invokeai/frontend/CLI @lstein @psychedelicious @hipsterusername
|
||||
/invokeai/frontend/install @lstein @ebr @psychedelicious @hipsterusername
|
||||
/invokeai/frontend/merge @lstein @blessedcoolant @psychedelicious @hipsterusername
|
||||
/invokeai/frontend/training @lstein @blessedcoolant @psychedelicious @hipsterusername
|
||||
/invokeai/frontend/web @psychedelicious @blessedcoolant @maryhipp @hipsterusername
|
||||
/invokeai/frontend/CLI @lstein
|
||||
/invokeai/frontend/install @lstein
|
||||
/invokeai/frontend/merge @lstein @blessedcoolant
|
||||
/invokeai/frontend/training @lstein @blessedcoolant
|
||||
/invokeai/frontend/web @blessedcoolant @lstein @dunkeroni @Pfannkuchensack
|
||||
|
||||
|
||||
6
.github/workflows/build-container.yml
vendored
6
.github/workflows/build-container.yml
vendored
@@ -53,8 +53,10 @@ jobs:
|
||||
df -h
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||
sudo swapoff /mnt/swapfile
|
||||
sudo rm -rf /mnt/swapfile
|
||||
if [ -f /mnt/swapfile ]; then
|
||||
sudo swapoff /mnt/swapfile
|
||||
sudo rm -rf /mnt/swapfile
|
||||
fi
|
||||
if [ -d /mnt ]; then
|
||||
sudo chmod -R 777 /mnt
|
||||
echo '{"data-root": "/mnt/docker-root"}' | sudo tee /etc/docker/daemon.json
|
||||
|
||||
1
.github/workflows/close-inactive-issues.yml
vendored
1
.github/workflows/close-inactive-issues.yml
vendored
@@ -23,6 +23,7 @@ jobs:
|
||||
close-issue-message: "Due to inactivity, this issue was automatically closed. If you are still experiencing the issue, please recreate the issue."
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
only-labels: "bug"
|
||||
exempt-issue-labels: "Active Issue"
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
operations-per-run: 500
|
||||
|
||||
6
.github/workflows/mkdocs-material.yml
vendored
6
.github/workflows/mkdocs-material.yml
vendored
@@ -22,12 +22,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: setup python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: '3.12'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
|
||||
6
.github/workflows/typegen-checks.yml
vendored
6
.github/workflows/typegen-checks.yml
vendored
@@ -46,8 +46,10 @@ jobs:
|
||||
df -h
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||
sudo swapoff /mnt/swapfile
|
||||
sudo rm -rf /mnt/swapfile
|
||||
if [ -f /mnt/swapfile ]; then
|
||||
sudo swapoff /mnt/swapfile
|
||||
sudo rm -rf /mnt/swapfile
|
||||
fi
|
||||
echo "----- Free space after cleanup"
|
||||
df -h
|
||||
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -192,3 +192,6 @@ installer/InvokeAI-Installer/
|
||||
.aider*
|
||||
|
||||
.claude/
|
||||
|
||||
# Weblate configuration file
|
||||
weblate.ini
|
||||
@@ -16,6 +16,12 @@ Invoke is a leading creative engine built to empower professionals and enthusias
|
||||
|
||||

|
||||
|
||||
---
|
||||
> ## 📣 Are you a new or returning InvokeAI user?
|
||||
> Take our first annual [User's Survey](https://forms.gle/rCE5KuQ7Wfrd1UnS7)
|
||||
|
||||
---
|
||||
|
||||
# Documentation
|
||||
|
||||
| **Quick Links** |
|
||||
|
||||
@@ -16,7 +16,9 @@ The launcher uses GitHub as the source of truth for available releases.
|
||||
|
||||
## General Prep
|
||||
|
||||
Make a developer call-out for PRs to merge. Merge and test things out. Bump the version by editing `invokeai/version/invokeai_version.py`.
|
||||
Make a developer call-out for PRs to merge. Merge and test things
|
||||
out. Create a branch with a name like user/chore/vX.X.X-prep and bump the version by editing
|
||||
`invokeai/version/invokeai_version.py` and commit locally.
|
||||
|
||||
## Release Workflow
|
||||
|
||||
@@ -26,14 +28,14 @@ It is triggered on **tag push**, when the tag matches `v*`.
|
||||
|
||||
### Triggering the Workflow
|
||||
|
||||
Ensure all commits that should be in the release are merged, and you have pulled them locally.
|
||||
|
||||
Double-check that you have checked out the commit that will represent the release (typically the latest commit on `main`).
|
||||
Ensure all commits that should be in the release are merged into this branch, and that you have pulled them locally.
|
||||
|
||||
Run `make tag-release` to tag the current commit and kick off the workflow. You will be prompted to provide a message - use the version specifier.
|
||||
|
||||
If this version's tag already exists for some reason (maybe you had to make a last minute change), the script will overwrite it.
|
||||
|
||||
Push the commit to trigger the workflow.
|
||||
|
||||
> In case you cannot use the Make target, the release may also be dispatched [manually] via GH.
|
||||
|
||||
### Workflow Jobs and Process
|
||||
@@ -89,7 +91,7 @@ The publish jobs will not run if any of the previous jobs fail.
|
||||
|
||||
They use [GitHub environments], which are configured as [trusted publishers] on PyPI.
|
||||
|
||||
Both jobs require a @hipsterusername or @psychedelicious to approve them from the workflow's **Summary** tab.
|
||||
Both jobs require a @lstein or @blessedcoolant to approve them from the workflow's **Summary** tab.
|
||||
|
||||
- Click the **Review deployments** button
|
||||
- Select the environment (either `testpypi` or `pypi` - typically you select both)
|
||||
@@ -101,7 +103,7 @@ Both jobs require a @hipsterusername or @psychedelicious to approve them from th
|
||||
|
||||
Check the [python infrastructure status page] for incidents.
|
||||
|
||||
If there are no incidents, contact @hipsterusername or @lstein, who have owner access to GH and PyPI, to see if access has expired or something like that.
|
||||
If there are no incidents, contact @lstein or @blessedcoolant, who have owner access to GH and PyPI, to see if access has expired or something like that.
|
||||
|
||||
#### `publish-testpypi` Job
|
||||
|
||||
|
||||
295
docs/contributing/HOTKEYS.md
Normal file
295
docs/contributing/HOTKEYS.md
Normal file
@@ -0,0 +1,295 @@
|
||||
# Hotkeys System
|
||||
|
||||
This document describes the technical implementation of the customizable hotkeys system in InvokeAI.
|
||||
|
||||
> **Note:** For user-facing documentation on how to use customizable hotkeys, see [Hotkeys Feature Documentation](../features/hotkeys.md).
|
||||
|
||||
## Overview
|
||||
|
||||
The hotkeys system allows users to customize keyboard shortcuts throughout the application. All hotkeys are:
|
||||
- Centrally defined and managed
|
||||
- Customizable by users
|
||||
- Persisted across sessions
|
||||
- Type-safe and validated
|
||||
|
||||
## Architecture
|
||||
|
||||
The customizable hotkeys feature is built on top of the existing hotkey system with the following components:
|
||||
|
||||
### 1. Hotkeys State Slice (`hotkeysSlice.ts`)
|
||||
|
||||
Location: `invokeai/frontend/web/src/features/system/store/hotkeysSlice.ts`
|
||||
|
||||
**Responsibilities:**
|
||||
- Stores custom hotkey mappings in Redux state
|
||||
- Persisted to IndexedDB using `redux-remember`
|
||||
- Provides actions to change, reset individual, or reset all hotkeys
|
||||
|
||||
**State Shape:**
|
||||
```typescript
|
||||
{
|
||||
_version: 1,
|
||||
customHotkeys: {
|
||||
'app.invoke': ['mod+enter'],
|
||||
'canvas.undo': ['mod+z'],
|
||||
// ...
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Actions:**
|
||||
- `hotkeyChanged(id, hotkeys)` - Update a single hotkey
|
||||
- `hotkeyReset(id)` - Reset a single hotkey to default
|
||||
- `allHotkeysReset()` - Reset all hotkeys to defaults
|
||||
|
||||
### 2. useHotkeyData Hook (`useHotkeyData.ts`)
|
||||
|
||||
Location: `invokeai/frontend/web/src/features/system/components/HotkeysModal/useHotkeyData.ts`
|
||||
|
||||
**Responsibilities:**
|
||||
- Defines all default hotkeys
|
||||
- Merges default hotkeys with custom hotkeys from the store
|
||||
- Returns the effective hotkeys that should be used throughout the app
|
||||
- Provides platform-specific key translations (Ctrl/Cmd, Alt/Option)
|
||||
|
||||
**Key Functions:**
|
||||
- `useHotkeyData()` - Returns all hotkeys organized by category
|
||||
- `useRegisteredHotkeys()` - Hook to register a hotkey in a component
|
||||
|
||||
### 3. HotkeyEditor Component (`HotkeyEditor.tsx`)
|
||||
|
||||
Location: `invokeai/frontend/web/src/features/system/components/HotkeysModal/HotkeyEditor.tsx`
|
||||
|
||||
**Features:**
|
||||
- Inline editor with input field
|
||||
- Modifier buttons (Mod, Ctrl, Shift, Alt) for quick insertion
|
||||
- Live preview of hotkey combinations
|
||||
- Validation with visual feedback
|
||||
- Help tooltip with syntax examples
|
||||
- Save/cancel/reset buttons
|
||||
|
||||
**Smart Features:**
|
||||
- Automatic `+` insertion between modifiers
|
||||
- Cursor position preservation
|
||||
- Validation prevents invalid combinations (e.g., modifier-only keys)
|
||||
|
||||
### 4. HotkeysModal Component (`HotkeysModal.tsx`)
|
||||
|
||||
Location: `invokeai/frontend/web/src/features/system/components/HotkeysModal/HotkeysModal.tsx`
|
||||
|
||||
**Features:**
|
||||
- View Mode / Edit Mode toggle
|
||||
- Search functionality
|
||||
- Category-based organization
|
||||
- Shows HotkeyEditor components when in edit mode
|
||||
- "Reset All to Default" button in edit mode
|
||||
|
||||
## Data Flow
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ 1. User opens Hotkeys Modal │
|
||||
│ 2. User clicks "Edit Mode" button │
|
||||
│ 3. User clicks edit icon next to a hotkey │
|
||||
│ 4. User enters new hotkey(s) using editor │
|
||||
│ 5. User clicks save or presses Enter │
|
||||
│ 6. Custom hotkey stored via hotkeyChanged() action │
|
||||
│ 7. Redux state persisted to IndexedDB (redux-remember) │
|
||||
│ 8. useHotkeyData() hook picks up the change │
|
||||
│ 9. All components using useRegisteredHotkeys() get update │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Hotkey Format
|
||||
|
||||
Hotkeys use the format from `react-hotkeys-hook` library:
|
||||
|
||||
- **Modifiers:** `mod`, `ctrl`, `shift`, `alt`, `meta`
|
||||
- **Keys:** Letters, numbers, function keys, special keys
|
||||
- **Separator:** `+` between keys in a combination
|
||||
- **Multiple hotkeys:** Comma-separated (e.g., `mod+a, ctrl+b`)
|
||||
|
||||
**Examples:**
|
||||
- `mod+enter` - Mod key + Enter
|
||||
- `shift+x` - Shift + X
|
||||
- `ctrl+shift+a` - Control + Shift + A
|
||||
- `f1, f2` - F1 or F2 (alternatives)
|
||||
|
||||
## Developer Guide
|
||||
|
||||
### Using Hotkeys in Components
|
||||
|
||||
To use a hotkey in a component:
|
||||
|
||||
```tsx
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
|
||||
const MyComponent = () => {
|
||||
const handleAction = useCallback(() => {
|
||||
// Your action here
|
||||
}, []);
|
||||
|
||||
// This automatically uses custom hotkeys if configured
|
||||
useRegisteredHotkeys({
|
||||
id: 'myAction',
|
||||
category: 'app', // or 'canvas', 'viewer', 'gallery', 'workflows'
|
||||
callback: handleAction,
|
||||
options: { enabled: true, preventDefault: true },
|
||||
dependencies: [handleAction]
|
||||
});
|
||||
|
||||
// ...
|
||||
};
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `enabled` - Whether the hotkey is active
|
||||
- `preventDefault` - Prevent default browser behavior
|
||||
- `enableOnFormTags` - Allow hotkey in form elements (default: false)
|
||||
|
||||
### Adding New Hotkeys
|
||||
|
||||
To add a new hotkey to the system:
|
||||
|
||||
#### 1. Add Translation Strings
|
||||
|
||||
In `invokeai/frontend/web/public/locales/en.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"hotkeys": {
|
||||
"app": {
|
||||
"myAction": {
|
||||
"title": "My Action",
|
||||
"desc": "Description of what this hotkey does"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### 2. Register the Hotkey
|
||||
|
||||
In `invokeai/frontend/web/src/features/system/components/HotkeysModal/useHotkeyData.ts`:
|
||||
|
||||
```typescript
|
||||
// Inside the appropriate category builder function
|
||||
addHotkey('app', 'myAction', ['mod+k']); // Default binding
|
||||
```
|
||||
|
||||
#### 3. Use the Hotkey
|
||||
|
||||
In your component:
|
||||
|
||||
```typescript
|
||||
useRegisteredHotkeys({
|
||||
id: 'myAction',
|
||||
category: 'app',
|
||||
callback: handleMyAction,
|
||||
options: { enabled: true },
|
||||
dependencies: [handleMyAction]
|
||||
});
|
||||
```
|
||||
|
||||
### Hotkey Categories
|
||||
|
||||
Current categories:
|
||||
- **app** - Global application hotkeys
|
||||
- **canvas** - Canvas/drawing operations
|
||||
- **viewer** - Image viewer operations
|
||||
- **gallery** - Gallery/image grid operations
|
||||
- **workflows** - Node workflow editor
|
||||
|
||||
To add a new category, update `useHotkeyData.ts` and add translations.
|
||||
|
||||
## Testing
|
||||
|
||||
Tests are located in `invokeai/frontend/web/src/features/system/store/hotkeysSlice.test.ts`.
|
||||
|
||||
**Test Coverage:**
|
||||
- Adding custom hotkeys
|
||||
- Updating existing custom hotkeys
|
||||
- Resetting individual hotkeys
|
||||
- Resetting all hotkeys
|
||||
- State persistence and migration
|
||||
|
||||
Run tests with:
|
||||
|
||||
```bash
|
||||
cd invokeai/frontend/web
|
||||
pnpm test:no-watch
|
||||
```
|
||||
|
||||
## Persistence
|
||||
|
||||
Custom hotkeys are persisted using the same mechanism as other app settings:
|
||||
|
||||
- Stored in Redux state under the `hotkeys` slice
|
||||
- Persisted to IndexedDB via `redux-remember`
|
||||
- Automatically loaded when the app starts
|
||||
- Survives page refreshes and browser restarts
|
||||
- Includes migration support for state schema changes
|
||||
|
||||
**State Location:**
|
||||
- IndexedDB database: `invoke`
|
||||
- Store key: `hotkeys`
|
||||
|
||||
## Dependencies
|
||||
|
||||
- **react-hotkeys-hook** (v4.5.0) - Core hotkey handling
|
||||
- **@reduxjs/toolkit** - State management
|
||||
- **redux-remember** - Persistence
|
||||
- **zod** - State validation
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use `mod` instead of `ctrl`** - Automatically maps to Cmd on Mac, Ctrl elsewhere
|
||||
2. **Provide descriptive translations** - Help users understand what each hotkey does
|
||||
3. **Avoid conflicts** - Check existing hotkeys before adding new ones
|
||||
4. **Use preventDefault** - Prevent browser default behavior when appropriate
|
||||
5. **Check enabled state** - Only activate hotkeys when the action is available
|
||||
6. **Use dependencies correctly** - Ensure callbacks are stable with useCallback
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Conditional Hotkeys
|
||||
|
||||
```typescript
|
||||
useRegisteredHotkeys({
|
||||
id: 'save',
|
||||
category: 'app',
|
||||
callback: handleSave,
|
||||
options: {
|
||||
enabled: hasUnsavedChanges && !isLoading, // Only when valid
|
||||
preventDefault: true
|
||||
},
|
||||
dependencies: [hasUnsavedChanges, isLoading, handleSave]
|
||||
});
|
||||
```
|
||||
|
||||
### Multiple Hotkeys for Same Action
|
||||
|
||||
```typescript
|
||||
// In useHotkeyData.ts
|
||||
addHotkey('canvas', 'redo', ['mod+shift+z', 'mod+y']); // Two alternatives
|
||||
```
|
||||
|
||||
### Focus-Scoped Hotkeys
|
||||
|
||||
```typescript
|
||||
import { useFocusRegion } from 'common/hooks/focus';
|
||||
|
||||
const MyComponent = () => {
|
||||
const focusRegionRef = useFocusRegion('myRegion');
|
||||
|
||||
// Hotkey only works when this region has focus
|
||||
useRegisteredHotkeys({
|
||||
id: 'myAction',
|
||||
category: 'app',
|
||||
callback: handleAction,
|
||||
options: { enabled: true }
|
||||
});
|
||||
|
||||
return <div ref={focusRegionRef}>...</div>;
|
||||
};
|
||||
```
|
||||
1254
docs/contributing/NEW_MODEL_INTEGRATION.md
Normal file
1254
docs/contributing/NEW_MODEL_INTEGRATION.md
Normal file
File diff suppressed because it is too large
Load Diff
64
docs/contributing/PR-MERGE-POLICY.md
Normal file
64
docs/contributing/PR-MERGE-POLICY.md
Normal file
@@ -0,0 +1,64 @@
|
||||
# Pull Request Merge Policy
|
||||
|
||||
This document outlines the process for reviewing and merging pull requests (PRs) into the InvokeAI repository.
|
||||
|
||||
## Review Process
|
||||
|
||||
### 1. Assignment
|
||||
|
||||
One of the repository maintainers will assign collaborators to review a pull request. The assigned reviewer(s) will be responsible for conducting the code review.
|
||||
|
||||
### 2. Review and Iteration
|
||||
|
||||
The assignee is responsible for:
|
||||
- Reviewing the PR thoroughly
|
||||
- Providing constructive feedback
|
||||
- Iterating with the PR author until the assignee is satisfied that the PR is fit to merge
|
||||
- Ensuring the PR meets code quality standards, follows project conventions, and doesn't introduce bugs or regressions
|
||||
|
||||
### 3. Approval and Notification
|
||||
|
||||
Once the assignee is satisfied with the PR:
|
||||
- The assignee approves the PR
|
||||
- The assignee alerts one of the maintainers that the PR is ready for merge using the **#request-reviews Discord channel**
|
||||
|
||||
### 4. Final Merge
|
||||
|
||||
One of the maintainers is responsible for:
|
||||
- Performing a final check of the PR
|
||||
- Merging the PR into the appropriate branch
|
||||
|
||||
**Important:** Collaborators are strongly discouraged from merging PRs on their own, except in case of emergency (e.g., critical bug fix and no maintainer is available).
|
||||
|
||||
### 5. Release Policy
|
||||
|
||||
Once a feature release candidate is published, no feature PRs are to
|
||||
be merged into main. Only bugfixes are allowed until the final
|
||||
release.
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Clean Commit History
|
||||
|
||||
To encourage a clean development log, PR authors are encouraged to use `git rebase -i` to suppress trivial commit messages (e.g., `ruff` and `prettier` formatting fixes) after the PR is accepted but before it is merged.
|
||||
|
||||
### Merge Strategy
|
||||
|
||||
The maintainer will perform either a **3-way merge** or **squash merge** when merging a PR into the `main` branch. This approach helps avoid rebase conflict hell and maintains a cleaner project history.
|
||||
|
||||
### Attribution
|
||||
|
||||
The PR author should reference any papers, source code or
|
||||
documentation that they used while creating the code both in the PR
|
||||
and as comments in the code itself. If there are any licensing
|
||||
restrictions, these should be linked to and/or reproduced in the repo
|
||||
root.
|
||||
|
||||
|
||||
## Summary
|
||||
|
||||
This policy ensures that:
|
||||
- All PRs receive proper review from assigned collaborators
|
||||
- Maintainers have final oversight before code enters the main branch
|
||||
- The commit history remains clean and meaningful
|
||||
- Merge conflicts are minimized through appropriate merge strategies
|
||||
80
docs/features/hotkeys.md
Normal file
80
docs/features/hotkeys.md
Normal file
@@ -0,0 +1,80 @@
|
||||
# Customizable Hotkeys
|
||||
|
||||
InvokeAI allows you to customize all keyboard shortcuts (hotkeys) to match your workflow preferences.
|
||||
|
||||
## Features
|
||||
|
||||
- **View All Hotkeys**: See all available keyboard shortcuts in one place
|
||||
- **Customize Any Hotkey**: Change any shortcut to your preference
|
||||
- **Multiple Bindings**: Assign multiple key combinations to the same action
|
||||
- **Smart Validation**: Built-in validation prevents invalid combinations
|
||||
- **Persistent Settings**: Your custom hotkeys are saved and restored across sessions
|
||||
- **Easy Reset**: Reset individual hotkeys or all hotkeys back to defaults
|
||||
|
||||
## How to Use
|
||||
|
||||
### Opening the Hotkeys Modal
|
||||
|
||||
Press `Shift+?` or click the keyboard icon in the application to open the Hotkeys Modal.
|
||||
|
||||
### Viewing Hotkeys
|
||||
|
||||
In **View Mode** (default), you can:
|
||||
- Browse all available hotkeys organized by category (App, Canvas, Gallery, Workflows, etc.)
|
||||
- Search for specific hotkeys using the search bar
|
||||
- See the current key combination for each action
|
||||
|
||||
### Customizing Hotkeys
|
||||
|
||||
1. Click the **Edit Mode** button at the bottom of the Hotkeys Modal
|
||||
2. Find the hotkey you want to change
|
||||
3. Click the **pencil icon** next to it
|
||||
4. The editor will appear with:
|
||||
- **Input field**: Enter your new hotkey combination
|
||||
- **Modifier buttons**: Quick-insert Mod, Ctrl, Shift, Alt keys
|
||||
- **Help icon** (?): Shows syntax examples and valid keys
|
||||
- **Live preview**: See how your hotkey will look
|
||||
|
||||
5. Enter your new hotkey using the format:
|
||||
- `mod+a` - Mod key + A (Mod = Ctrl on Windows/Linux, Cmd on Mac)
|
||||
- `ctrl+shift+k` - Multiple modifiers
|
||||
- `f1` - Function keys
|
||||
- `mod+enter, ctrl+enter` - Multiple alternatives (separated by comma)
|
||||
|
||||
6. Click the **checkmark** or press Enter to save
|
||||
7. Click the **X** or press Escape to cancel
|
||||
|
||||
### Resetting Hotkeys
|
||||
|
||||
**Reset a single hotkey:**
|
||||
- Click the counter-clockwise arrow icon that appears next to customized hotkeys
|
||||
|
||||
**Reset all hotkeys:**
|
||||
- In Edit Mode, click the **Reset All to Default** button at the bottom
|
||||
|
||||
### Hotkey Format Reference
|
||||
|
||||
**Valid Modifiers:**
|
||||
- `mod` - Context-aware: Ctrl (Windows/Linux) or Cmd (Mac)
|
||||
- `ctrl` - Control key
|
||||
- `shift` - Shift key
|
||||
- `alt` - Alt key (Option on Mac)
|
||||
|
||||
**Valid Keys:**
|
||||
- Letters: `a-z`
|
||||
- Numbers: `0-9`
|
||||
- Function keys: `f1-f12`
|
||||
- Special keys: `enter`, `space`, `tab`, `backspace`, `delete`, `escape`
|
||||
- Arrow keys: `up`, `down`, `left`, `right`
|
||||
- And more...
|
||||
|
||||
**Examples:**
|
||||
- ✅ `mod+s` - Save action
|
||||
- ✅ `ctrl+shift+p` - Command palette
|
||||
- ✅ `f5, mod+r` - Two alternatives for refresh
|
||||
- ❌ `mod+` - Invalid (no key after modifier)
|
||||
- ❌ `shift+ctrl+` - Invalid (ends with modifier)
|
||||
|
||||
## For Developers
|
||||
|
||||
For technical implementation details, architecture, and how to add new hotkeys to the system, see the [Hotkeys Developer Documentation](../contributing/HOTKEYS.md).
|
||||
@@ -70,7 +70,7 @@ Prior to installing PyPatchMatch, you need to take the following steps:
|
||||
`from patchmatch import patch_match`: It should look like the following:
|
||||
|
||||
```py
|
||||
Python 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] on linux
|
||||
Python 3.12.3 (main, Aug 14 2025, 17:47:21) [GCC 13.3.0] on linux
|
||||
Type "help", "copyright", "credits" or "license" for more information.
|
||||
>>> from patchmatch import patch_match
|
||||
Compiling and loading c extensions from "/home/lstein/Projects/InvokeAI/.invokeai-env/src/pypatchmatch/patchmatch".
|
||||
|
||||
@@ -25,12 +25,24 @@ Hardware requirements vary significantly depending on model and image output siz
|
||||
- Memory: At least 16GB RAM.
|
||||
- Disk: 10GB for base installation plus 100GB for models.
|
||||
|
||||
=== "FLUX - 1024×1024"
|
||||
=== "FLUX.1 - 1024×1024"
|
||||
|
||||
- GPU: Nvidia 20xx series or later, 10GB+ VRAM.
|
||||
- Memory: At least 32GB RAM.
|
||||
- Disk: 10GB for base installation plus 200GB for models.
|
||||
|
||||
=== "FLUX.2 Klein - 1024×1024"
|
||||
|
||||
- GPU: Nvidia 20xx series or later, 6GB+ VRAM for GGUF Q4 quantized models, 12GB+ for full precision.
|
||||
- Memory: At least 16GB RAM.
|
||||
- Disk: 10GB for base installation plus 20GB for models.
|
||||
|
||||
=== "Z-Image Turbo - 1024x1024"
|
||||
- GPU: Nvidia 20xx series or later, 8GB+ VRAM for the Q4_K quantized model. 16GB+ needed for the Q8 or BF16 models.
|
||||
- Memory: At least 16GB RAM.
|
||||
- Disk: 10GB for base installation plus 35GB for models.
|
||||
|
||||
|
||||
More detail on system requirements can be found [here](./requirements.md).
|
||||
|
||||
## Step 2: Download and Set Up the Launcher
|
||||
|
||||
@@ -25,12 +25,29 @@ The requirements below are rough guidelines for best performance. GPUs with less
|
||||
- Memory: At least 16GB RAM.
|
||||
- Disk: 10GB for base installation plus 100GB for models.
|
||||
|
||||
=== "FLUX - 1024×1024"
|
||||
=== "FLUX.1 - 1024×1024"
|
||||
|
||||
- GPU: Nvidia 20xx series or later, 10GB+ VRAM.
|
||||
- Memory: At least 32GB RAM.
|
||||
- Disk: 10GB for base installation plus 200GB for models.
|
||||
|
||||
=== "FLUX.2 Klein 4B - 1024×1024"
|
||||
|
||||
- GPU: Nvidia 30xx series or later, 12GB+ VRAM (e.g. RTX 3090, RTX 4070). FP8 version works with 8GB+ VRAM.
|
||||
- Memory: At least 16GB RAM.
|
||||
- Disk: 10GB for base installation plus 20GB for models (Diffusers format with encoder).
|
||||
|
||||
=== "FLUX.2 Klein 9B - 1024×1024"
|
||||
|
||||
- GPU: Nvidia 40xx series, 24GB+ VRAM (e.g. RTX 4090). FP8 version works with 12GB+ VRAM.
|
||||
- Memory: At least 32GB RAM.
|
||||
- Disk: 10GB for base installation plus 40GB for models (Diffusers format with encoder).
|
||||
|
||||
=== "Z-Image Turbo - 1024x1024"
|
||||
- GPU: Nvidia 20xx series or later, 8GB+ VRAM for the Q4_K quantized model. 16GB+ needed for the Q8 or BF16 models.
|
||||
- Memory: At least 16GB RAM.
|
||||
- Disk: 10GB for base installation plus 35GB for models.
|
||||
|
||||
!!! info "`tmpfs` on Linux"
|
||||
|
||||
If your temporary directory is mounted as a `tmpfs`, ensure it has sufficient space.
|
||||
@@ -41,7 +58,7 @@ The requirements below are rough guidelines for best performance. GPUs with less
|
||||
|
||||
You don't need to do this if you are installing with the [Invoke Launcher](./quick_start.md).
|
||||
|
||||
Invoke requires python 3.10 through 3.12. If you don't already have one of these versions installed, we suggest installing 3.12, as it will be supported for longer.
|
||||
Invoke requires python 3.11 through 3.12. If you don't already have one of these versions installed, we suggest installing 3.12, as it will be supported for longer.
|
||||
|
||||
Check that your system has an up-to-date Python installed by running `python3 --version` in the terminal (Linux, macOS) or cmd/powershell (Windows).
|
||||
|
||||
@@ -56,7 +73,7 @@ Check that your system has an up-to-date Python installed by running `python3 --
|
||||
=== "macOS"
|
||||
|
||||
- Install python with [an official installer].
|
||||
- If model installs fail with a certificate error, you may need to run this command (changing the python version to match what you have installed): `/Applications/Python\ 3.10/Install\ Certificates.command`
|
||||
- If model installs fail with a certificate error, you may need to run this command (changing the python version to match what you have installed): `/Applications/Python\ 3.11/Install\ Certificates.command`
|
||||
- If you haven't already, you will need to install the XCode CLI Tools by running `xcode-select --install` in a terminal.
|
||||
|
||||
=== "Linux"
|
||||
|
||||
@@ -49,6 +49,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
FLUXConditioningInfo,
|
||||
SD3ConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
ZImageConditioningInfo,
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
@@ -129,6 +130,7 @@ class ApiDependencies:
|
||||
FLUXConditioningInfo,
|
||||
SD3ConditioningInfo,
|
||||
CogView4ConditioningInfo,
|
||||
ZImageConditioningInfo,
|
||||
],
|
||||
ephemeral=True,
|
||||
),
|
||||
|
||||
@@ -28,7 +28,7 @@ from invokeai.app.services.model_records import (
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.util.suppress_output import SuppressOutput
|
||||
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
|
||||
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
|
||||
from invokeai.backend.model_manager.configs.main import (
|
||||
Main_Checkpoint_SD1_Config,
|
||||
Main_Checkpoint_SD2_Config,
|
||||
@@ -38,6 +38,7 @@ from invokeai.backend.model_manager.configs.main import (
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.model_manager.starter_models import (
|
||||
STARTER_BUNDLES,
|
||||
@@ -191,6 +192,49 @@ async def get_model_record(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/i/{key}/reidentify",
|
||||
operation_id="reidentify_model",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model configuration was retrieved successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model could not be found"},
|
||||
},
|
||||
)
|
||||
async def reidentify_model(
|
||||
key: Annotated[str, Path(description="Key of the model to reidentify.")],
|
||||
) -> AnyModelConfig:
|
||||
"""Attempt to reidentify a model by re-probing its weights file."""
|
||||
try:
|
||||
config = ApiDependencies.invoker.services.model_manager.store.get_model(key)
|
||||
models_path = ApiDependencies.invoker.services.configuration.models_path
|
||||
if pathlib.Path(config.path).is_relative_to(models_path):
|
||||
model_path = pathlib.Path(config.path)
|
||||
else:
|
||||
model_path = models_path / config.path
|
||||
mod = ModelOnDisk(model_path)
|
||||
result = ModelConfigFactory.from_model_on_disk(mod)
|
||||
if result.config is None:
|
||||
raise InvalidModelException("Unable to identify model format")
|
||||
|
||||
# Retain user-editable fields from the original config
|
||||
result.config.key = config.key
|
||||
result.config.name = config.name
|
||||
result.config.description = config.description
|
||||
result.config.cover_image = config.cover_image
|
||||
result.config.trigger_phrases = config.trigger_phrases
|
||||
result.config.source = config.source
|
||||
result.config.source_type = config.source_type
|
||||
|
||||
new_config = ApiDependencies.invoker.services.model_manager.store.replace_model(config.key, result.config)
|
||||
return new_config
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
class FoundModel(BaseModel):
|
||||
path: str = Field(description="Path to the model")
|
||||
is_installed: bool = Field(description="Whether or not the model is already installed")
|
||||
@@ -238,9 +282,10 @@ async def scan_for_models(
|
||||
found_model = FoundModel(path=path, is_installed=is_installed)
|
||||
scan_results.append(found_model)
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An error occurred while searching the directory: {e}",
|
||||
detail=f"An error occurred while searching the directory: {error_type}",
|
||||
)
|
||||
return scan_results
|
||||
|
||||
@@ -411,6 +456,59 @@ async def delete_model(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
class BulkDeleteModelsRequest(BaseModel):
|
||||
"""Request body for bulk model deletion."""
|
||||
|
||||
keys: List[str] = Field(description="List of model keys to delete")
|
||||
|
||||
|
||||
class BulkDeleteModelsResponse(BaseModel):
|
||||
"""Response body for bulk model deletion."""
|
||||
|
||||
deleted: List[str] = Field(description="List of successfully deleted model keys")
|
||||
failed: List[dict] = Field(description="List of failed deletions with error messages")
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/i/bulk_delete",
|
||||
operation_id="bulk_delete_models",
|
||||
responses={
|
||||
200: {"description": "Models deleted (possibly with some failures)"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def bulk_delete_models(
|
||||
request: BulkDeleteModelsRequest = Body(description="List of model keys to delete"),
|
||||
) -> BulkDeleteModelsResponse:
|
||||
"""
|
||||
Delete multiple model records from database.
|
||||
|
||||
The configuration records will be removed. The corresponding weights files will be
|
||||
deleted as well if they reside within the InvokeAI "models" directory.
|
||||
Returns a list of successfully deleted keys and failed deletions with error messages.
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
|
||||
deleted = []
|
||||
failed = []
|
||||
|
||||
for key in request.keys:
|
||||
try:
|
||||
installer.delete(key)
|
||||
deleted.append(key)
|
||||
logger.info(f"Deleted model: {key}")
|
||||
except UnknownModelException as e:
|
||||
logger.error(f"Failed to delete model {key}: {str(e)}")
|
||||
failed.append({"key": key, "error": str(e)})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete model {key}: {str(e)}")
|
||||
failed.append({"key": key, "error": str(e)})
|
||||
|
||||
logger.info(f"Bulk delete completed: {len(deleted)} deleted, {len(failed)} failed")
|
||||
return BulkDeleteModelsResponse(deleted=deleted, failed=failed)
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/i/{key}/image",
|
||||
operation_id="delete_model_image",
|
||||
@@ -816,15 +914,48 @@ class StarterModelResponse(BaseModel):
|
||||
def get_is_installed(
|
||||
starter_model: StarterModel | StarterModelWithoutDependencies, installed_models: list[AnyModelConfig]
|
||||
) -> bool:
|
||||
from invokeai.backend.model_manager.taxonomy import ModelType
|
||||
|
||||
for model in installed_models:
|
||||
# Check if source matches exactly
|
||||
if model.source == starter_model.source:
|
||||
return True
|
||||
# Check if name (or previous names), base and type match
|
||||
if (
|
||||
(model.name == starter_model.name or model.name in starter_model.previous_names)
|
||||
and model.base == starter_model.base
|
||||
and model.type == starter_model.type
|
||||
):
|
||||
return True
|
||||
|
||||
# Special handling for Qwen3Encoder models - check by type and variant
|
||||
# This allows renamed models to still be detected as installed
|
||||
if starter_model.type == ModelType.Qwen3Encoder:
|
||||
from invokeai.backend.model_manager.taxonomy import Qwen3VariantType
|
||||
|
||||
# Determine expected variant from source pattern
|
||||
expected_variant: Qwen3VariantType | None = None
|
||||
if "klein-9B" in starter_model.source or "qwen3_8b" in starter_model.source.lower():
|
||||
expected_variant = Qwen3VariantType.Qwen3_8B
|
||||
elif (
|
||||
"klein-4B" in starter_model.source
|
||||
or "qwen3_4b" in starter_model.source.lower()
|
||||
or "Z-Image" in starter_model.source
|
||||
):
|
||||
expected_variant = Qwen3VariantType.Qwen3_4B
|
||||
|
||||
if expected_variant is not None:
|
||||
for model in installed_models:
|
||||
if model.type == ModelType.Qwen3Encoder and hasattr(model, "variant"):
|
||||
model_variant = model.variant
|
||||
# Handle both enum and string values
|
||||
if isinstance(model_variant, Qwen3VariantType):
|
||||
if model_variant == expected_variant:
|
||||
return True
|
||||
elif isinstance(model_variant, str):
|
||||
if model_variant == expected_variant.value:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -223,6 +223,15 @@ async def get_workflow_thumbnail(
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@workflows_router.get("/tags", operation_id="get_all_tags")
|
||||
async def get_all_tags(
|
||||
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
|
||||
) -> list[str]:
|
||||
"""Gets all unique tags from workflows"""
|
||||
|
||||
return ApiDependencies.invoker.services.workflow_records.get_all_tags(categories=categories)
|
||||
|
||||
|
||||
@workflows_router.get("/counts_by_tag", operation_id="get_counts_by_tag")
|
||||
async def get_counts_by_tag(
|
||||
tags: list[str] = Query(description="The tags to get counts for"),
|
||||
|
||||
@@ -154,6 +154,7 @@ class FieldDescriptions:
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
t5_encoder = "T5 tokenizer and text encoder"
|
||||
glm_encoder = "GLM (THUDM) tokenizer and text encoder"
|
||||
qwen3_encoder = "Qwen3 tokenizer and text encoder"
|
||||
clip_embed_model = "CLIP Embed loader"
|
||||
clip_g_model = "CLIP-G Embed loader"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
@@ -169,6 +170,7 @@ class FieldDescriptions:
|
||||
flux_model = "Flux model (Transformer) to load"
|
||||
sd3_model = "SD3 model (MMDiTX) to load"
|
||||
cogview4_model = "CogView4 model (Transformer) to load"
|
||||
z_image_model = "Z-Image model (Transformer) to load"
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
@@ -241,6 +243,12 @@ class BoardField(BaseModel):
|
||||
board_id: str = Field(description="The id of the board")
|
||||
|
||||
|
||||
class StylePresetField(BaseModel):
|
||||
"""A style preset primitive field"""
|
||||
|
||||
style_preset_id: str = Field(description="The id of the style preset")
|
||||
|
||||
|
||||
class DenoiseMaskField(BaseModel):
|
||||
"""An inpaint mask field"""
|
||||
|
||||
@@ -321,6 +329,17 @@ class CogView4ConditioningField(BaseModel):
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
class ZImageConditioningField(BaseModel):
|
||||
"""A Z-Image conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
mask: Optional[TensorField] = Field(
|
||||
default=None,
|
||||
description="The mask associated with this conditioning tensor for regional prompting. "
|
||||
"Excluded regions should be set to False, included regions should be set to True.",
|
||||
)
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
@@ -513,7 +532,7 @@ def migrate_model_ui_type(ui_type: UIType | str, json_schema_extra: dict[str, An
|
||||
case UIType.VAEModel:
|
||||
ui_model_type = [ModelType.VAE]
|
||||
case UIType.FluxVAEModel:
|
||||
ui_model_base = [BaseModelType.Flux]
|
||||
ui_model_base = [BaseModelType.Flux, BaseModelType.Flux2]
|
||||
ui_model_type = [ModelType.VAE]
|
||||
case UIType.LoRAModel:
|
||||
ui_model_type = [ModelType.LoRA]
|
||||
|
||||
505
invokeai/app/invocations/flux2_denoise.py
Normal file
505
invokeai/app/invocations/flux2_denoise.py
Normal file
@@ -0,0 +1,505 @@
|
||||
"""Flux2 Klein Denoise Invocation.
|
||||
|
||||
Run denoising process with a FLUX.2 Klein transformer model.
|
||||
Uses Qwen3 conditioning instead of CLIP+T5.
|
||||
"""
|
||||
|
||||
from contextlib import ExitStack
|
||||
from typing import Callable, Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
FluxKontextConditioningField,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
|
||||
from invokeai.backend.flux.schedulers import FLUX_SCHEDULER_LABELS, FLUX_SCHEDULER_MAP, FLUX_SCHEDULER_NAME_VALUES
|
||||
from invokeai.backend.flux2.denoise import denoise
|
||||
from invokeai.backend.flux2.ref_image_extension import Flux2RefImageExtension
|
||||
from invokeai.backend.flux2.sampling_utils import (
|
||||
compute_empirical_mu,
|
||||
generate_img_ids_flux2,
|
||||
get_noise_flux2,
|
||||
get_schedule_flux2,
|
||||
pack_flux2,
|
||||
unpack_flux2,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux2_denoise",
|
||||
title="FLUX2 Denoise",
|
||||
tags=["image", "flux", "flux2", "klein", "denoise"],
|
||||
category="image",
|
||||
version="1.3.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Flux2DenoiseInvocation(BaseInvocation):
|
||||
"""Run denoising process with a FLUX.2 Klein transformer model.
|
||||
|
||||
This node is designed for FLUX.2 Klein models which use Qwen3 as the text encoder.
|
||||
It does not support ControlNet, IP-Adapters, or regional prompting.
|
||||
"""
|
||||
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.denoise_mask,
|
||||
input=Input.Connection,
|
||||
)
|
||||
denoising_start: float = InputField(
|
||||
default=0.0,
|
||||
ge=0,
|
||||
le=1,
|
||||
description=FieldDescriptions.denoising_start,
|
||||
)
|
||||
denoising_end: float = InputField(
|
||||
default=1.0,
|
||||
ge=0,
|
||||
le=1,
|
||||
description=FieldDescriptions.denoising_end,
|
||||
)
|
||||
add_noise: bool = InputField(default=True, description="Add noise based on denoising start.")
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
positive_text_conditioning: FluxConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond,
|
||||
input=Input.Connection,
|
||||
)
|
||||
negative_text_conditioning: Optional[FluxConditioningField] = InputField(
|
||||
default=None,
|
||||
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
|
||||
input=Input.Connection,
|
||||
)
|
||||
cfg_scale: float = InputField(
|
||||
default=1.0,
|
||||
description=FieldDescriptions.cfg_scale,
|
||||
title="CFG Scale",
|
||||
)
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
num_steps: int = InputField(
|
||||
default=4,
|
||||
description="Number of diffusion steps. Use 4 for distilled models, 28+ for base models.",
|
||||
)
|
||||
scheduler: FLUX_SCHEDULER_NAME_VALUES = InputField(
|
||||
default="euler",
|
||||
description="Scheduler (sampler) for the denoising process. 'euler' is fast and standard. "
|
||||
"'heun' is 2nd-order (better quality, 2x slower). 'lcm' is optimized for few steps.",
|
||||
ui_choice_labels=FLUX_SCHEDULER_LABELS,
|
||||
)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
vae: VAEField = InputField(
|
||||
description="FLUX.2 VAE model (required for BN statistics).",
|
||||
input=Input.Connection,
|
||||
)
|
||||
kontext_conditioning: FluxKontextConditioningField | list[FluxKontextConditioningField] | None = InputField(
|
||||
default=None,
|
||||
description="FLUX Kontext conditioning (reference images for multi-reference image editing).",
|
||||
input=Input.Connection,
|
||||
title="Reference Images",
|
||||
)
|
||||
|
||||
def _get_bn_stats(self, context: InvocationContext) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Extract BN statistics from the FLUX.2 VAE.
|
||||
|
||||
The FLUX.2 VAE uses batch normalization on the patchified 128-channel representation.
|
||||
IMPORTANT: BFL FLUX.2 VAE uses affine=False, so there are NO learnable weight/bias.
|
||||
|
||||
BN formula (affine=False): y = (x - mean) / std
|
||||
Inverse: x = y * std + mean
|
||||
|
||||
Returns:
|
||||
Tuple of (bn_mean, bn_std) tensors of shape (128,), or None if BN layer not found.
|
||||
"""
|
||||
with context.models.load(self.vae.vae).model_on_device() as (_, vae):
|
||||
# Ensure VAE is in eval mode to prevent BN stats from being updated
|
||||
vae.eval()
|
||||
|
||||
# Try to find the BN layer - it may be at different locations depending on model format
|
||||
bn_layer = None
|
||||
if hasattr(vae, "bn"):
|
||||
bn_layer = vae.bn
|
||||
elif hasattr(vae, "batch_norm"):
|
||||
bn_layer = vae.batch_norm
|
||||
elif hasattr(vae, "encoder") and hasattr(vae.encoder, "bn"):
|
||||
bn_layer = vae.encoder.bn
|
||||
|
||||
if bn_layer is None:
|
||||
return None
|
||||
|
||||
# Verify running statistics are initialized
|
||||
if bn_layer.running_mean is None or bn_layer.running_var is None:
|
||||
return None
|
||||
|
||||
# Get BN running statistics from VAE
|
||||
bn_mean = bn_layer.running_mean.clone() # Shape: (128,)
|
||||
bn_var = bn_layer.running_var.clone() # Shape: (128,)
|
||||
bn_eps = bn_layer.eps if hasattr(bn_layer, "eps") else 1e-4 # BFL uses 1e-4
|
||||
bn_std = torch.sqrt(bn_var + bn_eps)
|
||||
|
||||
return bn_mean, bn_std
|
||||
|
||||
def _bn_normalize(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
bn_mean: torch.Tensor,
|
||||
bn_std: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Apply BN normalization to packed latents.
|
||||
|
||||
BN formula (affine=False): y = (x - mean) / std
|
||||
|
||||
Args:
|
||||
x: Packed latents of shape (B, seq, 128).
|
||||
bn_mean: BN running mean of shape (128,).
|
||||
bn_std: BN running std of shape (128,).
|
||||
|
||||
Returns:
|
||||
Normalized latents of same shape.
|
||||
"""
|
||||
# x: (B, seq, 128), params: (128,) -> broadcast over batch and sequence dims
|
||||
bn_mean = bn_mean.to(x.device, x.dtype)
|
||||
bn_std = bn_std.to(x.device, x.dtype)
|
||||
return (x - bn_mean) / bn_std
|
||||
|
||||
def _bn_denormalize(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
bn_mean: torch.Tensor,
|
||||
bn_std: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Apply BN denormalization to packed latents (inverse of normalization).
|
||||
|
||||
Inverse BN (affine=False): x = y * std + mean
|
||||
|
||||
Args:
|
||||
x: Packed latents of shape (B, seq, 128).
|
||||
bn_mean: BN running mean of shape (128,).
|
||||
bn_std: BN running std of shape (128,).
|
||||
|
||||
Returns:
|
||||
Denormalized latents of same shape.
|
||||
"""
|
||||
# x: (B, seq, 128), params: (128,) -> broadcast over batch and sequence dims
|
||||
bn_mean = bn_mean.to(x.device, x.dtype)
|
||||
bn_std = bn_std.to(x.device, x.dtype)
|
||||
return x * bn_std + bn_mean
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
latents = latents.detach().to("cpu")
|
||||
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
|
||||
inference_dtype = torch.bfloat16
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
# Get BN statistics from VAE for latent denormalization (optional)
|
||||
# BFL FLUX.2 VAE uses affine=False, so only mean/std are needed
|
||||
# Some VAE formats (e.g. diffusers) may not expose BN stats directly
|
||||
bn_stats = self._get_bn_stats(context)
|
||||
bn_mean, bn_std = bn_stats if bn_stats is not None else (None, None)
|
||||
|
||||
# Load the input latents, if provided
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
init_latents = init_latents.to(device=device, dtype=inference_dtype)
|
||||
|
||||
# Prepare input noise (FLUX.2 uses 32 channels)
|
||||
noise = get_noise_flux2(
|
||||
num_samples=1,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
device=device,
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
b, _c, latent_h, latent_w = noise.shape
|
||||
packed_h = latent_h // 2
|
||||
packed_w = latent_w // 2
|
||||
|
||||
# Load the conditioning data
|
||||
pos_cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
assert len(pos_cond_data.conditionings) == 1
|
||||
pos_flux_conditioning = pos_cond_data.conditionings[0]
|
||||
assert isinstance(pos_flux_conditioning, FLUXConditioningInfo)
|
||||
pos_flux_conditioning = pos_flux_conditioning.to(dtype=inference_dtype, device=device)
|
||||
|
||||
# Qwen3 stacked embeddings (stored in t5_embeds field for compatibility)
|
||||
txt = pos_flux_conditioning.t5_embeds
|
||||
|
||||
# Generate text position IDs (4D format for FLUX.2: T, H, W, L)
|
||||
# FLUX.2 uses 4D position coordinates for its rotary position embeddings
|
||||
# IMPORTANT: Position IDs must be int64 (long) dtype
|
||||
# Diffusers uses: T=0, H=0, W=0, L=0..seq_len-1
|
||||
seq_len = txt.shape[1]
|
||||
txt_ids = torch.zeros(1, seq_len, 4, device=device, dtype=torch.long)
|
||||
txt_ids[..., 3] = torch.arange(seq_len, device=device, dtype=torch.long) # L coordinate varies
|
||||
|
||||
# Load negative conditioning if provided
|
||||
neg_txt = None
|
||||
neg_txt_ids = None
|
||||
if self.negative_text_conditioning is not None:
|
||||
neg_cond_data = context.conditioning.load(self.negative_text_conditioning.conditioning_name)
|
||||
assert len(neg_cond_data.conditionings) == 1
|
||||
neg_flux_conditioning = neg_cond_data.conditionings[0]
|
||||
assert isinstance(neg_flux_conditioning, FLUXConditioningInfo)
|
||||
neg_flux_conditioning = neg_flux_conditioning.to(dtype=inference_dtype, device=device)
|
||||
neg_txt = neg_flux_conditioning.t5_embeds
|
||||
# For text tokens: T=0, H=0, W=0, L=0..seq_len-1 (only L varies per token)
|
||||
neg_seq_len = neg_txt.shape[1]
|
||||
neg_txt_ids = torch.zeros(1, neg_seq_len, 4, device=device, dtype=torch.long)
|
||||
neg_txt_ids[..., 3] = torch.arange(neg_seq_len, device=device, dtype=torch.long)
|
||||
|
||||
# Validate transformer config
|
||||
transformer_config = context.models.get_config(self.transformer.transformer)
|
||||
assert transformer_config.base == BaseModelType.Flux2 and transformer_config.type == ModelType.Main
|
||||
|
||||
# Calculate the timestep schedule using FLUX.2 specific schedule
|
||||
# This matches diffusers' Flux2Pipeline implementation
|
||||
# Note: Schedule shifting is handled by the scheduler via mu parameter
|
||||
image_seq_len = packed_h * packed_w
|
||||
timesteps = get_schedule_flux2(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=image_seq_len,
|
||||
)
|
||||
# Compute mu for dynamic schedule shifting (used by FlowMatchEulerDiscreteScheduler)
|
||||
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=self.num_steps)
|
||||
|
||||
# Clip the timesteps schedule based on denoising_start and denoising_end
|
||||
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
|
||||
|
||||
# Prepare input latent image
|
||||
if init_latents is not None:
|
||||
if self.add_noise:
|
||||
t_0 = timesteps[0]
|
||||
x = t_0 * noise + (1.0 - t_0) * init_latents
|
||||
else:
|
||||
x = init_latents
|
||||
else:
|
||||
if self.denoising_start > 1e-5:
|
||||
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
||||
x = noise
|
||||
|
||||
# If len(timesteps) == 1, then short-circuit
|
||||
if len(timesteps) <= 1:
|
||||
return x
|
||||
|
||||
# Generate image position IDs (FLUX.2 uses 4D coordinates)
|
||||
# Position IDs use int64 dtype like diffusers
|
||||
img_ids = generate_img_ids_flux2(h=latent_h, w=latent_w, batch_size=b, device=device)
|
||||
|
||||
# Prepare inpaint mask
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
|
||||
# Pack all latent tensors
|
||||
init_latents_packed = pack_flux2(init_latents) if init_latents is not None else None
|
||||
inpaint_mask_packed = pack_flux2(inpaint_mask) if inpaint_mask is not None else None
|
||||
noise_packed = pack_flux2(noise)
|
||||
x = pack_flux2(x)
|
||||
|
||||
# BN normalization for txt2img:
|
||||
# - DO NOT normalize random noise (it's already N(0,1) distributed)
|
||||
# - Diffusers only normalizes image latents from VAE (for img2img/kontext)
|
||||
# - Output MUST be denormalized after denoising before VAE decode
|
||||
#
|
||||
# For img2img with init_latents, we should normalize init_latents on unpacked
|
||||
# shape (B, 128, H/16, W/16) - this is handled by _bn_normalize_unpacked below
|
||||
|
||||
# Verify packed dimensions
|
||||
assert packed_h * packed_w == x.shape[1]
|
||||
|
||||
# Prepare inpaint extension
|
||||
inpaint_extension: Optional[RectifiedFlowInpaintExtension] = None
|
||||
if inpaint_mask_packed is not None:
|
||||
assert init_latents_packed is not None
|
||||
inpaint_extension = RectifiedFlowInpaintExtension(
|
||||
init_latents=init_latents_packed,
|
||||
inpaint_mask=inpaint_mask_packed,
|
||||
noise=noise_packed,
|
||||
)
|
||||
|
||||
# Prepare CFG scale list
|
||||
num_steps = len(timesteps) - 1
|
||||
cfg_scale_list = [self.cfg_scale] * num_steps
|
||||
|
||||
# Check if we're doing inpainting (have a mask or a clipped schedule)
|
||||
is_inpainting = self.denoise_mask is not None or self.denoising_start > 1e-5
|
||||
|
||||
# Create scheduler with FLUX.2 Klein configuration
|
||||
# For inpainting/img2img, use manual Euler stepping to preserve the exact timestep schedule
|
||||
# For txt2img, use the scheduler with dynamic shifting for optimal results
|
||||
scheduler = None
|
||||
if self.scheduler in FLUX_SCHEDULER_MAP and not is_inpainting:
|
||||
# Only use scheduler for txt2img - use manual Euler for inpainting to preserve exact timesteps
|
||||
scheduler_class = FLUX_SCHEDULER_MAP[self.scheduler]
|
||||
# FlowMatchHeunDiscreteScheduler only supports num_train_timesteps and shift parameters
|
||||
# FlowMatchEulerDiscreteScheduler and FlowMatchLCMScheduler support dynamic shifting
|
||||
if self.scheduler == "heun":
|
||||
scheduler = scheduler_class(
|
||||
num_train_timesteps=1000,
|
||||
shift=3.0,
|
||||
)
|
||||
else:
|
||||
scheduler = scheduler_class(
|
||||
num_train_timesteps=1000,
|
||||
shift=3.0,
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.5,
|
||||
max_shift=1.15,
|
||||
base_image_seq_len=256,
|
||||
max_image_seq_len=4096,
|
||||
time_shift_type="exponential",
|
||||
)
|
||||
|
||||
# Prepare reference image extension for FLUX.2 Klein built-in editing
|
||||
ref_image_extension = None
|
||||
if self.kontext_conditioning:
|
||||
ref_image_extension = Flux2RefImageExtension(
|
||||
context=context,
|
||||
ref_image_conditioning=self.kontext_conditioning
|
||||
if isinstance(self.kontext_conditioning, list)
|
||||
else [self.kontext_conditioning],
|
||||
vae_field=self.vae,
|
||||
device=device,
|
||||
dtype=inference_dtype,
|
||||
bn_mean=bn_mean,
|
||||
bn_std=bn_std,
|
||||
)
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
# Load the transformer model
|
||||
(cached_weights, transformer) = exit_stack.enter_context(
|
||||
context.models.load(self.transformer.transformer).model_on_device()
|
||||
)
|
||||
config = transformer_config
|
||||
|
||||
# Determine if the model is quantized
|
||||
if config.format in [ModelFormat.Diffusers]:
|
||||
model_is_quantized = False
|
||||
elif config.format in [
|
||||
ModelFormat.BnbQuantizedLlmInt8b,
|
||||
ModelFormat.BnbQuantizednf4b,
|
||||
ModelFormat.GGUFQuantized,
|
||||
]:
|
||||
model_is_quantized = True
|
||||
else:
|
||||
model_is_quantized = False
|
||||
|
||||
# Apply LoRA models to the transformer
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
dtype=inference_dtype,
|
||||
cached_weights=cached_weights,
|
||||
force_sidecar_patching=model_is_quantized,
|
||||
)
|
||||
)
|
||||
|
||||
# Prepare reference image conditioning if provided
|
||||
img_cond_seq = None
|
||||
img_cond_seq_ids = None
|
||||
if ref_image_extension is not None:
|
||||
# Ensure batch sizes match
|
||||
ref_image_extension.ensure_batch_size(x.shape[0])
|
||||
img_cond_seq, img_cond_seq_ids = (
|
||||
ref_image_extension.ref_image_latents,
|
||||
ref_image_extension.ref_image_ids,
|
||||
)
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=x,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
timesteps=timesteps,
|
||||
step_callback=self._build_step_callback(context),
|
||||
cfg_scale=cfg_scale_list,
|
||||
neg_txt=neg_txt,
|
||||
neg_txt_ids=neg_txt_ids,
|
||||
scheduler=scheduler,
|
||||
mu=mu,
|
||||
inpaint_extension=inpaint_extension,
|
||||
img_cond_seq=img_cond_seq,
|
||||
img_cond_seq_ids=img_cond_seq_ids,
|
||||
)
|
||||
|
||||
# Apply BN denormalization if BN stats are available
|
||||
# The diffusers Flux2KleinPipeline applies: latents = latents * bn_std + bn_mean
|
||||
# This transforms latents from normalized space to VAE's expected input space
|
||||
if bn_mean is not None and bn_std is not None:
|
||||
x = self._bn_denormalize(x, bn_mean, bn_std)
|
||||
|
||||
x = unpack_flux2(x.float(), self.height, self.width)
|
||||
return x
|
||||
|
||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
"""Prepare the inpaint mask."""
|
||||
if self.denoise_mask is None:
|
||||
return None
|
||||
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
mask = 1.0 - mask
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
mask = tv_resize(
|
||||
img=mask,
|
||||
size=[latent_height, latent_width],
|
||||
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
||||
antialias=False,
|
||||
)
|
||||
|
||||
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
||||
return mask.expand_as(latents)
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
"""Iterate over LoRA models to apply."""
|
||||
for lora in self.transformer.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
"""Build a callback for step progress updates."""
|
||||
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
latents = state.latents.float()
|
||||
state.latents = unpack_flux2(latents, self.height, self.width).squeeze()
|
||||
context.util.flux2_step_callback(state)
|
||||
|
||||
return step_callback
|
||||
222
invokeai/app/invocations/flux2_klein_model_loader.py
Normal file
222
invokeai/app/invocations/flux2_klein_model_loader.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Flux2 Klein Model Loader Invocation.
|
||||
|
||||
Loads a Flux2 Klein model with its Qwen3 text encoder and VAE.
|
||||
Unlike standard FLUX which uses CLIP+T5, Klein uses only Qwen3.
|
||||
"""
|
||||
|
||||
from typing import Literal, Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import (
|
||||
ModelIdentifierField,
|
||||
Qwen3EncoderField,
|
||||
TransformerField,
|
||||
VAEField,
|
||||
)
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
BaseModelType,
|
||||
Flux2VariantType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
Qwen3VariantType,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("flux2_klein_model_loader_output")
|
||||
class Flux2KleinModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Flux2 Klein model loader output."""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
qwen3_encoder: Qwen3EncoderField = OutputField(description=FieldDescriptions.qwen3_encoder, title="Qwen3 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
max_seq_len: Literal[256, 512] = OutputField(
|
||||
description="The max sequence length for the Qwen3 encoder.",
|
||||
title="Max Seq Length",
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux2_klein_model_loader",
|
||||
title="Main Model - Flux2 Klein",
|
||||
tags=["model", "flux", "klein", "qwen3"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Flux2KleinModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a Flux2 Klein model, outputting its submodels.
|
||||
|
||||
Flux2 Klein uses Qwen3 as the text encoder instead of CLIP+T5.
|
||||
It uses a 32-channel VAE (AutoencoderKLFlux2) instead of the 16-channel FLUX.1 VAE.
|
||||
|
||||
When using a Diffusers format model, both VAE and Qwen3 encoder are extracted
|
||||
automatically from the main model. You can override with standalone models:
|
||||
- Transformer: Always from Flux2 Klein main model
|
||||
- VAE: From main model (Diffusers) or standalone VAE
|
||||
- Qwen3 Encoder: From main model (Diffusers) or standalone Qwen3 model
|
||||
"""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
input=Input.Direct,
|
||||
ui_model_base=BaseModelType.Flux2,
|
||||
ui_model_type=ModelType.Main,
|
||||
title="Transformer",
|
||||
)
|
||||
|
||||
vae_model: Optional[ModelIdentifierField] = InputField(
|
||||
default=None,
|
||||
description="Standalone VAE model. Flux2 Klein uses the same VAE as FLUX (16-channel). "
|
||||
"If not provided, VAE will be loaded from the Qwen3 Source model.",
|
||||
input=Input.Direct,
|
||||
ui_model_base=[BaseModelType.Flux, BaseModelType.Flux2],
|
||||
ui_model_type=ModelType.VAE,
|
||||
title="VAE",
|
||||
)
|
||||
|
||||
qwen3_encoder_model: Optional[ModelIdentifierField] = InputField(
|
||||
default=None,
|
||||
description="Standalone Qwen3 Encoder model. "
|
||||
"If not provided, encoder will be loaded from the Qwen3 Source model.",
|
||||
input=Input.Direct,
|
||||
ui_model_type=ModelType.Qwen3Encoder,
|
||||
title="Qwen3 Encoder",
|
||||
)
|
||||
|
||||
qwen3_source_model: Optional[ModelIdentifierField] = InputField(
|
||||
default=None,
|
||||
description="Diffusers Flux2 Klein model to extract VAE and/or Qwen3 encoder from. "
|
||||
"Use this if you don't have separate VAE/Qwen3 models. "
|
||||
"Ignored if both VAE and Qwen3 Encoder are provided separately.",
|
||||
input=Input.Direct,
|
||||
ui_model_base=BaseModelType.Flux2,
|
||||
ui_model_type=ModelType.Main,
|
||||
ui_model_format=ModelFormat.Diffusers,
|
||||
title="Qwen3 Source (Diffusers)",
|
||||
)
|
||||
|
||||
max_seq_len: Literal[256, 512] = InputField(
|
||||
default=512,
|
||||
description="Max sequence length for the Qwen3 encoder.",
|
||||
title="Max Seq Length",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> Flux2KleinModelLoaderOutput:
|
||||
# Transformer always comes from the main model
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
|
||||
# Check if main model is Diffusers format (can extract VAE directly)
|
||||
main_config = context.models.get_config(self.model)
|
||||
main_is_diffusers = main_config.format == ModelFormat.Diffusers
|
||||
|
||||
# Determine VAE source
|
||||
# IMPORTANT: FLUX.2 Klein uses a 32-channel VAE (AutoencoderKLFlux2), not the 16-channel FLUX.1 VAE.
|
||||
# The VAE should come from the FLUX.2 Klein Diffusers model, not a separate FLUX VAE.
|
||||
if self.vae_model is not None:
|
||||
# Use standalone VAE (user explicitly selected one)
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
elif main_is_diffusers:
|
||||
# Extract VAE from main model (recommended for FLUX.2)
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
elif self.qwen3_source_model is not None:
|
||||
# Extract from Qwen3 source Diffusers model
|
||||
self._validate_diffusers_format(context, self.qwen3_source_model, "Qwen3 Source")
|
||||
vae = self.qwen3_source_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
else:
|
||||
raise ValueError(
|
||||
"No VAE source provided. Standalone safetensors/GGUF models require a separate VAE. "
|
||||
"Options:\n"
|
||||
" 1. Set 'VAE' to a standalone FLUX VAE model\n"
|
||||
" 2. Set 'Qwen3 Source' to a Diffusers Flux2 Klein model to extract the VAE from"
|
||||
)
|
||||
|
||||
# Determine Qwen3 Encoder source
|
||||
if self.qwen3_encoder_model is not None:
|
||||
# Use standalone Qwen3 Encoder - validate it matches the FLUX.2 Klein variant
|
||||
self._validate_qwen3_encoder_variant(context, main_config)
|
||||
qwen3_tokenizer = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
qwen3_encoder = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
elif main_is_diffusers:
|
||||
# Extract from main model (recommended for FLUX.2 Klein)
|
||||
qwen3_tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
qwen3_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
elif self.qwen3_source_model is not None:
|
||||
# Extract from separate Diffusers model
|
||||
self._validate_diffusers_format(context, self.qwen3_source_model, "Qwen3 Source")
|
||||
qwen3_tokenizer = self.qwen3_source_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
qwen3_encoder = self.qwen3_source_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
else:
|
||||
raise ValueError(
|
||||
"No Qwen3 Encoder source provided. Standalone safetensors/GGUF models require a separate text encoder. "
|
||||
"Options:\n"
|
||||
" 1. Set 'Qwen3 Encoder' to a standalone Qwen3 text encoder model "
|
||||
"(Klein 4B needs Qwen3 4B, Klein 9B needs Qwen3 8B)\n"
|
||||
" 2. Set 'Qwen3 Source' to a Diffusers Flux2 Klein model to extract the encoder from"
|
||||
)
|
||||
|
||||
return Flux2KleinModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
qwen3_encoder=Qwen3EncoderField(tokenizer=qwen3_tokenizer, text_encoder=qwen3_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=self.max_seq_len,
|
||||
)
|
||||
|
||||
def _validate_diffusers_format(
|
||||
self, context: InvocationContext, model: ModelIdentifierField, model_name: str
|
||||
) -> None:
|
||||
"""Validate that a model is in Diffusers format."""
|
||||
config = context.models.get_config(model)
|
||||
if config.format != ModelFormat.Diffusers:
|
||||
raise ValueError(
|
||||
f"The {model_name} model must be a Diffusers format model. "
|
||||
f"The selected model '{config.name}' is in {config.format.value} format."
|
||||
)
|
||||
|
||||
def _validate_qwen3_encoder_variant(self, context: InvocationContext, main_config) -> None:
|
||||
"""Validate that the standalone Qwen3 encoder variant matches the FLUX.2 Klein variant.
|
||||
|
||||
- FLUX.2 Klein 4B requires Qwen3 4B encoder
|
||||
- FLUX.2 Klein 9B requires Qwen3 8B encoder
|
||||
"""
|
||||
if self.qwen3_encoder_model is None:
|
||||
return
|
||||
|
||||
# Get the Qwen3 encoder config
|
||||
qwen3_config = context.models.get_config(self.qwen3_encoder_model)
|
||||
|
||||
# Check if the config has a variant field
|
||||
if not hasattr(qwen3_config, "variant"):
|
||||
# Can't validate, skip
|
||||
return
|
||||
|
||||
qwen3_variant = qwen3_config.variant
|
||||
|
||||
# Get the FLUX.2 Klein variant from the main model config
|
||||
if not hasattr(main_config, "variant"):
|
||||
return
|
||||
|
||||
flux2_variant = main_config.variant
|
||||
|
||||
# Validate the variants match
|
||||
# Klein4B requires Qwen3_4B, Klein9B/Klein9BBase requires Qwen3_8B
|
||||
expected_qwen3_variant = None
|
||||
if flux2_variant == Flux2VariantType.Klein4B:
|
||||
expected_qwen3_variant = Qwen3VariantType.Qwen3_4B
|
||||
elif flux2_variant in (Flux2VariantType.Klein9B, Flux2VariantType.Klein9BBase):
|
||||
expected_qwen3_variant = Qwen3VariantType.Qwen3_8B
|
||||
|
||||
if expected_qwen3_variant is not None and qwen3_variant != expected_qwen3_variant:
|
||||
raise ValueError(
|
||||
f"Qwen3 encoder variant mismatch: FLUX.2 Klein {flux2_variant.value} requires "
|
||||
f"{expected_qwen3_variant.value} encoder, but {qwen3_variant.value} was selected. "
|
||||
f"Please select a matching Qwen3 encoder or use a Diffusers format model which includes the correct encoder."
|
||||
)
|
||||
222
invokeai/app/invocations/flux2_klein_text_encoder.py
Normal file
222
invokeai/app/invocations/flux2_klein_text_encoder.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Flux2 Klein Text Encoder Invocation.
|
||||
|
||||
Flux2 Klein uses Qwen3 as the text encoder instead of CLIP+T5.
|
||||
The key difference is that it extracts hidden states from layers (9, 18, 27)
|
||||
and stacks them together for richer text representations.
|
||||
|
||||
This implementation matches the diffusers Flux2KleinPipeline exactly.
|
||||
"""
|
||||
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Literal, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
Input,
|
||||
InputField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.invocations.model import Qwen3EncoderField
|
||||
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_T5_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
# FLUX.2 Klein extracts hidden states from these specific layers
|
||||
# Matching diffusers Flux2KleinPipeline: (9, 18, 27)
|
||||
# hidden_states[0] is embedding layer, so layer N is at index N
|
||||
KLEIN_EXTRACTION_LAYERS = (9, 18, 27)
|
||||
|
||||
# Default max sequence length for Klein models
|
||||
KLEIN_MAX_SEQ_LEN = 512
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux2_klein_text_encoder",
|
||||
title="Prompt - Flux2 Klein",
|
||||
tags=["prompt", "conditioning", "flux", "klein", "qwen3"],
|
||||
category="conditioning",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Flux2KleinTextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes and preps a prompt for Flux2 Klein image generation.
|
||||
|
||||
Flux2 Klein uses Qwen3 as the text encoder, extracting hidden states from
|
||||
layers (9, 18, 27) and stacking them for richer text representations.
|
||||
This matches the diffusers Flux2KleinPipeline implementation exactly.
|
||||
"""
|
||||
|
||||
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
|
||||
qwen3_encoder: Qwen3EncoderField = InputField(
|
||||
title="Qwen3 Encoder",
|
||||
description=FieldDescriptions.qwen3_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
max_seq_len: Literal[256, 512] = InputField(
|
||||
default=512,
|
||||
description="Max sequence length for the Qwen3 encoder.",
|
||||
)
|
||||
mask: Optional[TensorField] = InputField(
|
||||
default=None,
|
||||
description="A mask defining the region that this conditioning prompt applies to.",
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
||||
qwen3_embeds, pooled_embeds = self._encode_prompt(context)
|
||||
|
||||
# Use FLUXConditioningInfo for compatibility with existing Flux denoiser
|
||||
# t5_embeds -> qwen3 stacked embeddings
|
||||
# clip_embeds -> pooled qwen3 embedding
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[FLUXConditioningInfo(clip_embeds=pooled_embeds, t5_embeds=qwen3_embeds)]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return FluxConditioningOutput(
|
||||
conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask)
|
||||
)
|
||||
|
||||
def _encode_prompt(self, context: InvocationContext) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encode prompt using Qwen3 text encoder with Klein-style layer extraction.
|
||||
|
||||
This matches the diffusers Flux2KleinPipeline._get_qwen3_prompt_embeds() exactly.
|
||||
|
||||
Returns:
|
||||
Tuple of (stacked_embeddings, pooled_embedding):
|
||||
- stacked_embeddings: Hidden states from layers (9, 18, 27) stacked together.
|
||||
Shape: (1, seq_len, hidden_size * 3)
|
||||
- pooled_embedding: Pooled representation for global conditioning.
|
||||
Shape: (1, hidden_size)
|
||||
"""
|
||||
prompt = self.prompt
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
text_encoder_info = context.models.load(self.qwen3_encoder.text_encoder)
|
||||
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
(cached_weights, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
|
||||
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
|
||||
|
||||
# Apply LoRA models to the text encoder
|
||||
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=text_encoder,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_T5_PREFIX, # Reuse T5 prefix for Qwen3 LoRAs
|
||||
dtype=lora_dtype,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
|
||||
context.util.signal_progress("Running Qwen3 text encoder (Klein)")
|
||||
|
||||
if not isinstance(text_encoder, PreTrainedModel):
|
||||
raise TypeError(
|
||||
f"Expected PreTrainedModel for text encoder, got {type(text_encoder).__name__}. "
|
||||
"The Qwen3 encoder model may be corrupted or incompatible."
|
||||
)
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
||||
raise TypeError(
|
||||
f"Expected PreTrainedTokenizerBase for tokenizer, got {type(tokenizer).__name__}. "
|
||||
"The Qwen3 tokenizer may be corrupted or incompatible."
|
||||
)
|
||||
|
||||
# Format messages exactly like diffusers Flux2KleinPipeline:
|
||||
# - Only user message, NO system message
|
||||
# - add_generation_prompt=True (adds assistant prefix)
|
||||
# - enable_thinking=False
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Step 1: Apply chat template to get formatted text (tokenize=False)
|
||||
text: str = tokenizer.apply_chat_template( # type: ignore[assignment]
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True, # Adds assistant prefix like diffusers
|
||||
enable_thinking=False, # Disable thinking mode
|
||||
)
|
||||
|
||||
# Step 2: Tokenize the formatted text
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.max_seq_len,
|
||||
)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
attention_mask = inputs["attention_mask"]
|
||||
|
||||
# Move to device
|
||||
input_ids = input_ids.to(device)
|
||||
attention_mask = attention_mask.to(device)
|
||||
|
||||
# Forward pass through the model - matching diffusers exactly
|
||||
outputs = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# Validate hidden_states output
|
||||
if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
|
||||
raise RuntimeError(
|
||||
"Text encoder did not return hidden_states. "
|
||||
"Ensure output_hidden_states=True is supported by this model."
|
||||
)
|
||||
|
||||
num_hidden_layers = len(outputs.hidden_states)
|
||||
|
||||
# Extract and stack hidden states - EXACTLY like diffusers:
|
||||
# out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
||||
# prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
hidden_states_list = []
|
||||
for layer_idx in KLEIN_EXTRACTION_LAYERS:
|
||||
if layer_idx >= num_hidden_layers:
|
||||
layer_idx = num_hidden_layers - 1
|
||||
hidden_states_list.append(outputs.hidden_states[layer_idx])
|
||||
|
||||
# Stack along dim=1, then permute and reshape - exactly like diffusers
|
||||
out = torch.stack(hidden_states_list, dim=1)
|
||||
out = out.to(dtype=text_encoder.dtype, device=device)
|
||||
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
|
||||
# Create pooled embedding for global conditioning
|
||||
# Use mean pooling over the sequence (excluding padding)
|
||||
# This serves a similar role to CLIP's pooled output in standard FLUX
|
||||
last_hidden_state = outputs.hidden_states[-1] # Use last layer for pooling
|
||||
# Expand mask to match hidden state dimensions
|
||||
expanded_mask = attention_mask.unsqueeze(-1).expand_as(last_hidden_state).float()
|
||||
sum_embeds = (last_hidden_state * expanded_mask).sum(dim=1)
|
||||
num_tokens = expanded_mask.sum(dim=1).clamp(min=1)
|
||||
pooled_embeds = sum_embeds / num_tokens
|
||||
|
||||
return prompt_embeds, pooled_embeds
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
"""Iterate over LoRA models to apply to the Qwen3 text encoder."""
|
||||
for lora in self.qwen3_encoder.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
if not isinstance(lora_info.model, ModelPatchRaw):
|
||||
raise TypeError(
|
||||
f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. "
|
||||
"The LoRA model may be corrupted or incompatible."
|
||||
)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
92
invokeai/app/invocations/flux2_vae_decode.py
Normal file
92
invokeai/app/invocations/flux2_vae_decode.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Flux2 Klein VAE Decode Invocation.
|
||||
|
||||
Decodes latents to images using the FLUX.2 32-channel VAE (AutoencoderKLFlux2).
|
||||
"""
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux2_vae_decode",
|
||||
title="Latents to Image - FLUX2",
|
||||
tags=["latents", "image", "vae", "l2i", "flux2", "klein"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Flux2VaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents using FLUX.2 Klein's 32-channel VAE."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
||||
"""Decode latents to image using FLUX.2 VAE.
|
||||
|
||||
Input latents should already be in the correct space after BN denormalization
|
||||
was applied in the denoiser. The VAE expects (B, 32, H, W) format.
|
||||
"""
|
||||
with vae_info.model_on_device() as (_, vae):
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
device = TorchDevice.choose_torch_device()
|
||||
latents = latents.to(device=device, dtype=vae_dtype)
|
||||
|
||||
# Decode using diffusers API
|
||||
decoded = vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
# Convert from [-1, 1] to [0, 1] then to [0, 255] PIL image
|
||||
img = (decoded / 2 + 0.5).clamp(0, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
img_np = (img * 255).byte().cpu().numpy()
|
||||
# Explicitly create RGB image (not grayscale)
|
||||
img_pil = Image.fromarray(img_np, mode="RGB")
|
||||
return img_pil
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
# Log latent statistics for debugging black image issues
|
||||
context.logger.debug(
|
||||
f"FLUX.2 VAE decode input: shape={latents.shape}, "
|
||||
f"min={latents.min().item():.4f}, max={latents.max().item():.4f}, "
|
||||
f"mean={latents.mean().item():.4f}"
|
||||
)
|
||||
|
||||
# Warn if input latents are all zeros or very small (would cause black images)
|
||||
if latents.abs().max() < 1e-6:
|
||||
context.logger.warning(
|
||||
"FLUX.2 VAE decode received near-zero latents! This will cause black images. "
|
||||
"The latent cache may be corrupted - try clearing the cache."
|
||||
)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
context.util.signal_progress("Running VAE")
|
||||
image = self._vae_decode(vae_info=vae_info, latents=latents)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
88
invokeai/app/invocations/flux2_vae_encode.py
Normal file
88
invokeai/app/invocations/flux2_vae_encode.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Flux2 Klein VAE Encode Invocation.
|
||||
|
||||
Encodes images to latents using the FLUX.2 32-channel VAE (AutoencoderKLFlux2).
|
||||
"""
|
||||
|
||||
import einops
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux2_vae_encode",
|
||||
title="Image to Latents - FLUX2",
|
||||
tags=["latents", "image", "vae", "i2l", "flux2", "klein"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Flux2VaeEncodeInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents using FLUX.2 Klein's 32-channel VAE."""
|
||||
|
||||
image: ImageField = InputField(
|
||||
description="The image to encode.",
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def _vae_encode(self, vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Encode image to latents using FLUX.2 VAE.
|
||||
|
||||
The VAE encodes to 32-channel latent space.
|
||||
Output latents shape: (B, 32, H/8, W/8).
|
||||
"""
|
||||
with vae_info.model_on_device() as (_, vae):
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
device = TorchDevice.choose_torch_device()
|
||||
image_tensor = image_tensor.to(device=device, dtype=vae_dtype)
|
||||
|
||||
# Encode using diffusers API
|
||||
# The VAE.encode() returns a DiagonalGaussianDistribution-like object
|
||||
latent_dist = vae.encode(image_tensor, return_dict=False)[0]
|
||||
|
||||
# Sample from the distribution (or use mode for deterministic output)
|
||||
# Using mode() for deterministic encoding
|
||||
if hasattr(latent_dist, "mode"):
|
||||
latents = latent_dist.mode()
|
||||
elif hasattr(latent_dist, "sample"):
|
||||
# Fall back to sampling if mode is not available
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
latents = latent_dist.sample(generator=generator)
|
||||
else:
|
||||
# Direct tensor output (some VAE implementations)
|
||||
latents = latent_dist
|
||||
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
|
||||
# Convert image to tensor (HWC -> CHW, normalize to [-1, 1])
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
context.util.signal_progress("Running VAE Encode")
|
||||
latents = self._vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
@@ -32,6 +32,13 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.dype.presets import (
|
||||
DYPE_PRESET_LABELS,
|
||||
DYPE_PRESET_OFF,
|
||||
DyPEPreset,
|
||||
get_dype_config_from_preset,
|
||||
)
|
||||
from invokeai.backend.flux.extensions.dype_extension import DyPEExtension
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.kontext_extension import KontextExtension
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
@@ -47,6 +54,7 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.flux.schedulers import FLUX_SCHEDULER_LABELS, FLUX_SCHEDULER_MAP, FLUX_SCHEDULER_NAME_VALUES
|
||||
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, FluxVariantType, ModelFormat, ModelType
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
@@ -63,7 +71,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="4.1.0",
|
||||
version="4.5.1",
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation):
|
||||
"""Run denoising process with a FLUX transformer model."""
|
||||
@@ -132,6 +140,12 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
num_steps: int = InputField(
|
||||
default=4, description="Number of diffusion steps. Recommended values are schnell: 4, dev: 50."
|
||||
)
|
||||
scheduler: FLUX_SCHEDULER_NAME_VALUES = InputField(
|
||||
default="euler",
|
||||
description="Scheduler (sampler) for the denoising process. 'euler' is fast and standard. "
|
||||
"'heun' is 2nd-order (better quality, 2x slower). 'lcm' is optimized for few steps.",
|
||||
ui_choice_labels=FLUX_SCHEDULER_LABELS,
|
||||
)
|
||||
guidance: float = InputField(
|
||||
default=4.0,
|
||||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
|
||||
@@ -159,6 +173,31 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
# DyPE (Dynamic Position Extrapolation) for high-resolution generation
|
||||
dype_preset: DyPEPreset = InputField(
|
||||
default=DYPE_PRESET_OFF,
|
||||
description=(
|
||||
"DyPE preset for high-resolution generation. 'auto' enables automatically for resolutions > 1536px. "
|
||||
"'area' enables automatically based on image area. '4k' uses optimized settings for 4K output."
|
||||
),
|
||||
ui_order=100,
|
||||
ui_choice_labels=DYPE_PRESET_LABELS,
|
||||
)
|
||||
dype_scale: Optional[float] = InputField(
|
||||
default=None,
|
||||
ge=0.0,
|
||||
le=8.0,
|
||||
description="DyPE magnitude (λs). Higher values = stronger extrapolation. Only used when dype_preset is not 'off'.",
|
||||
ui_order=101,
|
||||
)
|
||||
dype_exponent: Optional[float] = InputField(
|
||||
default=None,
|
||||
ge=0.0,
|
||||
le=1000.0,
|
||||
description="DyPE decay speed (λt). Controls transition from low to high frequency detail. Only used when dype_preset is not 'off'.",
|
||||
ui_order=102,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
@@ -232,8 +271,14 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
transformer_config = context.models.get_config(self.transformer.transformer)
|
||||
assert transformer_config.base is BaseModelType.Flux and transformer_config.type is ModelType.Main
|
||||
is_schnell = transformer_config.variant is FluxVariantType.Schnell
|
||||
assert (
|
||||
transformer_config.base in (BaseModelType.Flux, BaseModelType.Flux2)
|
||||
and transformer_config.type is ModelType.Main
|
||||
)
|
||||
# Schnell is only for FLUX.1, FLUX.2 Klein behaves like Dev (with guidance)
|
||||
is_schnell = (
|
||||
transformer_config.base is BaseModelType.Flux and transformer_config.variant is FluxVariantType.Schnell
|
||||
)
|
||||
|
||||
# Calculate the timestep schedule.
|
||||
timesteps = get_schedule(
|
||||
@@ -242,6 +287,12 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
# Create scheduler if not using default euler
|
||||
scheduler = None
|
||||
if self.scheduler in FLUX_SCHEDULER_MAP:
|
||||
scheduler_class = FLUX_SCHEDULER_MAP[self.scheduler]
|
||||
scheduler = scheduler_class(num_train_timesteps=1000)
|
||||
|
||||
# Clip the timesteps schedule based on denoising_start and denoising_end.
|
||||
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
|
||||
|
||||
@@ -409,6 +460,30 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
kontext_extension.ensure_batch_size(x.shape[0])
|
||||
img_cond_seq, img_cond_seq_ids = kontext_extension.kontext_latents, kontext_extension.kontext_ids
|
||||
|
||||
# Prepare DyPE extension for high-resolution generation
|
||||
dype_extension: DyPEExtension | None = None
|
||||
dype_config = get_dype_config_from_preset(
|
||||
preset=self.dype_preset,
|
||||
width=self.width,
|
||||
height=self.height,
|
||||
custom_scale=self.dype_scale,
|
||||
custom_exponent=self.dype_exponent,
|
||||
)
|
||||
if dype_config is not None:
|
||||
dype_extension = DyPEExtension(
|
||||
config=dype_config,
|
||||
target_height=self.height,
|
||||
target_width=self.width,
|
||||
)
|
||||
context.logger.info(
|
||||
f"DyPE enabled: resolution={self.width}x{self.height}, preset={self.dype_preset}, "
|
||||
f"method={dype_config.method}, scale={dype_config.dype_scale:.2f}, "
|
||||
f"exponent={dype_config.dype_exponent:.2f}, start_sigma={dype_config.dype_start_sigma:.2f}, "
|
||||
f"base_resolution={dype_config.base_resolution}"
|
||||
)
|
||||
else:
|
||||
context.logger.debug(f"DyPE disabled: resolution={self.width}x{self.height}, preset={self.dype_preset}")
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=x,
|
||||
@@ -426,6 +501,8 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
img_cond=img_cond,
|
||||
img_cond_seq=img_cond_seq,
|
||||
img_cond_seq_ids=img_cond_seq_ids,
|
||||
dype_extension=dype_extension,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
|
||||
@@ -162,7 +162,7 @@ class FLUXLoRACollectionLoader(BaseInvocation):
|
||||
if not context.models.exists(lora.lora.key):
|
||||
raise Exception(f"Unknown lora: {lora.lora.key}!")
|
||||
|
||||
assert lora.lora.base is BaseModelType.Flux
|
||||
assert lora.lora.base in (BaseModelType.Flux, BaseModelType.Flux2)
|
||||
|
||||
added_loras.append(lora.lora.key)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.t5_model_identifier import (
|
||||
@@ -37,28 +37,25 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
title="Main Model - FLUX",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.6",
|
||||
version="1.0.7",
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a flux base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
input=Input.Direct,
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.Main,
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Direct,
|
||||
title="T5 Encoder",
|
||||
ui_model_type=ModelType.T5Encoder,
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
ui_model_type=ModelType.CLIPEmbed,
|
||||
)
|
||||
|
||||
@@ -46,7 +46,12 @@ class IdealSizeInvocation(BaseInvocation):
|
||||
dimension = 512
|
||||
elif unet_config.base == BaseModelType.StableDiffusion2:
|
||||
dimension = 768
|
||||
elif unet_config.base in (BaseModelType.StableDiffusionXL, BaseModelType.Flux, BaseModelType.StableDiffusion3):
|
||||
elif unet_config.base in (
|
||||
BaseModelType.StableDiffusionXL,
|
||||
BaseModelType.Flux,
|
||||
BaseModelType.Flux2,
|
||||
BaseModelType.StableDiffusion3,
|
||||
):
|
||||
dimension = 1024
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {unet_config.base}")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from contextlib import nullcontext
|
||||
from functools import singledispatchmethod
|
||||
from typing import Literal
|
||||
|
||||
import einops
|
||||
import torch
|
||||
@@ -20,7 +21,7 @@ from invokeai.app.invocations.fields import (
|
||||
Input,
|
||||
InputField,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.model import BaseModelType, VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
@@ -29,13 +30,21 @@ from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_sd15_sdxl
|
||||
|
||||
"""
|
||||
SDXL VAE color compensation values determined experimentally to reduce color drift.
|
||||
If more reliable values are found in the future (e.g. individual color channels), they can be updated.
|
||||
SD1.5, TAESD, TAESDXL VAEs distort in less predictable ways, so no compensation is offered at this time.
|
||||
"""
|
||||
COMPENSATION_OPTIONS = Literal["None", "SDXL"]
|
||||
COLOR_COMPENSATION_MAP = {"None": [1, 0], "SDXL": [1.015, -0.002]}
|
||||
|
||||
|
||||
@invocation(
|
||||
"i2l",
|
||||
title="Image to Latents - SD1.5, SDXL",
|
||||
tags=["latents", "image", "vae", "i2l"],
|
||||
category="latents",
|
||||
version="1.1.1",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
@@ -52,6 +61,10 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
# offer a way to directly set None values.
|
||||
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
|
||||
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
|
||||
color_compensation: COMPENSATION_OPTIONS = InputField(
|
||||
default="None",
|
||||
description="Apply VAE scaling compensation when encoding images (reduces color drift).",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def vae_encode(
|
||||
@@ -62,7 +75,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
image_tensor: torch.Tensor,
|
||||
tile_size: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny)), "VAE must be of type SD-1.5 or SDXL"
|
||||
estimated_working_memory = estimate_vae_working_memory_sd15_sdxl(
|
||||
operation="encode",
|
||||
image_tensor=image_tensor,
|
||||
@@ -71,7 +84,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
fp32=upcast,
|
||||
)
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny)), "VAE must be of type SD-1.5 or SDXL"
|
||||
orig_dtype = vae.dtype
|
||||
if upcast:
|
||||
vae.to(dtype=torch.float32)
|
||||
@@ -127,9 +140,14 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny)), "VAE must be of type SD-1.5 or SDXL"
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
|
||||
if self.color_compensation != "None" and vae_info.config.base == BaseModelType.StableDiffusionXL:
|
||||
scale, bias = COLOR_COMPENSATION_MAP[self.color_compensation]
|
||||
image_tensor = image_tensor * scale + bias
|
||||
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
|
||||
@@ -2,12 +2,6 @@ from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
|
||||
@@ -77,26 +71,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
latents = latents.to(TorchDevice.choose_torch_device())
|
||||
if self.fp32:
|
||||
# FP32 mode: convert everything to float32 for maximum precision
|
||||
vae.to(dtype=torch.float32)
|
||||
|
||||
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
||||
vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
vae.post_quant_conv.to(latents.dtype)
|
||||
vae.decoder.conv_in.to(latents.dtype)
|
||||
vae.decoder.mid_block.to(latents.dtype)
|
||||
else:
|
||||
latents = latents.float()
|
||||
|
||||
latents = latents.float()
|
||||
else:
|
||||
vae.to(dtype=torch.float16)
|
||||
latents = latents.half()
|
||||
|
||||
@@ -150,6 +150,10 @@ GENERATION_MODES = Literal[
|
||||
"flux_img2img",
|
||||
"flux_inpaint",
|
||||
"flux_outpaint",
|
||||
"flux2_txt2img",
|
||||
"flux2_img2img",
|
||||
"flux2_inpaint",
|
||||
"flux2_outpaint",
|
||||
"sd3_txt2img",
|
||||
"sd3_img2img",
|
||||
"sd3_inpaint",
|
||||
@@ -158,6 +162,10 @@ GENERATION_MODES = Literal[
|
||||
"cogview4_img2img",
|
||||
"cogview4_inpaint",
|
||||
"cogview4_outpaint",
|
||||
"z_image_txt2img",
|
||||
"z_image_img2img",
|
||||
"z_image_inpaint",
|
||||
"z_image_outpaint",
|
||||
]
|
||||
|
||||
|
||||
@@ -166,7 +174,7 @@ GENERATION_MODES = Literal[
|
||||
title="Core Metadata",
|
||||
tags=["metadata"],
|
||||
category="metadata",
|
||||
version="2.0.0",
|
||||
version="2.1.0",
|
||||
classification=Classification.Internal,
|
||||
)
|
||||
class CoreMetadataInvocation(BaseInvocation):
|
||||
@@ -217,6 +225,10 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
qwen3_encoder: Optional[ModelIdentifierField] = InputField(
|
||||
default=None,
|
||||
description="The Qwen3 text encoder model used for Z-Image inference",
|
||||
)
|
||||
|
||||
# High resolution fix metadata.
|
||||
hrf_enabled: Optional[bool] = InputField(
|
||||
|
||||
@@ -52,6 +52,7 @@ from invokeai.app.invocations.primitives import (
|
||||
)
|
||||
from invokeai.app.invocations.scheduler import SchedulerOutput
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField, T2IAdapterInvocation
|
||||
from invokeai.app.invocations.z_image_denoise import ZImageDenoiseInvocation
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
@@ -729,6 +730,52 @@ class FluxDenoiseLatentsMetaInvocation(FluxDenoiseInvocation, WithMetadata):
|
||||
return LatentsMetaOutput(**params, metadata=MetadataField.model_validate(md))
|
||||
|
||||
|
||||
@invocation(
|
||||
"z_image_denoise_meta",
|
||||
title=f"{ZImageDenoiseInvocation.UIConfig.title} + Metadata",
|
||||
tags=["z-image", "latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ZImageDenoiseMetaInvocation(ZImageDenoiseInvocation, WithMetadata):
|
||||
"""Run denoising process with a Z-Image transformer model + metadata."""
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsMetaOutput:
|
||||
def _loras_to_json(obj: Union[Any, list[Any]]):
|
||||
if not isinstance(obj, list):
|
||||
obj = [obj]
|
||||
|
||||
output: list[dict[str, Any]] = []
|
||||
for item in obj:
|
||||
output.append(
|
||||
LoRAMetadataField(
|
||||
model=item.lora,
|
||||
weight=item.weight,
|
||||
).model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
|
||||
)
|
||||
return output
|
||||
|
||||
obj = super().invoke(context)
|
||||
|
||||
md: Dict[str, Any] = {} if self.metadata is None else self.metadata.root
|
||||
md.update({"width": obj.width})
|
||||
md.update({"height": obj.height})
|
||||
md.update({"steps": self.steps})
|
||||
md.update({"guidance": self.guidance_scale})
|
||||
md.update({"denoising_start": self.denoising_start})
|
||||
md.update({"denoising_end": self.denoising_end})
|
||||
md.update({"scheduler": self.scheduler})
|
||||
md.update({"model": self.transformer.transformer})
|
||||
md.update({"seed": self.seed})
|
||||
if len(self.transformer.loras) > 0:
|
||||
md.update({"loras": _loras_to_json(self.transformer.loras)})
|
||||
|
||||
params = obj.__dict__.copy()
|
||||
del params["type"]
|
||||
|
||||
return LatentsMetaOutput(**params, metadata=MetadataField.model_validate(md))
|
||||
|
||||
|
||||
@invocation(
|
||||
"metadata_to_vae",
|
||||
title="Metadata To VAE",
|
||||
|
||||
@@ -72,6 +72,14 @@ class GlmEncoderField(BaseModel):
|
||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||
|
||||
|
||||
class Qwen3EncoderField(BaseModel):
|
||||
"""Field for Qwen3 text encoder used by Z-Image models."""
|
||||
|
||||
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||
loras: List[LoRAField] = Field(default_factory=list, description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class VAEField(BaseModel):
|
||||
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
@@ -502,6 +510,7 @@ class VAELoaderInvocation(BaseInvocation):
|
||||
BaseModelType.StableDiffusionXL,
|
||||
BaseModelType.StableDiffusion3,
|
||||
BaseModelType.Flux,
|
||||
BaseModelType.Flux2,
|
||||
],
|
||||
ui_model_type=ModelType.VAE,
|
||||
)
|
||||
|
||||
59
invokeai/app/invocations/pbr_maps.py
Normal file
59
invokeai/app/invocations/pbr_maps.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import pathlib
|
||||
from typing import Literal
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, OutputField, WithBoard, WithMetadata
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.pbr_maps.architecture.pbr_rrdb_net import PBR_RRDB_Net
|
||||
from invokeai.backend.image_util.pbr_maps.pbr_maps import NORMAL_MAP_MODEL, OTHER_MAP_MODEL, PBRMapsGenerator
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation_output("pbr_maps-output")
|
||||
class PBRMapsOutput(BaseInvocationOutput):
|
||||
normal_map: ImageField = OutputField(default=None, description="The generated normal map")
|
||||
roughness_map: ImageField = OutputField(default=None, description="The generated roughness map")
|
||||
displacement_map: ImageField = OutputField(default=None, description="The generated displacement map")
|
||||
|
||||
|
||||
@invocation("pbr_maps", title="PBR Maps", tags=["image", "material"], category="image", version="1.0.0")
|
||||
class PBRMapsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generate Normal, Displacement and Roughness Map from a given image"""
|
||||
|
||||
image: ImageField = InputField(description="Input image")
|
||||
tile_size: int = InputField(default=512, description="Tile size")
|
||||
border_mode: Literal["none", "seamless", "mirror", "replicate"] = InputField(
|
||||
default="none", description="Border mode to apply to eliminate any artifacts or seams"
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PBRMapsOutput:
|
||||
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||
|
||||
def loader(model_path: pathlib.Path):
|
||||
return PBRMapsGenerator.load_model(model_path, TorchDevice.choose_torch_device())
|
||||
|
||||
torch_device = TorchDevice.choose_torch_device()
|
||||
|
||||
with (
|
||||
context.models.load_remote_model(NORMAL_MAP_MODEL, loader) as normal_map_model,
|
||||
context.models.load_remote_model(OTHER_MAP_MODEL, loader) as other_map_model,
|
||||
):
|
||||
assert isinstance(normal_map_model, PBR_RRDB_Net)
|
||||
assert isinstance(other_map_model, PBR_RRDB_Net)
|
||||
pbr_pipeline = PBRMapsGenerator(normal_map_model, other_map_model, torch_device)
|
||||
normal_map, roughness_map, displacement_map = pbr_pipeline.generate_maps(
|
||||
image_pil, self.tile_size, self.border_mode
|
||||
)
|
||||
|
||||
normal_map = context.images.save(normal_map)
|
||||
normal_map_field = ImageField(image_name=normal_map.image_name)
|
||||
|
||||
roughness_map = context.images.save(roughness_map)
|
||||
roughness_map_field = ImageField(image_name=roughness_map.image_name)
|
||||
|
||||
displacement_map = context.images.save(displacement_map)
|
||||
displacement_map_field = ImageField(image_name=displacement_map.image_name)
|
||||
|
||||
return PBRMapsOutput(
|
||||
normal_map=normal_map_field, roughness_map=roughness_map_field, displacement_map=displacement_map_field
|
||||
)
|
||||
@@ -27,6 +27,7 @@ from invokeai.app.invocations.fields import (
|
||||
SD3ConditioningField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
ZImageConditioningField,
|
||||
)
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
@@ -461,6 +462,17 @@ class CogView4ConditioningOutput(BaseInvocationOutput):
|
||||
return cls(conditioning=CogView4ConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("z_image_conditioning_output")
|
||||
class ZImageConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a Z-Image text conditioning tensor."""
|
||||
|
||||
conditioning: ZImageConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
@classmethod
|
||||
def build(cls, conditioning_name: str) -> "ZImageConditioningOutput":
|
||||
return cls(conditioning=ZImageConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("conditioning_output")
|
||||
class ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
|
||||
57
invokeai/app/invocations/prompt_template.py
Normal file
57
invokeai/app/invocations/prompt_template.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import InputField, OutputField, StylePresetField, UIComponent
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("prompt_template_output")
|
||||
class PromptTemplateOutput(BaseInvocationOutput):
|
||||
"""Output for the Prompt Template node"""
|
||||
|
||||
positive_prompt: str = OutputField(description="The positive prompt with the template applied")
|
||||
negative_prompt: str = OutputField(description="The negative prompt with the template applied")
|
||||
|
||||
|
||||
@invocation(
|
||||
"prompt_template",
|
||||
title="Prompt Template",
|
||||
tags=["prompt", "template", "style", "preset"],
|
||||
category="prompt",
|
||||
version="1.0.0",
|
||||
)
|
||||
class PromptTemplateInvocation(BaseInvocation):
|
||||
"""Applies a Style Preset template to positive and negative prompts.
|
||||
|
||||
Select a Style Preset and provide positive/negative prompts. The node replaces
|
||||
{prompt} placeholders in the template with your input prompts.
|
||||
"""
|
||||
|
||||
style_preset: StylePresetField = InputField(
|
||||
description="The Style Preset to use as a template",
|
||||
)
|
||||
positive_prompt: str = InputField(
|
||||
default="",
|
||||
description="The positive prompt to insert into the template's {prompt} placeholder",
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
negative_prompt: str = InputField(
|
||||
default="",
|
||||
description="The negative prompt to insert into the template's {prompt} placeholder",
|
||||
ui_component=UIComponent.Textarea,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PromptTemplateOutput:
|
||||
# Fetch the style preset from the database
|
||||
style_preset = context._services.style_preset_records.get(self.style_preset.style_preset_id)
|
||||
|
||||
# Get the template prompts
|
||||
positive_template = style_preset.preset_data.positive_prompt
|
||||
negative_template = style_preset.preset_data.negative_prompt
|
||||
|
||||
# Replace {prompt} placeholder with the input prompts
|
||||
rendered_positive = positive_template.replace("{prompt}", self.positive_prompt)
|
||||
rendered_negative = negative_template.replace("{prompt}", self.negative_prompt)
|
||||
|
||||
return PromptTemplateOutput(
|
||||
positive_prompt=rendered_positive,
|
||||
negative_prompt=rendered_negative,
|
||||
)
|
||||
112
invokeai/app/invocations/z_image_control.py
Normal file
112
invokeai/app/invocations/z_image_control.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Z-Image Control invocation for spatial conditioning."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
OutputField,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
class ZImageControlField(BaseModel):
|
||||
"""A Z-Image control conditioning field for spatial control (Canny, HED, Depth, Pose, MLSD)."""
|
||||
|
||||
image_name: str = Field(description="The name of the preprocessed control image")
|
||||
control_model: ModelIdentifierField = Field(description="The Z-Image ControlNet adapter model")
|
||||
control_context_scale: float = Field(
|
||||
default=0.75,
|
||||
ge=0.0,
|
||||
le=2.0,
|
||||
description="The strength of the control signal. Recommended range: 0.65-0.80.",
|
||||
)
|
||||
begin_step_percent: float = Field(
|
||||
default=0.0,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="When the control is first applied (% of total steps)",
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1.0,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="When the control is last applied (% of total steps)",
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("z_image_control_output")
|
||||
class ZImageControlOutput(BaseInvocationOutput):
|
||||
"""Z-Image Control output containing control configuration."""
|
||||
|
||||
control: ZImageControlField = OutputField(description="Z-Image control conditioning")
|
||||
|
||||
|
||||
@invocation(
|
||||
"z_image_control",
|
||||
title="Z-Image ControlNet",
|
||||
tags=["image", "z-image", "control", "controlnet"],
|
||||
category="control",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class ZImageControlInvocation(BaseInvocation):
|
||||
"""Configure Z-Image ControlNet for spatial conditioning.
|
||||
|
||||
Takes a preprocessed control image (e.g., Canny edges, depth map, pose)
|
||||
and a Z-Image ControlNet adapter model to enable spatial control.
|
||||
|
||||
Supports 5 control modes: Canny, HED, Depth, Pose, MLSD.
|
||||
Recommended control_context_scale: 0.65-0.80.
|
||||
"""
|
||||
|
||||
image: ImageField = InputField(
|
||||
description="The preprocessed control image (Canny, HED, Depth, Pose, or MLSD)",
|
||||
)
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model,
|
||||
title="Control Model",
|
||||
ui_model_base=BaseModelType.ZImage,
|
||||
ui_model_type=ModelType.ControlNet,
|
||||
)
|
||||
control_context_scale: float = InputField(
|
||||
default=0.75,
|
||||
ge=0.0,
|
||||
le=2.0,
|
||||
description="Strength of the control signal. Recommended range: 0.65-0.80.",
|
||||
title="Control Scale",
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0.0,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="When the control is first applied (% of total steps)",
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1.0,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="When the control is last applied (% of total steps)",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ZImageControlOutput:
|
||||
return ZImageControlOutput(
|
||||
control=ZImageControlField(
|
||||
image_name=self.image.image_name,
|
||||
control_model=self.control_model,
|
||||
control_context_scale=self.control_context_scale,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
)
|
||||
)
|
||||
770
invokeai/app/invocations/z_image_denoise.py
Normal file
770
invokeai/app/invocations/z_image_denoise.py
Normal file
@@ -0,0 +1,770 @@
|
||||
import inspect
|
||||
import math
|
||||
from contextlib import ExitStack
|
||||
from typing import Callable, Iterator, Optional, Tuple
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
ZImageConditioningField,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.invocations.z_image_control import ZImageControlField
|
||||
from invokeai.app.invocations.z_image_image_to_latents import ZImageImageToLatentsInvocation
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.schedulers import ZIMAGE_SCHEDULER_LABELS, ZIMAGE_SCHEDULER_MAP, ZIMAGE_SCHEDULER_NAME_VALUES
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.z_image_lora_constants import Z_IMAGE_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ZImageConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.z_image.extensions.regional_prompting_extension import ZImageRegionalPromptingExtension
|
||||
from invokeai.backend.z_image.text_conditioning import ZImageTextConditioning
|
||||
from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
|
||||
from invokeai.backend.z_image.z_image_controlnet_extension import (
|
||||
ZImageControlNetExtension,
|
||||
z_image_forward_with_control,
|
||||
)
|
||||
from invokeai.backend.z_image.z_image_transformer_patch import patch_transformer_for_regional_prompting
|
||||
|
||||
|
||||
@invocation(
|
||||
"z_image_denoise",
|
||||
title="Denoise - Z-Image",
|
||||
tags=["image", "z-image"],
|
||||
category="image",
|
||||
version="1.4.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class ZImageDenoiseInvocation(BaseInvocation):
|
||||
"""Run the denoising process with a Z-Image model.
|
||||
|
||||
Supports regional prompting by connecting multiple conditioning inputs with masks.
|
||||
"""
|
||||
|
||||
# If latents is provided, this means we are doing image-to-image.
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None, description=FieldDescriptions.latents, input=Input.Connection
|
||||
)
|
||||
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
|
||||
)
|
||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
add_noise: bool = InputField(default=True, description="Add noise based on denoising start.")
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.z_image_model, input=Input.Connection, title="Transformer"
|
||||
)
|
||||
positive_conditioning: ZImageConditioningField | list[ZImageConditioningField] = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_conditioning: ZImageConditioningField | list[ZImageConditioningField] | None = InputField(
|
||||
default=None, description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
)
|
||||
# Z-Image-Turbo works best without CFG (guidance_scale=1.0)
|
||||
guidance_scale: float = InputField(
|
||||
default=1.0,
|
||||
ge=1.0,
|
||||
description="Guidance scale for classifier-free guidance. 1.0 = no CFG (recommended for Z-Image-Turbo). "
|
||||
"Values > 1.0 amplify guidance.",
|
||||
title="Guidance Scale",
|
||||
)
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
# Z-Image-Turbo uses 8 steps by default
|
||||
steps: int = InputField(default=8, gt=0, description="Number of denoising steps. 8 recommended for Z-Image-Turbo.")
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
# Z-Image Control support
|
||||
control: Optional[ZImageControlField] = InputField(
|
||||
default=None,
|
||||
description="Z-Image control conditioning for spatial control (Canny, HED, Depth, Pose, MLSD).",
|
||||
input=Input.Connection,
|
||||
)
|
||||
# VAE for encoding control images (required when using control)
|
||||
vae: Optional[VAEField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.vae + " Required for control conditioning.",
|
||||
input=Input.Connection,
|
||||
)
|
||||
# Scheduler selection for the denoising process
|
||||
scheduler: ZIMAGE_SCHEDULER_NAME_VALUES = InputField(
|
||||
default="euler",
|
||||
description="Scheduler (sampler) for the denoising process. Euler is the default and recommended for "
|
||||
"Z-Image-Turbo. Heun is 2nd-order (better quality, 2x slower). LCM is optimized for few steps.",
|
||||
ui_choice_labels=ZIMAGE_SCHEDULER_LABELS,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
latents = latents.detach().to("cpu")
|
||||
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
||||
"""Prepare the inpaint mask."""
|
||||
if self.denoise_mask is None:
|
||||
return None
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
|
||||
# Invert mask: 0.0 = regions to denoise, 1.0 = regions to preserve
|
||||
mask = 1.0 - mask
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
mask = tv_resize(
|
||||
img=mask,
|
||||
size=[latent_height, latent_width],
|
||||
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
||||
antialias=False,
|
||||
)
|
||||
|
||||
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
||||
return mask
|
||||
|
||||
def _load_text_conditioning(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
cond_field: ZImageConditioningField | list[ZImageConditioningField],
|
||||
img_height: int,
|
||||
img_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> list[ZImageTextConditioning]:
|
||||
"""Load Z-Image text conditioning with optional regional masks.
|
||||
|
||||
Args:
|
||||
context: The invocation context.
|
||||
cond_field: Single conditioning field or list of fields.
|
||||
img_height: Height of the image token grid (H // patch_size).
|
||||
img_width: Width of the image token grid (W // patch_size).
|
||||
dtype: Target dtype.
|
||||
device: Target device.
|
||||
|
||||
Returns:
|
||||
List of ZImageTextConditioning objects with embeddings and masks.
|
||||
"""
|
||||
# Normalize to a list
|
||||
cond_list = [cond_field] if isinstance(cond_field, ZImageConditioningField) else cond_field
|
||||
|
||||
text_conditionings: list[ZImageTextConditioning] = []
|
||||
for cond in cond_list:
|
||||
# Load the text embeddings
|
||||
cond_data = context.conditioning.load(cond.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
z_image_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(z_image_conditioning, ZImageConditioningInfo)
|
||||
z_image_conditioning = z_image_conditioning.to(dtype=dtype, device=device)
|
||||
prompt_embeds = z_image_conditioning.prompt_embeds
|
||||
|
||||
# Load the mask, if provided
|
||||
mask: torch.Tensor | None = None
|
||||
if cond.mask is not None:
|
||||
mask = context.tensors.load(cond.mask.tensor_name)
|
||||
mask = mask.to(device=device)
|
||||
mask = ZImageRegionalPromptingExtension.preprocess_regional_prompt_mask(
|
||||
mask, img_height, img_width, dtype, device
|
||||
)
|
||||
|
||||
text_conditionings.append(ZImageTextConditioning(prompt_embeds=prompt_embeds, mask=mask))
|
||||
|
||||
return text_conditionings
|
||||
|
||||
def _get_noise(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
seed: int,
|
||||
) -> torch.Tensor:
|
||||
"""Generate initial noise tensor."""
|
||||
# Generate noise as float32 on CPU for maximum compatibility,
|
||||
# then cast to target dtype/device
|
||||
rand_device = "cpu"
|
||||
rand_dtype = torch.float32
|
||||
|
||||
return torch.randn(
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
int(height) // LATENT_SCALE_FACTOR,
|
||||
int(width) // LATENT_SCALE_FACTOR,
|
||||
device=rand_device,
|
||||
dtype=rand_dtype,
|
||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
def _calculate_shift(
|
||||
self,
|
||||
image_seq_len: int,
|
||||
base_image_seq_len: int = 256,
|
||||
max_image_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
) -> float:
|
||||
"""Calculate timestep shift based on image sequence length.
|
||||
|
||||
Based on diffusers ZImagePipeline.calculate_shift method.
|
||||
"""
|
||||
m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len)
|
||||
b = base_shift - m * base_image_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
def _get_sigmas(self, mu: float, num_steps: int) -> list[float]:
|
||||
"""Generate sigma schedule with time shift.
|
||||
|
||||
Based on FlowMatchEulerDiscreteScheduler with shift.
|
||||
Generates num_steps + 1 sigma values (including terminal 0.0).
|
||||
"""
|
||||
import math
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: float) -> float:
|
||||
"""Apply time shift to a single timestep value."""
|
||||
if t <= 0:
|
||||
return 0.0
|
||||
if t >= 1:
|
||||
return 1.0
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
# Generate linearly spaced values from 1 to 0 (excluding endpoints for safety)
|
||||
# then apply time shift
|
||||
sigmas = []
|
||||
for i in range(num_steps + 1):
|
||||
t = 1.0 - i / num_steps # Goes from 1.0 to 0.0
|
||||
sigma = time_shift(mu, 1.0, t)
|
||||
sigmas.append(sigma)
|
||||
|
||||
return sigmas
|
||||
|
||||
def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
|
||||
device = TorchDevice.choose_torch_device()
|
||||
inference_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
|
||||
# Calculate image token grid dimensions
|
||||
patch_size = 2 # Z-Image uses patch_size=2
|
||||
latent_height = self.height // LATENT_SCALE_FACTOR
|
||||
latent_width = self.width // LATENT_SCALE_FACTOR
|
||||
img_token_height = latent_height // patch_size
|
||||
img_token_width = latent_width // patch_size
|
||||
img_seq_len = img_token_height * img_token_width
|
||||
|
||||
# Load positive conditioning with regional masks
|
||||
pos_text_conditionings = self._load_text_conditioning(
|
||||
context=context,
|
||||
cond_field=self.positive_conditioning,
|
||||
img_height=img_token_height,
|
||||
img_width=img_token_width,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Create regional prompting extension
|
||||
regional_extension = ZImageRegionalPromptingExtension.from_text_conditionings(
|
||||
text_conditionings=pos_text_conditionings,
|
||||
img_seq_len=img_seq_len,
|
||||
)
|
||||
|
||||
# Get the concatenated prompt embeddings for the transformer
|
||||
pos_prompt_embeds = regional_extension.regional_text_conditioning.prompt_embeds
|
||||
|
||||
# Load negative conditioning if provided and guidance_scale != 1.0
|
||||
# CFG formula: pred = pred_uncond + cfg_scale * (pred_cond - pred_uncond)
|
||||
# At cfg_scale=1.0: pred = pred_cond (no effect, skip uncond computation)
|
||||
# This matches FLUX's convention where 1.0 means "no CFG"
|
||||
neg_prompt_embeds: torch.Tensor | None = None
|
||||
do_classifier_free_guidance = (
|
||||
not math.isclose(self.guidance_scale, 1.0) and self.negative_conditioning is not None
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
assert self.negative_conditioning is not None
|
||||
# Load all negative conditionings and concatenate embeddings
|
||||
# Note: We ignore masks for negative conditioning as regional negative prompting is not fully supported
|
||||
neg_text_conditionings = self._load_text_conditioning(
|
||||
context=context,
|
||||
cond_field=self.negative_conditioning,
|
||||
img_height=img_token_height,
|
||||
img_width=img_token_width,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
# Concatenate all negative embeddings
|
||||
neg_prompt_embeds = torch.cat([tc.prompt_embeds for tc in neg_text_conditionings], dim=0)
|
||||
|
||||
# Calculate shift based on image sequence length
|
||||
mu = self._calculate_shift(img_seq_len)
|
||||
|
||||
# Generate sigma schedule with time shift
|
||||
sigmas = self._get_sigmas(mu, self.steps)
|
||||
|
||||
# Apply denoising_start and denoising_end clipping
|
||||
if self.denoising_start > 0 or self.denoising_end < 1:
|
||||
# Calculate start and end indices based on denoising range
|
||||
total_sigmas = len(sigmas)
|
||||
start_idx = int(self.denoising_start * (total_sigmas - 1))
|
||||
end_idx = int(self.denoising_end * (total_sigmas - 1)) + 1
|
||||
sigmas = sigmas[start_idx:end_idx]
|
||||
|
||||
total_steps = len(sigmas) - 1
|
||||
|
||||
# Load input latents if provided (image-to-image)
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
init_latents = init_latents.to(device=device, dtype=inference_dtype)
|
||||
|
||||
# Generate initial noise
|
||||
num_channels_latents = 16 # Z-Image uses 16 latent channels
|
||||
noise = self._get_noise(
|
||||
batch_size=1,
|
||||
num_channels_latents=num_channels_latents,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
# Prepare input latent image
|
||||
if init_latents is not None:
|
||||
if self.add_noise:
|
||||
# Noise the init_latents by the appropriate amount for the first timestep.
|
||||
s_0 = sigmas[0]
|
||||
latents = s_0 * noise + (1.0 - s_0) * init_latents
|
||||
else:
|
||||
latents = init_latents
|
||||
else:
|
||||
if self.denoising_start > 1e-5:
|
||||
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
||||
latents = noise
|
||||
|
||||
# Short-circuit if no denoising steps
|
||||
if total_steps <= 0:
|
||||
return latents
|
||||
|
||||
# Prepare inpaint extension
|
||||
inpaint_mask = self._prep_inpaint_mask(context, latents)
|
||||
inpaint_extension: RectifiedFlowInpaintExtension | None = None
|
||||
if inpaint_mask is not None:
|
||||
if init_latents is None:
|
||||
raise ValueError("Initial latents are required when using an inpaint mask (image-to-image inpainting)")
|
||||
inpaint_extension = RectifiedFlowInpaintExtension(
|
||||
init_latents=init_latents,
|
||||
inpaint_mask=inpaint_mask,
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
step_callback = self._build_step_callback(context)
|
||||
|
||||
# Initialize the diffusers scheduler if not using built-in Euler
|
||||
scheduler: SchedulerMixin | None = None
|
||||
use_scheduler = self.scheduler != "euler"
|
||||
|
||||
if use_scheduler:
|
||||
scheduler_class = ZIMAGE_SCHEDULER_MAP[self.scheduler]
|
||||
scheduler = scheduler_class(
|
||||
num_train_timesteps=1000,
|
||||
shift=1.0,
|
||||
)
|
||||
# Set timesteps - LCM should use num_inference_steps (it has its own sigma schedule),
|
||||
# while other schedulers can use custom sigmas if supported
|
||||
is_lcm = self.scheduler == "lcm"
|
||||
set_timesteps_sig = inspect.signature(scheduler.set_timesteps)
|
||||
if not is_lcm and "sigmas" in set_timesteps_sig.parameters:
|
||||
# Convert sigmas list to tensor for scheduler
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device)
|
||||
else:
|
||||
# LCM or scheduler doesn't support custom sigmas - use num_inference_steps
|
||||
scheduler.set_timesteps(num_inference_steps=total_steps, device=device)
|
||||
|
||||
# For Heun scheduler, the number of actual steps may differ
|
||||
num_scheduler_steps = len(scheduler.timesteps)
|
||||
else:
|
||||
num_scheduler_steps = total_steps
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
# Get transformer config to determine if it's quantized
|
||||
transformer_config = context.models.get_config(self.transformer.transformer)
|
||||
|
||||
# Determine if the model is quantized.
|
||||
# If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in
|
||||
# slower inference than direct patching, but is agnostic to the quantization format.
|
||||
if transformer_config.format in [ModelFormat.Diffusers, ModelFormat.Checkpoint]:
|
||||
model_is_quantized = False
|
||||
elif transformer_config.format in [ModelFormat.GGUFQuantized]:
|
||||
model_is_quantized = True
|
||||
else:
|
||||
raise ValueError(f"Unsupported Z-Image model format: {transformer_config.format}")
|
||||
|
||||
# Load transformer - always use base transformer, control is handled via extension
|
||||
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
|
||||
|
||||
# Prepare control extension if control is provided
|
||||
control_extension: ZImageControlNetExtension | None = None
|
||||
|
||||
if self.control is not None:
|
||||
# Load control adapter using context manager (proper GPU memory management)
|
||||
control_model_info = context.models.load(self.control.control_model)
|
||||
(_, control_adapter) = exit_stack.enter_context(control_model_info.model_on_device())
|
||||
assert isinstance(control_adapter, ZImageControlAdapter)
|
||||
|
||||
# Get control_in_dim from adapter config (16 for V1, 33 for V2.0)
|
||||
adapter_config = control_adapter.config
|
||||
control_in_dim = adapter_config.get("control_in_dim", 16)
|
||||
num_control_blocks = adapter_config.get("num_control_blocks", 6)
|
||||
|
||||
# Log control configuration for debugging
|
||||
version = "V2.0" if control_in_dim > 16 else "V1"
|
||||
context.util.signal_progress(
|
||||
f"Using Z-Image ControlNet {version} (Extension): control_in_dim={control_in_dim}, "
|
||||
f"num_blocks={num_control_blocks}, scale={self.control.control_context_scale}"
|
||||
)
|
||||
|
||||
# Load and prepare control image - must be VAE-encoded!
|
||||
if self.vae is None:
|
||||
raise ValueError("VAE is required when using Z-Image Control. Connect a VAE to the 'vae' input.")
|
||||
|
||||
control_image = context.images.get_pil(self.control.image_name)
|
||||
|
||||
# Resize control image to match output dimensions
|
||||
control_image = control_image.convert("RGB")
|
||||
control_image = control_image.resize((self.width, self.height), Image.Resampling.LANCZOS)
|
||||
|
||||
# Convert to tensor format for VAE encoding
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
|
||||
control_image_tensor = image_resized_to_grid_as_tensor(control_image)
|
||||
if control_image_tensor.dim() == 3:
|
||||
control_image_tensor = einops.rearrange(control_image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
# Encode control image through VAE to get latents
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
control_latents = ZImageImageToLatentsInvocation.vae_encode(
|
||||
vae_info=vae_info,
|
||||
image_tensor=control_image_tensor,
|
||||
)
|
||||
|
||||
# Move to inference device/dtype
|
||||
control_latents = control_latents.to(device=device, dtype=inference_dtype)
|
||||
|
||||
# Add frame dimension: [B, C, H, W] -> [C, 1, H, W] (single image)
|
||||
control_latents = control_latents.squeeze(0).unsqueeze(1)
|
||||
|
||||
# Prepare control_cond based on control_in_dim
|
||||
# V1: 16 channels (just control latents)
|
||||
# V2.0: 33 channels = 16 control + 16 reference + 1 mask
|
||||
# - Channels 0-15: control image latents (from VAE encoding)
|
||||
# - Channels 16-31: reference/inpaint image latents (zeros for pure control)
|
||||
# - Channel 32: inpaint mask (1.0 = don't inpaint, 0.0 = inpaint region)
|
||||
# For pure control (no inpainting), we set mask=1 to tell model "use control, don't inpaint"
|
||||
c, f, h, w = control_latents.shape
|
||||
if c < control_in_dim:
|
||||
padding_channels = control_in_dim - c
|
||||
if padding_channels == 17:
|
||||
# V2.0: 16 reference channels (zeros) + 1 mask channel (ones)
|
||||
ref_padding = torch.zeros(
|
||||
(16, f, h, w),
|
||||
device=device,
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
# Mask channel = 1.0 means "don't inpaint this region, use control signal"
|
||||
mask_channel = torch.ones(
|
||||
(1, f, h, w),
|
||||
device=device,
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
control_latents = torch.cat([control_latents, ref_padding, mask_channel], dim=0)
|
||||
else:
|
||||
# Generic padding with zeros for other cases
|
||||
zero_padding = torch.zeros(
|
||||
(padding_channels, f, h, w),
|
||||
device=device,
|
||||
dtype=inference_dtype,
|
||||
)
|
||||
control_latents = torch.cat([control_latents, zero_padding], dim=0)
|
||||
|
||||
# Create control extension (adapter is already on device from model_on_device)
|
||||
control_extension = ZImageControlNetExtension(
|
||||
control_adapter=control_adapter,
|
||||
control_cond=control_latents,
|
||||
weight=self.control.control_context_scale,
|
||||
begin_step_percent=self.control.begin_step_percent,
|
||||
end_step_percent=self.control.end_step_percent,
|
||||
)
|
||||
|
||||
# Apply LoRA models to the transformer.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=Z_IMAGE_LORA_TRANSFORMER_PREFIX,
|
||||
dtype=inference_dtype,
|
||||
cached_weights=cached_weights,
|
||||
force_sidecar_patching=model_is_quantized,
|
||||
)
|
||||
)
|
||||
|
||||
# Apply regional prompting patch if we have regional masks
|
||||
exit_stack.enter_context(
|
||||
patch_transformer_for_regional_prompting(
|
||||
transformer=transformer,
|
||||
regional_attn_mask=regional_extension.regional_attn_mask,
|
||||
img_seq_len=img_seq_len,
|
||||
)
|
||||
)
|
||||
|
||||
# Denoising loop - supports both built-in Euler and diffusers schedulers
|
||||
# Track user-facing step for progress (accounts for Heun's double steps)
|
||||
user_step = 0
|
||||
|
||||
if use_scheduler and scheduler is not None:
|
||||
# Use diffusers scheduler for stepping
|
||||
# Use tqdm with total_steps (user-facing steps) not num_scheduler_steps (internal steps)
|
||||
# This ensures progress bar shows 1/8, 2/8, etc. even when scheduler uses more internal steps
|
||||
pbar = tqdm(total=total_steps, desc="Denoising")
|
||||
for step_index in range(num_scheduler_steps):
|
||||
sched_timestep = scheduler.timesteps[step_index]
|
||||
# Convert scheduler timestep (0-1000) to normalized sigma (0-1)
|
||||
sigma_curr = sched_timestep.item() / scheduler.config.num_train_timesteps
|
||||
|
||||
# For Heun scheduler, track if we're in first or second order step
|
||||
is_heun = hasattr(scheduler, "state_in_first_order")
|
||||
in_first_order = scheduler.state_in_first_order if is_heun else True
|
||||
|
||||
# Timestep tensor for Z-Image model
|
||||
# The model expects t=0 at start (noise) and t=1 at end (clean)
|
||||
model_t = 1.0 - sigma_curr
|
||||
timestep = torch.tensor([model_t], device=device, dtype=inference_dtype).expand(latents.shape[0])
|
||||
|
||||
# Run transformer for positive prediction
|
||||
latent_model_input = latents.to(transformer.dtype)
|
||||
latent_model_input = latent_model_input.unsqueeze(2) # Add frame dimension
|
||||
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
||||
|
||||
# Determine if control should be applied at this step
|
||||
apply_control = control_extension is not None and control_extension.should_apply(
|
||||
user_step, total_steps
|
||||
)
|
||||
|
||||
# Run forward pass
|
||||
if apply_control:
|
||||
model_out_list, _ = z_image_forward_with_control(
|
||||
transformer=transformer,
|
||||
x=latent_model_input_list,
|
||||
t=timestep,
|
||||
cap_feats=[pos_prompt_embeds],
|
||||
control_extension=control_extension,
|
||||
)
|
||||
else:
|
||||
model_output = transformer(
|
||||
x=latent_model_input_list,
|
||||
t=timestep,
|
||||
cap_feats=[pos_prompt_embeds],
|
||||
)
|
||||
model_out_list = model_output[0]
|
||||
|
||||
noise_pred_cond = torch.stack([t.float() for t in model_out_list], dim=0)
|
||||
noise_pred_cond = noise_pred_cond.squeeze(2)
|
||||
noise_pred_cond = -noise_pred_cond # Z-Image uses v-prediction with negation
|
||||
|
||||
# Apply CFG if enabled
|
||||
if do_classifier_free_guidance and neg_prompt_embeds is not None:
|
||||
if apply_control:
|
||||
model_out_list_uncond, _ = z_image_forward_with_control(
|
||||
transformer=transformer,
|
||||
x=latent_model_input_list,
|
||||
t=timestep,
|
||||
cap_feats=[neg_prompt_embeds],
|
||||
control_extension=control_extension,
|
||||
)
|
||||
else:
|
||||
model_output_uncond = transformer(
|
||||
x=latent_model_input_list,
|
||||
t=timestep,
|
||||
cap_feats=[neg_prompt_embeds],
|
||||
)
|
||||
model_out_list_uncond = model_output_uncond[0]
|
||||
|
||||
noise_pred_uncond = torch.stack([t.float() for t in model_out_list_uncond], dim=0)
|
||||
noise_pred_uncond = noise_pred_uncond.squeeze(2)
|
||||
noise_pred_uncond = -noise_pred_uncond
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
else:
|
||||
noise_pred = noise_pred_cond
|
||||
|
||||
# Use scheduler.step() for the update
|
||||
step_output = scheduler.step(model_output=noise_pred, timestep=sched_timestep, sample=latents)
|
||||
latents = step_output.prev_sample
|
||||
|
||||
# Get sigma_prev for inpainting (next sigma value)
|
||||
if step_index + 1 < len(scheduler.sigmas):
|
||||
sigma_prev = scheduler.sigmas[step_index + 1].item()
|
||||
else:
|
||||
sigma_prev = 0.0
|
||||
|
||||
if inpaint_extension is not None:
|
||||
latents = inpaint_extension.merge_intermediate_latents_with_init_latents(latents, sigma_prev)
|
||||
|
||||
# For Heun, only increment user step after second-order step completes
|
||||
if is_heun:
|
||||
if not in_first_order:
|
||||
user_step += 1
|
||||
# Only call step_callback if we haven't exceeded total_steps
|
||||
if user_step <= total_steps:
|
||||
pbar.update(1)
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=user_step,
|
||||
order=2,
|
||||
total_steps=total_steps,
|
||||
timestep=int(sigma_curr * 1000),
|
||||
latents=latents,
|
||||
),
|
||||
)
|
||||
else:
|
||||
# For LCM and other first-order schedulers
|
||||
user_step += 1
|
||||
# Only call step_callback if we haven't exceeded total_steps
|
||||
# (LCM scheduler may have more internal steps than user-facing steps)
|
||||
if user_step <= total_steps:
|
||||
pbar.update(1)
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=user_step,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(sigma_curr * 1000),
|
||||
latents=latents,
|
||||
),
|
||||
)
|
||||
pbar.close()
|
||||
else:
|
||||
# Original Euler implementation (default, optimized for Z-Image)
|
||||
for step_idx in tqdm(range(total_steps)):
|
||||
sigma_curr = sigmas[step_idx]
|
||||
sigma_prev = sigmas[step_idx + 1]
|
||||
|
||||
# Timestep tensor for Z-Image model
|
||||
# The model expects t=0 at start (noise) and t=1 at end (clean)
|
||||
# Sigma goes from 1 (noise) to 0 (clean), so model_t = 1 - sigma
|
||||
model_t = 1.0 - sigma_curr
|
||||
timestep = torch.tensor([model_t], device=device, dtype=inference_dtype).expand(latents.shape[0])
|
||||
|
||||
# Run transformer for positive prediction
|
||||
# Z-Image transformer expects: x as list of [C, 1, H, W] tensors, t, cap_feats as list
|
||||
# Prepare latent input: [B, C, H, W] -> [B, C, 1, H, W] -> list of [C, 1, H, W]
|
||||
latent_model_input = latents.to(transformer.dtype)
|
||||
latent_model_input = latent_model_input.unsqueeze(2) # Add frame dimension
|
||||
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
||||
|
||||
# Determine if control should be applied at this step
|
||||
apply_control = control_extension is not None and control_extension.should_apply(
|
||||
step_idx, total_steps
|
||||
)
|
||||
|
||||
# Run forward pass - use custom forward with control if extension is active
|
||||
if apply_control:
|
||||
model_out_list, _ = z_image_forward_with_control(
|
||||
transformer=transformer,
|
||||
x=latent_model_input_list,
|
||||
t=timestep,
|
||||
cap_feats=[pos_prompt_embeds],
|
||||
control_extension=control_extension,
|
||||
)
|
||||
else:
|
||||
model_output = transformer(
|
||||
x=latent_model_input_list,
|
||||
t=timestep,
|
||||
cap_feats=[pos_prompt_embeds],
|
||||
)
|
||||
model_out_list = model_output[0] # Extract list of tensors from tuple
|
||||
|
||||
noise_pred_cond = torch.stack([t.float() for t in model_out_list], dim=0)
|
||||
noise_pred_cond = noise_pred_cond.squeeze(2) # Remove frame dimension
|
||||
noise_pred_cond = -noise_pred_cond # Z-Image uses v-prediction with negation
|
||||
|
||||
# Apply CFG if enabled
|
||||
if do_classifier_free_guidance and neg_prompt_embeds is not None:
|
||||
if apply_control:
|
||||
model_out_list_uncond, _ = z_image_forward_with_control(
|
||||
transformer=transformer,
|
||||
x=latent_model_input_list,
|
||||
t=timestep,
|
||||
cap_feats=[neg_prompt_embeds],
|
||||
control_extension=control_extension,
|
||||
)
|
||||
else:
|
||||
model_output_uncond = transformer(
|
||||
x=latent_model_input_list,
|
||||
t=timestep,
|
||||
cap_feats=[neg_prompt_embeds],
|
||||
)
|
||||
model_out_list_uncond = model_output_uncond[0] # Extract list of tensors from tuple
|
||||
|
||||
noise_pred_uncond = torch.stack([t.float() for t in model_out_list_uncond], dim=0)
|
||||
noise_pred_uncond = noise_pred_uncond.squeeze(2)
|
||||
noise_pred_uncond = -noise_pred_uncond
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
else:
|
||||
noise_pred = noise_pred_cond
|
||||
|
||||
# Euler step
|
||||
latents_dtype = latents.dtype
|
||||
latents = latents.to(dtype=torch.float32)
|
||||
latents = latents + (sigma_prev - sigma_curr) * noise_pred
|
||||
latents = latents.to(dtype=latents_dtype)
|
||||
|
||||
if inpaint_extension is not None:
|
||||
latents = inpaint_extension.merge_intermediate_latents_with_init_latents(latents, sigma_prev)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=step_idx + 1,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(sigma_curr * 1000),
|
||||
latents=latents,
|
||||
),
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, BaseModelType.ZImage)
|
||||
|
||||
return step_callback
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
"""Iterate over LoRA models to apply to the transformer."""
|
||||
for lora in self.transformer.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
if not isinstance(lora_info.model, ModelPatchRaw):
|
||||
raise TypeError(
|
||||
f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. "
|
||||
"The LoRA model may be corrupted or incompatible."
|
||||
)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
110
invokeai/app/invocations/z_image_image_to_latents.py
Normal file
110
invokeai/app/invocations/z_image_image_to_latents.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from typing import Union
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder as FluxAutoEncoder
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
|
||||
|
||||
# Z-Image can use either the Diffusers AutoencoderKL or the FLUX AutoEncoder
|
||||
ZImageVAE = Union[AutoencoderKL, FluxAutoEncoder]
|
||||
|
||||
|
||||
@invocation(
|
||||
"z_image_i2l",
|
||||
title="Image to Latents - Z-Image",
|
||||
tags=["image", "latents", "vae", "i2l", "z-image"],
|
||||
category="image",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class ZImageImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates latents from an image using Z-Image VAE (supports both Diffusers and FLUX VAE)."""
|
||||
|
||||
image: ImageField = InputField(description="The image to encode.")
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
if not isinstance(vae_info.model, (AutoencoderKL, FluxAutoEncoder)):
|
||||
raise TypeError(
|
||||
f"Expected AutoencoderKL or FluxAutoEncoder for Z-Image VAE, got {type(vae_info.model).__name__}. "
|
||||
"Ensure you are using a compatible VAE model."
|
||||
)
|
||||
|
||||
# Estimate working memory needed for VAE encode
|
||||
estimated_working_memory = estimate_vae_working_memory_flux(
|
||||
operation="encode",
|
||||
image_tensor=image_tensor,
|
||||
vae=vae_info.model,
|
||||
)
|
||||
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
if not isinstance(vae, (AutoencoderKL, FluxAutoEncoder)):
|
||||
raise TypeError(
|
||||
f"Expected AutoencoderKL or FluxAutoEncoder, got {type(vae).__name__}. "
|
||||
"VAE model type changed unexpectedly after loading."
|
||||
)
|
||||
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
||||
|
||||
with torch.inference_mode():
|
||||
if isinstance(vae, FluxAutoEncoder):
|
||||
# FLUX VAE handles scaling internally
|
||||
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
|
||||
latents = vae.encode(image_tensor, sample=True, generator=generator)
|
||||
else:
|
||||
# AutoencoderKL - needs manual scaling
|
||||
vae.disable_tiling()
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
latents: torch.Tensor = image_tensor_dist.sample().to(dtype=vae.dtype)
|
||||
|
||||
# Apply scaling_factor and shift_factor from VAE config
|
||||
# Z-Image uses: latents = (latents - shift_factor) * scaling_factor
|
||||
scaling_factor = vae.config.scaling_factor
|
||||
shift_factor = getattr(vae.config, "shift_factor", None)
|
||||
|
||||
if shift_factor is not None:
|
||||
latents = latents - shift_factor
|
||||
latents = latents * scaling_factor
|
||||
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
if not isinstance(vae_info.model, (AutoencoderKL, FluxAutoEncoder)):
|
||||
raise TypeError(
|
||||
f"Expected AutoencoderKL or FluxAutoEncoder for Z-Image VAE, got {type(vae_info.model).__name__}. "
|
||||
"Ensure you are using a compatible VAE model."
|
||||
)
|
||||
|
||||
context.util.signal_progress("Running VAE")
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
111
invokeai/app/invocations/z_image_latents_to_image.py
Normal file
111
invokeai/app/invocations/z_image_latents_to_image.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from contextlib import nullcontext
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder as FluxAutoEncoder
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
|
||||
|
||||
# Z-Image can use either the Diffusers AutoencoderKL or the FLUX AutoEncoder
|
||||
ZImageVAE = Union[AutoencoderKL, FluxAutoEncoder]
|
||||
|
||||
|
||||
@invocation(
|
||||
"z_image_l2i",
|
||||
title="Latents to Image - Z-Image",
|
||||
tags=["latents", "image", "vae", "l2i", "z-image"],
|
||||
category="latents",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class ZImageLatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents using Z-Image VAE (supports both Diffusers and FLUX VAE)."""
|
||||
|
||||
latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
if not isinstance(vae_info.model, (AutoencoderKL, FluxAutoEncoder)):
|
||||
raise TypeError(
|
||||
f"Expected AutoencoderKL or FluxAutoEncoder for Z-Image VAE, got {type(vae_info.model).__name__}. "
|
||||
"Ensure you are using a compatible VAE model."
|
||||
)
|
||||
|
||||
is_flux_vae = isinstance(vae_info.model, FluxAutoEncoder)
|
||||
|
||||
# Estimate working memory needed for VAE decode
|
||||
estimated_working_memory = estimate_vae_working_memory_flux(
|
||||
operation="decode",
|
||||
image_tensor=latents,
|
||||
vae=vae_info.model,
|
||||
)
|
||||
|
||||
# FLUX VAE doesn't support seamless, so only apply for AutoencoderKL
|
||||
seamless_context = (
|
||||
nullcontext() if is_flux_vae else SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes)
|
||||
)
|
||||
|
||||
with seamless_context, vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
context.util.signal_progress("Running VAE")
|
||||
if not isinstance(vae, (AutoencoderKL, FluxAutoEncoder)):
|
||||
raise TypeError(
|
||||
f"Expected AutoencoderKL or FluxAutoEncoder, got {type(vae).__name__}. "
|
||||
"VAE model type changed unexpectedly after loading."
|
||||
)
|
||||
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
||||
|
||||
# Disable tiling for AutoencoderKL
|
||||
if isinstance(vae, AutoencoderKL):
|
||||
vae.disable_tiling()
|
||||
|
||||
# Clear memory as VAE decode can request a lot
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
if isinstance(vae, FluxAutoEncoder):
|
||||
# FLUX VAE handles scaling internally
|
||||
img = vae.decode(latents)
|
||||
else:
|
||||
# AutoencoderKL - Apply scaling_factor and shift_factor from VAE config
|
||||
# Z-Image uses: latents = latents / scaling_factor + shift_factor
|
||||
scaling_factor = vae.config.scaling_factor
|
||||
shift_factor = getattr(vae.config, "shift_factor", None)
|
||||
|
||||
latents = latents / scaling_factor
|
||||
if shift_factor is not None:
|
||||
latents = latents + shift_factor
|
||||
|
||||
img = vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
image_dto = context.images.save(image=img_pil)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
153
invokeai/app/invocations/z_image_lora_loader.py
Normal file
153
invokeai/app/invocations/z_image_lora_loader.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, Qwen3EncoderField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
@invocation_output("z_image_lora_loader_output")
|
||||
class ZImageLoRALoaderOutput(BaseInvocationOutput):
|
||||
"""Z-Image LoRA Loader Output"""
|
||||
|
||||
transformer: Optional[TransformerField] = OutputField(
|
||||
default=None, description=FieldDescriptions.transformer, title="Z-Image Transformer"
|
||||
)
|
||||
qwen3_encoder: Optional[Qwen3EncoderField] = OutputField(
|
||||
default=None, description=FieldDescriptions.qwen3_encoder, title="Qwen3 Encoder"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"z_image_lora_loader",
|
||||
title="Apply LoRA - Z-Image",
|
||||
tags=["lora", "model", "z-image"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ZImageLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply a LoRA model to a Z-Image transformer and/or Qwen3 text encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model,
|
||||
title="LoRA",
|
||||
ui_model_base=BaseModelType.ZImage,
|
||||
ui_model_type=ModelType.LoRA,
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
transformer: TransformerField | None = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="Z-Image Transformer",
|
||||
)
|
||||
qwen3_encoder: Qwen3EncoderField | None = InputField(
|
||||
default=None,
|
||||
title="Qwen3 Encoder",
|
||||
description=FieldDescriptions.qwen3_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ZImageLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||
|
||||
# Check for existing LoRAs with the same key.
|
||||
if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras):
|
||||
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
|
||||
if self.qwen3_encoder and any(lora.lora.key == lora_key for lora in self.qwen3_encoder.loras):
|
||||
raise ValueError(f'LoRA "{lora_key}" already applied to Qwen3 encoder.')
|
||||
|
||||
output = ZImageLoRALoaderOutput()
|
||||
|
||||
# Attach LoRA layers to the models.
|
||||
if self.transformer is not None:
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
output.transformer.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
if self.qwen3_encoder is not None:
|
||||
output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True)
|
||||
output.qwen3_encoder.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@invocation(
|
||||
"z_image_lora_collection_loader",
|
||||
title="Apply LoRA Collection - Z-Image",
|
||||
tags=["lora", "model", "z-image"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ZImageLoRACollectionLoader(BaseInvocation):
|
||||
"""Applies a collection of LoRAs to a Z-Image transformer."""
|
||||
|
||||
loras: Optional[LoRAField | list[LoRAField]] = InputField(
|
||||
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
|
||||
)
|
||||
|
||||
transformer: Optional[TransformerField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
qwen3_encoder: Qwen3EncoderField | None = InputField(
|
||||
default=None,
|
||||
title="Qwen3 Encoder",
|
||||
description=FieldDescriptions.qwen3_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ZImageLoRALoaderOutput:
|
||||
output = ZImageLoRALoaderOutput()
|
||||
loras = self.loras if isinstance(self.loras, list) else [self.loras]
|
||||
added_loras: list[str] = []
|
||||
|
||||
if self.transformer is not None:
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
|
||||
if self.qwen3_encoder is not None:
|
||||
output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True)
|
||||
|
||||
for lora in loras:
|
||||
if lora is None:
|
||||
continue
|
||||
if lora.lora.key in added_loras:
|
||||
continue
|
||||
|
||||
if not context.models.exists(lora.lora.key):
|
||||
raise Exception(f"Unknown lora: {lora.lora.key}!")
|
||||
|
||||
if lora.lora.base is not BaseModelType.ZImage:
|
||||
raise ValueError(
|
||||
f"LoRA '{lora.lora.key}' is for {lora.lora.base.value if lora.lora.base else 'unknown'} models, "
|
||||
"not Z-Image models. Ensure you are using a Z-Image compatible LoRA."
|
||||
)
|
||||
|
||||
added_loras.append(lora.lora.key)
|
||||
|
||||
if self.transformer is not None and output.transformer is not None:
|
||||
output.transformer.loras.append(lora)
|
||||
|
||||
if self.qwen3_encoder is not None and output.qwen3_encoder is not None:
|
||||
output.qwen3_encoder.loras.append(lora)
|
||||
|
||||
return output
|
||||
135
invokeai/app/invocations/z_image_model_loader.py
Normal file
135
invokeai/app/invocations/z_image_model_loader.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import (
|
||||
ModelIdentifierField,
|
||||
Qwen3EncoderField,
|
||||
TransformerField,
|
||||
VAEField,
|
||||
)
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
|
||||
|
||||
@invocation_output("z_image_model_loader_output")
|
||||
class ZImageModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Z-Image base model loader output."""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
qwen3_encoder: Qwen3EncoderField = OutputField(description=FieldDescriptions.qwen3_encoder, title="Qwen3 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation(
|
||||
"z_image_model_loader",
|
||||
title="Main Model - Z-Image",
|
||||
tags=["model", "z-image"],
|
||||
category="model",
|
||||
version="3.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class ZImageModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a Z-Image model, outputting its submodels.
|
||||
|
||||
Similar to FLUX, you can mix and match components:
|
||||
- Transformer: From Z-Image main model (GGUF quantized or Diffusers format)
|
||||
- VAE: Separate FLUX VAE (shared with FLUX models) or from a Diffusers Z-Image model
|
||||
- Qwen3 Encoder: Separate Qwen3Encoder model or from a Diffusers Z-Image model
|
||||
"""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.z_image_model,
|
||||
input=Input.Direct,
|
||||
ui_model_base=BaseModelType.ZImage,
|
||||
ui_model_type=ModelType.Main,
|
||||
title="Transformer",
|
||||
)
|
||||
|
||||
vae_model: Optional[ModelIdentifierField] = InputField(
|
||||
default=None,
|
||||
description="Standalone VAE model. Z-Image uses the same VAE as FLUX (16-channel). "
|
||||
"If not provided, VAE will be loaded from the Qwen3 Source model.",
|
||||
input=Input.Direct,
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.VAE,
|
||||
title="VAE",
|
||||
)
|
||||
|
||||
qwen3_encoder_model: Optional[ModelIdentifierField] = InputField(
|
||||
default=None,
|
||||
description="Standalone Qwen3 Encoder model. "
|
||||
"If not provided, encoder will be loaded from the Qwen3 Source model.",
|
||||
input=Input.Direct,
|
||||
ui_model_type=ModelType.Qwen3Encoder,
|
||||
title="Qwen3 Encoder",
|
||||
)
|
||||
|
||||
qwen3_source_model: Optional[ModelIdentifierField] = InputField(
|
||||
default=None,
|
||||
description="Diffusers Z-Image model to extract VAE and/or Qwen3 encoder from. "
|
||||
"Use this if you don't have separate VAE/Qwen3 models. "
|
||||
"Ignored if both VAE and Qwen3 Encoder are provided separately.",
|
||||
input=Input.Direct,
|
||||
ui_model_base=BaseModelType.ZImage,
|
||||
ui_model_type=ModelType.Main,
|
||||
ui_model_format=ModelFormat.Diffusers,
|
||||
title="Qwen3 Source (Diffusers)",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ZImageModelLoaderOutput:
|
||||
# Transformer always comes from the main model
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
|
||||
# Determine VAE source
|
||||
if self.vae_model is not None:
|
||||
# Use standalone FLUX VAE
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
elif self.qwen3_source_model is not None:
|
||||
# Extract from Diffusers Z-Image model
|
||||
self._validate_diffusers_format(context, self.qwen3_source_model, "Qwen3 Source")
|
||||
vae = self.qwen3_source_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
else:
|
||||
raise ValueError(
|
||||
"No VAE source provided. Either set 'VAE' to a FLUX VAE model, "
|
||||
"or set 'Qwen3 Source' to a Diffusers Z-Image model."
|
||||
)
|
||||
|
||||
# Determine Qwen3 Encoder source
|
||||
if self.qwen3_encoder_model is not None:
|
||||
# Use standalone Qwen3 Encoder
|
||||
qwen3_tokenizer = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
qwen3_encoder = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
elif self.qwen3_source_model is not None:
|
||||
# Extract from Diffusers Z-Image model
|
||||
self._validate_diffusers_format(context, self.qwen3_source_model, "Qwen3 Source")
|
||||
qwen3_tokenizer = self.qwen3_source_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
qwen3_encoder = self.qwen3_source_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
else:
|
||||
raise ValueError(
|
||||
"No Qwen3 Encoder source provided. Either set 'Qwen3 Encoder' to a standalone model, "
|
||||
"or set 'Qwen3 Source' to a Diffusers Z-Image model."
|
||||
)
|
||||
|
||||
return ZImageModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
qwen3_encoder=Qwen3EncoderField(tokenizer=qwen3_tokenizer, text_encoder=qwen3_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
)
|
||||
|
||||
def _validate_diffusers_format(
|
||||
self, context: InvocationContext, model: ModelIdentifierField, model_name: str
|
||||
) -> None:
|
||||
"""Validate that a model is in Diffusers format."""
|
||||
config = context.models.get_config(model)
|
||||
if config.format != ModelFormat.Diffusers:
|
||||
raise ValueError(
|
||||
f"The {model_name} model must be a Diffusers format Z-Image model. "
|
||||
f"The selected model '{config.name}' is in {config.format.value} format."
|
||||
)
|
||||
110
invokeai/app/invocations/z_image_seed_variance_enhancer.py
Normal file
110
invokeai/app/invocations/z_image_seed_variance_enhancer.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
ZImageConditioningField,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ZImageConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
ConditioningFieldData,
|
||||
ZImageConditioningInfo,
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"z_image_seed_variance_enhancer",
|
||||
title="Seed Variance Enhancer - Z-Image",
|
||||
tags=["conditioning", "z-image", "variance", "seed"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class ZImageSeedVarianceEnhancerInvocation(BaseInvocation):
|
||||
"""Adds seed-based noise to Z-Image conditioning to increase variance between seeds.
|
||||
|
||||
Z-Image-Turbo can produce relatively similar images with different seeds,
|
||||
making it harder to explore variations of a prompt. This node implements
|
||||
reproducible, seed-based noise injection into text embeddings to increase
|
||||
visual variation while maintaining reproducibility.
|
||||
|
||||
The noise strength is auto-calibrated relative to the embedding's standard
|
||||
deviation, ensuring consistent results across different prompts.
|
||||
"""
|
||||
|
||||
conditioning: ZImageConditioningField = InputField(
|
||||
description=FieldDescriptions.cond,
|
||||
input=Input.Connection,
|
||||
title="Conditioning",
|
||||
)
|
||||
seed: int = InputField(
|
||||
default=0,
|
||||
ge=0,
|
||||
description="Seed for reproducible noise generation. Different seeds produce different noise patterns.",
|
||||
)
|
||||
strength: float = InputField(
|
||||
default=0.1,
|
||||
ge=0.0,
|
||||
le=2.0,
|
||||
description="Noise strength as multiplier of embedding std. 0=off, 0.1=subtle, 0.5=strong.",
|
||||
)
|
||||
randomize_percent: float = InputField(
|
||||
default=50.0,
|
||||
ge=1.0,
|
||||
le=100.0,
|
||||
description="Percentage of embedding values to add noise to (1-100). Lower values create more selective noise patterns.",
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ZImageConditioningOutput:
|
||||
# Load conditioning data
|
||||
cond_data = context.conditioning.load(self.conditioning.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1, "Expected exactly one conditioning tensor"
|
||||
z_image_conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(z_image_conditioning, ZImageConditioningInfo), "Expected ZImageConditioningInfo"
|
||||
|
||||
# Early return if strength is zero (no modification needed)
|
||||
if self.strength == 0:
|
||||
return ZImageConditioningOutput(conditioning=self.conditioning)
|
||||
|
||||
# Clone embeddings to avoid modifying the original
|
||||
prompt_embeds = z_image_conditioning.prompt_embeds.clone()
|
||||
|
||||
# Calculate actual noise strength based on embedding statistics
|
||||
# This auto-calibration ensures consistent results across different prompts
|
||||
embed_std = torch.std(prompt_embeds).item()
|
||||
actual_strength = self.strength * embed_std
|
||||
|
||||
# Generate deterministic noise using the seed
|
||||
generator = torch.Generator(device=prompt_embeds.device)
|
||||
generator.manual_seed(self.seed)
|
||||
noise = torch.rand(
|
||||
prompt_embeds.shape, generator=generator, device=prompt_embeds.device, dtype=prompt_embeds.dtype
|
||||
)
|
||||
noise = noise * 2 - 1 # Scale to [-1, 1)
|
||||
noise = noise * actual_strength
|
||||
|
||||
# Create selective mask for noise application
|
||||
generator.manual_seed(self.seed + 1)
|
||||
noise_mask = torch.bernoulli(
|
||||
torch.ones_like(prompt_embeds) * (self.randomize_percent / 100.0),
|
||||
generator=generator,
|
||||
).bool()
|
||||
|
||||
# Apply noise only to masked positions
|
||||
prompt_embeds = prompt_embeds + (noise * noise_mask)
|
||||
|
||||
# Save modified conditioning
|
||||
new_conditioning = ZImageConditioningInfo(prompt_embeds=prompt_embeds)
|
||||
conditioning_data = ConditioningFieldData(conditionings=[new_conditioning])
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ZImageConditioningOutput(
|
||||
conditioning=ZImageConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
mask=self.conditioning.mask,
|
||||
)
|
||||
)
|
||||
197
invokeai/app/invocations/z_image_text_encoder.py
Normal file
197
invokeai/app/invocations/z_image_text_encoder.py
Normal file
@@ -0,0 +1,197 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
ZImageConditioningField,
|
||||
)
|
||||
from invokeai.app.invocations.model import Qwen3EncoderField
|
||||
from invokeai.app.invocations.primitives import ZImageConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.z_image_lora_constants import Z_IMAGE_LORA_QWEN3_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
ConditioningFieldData,
|
||||
ZImageConditioningInfo,
|
||||
)
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
# Z-Image max sequence length based on diffusers default
|
||||
Z_IMAGE_MAX_SEQ_LEN = 512
|
||||
|
||||
|
||||
@invocation(
|
||||
"z_image_text_encoder",
|
||||
title="Prompt - Z-Image",
|
||||
tags=["prompt", "conditioning", "z-image"],
|
||||
category="conditioning",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class ZImageTextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes and preps a prompt for a Z-Image image.
|
||||
|
||||
Supports regional prompting by connecting a mask input.
|
||||
"""
|
||||
|
||||
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
|
||||
qwen3_encoder: Qwen3EncoderField = InputField(
|
||||
title="Qwen3 Encoder",
|
||||
description=FieldDescriptions.qwen3_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
mask: Optional[TensorField] = InputField(
|
||||
default=None,
|
||||
description="A mask defining the region that this conditioning prompt applies to.",
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ZImageConditioningOutput:
|
||||
prompt_embeds = self._encode_prompt(context, max_seq_len=Z_IMAGE_MAX_SEQ_LEN)
|
||||
conditioning_data = ConditioningFieldData(conditionings=[ZImageConditioningInfo(prompt_embeds=prompt_embeds)])
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return ZImageConditioningOutput(
|
||||
conditioning=ZImageConditioningField(conditioning_name=conditioning_name, mask=self.mask)
|
||||
)
|
||||
|
||||
def _encode_prompt(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
|
||||
"""Encode prompt using Qwen3 text encoder.
|
||||
|
||||
Based on the ZImagePipeline._encode_prompt method from diffusers.
|
||||
"""
|
||||
prompt = self.prompt
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
text_encoder_info = context.models.load(self.qwen3_encoder.text_encoder)
|
||||
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
(_, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
|
||||
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
|
||||
|
||||
# Apply LoRA models to the text encoder
|
||||
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=text_encoder,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=Z_IMAGE_LORA_QWEN3_PREFIX,
|
||||
dtype=lora_dtype,
|
||||
)
|
||||
)
|
||||
|
||||
context.util.signal_progress("Running Qwen3 text encoder")
|
||||
if not isinstance(text_encoder, PreTrainedModel):
|
||||
raise TypeError(
|
||||
f"Expected PreTrainedModel for text encoder, got {type(text_encoder).__name__}. "
|
||||
"The Qwen3 encoder model may be corrupted or incompatible."
|
||||
)
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
||||
raise TypeError(
|
||||
f"Expected PreTrainedTokenizerBase for tokenizer, got {type(tokenizer).__name__}. "
|
||||
"The Qwen3 tokenizer may be corrupted or incompatible."
|
||||
)
|
||||
|
||||
# Apply chat template similar to diffusers ZImagePipeline
|
||||
# The chat template formats the prompt for the Qwen3 model
|
||||
try:
|
||||
prompt_formatted = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
)
|
||||
except (AttributeError, TypeError) as e:
|
||||
# Fallback if tokenizer doesn't support apply_chat_template or enable_thinking
|
||||
context.logger.warning(f"Chat template failed ({e}), using raw prompt.")
|
||||
prompt_formatted = prompt
|
||||
|
||||
# Tokenize the formatted prompt
|
||||
text_inputs = tokenizer(
|
||||
prompt_formatted,
|
||||
padding="max_length",
|
||||
max_length=max_seq_len,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask
|
||||
if not isinstance(text_input_ids, torch.Tensor):
|
||||
raise TypeError(
|
||||
f"Expected torch.Tensor for input_ids, got {type(text_input_ids).__name__}. "
|
||||
"Tokenizer returned unexpected type."
|
||||
)
|
||||
if not isinstance(attention_mask, torch.Tensor):
|
||||
raise TypeError(
|
||||
f"Expected torch.Tensor for attention_mask, got {type(attention_mask).__name__}. "
|
||||
"Tokenizer returned unexpected type."
|
||||
)
|
||||
|
||||
# Check for truncation
|
||||
untruncated_ids = tokenizer(prompt_formatted, padding="longest", return_tensors="pt").input_ids
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, max_seq_len - 1 : -1])
|
||||
context.logger.warning(
|
||||
f"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f"{max_seq_len} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
# Get hidden states from the text encoder
|
||||
# Use the second-to-last hidden state like diffusers does
|
||||
prompt_mask = attention_mask.to(device).bool()
|
||||
outputs = text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=prompt_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
# Validate hidden_states output
|
||||
if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
|
||||
raise RuntimeError(
|
||||
"Text encoder did not return hidden_states. "
|
||||
"Ensure output_hidden_states=True is supported by this model."
|
||||
)
|
||||
if len(outputs.hidden_states) < 2:
|
||||
raise RuntimeError(
|
||||
f"Expected at least 2 hidden states from text encoder, got {len(outputs.hidden_states)}. "
|
||||
"This may indicate an incompatible model or configuration."
|
||||
)
|
||||
prompt_embeds = outputs.hidden_states[-2]
|
||||
|
||||
# Z-Image expects a 2D tensor [seq_len, hidden_dim] with only valid tokens
|
||||
# Based on diffusers ZImagePipeline implementation:
|
||||
# embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
|
||||
# Since batch_size=1, we take the first item and filter by mask
|
||||
prompt_embeds = prompt_embeds[0][prompt_mask[0]]
|
||||
|
||||
if not isinstance(prompt_embeds, torch.Tensor):
|
||||
raise TypeError(
|
||||
f"Expected torch.Tensor for prompt embeddings, got {type(prompt_embeds).__name__}. "
|
||||
"Text encoder returned unexpected type."
|
||||
)
|
||||
return prompt_embeds
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
"""Iterate over LoRA models to apply to the Qwen3 text encoder."""
|
||||
for lora in self.qwen3_encoder.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
if not isinstance(lora_info.model, ModelPatchRaw):
|
||||
raise TypeError(
|
||||
f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. "
|
||||
"The LoRA model may be corrupted or incompatible."
|
||||
)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
@@ -85,6 +85,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
max_cache_ram_gb: The maximum amount of CPU RAM to use for model caching in GB. If unset, the limit will be configured based on the available RAM. In most cases, it is recommended to leave this unset.
|
||||
max_cache_vram_gb: The amount of VRAM to use for model caching in GB. If unset, the limit will be configured based on the available VRAM and the device_working_mem_gb. In most cases, it is recommended to leave this unset.
|
||||
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
||||
model_cache_keep_alive_min: How long to keep models in cache after last use, in minutes. A value of 0 (the default) means models are kept in cache indefinitely. If no model generations occur within the timeout period, the model cache is cleared using the same logic as the 'Clear Model Cache' button.
|
||||
device_working_mem_gb: The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.
|
||||
enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.
|
||||
keep_ram_copy_of_weights: Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.
|
||||
@@ -165,9 +166,10 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
max_cache_ram_gb: Optional[float] = Field(default=None, gt=0, description="The maximum amount of CPU RAM to use for model caching in GB. If unset, the limit will be configured based on the available RAM. In most cases, it is recommended to leave this unset.")
|
||||
max_cache_vram_gb: Optional[float] = Field(default=None, ge=0, description="The amount of VRAM to use for model caching in GB. If unset, the limit will be configured based on the available VRAM and the device_working_mem_gb. In most cases, it is recommended to leave this unset.")
|
||||
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
|
||||
model_cache_keep_alive_min: float = Field(default=0, ge=0, description="How long to keep models in cache after last use, in minutes. A value of 0 (the default) means models are kept in cache indefinitely. If no model generations occur within the timeout period, the model cache is cleared using the same logic as the 'Clear Model Cache' button.")
|
||||
device_working_mem_gb: float = Field(default=3, description="The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.")
|
||||
enable_partial_loading: bool = Field(default=False, description="Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.")
|
||||
keep_ram_copy_of_weights: bool = Field(default=True, description="Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.")
|
||||
keep_ram_copy_of_weights: bool = Field(default=True, description="Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.")
|
||||
# Deprecated CACHE configs
|
||||
ram: Optional[float] = Field(default=None, gt=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
|
||||
vram: Optional[float] = Field(default=None, ge=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
|
||||
|
||||
@@ -14,7 +14,7 @@ class NodeExecutionStatsSummary:
|
||||
node_type: str
|
||||
num_calls: int
|
||||
time_used_seconds: float
|
||||
peak_vram_gb: float
|
||||
delta_vram_gb: float
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -58,10 +58,10 @@ class InvocationStatsSummary:
|
||||
def __str__(self) -> str:
|
||||
_str = ""
|
||||
_str = f"Graph stats: {self.graph_stats.graph_execution_state_id}\n"
|
||||
_str += f"{'Node':>30} {'Calls':>7} {'Seconds':>9} {'VRAM Used':>10}\n"
|
||||
_str += f"{'Node':>30} {'Calls':>7} {'Seconds':>9} {'VRAM Change':+>10}\n"
|
||||
|
||||
for summary in self.node_stats:
|
||||
_str += f"{summary.node_type:>30} {summary.num_calls:>7} {summary.time_used_seconds:>8.3f}s {summary.peak_vram_gb:>9.3f}G\n"
|
||||
_str += f"{summary.node_type:>30} {summary.num_calls:>7} {summary.time_used_seconds:>8.3f}s {summary.delta_vram_gb:+10.3f}G\n"
|
||||
|
||||
_str += f"TOTAL GRAPH EXECUTION TIME: {self.graph_stats.execution_time_seconds:7.3f}s\n"
|
||||
|
||||
@@ -100,7 +100,7 @@ class NodeExecutionStats:
|
||||
start_ram_gb: float # GB
|
||||
end_ram_gb: float # GB
|
||||
|
||||
peak_vram_gb: float # GB
|
||||
delta_vram_gb: float # GB
|
||||
|
||||
def total_time(self) -> float:
|
||||
return self.end_time - self.start_time
|
||||
@@ -174,9 +174,9 @@ class GraphExecutionStats:
|
||||
for node_type, node_type_stats_list in node_stats_by_type.items():
|
||||
num_calls = len(node_type_stats_list)
|
||||
time_used = sum([n.total_time() for n in node_type_stats_list])
|
||||
peak_vram = max([n.peak_vram_gb for n in node_type_stats_list])
|
||||
delta_vram = max([n.delta_vram_gb for n in node_type_stats_list])
|
||||
summary = NodeExecutionStatsSummary(
|
||||
node_type=node_type, num_calls=num_calls, time_used_seconds=time_used, peak_vram_gb=peak_vram
|
||||
node_type=node_type, num_calls=num_calls, time_used_seconds=time_used, delta_vram_gb=delta_vram
|
||||
)
|
||||
summaries.append(summary)
|
||||
|
||||
|
||||
@@ -52,8 +52,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
# Record state before the invocation.
|
||||
start_time = time.time()
|
||||
start_ram = psutil.Process().memory_info().rss
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Remember current VRAM usage
|
||||
vram_in_use = torch.cuda.memory_allocated() if torch.cuda.is_available() else 0.0
|
||||
|
||||
assert services.model_manager.load is not None
|
||||
services.model_manager.load.ram_cache.stats = self._cache_stats[graph_execution_state_id]
|
||||
@@ -62,14 +63,16 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
# Let the invocation run.
|
||||
yield None
|
||||
finally:
|
||||
# Record state after the invocation.
|
||||
# Record delta VRAM
|
||||
delta_vram_gb = ((torch.cuda.memory_allocated() - vram_in_use) / GB) if torch.cuda.is_available() else 0.0
|
||||
|
||||
node_stats = NodeExecutionStats(
|
||||
invocation_type=invocation.get_type(),
|
||||
start_time=start_time,
|
||||
end_time=time.time(),
|
||||
start_ram_gb=start_ram / GB,
|
||||
end_ram_gb=psutil.Process().memory_info().rss / GB,
|
||||
peak_vram_gb=torch.cuda.max_memory_allocated() / GB if torch.cuda.is_available() else 0.0,
|
||||
delta_vram_gb=delta_vram_gb,
|
||||
)
|
||||
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
|
||||
|
||||
@@ -81,6 +84,8 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)
|
||||
node_stats_summaries = self._get_node_summaries(graph_execution_state_id)
|
||||
model_cache_stats_summary = self._get_model_cache_summary(graph_execution_state_id)
|
||||
# Note: We use memory_allocated() here (not memory_reserved()) because we want to show
|
||||
# the current actively-used VRAM, not the total reserved memory including PyTorch's cache.
|
||||
vram_usage_gb = torch.cuda.memory_allocated() / GB if torch.cuda.is_available() else None
|
||||
|
||||
return InvocationStatsSummary(
|
||||
|
||||
@@ -85,9 +85,12 @@ class LocalModelSource(StringLikeSource):
|
||||
|
||||
class HFModelSource(StringLikeSource):
|
||||
"""
|
||||
A HuggingFace repo_id with optional variant, sub-folder and access token.
|
||||
A HuggingFace repo_id with optional variant, sub-folder(s) and access token.
|
||||
Note that the variant option, if not provided to the constructor, will default to fp16, which is
|
||||
what people (almost) always want.
|
||||
|
||||
The subfolder can be a single path or multiple paths joined by '+' (e.g., "text_encoder+tokenizer").
|
||||
When multiple subfolders are specified, all of them will be downloaded and combined into the model directory.
|
||||
"""
|
||||
|
||||
repo_id: str
|
||||
@@ -103,6 +106,16 @@ class HFModelSource(StringLikeSource):
|
||||
raise ValueError(f"{v}: invalid repo_id format")
|
||||
return v
|
||||
|
||||
@property
|
||||
def subfolders(self) -> list[Path]:
|
||||
"""Return list of subfolders (supports '+' separated multiple subfolders)."""
|
||||
if self.subfolder is None:
|
||||
return []
|
||||
subfolder_str = self.subfolder.as_posix()
|
||||
if "+" in subfolder_str:
|
||||
return [Path(s.strip()) for s in subfolder_str.split("+")]
|
||||
return [self.subfolder]
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of repoid when string rep needed."""
|
||||
base: str = self.repo_id
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"""Model installation class."""
|
||||
|
||||
import gc
|
||||
import locale
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from copy import deepcopy
|
||||
@@ -135,6 +137,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
for model in self._scan_for_missing_models():
|
||||
self._logger.warning(f"Missing model file: {model.name} at {model.path}")
|
||||
|
||||
self._write_invoke_managed_models_dir_readme()
|
||||
|
||||
def stop(self, invoker: Optional[Invoker] = None) -> None:
|
||||
"""Stop the installer thread; after this the object can be deleted and garbage collected."""
|
||||
if not self._running:
|
||||
@@ -147,6 +151,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._install_thread.join()
|
||||
self._running = False
|
||||
|
||||
def _write_invoke_managed_models_dir_readme(self) -> None:
|
||||
"""Write a README file to the Invoke-managed models directory warning users to not fiddle with it."""
|
||||
readme_path = self.app_config.models_path / "README.txt"
|
||||
with open(readme_path, "wt", encoding=locale.getpreferredencoding()) as f:
|
||||
f.write(
|
||||
"This directory is managed by Invoke. Do not add, delete or move files in this directory.\n\nTo manage models, use the web interface.\n"
|
||||
)
|
||||
|
||||
def _clear_pending_jobs(self) -> None:
|
||||
for job in self.list_jobs():
|
||||
if not job.in_terminal_state:
|
||||
@@ -177,6 +189,22 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
config.source_type = ModelSourceType.Path
|
||||
return self._register(model_path, config)
|
||||
|
||||
# TODO: Replace this with a proper fix for underlying problem of Windows holding open
|
||||
# the file when it needs to be moved.
|
||||
@staticmethod
|
||||
def _move_with_retries(src: Path, dst: Path, attempts: int = 5, delay: float = 0.5) -> None:
|
||||
"""Workaround for Windows file-handle issues when moving files."""
|
||||
for tries_left in range(attempts, 0, -1):
|
||||
try:
|
||||
move(src, dst)
|
||||
return
|
||||
except PermissionError:
|
||||
gc.collect()
|
||||
if tries_left == 1:
|
||||
raise
|
||||
time.sleep(delay)
|
||||
delay *= 2 # Exponential backoff
|
||||
|
||||
def install_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
@@ -195,7 +223,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
dest_dir.mkdir(parents=True)
|
||||
dest_path = dest_dir / model_path.name if model_path.is_file() else dest_dir
|
||||
if model_path.is_file():
|
||||
move(model_path, dest_path)
|
||||
self._move_with_retries(model_path, dest_path) # Windows workaround TODO: fix root cause
|
||||
elif model_path.is_dir():
|
||||
# Move the contents of the directory, not the directory itself
|
||||
for item in model_path.iterdir():
|
||||
@@ -407,10 +435,15 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
model_path.mkdir(parents=True, exist_ok=True)
|
||||
model_source = self._guess_source(str(source))
|
||||
remote_files, _ = self._remote_files_from_source(model_source)
|
||||
# Handle multiple subfolders for HFModelSource
|
||||
subfolders = model_source.subfolders if isinstance(model_source, HFModelSource) else []
|
||||
job = self._multifile_download(
|
||||
dest=model_path,
|
||||
remote_files=remote_files,
|
||||
subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
|
||||
subfolder=model_source.subfolder
|
||||
if isinstance(model_source, HFModelSource) and len(subfolders) <= 1
|
||||
else None,
|
||||
subfolders=subfolders if len(subfolders) > 1 else None,
|
||||
)
|
||||
files_string = "file" if len(remote_files) == 1 else "files"
|
||||
self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
|
||||
@@ -428,10 +461,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
if isinstance(source, HFModelSource):
|
||||
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
# Use subfolders property which handles '+' separated multiple subfolders
|
||||
subfolders = source.subfolders
|
||||
return (
|
||||
metadata.download_urls(
|
||||
variant=source.variant or self._guess_variant(),
|
||||
subfolder=source.subfolder,
|
||||
subfolder=source.subfolder if len(subfolders) <= 1 else None,
|
||||
subfolders=subfolders if len(subfolders) > 1 else None,
|
||||
session=self._session,
|
||||
),
|
||||
metadata,
|
||||
@@ -482,6 +518,39 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._install_thread.start()
|
||||
self._running = True
|
||||
|
||||
@staticmethod
|
||||
def _safe_rmtree(path: Path, logger: Any) -> None:
|
||||
"""Remove a directory tree with retry logic for Windows file locking issues.
|
||||
|
||||
On Windows, memory-mapped files may not be immediately released even after
|
||||
the file handle is closed. This function retries the removal with garbage
|
||||
collection to help release any lingering references.
|
||||
"""
|
||||
max_retries = 3
|
||||
retry_delay = 0.5 # seconds
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Force garbage collection to release any lingering file references
|
||||
gc.collect()
|
||||
rmtree(path)
|
||||
return
|
||||
except PermissionError as e:
|
||||
if attempt < max_retries - 1 and sys.platform == "win32":
|
||||
logger.warning(
|
||||
f"Failed to remove {path} (attempt {attempt + 1}/{max_retries}): {e}. "
|
||||
f"Retrying in {retry_delay}s..."
|
||||
)
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 2 # Exponential backoff
|
||||
else:
|
||||
logger.error(f"Failed to remove temporary directory {path}: {e}")
|
||||
# On final failure, don't raise - the temp dir will be cleaned up on next startup
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error removing {path}: {e}")
|
||||
return
|
||||
|
||||
def _install_next_item(self) -> None:
|
||||
self._logger.debug(f"Installer thread {threading.get_ident()} starting")
|
||||
while True:
|
||||
@@ -511,7 +580,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
finally:
|
||||
# if this is an install of a remote file, then clean up the temporary directory
|
||||
if job._install_tmpdir is not None:
|
||||
rmtree(job._install_tmpdir)
|
||||
self._safe_rmtree(job._install_tmpdir, self._logger)
|
||||
self._install_completed_event.set()
|
||||
self._install_queue.task_done()
|
||||
self._logger.info(f"Installer thread {threading.get_ident()} exiting")
|
||||
@@ -556,7 +625,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
path = self._app_config.models_path
|
||||
for tmpdir in path.glob(f"{TMPDIR_PREFIX}*"):
|
||||
self._logger.info(f"Removing dangling temporary directory {tmpdir}")
|
||||
rmtree(tmpdir)
|
||||
self._safe_rmtree(tmpdir, self._logger)
|
||||
|
||||
def _scan_for_missing_models(self) -> list[AnyModelConfig]:
|
||||
"""Scan the models directory for missing models and return a list of them."""
|
||||
@@ -731,10 +800,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
install_job._install_tmpdir = destdir
|
||||
install_job.total_bytes = sum((x.size or 0) for x in remote_files)
|
||||
|
||||
# Handle multiple subfolders for HFModelSource
|
||||
subfolders = source.subfolders if isinstance(source, HFModelSource) else []
|
||||
multifile_job = self._multifile_download(
|
||||
remote_files=remote_files,
|
||||
dest=destdir,
|
||||
subfolder=source.subfolder if isinstance(source, HFModelSource) else None,
|
||||
subfolder=source.subfolder if isinstance(source, HFModelSource) and len(subfolders) <= 1 else None,
|
||||
subfolders=subfolders if len(subfolders) > 1 else None,
|
||||
access_token=source.access_token,
|
||||
submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict
|
||||
)
|
||||
@@ -761,6 +833,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
remote_files: List[RemoteModelFile],
|
||||
dest: Path,
|
||||
subfolder: Optional[Path] = None,
|
||||
subfolders: Optional[List[Path]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
submit_job: bool = True,
|
||||
) -> MultiFileDownloadJob:
|
||||
@@ -768,24 +841,61 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such
|
||||
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
|
||||
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
|
||||
if subfolder:
|
||||
#
|
||||
# For multiple subfolders (e.g., text_encoder+tokenizer), we create a combined folder name
|
||||
# (e.g., sdxl-turbo_text_encoder_tokenizer) and keep each subfolder's contents in its own
|
||||
# subdirectory within the model folder.
|
||||
|
||||
if subfolders and len(subfolders) > 1:
|
||||
# Multiple subfolders: create combined name and keep subfolder structure
|
||||
top = Path(remote_files[0].path.parts[0]) # e.g. "Z-Image-Turbo/"
|
||||
subfolder_names = [sf.name.replace("/", "_").replace("\\", "_") for sf in subfolders]
|
||||
combined_name = "_".join(subfolder_names)
|
||||
path_to_add = Path(f"{top}_{combined_name}")
|
||||
|
||||
parts: List[RemoteModelFile] = []
|
||||
for model_file in remote_files:
|
||||
assert model_file.size is not None
|
||||
# Determine which subfolder this file belongs to
|
||||
file_path = model_file.path
|
||||
new_path: Optional[Path] = None
|
||||
for sf in subfolders:
|
||||
try:
|
||||
# Try to get relative path from this subfolder
|
||||
relative = file_path.relative_to(top / sf)
|
||||
# Keep the subfolder name as a subdirectory
|
||||
new_path = path_to_add / sf.name / relative
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if new_path is None:
|
||||
# File doesn't match any subfolder, keep original path structure
|
||||
new_path = path_to_add / file_path.relative_to(top)
|
||||
|
||||
parts.append(RemoteModelFile(url=model_file.url, path=new_path))
|
||||
elif subfolder:
|
||||
# Single subfolder: flatten into renamed folder
|
||||
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
|
||||
path_to_remove = top / subfolder # sdxl-turbo/vae/
|
||||
subfolder_rename = subfolder.name.replace("/", "_").replace("\\", "_")
|
||||
path_to_add = Path(f"{top}_{subfolder_rename}")
|
||||
else:
|
||||
path_to_remove = Path(".")
|
||||
path_to_add = Path(".")
|
||||
|
||||
parts: List[RemoteModelFile] = []
|
||||
for model_file in remote_files:
|
||||
assert model_file.size is not None
|
||||
parts.append(
|
||||
RemoteModelFile(
|
||||
url=model_file.url, # if a subfolder, then sdxl-turbo_vae/config.json
|
||||
path=path_to_add / model_file.path.relative_to(path_to_remove),
|
||||
parts = []
|
||||
for model_file in remote_files:
|
||||
assert model_file.size is not None
|
||||
parts.append(
|
||||
RemoteModelFile(
|
||||
url=model_file.url,
|
||||
path=path_to_add / model_file.path.relative_to(path_to_remove),
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
# No subfolder specified - pass through unchanged
|
||||
parts = []
|
||||
for model_file in remote_files:
|
||||
assert model_file.size is not None
|
||||
parts.append(RemoteModelFile(url=model_file.url, path=model_file.path))
|
||||
|
||||
return self._download_queue.multifile_download(
|
||||
parts=parts,
|
||||
|
||||
@@ -60,6 +60,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
service.start(invoker)
|
||||
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
# Shutdown the model cache to cancel any pending timers
|
||||
if hasattr(self._load, "ram_cache"):
|
||||
self._load.ram_cache.shutdown()
|
||||
|
||||
for service in [self._store, self._install, self._load]:
|
||||
if hasattr(service, "stop"):
|
||||
service.stop(invoker)
|
||||
@@ -88,7 +92,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
max_ram_cache_size_gb=app_config.max_cache_ram_gb,
|
||||
max_vram_cache_size_gb=app_config.max_cache_vram_gb,
|
||||
execution_device=execution_device or TorchDevice.choose_torch_device(),
|
||||
storage_device="cpu",
|
||||
log_memory_usage=app_config.log_memory_usage,
|
||||
logger=logger,
|
||||
keep_alive_minutes=app_config.model_cache_keep_alive_min,
|
||||
)
|
||||
loader = ModelLoadService(
|
||||
app_config=app_config,
|
||||
|
||||
@@ -19,11 +19,13 @@ from invokeai.backend.model_manager.configs.main import MainModelDefaultSettings
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
BaseModelType,
|
||||
ClipVariantType,
|
||||
Flux2VariantType,
|
||||
FluxVariantType,
|
||||
ModelFormat,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
Qwen3VariantType,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
|
||||
@@ -89,8 +91,8 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
||||
|
||||
# Checkpoint-specific changes
|
||||
# TODO(MM2): Should we expose these? Feels footgun-y...
|
||||
variant: Optional[ModelVariantType | ClipVariantType | FluxVariantType] = Field(
|
||||
description="The variant of the model.", default=None
|
||||
variant: Optional[ModelVariantType | ClipVariantType | FluxVariantType | Flux2VariantType | Qwen3VariantType] = (
|
||||
Field(description="The variant of the model.", default=None)
|
||||
)
|
||||
prediction_type: Optional[SchedulerPredictionType] = Field(
|
||||
description="The prediction type of the model.", default=None
|
||||
@@ -138,6 +140,18 @@ class ModelRecordServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def replace_model(self, key: str, new_config: AnyModelConfig) -> AnyModelConfig:
|
||||
"""
|
||||
Replace the model record entirely, returning the new record.
|
||||
|
||||
This is used when we re-identify a model and have a new config object.
|
||||
|
||||
:param key: Unique key for the model to be updated.
|
||||
:param new_config: The new model config to write.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self, key: str) -> AnyModelConfig:
|
||||
"""
|
||||
|
||||
@@ -179,6 +179,23 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
return self.get_model(key)
|
||||
|
||||
def replace_model(self, key: str, new_config: AnyModelConfig) -> AnyModelConfig:
|
||||
if key != new_config.key:
|
||||
raise ValueError("key does not match new_config.key")
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE models
|
||||
SET
|
||||
config=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(new_config.model_dump_json(), key),
|
||||
)
|
||||
if cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
return self.get_model(key)
|
||||
|
||||
def get_model(self, key: str) -> AnyModelConfig:
|
||||
"""
|
||||
Retrieve the ModelConfigBase instance for the indicated model.
|
||||
|
||||
194
invokeai/app/services/shared/README.md
Normal file
194
invokeai/app/services/shared/README.md
Normal file
@@ -0,0 +1,194 @@
|
||||
# InvokeAI Graph - Design Overview
|
||||
|
||||
High-level design for the graph module. Focuses on responsibilities, data flow, and how traversal works.
|
||||
|
||||
## 1) Purpose
|
||||
|
||||
Provide a typed, acyclic workflow model (**Graph**) plus a runtime scheduler (**GraphExecutionState**) that expands
|
||||
iterator patterns, tracks readiness via indegree (the number of incoming edges to a node in the directed graph), and
|
||||
executes nodes in class-grouped batches. Source graphs remain immutable during a run; runtime expansion happens in a
|
||||
separate execution graph.
|
||||
|
||||
## 2) Major Data Types
|
||||
|
||||
### EdgeConnection
|
||||
|
||||
* Fields: `node_id: str`, `field: str`.
|
||||
* Hashable; printed as `node.field` for readable diagnostics.
|
||||
|
||||
### Edge
|
||||
|
||||
* Fields: `source: EdgeConnection`, `destination: EdgeConnection`.
|
||||
* One directed connection from a specific output port to a specific input port.
|
||||
|
||||
### AnyInvocation / AnyInvocationOutput
|
||||
|
||||
* Pydantic wrappers that carry concrete invocation models and outputs.
|
||||
* No registry logic in this file; they are permissive containers for heterogeneous nodes.
|
||||
|
||||
### IterateInvocation / CollectInvocation
|
||||
|
||||
* Control nodes used by validation and execution:
|
||||
|
||||
* **IterateInvocation**: input `collection`, outputs include `item` (and index/total).
|
||||
* **CollectInvocation**: many `item` inputs aggregated to one `collection` output.
|
||||
|
||||
## 3) Graph (author-time model)
|
||||
|
||||
A container for declared nodes and edges. Does **not** perform iteration expansion.
|
||||
|
||||
### 3.1 Data
|
||||
|
||||
* `nodes: dict[str, AnyInvocation]` - key must equal `node.id`.
|
||||
* `edges: list[Edge]` - zero or more.
|
||||
* Utility: `_get_input_edges(node_id, field?)`, `_get_output_edges(node_id, field?)`
|
||||
These scan `self.edges` (no adjacency indices in the current code).
|
||||
|
||||
### 3.2 Validation (`validate_self`)
|
||||
|
||||
Runs a sequence of checks:
|
||||
|
||||
1. **Node ID uniqueness**
|
||||
No duplicate IDs; map key equals `node.id`.
|
||||
2. **Endpoint existence**
|
||||
Source and destination node IDs must exist.
|
||||
3. **Port existence**
|
||||
Input ports must exist on the node class; output ports on the node's output model.
|
||||
4. **Type compatibility**
|
||||
`get_output_field_type` vs `get_input_field_type` and `are_connection_types_compatible`.
|
||||
5. **DAG constraint**
|
||||
Build a *flat* `DiGraph` (no runtime expansion) and assert acyclicity.
|
||||
6. **Iterator / collector structure**
|
||||
Enforce special rules:
|
||||
|
||||
* Iterator's input must be `collection`; its outgoing edges use `item`.
|
||||
* Collector accepts many `item` inputs; outputs a single `collection`.
|
||||
* Edge fan-in to a non-collector input is rejected.
|
||||
|
||||
### 3.3 Edge admission (`_validate_edge`)
|
||||
|
||||
Checks a single prospective edge before insertion:
|
||||
|
||||
* Endpoints/ports exist.
|
||||
* Destination port is not already occupied unless it's a collector `item`.
|
||||
* Adding the edge to the flat DAG must keep it acyclic.
|
||||
* Iterator/collector constraints re-checked when the edge creates relevant patterns.
|
||||
|
||||
### 3.4 Topology utilities
|
||||
|
||||
* `nx_graph()` - DiGraph of declared nodes and edges.
|
||||
* `nx_graph_with_data()` - includes node/edge attributes.
|
||||
* `nx_graph_flat()` - "flattened" DAG (still author-time; no runtime copies).
|
||||
Used in validation and in `_prepare()` during execution planning.
|
||||
|
||||
### 3.5 Mutation helpers
|
||||
|
||||
* `add_node`, `update_node` (preserve edges, rewrite endpoints if id changes), `delete_node`.
|
||||
* `add_edge`, `delete_edge` (with validation).
|
||||
|
||||
## 4) GraphExecutionState (runtime)
|
||||
|
||||
Holds the state for a single run. Keeps the source graph intact; materializes a separate execution graph.
|
||||
|
||||
### 4.1 Data
|
||||
|
||||
* `graph: Graph` - immutable source during a run.
|
||||
* `execution_graph: Graph` - materialized runtime nodes/edges.
|
||||
* `executed: set[str]`, `executed_history: list[str]`.
|
||||
* `results: dict[str, AnyInvocationOutput]`, `errors: dict[str, str]`.
|
||||
* `prepared_source_mapping: dict[str, str]` - exec id → source id.
|
||||
* `source_prepared_mapping: dict[str, set[str]]` - source id → exec ids.
|
||||
* `indegree: dict[str, int]` - unmet inputs per exec node.
|
||||
* **Ready queues grouped by class** (private attrs):
|
||||
`_ready_queues: dict[class_name, deque[str]]`, `_active_class: Optional[str]`. Optional `ready_order: list[str]` to
|
||||
prioritize classes.
|
||||
|
||||
### 4.2 Core methods
|
||||
|
||||
* `next()`
|
||||
Returns the next ready exec node. If none, calls `_prepare()` to materialize more, then retries. Before returning a
|
||||
node, `_prepare_inputs()` deep-copies inbound values into the node fields.
|
||||
* `complete(node_id, output)`
|
||||
Record result; mark exec node executed; if all exec copies of the same **source** are done, mark the source executed.
|
||||
For each outgoing exec edge, decrement child indegree and enqueue when it reaches zero.
|
||||
|
||||
### 4.3 Preparation (`_prepare()`)
|
||||
|
||||
* Build a flat DAG from the **source** graph.
|
||||
* Choose the **next source node** in topological order that:
|
||||
|
||||
1. has not been prepared,
|
||||
2. if it is an iterator, *its inputs are already executed*,
|
||||
3. it has *no unexecuted iterator ancestors*.
|
||||
* If the node is a **CollectInvocation**: collapse all prepared parents into one mapping and create **one** exec node.
|
||||
* Otherwise: compute all combinations of prepared iterator ancestors. For each combination, pick the matching prepared parent per upstream and create **one** exec node.
|
||||
* For each new exec node:
|
||||
|
||||
* Deep-copy the source node; assign a fresh ID (and `index` for iterators).
|
||||
* Wire edges from chosen prepared parents.
|
||||
* Set `indegree = number of unmet inputs` (i.e., parents not yet executed).
|
||||
* If `indegree == 0`, enqueue into its class queue.
|
||||
|
||||
### 4.4 Readiness and batching
|
||||
|
||||
* `_enqueue_if_ready(nid)` enqueues by class name only when `indegree == 0` and not executed.
|
||||
* `_get_next_node()` drains the `_active_class` queue FIFO; when empty, selects the next nonempty class queue (by `ready_order` if set, else alphabetical), and continues. Optional fairness knobs can limit batch size per class; default is drain fully.
|
||||
|
||||
#### 4.4.1 Indegree (what it is and how it's used)
|
||||
|
||||
**Indegree** is the number of incoming edges to a node in the execution graph that are still unmet. In this engine:
|
||||
* For every materialized exec node, `indegree[node]` equals the count of its prerequisite parents that have **not** finished yet.
|
||||
* A node is "ready" exactly when `indegree[node] == 0`; only then is it enqueued.
|
||||
* When a node completes, the scheduler decrements `indegree[child]` for each outgoing edge. Any child that reaches 0 is enqueued.
|
||||
|
||||
Example: edges `A→C`, `B→C`, `C→D`. Start: `A:0, B:0, C:2, D:1`. Run `A` → `C:1`. Run `B` → `C:0` → enqueue `C`. Run `C`
|
||||
→ `D:0` → enqueue `D`. Run `D` → done.
|
||||
|
||||
### 4.5 Input hydration (`_prepare_inputs()`)
|
||||
|
||||
* For **CollectInvocation**: gather all incoming `item` values into `collection`.
|
||||
* For all others: deep-copy each incoming edge's value into the destination field.
|
||||
This prevents cross-node mutation through shared references.
|
||||
|
||||
## 5) Traversal Summary
|
||||
|
||||
1. Author builds a valid **Graph**.
|
||||
2. Create **GraphExecutionState** with that graph.
|
||||
3. Loop:
|
||||
|
||||
* `node = state.next()` → may trigger `_prepare()` expansion.
|
||||
* Execute node externally → `output`.
|
||||
* `state.complete(node.id, output)` → updates indegrees and queues.
|
||||
4. Finish when `next()` returns `None`.
|
||||
|
||||
The source graph is never mutated; all expansion occurs in `execution_graph` with traceability back to source nodes.
|
||||
|
||||
## 6) Invariants
|
||||
|
||||
* Source **Graph** remains a DAG and type-consistent.
|
||||
* `execution_graph` remains a DAG.
|
||||
* Nodes are enqueued only when `indegree == 0`.
|
||||
* `results` and `errors` are keyed by **exec node id**.
|
||||
* Collectors only aggregate `item` inputs; other inputs behave one-to-one.
|
||||
|
||||
## 7) Extensibility
|
||||
|
||||
* **New node types**: implement as Pydantic models with typed fields and outputs. Register per your invocation system; this file accepts them as `AnyInvocation`.
|
||||
* **Scheduling policy**: adjust `ready_order` to batch by class; add a batch cap for fairness without changing complexity.
|
||||
* **Dynamic behaviors** (future): can be added in `GraphExecutionState` by creating exec nodes and edges at `complete()` time, as long as the DAG invariant holds.
|
||||
|
||||
## 8) Error Model (selected)
|
||||
|
||||
* `DuplicateNodeIdError`, `NodeAlreadyInGraphError`
|
||||
* `NodeNotFoundError`, `NodeFieldNotFoundError`
|
||||
* `InvalidEdgeError`, `CyclicalGraphError`
|
||||
* `NodeInputError` (raised when preparing inputs for execution)
|
||||
|
||||
Messages favor short, precise diagnostics (node id, field, and failing condition).
|
||||
|
||||
## 9) Rationale
|
||||
|
||||
* **Two-graph approach** isolates authoring from execution expansion and keeps validation simple.
|
||||
* **Indegree + queues** gives O(1) scheduling decisions with clear batching semantics.
|
||||
* **Iterator/collector separation** keeps fan-out/fan-in explicit and testable.
|
||||
* **Deep-copy hydration** avoids incidental aliasing bugs between nodes.
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
|
||||
from collections import deque
|
||||
from typing import Any, Deque, Iterable, Optional, Type, TypeVar, Union, get_args, get_origin
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import (
|
||||
@@ -10,6 +11,7 @@ from pydantic import (
|
||||
ConfigDict,
|
||||
GetCoreSchemaHandler,
|
||||
GetJsonSchemaHandler,
|
||||
PrivateAttr,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
)
|
||||
@@ -33,6 +35,10 @@ from invokeai.app.util.misc import uuid_string
|
||||
# in 3.10 this would be "from types import NoneType"
|
||||
NoneType = type(None)
|
||||
|
||||
# Port name constants
|
||||
ITEM_FIELD = "item"
|
||||
COLLECTION_FIELD = "collection"
|
||||
|
||||
|
||||
class EdgeConnection(BaseModel):
|
||||
node_id: str = Field(description="The id of the node for this edge connection")
|
||||
@@ -395,7 +401,7 @@ class Graph(BaseModel):
|
||||
|
||||
try:
|
||||
self.edges.remove(edge)
|
||||
except KeyError:
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def validate_self(self) -> None:
|
||||
@@ -414,7 +420,8 @@ class Graph(BaseModel):
|
||||
|
||||
# Validate that all node ids are unique
|
||||
node_ids = [n.id for n in self.nodes.values()]
|
||||
duplicate_node_ids = {node_id for node_id in node_ids if node_ids.count(node_id) >= 2}
|
||||
seen = set()
|
||||
duplicate_node_ids = {nid for nid in node_ids if (nid in seen) or seen.add(nid)}
|
||||
if duplicate_node_ids:
|
||||
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
|
||||
|
||||
@@ -529,19 +536,19 @@ class Graph(BaseModel):
|
||||
raise InvalidEdgeError(f"Field types are incompatible ({edge})")
|
||||
|
||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == COLLECTION_FIELD:
|
||||
err = self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source)
|
||||
if err is not None:
|
||||
raise InvalidEdgeError(f"Iterator input type does not match iterator output type ({edge}): {err}")
|
||||
|
||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||
if isinstance(from_node, IterateInvocation) and edge.source.field == ITEM_FIELD:
|
||||
err = self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination)
|
||||
if err is not None:
|
||||
raise InvalidEdgeError(f"Iterator output type does not match iterator input type ({edge}): {err}")
|
||||
|
||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == ITEM_FIELD:
|
||||
err = self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source)
|
||||
if err is not None:
|
||||
raise InvalidEdgeError(f"Collector output type does not match collector input type ({edge}): {err}")
|
||||
@@ -549,7 +556,7 @@ class Graph(BaseModel):
|
||||
# Validate if collector output type matches input type (if this edge results in both being set) - skip if the destination field is not Any or list[Any]
|
||||
if (
|
||||
isinstance(from_node, CollectInvocation)
|
||||
and edge.source.field == "collection"
|
||||
and edge.source.field == COLLECTION_FIELD
|
||||
and not self._is_destination_field_list_of_Any(edge)
|
||||
and not self._is_destination_field_Any(edge)
|
||||
):
|
||||
@@ -639,8 +646,8 @@ class Graph(BaseModel):
|
||||
new_input: Optional[EdgeConnection] = None,
|
||||
new_output: Optional[EdgeConnection] = None,
|
||||
) -> str | None:
|
||||
inputs = [e.source for e in self._get_input_edges(node_id, "collection")]
|
||||
outputs = [e.destination for e in self._get_output_edges(node_id, "item")]
|
||||
inputs = [e.source for e in self._get_input_edges(node_id, COLLECTION_FIELD)]
|
||||
outputs = [e.destination for e in self._get_output_edges(node_id, ITEM_FIELD)]
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
@@ -670,7 +677,7 @@ class Graph(BaseModel):
|
||||
if isinstance(input_node, CollectInvocation):
|
||||
# Traverse the graph to find the first collector input edge. Collectors validate that their collection
|
||||
# inputs are all of the same type, so we can use the first input edge to determine the collector's type
|
||||
first_collector_input_edge = self._get_input_edges(input_node.id, "item")[0]
|
||||
first_collector_input_edge = self._get_input_edges(input_node.id, ITEM_FIELD)[0]
|
||||
first_collector_input_type = get_output_field_type(
|
||||
self.get_node(first_collector_input_edge.source.node_id), first_collector_input_edge.source.field
|
||||
)
|
||||
@@ -690,8 +697,8 @@ class Graph(BaseModel):
|
||||
new_input: Optional[EdgeConnection] = None,
|
||||
new_output: Optional[EdgeConnection] = None,
|
||||
) -> str | None:
|
||||
inputs = [e.source for e in self._get_input_edges(node_id, "item")]
|
||||
outputs = [e.destination for e in self._get_output_edges(node_id, "collection")]
|
||||
inputs = [e.source for e in self._get_input_edges(node_id, ITEM_FIELD)]
|
||||
outputs = [e.destination for e in self._get_output_edges(node_id, COLLECTION_FIELD)]
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
@@ -761,7 +768,7 @@ class Graph(BaseModel):
|
||||
# TODO: figure out if iteration nodes need to be expanded
|
||||
|
||||
unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges}
|
||||
g.add_edges_from([(e[0], e[1]) for e in unique_edges])
|
||||
g.add_edges_from(unique_edges)
|
||||
return g
|
||||
|
||||
|
||||
@@ -802,6 +809,41 @@ class GraphExecutionState(BaseModel):
|
||||
description="The map of original graph nodes to prepared nodes",
|
||||
default_factory=dict,
|
||||
)
|
||||
# Ready queues grouped by node class name (internal only)
|
||||
_ready_queues: dict[str, Deque[str]] = PrivateAttr(default_factory=dict)
|
||||
# Current class being drained; stays until its queue empties
|
||||
_active_class: Optional[str] = PrivateAttr(default=None)
|
||||
# Optional priority; others follow in name order
|
||||
ready_order: list[str] = Field(default_factory=list)
|
||||
indegree: dict[str, int] = Field(default_factory=dict, description="Remaining unmet input count for exec nodes")
|
||||
|
||||
def _type_key(self, node_obj: BaseInvocation) -> str:
|
||||
return node_obj.__class__.__name__
|
||||
|
||||
def _queue_for(self, cls_name: str) -> Deque[str]:
|
||||
q = self._ready_queues.get(cls_name)
|
||||
if q is None:
|
||||
q = deque()
|
||||
self._ready_queues[cls_name] = q
|
||||
return q
|
||||
|
||||
def set_ready_order(self, order: Iterable[Type[BaseInvocation] | str]) -> None:
|
||||
names: list[str] = []
|
||||
for x in order:
|
||||
names.append(x.__name__ if hasattr(x, "__name__") else str(x))
|
||||
self.ready_order = names
|
||||
|
||||
def _enqueue_if_ready(self, nid: str) -> None:
|
||||
"""Push nid to its class queue if unmet inputs == 0."""
|
||||
# Invariants: exec node exists and has an indegree entry
|
||||
if nid not in self.execution_graph.nodes:
|
||||
raise KeyError(f"exec node {nid} missing from execution_graph")
|
||||
if nid not in self.indegree:
|
||||
raise KeyError(f"indegree missing for exec node {nid}")
|
||||
if self.indegree[nid] != 0 or nid in self.executed:
|
||||
return
|
||||
node_obj = self.execution_graph.nodes[nid]
|
||||
self._queue_for(self._type_key(node_obj)).append(nid)
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
@@ -834,12 +876,14 @@ class GraphExecutionState(BaseModel):
|
||||
# If there are no prepared nodes, prepare some nodes
|
||||
next_node = self._get_next_node()
|
||||
if next_node is None:
|
||||
prepared_id = self._prepare()
|
||||
base_g = self.graph.nx_graph_flat()
|
||||
prepared_id = self._prepare(base_g)
|
||||
|
||||
# Prepare as many nodes as we can
|
||||
while prepared_id is not None:
|
||||
prepared_id = self._prepare()
|
||||
next_node = self._get_next_node()
|
||||
prepared_id = self._prepare(base_g)
|
||||
if next_node is None:
|
||||
next_node = self._get_next_node()
|
||||
|
||||
# Get values from edges
|
||||
if next_node is not None:
|
||||
@@ -869,6 +913,18 @@ class GraphExecutionState(BaseModel):
|
||||
self.executed.add(source_node)
|
||||
self.executed_history.append(source_node)
|
||||
|
||||
# Decrement children indegree and enqueue when ready
|
||||
for e in self.execution_graph._get_output_edges(node_id):
|
||||
child = e.destination.node_id
|
||||
if child not in self.indegree:
|
||||
raise KeyError(f"indegree missing for exec node {child}")
|
||||
# Only decrement if there's something to satisfy
|
||||
if self.indegree[child] == 0:
|
||||
raise RuntimeError(f"indegree underflow for {child} from parent {node_id}")
|
||||
self.indegree[child] -= 1
|
||||
if self.indegree[child] == 0:
|
||||
self._enqueue_if_ready(child)
|
||||
|
||||
def set_node_error(self, node_id: str, error: str):
|
||||
"""Marks a node as errored"""
|
||||
self.errors[node_id] = error
|
||||
@@ -892,7 +948,7 @@ class GraphExecutionState(BaseModel):
|
||||
# If this is an iterator node, we must create a copy for each iteration
|
||||
if isinstance(node, IterateInvocation):
|
||||
# Get input collection edge (should error if there are no inputs)
|
||||
input_collection_edge = next(iter(self.graph._get_input_edges(node_id, "collection")))
|
||||
input_collection_edge = next(iter(self.graph._get_input_edges(node_id, COLLECTION_FIELD)))
|
||||
input_collection_prepared_node_id = next(
|
||||
n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id
|
||||
)
|
||||
@@ -922,7 +978,7 @@ class GraphExecutionState(BaseModel):
|
||||
# Create a new node (or one for each iteration of this iterator)
|
||||
for i in range(self_iteration_count) if self_iteration_count > 0 else [-1]:
|
||||
# Create a new node
|
||||
new_node = copy.deepcopy(node)
|
||||
new_node = node.model_copy(deep=True)
|
||||
|
||||
# Create the node id (use a random uuid)
|
||||
new_node.id = uuid_string()
|
||||
@@ -946,53 +1002,55 @@ class GraphExecutionState(BaseModel):
|
||||
)
|
||||
self.execution_graph.add_edge(new_edge)
|
||||
|
||||
# Initialize indegree as unmet inputs only and enqueue if ready
|
||||
inputs = self.execution_graph._get_input_edges(new_node.id)
|
||||
unmet = sum(1 for e in inputs if e.source.node_id not in self.executed)
|
||||
self.indegree[new_node.id] = unmet
|
||||
self._enqueue_if_ready(new_node.id)
|
||||
|
||||
new_nodes.append(new_node.id)
|
||||
|
||||
return new_nodes
|
||||
|
||||
def _iterator_graph(self) -> nx.DiGraph:
|
||||
def _iterator_graph(self, base: Optional[nx.DiGraph] = None) -> nx.DiGraph:
|
||||
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
|
||||
g = self.graph.nx_graph_flat()
|
||||
g = base.copy() if base is not None else self.graph.nx_graph_flat()
|
||||
collectors = (n for n in self.graph.nodes if isinstance(self.graph.get_node(n), CollectInvocation))
|
||||
for c in collectors:
|
||||
g.remove_edges_from(list(g.in_edges(c)))
|
||||
return g
|
||||
|
||||
def _get_node_iterators(self, node_id: str) -> list[str]:
|
||||
def _get_node_iterators(self, node_id: str, it_graph: Optional[nx.DiGraph] = None) -> list[str]:
|
||||
"""Gets iterators for a node"""
|
||||
g = self._iterator_graph()
|
||||
iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.get_node(n), IterateInvocation)]
|
||||
return iterators
|
||||
g = it_graph or self._iterator_graph()
|
||||
return [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.get_node(n), IterateInvocation)]
|
||||
|
||||
def _prepare(self) -> Optional[str]:
|
||||
def _prepare(self, base_g: Optional[nx.DiGraph] = None) -> Optional[str]:
|
||||
# Get flattened source graph
|
||||
g = self.graph.nx_graph_flat()
|
||||
g = base_g or self.graph.nx_graph_flat()
|
||||
|
||||
# Find next node that:
|
||||
# - was not already prepared
|
||||
# - is not an iterate node whose inputs have not been executed
|
||||
# - does not have an unexecuted iterate ancestor
|
||||
sorted_nodes = nx.topological_sort(g)
|
||||
|
||||
def unprepared(n: str) -> bool:
|
||||
return n not in self.source_prepared_mapping
|
||||
|
||||
def iter_inputs_ready(n: str) -> bool:
|
||||
if not isinstance(self.graph.get_node(n), IterateInvocation):
|
||||
return True
|
||||
return all(u in self.executed for u, _ in g.in_edges(n))
|
||||
|
||||
def no_unexecuted_iter_ancestors(n: str) -> bool:
|
||||
return not any(
|
||||
isinstance(self.graph.get_node(a), IterateInvocation) and a not in self.executed
|
||||
for a in nx.ancestors(g, n)
|
||||
)
|
||||
|
||||
next_node_id = next(
|
||||
(
|
||||
n
|
||||
for n in sorted_nodes
|
||||
# exclude nodes that have already been prepared
|
||||
if n not in self.source_prepared_mapping
|
||||
# exclude iterate nodes whose inputs have not been executed
|
||||
and not (
|
||||
isinstance(self.graph.get_node(n), IterateInvocation) # `n` is an iterate node...
|
||||
and not all((e[0] in self.executed for e in g.in_edges(n))) # ...that has unexecuted inputs
|
||||
)
|
||||
# exclude nodes who have unexecuted iterate ancestors
|
||||
and not any(
|
||||
(
|
||||
isinstance(self.graph.get_node(a), IterateInvocation) # `a` is an iterate ancestor of `n`...
|
||||
and a not in self.executed # ...that is not executed
|
||||
for a in nx.ancestors(g, n) # for all ancestors `a` of node `n`
|
||||
)
|
||||
)
|
||||
),
|
||||
(n for n in sorted_nodes if unprepared(n) and iter_inputs_ready(n) and no_unexecuted_iter_ancestors(n)),
|
||||
None,
|
||||
)
|
||||
|
||||
@@ -1000,7 +1058,7 @@ class GraphExecutionState(BaseModel):
|
||||
return None
|
||||
|
||||
# Get all parents of the next node
|
||||
next_node_parents = [e[0] for e in g.in_edges(next_node_id)]
|
||||
next_node_parents = [u for u, _ in g.in_edges(next_node_id)]
|
||||
|
||||
# Create execution nodes
|
||||
next_node = self.graph.get_node(next_node_id)
|
||||
@@ -1018,7 +1076,8 @@ class GraphExecutionState(BaseModel):
|
||||
else: # Iterators or normal nodes
|
||||
# Get all iterator combinations for this node
|
||||
# Will produce a list of lists of prepared iterator nodes, from which results can be iterated
|
||||
iterator_nodes = self._get_node_iterators(next_node_id)
|
||||
it_g = self._iterator_graph(g)
|
||||
iterator_nodes = self._get_node_iterators(next_node_id, it_g)
|
||||
iterator_nodes_prepared = [list(self.source_prepared_mapping[n]) for n in iterator_nodes]
|
||||
iterator_node_prepared_combinations = list(itertools.product(*iterator_nodes_prepared))
|
||||
|
||||
@@ -1066,45 +1125,41 @@ class GraphExecutionState(BaseModel):
|
||||
)
|
||||
|
||||
def _get_next_node(self) -> Optional[BaseInvocation]:
|
||||
"""Gets the deepest node that is ready to be executed"""
|
||||
g = self.execution_graph.nx_graph()
|
||||
"""Gets the next ready node: FIFO within class, drain class before switching."""
|
||||
# 1) Continue draining the active class
|
||||
if self._active_class:
|
||||
q = self._ready_queues.get(self._active_class)
|
||||
while q:
|
||||
nid = q.popleft()
|
||||
if nid not in self.executed:
|
||||
return self.execution_graph.nodes[nid]
|
||||
# emptied: release active class
|
||||
self._active_class = None
|
||||
|
||||
# Perform a topological sort using depth-first search
|
||||
topo_order = list(nx.dfs_postorder_nodes(g))
|
||||
|
||||
# Get all IterateInvocation nodes
|
||||
iterate_nodes = [n for n in topo_order if isinstance(self.execution_graph.nodes[n], IterateInvocation)]
|
||||
|
||||
# Sort the IterateInvocation nodes based on their index attribute
|
||||
iterate_nodes.sort(key=lambda x: self.execution_graph.nodes[x].index)
|
||||
|
||||
# Prioritize IterateInvocation nodes and their children
|
||||
for iterate_node in iterate_nodes:
|
||||
if iterate_node not in self.executed and all((e[0] in self.executed for e in g.in_edges(iterate_node))):
|
||||
return self.execution_graph.nodes[iterate_node]
|
||||
|
||||
# Check the children of the IterateInvocation node
|
||||
for child_node in nx.dfs_postorder_nodes(g, iterate_node):
|
||||
if child_node not in self.executed and all((e[0] in self.executed for e in g.in_edges(child_node))):
|
||||
return self.execution_graph.nodes[child_node]
|
||||
|
||||
# If no IterateInvocation node or its children are ready, return the first ready node in the topological order
|
||||
for node in topo_order:
|
||||
if node not in self.executed and all((e[0] in self.executed for e in g.in_edges(node))):
|
||||
return self.execution_graph.nodes[node]
|
||||
|
||||
# If no node is found, return None
|
||||
# 2) Pick next class by priority, then by class name
|
||||
seen = set(self.ready_order)
|
||||
for cls_name in self.ready_order:
|
||||
q = self._ready_queues.get(cls_name)
|
||||
if q:
|
||||
self._active_class = cls_name
|
||||
# recurse to drain newly set active class
|
||||
return self._get_next_node()
|
||||
for cls_name in sorted(k for k in self._ready_queues.keys() if k not in seen):
|
||||
q = self._ready_queues[cls_name]
|
||||
if q:
|
||||
self._active_class = cls_name
|
||||
return self._get_next_node()
|
||||
return None
|
||||
|
||||
def _prepare_inputs(self, node: BaseInvocation):
|
||||
input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id]
|
||||
input_edges = self.execution_graph._get_input_edges(node.id)
|
||||
# Inputs must be deep-copied, else if a node mutates the object, other nodes that get the same input
|
||||
# will see the mutation.
|
||||
if isinstance(node, CollectInvocation):
|
||||
output_collection = [
|
||||
copydeep(getattr(self.results[edge.source.node_id], edge.source.field))
|
||||
for edge in input_edges
|
||||
if edge.destination.field == "item"
|
||||
if edge.destination.field == ITEM_FIELD
|
||||
]
|
||||
node.collection = output_collection
|
||||
else:
|
||||
|
||||
@@ -630,6 +630,21 @@ class UtilInterface(InvocationContextInterface):
|
||||
is_canceled=self.is_canceled,
|
||||
)
|
||||
|
||||
def flux2_step_callback(self, intermediate_state: PipelineIntermediateState) -> None:
|
||||
"""
|
||||
The step callback for FLUX.2 Klein models (32-channel VAE).
|
||||
|
||||
Args:
|
||||
intermediate_state: The intermediate state of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
diffusion_step_callback(
|
||||
signal_progress=self.signal_progress,
|
||||
intermediate_state=intermediate_state,
|
||||
base_model=BaseModelType.Flux2,
|
||||
is_canceled=self.is_canceled,
|
||||
)
|
||||
|
||||
def signal_progress(
|
||||
self,
|
||||
message: str,
|
||||
|
||||
@@ -27,6 +27,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_21 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_22 import build_migration_22
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_23 import build_migration_23
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_24 import build_migration_24
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_25 import build_migration_25
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@@ -71,6 +72,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_22(app_config=config, logger=logger))
|
||||
migrator.register_migration(build_migration_23(app_config=config, logger=logger))
|
||||
migrator.register_migration(build_migration_24(app_config=config, logger=logger))
|
||||
migrator.register_migration(build_migration_25(app_config=config, logger=logger))
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
import json
|
||||
import sqlite3
|
||||
from logging import Logger
|
||||
from typing import Any
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
from invokeai.backend.model_manager.taxonomy import ModelType, Qwen3VariantType
|
||||
|
||||
|
||||
class Migration25Callback:
|
||||
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
|
||||
self._app_config = app_config
|
||||
self._logger = logger
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
cursor.execute("SELECT id, config FROM models;")
|
||||
rows = cursor.fetchall()
|
||||
|
||||
migrated_count = 0
|
||||
|
||||
for model_id, config_json in rows:
|
||||
try:
|
||||
config_dict: dict[str, Any] = json.loads(config_json)
|
||||
|
||||
if config_dict.get("type") != ModelType.Qwen3Encoder.value:
|
||||
continue
|
||||
|
||||
if "variant" in config_dict:
|
||||
continue
|
||||
|
||||
config_dict["variant"] = Qwen3VariantType.Qwen3_4B.value
|
||||
|
||||
cursor.execute(
|
||||
"UPDATE models SET config = ? WHERE id = ?;",
|
||||
(json.dumps(config_dict), model_id),
|
||||
)
|
||||
migrated_count += 1
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
self._logger.error("Invalid config JSON for model %s: %s", model_id, e)
|
||||
raise
|
||||
|
||||
if migrated_count > 0:
|
||||
self._logger.info(f"Migration complete: {migrated_count} Qwen3 encoder configs updated with variant field")
|
||||
else:
|
||||
self._logger.info("Migration complete: no Qwen3 encoder configs needed migration")
|
||||
|
||||
|
||||
def build_migration_25(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
|
||||
"""Builds the migration object for migrating from version 24 to version 25.
|
||||
|
||||
This migration adds the variant field to existing Qwen3 encoder models.
|
||||
Models installed before the variant field was added will default to Qwen3_4B (for Z-Image compatibility).
|
||||
"""
|
||||
|
||||
return Migration(
|
||||
from_version=24,
|
||||
to_version=25,
|
||||
callback=Migration25Callback(app_config=app_config, logger=logger),
|
||||
)
|
||||
@@ -74,3 +74,11 @@ class WorkflowRecordsStorageBase(ABC):
|
||||
def update_opened_at(self, workflow_id: str) -> None:
|
||||
"""Open a workflow."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all_tags(
|
||||
self,
|
||||
categories: Optional[list[WorkflowCategory]] = None,
|
||||
) -> list[str]:
|
||||
"""Gets all unique tags from workflows."""
|
||||
pass
|
||||
|
||||
@@ -332,6 +332,48 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
(workflow_id,),
|
||||
)
|
||||
|
||||
def get_all_tags(
|
||||
self,
|
||||
categories: Optional[list[WorkflowCategory]] = None,
|
||||
) -> list[str]:
|
||||
with self._db.transaction() as cursor:
|
||||
conditions: list[str] = []
|
||||
params: list[str] = []
|
||||
|
||||
# Only get workflows that have tags
|
||||
conditions.append("tags IS NOT NULL AND tags != ''")
|
||||
|
||||
if categories:
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
conditions.append(f"category IN ({placeholders})")
|
||||
params.extend([category.value for category in categories])
|
||||
|
||||
stmt = """--sql
|
||||
SELECT DISTINCT tags
|
||||
FROM workflow_library
|
||||
"""
|
||||
|
||||
if conditions:
|
||||
stmt += " WHERE " + " AND ".join(conditions)
|
||||
|
||||
cursor.execute(stmt, params)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
# Parse comma-separated tags and collect unique tags
|
||||
all_tags: set[str] = set()
|
||||
|
||||
for row in rows:
|
||||
tags_value = row[0]
|
||||
if tags_value and isinstance(tags_value, str):
|
||||
# Tags are stored as comma-separated string
|
||||
for tag in tags_value.split(","):
|
||||
tag_stripped = tag.strip()
|
||||
if tag_stripped:
|
||||
all_tags.add(tag_stripped)
|
||||
|
||||
return sorted(all_tags)
|
||||
|
||||
def _sync_default_workflows(self) -> None:
|
||||
"""Syncs default workflows to the database. Internal use only."""
|
||||
|
||||
|
||||
@@ -93,14 +93,60 @@ COGVIEW4_LATENT_RGB_FACTORS = [
|
||||
[-0.00955853, -0.00980067, -0.00977842],
|
||||
]
|
||||
|
||||
# FLUX.2 uses 32 latent channels.
|
||||
# Factors from ComfyUI: https://github.com/Comfy-Org/ComfyUI/blob/main/comfy/latent_formats.py
|
||||
FLUX2_LATENT_RGB_FACTORS = [
|
||||
# R G B
|
||||
[0.0058, 0.0113, 0.0073],
|
||||
[0.0495, 0.0443, 0.0836],
|
||||
[-0.0099, 0.0096, 0.0644],
|
||||
[0.2144, 0.3009, 0.3652],
|
||||
[0.0166, -0.0039, -0.0054],
|
||||
[0.0157, 0.0103, -0.0160],
|
||||
[-0.0398, 0.0902, -0.0235],
|
||||
[-0.0052, 0.0095, 0.0109],
|
||||
[-0.3527, -0.2712, -0.1666],
|
||||
[-0.0301, -0.0356, -0.0180],
|
||||
[-0.0107, 0.0078, 0.0013],
|
||||
[0.0746, 0.0090, -0.0941],
|
||||
[0.0156, 0.0169, 0.0070],
|
||||
[-0.0034, -0.0040, -0.0114],
|
||||
[0.0032, 0.0181, 0.0080],
|
||||
[-0.0939, -0.0008, 0.0186],
|
||||
[0.0018, 0.0043, 0.0104],
|
||||
[0.0284, 0.0056, -0.0127],
|
||||
[-0.0024, -0.0022, -0.0030],
|
||||
[0.1207, -0.0026, 0.0065],
|
||||
[0.0128, 0.0101, 0.0142],
|
||||
[0.0137, -0.0072, -0.0007],
|
||||
[0.0095, 0.0092, -0.0059],
|
||||
[0.0000, -0.0077, -0.0049],
|
||||
[-0.0465, -0.0204, -0.0312],
|
||||
[0.0095, 0.0012, -0.0066],
|
||||
[0.0290, -0.0034, 0.0025],
|
||||
[0.0220, 0.0169, -0.0048],
|
||||
[-0.0332, -0.0457, -0.0468],
|
||||
[-0.0085, 0.0389, 0.0609],
|
||||
[-0.0076, 0.0003, -0.0043],
|
||||
[-0.0111, -0.0460, -0.0614],
|
||||
]
|
||||
|
||||
FLUX2_LATENT_RGB_BIAS = [-0.0329, -0.0718, -0.0851]
|
||||
|
||||
|
||||
def sample_to_lowres_estimated_image(
|
||||
samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None
|
||||
samples: torch.Tensor,
|
||||
latent_rgb_factors: torch.Tensor,
|
||||
smooth_matrix: Optional[torch.Tensor] = None,
|
||||
latent_rgb_bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if samples.dim() == 4:
|
||||
samples = samples[0]
|
||||
latent_image = samples.permute(1, 2, 0) @ latent_rgb_factors
|
||||
|
||||
if latent_rgb_bias is not None:
|
||||
latent_image = latent_image + latent_rgb_bias
|
||||
|
||||
if smooth_matrix is not None:
|
||||
latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2)
|
||||
latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1, 1, 3, 3)), padding=1)
|
||||
@@ -153,6 +199,7 @@ def diffusion_step_callback(
|
||||
sample = intermediate_state.latents
|
||||
|
||||
smooth_matrix: list[list[float]] | None = None
|
||||
latent_rgb_bias: list[float] | None = None
|
||||
if base_model in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
latent_rgb_factors = SD1_5_LATENT_RGB_FACTORS
|
||||
elif base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
|
||||
@@ -164,6 +211,12 @@ def diffusion_step_callback(
|
||||
latent_rgb_factors = COGVIEW4_LATENT_RGB_FACTORS
|
||||
elif base_model == BaseModelType.Flux:
|
||||
latent_rgb_factors = FLUX_LATENT_RGB_FACTORS
|
||||
elif base_model == BaseModelType.Flux2:
|
||||
latent_rgb_factors = FLUX2_LATENT_RGB_FACTORS
|
||||
latent_rgb_bias = FLUX2_LATENT_RGB_BIAS
|
||||
elif base_model == BaseModelType.ZImage:
|
||||
# Z-Image uses FLUX-compatible VAE with 16 latent channels
|
||||
latent_rgb_factors = FLUX_LATENT_RGB_FACTORS
|
||||
else:
|
||||
raise ValueError(f"Unsupported base model: {base_model}")
|
||||
|
||||
@@ -171,8 +224,14 @@ def diffusion_step_callback(
|
||||
smooth_matrix_torch = (
|
||||
torch.tensor(smooth_matrix, dtype=sample.dtype, device=sample.device) if smooth_matrix else None
|
||||
)
|
||||
latent_rgb_bias_torch = (
|
||||
torch.tensor(latent_rgb_bias, dtype=sample.dtype, device=sample.device) if latent_rgb_bias else None
|
||||
)
|
||||
image = sample_to_lowres_estimated_image(
|
||||
samples=sample, latent_rgb_factors=latent_rgb_factors_torch, smooth_matrix=smooth_matrix_torch
|
||||
samples=sample,
|
||||
latent_rgb_factors=latent_rgb_factors_torch,
|
||||
smooth_matrix=smooth_matrix_torch,
|
||||
latent_rgb_bias=latent_rgb_bias_torch,
|
||||
)
|
||||
|
||||
width = image.width * 8
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import inspect
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
|
||||
from invokeai.backend.flux.extensions.dype_extension import DyPEExtension
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
@@ -35,149 +38,366 @@ def denoise(
|
||||
# extra img tokens (sequence-wise) - for Kontext conditioning
|
||||
img_cond_seq: torch.Tensor | None = None,
|
||||
img_cond_seq_ids: torch.Tensor | None = None,
|
||||
# DyPE extension for high-resolution generation
|
||||
dype_extension: DyPEExtension | None = None,
|
||||
# Optional scheduler for alternative sampling methods
|
||||
scheduler: SchedulerMixin | None = None,
|
||||
):
|
||||
# step 0 is the initial state
|
||||
total_steps = len(timesteps) - 1
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=0,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(timesteps[0]),
|
||||
latents=img,
|
||||
),
|
||||
)
|
||||
# Determine if we're using a diffusers scheduler or the built-in Euler method
|
||||
use_scheduler = scheduler is not None
|
||||
|
||||
if use_scheduler:
|
||||
# Initialize scheduler with timesteps
|
||||
# The timesteps list contains values in [0, 1] range (sigmas)
|
||||
# LCM should use num_inference_steps (it has its own sigma schedule),
|
||||
# while other schedulers can use custom sigmas if supported
|
||||
is_lcm = scheduler.__class__.__name__ == "FlowMatchLCMScheduler"
|
||||
set_timesteps_sig = inspect.signature(scheduler.set_timesteps)
|
||||
if not is_lcm and "sigmas" in set_timesteps_sig.parameters:
|
||||
# Scheduler supports custom sigmas - use InvokeAI's time-shifted schedule
|
||||
scheduler.set_timesteps(sigmas=timesteps, device=img.device)
|
||||
else:
|
||||
# LCM or scheduler doesn't support custom sigmas - use num_inference_steps
|
||||
# The schedule will be computed by the scheduler itself
|
||||
num_inference_steps = len(timesteps) - 1
|
||||
scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=img.device)
|
||||
|
||||
# For schedulers like Heun, the number of actual steps may differ
|
||||
# (Heun doubles timesteps internally)
|
||||
num_scheduler_steps = len(scheduler.timesteps)
|
||||
# For user-facing step count, use the original number of denoising steps
|
||||
total_steps = len(timesteps) - 1
|
||||
else:
|
||||
total_steps = len(timesteps) - 1
|
||||
num_scheduler_steps = total_steps
|
||||
|
||||
# guidance_vec is ignored for schnell.
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
|
||||
# Store original sequence length for slicing predictions
|
||||
original_seq_len = img.shape[1]
|
||||
|
||||
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
# DyPE: Patch model with DyPE-aware position embedder
|
||||
dype_embedder = None
|
||||
original_pe_embedder = None
|
||||
if dype_extension is not None:
|
||||
dype_embedder, original_pe_embedder = dype_extension.patch_model(model)
|
||||
|
||||
# Run ControlNet models.
|
||||
controlnet_residuals: list[ControlNetFluxOutput] = []
|
||||
for controlnet_extension in controlnet_extensions:
|
||||
controlnet_residuals.append(
|
||||
controlnet_extension.run_controlnet(
|
||||
timestep_index=step_index,
|
||||
total_num_timesteps=total_steps,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
try:
|
||||
# Track the actual step for user-facing progress (accounts for Heun's double steps)
|
||||
user_step = 0
|
||||
|
||||
if use_scheduler:
|
||||
# Use diffusers scheduler for stepping
|
||||
# Use tqdm with total_steps (user-facing steps) not num_scheduler_steps (internal steps)
|
||||
# This ensures progress bar shows 1/8, 2/8, etc. even when scheduler uses more internal steps
|
||||
pbar = tqdm(total=total_steps, desc="Denoising")
|
||||
for step_index in range(num_scheduler_steps):
|
||||
timestep = scheduler.timesteps[step_index]
|
||||
# Convert scheduler timestep (0-1000) to normalized (0-1) for the model
|
||||
t_curr = timestep.item() / scheduler.config.num_train_timesteps
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
|
||||
# DyPE: Update step state for timestep-dependent scaling
|
||||
if dype_extension is not None and dype_embedder is not None:
|
||||
dype_extension.update_step_state(
|
||||
embedder=dype_embedder,
|
||||
timestep=t_curr,
|
||||
timestep_index=user_step,
|
||||
total_steps=total_steps,
|
||||
)
|
||||
|
||||
# For Heun scheduler, track if we're in first or second order step
|
||||
is_heun = hasattr(scheduler, "state_in_first_order")
|
||||
in_first_order = scheduler.state_in_first_order if is_heun else True
|
||||
|
||||
# Run ControlNet models
|
||||
controlnet_residuals: list[ControlNetFluxOutput] = []
|
||||
for controlnet_extension in controlnet_extensions:
|
||||
controlnet_residuals.append(
|
||||
controlnet_extension.run_controlnet(
|
||||
timestep_index=user_step,
|
||||
total_num_timesteps=total_steps,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
)
|
||||
|
||||
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
|
||||
|
||||
# Prepare input for model
|
||||
img_input = img
|
||||
img_input_ids = img_ids
|
||||
|
||||
if img_cond is not None:
|
||||
img_input = torch.cat((img_input, img_cond), dim=-1)
|
||||
|
||||
if img_cond_seq is not None:
|
||||
assert img_cond_seq_ids is not None
|
||||
img_input = torch.cat((img_input, img_cond_seq), dim=1)
|
||||
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
|
||||
|
||||
pred = model(
|
||||
img=img_input,
|
||||
img_ids=img_input_ids,
|
||||
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
timestep_index=user_step,
|
||||
total_num_timesteps=total_steps,
|
||||
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
|
||||
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
|
||||
ip_adapter_extensions=pos_ip_adapter_extensions,
|
||||
regional_prompting_extension=pos_regional_prompting_extension,
|
||||
)
|
||||
)
|
||||
|
||||
# Merge the ControlNet residuals from multiple ControlNets.
|
||||
# TODO(ryand): We may want to calculate the sum just-in-time to keep peak memory low. Keep in mind, that the
|
||||
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
|
||||
# tensors. Calculating the sum materializes each tensor into its own instance.
|
||||
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
|
||||
if img_cond_seq is not None:
|
||||
pred = pred[:, :original_seq_len]
|
||||
|
||||
# Prepare input for model - concatenate fresh each step
|
||||
img_input = img
|
||||
img_input_ids = img_ids
|
||||
# Get CFG scale for current user step
|
||||
step_cfg_scale = cfg_scale[min(user_step, len(cfg_scale) - 1)]
|
||||
|
||||
# Add channel-wise conditioning (for ControlNet, FLUX Fill, etc.)
|
||||
if img_cond is not None:
|
||||
img_input = torch.cat((img_input, img_cond), dim=-1)
|
||||
if not math.isclose(step_cfg_scale, 1.0):
|
||||
if neg_regional_prompting_extension is None:
|
||||
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
|
||||
|
||||
# Add sequence-wise conditioning (for Kontext)
|
||||
if img_cond_seq is not None:
|
||||
assert img_cond_seq_ids is not None, (
|
||||
"You need to provide either both or neither of the sequence conditioning"
|
||||
)
|
||||
img_input = torch.cat((img_input, img_cond_seq), dim=1)
|
||||
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
|
||||
neg_img_input = img
|
||||
neg_img_input_ids = img_ids
|
||||
|
||||
pred = model(
|
||||
img=img_input,
|
||||
img_ids=img_input_ids,
|
||||
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
timestep_index=step_index,
|
||||
total_num_timesteps=total_steps,
|
||||
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
|
||||
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
|
||||
ip_adapter_extensions=pos_ip_adapter_extensions,
|
||||
regional_prompting_extension=pos_regional_prompting_extension,
|
||||
)
|
||||
if img_cond is not None:
|
||||
neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1)
|
||||
|
||||
# Slice prediction to only include the main image tokens
|
||||
if img_cond_seq is not None:
|
||||
pred = pred[:, :original_seq_len]
|
||||
if img_cond_seq is not None:
|
||||
neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1)
|
||||
neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1)
|
||||
|
||||
step_cfg_scale = cfg_scale[step_index]
|
||||
neg_pred = model(
|
||||
img=neg_img_input,
|
||||
img_ids=neg_img_input_ids,
|
||||
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
timestep_index=user_step,
|
||||
total_num_timesteps=total_steps,
|
||||
controlnet_double_block_residuals=None,
|
||||
controlnet_single_block_residuals=None,
|
||||
ip_adapter_extensions=neg_ip_adapter_extensions,
|
||||
regional_prompting_extension=neg_regional_prompting_extension,
|
||||
)
|
||||
|
||||
# If step_cfg_scale, is 1.0, then we don't need to run the negative prediction.
|
||||
if not math.isclose(step_cfg_scale, 1.0):
|
||||
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance
|
||||
# on systems with sufficient VRAM.
|
||||
if img_cond_seq is not None:
|
||||
neg_pred = neg_pred[:, :original_seq_len]
|
||||
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
|
||||
|
||||
if neg_regional_prompting_extension is None:
|
||||
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
|
||||
# Use scheduler.step() for the update
|
||||
step_output = scheduler.step(model_output=pred, timestep=timestep, sample=img)
|
||||
img = step_output.prev_sample
|
||||
|
||||
# For negative prediction with Kontext, we need to include the reference images
|
||||
# to maintain consistency between positive and negative passes. Without this,
|
||||
# CFG would create artifacts as the attention mechanism would see different
|
||||
# spatial structures in each pass
|
||||
neg_img_input = img
|
||||
neg_img_input_ids = img_ids
|
||||
# Get t_prev for inpainting (next sigma value)
|
||||
if step_index + 1 < len(scheduler.sigmas):
|
||||
t_prev = scheduler.sigmas[step_index + 1].item()
|
||||
else:
|
||||
t_prev = 0.0
|
||||
|
||||
# Add channel-wise conditioning for negative pass if present
|
||||
if inpaint_extension is not None:
|
||||
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
|
||||
|
||||
# For Heun, only increment user step after second-order step completes
|
||||
if is_heun:
|
||||
if not in_first_order:
|
||||
# Second order step completed
|
||||
user_step += 1
|
||||
# Only call step_callback if we haven't exceeded total_steps
|
||||
if user_step <= total_steps:
|
||||
pbar.update(1)
|
||||
preview_img = img - t_curr * pred
|
||||
if inpaint_extension is not None:
|
||||
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(
|
||||
preview_img, 0.0
|
||||
)
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=user_step,
|
||||
order=2,
|
||||
total_steps=total_steps,
|
||||
timestep=int(t_curr * 1000),
|
||||
latents=preview_img,
|
||||
),
|
||||
)
|
||||
else:
|
||||
# For LCM and other first-order schedulers
|
||||
user_step += 1
|
||||
# Only call step_callback if we haven't exceeded total_steps
|
||||
# (LCM scheduler may have more internal steps than user-facing steps)
|
||||
if user_step <= total_steps:
|
||||
pbar.update(1)
|
||||
preview_img = img - t_curr * pred
|
||||
if inpaint_extension is not None:
|
||||
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(
|
||||
preview_img, 0.0
|
||||
)
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=user_step,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(t_curr * 1000),
|
||||
latents=preview_img,
|
||||
),
|
||||
)
|
||||
|
||||
pbar.close()
|
||||
return img
|
||||
|
||||
# Original Euler implementation (when scheduler is None)
|
||||
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
|
||||
# DyPE: Update step state for timestep-dependent scaling
|
||||
if dype_extension is not None and dype_embedder is not None:
|
||||
dype_extension.update_step_state(
|
||||
embedder=dype_embedder,
|
||||
timestep=t_curr,
|
||||
timestep_index=step_index,
|
||||
total_steps=total_steps,
|
||||
)
|
||||
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
|
||||
# Run ControlNet models.
|
||||
controlnet_residuals: list[ControlNetFluxOutput] = []
|
||||
for controlnet_extension in controlnet_extensions:
|
||||
controlnet_residuals.append(
|
||||
controlnet_extension.run_controlnet(
|
||||
timestep_index=step_index,
|
||||
total_num_timesteps=total_steps,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
)
|
||||
|
||||
# Merge the ControlNet residuals from multiple ControlNets.
|
||||
# TODO(ryand): We may want to calculate the sum just-in-time to keep peak memory low. Keep in mind, that the
|
||||
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
|
||||
# tensors. Calculating the sum materializes each tensor into its own instance.
|
||||
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
|
||||
|
||||
# Prepare input for model - concatenate fresh each step
|
||||
img_input = img
|
||||
img_input_ids = img_ids
|
||||
|
||||
# Add channel-wise conditioning (for ControlNet, FLUX Fill, etc.)
|
||||
if img_cond is not None:
|
||||
neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1)
|
||||
img_input = torch.cat((img_input, img_cond), dim=-1)
|
||||
|
||||
# Add sequence-wise conditioning (Kontext) for negative pass
|
||||
# This ensures reference images are processed consistently
|
||||
# Add sequence-wise conditioning (for Kontext)
|
||||
if img_cond_seq is not None:
|
||||
neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1)
|
||||
neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1)
|
||||
assert img_cond_seq_ids is not None, (
|
||||
"You need to provide either both or neither of the sequence conditioning"
|
||||
)
|
||||
img_input = torch.cat((img_input, img_cond_seq), dim=1)
|
||||
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
|
||||
|
||||
neg_pred = model(
|
||||
img=neg_img_input,
|
||||
img_ids=neg_img_input_ids,
|
||||
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
pred = model(
|
||||
img=img_input,
|
||||
img_ids=img_input_ids,
|
||||
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
timestep_index=step_index,
|
||||
total_num_timesteps=total_steps,
|
||||
controlnet_double_block_residuals=None,
|
||||
controlnet_single_block_residuals=None,
|
||||
ip_adapter_extensions=neg_ip_adapter_extensions,
|
||||
regional_prompting_extension=neg_regional_prompting_extension,
|
||||
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
|
||||
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
|
||||
ip_adapter_extensions=pos_ip_adapter_extensions,
|
||||
regional_prompting_extension=pos_regional_prompting_extension,
|
||||
)
|
||||
|
||||
# Slice negative prediction to match main image tokens
|
||||
# Slice prediction to only include the main image tokens
|
||||
if img_cond_seq is not None:
|
||||
neg_pred = neg_pred[:, :original_seq_len]
|
||||
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
|
||||
pred = pred[:, :original_seq_len]
|
||||
|
||||
preview_img = img - t_curr * pred
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
step_cfg_scale = cfg_scale[step_index]
|
||||
|
||||
if inpaint_extension is not None:
|
||||
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
|
||||
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
|
||||
# If step_cfg_scale, is 1.0, then we don't need to run the negative prediction.
|
||||
if not math.isclose(step_cfg_scale, 1.0):
|
||||
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance
|
||||
# on systems with sufficient VRAM.
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=step_index + 1,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(t_curr),
|
||||
latents=preview_img,
|
||||
),
|
||||
)
|
||||
if neg_regional_prompting_extension is None:
|
||||
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
|
||||
|
||||
return img
|
||||
# For negative prediction with Kontext, we need to include the reference images
|
||||
# to maintain consistency between positive and negative passes. Without this,
|
||||
# CFG would create artifacts as the attention mechanism would see different
|
||||
# spatial structures in each pass
|
||||
neg_img_input = img
|
||||
neg_img_input_ids = img_ids
|
||||
|
||||
# Add channel-wise conditioning for negative pass if present
|
||||
if img_cond is not None:
|
||||
neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1)
|
||||
|
||||
# Add sequence-wise conditioning (Kontext) for negative pass
|
||||
# This ensures reference images are processed consistently
|
||||
if img_cond_seq is not None:
|
||||
neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1)
|
||||
neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1)
|
||||
|
||||
neg_pred = model(
|
||||
img=neg_img_input,
|
||||
img_ids=neg_img_input_ids,
|
||||
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
||||
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
||||
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
timestep_index=step_index,
|
||||
total_num_timesteps=total_steps,
|
||||
controlnet_double_block_residuals=None,
|
||||
controlnet_single_block_residuals=None,
|
||||
ip_adapter_extensions=neg_ip_adapter_extensions,
|
||||
regional_prompting_extension=neg_regional_prompting_extension,
|
||||
)
|
||||
|
||||
# Slice negative prediction to match main image tokens
|
||||
if img_cond_seq is not None:
|
||||
neg_pred = neg_pred[:, :original_seq_len]
|
||||
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
|
||||
|
||||
preview_img = img - t_curr * pred
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
if inpaint_extension is not None:
|
||||
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
|
||||
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=step_index + 1,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(t_curr),
|
||||
latents=preview_img,
|
||||
),
|
||||
)
|
||||
|
||||
return img
|
||||
|
||||
finally:
|
||||
# DyPE: Restore original position embedder
|
||||
if original_pe_embedder is not None:
|
||||
DyPEExtension.restore_model(model, original_pe_embedder)
|
||||
|
||||
35
invokeai/backend/flux/dype/__init__.py
Normal file
35
invokeai/backend/flux/dype/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Dynamic Position Extrapolation (DyPE) for FLUX models.
|
||||
|
||||
DyPE enables high-resolution image generation (4K+) with pretrained FLUX models
|
||||
by dynamically scaling RoPE position embeddings during the denoising process.
|
||||
|
||||
Based on: https://github.com/wildminder/ComfyUI-DyPE
|
||||
"""
|
||||
|
||||
from invokeai.backend.flux.dype.base import DyPEConfig
|
||||
from invokeai.backend.flux.dype.embed import DyPEEmbedND
|
||||
from invokeai.backend.flux.dype.presets import (
|
||||
DYPE_PRESET_4K,
|
||||
DYPE_PRESET_AREA,
|
||||
DYPE_PRESET_AUTO,
|
||||
DYPE_PRESET_LABELS,
|
||||
DYPE_PRESET_MANUAL,
|
||||
DYPE_PRESET_OFF,
|
||||
DyPEPreset,
|
||||
get_dype_config_for_area,
|
||||
get_dype_config_for_resolution,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DyPEConfig",
|
||||
"DyPEEmbedND",
|
||||
"DyPEPreset",
|
||||
"DYPE_PRESET_OFF",
|
||||
"DYPE_PRESET_MANUAL",
|
||||
"DYPE_PRESET_AUTO",
|
||||
"DYPE_PRESET_AREA",
|
||||
"DYPE_PRESET_4K",
|
||||
"DYPE_PRESET_LABELS",
|
||||
"get_dype_config_for_area",
|
||||
"get_dype_config_for_resolution",
|
||||
]
|
||||
260
invokeai/backend/flux/dype/base.py
Normal file
260
invokeai/backend/flux/dype/base.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""DyPE base configuration and utilities."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class DyPEConfig:
|
||||
"""Configuration for Dynamic Position Extrapolation."""
|
||||
|
||||
enable_dype: bool = True
|
||||
base_resolution: int = 1024 # Native training resolution
|
||||
method: Literal["vision_yarn", "yarn", "ntk", "base"] = "vision_yarn"
|
||||
dype_scale: float = 2.0 # Magnitude λs (0.0-8.0)
|
||||
dype_exponent: float = 2.0 # Decay speed λt (0.0-1000.0)
|
||||
dype_start_sigma: float = 1.0 # When DyPE decay starts
|
||||
|
||||
|
||||
def get_mscale(scale: float, mscale_factor: float = 1.0) -> float:
|
||||
"""Calculate magnitude scaling factor.
|
||||
|
||||
Args:
|
||||
scale: The resolution scaling factor
|
||||
mscale_factor: Adjustment factor for the scaling
|
||||
|
||||
Returns:
|
||||
The magnitude scaling factor
|
||||
"""
|
||||
if scale <= 1.0:
|
||||
return 1.0
|
||||
return mscale_factor * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def get_timestep_mscale(
|
||||
scale: float,
|
||||
current_sigma: float,
|
||||
dype_scale: float,
|
||||
dype_exponent: float,
|
||||
dype_start_sigma: float,
|
||||
) -> float:
|
||||
"""Calculate timestep-dependent magnitude scaling.
|
||||
|
||||
The key insight of DyPE: early steps focus on low frequencies (global structure),
|
||||
late steps on high frequencies (details). This function modulates the scaling
|
||||
based on the current timestep/sigma.
|
||||
|
||||
Args:
|
||||
scale: Resolution scaling factor
|
||||
current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
|
||||
dype_scale: DyPE magnitude (λs)
|
||||
dype_exponent: DyPE decay speed (λt)
|
||||
dype_start_sigma: Sigma threshold to start decay
|
||||
|
||||
Returns:
|
||||
Timestep-modulated scaling factor
|
||||
"""
|
||||
if scale <= 1.0:
|
||||
return 1.0
|
||||
|
||||
# Normalize sigma to [0, 1] range relative to start_sigma
|
||||
if current_sigma >= dype_start_sigma:
|
||||
t_normalized = 1.0
|
||||
else:
|
||||
t_normalized = current_sigma / dype_start_sigma
|
||||
|
||||
# Apply exponential decay: stronger extrapolation early, weaker late
|
||||
# decay = exp(-λt * (1 - t)) where t=1 is early (high sigma), t=0 is late
|
||||
decay = math.exp(-dype_exponent * (1.0 - t_normalized))
|
||||
|
||||
# Base mscale from resolution
|
||||
base_mscale = get_mscale(scale)
|
||||
|
||||
# Interpolate between base_mscale and 1.0 based on decay and dype_scale
|
||||
# When decay=1 (early): use scaled value
|
||||
# When decay=0 (late): use base value
|
||||
scaled_mscale = 1.0 + (base_mscale - 1.0) * dype_scale * decay
|
||||
|
||||
return scaled_mscale
|
||||
|
||||
|
||||
def compute_vision_yarn_freqs(
|
||||
pos: Tensor,
|
||||
dim: int,
|
||||
theta: int,
|
||||
scale_h: float,
|
||||
scale_w: float,
|
||||
current_sigma: float,
|
||||
dype_config: DyPEConfig,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""Compute RoPE frequencies using NTK-aware scaling for high-resolution.
|
||||
|
||||
This method extends FLUX's position encoding to handle resolutions beyond
|
||||
the 1024px training resolution by scaling the base frequency (theta).
|
||||
|
||||
The NTK-aware approach smoothly interpolates frequencies to cover larger
|
||||
position ranges without breaking the attention patterns.
|
||||
|
||||
DyPE (Dynamic Position Extrapolation) modulates the NTK scaling based on
|
||||
the current timestep - stronger extrapolation in early steps (global structure),
|
||||
weaker in late steps (fine details).
|
||||
|
||||
Args:
|
||||
pos: Position tensor
|
||||
dim: Embedding dimension
|
||||
theta: RoPE base frequency
|
||||
scale_h: Height scaling factor
|
||||
scale_w: Width scaling factor
|
||||
current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
|
||||
dype_config: DyPE configuration
|
||||
|
||||
Returns:
|
||||
Tuple of (cos, sin) frequency tensors
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
|
||||
# Use the larger scale for NTK calculation
|
||||
scale = max(scale_h, scale_w)
|
||||
|
||||
device = pos.device
|
||||
dtype = torch.float64 if device.type != "mps" else torch.float32
|
||||
|
||||
# NTK-aware theta scaling: extends position coverage for high-res
|
||||
# Formula: theta_scaled = theta * scale^(dim/(dim-2))
|
||||
# This increases the wavelength of position encodings proportionally
|
||||
if scale > 1.0:
|
||||
ntk_alpha = scale ** (dim / (dim - 2))
|
||||
|
||||
# Apply timestep-dependent DyPE modulation
|
||||
# mscale controls how strongly we apply the NTK extrapolation
|
||||
# Early steps (high sigma): stronger extrapolation for global structure
|
||||
# Late steps (low sigma): weaker extrapolation for fine details
|
||||
mscale = get_timestep_mscale(
|
||||
scale=scale,
|
||||
current_sigma=current_sigma,
|
||||
dype_scale=dype_config.dype_scale,
|
||||
dype_exponent=dype_config.dype_exponent,
|
||||
dype_start_sigma=dype_config.dype_start_sigma,
|
||||
)
|
||||
|
||||
# Modulate NTK alpha by mscale
|
||||
# When mscale > 1: interpolate towards stronger extrapolation
|
||||
# When mscale = 1: use base NTK alpha
|
||||
modulated_alpha = 1.0 + (ntk_alpha - 1.0) * mscale
|
||||
scaled_theta = theta * modulated_alpha
|
||||
else:
|
||||
scaled_theta = theta
|
||||
|
||||
# Standard RoPE frequency computation
|
||||
freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
|
||||
freqs = 1.0 / (scaled_theta**freq_seq)
|
||||
|
||||
# Compute angles = position * frequency
|
||||
angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)
|
||||
|
||||
cos = torch.cos(angles)
|
||||
sin = torch.sin(angles)
|
||||
|
||||
return cos.to(pos.dtype), sin.to(pos.dtype)
|
||||
|
||||
|
||||
def compute_yarn_freqs(
|
||||
pos: Tensor,
|
||||
dim: int,
|
||||
theta: int,
|
||||
scale: float,
|
||||
current_sigma: float,
|
||||
dype_config: DyPEConfig,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""Compute RoPE frequencies using YARN/NTK method.
|
||||
|
||||
Uses NTK-aware theta scaling for high-resolution support with
|
||||
timestep-dependent DyPE modulation.
|
||||
|
||||
Args:
|
||||
pos: Position tensor
|
||||
dim: Embedding dimension
|
||||
theta: RoPE base frequency
|
||||
scale: Uniform scaling factor
|
||||
current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
|
||||
dype_config: DyPE configuration
|
||||
|
||||
Returns:
|
||||
Tuple of (cos, sin) frequency tensors
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
|
||||
device = pos.device
|
||||
dtype = torch.float64 if device.type != "mps" else torch.float32
|
||||
|
||||
# NTK-aware theta scaling with DyPE modulation
|
||||
if scale > 1.0:
|
||||
ntk_alpha = scale ** (dim / (dim - 2))
|
||||
|
||||
# Apply timestep-dependent DyPE modulation
|
||||
mscale = get_timestep_mscale(
|
||||
scale=scale,
|
||||
current_sigma=current_sigma,
|
||||
dype_scale=dype_config.dype_scale,
|
||||
dype_exponent=dype_config.dype_exponent,
|
||||
dype_start_sigma=dype_config.dype_start_sigma,
|
||||
)
|
||||
|
||||
# Modulate NTK alpha by mscale
|
||||
modulated_alpha = 1.0 + (ntk_alpha - 1.0) * mscale
|
||||
scaled_theta = theta * modulated_alpha
|
||||
else:
|
||||
scaled_theta = theta
|
||||
|
||||
freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
|
||||
freqs = 1.0 / (scaled_theta**freq_seq)
|
||||
|
||||
angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)
|
||||
|
||||
cos = torch.cos(angles)
|
||||
sin = torch.sin(angles)
|
||||
|
||||
return cos.to(pos.dtype), sin.to(pos.dtype)
|
||||
|
||||
|
||||
def compute_ntk_freqs(
|
||||
pos: Tensor,
|
||||
dim: int,
|
||||
theta: int,
|
||||
scale: float,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""Compute RoPE frequencies using NTK method.
|
||||
|
||||
Neural Tangent Kernel approach - continuous frequency scaling without
|
||||
timestep dependency.
|
||||
|
||||
Args:
|
||||
pos: Position tensor
|
||||
dim: Embedding dimension
|
||||
theta: RoPE base frequency
|
||||
scale: Scaling factor
|
||||
|
||||
Returns:
|
||||
Tuple of (cos, sin) frequency tensors
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
|
||||
device = pos.device
|
||||
dtype = torch.float64 if device.type != "mps" else torch.float32
|
||||
|
||||
# NTK scaling
|
||||
scaled_theta = theta * (scale ** (dim / (dim - 2)))
|
||||
|
||||
freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
|
||||
freqs = 1.0 / (scaled_theta**freq_seq)
|
||||
|
||||
angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)
|
||||
|
||||
cos = torch.cos(angles)
|
||||
sin = torch.sin(angles)
|
||||
|
||||
return cos.to(pos.dtype), sin.to(pos.dtype)
|
||||
116
invokeai/backend/flux/dype/embed.py
Normal file
116
invokeai/backend/flux/dype/embed.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""DyPE-enhanced position embedding module."""
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from invokeai.backend.flux.dype.base import DyPEConfig
|
||||
from invokeai.backend.flux.dype.rope import rope_dype
|
||||
|
||||
|
||||
class DyPEEmbedND(nn.Module):
|
||||
"""N-dimensional position embedding with DyPE support.
|
||||
|
||||
This class replaces the standard EmbedND from FLUX with a DyPE-aware version
|
||||
that dynamically scales position embeddings based on resolution and timestep.
|
||||
|
||||
The key difference from EmbedND:
|
||||
- Maintains step state (current_sigma, target dimensions)
|
||||
- Uses rope_dype() instead of rope() for frequency computation
|
||||
- Applies timestep-dependent scaling for better high-resolution generation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
theta: int,
|
||||
axes_dim: list[int],
|
||||
dype_config: DyPEConfig,
|
||||
):
|
||||
"""Initialize DyPE position embedder.
|
||||
|
||||
Args:
|
||||
dim: Total embedding dimension (sum of axes_dim)
|
||||
theta: RoPE base frequency
|
||||
axes_dim: Dimension allocation per axis (e.g., [16, 56, 56] for FLUX)
|
||||
dype_config: DyPE configuration
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
self.dype_config = dype_config
|
||||
|
||||
# Step state - updated before each denoising step
|
||||
self._current_sigma: float = 1.0
|
||||
self._target_height: int = 1024
|
||||
self._target_width: int = 1024
|
||||
|
||||
def set_step_state(self, sigma: float, height: int, width: int) -> None:
|
||||
"""Update the step state before each denoising step.
|
||||
|
||||
This method should be called by the DyPE extension before each step
|
||||
to update the current noise level and target dimensions.
|
||||
|
||||
Args:
|
||||
sigma: Current noise level (timestep value, 1.0 = full noise)
|
||||
height: Target image height in pixels
|
||||
width: Target image width in pixels
|
||||
"""
|
||||
self._current_sigma = sigma
|
||||
self._target_height = height
|
||||
self._target_width = width
|
||||
|
||||
def forward(self, ids: Tensor) -> Tensor:
|
||||
"""Compute position embeddings with DyPE scaling.
|
||||
|
||||
Args:
|
||||
ids: Position indices tensor with shape (batch, seq_len, n_axes)
|
||||
For FLUX: n_axes=3 (time/channel, height, width)
|
||||
|
||||
Returns:
|
||||
Position embedding tensor with shape (batch, 1, seq_len, dim)
|
||||
"""
|
||||
n_axes = ids.shape[-1]
|
||||
|
||||
# Compute RoPE for each axis with DyPE scaling
|
||||
embeddings = []
|
||||
for i in range(n_axes):
|
||||
axis_emb = rope_dype(
|
||||
pos=ids[..., i],
|
||||
dim=self.axes_dim[i],
|
||||
theta=self.theta,
|
||||
current_sigma=self._current_sigma,
|
||||
target_height=self._target_height,
|
||||
target_width=self._target_width,
|
||||
dype_config=self.dype_config,
|
||||
)
|
||||
embeddings.append(axis_emb)
|
||||
|
||||
# Concatenate embeddings from all axes
|
||||
emb = torch.cat(embeddings, dim=-3)
|
||||
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
@classmethod
|
||||
def from_embednd(
|
||||
cls,
|
||||
embed_nd: nn.Module,
|
||||
dype_config: DyPEConfig,
|
||||
) -> "DyPEEmbedND":
|
||||
"""Create a DyPEEmbedND from an existing EmbedND.
|
||||
|
||||
This is a convenience method for patching an existing FLUX model.
|
||||
|
||||
Args:
|
||||
embed_nd: Original EmbedND module from FLUX
|
||||
dype_config: DyPE configuration
|
||||
|
||||
Returns:
|
||||
New DyPEEmbedND with same parameters
|
||||
"""
|
||||
return cls(
|
||||
dim=embed_nd.dim,
|
||||
theta=embed_nd.theta,
|
||||
axes_dim=embed_nd.axes_dim,
|
||||
dype_config=dype_config,
|
||||
)
|
||||
203
invokeai/backend/flux/dype/presets.py
Normal file
203
invokeai/backend/flux/dype/presets.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""DyPE presets and automatic configuration."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
from invokeai.backend.flux.dype.base import DyPEConfig
|
||||
|
||||
# DyPE preset type - using Literal for proper frontend dropdown support
|
||||
DyPEPreset = Literal["off", "manual", "auto", "area", "4k"]
|
||||
|
||||
# Constants for preset values
|
||||
DYPE_PRESET_OFF: DyPEPreset = "off"
|
||||
DYPE_PRESET_MANUAL: DyPEPreset = "manual"
|
||||
DYPE_PRESET_AUTO: DyPEPreset = "auto"
|
||||
DYPE_PRESET_AREA: DyPEPreset = "area"
|
||||
DYPE_PRESET_4K: DyPEPreset = "4k"
|
||||
|
||||
# Human-readable labels for the UI
|
||||
DYPE_PRESET_LABELS: dict[str, str] = {
|
||||
"off": "Off",
|
||||
"manual": "Manual",
|
||||
"auto": "Auto (>1536px)",
|
||||
"area": "Area (auto)",
|
||||
"4k": "4K Optimized",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DyPEPresetConfig:
|
||||
"""Preset configuration values."""
|
||||
|
||||
base_resolution: int
|
||||
method: str
|
||||
dype_scale: float
|
||||
dype_exponent: float
|
||||
dype_start_sigma: float
|
||||
|
||||
|
||||
# Predefined preset configurations
|
||||
DYPE_PRESETS: dict[DyPEPreset, DyPEPresetConfig] = {
|
||||
DYPE_PRESET_4K: DyPEPresetConfig(
|
||||
base_resolution=1024,
|
||||
method="vision_yarn",
|
||||
dype_scale=2.0,
|
||||
dype_exponent=2.0,
|
||||
dype_start_sigma=1.0,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_dype_config_for_resolution(
|
||||
width: int,
|
||||
height: int,
|
||||
base_resolution: int = 1024,
|
||||
activation_threshold: int = 1536,
|
||||
) -> DyPEConfig | None:
|
||||
"""Automatically determine DyPE config based on target resolution.
|
||||
|
||||
FLUX can handle resolutions up to ~1.5x natively without significant artifacts.
|
||||
DyPE is only activated when the resolution exceeds the activation threshold.
|
||||
|
||||
Args:
|
||||
width: Target image width in pixels
|
||||
height: Target image height in pixels
|
||||
base_resolution: Native training resolution of the model (for scale calculation)
|
||||
activation_threshold: Resolution threshold above which DyPE is activated
|
||||
|
||||
Returns:
|
||||
DyPEConfig if DyPE should be enabled, None otherwise
|
||||
"""
|
||||
max_dim = max(width, height)
|
||||
|
||||
if max_dim <= activation_threshold:
|
||||
return None # FLUX can handle this natively
|
||||
|
||||
# Calculate scaling factor based on base_resolution
|
||||
scale = max_dim / base_resolution
|
||||
|
||||
# Dynamic parameters based on scaling
|
||||
# Higher resolution = higher dype_scale, capped at 8.0
|
||||
dynamic_dype_scale = min(2.0 * scale, 8.0)
|
||||
|
||||
return DyPEConfig(
|
||||
enable_dype=True,
|
||||
base_resolution=base_resolution,
|
||||
method="vision_yarn",
|
||||
dype_scale=dynamic_dype_scale,
|
||||
dype_exponent=2.0,
|
||||
dype_start_sigma=1.0,
|
||||
)
|
||||
|
||||
|
||||
def get_dype_config_for_area(
|
||||
width: int,
|
||||
height: int,
|
||||
base_resolution: int = 1024,
|
||||
) -> DyPEConfig | None:
|
||||
"""Automatically determine DyPE config based on target area.
|
||||
|
||||
Uses sqrt(area/base_area) as an effective side-length ratio.
|
||||
DyPE is enabled only when target area exceeds base area.
|
||||
|
||||
Returns:
|
||||
DyPEConfig if DyPE should be enabled, None otherwise
|
||||
"""
|
||||
area = width * height
|
||||
base_area = base_resolution**2
|
||||
|
||||
if area <= base_area:
|
||||
return None
|
||||
|
||||
area_ratio = area / base_area
|
||||
effective_side_ratio = math.sqrt(area_ratio) # 1.0 at base, 2.0 at 2K (if base is 1K)
|
||||
|
||||
# Strength: 0 at base area, 8 at sat_area, clamped thereafter.
|
||||
sat_area = 2027520 # Determined by experimentation where a vertical line appears
|
||||
sat_side_ratio = math.sqrt(sat_area / base_area)
|
||||
dynamic_dype_scale = 8.0 * (effective_side_ratio - 1.0) / (sat_side_ratio - 1.0)
|
||||
dynamic_dype_scale = max(0.0, min(dynamic_dype_scale, 8.0))
|
||||
|
||||
# Continuous exponent schedule:
|
||||
# r=1 -> 0.5, r=2 -> 1.0, r=4 -> 2.0 (exact), smoothly varying in between.
|
||||
x = math.log2(effective_side_ratio)
|
||||
dype_exponent = 0.25 * (x**2) + 0.25 * x + 0.5
|
||||
dype_exponent = max(0.5, min(dype_exponent, 2.0))
|
||||
|
||||
return DyPEConfig(
|
||||
enable_dype=True,
|
||||
base_resolution=base_resolution,
|
||||
method="vision_yarn",
|
||||
dype_scale=dynamic_dype_scale,
|
||||
dype_exponent=dype_exponent,
|
||||
dype_start_sigma=1.0,
|
||||
)
|
||||
|
||||
|
||||
def get_dype_config_from_preset(
|
||||
preset: DyPEPreset,
|
||||
width: int,
|
||||
height: int,
|
||||
custom_scale: float | None = None,
|
||||
custom_exponent: float | None = None,
|
||||
) -> DyPEConfig | None:
|
||||
"""Get DyPE configuration from a preset or custom values.
|
||||
|
||||
Args:
|
||||
preset: The DyPE preset to use
|
||||
width: Target image width
|
||||
height: Target image height
|
||||
custom_scale: Optional custom dype_scale (only used with 'manual' preset)
|
||||
custom_exponent: Optional custom dype_exponent (only used with 'manual' preset)
|
||||
|
||||
Returns:
|
||||
DyPEConfig if DyPE should be enabled, None otherwise
|
||||
"""
|
||||
if preset == DYPE_PRESET_OFF:
|
||||
return None
|
||||
|
||||
if preset == DYPE_PRESET_MANUAL:
|
||||
# Manual mode - custom values can override defaults
|
||||
max_dim = max(width, height)
|
||||
scale = max_dim / 1024
|
||||
dynamic_dype_scale = min(2.0 * scale, 8.0)
|
||||
return DyPEConfig(
|
||||
enable_dype=True,
|
||||
base_resolution=1024,
|
||||
method="vision_yarn",
|
||||
dype_scale=custom_scale if custom_scale is not None else dynamic_dype_scale,
|
||||
dype_exponent=custom_exponent if custom_exponent is not None else 2.0,
|
||||
dype_start_sigma=1.0,
|
||||
)
|
||||
|
||||
if preset == DYPE_PRESET_AUTO:
|
||||
# Auto preset - custom values are ignored
|
||||
return get_dype_config_for_resolution(
|
||||
width=width,
|
||||
height=height,
|
||||
base_resolution=1024,
|
||||
activation_threshold=1536,
|
||||
)
|
||||
|
||||
if preset == DYPE_PRESET_AREA:
|
||||
# Area-based preset - custom values are ignored
|
||||
return get_dype_config_for_area(
|
||||
width=width,
|
||||
height=height,
|
||||
base_resolution=1024,
|
||||
)
|
||||
|
||||
# Use preset configuration (4K etc.) - custom values are ignored
|
||||
preset_config = DYPE_PRESETS.get(preset)
|
||||
if preset_config is None:
|
||||
return None
|
||||
|
||||
return DyPEConfig(
|
||||
enable_dype=True,
|
||||
base_resolution=preset_config.base_resolution,
|
||||
method=preset_config.method,
|
||||
dype_scale=preset_config.dype_scale,
|
||||
dype_exponent=preset_config.dype_exponent,
|
||||
dype_start_sigma=preset_config.dype_start_sigma,
|
||||
)
|
||||
110
invokeai/backend/flux/dype/rope.py
Normal file
110
invokeai/backend/flux/dype/rope.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""DyPE-enhanced RoPE (Rotary Position Embedding) functions."""
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
from invokeai.backend.flux.dype.base import (
|
||||
DyPEConfig,
|
||||
compute_ntk_freqs,
|
||||
compute_vision_yarn_freqs,
|
||||
compute_yarn_freqs,
|
||||
)
|
||||
|
||||
|
||||
def rope_dype(
|
||||
pos: Tensor,
|
||||
dim: int,
|
||||
theta: int,
|
||||
current_sigma: float,
|
||||
target_height: int,
|
||||
target_width: int,
|
||||
dype_config: DyPEConfig,
|
||||
) -> Tensor:
|
||||
"""Compute RoPE with Dynamic Position Extrapolation.
|
||||
|
||||
This is the core DyPE function that replaces the standard rope() function.
|
||||
It applies resolution-aware and timestep-aware scaling to position embeddings.
|
||||
|
||||
Args:
|
||||
pos: Position indices tensor
|
||||
dim: Embedding dimension per axis
|
||||
theta: RoPE base frequency (typically 10000)
|
||||
current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
|
||||
target_height: Target image height in pixels
|
||||
target_width: Target image width in pixels
|
||||
dype_config: DyPE configuration
|
||||
|
||||
Returns:
|
||||
Rotary position embedding tensor with shape suitable for FLUX attention
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
|
||||
# Calculate scaling factors
|
||||
base_res = dype_config.base_resolution
|
||||
scale_h = target_height / base_res
|
||||
scale_w = target_width / base_res
|
||||
scale = max(scale_h, scale_w)
|
||||
|
||||
# If no scaling needed and DyPE disabled, use base method
|
||||
if not dype_config.enable_dype or scale <= 1.0:
|
||||
return _rope_base(pos, dim, theta)
|
||||
|
||||
# Select method and compute frequencies
|
||||
method = dype_config.method
|
||||
|
||||
if method == "vision_yarn":
|
||||
cos, sin = compute_vision_yarn_freqs(
|
||||
pos=pos,
|
||||
dim=dim,
|
||||
theta=theta,
|
||||
scale_h=scale_h,
|
||||
scale_w=scale_w,
|
||||
current_sigma=current_sigma,
|
||||
dype_config=dype_config,
|
||||
)
|
||||
elif method == "yarn":
|
||||
cos, sin = compute_yarn_freqs(
|
||||
pos=pos,
|
||||
dim=dim,
|
||||
theta=theta,
|
||||
scale=scale,
|
||||
current_sigma=current_sigma,
|
||||
dype_config=dype_config,
|
||||
)
|
||||
elif method == "ntk":
|
||||
cos, sin = compute_ntk_freqs(
|
||||
pos=pos,
|
||||
dim=dim,
|
||||
theta=theta,
|
||||
scale=scale,
|
||||
)
|
||||
else: # "base"
|
||||
return _rope_base(pos, dim, theta)
|
||||
|
||||
# Construct rotation matrix from cos/sin
|
||||
# Output shape: (batch, seq_len, dim/2, 2, 2)
|
||||
out = torch.stack([cos, -sin, sin, cos], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
|
||||
return out.to(dtype=pos.dtype, device=pos.device)
|
||||
|
||||
|
||||
def _rope_base(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
"""Standard RoPE without DyPE scaling.
|
||||
|
||||
This matches the original rope() function from invokeai.backend.flux.math.
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
|
||||
device = pos.device
|
||||
dtype = torch.float64 if device.type != "mps" else torch.float32
|
||||
|
||||
scale = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
|
||||
out = torch.einsum("...n,d->...nd", pos.to(dtype), omega)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||
|
||||
return out.to(dtype=pos.dtype, device=pos.device)
|
||||
91
invokeai/backend/flux/extensions/dype_extension.py
Normal file
91
invokeai/backend/flux/extensions/dype_extension.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""DyPE extension for FLUX denoising pipeline."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from invokeai.backend.flux.dype.base import DyPEConfig
|
||||
from invokeai.backend.flux.dype.embed import DyPEEmbedND
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.flux.model import Flux
|
||||
|
||||
|
||||
@dataclass
|
||||
class DyPEExtension:
|
||||
"""Extension for Dynamic Position Extrapolation in FLUX models.
|
||||
|
||||
This extension manages the patching of the FLUX model's position embedder
|
||||
and updates the step state during denoising.
|
||||
|
||||
Usage:
|
||||
1. Create extension with config and target dimensions
|
||||
2. Call patch_model() to replace pe_embedder with DyPE version
|
||||
3. Call update_step_state() before each denoising step
|
||||
4. Call restore_model() after denoising to restore original embedder
|
||||
"""
|
||||
|
||||
config: DyPEConfig
|
||||
target_height: int
|
||||
target_width: int
|
||||
|
||||
def patch_model(self, model: "Flux") -> tuple[DyPEEmbedND, object]:
|
||||
"""Patch the model's position embedder with DyPE version.
|
||||
|
||||
Args:
|
||||
model: The FLUX model to patch
|
||||
|
||||
Returns:
|
||||
Tuple of (new DyPE embedder, original embedder for restoration)
|
||||
"""
|
||||
original_embedder = model.pe_embedder
|
||||
|
||||
dype_embedder = DyPEEmbedND.from_embednd(
|
||||
embed_nd=original_embedder,
|
||||
dype_config=self.config,
|
||||
)
|
||||
|
||||
# Set initial state
|
||||
dype_embedder.set_step_state(
|
||||
sigma=1.0,
|
||||
height=self.target_height,
|
||||
width=self.target_width,
|
||||
)
|
||||
|
||||
# Replace the embedder
|
||||
model.pe_embedder = dype_embedder
|
||||
|
||||
return dype_embedder, original_embedder
|
||||
|
||||
def update_step_state(
|
||||
self,
|
||||
embedder: DyPEEmbedND,
|
||||
timestep: float,
|
||||
timestep_index: int,
|
||||
total_steps: int,
|
||||
) -> None:
|
||||
"""Update the step state in the DyPE embedder.
|
||||
|
||||
This should be called before each denoising step to update the
|
||||
current noise level for timestep-dependent scaling.
|
||||
|
||||
Args:
|
||||
embedder: The DyPE embedder to update
|
||||
timestep: Current timestep value (sigma/noise level)
|
||||
timestep_index: Current step index (0-based)
|
||||
total_steps: Total number of denoising steps
|
||||
"""
|
||||
embedder.set_step_state(
|
||||
sigma=timestep,
|
||||
height=self.target_height,
|
||||
width=self.target_width,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def restore_model(model: "Flux", original_embedder: object) -> None:
|
||||
"""Restore the original position embedder.
|
||||
|
||||
Args:
|
||||
model: The FLUX model to restore
|
||||
original_embedder: The original embedder saved from patch_model()
|
||||
"""
|
||||
model.pe_embedder = original_embedder
|
||||
62
invokeai/backend/flux/schedulers.py
Normal file
62
invokeai/backend/flux/schedulers.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Flow Matching scheduler definitions and mapping.
|
||||
|
||||
This module provides the scheduler types and mapping for Flow Matching models
|
||||
(Flux and Z-Image), supporting multiple schedulers from the diffusers library.
|
||||
"""
|
||||
|
||||
from typing import Literal, Type
|
||||
|
||||
from diffusers import (
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
FlowMatchHeunDiscreteScheduler,
|
||||
)
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
|
||||
# Note: FlowMatchLCMScheduler may not be available in all diffusers versions
|
||||
try:
|
||||
from diffusers import FlowMatchLCMScheduler
|
||||
|
||||
_HAS_LCM = True
|
||||
except ImportError:
|
||||
_HAS_LCM = False
|
||||
|
||||
# Scheduler name literal type for type checking
|
||||
FLUX_SCHEDULER_NAME_VALUES = Literal["euler", "heun", "lcm"]
|
||||
|
||||
# Human-readable labels for the UI
|
||||
FLUX_SCHEDULER_LABELS: dict[str, str] = {
|
||||
"euler": "Euler",
|
||||
"heun": "Heun (2nd order)",
|
||||
"lcm": "LCM",
|
||||
}
|
||||
|
||||
# Mapping from scheduler names to scheduler classes
|
||||
FLUX_SCHEDULER_MAP: dict[str, Type[SchedulerMixin]] = {
|
||||
"euler": FlowMatchEulerDiscreteScheduler,
|
||||
"heun": FlowMatchHeunDiscreteScheduler,
|
||||
}
|
||||
|
||||
if _HAS_LCM:
|
||||
FLUX_SCHEDULER_MAP["lcm"] = FlowMatchLCMScheduler
|
||||
|
||||
|
||||
# Z-Image scheduler types (same schedulers as Flux, both use Flow Matching)
|
||||
# Note: Z-Image-Turbo is optimized for ~8 steps with Euler, but other schedulers
|
||||
# can be used for experimentation.
|
||||
ZIMAGE_SCHEDULER_NAME_VALUES = Literal["euler", "heun", "lcm"]
|
||||
|
||||
# Human-readable labels for the UI
|
||||
ZIMAGE_SCHEDULER_LABELS: dict[str, str] = {
|
||||
"euler": "Euler",
|
||||
"heun": "Heun (2nd order)",
|
||||
"lcm": "LCM",
|
||||
}
|
||||
|
||||
# Mapping from scheduler names to scheduler classes (same as Flux)
|
||||
ZIMAGE_SCHEDULER_MAP: dict[str, Type[SchedulerMixin]] = {
|
||||
"euler": FlowMatchEulerDiscreteScheduler,
|
||||
"heun": FlowMatchHeunDiscreteScheduler,
|
||||
}
|
||||
|
||||
if _HAS_LCM:
|
||||
ZIMAGE_SCHEDULER_MAP["lcm"] = FlowMatchLCMScheduler
|
||||
@@ -5,7 +5,7 @@ from typing import Literal
|
||||
|
||||
from invokeai.backend.flux.model import FluxParams
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
|
||||
from invokeai.backend.model_manager.taxonomy import AnyVariant, FluxVariantType
|
||||
from invokeai.backend.model_manager.taxonomy import AnyVariant, Flux2VariantType, FluxVariantType
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -46,6 +46,8 @@ _flux_max_seq_lengths: dict[AnyVariant, Literal[256, 512]] = {
|
||||
FluxVariantType.Dev: 512,
|
||||
FluxVariantType.DevFill: 512,
|
||||
FluxVariantType.Schnell: 256,
|
||||
Flux2VariantType.Klein4B: 512,
|
||||
Flux2VariantType.Klein9B: 512,
|
||||
}
|
||||
|
||||
|
||||
@@ -117,6 +119,38 @@ _flux_transformer_params: dict[AnyVariant, FluxParams] = {
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
# Flux2 Klein 4B uses Qwen3 4B text encoder with stacked embeddings from layers [9, 18, 27]
|
||||
# The context_in_dim is 3 * hidden_size of Qwen3 (3 * 2560 = 7680)
|
||||
Flux2VariantType.Klein4B: FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=2560, # Qwen3-4B hidden size (used for pooled output)
|
||||
context_in_dim=7680, # 3 layers * 2560 = 7680 for Qwen3-4B
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
# Flux2 Klein 9B uses Qwen3 8B text encoder with stacked embeddings from layers [9, 18, 27]
|
||||
# The context_in_dim is 3 * hidden_size of Qwen3 (3 * 4096 = 12288)
|
||||
Flux2VariantType.Klein9B: FluxParams(
|
||||
in_channels=64,
|
||||
vec_in_dim=4096, # Qwen3-8B hidden size (used for pooled output)
|
||||
context_in_dim=12288, # 3 layers * 4096 = 12288 for Qwen3-8B
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
||||
4
invokeai/backend/flux2/__init__.py
Normal file
4
invokeai/backend/flux2/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""FLUX.2 backend modules.
|
||||
|
||||
This package contains modules specific to FLUX.2 models (e.g., Klein).
|
||||
"""
|
||||
288
invokeai/backend/flux2/denoise.py
Normal file
288
invokeai/backend/flux2/denoise.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""Flux2 Klein Denoising Function.
|
||||
|
||||
This module provides the denoising function for FLUX.2 Klein models,
|
||||
which use Qwen3 as the text encoder instead of CLIP+T5.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
|
||||
|
||||
def denoise(
|
||||
model: torch.nn.Module,
|
||||
# model input
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
step_callback: Callable[[PipelineIntermediateState], None],
|
||||
cfg_scale: list[float],
|
||||
# Negative conditioning for CFG
|
||||
neg_txt: torch.Tensor | None = None,
|
||||
neg_txt_ids: torch.Tensor | None = None,
|
||||
# Scheduler for stepping (e.g., FlowMatchEulerDiscreteScheduler, FlowMatchHeunDiscreteScheduler)
|
||||
scheduler: Any = None,
|
||||
# Dynamic shifting parameter for FLUX.2 Klein (computed from image resolution)
|
||||
mu: float | None = None,
|
||||
# Inpainting extension for merging latents during denoising
|
||||
inpaint_extension: RectifiedFlowInpaintExtension | None = None,
|
||||
# Reference image conditioning (multi-reference image editing)
|
||||
img_cond_seq: torch.Tensor | None = None,
|
||||
img_cond_seq_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Denoise latents using a FLUX.2 Klein transformer model.
|
||||
|
||||
This is a simplified denoise function for FLUX.2 Klein models that uses
|
||||
the diffusers Flux2Transformer2DModel interface.
|
||||
|
||||
Note: FLUX.2 Klein has guidance_embeds=False, so no guidance parameter is used.
|
||||
CFG is applied externally using negative conditioning when cfg_scale != 1.0.
|
||||
|
||||
Args:
|
||||
model: The Flux2Transformer2DModel from diffusers.
|
||||
img: Packed latent image tensor of shape (B, seq_len, channels).
|
||||
img_ids: Image position IDs tensor.
|
||||
txt: Text encoder hidden states (Qwen3 embeddings).
|
||||
txt_ids: Text position IDs tensor.
|
||||
timesteps: List of timesteps for denoising schedule (linear sigmas from 1.0 to 1/n).
|
||||
step_callback: Callback function for progress updates.
|
||||
cfg_scale: List of CFG scale values per step.
|
||||
neg_txt: Negative text embeddings for CFG (optional).
|
||||
neg_txt_ids: Negative text position IDs (optional).
|
||||
scheduler: Optional diffusers scheduler (Euler, Heun, LCM). If None, uses manual Euler.
|
||||
mu: Dynamic shifting parameter computed from image resolution. Required when scheduler
|
||||
has use_dynamic_shifting=True.
|
||||
|
||||
Returns:
|
||||
Denoised latent tensor.
|
||||
"""
|
||||
total_steps = len(timesteps) - 1
|
||||
|
||||
# Store original sequence length for extracting output later (before concatenating reference images)
|
||||
original_seq_len = img.shape[1]
|
||||
|
||||
# Concatenate reference image conditioning if provided (multi-reference image editing)
|
||||
if img_cond_seq is not None and img_cond_seq_ids is not None:
|
||||
img = torch.cat([img, img_cond_seq], dim=1)
|
||||
img_ids = torch.cat([img_ids, img_cond_seq_ids], dim=1)
|
||||
|
||||
# Klein has guidance_embeds=False, but the transformer forward() still requires a guidance tensor
|
||||
# We pass a dummy value (1.0) since it won't affect the output when guidance_embeds=False
|
||||
guidance = torch.full((img.shape[0],), 1.0, device=img.device, dtype=img.dtype)
|
||||
|
||||
# Use scheduler if provided
|
||||
use_scheduler = scheduler is not None
|
||||
if use_scheduler:
|
||||
# Set up scheduler with sigmas and mu for dynamic shifting
|
||||
# Convert timesteps (0-1 range) to sigmas for the scheduler
|
||||
# The scheduler will apply dynamic shifting internally using mu (if enabled in scheduler config)
|
||||
sigmas = np.array(timesteps[:-1], dtype=np.float32) # Exclude final 0.0
|
||||
|
||||
# Check if scheduler supports sigmas parameter using inspect.signature
|
||||
# FlowMatchHeunDiscreteScheduler and FlowMatchLCMScheduler don't support sigmas
|
||||
set_timesteps_sig = inspect.signature(scheduler.set_timesteps)
|
||||
supports_sigmas = "sigmas" in set_timesteps_sig.parameters
|
||||
if supports_sigmas and mu is not None:
|
||||
# Pass mu if provided - it will only be used if scheduler has use_dynamic_shifting=True
|
||||
scheduler.set_timesteps(sigmas=sigmas.tolist(), mu=mu, device=img.device)
|
||||
elif supports_sigmas:
|
||||
scheduler.set_timesteps(sigmas=sigmas.tolist(), device=img.device)
|
||||
else:
|
||||
# Scheduler doesn't support sigmas (e.g., Heun, LCM) - use num_inference_steps
|
||||
scheduler.set_timesteps(num_inference_steps=len(sigmas), device=img.device)
|
||||
num_scheduler_steps = len(scheduler.timesteps)
|
||||
is_heun = hasattr(scheduler, "state_in_first_order")
|
||||
user_step = 0
|
||||
|
||||
pbar = tqdm(total=total_steps, desc="Denoising")
|
||||
for step_index in range(num_scheduler_steps):
|
||||
timestep = scheduler.timesteps[step_index]
|
||||
# Convert scheduler timestep (0-1000) to normalized (0-1) for the model
|
||||
t_curr = timestep.item() / scheduler.config.num_train_timesteps
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
|
||||
# Track if we're in first or second order step (for Heun)
|
||||
in_first_order = scheduler.state_in_first_order if is_heun else True
|
||||
|
||||
# Run the transformer model (matching diffusers: guidance=guidance, return_dict=False)
|
||||
output = model(
|
||||
hidden_states=img,
|
||||
encoder_hidden_states=txt,
|
||||
timestep=t_vec,
|
||||
img_ids=img_ids,
|
||||
txt_ids=txt_ids,
|
||||
guidance=guidance,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
# Extract the sample from the output (return_dict=False returns tuple)
|
||||
pred = output[0] if isinstance(output, tuple) else output
|
||||
|
||||
step_cfg_scale = cfg_scale[min(user_step, len(cfg_scale) - 1)]
|
||||
|
||||
# Apply CFG if scale is not 1.0
|
||||
if not math.isclose(step_cfg_scale, 1.0):
|
||||
if neg_txt is None:
|
||||
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
|
||||
|
||||
neg_output = model(
|
||||
hidden_states=img,
|
||||
encoder_hidden_states=neg_txt,
|
||||
timestep=t_vec,
|
||||
img_ids=img_ids,
|
||||
txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids,
|
||||
guidance=guidance,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
neg_pred = neg_output[0] if isinstance(neg_output, tuple) else neg_output
|
||||
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
|
||||
|
||||
# Use scheduler.step() for the update
|
||||
step_output = scheduler.step(model_output=pred, timestep=timestep, sample=img)
|
||||
img = step_output.prev_sample
|
||||
|
||||
# Get t_prev for inpainting (next sigma value)
|
||||
if step_index + 1 < len(scheduler.sigmas):
|
||||
t_prev = scheduler.sigmas[step_index + 1].item()
|
||||
else:
|
||||
t_prev = 0.0
|
||||
|
||||
# Apply inpainting merge at each step
|
||||
if inpaint_extension is not None:
|
||||
# Separate the generated latents from the reference conditioning
|
||||
gen_img = img[:, :original_seq_len, :]
|
||||
ref_img = img[:, original_seq_len:, :]
|
||||
|
||||
# Merge only the generated part
|
||||
gen_img = inpaint_extension.merge_intermediate_latents_with_init_latents(gen_img, t_prev)
|
||||
|
||||
# Concatenate back together
|
||||
img = torch.cat([gen_img, ref_img], dim=1)
|
||||
|
||||
# For Heun, only increment user step after second-order step completes
|
||||
if is_heun:
|
||||
if not in_first_order:
|
||||
user_step += 1
|
||||
if user_step <= total_steps:
|
||||
pbar.update(1)
|
||||
preview_img = img - t_curr * pred
|
||||
if inpaint_extension is not None:
|
||||
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(
|
||||
preview_img, 0.0
|
||||
)
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=user_step,
|
||||
order=2,
|
||||
total_steps=total_steps,
|
||||
timestep=int(t_curr * 1000),
|
||||
latents=preview_img,
|
||||
),
|
||||
)
|
||||
else:
|
||||
user_step += 1
|
||||
if user_step <= total_steps:
|
||||
pbar.update(1)
|
||||
preview_img = img - t_curr * pred
|
||||
if inpaint_extension is not None:
|
||||
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
|
||||
# Extract only the generated image portion for preview (exclude reference images)
|
||||
callback_latents = preview_img[:, :original_seq_len, :] if img_cond_seq is not None else preview_img
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=user_step,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(t_curr * 1000),
|
||||
latents=callback_latents,
|
||||
),
|
||||
)
|
||||
|
||||
pbar.close()
|
||||
else:
|
||||
# Manual Euler stepping (original behavior)
|
||||
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
|
||||
# Run the transformer model (matching diffusers: guidance=guidance, return_dict=False)
|
||||
output = model(
|
||||
hidden_states=img,
|
||||
encoder_hidden_states=txt,
|
||||
timestep=t_vec,
|
||||
img_ids=img_ids,
|
||||
txt_ids=txt_ids,
|
||||
guidance=guidance,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
# Extract the sample from the output (return_dict=False returns tuple)
|
||||
pred = output[0] if isinstance(output, tuple) else output
|
||||
|
||||
step_cfg_scale = cfg_scale[step_index]
|
||||
|
||||
# Apply CFG if scale is not 1.0
|
||||
if not math.isclose(step_cfg_scale, 1.0):
|
||||
if neg_txt is None:
|
||||
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
|
||||
|
||||
neg_output = model(
|
||||
hidden_states=img,
|
||||
encoder_hidden_states=neg_txt,
|
||||
timestep=t_vec,
|
||||
img_ids=img_ids,
|
||||
txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids,
|
||||
guidance=guidance,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
neg_pred = neg_output[0] if isinstance(neg_output, tuple) else neg_output
|
||||
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
|
||||
|
||||
# Euler step
|
||||
preview_img = img - t_curr * pred
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
# Apply inpainting merge at each step
|
||||
if inpaint_extension is not None:
|
||||
# Separate the generated latents from the reference conditioning
|
||||
gen_img = img[:, :original_seq_len, :]
|
||||
ref_img = img[:, original_seq_len:, :]
|
||||
|
||||
# Merge only the generated part
|
||||
gen_img = inpaint_extension.merge_intermediate_latents_with_init_latents(gen_img, t_prev)
|
||||
|
||||
# Concatenate back together
|
||||
img = torch.cat([gen_img, ref_img], dim=1)
|
||||
|
||||
# Handling preview images
|
||||
preview_gen = preview_img[:, :original_seq_len, :]
|
||||
preview_gen = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_gen, 0.0)
|
||||
|
||||
# Extract only the generated image portion for preview (exclude reference images)
|
||||
callback_latents = preview_img[:, :original_seq_len, :] if img_cond_seq is not None else preview_img
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=step_index + 1,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(t_curr),
|
||||
latents=callback_latents,
|
||||
),
|
||||
)
|
||||
|
||||
# Extract only the generated image portion (exclude concatenated reference images)
|
||||
if img_cond_seq is not None:
|
||||
img = img[:, :original_seq_len, :]
|
||||
|
||||
return img
|
||||
294
invokeai/backend/flux2/ref_image_extension.py
Normal file
294
invokeai/backend/flux2/ref_image_extension.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""FLUX.2 Klein Reference Image Extension for multi-reference image editing.
|
||||
|
||||
This module provides the Flux2RefImageExtension for FLUX.2 Klein models,
|
||||
which handles encoding reference images using the FLUX.2 VAE and
|
||||
generating the appropriate position IDs for multi-reference image editing.
|
||||
|
||||
FLUX.2 Klein has built-in support for reference image editing (unlike FLUX.1
|
||||
which requires a separate Kontext model).
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as T
|
||||
from einops import repeat
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.fields import FluxKontextConditioningField
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux2.sampling_utils import pack_flux2
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
# Maximum pixel counts for reference images (matches BFL FLUX.2 sampling.py)
|
||||
# Single reference image: 2024² pixels, Multiple: 1024² pixels
|
||||
MAX_PIXELS_SINGLE_REF = 2024**2 # ~4.1M pixels
|
||||
MAX_PIXELS_MULTI_REF = 1024**2 # ~1M pixels
|
||||
|
||||
|
||||
def resize_image_to_max_pixels(image: Image.Image, max_pixels: int) -> Image.Image:
|
||||
"""Resize image to fit within max_pixels while preserving aspect ratio.
|
||||
|
||||
This matches the BFL FLUX.2 sampling.py cap_pixels() behavior.
|
||||
|
||||
Args:
|
||||
image: PIL Image to resize.
|
||||
max_pixels: Maximum total pixel count (width * height).
|
||||
|
||||
Returns:
|
||||
Resized PIL Image (or original if already within bounds).
|
||||
"""
|
||||
width, height = image.size
|
||||
pixel_count = width * height
|
||||
|
||||
if pixel_count <= max_pixels:
|
||||
return image
|
||||
|
||||
# Calculate scale factor to fit within max_pixels (BFL approach)
|
||||
scale = math.sqrt(max_pixels / pixel_count)
|
||||
new_width = int(width * scale)
|
||||
new_height = int(height * scale)
|
||||
|
||||
# Ensure dimensions are at least 1
|
||||
new_width = max(1, new_width)
|
||||
new_height = max(1, new_height)
|
||||
|
||||
return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
|
||||
def generate_img_ids_flux2_with_offset(
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
batch_size: int,
|
||||
device: torch.device,
|
||||
idx_offset: int = 0,
|
||||
h_offset: int = 0,
|
||||
w_offset: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Generate tensor of image position ids with optional offsets for FLUX.2.
|
||||
|
||||
FLUX.2 uses 4D position coordinates (T, H, W, L) for its rotary position embeddings.
|
||||
Position IDs use int64 (long) dtype.
|
||||
|
||||
Args:
|
||||
latent_height: Height of image in latent space (before packing).
|
||||
latent_width: Width of image in latent space (before packing).
|
||||
batch_size: Number of images in the batch.
|
||||
device: Device to create tensors on.
|
||||
idx_offset: Offset for T (time/index) coordinate - use 1 for reference images.
|
||||
h_offset: Spatial offset for H coordinate in latent space.
|
||||
w_offset: Spatial offset for W coordinate in latent space.
|
||||
|
||||
Returns:
|
||||
Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 4].
|
||||
"""
|
||||
# After packing, the spatial dimensions are halved due to the 2x2 patch structure
|
||||
packed_height = latent_height // 2
|
||||
packed_width = latent_width // 2
|
||||
|
||||
# Convert spatial offsets from latent space to packed space
|
||||
packed_h_offset = h_offset // 2
|
||||
packed_w_offset = w_offset // 2
|
||||
|
||||
# Create base tensor for position IDs with shape [packed_height, packed_width, 4]
|
||||
# The 4 channels represent: [T, H, W, L]
|
||||
img_ids = torch.zeros(packed_height, packed_width, 4, device=device, dtype=torch.long)
|
||||
|
||||
# Set T (time/index offset) for all positions - use 1 for reference images
|
||||
img_ids[..., 0] = idx_offset
|
||||
|
||||
# Set H (height/y) coordinates with offset
|
||||
h_coords = torch.arange(packed_height, device=device, dtype=torch.long) + packed_h_offset
|
||||
img_ids[..., 1] = h_coords[:, None]
|
||||
|
||||
# Set W (width/x) coordinates with offset
|
||||
w_coords = torch.arange(packed_width, device=device, dtype=torch.long) + packed_w_offset
|
||||
img_ids[..., 2] = w_coords[None, :]
|
||||
|
||||
# L (layer) coordinate stays 0
|
||||
|
||||
# Expand to include batch dimension: [batch_size, (packed_height * packed_width), 4]
|
||||
img_ids = img_ids.reshape(1, packed_height * packed_width, 4)
|
||||
img_ids = repeat(img_ids, "1 s c -> b s c", b=batch_size)
|
||||
|
||||
return img_ids
|
||||
|
||||
|
||||
class Flux2RefImageExtension:
|
||||
"""Applies FLUX.2 Klein reference image conditioning.
|
||||
|
||||
This extension handles encoding reference images using the FLUX.2 VAE
|
||||
and generating the appropriate 4D position IDs for multi-reference image editing.
|
||||
|
||||
FLUX.2 Klein has built-in support for reference image editing, unlike FLUX.1
|
||||
which requires a separate Kontext model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ref_image_conditioning: list[FluxKontextConditioningField],
|
||||
context: InvocationContext,
|
||||
vae_field: VAEField,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
bn_mean: torch.Tensor | None = None,
|
||||
bn_std: torch.Tensor | None = None,
|
||||
):
|
||||
"""Initialize the Flux2RefImageExtension.
|
||||
|
||||
Args:
|
||||
ref_image_conditioning: List of reference image conditioning fields.
|
||||
context: The invocation context for loading models and images.
|
||||
vae_field: The FLUX.2 VAE field for encoding images.
|
||||
device: Target device for tensors.
|
||||
dtype: Target dtype for tensors.
|
||||
bn_mean: BN running mean for normalizing latents (shape: 128).
|
||||
bn_std: BN running std for normalizing latents (shape: 128).
|
||||
"""
|
||||
self._context = context
|
||||
self._device = device
|
||||
self._dtype = dtype
|
||||
self._vae_field = vae_field
|
||||
self._bn_mean = bn_mean
|
||||
self._bn_std = bn_std
|
||||
self.ref_image_conditioning = ref_image_conditioning
|
||||
|
||||
# Pre-process and cache the reference image latents and ids upon initialization
|
||||
self.ref_image_latents, self.ref_image_ids = self._prepare_ref_images()
|
||||
|
||||
def _bn_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply BN normalization to packed latents.
|
||||
|
||||
BN formula (affine=False): y = (x - mean) / std
|
||||
|
||||
Args:
|
||||
x: Packed latents of shape (B, seq, 128).
|
||||
|
||||
Returns:
|
||||
Normalized latents of same shape.
|
||||
"""
|
||||
assert self._bn_mean is not None and self._bn_std is not None
|
||||
bn_mean = self._bn_mean.to(x.device, x.dtype)
|
||||
bn_std = self._bn_std.to(x.device, x.dtype)
|
||||
return (x - bn_mean) / bn_std
|
||||
|
||||
def _prepare_ref_images(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encode reference images and prepare their concatenated latents and IDs with spatial tiling."""
|
||||
all_latents = []
|
||||
all_ids = []
|
||||
|
||||
# Track cumulative dimensions for spatial tiling
|
||||
canvas_h = 0
|
||||
canvas_w = 0
|
||||
|
||||
vae_info = self._context.models.load(self._vae_field.vae)
|
||||
|
||||
# Determine max pixels based on number of reference images (BFL FLUX.2 approach)
|
||||
num_refs = len(self.ref_image_conditioning)
|
||||
max_pixels = MAX_PIXELS_SINGLE_REF if num_refs == 1 else MAX_PIXELS_MULTI_REF
|
||||
|
||||
for idx, ref_image_field in enumerate(self.ref_image_conditioning):
|
||||
image = self._context.images.get_pil(ref_image_field.image.image_name)
|
||||
image = image.convert("RGB")
|
||||
|
||||
# Resize large images to max pixel count (matches BFL FLUX.2 sampling.py)
|
||||
image = resize_image_to_max_pixels(image, max_pixels)
|
||||
|
||||
# Convert to tensor using torchvision transforms
|
||||
transformation = T.Compose([T.ToTensor()])
|
||||
image_tensor = transformation(image)
|
||||
# Convert from [0, 1] to [-1, 1] range expected by VAE
|
||||
image_tensor = image_tensor * 2.0 - 1.0
|
||||
image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# Encode using FLUX.2 VAE
|
||||
with vae_info.model_on_device() as (_, vae):
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
||||
|
||||
# FLUX.2 VAE uses diffusers API
|
||||
latent_dist = vae.encode(image_tensor, return_dict=False)[0]
|
||||
|
||||
# Use mode() for deterministic encoding (no sampling)
|
||||
if hasattr(latent_dist, "mode"):
|
||||
ref_image_latents_unpacked = latent_dist.mode()
|
||||
elif hasattr(latent_dist, "sample"):
|
||||
ref_image_latents_unpacked = latent_dist.sample()
|
||||
else:
|
||||
ref_image_latents_unpacked = latent_dist
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
# Extract tensor dimensions (B, 32, H, W for FLUX.2)
|
||||
batch_size, _, latent_height, latent_width = ref_image_latents_unpacked.shape
|
||||
|
||||
# Pad latents to be compatible with patch_size=2
|
||||
pad_h = (2 - latent_height % 2) % 2
|
||||
pad_w = (2 - latent_width % 2) % 2
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
ref_image_latents_unpacked = F.pad(ref_image_latents_unpacked, (0, pad_w, 0, pad_h), mode="circular")
|
||||
_, _, latent_height, latent_width = ref_image_latents_unpacked.shape
|
||||
|
||||
# Pack the latents using FLUX.2 pack function (32 channels -> 128)
|
||||
ref_image_latents_packed = pack_flux2(ref_image_latents_unpacked).to(self._device, self._dtype)
|
||||
|
||||
# Apply BN normalization to match the input latents scale
|
||||
# This is critical - the transformer expects normalized latents
|
||||
if self._bn_mean is not None and self._bn_std is not None:
|
||||
ref_image_latents_packed = self._bn_normalize(ref_image_latents_packed)
|
||||
|
||||
# Determine spatial offsets for this reference image
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
|
||||
if idx > 0: # First image starts at (0, 0)
|
||||
# Calculate potential canvas dimensions for each tiling option
|
||||
potential_h_vertical = canvas_h + latent_height
|
||||
potential_w_horizontal = canvas_w + latent_width
|
||||
|
||||
# Choose arrangement that minimizes the maximum dimension
|
||||
if potential_h_vertical > potential_w_horizontal:
|
||||
# Tile horizontally (to the right)
|
||||
w_offset = canvas_w
|
||||
canvas_w = canvas_w + latent_width
|
||||
canvas_h = max(canvas_h, latent_height)
|
||||
else:
|
||||
# Tile vertically (below)
|
||||
h_offset = canvas_h
|
||||
canvas_h = canvas_h + latent_height
|
||||
canvas_w = max(canvas_w, latent_width)
|
||||
else:
|
||||
canvas_h = latent_height
|
||||
canvas_w = latent_width
|
||||
|
||||
# Generate position IDs with 4D format (T, H, W, L)
|
||||
# Use T-coordinate offset with scale=10 like diffusers Flux2Pipeline:
|
||||
# T = scale + scale * idx (so first ref image is T=10, second is T=20, etc.)
|
||||
# The generated image uses T=0, so this clearly separates reference images
|
||||
t_offset = 10 + 10 * idx # scale=10 matches diffusers
|
||||
ref_image_ids = generate_img_ids_flux2_with_offset(
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
batch_size=batch_size,
|
||||
device=self._device,
|
||||
idx_offset=t_offset, # Reference images use T=10, 20, 30...
|
||||
h_offset=h_offset,
|
||||
w_offset=w_offset,
|
||||
)
|
||||
|
||||
all_latents.append(ref_image_latents_packed)
|
||||
all_ids.append(ref_image_ids)
|
||||
|
||||
# Concatenate all latents and IDs along the sequence dimension
|
||||
concatenated_latents = torch.cat(all_latents, dim=1)
|
||||
concatenated_ids = torch.cat(all_ids, dim=1)
|
||||
|
||||
return concatenated_latents, concatenated_ids
|
||||
|
||||
def ensure_batch_size(self, target_batch_size: int) -> None:
|
||||
"""Ensure the reference image latents and IDs match the target batch size."""
|
||||
if self.ref_image_latents.shape[0] != target_batch_size:
|
||||
self.ref_image_latents = self.ref_image_latents.repeat(target_batch_size, 1, 1)
|
||||
self.ref_image_ids = self.ref_image_ids.repeat(target_batch_size, 1, 1)
|
||||
206
invokeai/backend/flux2/sampling_utils.py
Normal file
206
invokeai/backend/flux2/sampling_utils.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""FLUX.2 Klein Sampling Utilities.
|
||||
|
||||
FLUX.2 Klein uses a 32-channel VAE (AutoencoderKLFlux2) instead of the 16-channel VAE
|
||||
used by FLUX.1. This module provides sampling utilities adapted for FLUX.2.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def get_noise_flux2(
|
||||
num_samples: int,
|
||||
height: int,
|
||||
width: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> torch.Tensor:
|
||||
"""Generate noise for FLUX.2 Klein (32 channels).
|
||||
|
||||
FLUX.2 uses a 32-channel VAE, so noise must have 32 channels.
|
||||
The spatial dimensions are calculated to allow for packing.
|
||||
|
||||
Args:
|
||||
num_samples: Batch size.
|
||||
height: Target image height in pixels.
|
||||
width: Target image width in pixels.
|
||||
device: Target device.
|
||||
dtype: Target dtype.
|
||||
seed: Random seed.
|
||||
|
||||
Returns:
|
||||
Noise tensor of shape (num_samples, 32, latent_h, latent_w).
|
||||
"""
|
||||
# We always generate noise on the same device and dtype then cast to ensure consistency.
|
||||
rand_device = "cpu"
|
||||
rand_dtype = torch.float16
|
||||
|
||||
# FLUX.2 uses 32 latent channels
|
||||
# Latent dimensions: height/8, width/8 (from VAE downsampling)
|
||||
# Must be divisible by 2 for packing (patchify step)
|
||||
latent_h = 2 * math.ceil(height / 16)
|
||||
latent_w = 2 * math.ceil(width / 16)
|
||||
|
||||
return torch.randn(
|
||||
num_samples,
|
||||
32, # FLUX.2 uses 32 latent channels (vs 16 for FLUX.1)
|
||||
latent_h,
|
||||
latent_w,
|
||||
device=rand_device,
|
||||
dtype=rand_dtype,
|
||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
def pack_flux2(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Pack latent image to flattened array of patch embeddings for FLUX.2.
|
||||
|
||||
This performs the patchify + pack operation in one step:
|
||||
1. Patchify: Group 2x2 spatial patches into channels (C*4)
|
||||
2. Pack: Flatten spatial dimensions to sequence
|
||||
|
||||
For 32-channel input: (B, 32, H, W) -> (B, H/2*W/2, 128)
|
||||
|
||||
Args:
|
||||
x: Latent tensor of shape (B, 32, H, W).
|
||||
|
||||
Returns:
|
||||
Packed tensor of shape (B, H/2*W/2, 128).
|
||||
"""
|
||||
# Same operation as FLUX.1 pack, but input has 32 channels -> output has 128
|
||||
return rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
|
||||
|
||||
def unpack_flux2(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""Unpack flat array of patch embeddings back to latent image for FLUX.2.
|
||||
|
||||
This reverses the pack_flux2 operation:
|
||||
1. Unpack: Restore spatial dimensions from sequence
|
||||
2. Unpatchify: Restore 32 channels from 128
|
||||
|
||||
Args:
|
||||
x: Packed tensor of shape (B, H/2*W/2, 128).
|
||||
height: Target image height in pixels.
|
||||
width: Target image width in pixels.
|
||||
|
||||
Returns:
|
||||
Latent tensor of shape (B, 32, H, W).
|
||||
"""
|
||||
# Calculate latent dimensions
|
||||
latent_h = 2 * math.ceil(height / 16)
|
||||
latent_w = 2 * math.ceil(width / 16)
|
||||
|
||||
# Packed dimensions (after patchify)
|
||||
packed_h = latent_h // 2
|
||||
packed_w = latent_w // 2
|
||||
|
||||
return rearrange(
|
||||
x,
|
||||
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||
h=packed_h,
|
||||
w=packed_w,
|
||||
ph=2,
|
||||
pw=2,
|
||||
)
|
||||
|
||||
|
||||
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
||||
"""Compute mu for FLUX.2 schedule shifting.
|
||||
|
||||
Uses a fixed mu value of 2.02, matching ComfyUI's proven FLUX.2 configuration.
|
||||
|
||||
The previous implementation (from diffusers' FLUX.1 pipeline) computed mu as a
|
||||
linear function of image_seq_len, which produced excessively high values at
|
||||
high resolutions (e.g., mu=3.23 at 2048x2048). This over-shifted the sigma
|
||||
schedule, compressing almost all values above 0.9 and forcing the model to
|
||||
denoise everything in the final 1-2 steps, causing severe grid/diamond artifacts.
|
||||
|
||||
ComfyUI uses a fixed shift=2.02 for FLUX.2 Klein at all resolutions and produces
|
||||
artifact-free images even at 2048x2048.
|
||||
|
||||
Args:
|
||||
image_seq_len: Number of image tokens (packed_h * packed_w). Currently unused.
|
||||
num_steps: Number of denoising steps. Currently unused.
|
||||
|
||||
Returns:
|
||||
The mu value (fixed at 2.02).
|
||||
"""
|
||||
return 2.02
|
||||
|
||||
|
||||
def get_schedule_flux2(
|
||||
num_steps: int,
|
||||
image_seq_len: int,
|
||||
) -> list[float]:
|
||||
"""Get linear timestep schedule for FLUX.2.
|
||||
|
||||
Returns a linear sigma schedule from 1.0 to 1/num_steps.
|
||||
The actual schedule shifting is handled by the FlowMatchEulerDiscreteScheduler
|
||||
using the mu parameter and use_dynamic_shifting=True.
|
||||
|
||||
Args:
|
||||
num_steps: Number of denoising steps.
|
||||
image_seq_len: Number of image tokens (packed_h * packed_w). Currently unused,
|
||||
but kept for API compatibility. The scheduler computes shifting internally.
|
||||
|
||||
Returns:
|
||||
List of linear sigmas from 1.0 to 1/num_steps, plus final 0.0.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Create linear sigmas from 1.0 to 1/num_steps
|
||||
# The scheduler will apply dynamic shifting using mu parameter
|
||||
sigmas = np.linspace(1.0, 1 / num_steps, num_steps)
|
||||
sigmas_list = [float(s) for s in sigmas]
|
||||
|
||||
# Add final 0.0 for the last step (scheduler needs n+1 timesteps for n steps)
|
||||
sigmas_list.append(0.0)
|
||||
|
||||
return sigmas_list
|
||||
|
||||
|
||||
def generate_img_ids_flux2(h: int, w: int, batch_size: int, device: torch.device) -> torch.Tensor:
|
||||
"""Generate tensor of image position ids for FLUX.2 with RoPE scaling.
|
||||
|
||||
FLUX.2 uses 4D position coordinates (T, H, W, L) for its rotary position embeddings.
|
||||
This is different from FLUX.1 which uses 3D coordinates.
|
||||
|
||||
RoPE Scaling: For resolutions >1536x1536, position IDs are scaled down using
|
||||
Position Interpolation to prevent RoPE degradation and diamond/grid artifacts.
|
||||
|
||||
IMPORTANT: Position IDs must use int64 (long) dtype like diffusers, not bfloat16.
|
||||
Using floating point dtype for position IDs can cause NaN in rotary embeddings.
|
||||
|
||||
Args:
|
||||
h: Height of image in latent space.
|
||||
w: Width of image in latent space.
|
||||
batch_size: Batch size.
|
||||
device: Device.
|
||||
|
||||
Returns:
|
||||
Image position ids tensor of shape (batch_size, h/2*w/2, 4) with int64 dtype.
|
||||
"""
|
||||
# After packing, spatial dims are h/2 x w/2
|
||||
packed_h = h // 2
|
||||
packed_w = w // 2
|
||||
|
||||
# Create coordinate grids - 4D: (T, H, W, L)
|
||||
# T = time/batch index, H = height, W = width, L = layer/channel
|
||||
# Use int64 (long) dtype like diffusers
|
||||
img_ids = torch.zeros(packed_h, packed_w, 4, device=device, dtype=torch.long)
|
||||
|
||||
# T (time/batch) coordinate - set to 0 (already initialized)
|
||||
# H coordinates
|
||||
img_ids[..., 1] = torch.arange(packed_h, device=device, dtype=torch.long)[:, None]
|
||||
# W coordinates
|
||||
img_ids[..., 2] = torch.arange(packed_w, device=device, dtype=torch.long)[None, :]
|
||||
# L (layer) coordinate - set to 0 (already initialized)
|
||||
|
||||
# Flatten and expand for batch
|
||||
img_ids = img_ids.reshape(1, packed_h * packed_w, 4)
|
||||
img_ids = img_ids.expand(batch_size, -1, -1)
|
||||
|
||||
return img_ids
|
||||
367
invokeai/backend/image_util/pbr_maps/architecture/block.py
Normal file
367
invokeai/backend/image_util/pbr_maps/architecture/block.py
Normal file
@@ -0,0 +1,367 @@
|
||||
# Original: https://github.com/joeyballentine/Material-Map-Generator
|
||||
# Adopted and optimized for Invoke AI
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Any, List, Literal, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
ACTIVATION_LAYER_TYPE = Literal["relu", "leakyrelu", "prelu"]
|
||||
NORMALIZATION_LAYER_TYPE = Literal["batch", "instance"]
|
||||
PADDING_LAYER_TYPE = Literal["zero", "reflect", "replicate"]
|
||||
BLOCK_MODE = Literal["CNA", "NAC", "CNAC"]
|
||||
UPCONV_BLOCK_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear"]
|
||||
|
||||
|
||||
def act(act_type: ACTIVATION_LAYER_TYPE, inplace: bool = True, neg_slope: float = 0.2, n_prelu: int = 1):
|
||||
"""Helper to select Activation Layer"""
|
||||
if act_type == "relu":
|
||||
layer = nn.ReLU(inplace)
|
||||
elif act_type == "leakyrelu":
|
||||
layer = nn.LeakyReLU(neg_slope, inplace)
|
||||
elif act_type == "prelu":
|
||||
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
||||
return layer
|
||||
|
||||
|
||||
def norm(norm_type: NORMALIZATION_LAYER_TYPE, nc: int):
|
||||
"""Helper to select Normalization Layer"""
|
||||
if norm_type == "batch":
|
||||
layer = nn.BatchNorm2d(nc, affine=True)
|
||||
elif norm_type == "instance":
|
||||
layer = nn.InstanceNorm2d(nc, affine=False)
|
||||
return layer
|
||||
|
||||
|
||||
def pad(pad_type: PADDING_LAYER_TYPE, padding: int):
|
||||
"""Helper to select Padding Layer"""
|
||||
if padding == 0 or pad_type == "zero":
|
||||
return None
|
||||
if pad_type == "reflect":
|
||||
layer = nn.ReflectionPad2d(padding)
|
||||
elif pad_type == "replicate":
|
||||
layer = nn.ReplicationPad2d(padding)
|
||||
return layer
|
||||
|
||||
|
||||
def get_valid_padding(kernel_size: int, dilation: int):
|
||||
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
||||
padding = (kernel_size - 1) // 2
|
||||
return padding
|
||||
|
||||
|
||||
def sequential(*args: Any):
|
||||
# Flatten Sequential. It unwraps nn.Sequential.
|
||||
if len(args) == 1:
|
||||
if isinstance(args[0], OrderedDict):
|
||||
raise NotImplementedError("sequential does not support OrderedDict input.")
|
||||
return args[0] # No sequential is needed.
|
||||
modules: List[nn.Module] = []
|
||||
for module in args:
|
||||
if isinstance(module, nn.Sequential):
|
||||
for submodule in module.children():
|
||||
modules.append(submodule)
|
||||
elif isinstance(module, nn.Module):
|
||||
modules.append(module)
|
||||
return nn.Sequential(*modules)
|
||||
|
||||
|
||||
def conv_block(
|
||||
in_nc: int,
|
||||
out_nc: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
pad_type: Optional[PADDING_LAYER_TYPE] = "zero",
|
||||
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
|
||||
act_type: Optional[ACTIVATION_LAYER_TYPE] = "relu",
|
||||
mode: BLOCK_MODE = "CNA",
|
||||
):
|
||||
"""
|
||||
Conv layer with padding, normalization, activation
|
||||
mode: CNA --> Conv -> Norm -> Act
|
||||
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
|
||||
"""
|
||||
assert mode in ["CNA", "NAC", "CNAC"], f"Wrong conv mode [{mode}]"
|
||||
padding = get_valid_padding(kernel_size, dilation)
|
||||
p = pad(pad_type, padding) if pad_type else None
|
||||
padding = padding if pad_type == "zero" else 0
|
||||
|
||||
c = nn.Conv2d(
|
||||
in_nc,
|
||||
out_nc,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias,
|
||||
groups=groups,
|
||||
)
|
||||
a = act(act_type) if act_type else None
|
||||
match mode:
|
||||
case "CNA":
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
return sequential(p, c, n, a)
|
||||
case "NAC":
|
||||
if norm_type is None and act_type is not None:
|
||||
a = act(act_type, inplace=False)
|
||||
n = norm(norm_type, in_nc) if norm_type else None
|
||||
return sequential(n, a, p, c)
|
||||
case "CNAC":
|
||||
n = norm(norm_type, in_nc) if norm_type else None
|
||||
return sequential(n, a, p, c)
|
||||
|
||||
|
||||
class ConcatBlock(nn.Module):
|
||||
# Concat the output of a submodule to its input
|
||||
def __init__(self, submodule: nn.Module):
|
||||
super(ConcatBlock, self).__init__()
|
||||
self.sub = submodule
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
output = torch.cat((x, self.sub(x)), dim=1)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
tmpstr = "Identity .. \n|"
|
||||
modstr = self.sub.__repr__().replace("\n", "\n|")
|
||||
tmpstr = tmpstr + modstr
|
||||
return tmpstr
|
||||
|
||||
|
||||
class ShortcutBlock(nn.Module):
|
||||
# Elementwise sum the output of a submodule to its input
|
||||
def __init__(self, submodule: nn.Module):
|
||||
super(ShortcutBlock, self).__init__()
|
||||
self.sub = submodule
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
output = x + self.sub(x)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
tmpstr = "Identity + \n|"
|
||||
modstr = self.sub.__repr__().replace("\n", "\n|")
|
||||
tmpstr = tmpstr + modstr
|
||||
return tmpstr
|
||||
|
||||
|
||||
class ShortcutBlockSPSR(nn.Module):
|
||||
# Elementwise sum the output of a submodule to its input
|
||||
def __init__(self, submodule: nn.Module):
|
||||
super(ShortcutBlockSPSR, self).__init__()
|
||||
self.sub = submodule
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x, self.sub
|
||||
|
||||
def __repr__(self):
|
||||
tmpstr = "Identity + \n|"
|
||||
modstr = self.sub.__repr__().replace("\n", "\n|")
|
||||
tmpstr = tmpstr + modstr
|
||||
return tmpstr
|
||||
|
||||
|
||||
class ResNetBlock(nn.Module):
|
||||
"""
|
||||
ResNet Block, 3-3 style
|
||||
with extra residual scaling used in EDSR
|
||||
(Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_nc: int,
|
||||
mid_nc: int,
|
||||
out_nc: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
pad_type: PADDING_LAYER_TYPE = "zero",
|
||||
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
|
||||
act_type: Optional[ACTIVATION_LAYER_TYPE] = "relu",
|
||||
mode: BLOCK_MODE = "CNA",
|
||||
res_scale: int = 1,
|
||||
):
|
||||
super(ResNetBlock, self).__init__()
|
||||
conv0 = conv_block(
|
||||
in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, norm_type, act_type, mode
|
||||
)
|
||||
if mode == "CNA":
|
||||
act_type = None
|
||||
if mode == "CNAC": # Residual path: |-CNAC-|
|
||||
act_type = None
|
||||
norm_type = None
|
||||
conv1 = conv_block(
|
||||
mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, norm_type, act_type, mode
|
||||
)
|
||||
|
||||
self.res = sequential(conv0, conv1)
|
||||
self.res_scale = res_scale
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
res = self.res(x).mul(self.res_scale)
|
||||
return x + res
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
"""
|
||||
Residual Dense Block
|
||||
style: 5 convs
|
||||
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nc: int,
|
||||
kernel_size: int = 3,
|
||||
gc: int = 32,
|
||||
stride: int = 1,
|
||||
bias: bool = True,
|
||||
pad_type: PADDING_LAYER_TYPE = "zero",
|
||||
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
|
||||
act_type: ACTIVATION_LAYER_TYPE = "leakyrelu",
|
||||
mode: BLOCK_MODE = "CNA",
|
||||
):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = conv_block(
|
||||
nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode
|
||||
)
|
||||
self.conv2 = conv_block(
|
||||
nc + gc,
|
||||
gc,
|
||||
kernel_size,
|
||||
stride,
|
||||
bias=bias,
|
||||
pad_type=pad_type,
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
)
|
||||
self.conv3 = conv_block(
|
||||
nc + 2 * gc,
|
||||
gc,
|
||||
kernel_size,
|
||||
stride,
|
||||
bias=bias,
|
||||
pad_type=pad_type,
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
)
|
||||
self.conv4 = conv_block(
|
||||
nc + 3 * gc,
|
||||
gc,
|
||||
kernel_size,
|
||||
stride,
|
||||
bias=bias,
|
||||
pad_type=pad_type,
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
)
|
||||
if mode == "CNA":
|
||||
last_act = None
|
||||
else:
|
||||
last_act = act_type
|
||||
self.conv5 = conv_block(
|
||||
nc + 4 * gc, nc, 3, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=last_act, mode=mode
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x1 = self.conv1(x)
|
||||
x2 = self.conv2(torch.cat((x, x1), 1))
|
||||
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
||||
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5.mul(0.2) + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
"""
|
||||
Residual in Residual Dense Block
|
||||
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nc: int,
|
||||
kernel_size: int = 3,
|
||||
gc: int = 32,
|
||||
stride: int = 1,
|
||||
bias: bool = True,
|
||||
pad_type: PADDING_LAYER_TYPE = "zero",
|
||||
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
|
||||
act_type: ACTIVATION_LAYER_TYPE = "leakyrelu",
|
||||
mode: BLOCK_MODE = "CNA",
|
||||
):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out.mul(0.2) + x
|
||||
|
||||
|
||||
# Upsampler
|
||||
def pixelshuffle_block(
|
||||
in_nc: int,
|
||||
out_nc: int,
|
||||
upscale_factor: int = 2,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
bias: bool = True,
|
||||
pad_type: PADDING_LAYER_TYPE = "zero",
|
||||
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
|
||||
act_type: ACTIVATION_LAYER_TYPE = "relu",
|
||||
):
|
||||
"""
|
||||
Pixel shuffle layer
|
||||
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
||||
Neural Network, CVPR17)
|
||||
"""
|
||||
conv = conv_block(
|
||||
in_nc,
|
||||
out_nc * (upscale_factor**2),
|
||||
kernel_size,
|
||||
stride,
|
||||
bias=bias,
|
||||
pad_type=pad_type,
|
||||
norm_type=None,
|
||||
act_type=None,
|
||||
)
|
||||
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
a = act(act_type) if act_type else None
|
||||
return sequential(conv, pixel_shuffle, n, a)
|
||||
|
||||
|
||||
def upconv_block(
|
||||
in_nc: int,
|
||||
out_nc: int,
|
||||
upscale_factor: int = 2,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
bias: bool = True,
|
||||
pad_type: PADDING_LAYER_TYPE = "zero",
|
||||
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
|
||||
act_type: ACTIVATION_LAYER_TYPE = "relu",
|
||||
mode: UPCONV_BLOCK_MODE = "nearest",
|
||||
):
|
||||
# Adopted from https://distill.pub/2016/deconv-checkerboard/
|
||||
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
|
||||
conv = conv_block(
|
||||
in_nc, out_nc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type
|
||||
)
|
||||
return sequential(upsample, conv)
|
||||
@@ -0,0 +1,70 @@
|
||||
# Original: https://github.com/joeyballentine/Material-Map-Generator
|
||||
# Adopted and optimized for Invoke AI
|
||||
|
||||
import math
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import invokeai.backend.image_util.pbr_maps.architecture.block as B
|
||||
|
||||
UPSCALE_MODE = Literal["upconv", "pixelshuffle"]
|
||||
|
||||
|
||||
class PBR_RRDB_Net(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_nc: int,
|
||||
out_nc: int,
|
||||
nf: int,
|
||||
nb: int,
|
||||
gc: int = 32,
|
||||
upscale: int = 4,
|
||||
norm_type: Optional[B.NORMALIZATION_LAYER_TYPE] = None,
|
||||
act_type: B.ACTIVATION_LAYER_TYPE = "leakyrelu",
|
||||
mode: B.BLOCK_MODE = "CNA",
|
||||
res_scale: int = 1,
|
||||
upsample_mode: UPSCALE_MODE = "upconv",
|
||||
):
|
||||
super(PBR_RRDB_Net, self).__init__()
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
if upscale == 3:
|
||||
n_upscale = 1
|
||||
|
||||
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
|
||||
rb_blocks = [
|
||||
B.RRDB(
|
||||
nf,
|
||||
kernel_size=3,
|
||||
gc=32,
|
||||
stride=1,
|
||||
bias=True,
|
||||
pad_type="zero",
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode="CNA",
|
||||
)
|
||||
for _ in range(nb)
|
||||
]
|
||||
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
|
||||
|
||||
if upsample_mode == "upconv":
|
||||
upsample_block = B.upconv_block
|
||||
elif upsample_mode == "pixelshuffle":
|
||||
upsample_block = B.pixelshuffle_block
|
||||
|
||||
if upscale == 3:
|
||||
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
|
||||
else:
|
||||
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
|
||||
|
||||
HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
|
||||
HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
|
||||
|
||||
self.model = B.sequential(
|
||||
fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)), *upsampler, HR_conv0, HR_conv1
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.model(x)
|
||||
141
invokeai/backend/image_util/pbr_maps/pbr_maps.py
Normal file
141
invokeai/backend/image_util/pbr_maps/pbr_maps.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# Original: https://github.com/joeyballentine/Material-Map-Generator
|
||||
# Adopted and optimized for Invoke AI
|
||||
|
||||
import pathlib
|
||||
from typing import Any, Literal
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from invokeai.backend.image_util.pbr_maps.architecture.pbr_rrdb_net import PBR_RRDB_Net
|
||||
from invokeai.backend.image_util.pbr_maps.utils.image_ops import crop_seamless, esrgan_launcher_split_merge
|
||||
|
||||
NORMAL_MAP_MODEL = (
|
||||
"https://huggingface.co/InvokeAI/pbr-material-maps/resolve/main/normal_map_generator.safetensors?download=true"
|
||||
)
|
||||
OTHER_MAP_MODEL = (
|
||||
"https://huggingface.co/InvokeAI/pbr-material-maps/resolve/main/franken_map_generator.safetensors?download=true"
|
||||
)
|
||||
|
||||
|
||||
class PBRMapsGenerator:
|
||||
def __init__(self, normal_map_model: PBR_RRDB_Net, other_map_model: PBR_RRDB_Net, device: torch.device) -> None:
|
||||
self.normal_map_model = normal_map_model
|
||||
self.other_map_model = other_map_model
|
||||
self.device = device
|
||||
|
||||
@staticmethod
|
||||
def load_model(model_path: pathlib.Path, device: torch.device) -> PBR_RRDB_Net:
|
||||
state_dict = load_file(model_path.as_posix(), device=device.type)
|
||||
|
||||
model = PBR_RRDB_Net(
|
||||
3,
|
||||
3,
|
||||
32,
|
||||
12,
|
||||
gc=32,
|
||||
upscale=1,
|
||||
norm_type=None,
|
||||
act_type="leakyrelu",
|
||||
mode="CNA",
|
||||
res_scale=1,
|
||||
upsample_mode="upconv",
|
||||
)
|
||||
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
del state_dict
|
||||
if torch.cuda.is_available() and device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model.eval()
|
||||
|
||||
for _, v in model.named_parameters():
|
||||
v.requires_grad = False
|
||||
|
||||
return model.to(device)
|
||||
|
||||
def process(self, img: npt.NDArray[Any], model: PBR_RRDB_Net):
|
||||
img = img.astype(np.float32) / np.iinfo(img.dtype).max
|
||||
img = img[..., ::-1].copy()
|
||||
tensor_img = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(tensor_img).data.squeeze(0).float().cpu().clamp_(0, 1).numpy()
|
||||
output = output[[2, 1, 0], :, :]
|
||||
output = np.transpose(output, (1, 2, 0))
|
||||
output = (output * 255.0).round()
|
||||
return output
|
||||
|
||||
def _cv2_to_pil(self, image: npt.NDArray[Any]):
|
||||
return Image.fromarray(cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_RGB2BGR))
|
||||
|
||||
def generate_maps(
|
||||
self,
|
||||
image: Image.Image,
|
||||
tile_size: int = 512,
|
||||
border_mode: Literal["none", "seamless", "mirror", "replicate"] = "none",
|
||||
):
|
||||
"""
|
||||
Generate PBR texture maps (normal, roughness, and displacement) from an input image.
|
||||
The image can optionally be padded before inference to control how borders are treated,
|
||||
which can help create seamless or edge‑consistent textures.
|
||||
|
||||
Args:
|
||||
image: Source image used to generate the PBR maps.
|
||||
tile_size: Maximum tile size used for tiled inference. If the image is larger than
|
||||
this size in either dimension, it will be split into tiles for processing and
|
||||
then merged.
|
||||
|
||||
border_mode: Strategy for padding the image before inference:
|
||||
- "none": No padding is applied; the image is processed as‑is.
|
||||
- "seamless": Pads the image using wrap‑around tiling
|
||||
(`cv2.BORDER_WRAP`) to help produce seamless textures.
|
||||
- "mirror": Pads the image by mirroring border pixels
|
||||
(`cv2.BORDER_REFLECT_101`) to reduce edge artifacts.
|
||||
- "replicate": Pads the image by replicating the edge pixels outward
|
||||
(`cv2.BORDER_REPLICATE`).
|
||||
|
||||
Returns:
|
||||
A tuple of three PIL Images:
|
||||
- normal_map: RGB normal map generated from the input.
|
||||
- roughness: Single‑channel roughness map extracted from the second model output.
|
||||
- displacement: Single‑channel displacement (height) map extracted from the
|
||||
second model output.
|
||||
"""
|
||||
|
||||
models = [self.normal_map_model, self.other_map_model]
|
||||
np_image = np.array(image).astype(np.uint8)
|
||||
|
||||
match border_mode:
|
||||
case "seamless":
|
||||
np_image = cv2.copyMakeBorder(np_image, 16, 16, 16, 16, cv2.BORDER_WRAP)
|
||||
case "mirror":
|
||||
np_image = cv2.copyMakeBorder(np_image, 16, 16, 16, 16, cv2.BORDER_REFLECT_101)
|
||||
case "replicate":
|
||||
np_image = cv2.copyMakeBorder(np_image, 16, 16, 16, 16, cv2.BORDER_REPLICATE)
|
||||
case "none":
|
||||
pass
|
||||
|
||||
img_height, img_width = np_image.shape[:2]
|
||||
|
||||
# Checking whether to perform tiled inference
|
||||
do_split = img_height > tile_size or img_width > tile_size
|
||||
|
||||
if do_split:
|
||||
rlts = esrgan_launcher_split_merge(np_image, self.process, models, scale_factor=1, tile_size=tile_size)
|
||||
else:
|
||||
rlts = [self.process(np_image, model) for model in models]
|
||||
|
||||
if border_mode != "none":
|
||||
rlts = [crop_seamless(rlt) for rlt in rlts]
|
||||
|
||||
normal_map = self._cv2_to_pil(rlts[0])
|
||||
roughness = self._cv2_to_pil(rlts[1][:, :, 1])
|
||||
displacement = self._cv2_to_pil(rlts[1][:, :, 0])
|
||||
|
||||
return normal_map, roughness, displacement
|
||||
93
invokeai/backend/image_util/pbr_maps/utils/image_ops.py
Normal file
93
invokeai/backend/image_util/pbr_maps/utils/image_ops.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Original: https://github.com/joeyballentine/Material-Map-Generator
|
||||
# Adopted and optimized for Invoke AI
|
||||
|
||||
import math
|
||||
from typing import Any, Callable, List
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from invokeai.backend.image_util.pbr_maps.architecture.pbr_rrdb_net import PBR_RRDB_Net
|
||||
|
||||
|
||||
def crop_seamless(img: npt.NDArray[Any]):
|
||||
img_height, img_width = img.shape[:2]
|
||||
y, x = 16, 16
|
||||
h, w = img_height - 32, img_width - 32
|
||||
img = img[y : y + h, x : x + w]
|
||||
return img
|
||||
|
||||
|
||||
# from https://github.com/ata4/esrgan-launcher/blob/master/upscale.py
|
||||
def esrgan_launcher_split_merge(
|
||||
input_image: npt.NDArray[Any],
|
||||
upscale_function: Callable[[npt.NDArray[Any], PBR_RRDB_Net], npt.NDArray[Any]],
|
||||
models: List[PBR_RRDB_Net],
|
||||
scale_factor: int = 4,
|
||||
tile_size: int = 512,
|
||||
tile_padding: float = 0.125,
|
||||
):
|
||||
width, height, depth = input_image.shape
|
||||
output_width = width * scale_factor
|
||||
output_height = height * scale_factor
|
||||
output_shape = (output_width, output_height, depth)
|
||||
|
||||
# start with black image
|
||||
output_images = [np.zeros(output_shape, np.uint8) for _ in range(len(models))]
|
||||
|
||||
tile_padding = math.ceil(tile_size * tile_padding)
|
||||
tile_size = math.ceil(tile_size / scale_factor)
|
||||
|
||||
tiles_x = math.ceil(width / tile_size)
|
||||
tiles_y = math.ceil(height / tile_size)
|
||||
|
||||
for y in range(tiles_y):
|
||||
for x in range(tiles_x):
|
||||
# extract tile from input image
|
||||
ofs_x = x * tile_size
|
||||
ofs_y = y * tile_size
|
||||
|
||||
# input tile area on total image
|
||||
input_start_x = ofs_x
|
||||
input_end_x = min(ofs_x + tile_size, width)
|
||||
|
||||
input_start_y = ofs_y
|
||||
input_end_y = min(ofs_y + tile_size, height)
|
||||
|
||||
# input tile area on total image with padding
|
||||
input_start_x_pad = max(input_start_x - tile_padding, 0)
|
||||
input_end_x_pad = min(input_end_x + tile_padding, width)
|
||||
|
||||
input_start_y_pad = max(input_start_y - tile_padding, 0)
|
||||
input_end_y_pad = min(input_end_y + tile_padding, height)
|
||||
|
||||
# input tile dimensions
|
||||
input_tile_width = input_end_x - input_start_x
|
||||
input_tile_height = input_end_y - input_start_y
|
||||
|
||||
input_tile = input_image[input_start_x_pad:input_end_x_pad, input_start_y_pad:input_end_y_pad]
|
||||
|
||||
for idx, model in enumerate(models):
|
||||
# upscale tile
|
||||
output_tile = upscale_function(input_tile, model)
|
||||
|
||||
# output tile area on total image
|
||||
output_start_x = input_start_x * scale_factor
|
||||
output_end_x = input_end_x * scale_factor
|
||||
|
||||
output_start_y = input_start_y * scale_factor
|
||||
output_end_y = input_end_y * scale_factor
|
||||
|
||||
# output tile area without padding
|
||||
output_start_x_tile = (input_start_x - input_start_x_pad) * scale_factor
|
||||
output_end_x_tile = output_start_x_tile + input_tile_width * scale_factor
|
||||
|
||||
output_start_y_tile = (input_start_y - input_start_y_pad) * scale_factor
|
||||
output_end_y_tile = output_start_y_tile + input_tile_height * scale_factor
|
||||
|
||||
# put tile into output image
|
||||
output_images[idx][output_start_x:output_end_x, output_start_y:output_end_y] = output_tile[
|
||||
output_start_x_tile:output_end_x_tile, output_start_y_tile:output_end_y_tile
|
||||
]
|
||||
|
||||
return output_images
|
||||
212
invokeai/backend/model_manager/README.md
Normal file
212
invokeai/backend/model_manager/README.md
Normal file
@@ -0,0 +1,212 @@
|
||||
# Model Management System
|
||||
|
||||
This document describes Invoke's model management system and common tasks for extending model support.
|
||||
|
||||
## Overview
|
||||
|
||||
The model management system handles the full lifecycle of models: identification, loading, and running. The system is extensible and supports multiple model architectures, formats, and quantization schemes.
|
||||
|
||||
### Three Major Subsystems
|
||||
|
||||
1. **Model Identification** (`configs/`): Determines model type, architecture, format, and metadata when users install models.
|
||||
2. **Model Loading** (`load/`): Loads models from disk into memory for inference.
|
||||
3. **Model Running**: Executes inference on loaded models. Implementation is scattered across the codebase, typically in architecture-specific inference code adjacent to `model_manager/`. The inference code is run in nodes in the graph execution system.
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### Model Taxonomy
|
||||
|
||||
The `taxonomy.py` module defines the type system for models:
|
||||
|
||||
- `ModelType`: The kind of model (e.g., `Main`, `LoRA`, `ControlNet`, `VAE`).
|
||||
- `ModelFormat`: Storage format - may imply a quantization or some other quality (e.g., `Diffusers`, `Checkpoint`, `LyCORIS`, `BnbQuantizednf4b`).
|
||||
- `BaseModelType`: Associated pipeline architecture (e.g., `StableDiffusion1`, `StableDiffusionXL`, `Flux`). Models without an associated base use `Any` (e.g., `CLIPVision` is its own thing).
|
||||
- `ModelVariantType`, `FluxVariantType`, `ClipVariantType`: Architecture-specific variants.
|
||||
|
||||
These enums form a discriminated union that uniquely identifies each model configuration class.
|
||||
|
||||
### Model "Configs"
|
||||
|
||||
Model configs are Pydantic models that describe a model on disk. They include the model taxonomy, path, and any metadata needed for loading or running the model.
|
||||
|
||||
Model configs are stored in the database.
|
||||
|
||||
### Model Identification
|
||||
|
||||
When a user installs a model, the system attempts to identify it by trying each registered config class until one matches.
|
||||
|
||||
**Config Classes** (`configs/`):
|
||||
|
||||
- All config classes inherit from `Config_Base`, either directly or indirectly via some intermediary class (e.g., `Diffusers_Config_Base`, `Checkpoint_Config_Base`, or something narrower).
|
||||
- Each config class represents a specific, unique combination of `type`, `format`, `base`, and optional `variant`.
|
||||
- Config classes must implement `from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict) -> Self`. This method inspects the model on disk and raises `NotAMatchError` if the model doesn't match the config class, or returns an instance of the config class if it does.
|
||||
- `ModelOnDisk` is a helper class that abstracts the model weights. It should be the entrypoint for inspecting the model (e.g., loading state dicts).
|
||||
- Override fields allow users to provide hints (e.g., when differentiating between SD1/SD2/SDXL VAEs with identical structures).
|
||||
|
||||
**Identification Process**:
|
||||
|
||||
1. `ModelConfigFactory.from_model_on_disk()` is called with a path to the model.
|
||||
2. The factory iterates through all registered config classes, calling `from_model_on_disk()` on each.
|
||||
3. Each config class inspects the model (state dict keys, tensor shapes, config files, etc.).
|
||||
4. If a match is found, the config instance is returned. If multiple matches are found, they are prioritized (e.g., main models over LoRAs).
|
||||
5. If no match is found, an `Unknown_Config` is returned as a fallback.
|
||||
|
||||
**Utilities** (`identification_utils.py`):
|
||||
|
||||
- `NotAMatchError`: Exception raised when a model doesn't match a config class.
|
||||
- `get_config_dict_or_raise()`: Load JSON config files from diffusers/transformers models.
|
||||
- `raise_for_class_name()`: Validate class names in config files.
|
||||
- `raise_for_override_fields()`: Validate user-provided override fields against the config schema.
|
||||
- `state_dict_has_any_keys_*()`: Helpers for inspecting state dict keys.
|
||||
|
||||
### Model Loading
|
||||
|
||||
Model loaders handle instantiating models from disk into memory.
|
||||
|
||||
**Loader Classes** (`load/model_loaders/`):
|
||||
|
||||
- Loaders register themselves with a decorator `@ModelLoaderRegistry.register(base=..., type=..., format=...)`. The `type`, `format` and `base` indicate which configs classes the loader can handle.
|
||||
- Each loader implements `_load_model(self, config: AnyModelConfig, submodel_type: Optional[SubModelType]) -> AnyModel`.
|
||||
- Loaders are responsible for:
|
||||
- Loading model weights from the config's path.
|
||||
- Instantiating the correct model class (often using diffusers, transformers, or custom implementations).
|
||||
- Returning the in-memory model representation.
|
||||
|
||||
**Model Cache** (`load/model_cache/`):
|
||||
|
||||
> This system typically does not require changes to support new model types, but it is important to understand how it works.
|
||||
|
||||
- Manages models in memory with RAM and VRAM limits.
|
||||
- Handles moving models between CPU (storage device) and GPU (execution device).
|
||||
- Implements LRU eviction for RAM and smallest-first offload for VRAM.
|
||||
- Supports partial loading for large models on CUDA.
|
||||
- Thread-safe with locks on all public methods.
|
||||
|
||||
**Loading Process**:
|
||||
|
||||
1. The appropriate loader is selected based on the model config's `base`, `type`, and `format` attributes.
|
||||
2. The loader's `_load_model()` method is called with the model config.
|
||||
3. The loaded model is added to the model cache via `ModelCache.put()`.
|
||||
4. When needed, the model is moved into VRAM via `ModelCache.get()` and `ModelCache.lock()`.
|
||||
|
||||
### Model Running
|
||||
|
||||
Model running is architecture-specific and typically implemented in folders adjacent to `model_manager/`.
|
||||
|
||||
Inference code doesn't necessarily follow any specific pattern, and doesn't interact directly with the model management system except to receive model configs and loaded models.
|
||||
|
||||
At a high level, when a node needs to run a model, it will:
|
||||
|
||||
- Receive a model identifier as an input or constant. This is typically the model's database ID (aka the `key`).
|
||||
- The node will use the `InvocationContext` API to load the model. The request is dispatched to the model manager which will load the model and return the a model loader with a context manager that yields the in-memory model, mediating VRAM/RAM management as needed.
|
||||
- The node will run inference using the loaded model using whatever patterns or libraries it needs.
|
||||
|
||||
## Common Tasks
|
||||
|
||||
### Task 1: Improving Identification for a Supported Model Type
|
||||
|
||||
When identification fails or produces incorrect results for a model that should be supported, you may need to refine the identification logic.
|
||||
|
||||
**Steps**:
|
||||
|
||||
1. Obtain the failing model file or directory.
|
||||
2. Create a test case for it, following the instructions in `tests/model_identification/README.md`.
|
||||
3. Review the relevant config class in `configs/` (e.g., `configs/lora.py` for LoRA models).
|
||||
4. Examine the `from_model_on_disk()` method for some existing models to understand the patterns for identification logic.
|
||||
5. Inspect the failing model's files and structure:
|
||||
- For checkpoint files: Load the state dict and examine keys and tensor shapes.
|
||||
- For diffusers models: Examine the config files and directory structure.
|
||||
6. Update the identification logic to handle the new model variant. Common approaches:
|
||||
- Check for specific state dict keys or key patterns.
|
||||
- Inspect tensor shapes (e.g., `state_dict[key].shape`).
|
||||
- Parse config files for class names or configuration values.
|
||||
- Use helper functions from `identification_utils.py`.
|
||||
7. Run the test suite to verify the new logic works and doesn't break existing tests: `pytest tests/model_identification/test_identification.py`.
|
||||
- Make sure you have installed the test dependencies (e.g. `uv pip install -e ".[dev,test]"`).
|
||||
- If the model type is complex or has multiple variants, consider adding more test cases to cover edge cases.
|
||||
8. If, after successfully adding identification support for the model, it still doesn't work, you may need to update loading and/or inference code as well.
|
||||
|
||||
**Key Files**:
|
||||
|
||||
- Config class: `configs/<model_type>.py`
|
||||
- Identification utilities: `configs/identification_utils.py`
|
||||
- Taxonomy: `taxonomy.py`
|
||||
- Test README: `tests/model_identification/README.md`
|
||||
|
||||
### Task 2: Adding Support for a New Model Type
|
||||
|
||||
Adding a new model type requires implementing identification and loading logic. Inference and new nodes ("invocations") may be required if the model type doesn't fit into existing architectures or nodes.
|
||||
|
||||
**Steps**:
|
||||
|
||||
#### 1. Define Taxonomy
|
||||
|
||||
- Add a new `ModelType` enum value in `taxonomy.py` if needed.
|
||||
- Determine the appropriate `BaseModelType` (or use `Any` if not architecture-specific).
|
||||
- Add a new `ModelFormat` if the model uses a unique storage format.
|
||||
|
||||
You may need to add other attributes, depending on the model.
|
||||
|
||||
#### 2. Implement Config Class
|
||||
|
||||
- Create a new config file in `configs/` (e.g., `configs/new_model.py`).
|
||||
- Define a config class inheriting from `Config_Base` and appropriate format base class:
|
||||
- `Diffusers_Config_Base` for diffusers-style models.
|
||||
- `Checkpoint_Config_Base` for single-file checkpoint models.
|
||||
- Define `type`, `format`, and `base` as `Literal` fields with defaults. Remember, these must uniquely identify the config class.
|
||||
- Implement `from_model_on_disk()`:
|
||||
- Validate the model is the correct format (file vs directory).
|
||||
- Inspect state dict keys, tensor shapes, or config files.
|
||||
- Raise `NotAMatchError` if the model doesn't match.
|
||||
- Extract any additional metadata needed (e.g., variant, prediction type).
|
||||
- Return an instance of the config class.
|
||||
- Register the config in `configs/factory.py`:
|
||||
- Add the config class to the `AnyModelConfig` union.
|
||||
- Add an `Annotated[YourConfig, YourConfig.get_tag()]` entry.
|
||||
|
||||
#### 3. Implement Loader Class
|
||||
|
||||
- Create a new loader file in `load/model_loaders/` (e.g., `load/model_loaders/new_model.py`).
|
||||
- Define a loader class inheriting from `ModelLoader`.
|
||||
- Decorate with `@ModelLoaderRegistry.register(base=..., type=..., format=...)`.
|
||||
- Implement `_load_model()`:
|
||||
- Load model weights from `config.path`.
|
||||
- Instantiate the model using the appropriate library (diffusers, transformers, or custom).
|
||||
- Handle `submodel_type` if the model has submodels (e.g., text encoders, VAE).
|
||||
- Return the in-memory model representation.
|
||||
|
||||
#### 4. Add Tests
|
||||
|
||||
Follow the instructions in `tests/model_identification/README.md`.
|
||||
|
||||
#### 5. Implement Inference and Nodes (if needed)
|
||||
|
||||
- If the model type requires new inference logic, implement it in an appropriate location.
|
||||
- Create nodes for the model if it doesn't fit into existing nodes. Search for subclasses of `BaseInvocation` for many examples.
|
||||
|
||||
### 6. Frontend Support
|
||||
|
||||
#### Workflows tab
|
||||
|
||||
Typically, you will not need to do anything for the model to work in the Workflow Editor. When you define the node's model field, you can provide constraints for what type of models are selectable. The UI will automatically filter the list of models based on the model taxonomy.
|
||||
|
||||
For example, this field definition in a node will allow users to select only "main" (pipeline) Stable Diffusion 1.x or 2.x models:
|
||||
|
||||
```py
|
||||
model: ModelIdentifierField = InputField(
|
||||
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2],
|
||||
ui_model_type=ModelType.Main,
|
||||
)
|
||||
```
|
||||
|
||||
This same pattern works for any combination of `type`, `base`, `format`, and `variant`.
|
||||
|
||||
#### Canvas / Generate tabs
|
||||
|
||||
The Canvas and Generate tabs use graphs internally, but they don't expose the full graph editor UI. Instead, they provide a simplified interface for common tasks.
|
||||
|
||||
They use "graph builder" functions, which take the user's selected settings and build a graph behind the scenes. We have one graph builder for each model architecture.
|
||||
|
||||
Updating or adding a graph builder can be a bit complex, and you'd likely need to update other UI components and state management to support the new model type.
|
||||
|
||||
The SDXL graph builder is a good example: `invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts`
|
||||
@@ -28,17 +28,6 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class URLModelSource(BaseModel):
|
||||
type: Literal[ModelSourceType.Url] = Field(default=ModelSourceType.Url)
|
||||
url: str = Field(
|
||||
description="The URL from which the model was installed.",
|
||||
)
|
||||
api_response: str | None = Field(
|
||||
default=None,
|
||||
description="The original API response from the source, as stringified JSON.",
|
||||
)
|
||||
|
||||
|
||||
class Config_Base(ABC, BaseModel):
|
||||
"""
|
||||
Abstract base class for model configurations. A model config describes a specific combination of model base, type and
|
||||
|
||||
@@ -88,7 +88,9 @@ class ControlNet_Diffusers_Config_Base(Diffusers_Config_Base):
|
||||
|
||||
cls._validate_base(mod)
|
||||
|
||||
return cls(**override_fields)
|
||||
repo_variant = {"repo_variant": override_fields.get("repo_variant", cls._get_repo_variant_or_raise(mod))}
|
||||
args = override_fields | repo_variant
|
||||
return cls(**args)
|
||||
|
||||
@classmethod
|
||||
def _validate_base(cls, mod: ModelOnDisk) -> None:
|
||||
@@ -228,3 +230,47 @@ class ControlNet_Checkpoint_SDXL_Config(ControlNet_Checkpoint_Config_Base, Confi
|
||||
|
||||
class ControlNet_Checkpoint_FLUX_Config(ControlNet_Checkpoint_Config_Base, Config_Base):
|
||||
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
||||
|
||||
|
||||
def _has_z_image_control_keys(state_dict: dict) -> bool:
|
||||
"""Check if state dict contains Z-Image Control specific keys."""
|
||||
z_image_control_keys = {"control_layers", "control_all_x_embedder", "control_noise_refiner"}
|
||||
for key in state_dict.keys():
|
||||
if isinstance(key, str):
|
||||
prefix = key.split(".")[0]
|
||||
if prefix in z_image_control_keys:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ControlNet_Checkpoint_ZImage_Config(Checkpoint_Config_Base, Config_Base):
|
||||
"""Model config for Z-Image Control adapter models (Safetensors checkpoint).
|
||||
|
||||
Z-Image Control models are standalone adapters containing only the control layers
|
||||
(control_layers, control_all_x_embedder, control_noise_refiner) that extend
|
||||
the base Z-Image transformer with spatial conditioning capabilities.
|
||||
|
||||
Supports: Canny, HED, Depth, Pose, MLSD.
|
||||
Recommended control_context_scale: 0.65-0.80.
|
||||
"""
|
||||
|
||||
type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet)
|
||||
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
||||
base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)
|
||||
default_settings: ControlAdapterDefaultSettings | None = Field(None)
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
||||
raise_if_not_file(mod)
|
||||
|
||||
raise_for_override_fields(cls, override_fields)
|
||||
|
||||
cls._validate_looks_like_z_image_control(mod)
|
||||
|
||||
return cls(**override_fields)
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_z_image_control(cls, mod: ModelOnDisk) -> None:
|
||||
state_dict = mod.load_state_dict()
|
||||
if not _has_z_image_control_keys(state_dict):
|
||||
raise NotAMatchError("state dict does not look like a Z-Image Control model")
|
||||
|
||||
@@ -20,6 +20,7 @@ from invokeai.backend.model_manager.configs.controlnet import (
|
||||
ControlNet_Checkpoint_SD1_Config,
|
||||
ControlNet_Checkpoint_SD2_Config,
|
||||
ControlNet_Checkpoint_SDXL_Config,
|
||||
ControlNet_Checkpoint_ZImage_Config,
|
||||
ControlNet_Diffusers_FLUX_Config,
|
||||
ControlNet_Diffusers_SD1_Config,
|
||||
ControlNet_Diffusers_SD2_Config,
|
||||
@@ -43,30 +44,44 @@ from invokeai.backend.model_manager.configs.lora import (
|
||||
LoRA_Diffusers_SD1_Config,
|
||||
LoRA_Diffusers_SD2_Config,
|
||||
LoRA_Diffusers_SDXL_Config,
|
||||
LoRA_Diffusers_ZImage_Config,
|
||||
LoRA_LyCORIS_FLUX_Config,
|
||||
LoRA_LyCORIS_SD1_Config,
|
||||
LoRA_LyCORIS_SD2_Config,
|
||||
LoRA_LyCORIS_SDXL_Config,
|
||||
LoRA_LyCORIS_ZImage_Config,
|
||||
LoRA_OMI_FLUX_Config,
|
||||
LoRA_OMI_SDXL_Config,
|
||||
LoraModelDefaultSettings,
|
||||
)
|
||||
from invokeai.backend.model_manager.configs.main import (
|
||||
Main_BnBNF4_FLUX_Config,
|
||||
Main_Checkpoint_Flux2_Config,
|
||||
Main_Checkpoint_FLUX_Config,
|
||||
Main_Checkpoint_SD1_Config,
|
||||
Main_Checkpoint_SD2_Config,
|
||||
Main_Checkpoint_SDXL_Config,
|
||||
Main_Checkpoint_SDXLRefiner_Config,
|
||||
Main_Checkpoint_ZImage_Config,
|
||||
Main_Diffusers_CogView4_Config,
|
||||
Main_Diffusers_Flux2_Config,
|
||||
Main_Diffusers_FLUX_Config,
|
||||
Main_Diffusers_SD1_Config,
|
||||
Main_Diffusers_SD2_Config,
|
||||
Main_Diffusers_SD3_Config,
|
||||
Main_Diffusers_SDXL_Config,
|
||||
Main_Diffusers_SDXLRefiner_Config,
|
||||
Main_Diffusers_ZImage_Config,
|
||||
Main_GGUF_Flux2_Config,
|
||||
Main_GGUF_FLUX_Config,
|
||||
Main_GGUF_ZImage_Config,
|
||||
MainModelDefaultSettings,
|
||||
)
|
||||
from invokeai.backend.model_manager.configs.qwen3_encoder import (
|
||||
Qwen3Encoder_Checkpoint_Config,
|
||||
Qwen3Encoder_GGUF_Config,
|
||||
Qwen3Encoder_Qwen3Encoder_Config,
|
||||
)
|
||||
from invokeai.backend.model_manager.configs.siglip import SigLIP_Diffusers_Config
|
||||
from invokeai.backend.model_manager.configs.spandrel import Spandrel_Checkpoint_Config
|
||||
from invokeai.backend.model_manager.configs.t2i_adapter import (
|
||||
@@ -84,10 +99,12 @@ from invokeai.backend.model_manager.configs.textual_inversion import (
|
||||
)
|
||||
from invokeai.backend.model_manager.configs.unknown import Unknown_Config
|
||||
from invokeai.backend.model_manager.configs.vae import (
|
||||
VAE_Checkpoint_Flux2_Config,
|
||||
VAE_Checkpoint_FLUX_Config,
|
||||
VAE_Checkpoint_SD1_Config,
|
||||
VAE_Checkpoint_SD2_Config,
|
||||
VAE_Checkpoint_SDXL_Config,
|
||||
VAE_Diffusers_Flux2_Config,
|
||||
VAE_Diffusers_SD1_Config,
|
||||
VAE_Diffusers_SDXL_Config,
|
||||
)
|
||||
@@ -137,29 +154,43 @@ AnyModelConfig = Annotated[
|
||||
Annotated[Main_Diffusers_SDXL_Config, Main_Diffusers_SDXL_Config.get_tag()],
|
||||
Annotated[Main_Diffusers_SDXLRefiner_Config, Main_Diffusers_SDXLRefiner_Config.get_tag()],
|
||||
Annotated[Main_Diffusers_SD3_Config, Main_Diffusers_SD3_Config.get_tag()],
|
||||
Annotated[Main_Diffusers_FLUX_Config, Main_Diffusers_FLUX_Config.get_tag()],
|
||||
Annotated[Main_Diffusers_Flux2_Config, Main_Diffusers_Flux2_Config.get_tag()],
|
||||
Annotated[Main_Diffusers_CogView4_Config, Main_Diffusers_CogView4_Config.get_tag()],
|
||||
Annotated[Main_Diffusers_ZImage_Config, Main_Diffusers_ZImage_Config.get_tag()],
|
||||
# Main (Pipeline) - checkpoint format
|
||||
# IMPORTANT: FLUX.2 must be checked BEFORE FLUX.1 because FLUX.2 has specific validation
|
||||
# that will reject FLUX.1 models, but FLUX.1 validation may incorrectly match FLUX.2 models
|
||||
Annotated[Main_Checkpoint_SD1_Config, Main_Checkpoint_SD1_Config.get_tag()],
|
||||
Annotated[Main_Checkpoint_SD2_Config, Main_Checkpoint_SD2_Config.get_tag()],
|
||||
Annotated[Main_Checkpoint_SDXL_Config, Main_Checkpoint_SDXL_Config.get_tag()],
|
||||
Annotated[Main_Checkpoint_SDXLRefiner_Config, Main_Checkpoint_SDXLRefiner_Config.get_tag()],
|
||||
Annotated[Main_Checkpoint_Flux2_Config, Main_Checkpoint_Flux2_Config.get_tag()],
|
||||
Annotated[Main_Checkpoint_FLUX_Config, Main_Checkpoint_FLUX_Config.get_tag()],
|
||||
Annotated[Main_Checkpoint_ZImage_Config, Main_Checkpoint_ZImage_Config.get_tag()],
|
||||
# Main (Pipeline) - quantized formats
|
||||
# IMPORTANT: FLUX.2 must be checked BEFORE FLUX.1 because FLUX.2 has specific validation
|
||||
# that will reject FLUX.1 models, but FLUX.1 validation may incorrectly match FLUX.2 models
|
||||
Annotated[Main_BnBNF4_FLUX_Config, Main_BnBNF4_FLUX_Config.get_tag()],
|
||||
Annotated[Main_GGUF_Flux2_Config, Main_GGUF_Flux2_Config.get_tag()],
|
||||
Annotated[Main_GGUF_FLUX_Config, Main_GGUF_FLUX_Config.get_tag()],
|
||||
Annotated[Main_GGUF_ZImage_Config, Main_GGUF_ZImage_Config.get_tag()],
|
||||
# VAE - checkpoint format
|
||||
Annotated[VAE_Checkpoint_SD1_Config, VAE_Checkpoint_SD1_Config.get_tag()],
|
||||
Annotated[VAE_Checkpoint_SD2_Config, VAE_Checkpoint_SD2_Config.get_tag()],
|
||||
Annotated[VAE_Checkpoint_SDXL_Config, VAE_Checkpoint_SDXL_Config.get_tag()],
|
||||
Annotated[VAE_Checkpoint_FLUX_Config, VAE_Checkpoint_FLUX_Config.get_tag()],
|
||||
Annotated[VAE_Checkpoint_Flux2_Config, VAE_Checkpoint_Flux2_Config.get_tag()],
|
||||
# VAE - diffusers format
|
||||
Annotated[VAE_Diffusers_SD1_Config, VAE_Diffusers_SD1_Config.get_tag()],
|
||||
Annotated[VAE_Diffusers_SDXL_Config, VAE_Diffusers_SDXL_Config.get_tag()],
|
||||
Annotated[VAE_Diffusers_Flux2_Config, VAE_Diffusers_Flux2_Config.get_tag()],
|
||||
# ControlNet - checkpoint format
|
||||
Annotated[ControlNet_Checkpoint_SD1_Config, ControlNet_Checkpoint_SD1_Config.get_tag()],
|
||||
Annotated[ControlNet_Checkpoint_SD2_Config, ControlNet_Checkpoint_SD2_Config.get_tag()],
|
||||
Annotated[ControlNet_Checkpoint_SDXL_Config, ControlNet_Checkpoint_SDXL_Config.get_tag()],
|
||||
Annotated[ControlNet_Checkpoint_FLUX_Config, ControlNet_Checkpoint_FLUX_Config.get_tag()],
|
||||
Annotated[ControlNet_Checkpoint_ZImage_Config, ControlNet_Checkpoint_ZImage_Config.get_tag()],
|
||||
# ControlNet - diffusers format
|
||||
Annotated[ControlNet_Diffusers_SD1_Config, ControlNet_Diffusers_SD1_Config.get_tag()],
|
||||
Annotated[ControlNet_Diffusers_SD2_Config, ControlNet_Diffusers_SD2_Config.get_tag()],
|
||||
@@ -170,6 +201,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[LoRA_LyCORIS_SD2_Config, LoRA_LyCORIS_SD2_Config.get_tag()],
|
||||
Annotated[LoRA_LyCORIS_SDXL_Config, LoRA_LyCORIS_SDXL_Config.get_tag()],
|
||||
Annotated[LoRA_LyCORIS_FLUX_Config, LoRA_LyCORIS_FLUX_Config.get_tag()],
|
||||
Annotated[LoRA_LyCORIS_ZImage_Config, LoRA_LyCORIS_ZImage_Config.get_tag()],
|
||||
# LoRA - OMI format
|
||||
Annotated[LoRA_OMI_SDXL_Config, LoRA_OMI_SDXL_Config.get_tag()],
|
||||
Annotated[LoRA_OMI_FLUX_Config, LoRA_OMI_FLUX_Config.get_tag()],
|
||||
@@ -178,11 +210,16 @@ AnyModelConfig = Annotated[
|
||||
Annotated[LoRA_Diffusers_SD2_Config, LoRA_Diffusers_SD2_Config.get_tag()],
|
||||
Annotated[LoRA_Diffusers_SDXL_Config, LoRA_Diffusers_SDXL_Config.get_tag()],
|
||||
Annotated[LoRA_Diffusers_FLUX_Config, LoRA_Diffusers_FLUX_Config.get_tag()],
|
||||
Annotated[LoRA_Diffusers_ZImage_Config, LoRA_Diffusers_ZImage_Config.get_tag()],
|
||||
# ControlLoRA - diffusers format
|
||||
Annotated[ControlLoRA_LyCORIS_FLUX_Config, ControlLoRA_LyCORIS_FLUX_Config.get_tag()],
|
||||
# T5 Encoder - all formats
|
||||
Annotated[T5Encoder_T5Encoder_Config, T5Encoder_T5Encoder_Config.get_tag()],
|
||||
Annotated[T5Encoder_BnBLLMint8_Config, T5Encoder_BnBLLMint8_Config.get_tag()],
|
||||
# Qwen3 Encoder
|
||||
Annotated[Qwen3Encoder_Qwen3Encoder_Config, Qwen3Encoder_Qwen3Encoder_Config.get_tag()],
|
||||
Annotated[Qwen3Encoder_Checkpoint_Config, Qwen3Encoder_Checkpoint_Config.get_tag()],
|
||||
Annotated[Qwen3Encoder_GGUF_Config, Qwen3Encoder_GGUF_Config.get_tag()],
|
||||
# TI - file format
|
||||
Annotated[TI_File_SD1_Config, TI_File_SD1_Config.get_tag()],
|
||||
Annotated[TI_File_SD2_Config, TI_File_SD2_Config.get_tag()],
|
||||
@@ -333,7 +370,11 @@ class ModelConfigFactory:
|
||||
# For directories, do a quick file count check with early exit
|
||||
total_files = 0
|
||||
# Ignore hidden files and directories
|
||||
paths_to_check = (p for p in path.rglob("*") if not p.name.startswith("."))
|
||||
paths_to_check = (
|
||||
p
|
||||
for p in path.rglob("*")
|
||||
if not p.name.startswith(".") and not any(part.startswith(".") for part in p.parts)
|
||||
)
|
||||
for item in paths_to_check:
|
||||
if item.is_file():
|
||||
total_files += 1
|
||||
@@ -473,7 +514,9 @@ class ModelConfigFactory:
|
||||
# Now do any post-processing needed for specific model types/bases/etc.
|
||||
match config.type:
|
||||
case ModelType.Main:
|
||||
config.default_settings = MainModelDefaultSettings.from_base(config.base)
|
||||
# Pass variant if available (e.g., for Flux2 models)
|
||||
variant = getattr(config, "variant", None)
|
||||
config.default_settings = MainModelDefaultSettings.from_base(config.base, variant)
|
||||
case ModelType.ControlNet | ModelType.T2IAdapter | ModelType.ControlLoRa:
|
||||
config.default_settings = ControlAdapterDefaultSettings.from_model_name(config.name)
|
||||
case ModelType.LoRA:
|
||||
|
||||
@@ -150,11 +150,16 @@ class LoRA_LyCORIS_Config_Base(LoRA_Config_Base):
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
|
||||
# First rule out ControlLoRA and Diffusers LoRA
|
||||
# First rule out ControlLoRA
|
||||
flux_format = _get_flux_lora_format(mod)
|
||||
if flux_format in [FluxLoRAFormat.Control]:
|
||||
raise NotAMatchError("model looks like Control LoRA")
|
||||
|
||||
# If it's a recognized Flux LoRA format (Kohya, Diffusers, OneTrainer, AIToolkit, XLabs, etc.),
|
||||
# it's valid and we skip the heuristic check
|
||||
if flux_format is not None:
|
||||
return
|
||||
|
||||
# Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
|
||||
# Some main models have these keys, likely due to the creator merging in a LoRA.
|
||||
has_key_with_lora_prefix = state_dict_has_any_keys_starting_with(
|
||||
@@ -217,6 +222,73 @@ class LoRA_LyCORIS_FLUX_Config(LoRA_LyCORIS_Config_Base, Config_Base):
|
||||
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
||||
|
||||
|
||||
class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base):
|
||||
"""Model config for Z-Image LoRA models in LyCORIS format."""
|
||||
|
||||
base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
|
||||
"""Z-Image LoRAs have different key patterns than SD/SDXL LoRAs.
|
||||
|
||||
Z-Image LoRAs use keys like:
|
||||
- diffusion_model.layers.X.attention.to_k.lora_down.weight (DoRA format)
|
||||
- diffusion_model.layers.X.attention.to_k.lora_A.weight (PEFT format)
|
||||
- diffusion_model.layers.X.attention.to_k.dora_scale (DoRA scale)
|
||||
"""
|
||||
state_dict = mod.load_state_dict()
|
||||
|
||||
# Check for Z-Image specific LoRA patterns
|
||||
has_z_image_lora_keys = state_dict_has_any_keys_starting_with(
|
||||
state_dict,
|
||||
{
|
||||
"diffusion_model.layers.", # Z-Image S3-DiT layer pattern
|
||||
},
|
||||
)
|
||||
|
||||
# Also check for LoRA weight suffixes (various formats)
|
||||
has_lora_suffix = state_dict_has_any_keys_ending_with(
|
||||
state_dict,
|
||||
{
|
||||
"lora_A.weight",
|
||||
"lora_B.weight",
|
||||
"lora_down.weight",
|
||||
"lora_up.weight",
|
||||
"dora_scale",
|
||||
},
|
||||
)
|
||||
|
||||
if has_z_image_lora_keys and has_lora_suffix:
|
||||
return
|
||||
|
||||
raise NotAMatchError("model does not match Z-Image LoRA heuristics")
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
|
||||
"""Z-Image LoRAs are identified by their diffusion_model.layers structure.
|
||||
|
||||
Z-Image uses S3-DiT architecture with layer names like:
|
||||
- diffusion_model.layers.0.attention.to_k.lora_A.weight
|
||||
- diffusion_model.layers.0.feed_forward.w1.lora_A.weight
|
||||
"""
|
||||
state_dict = mod.load_state_dict()
|
||||
|
||||
# Check for Z-Image transformer layer patterns
|
||||
# Z-Image uses diffusion_model.layers.X structure (unlike Flux which uses double_blocks/single_blocks)
|
||||
has_z_image_keys = state_dict_has_any_keys_starting_with(
|
||||
state_dict,
|
||||
{
|
||||
"diffusion_model.layers.", # Z-Image S3-DiT layer pattern
|
||||
},
|
||||
)
|
||||
|
||||
# If it looks like a Z-Image LoRA, return ZImage base
|
||||
if has_z_image_keys:
|
||||
return BaseModelType.ZImage
|
||||
|
||||
raise NotAMatchError("model does not look like a Z-Image LoRA")
|
||||
|
||||
|
||||
class ControlAdapter_Config_Base(ABC, BaseModel):
|
||||
default_settings: ControlAdapterDefaultSettings | None = Field(None)
|
||||
|
||||
@@ -320,3 +392,9 @@ class LoRA_Diffusers_SDXL_Config(LoRA_Diffusers_Config_Base, Config_Base):
|
||||
|
||||
class LoRA_Diffusers_FLUX_Config(LoRA_Diffusers_Config_Base, Config_Base):
|
||||
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
||||
|
||||
|
||||
class LoRA_Diffusers_ZImage_Config(LoRA_Diffusers_Config_Base, Config_Base):
|
||||
"""Model config for Z-Image LoRA models in Diffusers format."""
|
||||
|
||||
base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)
|
||||
|
||||
@@ -23,6 +23,7 @@ from invokeai.backend.model_manager.configs.identification_utils import (
|
||||
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
BaseModelType,
|
||||
Flux2VariantType,
|
||||
FluxVariantType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
@@ -52,7 +53,11 @@ class MainModelDefaultSettings(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
@classmethod
|
||||
def from_base(cls, base: BaseModelType) -> Self | None:
|
||||
def from_base(
|
||||
cls,
|
||||
base: BaseModelType,
|
||||
variant: Flux2VariantType | FluxVariantType | ModelVariantType | None = None,
|
||||
) -> Self | None:
|
||||
match base:
|
||||
case BaseModelType.StableDiffusion1:
|
||||
return cls(width=512, height=512)
|
||||
@@ -60,6 +65,16 @@ class MainModelDefaultSettings(BaseModel):
|
||||
return cls(width=768, height=768)
|
||||
case BaseModelType.StableDiffusionXL:
|
||||
return cls(width=1024, height=1024)
|
||||
case BaseModelType.ZImage:
|
||||
return cls(steps=9, cfg_scale=1.0, width=1024, height=1024)
|
||||
case BaseModelType.Flux2:
|
||||
# Different defaults based on variant
|
||||
if variant == Flux2VariantType.Klein9BBase:
|
||||
# Undistilled base model needs more steps
|
||||
return cls(steps=28, cfg_scale=1.0, width=1024, height=1024)
|
||||
else:
|
||||
# Distilled models (Klein 4B, Klein 9B) use fewer steps
|
||||
return cls(steps=4, cfg_scale=1.0, width=1024, height=1024)
|
||||
case _:
|
||||
# TODO(psyche): Do we want defaults for other base types?
|
||||
return None
|
||||
@@ -111,6 +126,47 @@ def _has_main_keys(state_dict: dict[str | int, Any]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool:
|
||||
"""Check if state dict contains Z-Image S3-DiT transformer keys.
|
||||
|
||||
This function returns True only for Z-Image main models, not LoRAs.
|
||||
LoRAs are excluded by checking for LoRA-specific weight suffixes.
|
||||
"""
|
||||
# Z-Image specific keys that distinguish it from other models
|
||||
z_image_specific_keys = {
|
||||
"cap_embedder", # Caption embedder - unique to Z-Image
|
||||
"context_refiner", # Context refiner blocks
|
||||
"cap_pad_token", # Caption padding token
|
||||
}
|
||||
|
||||
# LoRA-specific suffixes - if present, this is a LoRA not a main model
|
||||
lora_suffixes = (
|
||||
".lora_down.weight",
|
||||
".lora_up.weight",
|
||||
".lora_A.weight",
|
||||
".lora_B.weight",
|
||||
".dora_scale",
|
||||
)
|
||||
|
||||
for key in state_dict.keys():
|
||||
if isinstance(key, int):
|
||||
continue
|
||||
|
||||
# If we find any LoRA-specific keys, this is not a main model
|
||||
if key.endswith(lora_suffixes):
|
||||
return False
|
||||
|
||||
# Check for Z-Image specific key prefixes
|
||||
# Handle both direct keys (cap_embedder.0.weight) and
|
||||
# ComfyUI-style keys (model.diffusion_model.cap_embedder.0.weight)
|
||||
key_parts = key.split(".")
|
||||
for part in key_parts:
|
||||
if part in z_image_specific_keys:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class Main_SD_Checkpoint_Config_Base(Checkpoint_Config_Base, Main_Config_Base):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
@@ -225,6 +281,108 @@ class Main_Checkpoint_SDXLRefiner_Config(Main_SD_Checkpoint_Config_Base, Config_
|
||||
base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(default=BaseModelType.StableDiffusionXLRefiner)
|
||||
|
||||
|
||||
def _is_flux2_model(state_dict: dict[str | int, Any]) -> bool:
|
||||
"""Check if state dict is a FLUX.2 model by examining context_embedder dimensions.
|
||||
|
||||
FLUX.2 Klein uses Qwen3 encoder with larger context dimension:
|
||||
- FLUX.1: context_in_dim = 4096 (T5)
|
||||
- FLUX.2 Klein 4B: context_in_dim = 7680 (3×Qwen3-4B hidden size)
|
||||
- FLUX.2 Klein 8B: context_in_dim = 12288 (3×Qwen3-8B hidden size)
|
||||
|
||||
Also checks for FLUX.2-specific 32-channel latent space (in_channels=128 after packing).
|
||||
"""
|
||||
# Check context_embedder input dimension (most reliable)
|
||||
# Weight shape: [hidden_size, context_in_dim]
|
||||
for key in {"context_embedder.weight", "model.diffusion_model.context_embedder.weight"}:
|
||||
if key in state_dict:
|
||||
weight = state_dict[key]
|
||||
if hasattr(weight, "shape") and len(weight.shape) >= 2:
|
||||
context_in_dim = weight.shape[1]
|
||||
# FLUX.2 has context_in_dim > 4096 (Qwen3 vs T5)
|
||||
if context_in_dim > 4096:
|
||||
return True
|
||||
|
||||
# Also check in_channels - FLUX.2 uses 128 (32 latent channels × 4 packing)
|
||||
for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}:
|
||||
if key in state_dict:
|
||||
in_channels = state_dict[key].shape[1]
|
||||
# FLUX.2 uses 128 in_channels (32 latent channels × 4)
|
||||
# FLUX.1 uses 64 in_channels (16 latent channels × 4)
|
||||
if in_channels == 128:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | None:
|
||||
"""Determine FLUX.2 variant from state dict.
|
||||
|
||||
Distinguishes between Klein 4B and Klein 9B based on context embedding dimension:
|
||||
- Klein 4B: context_in_dim = 7680 (3 × Qwen3-4B hidden_size 2560)
|
||||
- Klein 9B: context_in_dim = 12288 (3 × Qwen3-8B hidden_size 4096)
|
||||
|
||||
Note: Klein 9B Base (undistilled) also has context_in_dim = 12288 but is rare.
|
||||
We default to Klein9B (distilled) for all 9B models since GGUF models may not
|
||||
include guidance embedding keys needed to distinguish them.
|
||||
|
||||
Supports both BFL format (checkpoint) and diffusers format keys:
|
||||
- BFL format: txt_in.weight (context embedder)
|
||||
- Diffusers format: context_embedder.weight
|
||||
"""
|
||||
# Context dimensions for each variant
|
||||
KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560
|
||||
KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096
|
||||
|
||||
# Check context_embedder to determine variant
|
||||
# Support both BFL format (txt_in.weight) and diffusers format (context_embedder.weight)
|
||||
context_keys = {
|
||||
# Diffusers format
|
||||
"context_embedder.weight",
|
||||
"model.diffusion_model.context_embedder.weight",
|
||||
# BFL format (used by checkpoint/GGUF models)
|
||||
"txt_in.weight",
|
||||
"model.diffusion_model.txt_in.weight",
|
||||
}
|
||||
for key in context_keys:
|
||||
if key in state_dict:
|
||||
weight = state_dict[key]
|
||||
# Handle GGUF quantized tensors which use tensor_shape instead of shape
|
||||
if hasattr(weight, "tensor_shape"):
|
||||
shape = weight.tensor_shape
|
||||
elif hasattr(weight, "shape"):
|
||||
shape = weight.shape
|
||||
else:
|
||||
continue
|
||||
if len(shape) >= 2:
|
||||
context_in_dim = shape[1]
|
||||
# Determine variant based on context dimension
|
||||
if context_in_dim == KLEIN_9B_CONTEXT_DIM:
|
||||
# Default to Klein9B (distilled) - the official/common 9B model
|
||||
return Flux2VariantType.Klein9B
|
||||
elif context_in_dim == KLEIN_4B_CONTEXT_DIM:
|
||||
return Flux2VariantType.Klein4B
|
||||
elif context_in_dim > 4096:
|
||||
# Unknown FLUX.2 variant, default to 4B
|
||||
return Flux2VariantType.Klein4B
|
||||
|
||||
# Check in_channels as backup - can only confirm it's FLUX.2, not which variant
|
||||
for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}:
|
||||
if key in state_dict:
|
||||
weight = state_dict[key]
|
||||
# Handle GGUF quantized tensors
|
||||
if hasattr(weight, "tensor_shape"):
|
||||
in_channels = weight.tensor_shape[1]
|
||||
elif hasattr(weight, "shape"):
|
||||
in_channels = weight.shape[1]
|
||||
else:
|
||||
continue
|
||||
if in_channels == 128:
|
||||
# It's FLUX.2 but we can't determine which Klein variant, default to 4B
|
||||
return Flux2VariantType.Klein4B
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | None:
|
||||
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
|
||||
|
||||
@@ -298,8 +456,9 @@ class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Conf
|
||||
|
||||
@classmethod
|
||||
def _validate_is_flux(cls, mod: ModelOnDisk) -> None:
|
||||
state_dict = mod.load_state_dict()
|
||||
if not state_dict_has_any_keys_exact(
|
||||
mod.load_state_dict(),
|
||||
state_dict,
|
||||
{
|
||||
"double_blocks.0.img_attn.norm.key_norm.scale",
|
||||
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
|
||||
@@ -307,6 +466,10 @@ class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Conf
|
||||
):
|
||||
raise NotAMatchError("state dict does not look like a FLUX checkpoint")
|
||||
|
||||
# Exclude FLUX.2 models - they have their own config class
|
||||
if _is_flux2_model(state_dict):
|
||||
raise NotAMatchError("model is a FLUX.2 model, not FLUX.1")
|
||||
|
||||
@classmethod
|
||||
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
|
||||
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
|
||||
@@ -340,6 +503,68 @@ class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Conf
|
||||
raise NotAMatchError("state dict looks like GGUF quantized")
|
||||
|
||||
|
||||
class Main_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
|
||||
"""Model config for FLUX.2 checkpoint models (e.g. Klein)."""
|
||||
|
||||
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
||||
base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
|
||||
|
||||
variant: Flux2VariantType = Field()
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
||||
raise_if_not_file(mod)
|
||||
|
||||
raise_for_override_fields(cls, override_fields)
|
||||
|
||||
cls._validate_looks_like_main_model(mod)
|
||||
|
||||
cls._validate_is_flux2(mod)
|
||||
|
||||
cls._validate_does_not_look_like_bnb_quantized(mod)
|
||||
|
||||
cls._validate_does_not_look_like_gguf_quantized(mod)
|
||||
|
||||
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
|
||||
|
||||
return cls(**override_fields, variant=variant)
|
||||
|
||||
@classmethod
|
||||
def _validate_is_flux2(cls, mod: ModelOnDisk) -> None:
|
||||
"""Validate that this is a FLUX.2 model, not FLUX.1."""
|
||||
state_dict = mod.load_state_dict()
|
||||
if not _is_flux2_model(state_dict):
|
||||
raise NotAMatchError("state dict does not look like a FLUX.2 model")
|
||||
|
||||
@classmethod
|
||||
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
|
||||
state_dict = mod.load_state_dict()
|
||||
variant = _get_flux2_variant(state_dict)
|
||||
|
||||
if variant is None:
|
||||
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
|
||||
|
||||
return variant
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
|
||||
has_main_model_keys = _has_main_keys(mod.load_state_dict())
|
||||
if not has_main_model_keys:
|
||||
raise NotAMatchError("state dict does not look like a main model")
|
||||
|
||||
@classmethod
|
||||
def _validate_does_not_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
|
||||
has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict())
|
||||
if has_bnb_nf4_keys:
|
||||
raise NotAMatchError("state dict looks like bnb quantized nf4")
|
||||
|
||||
@classmethod
|
||||
def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk):
|
||||
has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
|
||||
if has_ggml_tensors:
|
||||
raise NotAMatchError("state dict looks like GGUF quantized")
|
||||
|
||||
|
||||
class Main_BnBNF4_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
@@ -407,6 +632,8 @@ class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Bas
|
||||
|
||||
cls._validate_looks_like_gguf_quantized(mod)
|
||||
|
||||
cls._validate_is_not_flux2(mod)
|
||||
|
||||
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
|
||||
|
||||
return cls(**override_fields, variant=variant)
|
||||
@@ -437,6 +664,195 @@ class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Bas
|
||||
if not has_ggml_tensors:
|
||||
raise NotAMatchError("state dict does not look like GGUF quantized")
|
||||
|
||||
@classmethod
|
||||
def _validate_is_not_flux2(cls, mod: ModelOnDisk) -> None:
|
||||
"""Validate that this is NOT a FLUX.2 model."""
|
||||
state_dict = mod.load_state_dict()
|
||||
if _is_flux2_model(state_dict):
|
||||
raise NotAMatchError("model is a FLUX.2 model, not FLUX.1")
|
||||
|
||||
|
||||
class Main_GGUF_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
|
||||
"""Model config for GGUF-quantized FLUX.2 checkpoint models (e.g. Klein)."""
|
||||
|
||||
base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
|
||||
format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
|
||||
|
||||
variant: Flux2VariantType = Field()
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
||||
raise_if_not_file(mod)
|
||||
|
||||
raise_for_override_fields(cls, override_fields)
|
||||
|
||||
cls._validate_looks_like_main_model(mod)
|
||||
|
||||
cls._validate_looks_like_gguf_quantized(mod)
|
||||
|
||||
cls._validate_is_flux2(mod)
|
||||
|
||||
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
|
||||
|
||||
return cls(**override_fields, variant=variant)
|
||||
|
||||
@classmethod
|
||||
def _validate_is_flux2(cls, mod: ModelOnDisk) -> None:
|
||||
"""Validate that this is a FLUX.2 model, not FLUX.1."""
|
||||
state_dict = mod.load_state_dict()
|
||||
if not _is_flux2_model(state_dict):
|
||||
raise NotAMatchError("state dict does not look like a FLUX.2 model")
|
||||
|
||||
@classmethod
|
||||
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
|
||||
state_dict = mod.load_state_dict()
|
||||
variant = _get_flux2_variant(state_dict)
|
||||
|
||||
if variant is None:
|
||||
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
|
||||
|
||||
return variant
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
|
||||
has_main_model_keys = _has_main_keys(mod.load_state_dict())
|
||||
if not has_main_model_keys:
|
||||
raise NotAMatchError("state dict does not look like a main model")
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
|
||||
has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
|
||||
if not has_ggml_tensors:
|
||||
raise NotAMatchError("state dict does not look like GGUF quantized")
|
||||
|
||||
|
||||
class Main_Diffusers_FLUX_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
|
||||
"""Model config for FLUX.1 models in diffusers format."""
|
||||
|
||||
base: Literal[BaseModelType.Flux] = Field(BaseModelType.Flux)
|
||||
variant: FluxVariantType = Field()
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
||||
raise_if_not_dir(mod)
|
||||
|
||||
raise_for_override_fields(cls, override_fields)
|
||||
|
||||
# Check for FLUX-specific pipeline or transformer class names
|
||||
raise_for_class_name(
|
||||
common_config_paths(mod.path),
|
||||
{
|
||||
"FluxPipeline",
|
||||
"FluxFillPipeline",
|
||||
"FluxTransformer2DModel",
|
||||
},
|
||||
)
|
||||
|
||||
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
|
||||
|
||||
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
|
||||
|
||||
return cls(
|
||||
**override_fields,
|
||||
variant=variant,
|
||||
repo_variant=repo_variant,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
|
||||
"""Determine the FLUX variant from the transformer config.
|
||||
|
||||
FLUX variants are distinguished by:
|
||||
- in_channels: 64 for Dev/Schnell, 384 for DevFill
|
||||
- guidance_embeds: True for Dev, False for Schnell
|
||||
"""
|
||||
transformer_config = get_config_dict_or_raise(mod.path / "transformer" / "config.json")
|
||||
|
||||
in_channels = transformer_config.get("in_channels", 64)
|
||||
guidance_embeds = transformer_config.get("guidance_embeds", False)
|
||||
|
||||
# DevFill has 384 input channels
|
||||
if in_channels == 384:
|
||||
return FluxVariantType.DevFill
|
||||
|
||||
# Dev has guidance_embeds=True, Schnell has guidance_embeds=False
|
||||
if guidance_embeds:
|
||||
return FluxVariantType.Dev
|
||||
else:
|
||||
return FluxVariantType.Schnell
|
||||
|
||||
|
||||
class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
|
||||
"""Model config for FLUX.2 models in diffusers format (e.g. FLUX.2 Klein)."""
|
||||
|
||||
base: Literal[BaseModelType.Flux2] = Field(BaseModelType.Flux2)
|
||||
variant: Flux2VariantType = Field()
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
||||
raise_if_not_dir(mod)
|
||||
|
||||
raise_for_override_fields(cls, override_fields)
|
||||
|
||||
# Check for FLUX.2-specific pipeline class names
|
||||
raise_for_class_name(
|
||||
common_config_paths(mod.path),
|
||||
{
|
||||
"Flux2KleinPipeline",
|
||||
},
|
||||
)
|
||||
|
||||
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
|
||||
|
||||
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
|
||||
|
||||
return cls(
|
||||
**override_fields,
|
||||
variant=variant,
|
||||
repo_variant=repo_variant,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
|
||||
"""Determine the FLUX.2 variant from the transformer config.
|
||||
|
||||
FLUX.2 Klein uses Qwen3 text encoder with larger joint_attention_dim:
|
||||
- Klein 4B: joint_attention_dim = 7680 (3×Qwen3-4B hidden size)
|
||||
- Klein 9B/9B Base: joint_attention_dim = 12288 (3×Qwen3-8B hidden size)
|
||||
|
||||
To distinguish Klein 9B (distilled) from Klein 9B Base (undistilled),
|
||||
we check guidance_embeds:
|
||||
- Klein 9B (distilled): guidance_embeds = False (guidance is "baked in" during distillation)
|
||||
- Klein 9B Base (undistilled): guidance_embeds = True (needs guidance at inference)
|
||||
|
||||
Note: The official BFL Klein 9B model is the distilled version with guidance_embeds=False.
|
||||
"""
|
||||
KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560
|
||||
KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096
|
||||
|
||||
transformer_config = get_config_dict_or_raise(mod.path / "transformer" / "config.json")
|
||||
|
||||
joint_attention_dim = transformer_config.get("joint_attention_dim", 4096)
|
||||
guidance_embeds = transformer_config.get("guidance_embeds", False)
|
||||
|
||||
# Determine variant based on joint_attention_dim
|
||||
if joint_attention_dim == KLEIN_9B_CONTEXT_DIM:
|
||||
# Check guidance_embeds to distinguish distilled from undistilled
|
||||
# Klein 9B (distilled): guidance_embeds = False (guidance is baked in)
|
||||
# Klein 9B Base (undistilled): guidance_embeds = True (needs guidance)
|
||||
if guidance_embeds:
|
||||
return Flux2VariantType.Klein9BBase
|
||||
else:
|
||||
return Flux2VariantType.Klein9B
|
||||
elif joint_attention_dim == KLEIN_4B_CONTEXT_DIM:
|
||||
return Flux2VariantType.Klein4B
|
||||
elif joint_attention_dim > 4096:
|
||||
# Unknown FLUX.2 variant, default to 4B
|
||||
return Flux2VariantType.Klein4B
|
||||
|
||||
# Default to 4B
|
||||
return Flux2VariantType.Klein4B
|
||||
|
||||
|
||||
class Main_SD_Diffusers_Config_Base(Diffusers_Config_Base, Main_Config_Base):
|
||||
prediction_type: SchedulerPredictionType = Field()
|
||||
@@ -657,3 +1073,92 @@ class Main_Diffusers_CogView4_Config(Diffusers_Config_Base, Main_Config_Base, Co
|
||||
**override_fields,
|
||||
repo_variant=repo_variant,
|
||||
)
|
||||
|
||||
|
||||
class Main_Diffusers_ZImage_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
|
||||
"""Model config for Z-Image diffusers models (Z-Image-Turbo, Z-Image-Base, Z-Image-Edit)."""
|
||||
|
||||
base: Literal[BaseModelType.ZImage] = Field(BaseModelType.ZImage)
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
||||
raise_if_not_dir(mod)
|
||||
|
||||
raise_for_override_fields(cls, override_fields)
|
||||
|
||||
# This check implies the base type - no further validation needed.
|
||||
raise_for_class_name(
|
||||
common_config_paths(mod.path),
|
||||
{
|
||||
"ZImagePipeline",
|
||||
},
|
||||
)
|
||||
|
||||
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
|
||||
|
||||
return cls(
|
||||
**override_fields,
|
||||
repo_variant=repo_variant,
|
||||
)
|
||||
|
||||
|
||||
class Main_Checkpoint_ZImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
|
||||
"""Model config for Z-Image single-file checkpoint models (safetensors, etc)."""
|
||||
|
||||
base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)
|
||||
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
||||
raise_if_not_file(mod)
|
||||
|
||||
raise_for_override_fields(cls, override_fields)
|
||||
|
||||
cls._validate_looks_like_z_image_model(mod)
|
||||
|
||||
cls._validate_does_not_look_like_gguf_quantized(mod)
|
||||
|
||||
return cls(**override_fields)
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_z_image_model(cls, mod: ModelOnDisk) -> None:
|
||||
has_z_image_keys = _has_z_image_keys(mod.load_state_dict())
|
||||
if not has_z_image_keys:
|
||||
raise NotAMatchError("state dict does not look like a Z-Image model")
|
||||
|
||||
@classmethod
|
||||
def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
|
||||
has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
|
||||
if has_ggml_tensors:
|
||||
raise NotAMatchError("state dict looks like GGUF quantized")
|
||||
|
||||
|
||||
class Main_GGUF_ZImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
|
||||
"""Model config for GGUF-quantized Z-Image transformer models."""
|
||||
|
||||
base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)
|
||||
format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
||||
raise_if_not_file(mod)
|
||||
|
||||
raise_for_override_fields(cls, override_fields)
|
||||
|
||||
cls._validate_looks_like_z_image_model(mod)
|
||||
|
||||
cls._validate_looks_like_gguf_quantized(mod)
|
||||
|
||||
return cls(**override_fields)
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_z_image_model(cls, mod: ModelOnDisk) -> None:
|
||||
has_z_image_keys = _has_z_image_keys(mod.load_state_dict())
|
||||
if not has_z_image_keys:
|
||||
raise NotAMatchError("state dict does not look like a Z-Image model")
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
|
||||
has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
|
||||
if not has_ggml_tensors:
|
||||
raise NotAMatchError("state dict does not look like GGUF quantized")
|
||||
|
||||
265
invokeai/backend/model_manager/configs/qwen3_encoder.py
Normal file
265
invokeai/backend/model_manager/configs/qwen3_encoder.py
Normal file
@@ -0,0 +1,265 @@
|
||||
import json
|
||||
from typing import Any, Literal, Optional, Self
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Config_Base
|
||||
from invokeai.backend.model_manager.configs.identification_utils import (
|
||||
NotAMatchError,
|
||||
raise_for_class_name,
|
||||
raise_for_override_fields,
|
||||
raise_if_not_dir,
|
||||
raise_if_not_file,
|
||||
)
|
||||
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, Qwen3VariantType
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
|
||||
|
||||
def _has_qwen3_keys(state_dict: dict[str | int, Any]) -> bool:
|
||||
"""Check if state dict contains Qwen3 model keys.
|
||||
|
||||
Supports both:
|
||||
- PyTorch/diffusers format: model.layers.0., model.embed_tokens.weight
|
||||
- GGUF/llama.cpp format: blk.0., token_embd.weight
|
||||
"""
|
||||
# PyTorch/diffusers format indicators
|
||||
pytorch_indicators = ["model.layers.0.", "model.embed_tokens.weight"]
|
||||
# GGUF/llama.cpp format indicators
|
||||
gguf_indicators = ["blk.0.", "token_embd.weight"]
|
||||
|
||||
for key in state_dict.keys():
|
||||
if isinstance(key, str):
|
||||
# Check PyTorch format
|
||||
for indicator in pytorch_indicators:
|
||||
if key.startswith(indicator) or key == indicator:
|
||||
return True
|
||||
# Check GGUF format
|
||||
for indicator in gguf_indicators:
|
||||
if key.startswith(indicator) or key == indicator:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _has_ggml_tensors(state_dict: dict[str | int, Any]) -> bool:
|
||||
"""Check if state dict contains GGML tensors (GGUF quantized)."""
|
||||
return any(isinstance(v, GGMLTensor) for v in state_dict.values())
|
||||
|
||||
|
||||
def _get_qwen3_variant_from_state_dict(state_dict: dict[str | int, Any]) -> Optional[Qwen3VariantType]:
|
||||
"""Determine Qwen3 variant (4B vs 8B) from state dict based on hidden_size.
|
||||
|
||||
The hidden_size can be determined from the embed_tokens.weight tensor shape:
|
||||
- Qwen3 4B: hidden_size = 2560
|
||||
- Qwen3 8B: hidden_size = 4096
|
||||
|
||||
For GGUF format, the key is 'token_embd.weight'.
|
||||
For PyTorch format, the key is 'model.embed_tokens.weight'.
|
||||
"""
|
||||
# Hidden size thresholds
|
||||
QWEN3_4B_HIDDEN_SIZE = 2560
|
||||
QWEN3_8B_HIDDEN_SIZE = 4096
|
||||
|
||||
# Try to find embed_tokens weight
|
||||
embed_key = None
|
||||
for key in state_dict.keys():
|
||||
if isinstance(key, str):
|
||||
if key == "model.embed_tokens.weight" or key == "token_embd.weight":
|
||||
embed_key = key
|
||||
break
|
||||
|
||||
if embed_key is None:
|
||||
return None
|
||||
|
||||
tensor = state_dict[embed_key]
|
||||
|
||||
# Get hidden_size from tensor shape
|
||||
# Shape is [vocab_size, hidden_size]
|
||||
if isinstance(tensor, GGMLTensor):
|
||||
# GGUF tensor
|
||||
if hasattr(tensor, "shape") and len(tensor.shape) >= 2:
|
||||
hidden_size = tensor.shape[1]
|
||||
else:
|
||||
return None
|
||||
elif hasattr(tensor, "shape"):
|
||||
# PyTorch tensor
|
||||
if len(tensor.shape) >= 2:
|
||||
hidden_size = tensor.shape[1]
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
# Determine variant based on hidden_size
|
||||
if hidden_size == QWEN3_4B_HIDDEN_SIZE:
|
||||
return Qwen3VariantType.Qwen3_4B
|
||||
elif hidden_size == QWEN3_8B_HIDDEN_SIZE:
|
||||
return Qwen3VariantType.Qwen3_8B
|
||||
else:
|
||||
# Unknown size, default to 4B (more common)
|
||||
return Qwen3VariantType.Qwen3_4B
|
||||
|
||||
|
||||
class Qwen3Encoder_Checkpoint_Config(Checkpoint_Config_Base, Config_Base):
|
||||
"""Configuration for single-file Qwen3 Encoder models (safetensors)."""
|
||||
|
||||
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
||||
type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
|
||||
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
||||
variant: Qwen3VariantType = Field(description="Qwen3 model size variant (4B or 8B)")
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
||||
raise_if_not_file(mod)
|
||||
|
||||
raise_for_override_fields(cls, override_fields)
|
||||
|
||||
cls._validate_looks_like_qwen3_model(mod)
|
||||
|
||||
cls._validate_does_not_look_like_gguf_quantized(mod)
|
||||
|
||||
# Determine variant from state dict
|
||||
variant = cls._get_variant_or_default(mod)
|
||||
|
||||
return cls(variant=variant, **override_fields)
|
||||
|
||||
@classmethod
|
||||
def _get_variant_or_default(cls, mod: ModelOnDisk) -> Qwen3VariantType:
|
||||
"""Get variant from state dict, defaulting to 4B if unknown."""
|
||||
state_dict = mod.load_state_dict()
|
||||
variant = _get_qwen3_variant_from_state_dict(state_dict)
|
||||
return variant if variant is not None else Qwen3VariantType.Qwen3_4B
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_qwen3_model(cls, mod: ModelOnDisk) -> None:
|
||||
has_qwen3_keys = _has_qwen3_keys(mod.load_state_dict())
|
||||
if not has_qwen3_keys:
|
||||
raise NotAMatchError("state dict does not look like a Qwen3 model")
|
||||
|
||||
@classmethod
|
||||
def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
|
||||
has_ggml = _has_ggml_tensors(mod.load_state_dict())
|
||||
if has_ggml:
|
||||
raise NotAMatchError("state dict looks like GGUF quantized")
|
||||
|
||||
|
||||
class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
|
||||
"""Configuration for Qwen3 Encoder models in a diffusers-like format.
|
||||
|
||||
The model weights are expected to be in a folder called text_encoder inside the model directory,
|
||||
compatible with Qwen2VLForConditionalGeneration or similar architectures used by Z-Image.
|
||||
"""
|
||||
|
||||
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
||||
type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
|
||||
format: Literal[ModelFormat.Qwen3Encoder] = Field(default=ModelFormat.Qwen3Encoder)
|
||||
variant: Qwen3VariantType = Field(description="Qwen3 model size variant (4B or 8B)")
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
||||
raise_if_not_dir(mod)
|
||||
|
||||
raise_for_override_fields(cls, override_fields)
|
||||
|
||||
# Exclude full pipeline models - these should be matched as main models, not just Qwen3 encoders.
|
||||
# Full pipelines have model_index.json at root (diffusers format) or a transformer subfolder.
|
||||
model_index_path = mod.path / "model_index.json"
|
||||
transformer_path = mod.path / "transformer"
|
||||
if model_index_path.exists() or transformer_path.exists():
|
||||
raise NotAMatchError(
|
||||
"directory looks like a full diffusers pipeline (has model_index.json or transformer folder), "
|
||||
"not a standalone Qwen3 encoder"
|
||||
)
|
||||
|
||||
# Check for text_encoder config - support both:
|
||||
# 1. Full model structure: model_root/text_encoder/config.json
|
||||
# 2. Standalone text_encoder download: model_root/config.json (when text_encoder subfolder is downloaded separately)
|
||||
config_path_nested = mod.path / "text_encoder" / "config.json"
|
||||
config_path_direct = mod.path / "config.json"
|
||||
|
||||
if config_path_nested.exists():
|
||||
expected_config_path = config_path_nested
|
||||
elif config_path_direct.exists():
|
||||
expected_config_path = config_path_direct
|
||||
else:
|
||||
raise NotAMatchError(
|
||||
f"unable to load config file(s): {{PosixPath('{config_path_nested}'): 'file does not exist'}}"
|
||||
)
|
||||
|
||||
# Qwen3 uses Qwen2VLForConditionalGeneration or similar
|
||||
raise_for_class_name(
|
||||
expected_config_path,
|
||||
{
|
||||
"Qwen2VLForConditionalGeneration",
|
||||
"Qwen2ForCausalLM",
|
||||
"Qwen3ForCausalLM",
|
||||
},
|
||||
)
|
||||
|
||||
# Determine variant from config.json hidden_size
|
||||
variant = cls._get_variant_from_config(expected_config_path)
|
||||
|
||||
return cls(variant=variant, **override_fields)
|
||||
|
||||
@classmethod
|
||||
def _get_variant_from_config(cls, config_path) -> Qwen3VariantType:
|
||||
"""Get variant from config.json based on hidden_size."""
|
||||
QWEN3_4B_HIDDEN_SIZE = 2560
|
||||
QWEN3_8B_HIDDEN_SIZE = 4096
|
||||
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
hidden_size = config.get("hidden_size")
|
||||
if hidden_size == QWEN3_8B_HIDDEN_SIZE:
|
||||
return Qwen3VariantType.Qwen3_8B
|
||||
elif hidden_size == QWEN3_4B_HIDDEN_SIZE:
|
||||
return Qwen3VariantType.Qwen3_4B
|
||||
else:
|
||||
# Default to 4B for unknown sizes
|
||||
return Qwen3VariantType.Qwen3_4B
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return Qwen3VariantType.Qwen3_4B
|
||||
|
||||
|
||||
class Qwen3Encoder_GGUF_Config(Checkpoint_Config_Base, Config_Base):
|
||||
"""Configuration for GGUF-quantized Qwen3 Encoder models."""
|
||||
|
||||
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
||||
type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
|
||||
format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
|
||||
variant: Qwen3VariantType = Field(description="Qwen3 model size variant (4B or 8B)")
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
||||
raise_if_not_file(mod)
|
||||
|
||||
raise_for_override_fields(cls, override_fields)
|
||||
|
||||
cls._validate_looks_like_qwen3_model(mod)
|
||||
|
||||
cls._validate_looks_like_gguf_quantized(mod)
|
||||
|
||||
# Determine variant from state dict
|
||||
variant = cls._get_variant_or_default(mod)
|
||||
|
||||
return cls(variant=variant, **override_fields)
|
||||
|
||||
@classmethod
|
||||
def _get_variant_or_default(cls, mod: ModelOnDisk) -> Qwen3VariantType:
|
||||
"""Get variant from state dict, defaulting to 4B if unknown."""
|
||||
state_dict = mod.load_state_dict()
|
||||
variant = _get_qwen3_variant_from_state_dict(state_dict)
|
||||
return variant if variant is not None else Qwen3VariantType.Qwen3_4B
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_qwen3_model(cls, mod: ModelOnDisk) -> None:
|
||||
has_qwen3_keys = _has_qwen3_keys(mod.load_state_dict())
|
||||
if not has_qwen3_keys:
|
||||
raise NotAMatchError("state dict does not look like a Qwen3 model")
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
|
||||
has_ggml = _has_ggml_tensors(mod.load_state_dict())
|
||||
if not has_ggml:
|
||||
raise NotAMatchError("state dict does not look like GGUF quantized")
|
||||
@@ -33,6 +33,25 @@ REGEX_TO_BASE: dict[str, BaseModelType] = {
|
||||
}
|
||||
|
||||
|
||||
def _is_flux2_vae(state_dict: dict[str | int, Any]) -> bool:
|
||||
"""Check if state dict is a FLUX.2 VAE (AutoencoderKLFlux2).
|
||||
|
||||
FLUX.2 VAE can be identified by:
|
||||
1. Batch Normalization layers (bn.running_mean, bn.running_var) - unique to FLUX.2
|
||||
2. 32-dimensional latent space (decoder.conv_in has 32 input channels)
|
||||
|
||||
FLUX.1 VAE has 16-dimensional latent space and no BatchNorm layers.
|
||||
"""
|
||||
# Check for BN layer which is unique to FLUX.2 VAE
|
||||
has_bn = "bn.running_mean" in state_dict or "bn.running_var" in state_dict
|
||||
|
||||
# Check for 32-channel latent space (FLUX.2 has 32, FLUX.1 has 16)
|
||||
decoder_conv_in_key = "decoder.conv_in.weight"
|
||||
has_32_latent_channels = decoder_conv_in_key in state_dict and state_dict[decoder_conv_in_key].shape[1] == 32
|
||||
|
||||
return has_bn or has_32_latent_channels
|
||||
|
||||
|
||||
class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base):
|
||||
"""Model config for standalone VAE models."""
|
||||
|
||||
@@ -61,8 +80,9 @@ class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base):
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None:
|
||||
state_dict = mod.load_state_dict()
|
||||
if not state_dict_has_any_keys_starting_with(
|
||||
mod.load_state_dict(),
|
||||
state_dict,
|
||||
{
|
||||
"encoder.conv_in",
|
||||
"decoder.conv_in",
|
||||
@@ -70,9 +90,30 @@ class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base):
|
||||
):
|
||||
raise NotAMatchError("model does not match Checkpoint VAE heuristics")
|
||||
|
||||
# Exclude FLUX.2 VAEs - they have their own config class
|
||||
if _is_flux2_vae(state_dict):
|
||||
raise NotAMatchError("model is a FLUX.2 VAE, not a standard VAE")
|
||||
|
||||
@classmethod
|
||||
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
|
||||
# Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name
|
||||
# First, try to identify by latent space dimensions (most reliable)
|
||||
state_dict = mod.load_state_dict()
|
||||
decoder_conv_in_key = "decoder.conv_in.weight"
|
||||
if decoder_conv_in_key in state_dict:
|
||||
latent_channels = state_dict[decoder_conv_in_key].shape[1]
|
||||
if latent_channels == 16:
|
||||
# Flux1 VAE has 16-dimensional latent space
|
||||
return BaseModelType.Flux
|
||||
elif latent_channels == 4:
|
||||
# SD/SDXL VAE has 4-dimensional latent space
|
||||
# Try to distinguish SD1/SD2/SDXL by name, fallback to SD1
|
||||
for regexp, base in REGEX_TO_BASE.items():
|
||||
if re.search(regexp, mod.path.name, re.IGNORECASE):
|
||||
return base
|
||||
# Default to SD1 if we can't determine from name
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
# Fallback: guess based on name
|
||||
for regexp, base in REGEX_TO_BASE.items():
|
||||
if re.search(regexp, mod.path.name, re.IGNORECASE):
|
||||
return base
|
||||
@@ -96,6 +137,44 @@ class VAE_Checkpoint_FLUX_Config(VAE_Checkpoint_Config_Base, Config_Base):
|
||||
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
||||
|
||||
|
||||
class VAE_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Config_Base):
|
||||
"""Model config for FLUX.2 VAE checkpoint models (AutoencoderKLFlux2)."""
|
||||
|
||||
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
|
||||
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
||||
base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
||||
raise_if_not_file(mod)
|
||||
|
||||
raise_for_override_fields(cls, override_fields)
|
||||
|
||||
cls._validate_looks_like_vae(mod)
|
||||
|
||||
cls._validate_is_flux2_vae(mod)
|
||||
|
||||
return cls(**override_fields)
|
||||
|
||||
@classmethod
|
||||
def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None:
|
||||
if not state_dict_has_any_keys_starting_with(
|
||||
mod.load_state_dict(),
|
||||
{
|
||||
"encoder.conv_in",
|
||||
"decoder.conv_in",
|
||||
},
|
||||
):
|
||||
raise NotAMatchError("model does not match Checkpoint VAE heuristics")
|
||||
|
||||
@classmethod
|
||||
def _validate_is_flux2_vae(cls, mod: ModelOnDisk) -> None:
|
||||
"""Validate that this is a FLUX.2 VAE, not FLUX.1."""
|
||||
state_dict = mod.load_state_dict()
|
||||
if not _is_flux2_vae(state_dict):
|
||||
raise NotAMatchError("state dict does not look like a FLUX.2 VAE")
|
||||
|
||||
|
||||
class VAE_Diffusers_Config_Base(Diffusers_Config_Base):
|
||||
"""Model config for standalone VAE models (diffusers version)."""
|
||||
|
||||
@@ -161,3 +240,26 @@ class VAE_Diffusers_SD1_Config(VAE_Diffusers_Config_Base, Config_Base):
|
||||
|
||||
class VAE_Diffusers_SDXL_Config(VAE_Diffusers_Config_Base, Config_Base):
|
||||
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
|
||||
|
||||
|
||||
class VAE_Diffusers_Flux2_Config(Diffusers_Config_Base, Config_Base):
|
||||
"""Model config for FLUX.2 VAE models in diffusers format (AutoencoderKLFlux2)."""
|
||||
|
||||
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
|
||||
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
||||
base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
||||
raise_if_not_dir(mod)
|
||||
|
||||
raise_for_override_fields(cls, override_fields)
|
||||
|
||||
raise_for_class_name(
|
||||
common_config_paths(mod.path),
|
||||
{
|
||||
"AutoencoderKLFlux2",
|
||||
},
|
||||
)
|
||||
|
||||
return cls(**override_fields)
|
||||
|
||||
@@ -55,6 +55,21 @@ def synchronized(method: Callable[..., Any]) -> Callable[..., Any]:
|
||||
return wrapper
|
||||
|
||||
|
||||
def record_activity(method: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""A decorator that records activity after a method completes successfully.
|
||||
|
||||
Note: This decorator should be applied to methods that already hold self._lock.
|
||||
"""
|
||||
|
||||
@wraps(method)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
result = method(self, *args, **kwargs)
|
||||
self._record_activity()
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntrySnapshot:
|
||||
cache_key: str
|
||||
@@ -132,6 +147,7 @@ class ModelCache:
|
||||
storage_device: torch.device | str = "cpu",
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
keep_alive_minutes: float = 0,
|
||||
):
|
||||
"""Initialize the model RAM cache.
|
||||
|
||||
@@ -151,6 +167,7 @@ class ModelCache:
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
:param logger: InvokeAILogger to use (otherwise creates one)
|
||||
:param keep_alive_minutes: How long to keep models in cache after last use (in minutes). 0 means keep indefinitely.
|
||||
"""
|
||||
self._enable_partial_loading = enable_partial_loading
|
||||
self._keep_ram_copy_of_weights = keep_ram_copy_of_weights
|
||||
@@ -182,6 +199,12 @@ class ModelCache:
|
||||
self._on_cache_miss_callbacks: set[CacheMissCallback] = set()
|
||||
self._on_cache_models_cleared_callbacks: set[CacheModelsClearedCallback] = set()
|
||||
|
||||
# Keep-alive timeout support
|
||||
self._keep_alive_minutes = keep_alive_minutes
|
||||
self._last_activity_time: Optional[float] = None
|
||||
self._timeout_timer: Optional[threading.Timer] = None
|
||||
self._shutdown_event = threading.Event()
|
||||
|
||||
def on_cache_hit(self, cb: CacheHitCallback) -> Callable[[], None]:
|
||||
self._on_cache_hit_callbacks.add(cb)
|
||||
|
||||
@@ -190,7 +213,7 @@ class ModelCache:
|
||||
|
||||
return unsubscribe
|
||||
|
||||
def on_cache_miss(self, cb: CacheHitCallback) -> Callable[[], None]:
|
||||
def on_cache_miss(self, cb: CacheMissCallback) -> Callable[[], None]:
|
||||
self._on_cache_miss_callbacks.add(cb)
|
||||
|
||||
def unsubscribe() -> None:
|
||||
@@ -217,8 +240,82 @@ class ModelCache:
|
||||
def stats(self, stats: CacheStats) -> None:
|
||||
"""Set the CacheStats object for collecting cache statistics."""
|
||||
self._stats = stats
|
||||
# Populate the cache size in the stats object when it's set
|
||||
if self._stats is not None:
|
||||
self._stats.cache_size = self._ram_cache_size_bytes
|
||||
|
||||
def _record_activity(self) -> None:
|
||||
"""Record model activity and reset the timeout timer if configured.
|
||||
|
||||
Note: This method should only be called when self._lock is already held.
|
||||
"""
|
||||
if self._keep_alive_minutes <= 0:
|
||||
return
|
||||
|
||||
self._last_activity_time = time.time()
|
||||
|
||||
# Cancel any existing timer
|
||||
if self._timeout_timer is not None:
|
||||
self._timeout_timer.cancel()
|
||||
|
||||
# Start a new timer
|
||||
timeout_seconds = self._keep_alive_minutes * 60
|
||||
self._timeout_timer = threading.Timer(timeout_seconds, self._on_timeout)
|
||||
# Set as daemon so it doesn't prevent application shutdown
|
||||
self._timeout_timer.daemon = True
|
||||
self._timeout_timer.start()
|
||||
self._logger.debug(f"Model cache activity recorded. Timeout set to {self._keep_alive_minutes} minutes.")
|
||||
|
||||
@synchronized
|
||||
@record_activity
|
||||
def _on_timeout(self) -> None:
|
||||
"""Called when the keep-alive timeout expires. Clears the model cache."""
|
||||
if self._shutdown_event.is_set():
|
||||
return
|
||||
|
||||
# Double-check if there has been activity since the timer was set
|
||||
# This handles the race condition where activity occurred just before the timer fired
|
||||
if self._last_activity_time is not None and self._keep_alive_minutes > 0:
|
||||
elapsed_minutes = (time.time() - self._last_activity_time) / 60
|
||||
if elapsed_minutes < self._keep_alive_minutes:
|
||||
# Activity occurred, don't clear cache
|
||||
self._logger.debug(
|
||||
f"Model cache timeout fired but activity detected {elapsed_minutes:.2f} minutes ago. "
|
||||
f"Skipping cache clear."
|
||||
)
|
||||
return
|
||||
|
||||
# Check if there are any unlocked models that can be cleared
|
||||
unlocked_models = [key for key, entry in self._cached_models.items() if not entry.is_locked]
|
||||
|
||||
if len(unlocked_models) > 0:
|
||||
self._logger.info(
|
||||
f"Model cache keep-alive timeout of {self._keep_alive_minutes} minutes expired. "
|
||||
f"Clearing {len(unlocked_models)} unlocked model(s) from cache."
|
||||
)
|
||||
# Clear the cache by requesting a very large amount of space.
|
||||
# This is the same logic used by the "Clear Model Cache" button.
|
||||
# Using 1000 GB ensures all unlocked models are removed.
|
||||
self._make_room_internal(1000 * GB)
|
||||
elif len(self._cached_models) > 0:
|
||||
# All models are locked, don't log at info level
|
||||
self._logger.debug(
|
||||
f"Model cache timeout fired but all {len(self._cached_models)} model(s) are locked. "
|
||||
f"Skipping cache clear."
|
||||
)
|
||||
else:
|
||||
self._logger.debug("Model cache timeout fired but cache is already empty.")
|
||||
|
||||
@synchronized
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the model cache, cancelling any pending timers."""
|
||||
self._shutdown_event.set()
|
||||
if self._timeout_timer is not None:
|
||||
self._timeout_timer.cancel()
|
||||
self._timeout_timer = None
|
||||
|
||||
@synchronized
|
||||
@record_activity
|
||||
def put(self, key: str, model: AnyModel) -> None:
|
||||
"""Add a model to the cache."""
|
||||
if key in self._cached_models:
|
||||
@@ -228,7 +325,7 @@ class ModelCache:
|
||||
return
|
||||
|
||||
size = calc_model_size_by_data(self._logger, model)
|
||||
self.make_room(size)
|
||||
self._make_room_internal(size)
|
||||
|
||||
# Inject custom modules into the model.
|
||||
if isinstance(model, torch.nn.Module):
|
||||
@@ -272,6 +369,7 @@ class ModelCache:
|
||||
return overview
|
||||
|
||||
@synchronized
|
||||
@record_activity
|
||||
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
|
||||
"""Retrieve a model from the cache.
|
||||
|
||||
@@ -309,9 +407,11 @@ class ModelCache:
|
||||
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
|
||||
for cb in self._on_cache_hit_callbacks:
|
||||
cb(model_key=key, cache_snapshot=self._get_cache_snapshot())
|
||||
|
||||
return cache_entry
|
||||
|
||||
@synchronized
|
||||
@record_activity
|
||||
def lock(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int]) -> None:
|
||||
"""Lock a model for use and move it into VRAM."""
|
||||
if cache_entry.key not in self._cached_models:
|
||||
@@ -348,6 +448,7 @@ class ModelCache:
|
||||
self._log_cache_state()
|
||||
|
||||
@synchronized
|
||||
@record_activity
|
||||
def unlock(self, cache_entry: CacheRecord) -> None:
|
||||
"""Unlock a model."""
|
||||
if cache_entry.key not in self._cached_models:
|
||||
@@ -691,6 +792,10 @@ class ModelCache:
|
||||
external references to the model, there's nothing that the cache can do about it, and those models will not be
|
||||
garbage-collected.
|
||||
"""
|
||||
self._make_room_internal(bytes_needed)
|
||||
|
||||
def _make_room_internal(self, bytes_needed: int) -> None:
|
||||
"""Internal implementation of make_room(). Assumes the lock is already held."""
|
||||
self._logger.debug(f"Making room for {bytes_needed / MB:.2f}MB of RAM.")
|
||||
self._log_cache_state(title="Before dropping models:")
|
||||
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
from diffusers.models.normalization import RMSNorm as DiffusersRMSNorm
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
|
||||
|
||||
class CustomDiffusersRMSNorm(DiffusersRMSNorm, CustomModuleMixin):
|
||||
"""Custom wrapper for diffusers RMSNorm that supports device autocasting for partial model loading."""
|
||||
|
||||
def _autocast_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, hidden_states.device) if self.weight is not None else None
|
||||
bias = cast_to_device(self.bias, hidden_states.device) if self.bias is not None else None
|
||||
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
|
||||
if weight is not None:
|
||||
# convert into half-precision if necessary
|
||||
if weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(weight.dtype)
|
||||
hidden_states = hidden_states * weight
|
||||
if bias is not None:
|
||||
hidden_states = hidden_states + bias
|
||||
else:
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
raise RuntimeError("DiffusersRMSNorm layers do not support patches")
|
||||
|
||||
if self._device_autocasting_enabled:
|
||||
return self._autocast_forward(hidden_states)
|
||||
else:
|
||||
return super().forward(hidden_states)
|
||||
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
|
||||
CustomModuleMixin,
|
||||
)
|
||||
|
||||
|
||||
class CustomLayerNorm(torch.nn.LayerNorm, CustomModuleMixin):
|
||||
"""Custom wrapper for torch.nn.LayerNorm that supports device autocasting for partial model loading."""
|
||||
|
||||
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
weight = cast_to_device(self.weight, input.device) if self.weight is not None else None
|
||||
bias = cast_to_device(self.bias, input.device) if self.bias is not None else None
|
||||
return F.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if len(self._patches_and_weights) > 0:
|
||||
raise RuntimeError("LayerNorm layers do not support patches")
|
||||
|
||||
if self._device_autocasting_enabled:
|
||||
return self._autocast_forward(input)
|
||||
else:
|
||||
return super().forward(input)
|
||||
@@ -1,14 +1,18 @@
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
from diffusers.models.normalization import RMSNorm as DiffusersRMSNorm
|
||||
|
||||
from invokeai.backend.flux.modules.layers import RMSNorm
|
||||
from invokeai.backend.flux.modules.layers import RMSNorm as FluxRMSNorm
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv1d import (
|
||||
CustomConv1d,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv2d import (
|
||||
CustomConv2d,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_diffusers_rms_norm import (
|
||||
CustomDiffusersRMSNorm,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_embedding import (
|
||||
CustomEmbedding,
|
||||
)
|
||||
@@ -18,6 +22,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custo
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_group_norm import (
|
||||
CustomGroupNorm,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_layer_norm import (
|
||||
CustomLayerNorm,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import (
|
||||
CustomLinear,
|
||||
)
|
||||
@@ -31,7 +38,9 @@ AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]]
|
||||
torch.nn.Conv2d: CustomConv2d,
|
||||
torch.nn.GroupNorm: CustomGroupNorm,
|
||||
torch.nn.Embedding: CustomEmbedding,
|
||||
RMSNorm: CustomFluxRMSNorm,
|
||||
torch.nn.LayerNorm: CustomLayerNorm,
|
||||
FluxRMSNorm: CustomFluxRMSNorm,
|
||||
DiffusersRMSNorm: CustomDiffusersRMSNorm,
|
||||
}
|
||||
|
||||
try:
|
||||
|
||||
@@ -45,12 +45,13 @@ class CogView4DiffusersModel(GenericDiffusersLoader):
|
||||
model_path,
|
||||
torch_dtype=dtype,
|
||||
variant=variant,
|
||||
local_files_only=True,
|
||||
)
|
||||
except OSError as e:
|
||||
if variant and "no file named" in str(
|
||||
e
|
||||
): # try without the variant, just in case user's preferences changed
|
||||
result = load_class.from_pretrained(model_path, torch_dtype=dtype)
|
||||
result = load_class.from_pretrained(model_path, torch_dtype=dtype, local_files_only=True)
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -37,12 +37,14 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None
|
||||
variant = repo_variant.value if repo_variant else None
|
||||
try:
|
||||
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant)
|
||||
result: AnyModel = model_class.from_pretrained(
|
||||
model_path, torch_dtype=self._torch_dtype, variant=variant, local_files_only=True
|
||||
)
|
||||
except OSError as e:
|
||||
if variant and "no file named" in str(
|
||||
e
|
||||
): # try without the variant, just in case user's preferences changed
|
||||
result = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype)
|
||||
result = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, local_files_only=True)
|
||||
else:
|
||||
raise e
|
||||
return result
|
||||
|
||||
@@ -41,8 +41,13 @@ from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_u
|
||||
is_state_dict_likely_in_flux_onetrainer_format,
|
||||
lora_model_from_flux_onetrainer_state_dict,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_xlabs_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_xlabs_format,
|
||||
lora_model_from_flux_xlabs_state_dict,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
|
||||
from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
|
||||
from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import lora_model_from_z_image_state_dict
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.LoRA, format=ModelFormat.OMI)
|
||||
@@ -117,6 +122,8 @@ class LoRALoader(ModelLoader):
|
||||
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
|
||||
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict=state_dict):
|
||||
model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict)
|
||||
elif is_state_dict_likely_in_flux_xlabs_format(state_dict=state_dict):
|
||||
model = lora_model_from_flux_xlabs_state_dict(state_dict=state_dict)
|
||||
else:
|
||||
raise ValueError("LoRA model is in unsupported FLUX format")
|
||||
else:
|
||||
@@ -124,6 +131,10 @@ class LoRALoader(ModelLoader):
|
||||
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
# Currently, we don't apply any conversions for SD1 and SD2 LoRA models.
|
||||
model = lora_model_from_sd_state_dict(state_dict=state_dict)
|
||||
elif self._model_base == BaseModelType.ZImage:
|
||||
# Z-Image LoRAs use diffusers PEFT format with transformer and/or Qwen3 encoder layers.
|
||||
# We set alpha=None to use rank as alpha (common default).
|
||||
model = lora_model_from_z_image_state_dict(state_dict=state_dict, alpha=None)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LoRA base model: {self._model_base}")
|
||||
|
||||
|
||||
@@ -38,5 +38,6 @@ class OnnyxDiffusersModel(GenericDiffusersLoader):
|
||||
model_path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
variant=variant,
|
||||
local_files_only=True,
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -80,12 +80,13 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
model_path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
variant=variant,
|
||||
local_files_only=True,
|
||||
)
|
||||
except OSError as e:
|
||||
if variant and "no file named" in str(
|
||||
e
|
||||
): # try without the variant, just in case user's preferences changed
|
||||
result = load_class.from_pretrained(model_path, torch_dtype=self._torch_dtype)
|
||||
result = load_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, local_files_only=True)
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
1062
invokeai/backend/model_manager/load/model_loaders/z_image.py
Normal file
1062
invokeai/backend/model_manager/load/model_loaders/z_image.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -10,7 +10,7 @@ import onnxruntime as ort
|
||||
import torch
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast
|
||||
from transformers import CLIPTokenizer, PreTrainedTokenizerBase, T5Tokenizer, T5TokenizerFast
|
||||
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
@@ -73,6 +73,10 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
# relative to the text encoder that it's used with, so shouldn't matter too much, but we should fix this at some
|
||||
# point.
|
||||
return len(model)
|
||||
elif isinstance(model, PreTrainedTokenizerBase):
|
||||
# Catch-all for other tokenizer types (e.g., Qwen2Tokenizer, Qwen3Tokenizer).
|
||||
# Tokenizers are small relative to models, so returning 0 is acceptable.
|
||||
return 0
|
||||
else:
|
||||
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
||||
# supported model types.
|
||||
@@ -156,6 +160,7 @@ def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, var
|
||||
(".msgpack",), # flax
|
||||
(".ckpt",), # tf
|
||||
(".h5",), # tf2
|
||||
(".gguf",), # gguf quantized
|
||||
]
|
||||
|
||||
for file_format in formats:
|
||||
|
||||
@@ -95,13 +95,15 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
|
||||
self,
|
||||
variant: Optional[ModelRepoVariant] = None,
|
||||
subfolder: Optional[Path] = None,
|
||||
subfolders: Optional[List[Path]] = None,
|
||||
session: Optional[Session] = None,
|
||||
) -> List[RemoteModelFile]:
|
||||
"""
|
||||
Return list of downloadable files, filtering by variant and subfolder, if any.
|
||||
Return list of downloadable files, filtering by variant and subfolder(s), if any.
|
||||
|
||||
:param variant: Return model files needed to reconstruct the indicated variant
|
||||
:param subfolder: Return model files from the designated subfolder only
|
||||
:param subfolder: Return model files from the designated subfolder only (deprecated, use subfolders)
|
||||
:param subfolders: Return model files from the designated subfolders
|
||||
:param session: A request.Session object used for internet-free testing
|
||||
|
||||
Note that there is special variant-filtering behavior here:
|
||||
@@ -111,10 +113,15 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
|
||||
session = session or Session()
|
||||
configure_http_backend(backend_factory=lambda: session) # used in testing
|
||||
|
||||
paths = filter_files([x.path for x in self.files], variant, subfolder) # all files in the model
|
||||
prefix = f"{subfolder}/" if subfolder else ""
|
||||
paths = filter_files([x.path for x in self.files], variant, subfolder, subfolders) # all files in the model
|
||||
|
||||
# Determine prefix for model_index.json check - only applies for single subfolder
|
||||
prefix = ""
|
||||
if subfolder and not subfolders:
|
||||
prefix = f"{subfolder}/"
|
||||
|
||||
# the next step reads model_index.json to determine which subdirectories belong
|
||||
# to the model
|
||||
# to the model (only for single subfolder case)
|
||||
if Path(f"{prefix}model_index.json") in paths:
|
||||
url = hf_hub_url(self.id, filename="model_index.json", subfolder=str(subfolder) if subfolder else None)
|
||||
resp = session.get(url)
|
||||
|
||||
@@ -84,6 +84,9 @@ class ModelOnDisk:
|
||||
|
||||
path = self.resolve_weight_file(path)
|
||||
|
||||
if path in self._state_dict_cache:
|
||||
return self._state_dict_cache[path]
|
||||
|
||||
with SilenceWarnings():
|
||||
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
|
||||
scan_result = scan_file_path(path)
|
||||
|
||||
@@ -690,6 +690,178 @@ flux_fill = StarterModel(
|
||||
)
|
||||
# endregion
|
||||
|
||||
# region FLUX.2 Klein
|
||||
flux2_vae = StarterModel(
|
||||
name="FLUX.2 VAE",
|
||||
base=BaseModelType.Flux2,
|
||||
source="black-forest-labs/FLUX.2-klein-4B::vae",
|
||||
description="FLUX.2 VAE (16-channel, same architecture as FLUX.1 VAE). ~335MB",
|
||||
type=ModelType.VAE,
|
||||
)
|
||||
|
||||
flux2_klein_qwen3_4b_encoder = StarterModel(
|
||||
name="FLUX.2 Klein Qwen3 4B Encoder",
|
||||
base=BaseModelType.Any,
|
||||
source="black-forest-labs/FLUX.2-klein-4B::text_encoder+tokenizer",
|
||||
description="Qwen3 4B text encoder for FLUX.2 Klein 4B (also compatible with Z-Image). ~8GB",
|
||||
type=ModelType.Qwen3Encoder,
|
||||
)
|
||||
|
||||
flux2_klein_qwen3_8b_encoder = StarterModel(
|
||||
name="FLUX.2 Klein Qwen3 8B Encoder",
|
||||
base=BaseModelType.Any,
|
||||
source="black-forest-labs/FLUX.2-klein-9B::text_encoder+tokenizer",
|
||||
description="Qwen3 8B text encoder for FLUX.2 Klein 9B models. ~16GB",
|
||||
type=ModelType.Qwen3Encoder,
|
||||
)
|
||||
|
||||
flux2_klein_4b = StarterModel(
|
||||
name="FLUX.2 Klein 4B (Diffusers)",
|
||||
base=BaseModelType.Flux2,
|
||||
source="black-forest-labs/FLUX.2-klein-4B",
|
||||
description="FLUX.2 Klein 4B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~10GB",
|
||||
type=ModelType.Main,
|
||||
)
|
||||
|
||||
flux2_klein_4b_single = StarterModel(
|
||||
name="FLUX.2 Klein 4B",
|
||||
base=BaseModelType.Flux2,
|
||||
source="https://huggingface.co/black-forest-labs/FLUX.2-klein-4B/resolve/main/flux-2-klein-4b.safetensors",
|
||||
description="FLUX.2 Klein 4B standalone transformer. Installs with VAE and Qwen3 4B encoder. ~8GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[flux2_vae, flux2_klein_qwen3_4b_encoder],
|
||||
)
|
||||
|
||||
flux2_klein_4b_fp8 = StarterModel(
|
||||
name="FLUX.2 Klein 4B (FP8)",
|
||||
base=BaseModelType.Flux2,
|
||||
source="https://huggingface.co/black-forest-labs/FLUX.2-klein-4b-fp8/resolve/main/flux-2-klein-4b-fp8.safetensors",
|
||||
description="FLUX.2 Klein 4B FP8 quantized - smaller and faster. Installs with VAE and Qwen3 4B encoder. ~4GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[flux2_vae, flux2_klein_qwen3_4b_encoder],
|
||||
)
|
||||
|
||||
flux2_klein_9b = StarterModel(
|
||||
name="FLUX.2 Klein 9B (Diffusers)",
|
||||
base=BaseModelType.Flux2,
|
||||
source="black-forest-labs/FLUX.2-klein-9B",
|
||||
description="FLUX.2 Klein 9B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~20GB",
|
||||
type=ModelType.Main,
|
||||
)
|
||||
|
||||
flux2_klein_9b_fp8 = StarterModel(
|
||||
name="FLUX.2 Klein 9B (FP8)",
|
||||
base=BaseModelType.Flux2,
|
||||
source="https://huggingface.co/black-forest-labs/FLUX.2-klein-9b-fp8/resolve/main/flux-2-klein-9b-fp8.safetensors",
|
||||
description="FLUX.2 Klein 9B FP8 quantized - more efficient than full precision. Installs with VAE and Qwen3 8B encoder. ~9.5GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[flux2_vae, flux2_klein_qwen3_8b_encoder],
|
||||
)
|
||||
|
||||
flux2_klein_4b_gguf_q4 = StarterModel(
|
||||
name="FLUX.2 Klein 4B (GGUF Q4)",
|
||||
base=BaseModelType.Flux2,
|
||||
source="https://huggingface.co/unsloth/FLUX.2-klein-4B-GGUF/resolve/main/flux-2-klein-4b-Q4_K_M.gguf",
|
||||
description="FLUX.2 Klein 4B GGUF Q4_K_M quantized - runs on 6-8GB VRAM. Installs with VAE and Qwen3 4B encoder. ~2.6GB",
|
||||
type=ModelType.Main,
|
||||
format=ModelFormat.GGUFQuantized,
|
||||
dependencies=[flux2_vae, flux2_klein_qwen3_4b_encoder],
|
||||
)
|
||||
|
||||
flux2_klein_4b_gguf_q8 = StarterModel(
|
||||
name="FLUX.2 Klein 4B (GGUF Q8)",
|
||||
base=BaseModelType.Flux2,
|
||||
source="https://huggingface.co/unsloth/FLUX.2-klein-4B-GGUF/resolve/main/flux-2-klein-4b-Q8_0.gguf",
|
||||
description="FLUX.2 Klein 4B GGUF Q8_0 quantized - higher quality than Q4. Installs with VAE and Qwen3 4B encoder. ~4.3GB",
|
||||
type=ModelType.Main,
|
||||
format=ModelFormat.GGUFQuantized,
|
||||
dependencies=[flux2_vae, flux2_klein_qwen3_4b_encoder],
|
||||
)
|
||||
|
||||
flux2_klein_9b_gguf_q4 = StarterModel(
|
||||
name="FLUX.2 Klein 9B (GGUF Q4)",
|
||||
base=BaseModelType.Flux2,
|
||||
source="https://huggingface.co/unsloth/FLUX.2-klein-9B-GGUF/resolve/main/flux-2-klein-9b-Q4_K_M.gguf",
|
||||
description="FLUX.2 Klein 9B GGUF Q4_K_M quantized - runs on 12GB+ VRAM. Installs with VAE and Qwen3 8B encoder. ~5.8GB",
|
||||
type=ModelType.Main,
|
||||
format=ModelFormat.GGUFQuantized,
|
||||
dependencies=[flux2_vae, flux2_klein_qwen3_8b_encoder],
|
||||
)
|
||||
|
||||
flux2_klein_9b_gguf_q8 = StarterModel(
|
||||
name="FLUX.2 Klein 9B (GGUF Q8)",
|
||||
base=BaseModelType.Flux2,
|
||||
source="https://huggingface.co/unsloth/FLUX.2-klein-9B-GGUF/resolve/main/flux-2-klein-9b-Q8_0.gguf",
|
||||
description="FLUX.2 Klein 9B GGUF Q8_0 quantized - higher quality than Q4. Installs with VAE and Qwen3 8B encoder. ~10GB",
|
||||
type=ModelType.Main,
|
||||
format=ModelFormat.GGUFQuantized,
|
||||
dependencies=[flux2_vae, flux2_klein_qwen3_8b_encoder],
|
||||
)
|
||||
# endregion
|
||||
|
||||
# region Z-Image
|
||||
z_image_qwen3_encoder = StarterModel(
|
||||
name="Z-Image Qwen3 Text Encoder",
|
||||
base=BaseModelType.Any,
|
||||
source="Tongyi-MAI/Z-Image-Turbo::text_encoder+tokenizer",
|
||||
description="Qwen3 4B text encoder with tokenizer for Z-Image (full precision). ~8GB",
|
||||
type=ModelType.Qwen3Encoder,
|
||||
)
|
||||
|
||||
z_image_qwen3_encoder_quantized = StarterModel(
|
||||
name="Z-Image Qwen3 Text Encoder (quantized)",
|
||||
base=BaseModelType.Any,
|
||||
source="https://huggingface.co/worstplayer/Z-Image_Qwen_3_4b_text_encoder_GGUF/resolve/main/Qwen_3_4b-Q6_K.gguf",
|
||||
description="Qwen3 4B text encoder for Z-Image quantized to GGUF Q6_K format. ~3.3GB",
|
||||
type=ModelType.Qwen3Encoder,
|
||||
format=ModelFormat.GGUFQuantized,
|
||||
)
|
||||
|
||||
z_image_turbo = StarterModel(
|
||||
name="Z-Image Turbo",
|
||||
base=BaseModelType.ZImage,
|
||||
source="Tongyi-MAI/Z-Image-Turbo",
|
||||
description="Z-Image Turbo - fast 6B parameter text-to-image model with 8 inference steps. Supports bilingual prompts (English & Chinese). ~13GB",
|
||||
type=ModelType.Main,
|
||||
)
|
||||
|
||||
z_image_turbo_quantized = StarterModel(
|
||||
name="Z-Image Turbo (quantized)",
|
||||
base=BaseModelType.ZImage,
|
||||
source="https://huggingface.co/leejet/Z-Image-Turbo-GGUF/resolve/main/z_image_turbo-Q4_K.gguf",
|
||||
description="Z-Image Turbo quantized to GGUF Q4_K format. Requires standalone Qwen3 text encoder and Flux VAE. ~4GB",
|
||||
type=ModelType.Main,
|
||||
format=ModelFormat.GGUFQuantized,
|
||||
dependencies=[z_image_qwen3_encoder_quantized, flux_vae],
|
||||
)
|
||||
|
||||
z_image_turbo_q8 = StarterModel(
|
||||
name="Z-Image Turbo (Q8)",
|
||||
base=BaseModelType.ZImage,
|
||||
source="https://huggingface.co/leejet/Z-Image-Turbo-GGUF/resolve/main/z_image_turbo-Q8_0.gguf",
|
||||
description="Z-Image Turbo quantized to GGUF Q8_0 format. Higher quality, larger size. Requires standalone Qwen3 text encoder and Flux VAE. ~6.6GB",
|
||||
type=ModelType.Main,
|
||||
format=ModelFormat.GGUFQuantized,
|
||||
dependencies=[z_image_qwen3_encoder_quantized, flux_vae],
|
||||
)
|
||||
|
||||
z_image_controlnet_union = StarterModel(
|
||||
name="Z-Image ControlNet Union",
|
||||
base=BaseModelType.ZImage,
|
||||
source="https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.1/resolve/main/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors",
|
||||
description="Unified ControlNet for Z-Image Turbo supporting Canny, HED, Depth, Pose, MLSD, and Inpainting modes.",
|
||||
type=ModelType.ControlNet,
|
||||
)
|
||||
|
||||
z_image_controlnet_tile = StarterModel(
|
||||
name="Z-Image ControlNet Tile",
|
||||
base=BaseModelType.ZImage,
|
||||
source="https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.1/resolve/main/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.safetensors",
|
||||
description="Dedicated Tile ControlNet for Z-Image Turbo. Useful for upscaling and adding detail. ~6.7GB",
|
||||
type=ModelType.ControlNet,
|
||||
)
|
||||
# endregion
|
||||
|
||||
# List of starter models, displayed on the frontend.
|
||||
# The order/sort of this list is not changed by the frontend - set it how you want it here.
|
||||
STARTER_MODELS: list[StarterModel] = [
|
||||
@@ -763,9 +935,28 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
flux_redux,
|
||||
llava_onevision,
|
||||
flux_fill,
|
||||
flux2_vae,
|
||||
flux2_klein_4b,
|
||||
flux2_klein_4b_single,
|
||||
flux2_klein_4b_fp8,
|
||||
flux2_klein_9b,
|
||||
flux2_klein_9b_fp8,
|
||||
flux2_klein_4b_gguf_q4,
|
||||
flux2_klein_4b_gguf_q8,
|
||||
flux2_klein_9b_gguf_q4,
|
||||
flux2_klein_9b_gguf_q8,
|
||||
flux2_klein_qwen3_4b_encoder,
|
||||
flux2_klein_qwen3_8b_encoder,
|
||||
cogview4,
|
||||
flux_krea,
|
||||
flux_krea_quantized,
|
||||
z_image_turbo,
|
||||
z_image_turbo_quantized,
|
||||
z_image_turbo_q8,
|
||||
z_image_qwen3_encoder,
|
||||
z_image_qwen3_encoder_quantized,
|
||||
z_image_controlnet_union,
|
||||
z_image_controlnet_tile,
|
||||
]
|
||||
|
||||
sd1_bundle: list[StarterModel] = [
|
||||
@@ -820,10 +1011,26 @@ flux_bundle: list[StarterModel] = [
|
||||
flux_krea_quantized,
|
||||
]
|
||||
|
||||
zimage_bundle: list[StarterModel] = [
|
||||
z_image_turbo_quantized,
|
||||
z_image_qwen3_encoder_quantized,
|
||||
z_image_controlnet_union,
|
||||
z_image_controlnet_tile,
|
||||
flux_vae,
|
||||
]
|
||||
|
||||
flux2_klein_bundle: list[StarterModel] = [
|
||||
flux2_klein_4b_gguf_q4,
|
||||
flux2_vae,
|
||||
flux2_klein_qwen3_4b_encoder,
|
||||
]
|
||||
|
||||
STARTER_BUNDLES: dict[str, StarterModelBundle] = {
|
||||
BaseModelType.StableDiffusion1: StarterModelBundle(name="Stable Diffusion 1.5", models=sd1_bundle),
|
||||
BaseModelType.StableDiffusionXL: StarterModelBundle(name="SDXL", models=sdxl_bundle),
|
||||
BaseModelType.Flux: StarterModelBundle(name="FLUX.1 dev", models=flux_bundle),
|
||||
BaseModelType.Flux2: StarterModelBundle(name="FLUX.2 Klein", models=flux2_klein_bundle),
|
||||
BaseModelType.ZImage: StarterModelBundle(name="Z-Image Turbo", models=zimage_bundle),
|
||||
}
|
||||
|
||||
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user