mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Compare commits
300 Commits
v6.10.0rc1
...
external-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7938d840b2 | ||
|
|
450ba7b7e1 | ||
|
|
3fc981f4b6 | ||
|
|
e252a5bb47 | ||
|
|
ce896678d7 | ||
|
|
37ff6c3743 | ||
|
|
c743106f66 | ||
|
|
1b50c1a79c | ||
|
|
acd4157bdf | ||
|
|
06a1881bbd | ||
|
|
441821ca03 | ||
|
|
9d62bfdf8e | ||
|
|
cd888654d5 | ||
|
|
dd056067a9 | ||
|
|
ec4b87b949 | ||
|
|
8f00759af0 | ||
|
|
5c09c823a9 | ||
|
|
33ec16deb4 | ||
|
|
b42274a57e | ||
|
|
ec90b2fbe9 | ||
|
|
17157d7c60 | ||
|
|
a3507121da | ||
|
|
3c9b282a90 | ||
|
|
a2e4fbb9b5 | ||
|
|
06eff38354 | ||
|
|
d4104be0b8 | ||
|
|
ee600973ed | ||
|
|
d4c0e631e2 | ||
|
|
5f35d0e432 | ||
|
|
f0d09c34a8 | ||
|
|
853c3ef915 | ||
|
|
60d0bcdbc1 | ||
|
|
80be1b7282 | ||
|
|
dbbf28925b | ||
|
|
f08b802968 | ||
|
|
ae42182246 | ||
|
|
3e9e052d5d | ||
|
|
089e2db402 | ||
|
|
4cbd60b4a5 | ||
|
|
c2016bcfb7 | ||
|
|
32002bd37e | ||
|
|
e6f2980d7c | ||
|
|
01c67c5468 | ||
|
|
be015a5434 | ||
|
|
82f3dc9032 | ||
|
|
471ab9d9c0 | ||
|
|
41a542552e | ||
|
|
5596fa0cc8 | ||
|
|
05f4deb68c | ||
|
|
474d85e5e0 | ||
|
|
ed268b1cfc | ||
|
|
6963cd97ba | ||
|
|
813a5e2c2e | ||
|
|
18315db7f0 | ||
|
|
edde0b4737 | ||
|
|
ab6f186f8c | ||
|
|
7f2878f691 | ||
|
|
d32f6b5a56 | ||
|
|
f7aa5fcbbf | ||
|
|
438515bf9a | ||
|
|
27fc650f4f | ||
|
|
a1eef791a1 | ||
|
|
d8d0ebc356 | ||
|
|
8375f95ea9 | ||
|
|
9e4d0bb191 | ||
|
|
20a400cee8 | ||
|
|
40f02aa6c4 | ||
|
|
c3a482e80a | ||
|
|
257994f552 | ||
|
|
bafce41856 | ||
|
|
757bd3d002 | ||
|
|
519575e871 | ||
|
|
17da6bb9c3 | ||
|
|
b120ef5183 | ||
|
|
dc5007fe95 | ||
|
|
f39456e6f0 | ||
|
|
bba207a856 | ||
|
|
a7b367fda2 | ||
|
|
cd47b3baf7 | ||
|
|
c8ac303ad2 | ||
|
|
f01cbd35a8 | ||
|
|
2179d93ce0 | ||
|
|
863fa50551 | ||
|
|
e74d8ab2bb | ||
|
|
2d1dbceae5 | ||
|
|
689725c6e4 | ||
|
|
62b7c7a6e8 | ||
|
|
b8b6798167 | ||
|
|
274d9b3a74 | ||
|
|
3d81edac61 | ||
|
|
df225d3751 | ||
|
|
fcdcd7f46b | ||
|
|
94e04b1e1e | ||
|
|
67669b7fbe | ||
|
|
c7bdaf93b2 | ||
|
|
6b57b004a4 | ||
|
|
6fe7910a90 | ||
|
|
445c6a3c36 | ||
|
|
54c1609687 | ||
|
|
ec46b5cb9e | ||
|
|
4fd5cd26a0 | ||
|
|
c83c4af1ea | ||
|
|
10729f40f2 | ||
|
|
362054120e | ||
|
|
b91a156a3d | ||
|
|
c6b0d45c5f | ||
|
|
dc665e08ac | ||
|
|
0dd72837d3 | ||
|
|
d5a6283f23 | ||
|
|
6fe1a6f1ac | ||
|
|
5d34eab6f0 | ||
|
|
1b43769b95 | ||
|
|
a9d3b4e17c | ||
|
|
74ecc461b9 | ||
|
|
19650f6ada | ||
|
|
146b936844 | ||
|
|
b90969ee88 | ||
|
|
dfc66b7142 | ||
|
|
21efa70b4d | ||
|
|
7aa3c95767 | ||
|
|
afbd45ace7 | ||
|
|
b9f9015214 | ||
|
|
ddaa12b0fd | ||
|
|
c8dfea8681 | ||
|
|
1730193883 | ||
|
|
33c7b2a1f9 | ||
|
|
033ff77f94 | ||
|
|
89df130ca1 | ||
|
|
e9246c1899 | ||
|
|
b0f7b555b7 | ||
|
|
467ae66a87 | ||
|
|
848cc12d63 | ||
|
|
dbb20a011a | ||
|
|
3ada1dc743 | ||
|
|
0fb2ae4fae | ||
|
|
ec2eedb000 | ||
|
|
77e1ac19fc | ||
|
|
b23f18734b | ||
|
|
8c3cc3a970 | ||
|
|
86eff471fd | ||
|
|
a42fdb0f44 | ||
|
|
bacdfecb13 | ||
|
|
76b0838094 | ||
|
|
b7d7cd0748 | ||
|
|
d5c59ee64e | ||
|
|
3f79159249 | ||
|
|
c186e51b30 | ||
|
|
c072fd8261 | ||
|
|
f013fa6ff2 | ||
|
|
9566f9a23f | ||
|
|
62ee1b820d | ||
|
|
33779f3072 | ||
|
|
8cf83a9221 | ||
|
|
1281c9d211 | ||
|
|
4a09594230 | ||
|
|
a873ce0175 | ||
|
|
9ee7baaba5 | ||
|
|
fb5c43a905 | ||
|
|
0f69f4bb9a | ||
|
|
8a355e66fa | ||
|
|
b811602b38 | ||
|
|
0716b2fa75 | ||
|
|
4d71609115 | ||
|
|
0ecb903ae2 | ||
|
|
736f4ffeb1 | ||
|
|
2102b43edc | ||
|
|
5801e59e2b | ||
|
|
5fc950b745 | ||
|
|
63dec985cd | ||
|
|
03cdd6df2e | ||
|
|
99f4070ce7 | ||
|
|
cf07f8be14 | ||
|
|
1f0d92defc | ||
|
|
68089ca688 | ||
|
|
32e2132948 | ||
|
|
bec3586930 | ||
|
|
8bf4d1ea59 | ||
|
|
fd7a3aebd2 | ||
|
|
72491e2153 | ||
|
|
3d0725072d | ||
|
|
0ae7392c81 | ||
|
|
cff20b45f3 | ||
|
|
b92c6ae633 | ||
|
|
729bae19a5 | ||
|
|
fcc81f17a5 | ||
|
|
27ae70a428 | ||
|
|
82819cdadc | ||
|
|
b2b8820519 | ||
|
|
bb6c544603 | ||
|
|
8a18914637 | ||
|
|
d66df9a0d0 | ||
|
|
5c00684701 | ||
|
|
d93ce6ac42 | ||
|
|
13bf5feb4d | ||
|
|
53ab178edd | ||
|
|
2d8317f1aa | ||
|
|
04f815638c | ||
|
|
d6ad6a2dcb | ||
|
|
784503e484 | ||
|
|
da2809b000 | ||
|
|
53c34eb95e | ||
|
|
18fc822d37 | ||
|
|
89dc50bd7c | ||
|
|
d34655fd58 | ||
|
|
c1a8300e96 | ||
|
|
9c5b2f6498 | ||
|
|
dbb4a07a8f | ||
|
|
f66a1a38c8 | ||
|
|
be2635161c | ||
|
|
384a1a689d | ||
|
|
0021404639 | ||
|
|
a05a626644 | ||
|
|
97b82d752e | ||
|
|
f29820a7ba | ||
|
|
47a634d8fb | ||
|
|
768f3dbde0 | ||
|
|
1ca589ea10 | ||
|
|
3a21e7699f | ||
|
|
56fd7bc7c4 | ||
|
|
2425005aad | ||
|
|
2ccadd1834 | ||
|
|
5cef8bd364 | ||
|
|
8a6d593fe8 | ||
|
|
14309562b8 | ||
|
|
9f8f9965f9 | ||
|
|
44a21a348d | ||
|
|
81d83d5aab | ||
|
|
d99707fdcb | ||
|
|
252dd5b426 | ||
|
|
f922f6c634 | ||
|
|
be0cbe046c | ||
|
|
e39b880f6d | ||
|
|
4f8ec07d2f | ||
|
|
689953e3cf | ||
|
|
61c2589e39 | ||
|
|
8cf4c6944a | ||
|
|
db228ddc4f | ||
|
|
858c94b575 | ||
|
|
252794d717 | ||
|
|
7847ccea13 | ||
|
|
1bcf589d19 | ||
|
|
132a48497b | ||
|
|
f49e1b8dae | ||
|
|
e7233efb79 | ||
|
|
3b2d2ef10a | ||
|
|
66974841f1 | ||
|
|
87608ade45 | ||
|
|
1e83aeeb79 | ||
|
|
1c76d295a2 | ||
|
|
384250ff8c | ||
|
|
6c3ce8e7e9 | ||
|
|
d658ef4322 | ||
|
|
8d880ef5a0 | ||
|
|
c6775cc999 | ||
|
|
d44b99ae0a | ||
|
|
1675712094 | ||
|
|
2924d052c5 | ||
|
|
f1624a6215 | ||
|
|
b7e28e4fa6 | ||
|
|
d7d051200f | ||
|
|
0f830ddd00 | ||
|
|
9617140b7f | ||
|
|
bc4783028f | ||
|
|
16fedfb538 | ||
|
|
d781a3b8a2 | ||
|
|
7182ff26dc | ||
|
|
95ee27d5c0 | ||
|
|
b4f05d3fe7 | ||
|
|
8deafabe6b | ||
|
|
1bd1c76a2c | ||
|
|
56fd1da888 | ||
|
|
0956ce0cd3 | ||
|
|
d42bf9c941 | ||
|
|
d403587c7f | ||
|
|
355c985cc3 | ||
|
|
41742146e2 | ||
|
|
eb516e1998 | ||
|
|
0b1befa9ab | ||
|
|
bd678b1c95 | ||
|
|
56bef0b089 | ||
|
|
99fc1243cb | ||
|
|
a7205e4e36 | ||
|
|
ca14c5c9e1 | ||
|
|
5b69403ba8 | ||
|
|
83deb0233e | ||
|
|
8ebb6dd3d9 | ||
|
|
b7afd9b5b3 | ||
|
|
4987b4da1c | ||
|
|
a21b7792d8 | ||
|
|
8819cc30be | ||
|
|
9d1de81fe2 | ||
|
|
1e15b8c106 | ||
|
|
8d76b4e4d4 | ||
|
|
9662d1fdb6 | ||
|
|
b16717bbf8 | ||
|
|
c3217d8a08 | ||
|
|
2500153ed8 | ||
|
|
75a14e2a4b | ||
|
|
9bbd2b3f11 | ||
|
|
c26445253c |
22
.github/CODEOWNERS
vendored
22
.github/CODEOWNERS
vendored
@@ -1,32 +1,32 @@
|
||||
# continuous integration
|
||||
/.github/workflows/ @lstein @blessedcoolant
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @blessedcoolant
|
||||
/mkdocs.yml @lstein @blessedcoolant
|
||||
# documentation - anyone with write privileges can review
|
||||
/docs/
|
||||
/mkdocs.yml
|
||||
|
||||
# nodes
|
||||
/invokeai/app/ @blessedcoolant @lstein
|
||||
/invokeai/app/ @blessedcoolant @lstein @dunkeroni @JPPhoto
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @lstein @blessedcoolant
|
||||
/docker/ @lstein @blessedcoolant
|
||||
/scripts/ @lstein
|
||||
/installer/ @lstein
|
||||
/invokeai/assets @lstein
|
||||
/invokeai/configs @lstein
|
||||
/scripts/ @lstein @blessedcoolant
|
||||
/installer/ @lstein @blessedcoolant
|
||||
/invokeai/assets @lstein @blessedcoolant
|
||||
/invokeai/configs @lstein @blessedcoolant
|
||||
/invokeai/version @lstein @blessedcoolant
|
||||
|
||||
# web ui
|
||||
/invokeai/frontend @blessedcoolant @lstein
|
||||
/invokeai/frontend @blessedcoolant @lstein @dunkeroni
|
||||
|
||||
# generation, model management, postprocessing
|
||||
/invokeai/backend @lstein @blessedcoolant
|
||||
/invokeai/backend @lstein @blessedcoolant @dunkeroni @JPPhoto @Pfannkuchensack
|
||||
|
||||
# front ends
|
||||
/invokeai/frontend/CLI @lstein
|
||||
/invokeai/frontend/install @lstein
|
||||
/invokeai/frontend/merge @lstein @blessedcoolant
|
||||
/invokeai/frontend/training @lstein @blessedcoolant
|
||||
/invokeai/frontend/web @blessedcoolant @lstein
|
||||
/invokeai/frontend/web @blessedcoolant @lstein @dunkeroni @Pfannkuchensack
|
||||
|
||||
|
||||
6
.github/workflows/build-container.yml
vendored
6
.github/workflows/build-container.yml
vendored
@@ -53,8 +53,10 @@ jobs:
|
||||
df -h
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||
sudo swapoff /mnt/swapfile
|
||||
sudo rm -rf /mnt/swapfile
|
||||
if [ -f /mnt/swapfile ]; then
|
||||
sudo swapoff /mnt/swapfile
|
||||
sudo rm -rf /mnt/swapfile
|
||||
fi
|
||||
if [ -d /mnt ]; then
|
||||
sudo chmod -R 777 /mnt
|
||||
echo '{"data-root": "/mnt/docker-root"}' | sudo tee /etc/docker/daemon.json
|
||||
|
||||
1
.github/workflows/close-inactive-issues.yml
vendored
1
.github/workflows/close-inactive-issues.yml
vendored
@@ -23,6 +23,7 @@ jobs:
|
||||
close-issue-message: "Due to inactivity, this issue was automatically closed. If you are still experiencing the issue, please recreate the issue."
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
only-labels: "bug"
|
||||
exempt-issue-labels: "Active Issue"
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
operations-per-run: 500
|
||||
|
||||
6
.github/workflows/mkdocs-material.yml
vendored
6
.github/workflows/mkdocs-material.yml
vendored
@@ -22,12 +22,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: setup python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.10'
|
||||
python-version: '3.12'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
|
||||
6
.github/workflows/typegen-checks.yml
vendored
6
.github/workflows/typegen-checks.yml
vendored
@@ -46,8 +46,10 @@ jobs:
|
||||
df -h
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||
sudo swapoff /mnt/swapfile
|
||||
sudo rm -rf /mnt/swapfile
|
||||
if [ -f /mnt/swapfile ]; then
|
||||
sudo swapoff /mnt/swapfile
|
||||
sudo rm -rf /mnt/swapfile
|
||||
fi
|
||||
echo "----- Free space after cleanup"
|
||||
df -h
|
||||
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -192,3 +192,6 @@ installer/InvokeAI-Installer/
|
||||
.aider*
|
||||
|
||||
.claude/
|
||||
|
||||
# Weblate configuration file
|
||||
weblate.ini
|
||||
26
Makefile
26
Makefile
@@ -12,24 +12,25 @@ help:
|
||||
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
|
||||
@echo "test Run the unit tests."
|
||||
@echo "update-config-docstring Update the app's config docstring so mkdocs can autogenerate it correctly."
|
||||
@echo "frontend-install Install the pnpm modules needed for the front end"
|
||||
@echo "frontend-build Build the frontend in order to run on localhost:9090"
|
||||
@echo "frontend-install Install the pnpm modules needed for the frontend"
|
||||
@echo "frontend-build Build the frontend for localhost:9090"
|
||||
@echo "frontend-test Run the frontend test suite once"
|
||||
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
|
||||
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
||||
@echo "wheel Build the wheel for the current version"
|
||||
@echo "frontend-lint Run frontend checks and fixable lint/format steps"
|
||||
@echo "wheel Build the wheel for the current version"
|
||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
|
||||
@echo "docs Serve the mkdocs site with live reload"
|
||||
|
||||
# Runs ruff, fixing any safely-fixable errors and formatting
|
||||
ruff:
|
||||
ruff check . --fix
|
||||
ruff format .
|
||||
cd invokeai && uv tool run ruff@0.11.2 format
|
||||
|
||||
# Runs ruff, fixing all errors it can fix and formatting
|
||||
ruff-unsafe:
|
||||
ruff check . --fix --unsafe-fixes
|
||||
ruff format .
|
||||
ruff format
|
||||
|
||||
# Runs mypy, using the config in pyproject.toml
|
||||
mypy:
|
||||
@@ -57,6 +58,10 @@ frontend-install:
|
||||
frontend-build:
|
||||
cd invokeai/frontend/web && pnpm build
|
||||
|
||||
# Run the frontend test suite once
|
||||
frontend-test:
|
||||
cd invokeai/frontend/web && pnpm run test:run
|
||||
|
||||
# Run the frontend in dev mode
|
||||
frontend-dev:
|
||||
cd invokeai/frontend/web && pnpm dev
|
||||
@@ -64,6 +69,13 @@ frontend-dev:
|
||||
frontend-typegen:
|
||||
cd invokeai/frontend/web && python ../../../scripts/generate_openapi_schema.py | pnpm typegen
|
||||
|
||||
frontend-lint:
|
||||
cd invokeai/frontend/web/src && \
|
||||
pnpm lint:tsc && \
|
||||
pnpm lint:dpdm && \
|
||||
pnpm lint:eslint --fix && \
|
||||
pnpm lint:prettier --write
|
||||
|
||||
# Tag the release
|
||||
wheel:
|
||||
cd scripts && ./build_wheel.sh
|
||||
@@ -79,4 +91,4 @@ openapi:
|
||||
# Serve the mkdocs site w/ live reload
|
||||
.PHONY: docs
|
||||
docs:
|
||||
mkdocs serve
|
||||
mkdocs serve
|
||||
|
||||
36
README.md
36
README.md
@@ -16,6 +16,12 @@ Invoke is a leading creative engine built to empower professionals and enthusias
|
||||
|
||||

|
||||
|
||||
---
|
||||
> ## 📣 Are you a new or returning InvokeAI user?
|
||||
> Take our first annual [User's Survey](https://forms.gle/rCE5KuQ7Wfrd1UnS7)
|
||||
|
||||
---
|
||||
|
||||
# Documentation
|
||||
|
||||
| **Quick Links** |
|
||||
@@ -46,21 +52,45 @@ The Unified Canvas is a fully integrated canvas implementation with support for
|
||||
|
||||
### Workflows & Nodes
|
||||
|
||||
Invoke offers a fully featured workflow management solution, enabling users to combine the power of node-based workflows with the easy of a UI. This allows for customizable generation pipelines to be developed and shared by users looking to create specific workflows to support their production use-cases.
|
||||
Invoke offers a fully featured workflow management solution, enabling users to combine the power of node-based workflows with the ease of a UI. This allows for customizable generation pipelines to be developed and shared by users looking to create specific workflows to support their production use-cases.
|
||||
|
||||
### Board & Gallery Management
|
||||
|
||||
Invoke features an organized gallery system for easily storing, accessing, and remixing your content in the Invoke workspace. Images can be dragged/dropped onto any Image-base UI element in the application, and rich metadata within the Image allows for easy recall of key prompts or settings used in your workflow.
|
||||
|
||||
### Model Support
|
||||
- SD 1.5
|
||||
- SD 2.0
|
||||
- SDXL
|
||||
- SD 3.5 Medium
|
||||
- SD 3.5 Large
|
||||
- CogView 4
|
||||
- Flux.1 Dev
|
||||
- Flux.1 Schnell
|
||||
- Flux.1 Kontext
|
||||
- Flux.1 Krea
|
||||
- Flux Redux
|
||||
- Flux Fill
|
||||
- Flux.2 Klein 4B
|
||||
- Flux.2 Klein 9B
|
||||
- Z-Image Turbo
|
||||
- Z-Image Base
|
||||
- Anima
|
||||
- Qwen Image
|
||||
- Qwen Image Edit
|
||||
- Nano Banana (API Only)
|
||||
- GPT Image (API Only)
|
||||
- Wan (API Only)
|
||||
|
||||
### Other features
|
||||
|
||||
- Support for both ckpt and diffusers models
|
||||
- SD1.5, SD2.0, SDXL, and FLUX support
|
||||
- Support for ckpt, diffusers, and some gguf models
|
||||
- Upscaling Tools
|
||||
- Embedding Manager & Support
|
||||
- Model Manager & Support
|
||||
- Workflow creation & management
|
||||
- Node-Based Architecture
|
||||
- Object Segmentation & Selection Models (SAM / SAM2)
|
||||
|
||||
## Contributing
|
||||
|
||||
|
||||
169
USER_ISOLATION_IMPLEMENTATION.md
Normal file
169
USER_ISOLATION_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,169 @@
|
||||
# User Isolation Implementation Summary
|
||||
|
||||
This document describes the implementation of user isolation features in the InvokeAI session queue and processing system to address issues identified in the enhancement request.
|
||||
|
||||
## Issues Addressed
|
||||
|
||||
### 1. Cross-User Image/Preview Visibility
|
||||
**Problem:** When two users are logged in simultaneously and one initiates a generation, the generation preview shows up in both users' browsers and the generated image gets saved to both users' image boards.
|
||||
|
||||
**Solution:** Implemented socket-level event filtering based on user authentication:
|
||||
|
||||
#### Backend Changes (`invokeai/app/api/sockets.py`):
|
||||
- Added socket authentication middleware in `_handle_connect()` method
|
||||
- Extracts JWT token from socket auth data or HTTP headers
|
||||
- Verifies token using existing `verify_token()` function
|
||||
- Stores `user_id` and `is_admin` in socket session for later use
|
||||
- Modified `_handle_queue_event()` to filter events by user:
|
||||
- For `QueueItemEventBase` events, only emit to:
|
||||
- The user who owns the queue item (`user_id` matches)
|
||||
- Admin users (`is_admin` is True)
|
||||
- For general queue events, emit to all subscribers
|
||||
|
||||
#### Event System Changes (`invokeai/app/services/events/events_common.py`):
|
||||
- Added `user_id` field to `QueueItemEventBase` class
|
||||
- Updated all event builders to include `user_id` from queue items:
|
||||
- `InvocationStartedEvent.build()`
|
||||
- `InvocationProgressEvent.build()`
|
||||
- `InvocationCompleteEvent.build()`
|
||||
- `InvocationErrorEvent.build()`
|
||||
- `QueueItemStatusChangedEvent.build()`
|
||||
|
||||
### 2. Batch Field Values Privacy
|
||||
**Problem:** Users can see batch field values from generation processes launched by other users.
|
||||
|
||||
**Solution:** Implemented field value sanitization at the API level:
|
||||
|
||||
#### API Router Changes (`invokeai/app/api/routers/session_queue.py`):
|
||||
- Created `sanitize_queue_item_for_user()` helper function
|
||||
- Clears `field_values` for non-admin users viewing other users' items
|
||||
- Admins and item owners can see all field values
|
||||
- Updated endpoints to require authentication and sanitize responses:
|
||||
- `list_all_queue_items()` - Added `CurrentUser` dependency
|
||||
- `get_queue_items_by_item_ids()` - Added `CurrentUser` dependency
|
||||
- `get_queue_item()` - Added `CurrentUser` dependency
|
||||
|
||||
### 3. Queue Updates Across Browser Windows
|
||||
**Problem:** When the job queue tab is open in multiple browsers and a generation is begun in one browser window, the queue does not update in the other window.
|
||||
|
||||
**Status:** This issue is likely resolved by the socket authentication and event filtering changes. The existing socket subscription mechanism (`subscribe_queue` event) already supports multiple connections per user. Testing is required to confirm this works correctly with the new authentication flow.
|
||||
|
||||
### 4. User Information Display
|
||||
**Problem:** Queue table lacks user identification, making it difficult to know who launched which job.
|
||||
|
||||
**Solution:** Added user information to queue items and UI:
|
||||
|
||||
#### Database Layer (`invokeai/app/services/session_queue/session_queue_sqlite.py`):
|
||||
- Updated SQL queries to JOIN with `users` table
|
||||
- Modified methods to fetch user information:
|
||||
- `get_queue_item()` - Now selects `display_name` and `email` from users table
|
||||
- `dequeue()` - Includes user info
|
||||
- `get_next()` - Includes user info
|
||||
- `get_current()` - Includes user info
|
||||
- `list_all_queue_items()` - Includes user info
|
||||
|
||||
#### Data Model Changes (`invokeai/app/services/session_queue/session_queue_common.py`):
|
||||
- Added optional fields to `SessionQueueItem`:
|
||||
- `user_display_name: Optional[str]` - Display name from users table
|
||||
- `user_email: Optional[str]` - Email from users table
|
||||
- Note: `user_id` field already existed from Migration 25
|
||||
|
||||
#### Frontend UI Changes:
|
||||
- **Constants** (`constants.ts`): Added `user: '8rem'` column width
|
||||
- **Header** (`QueueListHeader.tsx`): Added "User" column header
|
||||
- **Item Component** (`QueueItemComponent.tsx`):
|
||||
- Added logic to display user information (display_name → email → user_id)
|
||||
- Added user column to queue item row
|
||||
- Added tooltip with full username on hover
|
||||
- Added "Hidden for privacy" message when field_values are null for non-owned items
|
||||
- **Localization** (`en.json`): Added translations:
|
||||
- `"user": "User"`
|
||||
- `"fieldValuesHidden": "Hidden for privacy"`
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### Token Verification
|
||||
- Tokens are verified using the existing `verify_token()` function from `invokeai.app.services.auth.token_service`
|
||||
- Invalid or missing tokens default to "system" user with non-admin privileges
|
||||
- Socket connections without valid tokens are still accepted for backward compatibility but have limited access
|
||||
|
||||
### Data Privacy
|
||||
- Field values are only visible to:
|
||||
- The user who created the queue item
|
||||
- Admin users
|
||||
- Non-admin users viewing other users' queue items see "Hidden for privacy" instead of field values
|
||||
|
||||
### Admin Privileges
|
||||
- Admin users can see all queue events and field values across all users
|
||||
- Admin status is determined from the JWT token's `is_admin` field
|
||||
|
||||
## Migration Notes
|
||||
|
||||
No database migration is required. The changes leverage:
|
||||
- Existing `user_id` column in `session_queue` table (added in Migration 25)
|
||||
- Existing `users` table (added in Migration 25)
|
||||
- SQL LEFT JOINs to fetch user information (gracefully handles missing user records)
|
||||
|
||||
## Testing Requirements
|
||||
|
||||
### Backend Testing
|
||||
1. **Socket Authentication:**
|
||||
- Verify valid tokens are accepted and user context is stored
|
||||
- Verify invalid tokens default to system user
|
||||
- Verify expired tokens are rejected
|
||||
|
||||
2. **Event Filtering:**
|
||||
- User A should only receive events for their own queue items
|
||||
- Admin users should receive all events
|
||||
- Non-admin users should not receive events from other users
|
||||
|
||||
3. **Field Value Sanitization:**
|
||||
- Non-admin users should see null field_values for other users' items
|
||||
- Admins should see all field values
|
||||
- Users should see their own field values
|
||||
|
||||
### Frontend Testing
|
||||
1. **UI Display:**
|
||||
- User column should display in queue list
|
||||
- Display name should be shown when available
|
||||
- Email should be shown as fallback when display name is missing
|
||||
- User ID should be shown when both display name and email are missing
|
||||
- Tooltip should show full username on hover
|
||||
|
||||
2. **Field Values Display:**
|
||||
- "Hidden for privacy" message should appear when viewing other users' items
|
||||
- Own items should show field values normally
|
||||
|
||||
3. **Multi-Browser Testing:**
|
||||
- Open queue tab in two browsers with different users
|
||||
- Start generation in one browser
|
||||
- Verify other browser doesn't see the preview/progress
|
||||
- Verify admin user can see all generations
|
||||
|
||||
### Integration Testing
|
||||
1. Multi-user scenarios with simultaneous generations
|
||||
2. Queue updates across multiple browser windows
|
||||
3. Admin vs. non-admin privilege differentiation
|
||||
4. Socket reconnection handling
|
||||
|
||||
## Known Limitations
|
||||
|
||||
1. **TypeScript Types:**
|
||||
- The OpenAPI schema needs to be regenerated to include new fields
|
||||
- Run: `cd invokeai/frontend/web && python ../../../scripts/generate_openapi_schema.py | pnpm typegen`
|
||||
|
||||
2. **Backward Compatibility:**
|
||||
- System user ("system") entries will not have display name or email
|
||||
- Existing queue items from before Migration 25 will have user_id="system"
|
||||
|
||||
3. **Socket.IO Session Storage:**
|
||||
- Socket.IO's in-memory session storage may not persist across server restarts
|
||||
- Consider implementing persistent session storage if needed for production
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
1. Add user filtering to queue list (show only my items vs. all items)
|
||||
2. Add permission system for queue management operations (cancel, retry, delete)
|
||||
3. Implement queue item ownership transfer for administrative purposes
|
||||
4. Add audit logging for queue operations with user attribution
|
||||
5. Consider implementing user-specific queue limits or quotas
|
||||
@@ -16,7 +16,9 @@ The launcher uses GitHub as the source of truth for available releases.
|
||||
|
||||
## General Prep
|
||||
|
||||
Make a developer call-out for PRs to merge. Merge and test things out. Bump the version by editing `invokeai/version/invokeai_version.py`.
|
||||
Make a developer call-out for PRs to merge. Merge and test things
|
||||
out. Create a branch with a name like user/chore/vX.X.X-prep and bump the version by editing
|
||||
`invokeai/version/invokeai_version.py` and commit locally.
|
||||
|
||||
## Release Workflow
|
||||
|
||||
@@ -26,14 +28,14 @@ It is triggered on **tag push**, when the tag matches `v*`.
|
||||
|
||||
### Triggering the Workflow
|
||||
|
||||
Ensure all commits that should be in the release are merged, and you have pulled them locally.
|
||||
|
||||
Double-check that you have checked out the commit that will represent the release (typically the latest commit on `main`).
|
||||
Ensure all commits that should be in the release are merged into this branch, and that you have pulled them locally.
|
||||
|
||||
Run `make tag-release` to tag the current commit and kick off the workflow. You will be prompted to provide a message - use the version specifier.
|
||||
|
||||
If this version's tag already exists for some reason (maybe you had to make a last minute change), the script will overwrite it.
|
||||
|
||||
Push the commit to trigger the workflow.
|
||||
|
||||
> In case you cannot use the Make target, the release may also be dispatched [manually] via GH.
|
||||
|
||||
### Workflow Jobs and Process
|
||||
|
||||
BIN
docs/assets/multiuser/admin-add-user-1.png
Normal file
BIN
docs/assets/multiuser/admin-add-user-1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 13 KiB |
BIN
docs/assets/multiuser/admin-add-user-2.png
Normal file
BIN
docs/assets/multiuser/admin-add-user-2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 14 KiB |
BIN
docs/assets/multiuser/admin-add-user-3.png
Normal file
BIN
docs/assets/multiuser/admin-add-user-3.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 18 KiB |
BIN
docs/assets/multiuser/admin-setup.png
Normal file
BIN
docs/assets/multiuser/admin-setup.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 40 KiB |
BIN
docs/assets/multiuser/user-login-1.png
Normal file
BIN
docs/assets/multiuser/user-login-1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 17 KiB |
205
docs/contributing/CANVAS_PROJECTS/CANVAS_PROJECTS.md
Normal file
205
docs/contributing/CANVAS_PROJECTS/CANVAS_PROJECTS.md
Normal file
@@ -0,0 +1,205 @@
|
||||
# Canvas Projects — Technical Documentation
|
||||
|
||||
## Overview
|
||||
|
||||
Canvas Projects provide a save/load mechanism for the entire canvas state. The feature serializes all canvas entities, generation parameters, reference images, and their associated image files into a ZIP-based `.invk` file. On load, it restores the full state, handling image deduplication and re-uploading as needed.
|
||||
|
||||
## File Format
|
||||
|
||||
The `.invk` file is a standard ZIP archive with the following structure:
|
||||
|
||||
```
|
||||
project.invk
|
||||
├── manifest.json
|
||||
├── canvas_state.json
|
||||
├── params.json
|
||||
├── ref_images.json
|
||||
├── loras.json
|
||||
└── images/
|
||||
├── {image_name_1}.png
|
||||
├── {image_name_2}.png
|
||||
└── ...
|
||||
```
|
||||
|
||||
### manifest.json
|
||||
|
||||
Schema version and metadata. Validated on load with Zod.
|
||||
|
||||
```json
|
||||
{
|
||||
"version": 1,
|
||||
"appVersion": "5.12.0",
|
||||
"createdAt": "2026-02-26T12:00:00.000Z",
|
||||
"name": "My Canvas Project"
|
||||
}
|
||||
```
|
||||
|
||||
| Field | Type | Description |
|
||||
|---|---|---|
|
||||
| `version` | `number` | Schema version, currently `1`. Used for migration logic on load. |
|
||||
| `appVersion` | `string` | InvokeAI version that created the file. Informational only. |
|
||||
| `createdAt` | `string` | ISO 8601 timestamp. |
|
||||
| `name` | `string` | User-provided project name. Also used as the download filename. |
|
||||
|
||||
### canvas_state.json
|
||||
|
||||
The serialized canvas entity tree. Type: `CanvasProjectState`.
|
||||
|
||||
```typescript
|
||||
type CanvasProjectState = {
|
||||
rasterLayers: CanvasRasterLayerState[];
|
||||
controlLayers: CanvasControlLayerState[];
|
||||
inpaintMasks: CanvasInpaintMaskState[];
|
||||
regionalGuidance: CanvasRegionalGuidanceState[];
|
||||
bbox: CanvasState['bbox'];
|
||||
selectedEntityIdentifier: CanvasState['selectedEntityIdentifier'];
|
||||
bookmarkedEntityIdentifier: CanvasState['bookmarkedEntityIdentifier'];
|
||||
};
|
||||
```
|
||||
|
||||
Each entity contains its full state including all canvas objects (brush lines, eraser lines, rect shapes, images). Image objects reference files by `image_name` which correspond to files in the `images/` folder.
|
||||
|
||||
### params.json
|
||||
|
||||
The complete generation parameters state (`ParamsState`). Optional on load (older files may not have it). This includes all fields from the params Redux slice:
|
||||
|
||||
- Prompts (positive, negative, prompt history)
|
||||
- Core generation settings (seed, steps, CFG scale, guidance, scheduler, iterations)
|
||||
- Model selections (main model, VAE, FLUX VAE, T5 encoder, CLIP embed models, refiner, Z-Image models, Klein models)
|
||||
- Dimensions (width, height, aspect ratio)
|
||||
- Img2img strength
|
||||
- Infill settings (method, tile size, patchmatch downscale, color)
|
||||
- Canvas coherence settings (mode, edge size, min denoise)
|
||||
- Refiner parameters (steps, CFG scale, scheduler, aesthetic scores, start)
|
||||
- FLUX-specific settings (scheduler, DyPE preset/scale/exponent)
|
||||
- Z-Image-specific settings (scheduler, seed variance)
|
||||
- Upscale settings (scheduler, CFG scale)
|
||||
- Seamless tiling, mask blur, CLIP skip, VAE precision, CPU noise, color compensation
|
||||
|
||||
### ref_images.json
|
||||
|
||||
Global reference image entities (`RefImageState[]`). These are IP-Adapter / FLUX Redux configs with `CroppableImageWithDims` containing both original and cropped image references. Optional on load.
|
||||
|
||||
### loras.json
|
||||
|
||||
Array of LoRA configurations (`LoRA[]`). Each entry contains:
|
||||
|
||||
```typescript
|
||||
type LoRA = {
|
||||
id: string;
|
||||
isEnabled: boolean;
|
||||
model: ModelIdentifierField;
|
||||
weight: number;
|
||||
};
|
||||
```
|
||||
|
||||
Optional on load. Like models, LoRA identifiers are stored as-is — if a LoRA is not installed when loading, the entry is restored but may not be usable.
|
||||
|
||||
### images/
|
||||
|
||||
All image files referenced anywhere in the state. Keyed by their original `image_name`. On save, each image is fetched from the backend via `GET /api/v1/images/i/{name}/full` and stored as-is.
|
||||
|
||||
## Key Source Files
|
||||
|
||||
| File | Purpose |
|
||||
|---|---|
|
||||
| `features/controlLayers/util/canvasProjectFile.ts` | Types, constants, image name collection, remapping, existence checking |
|
||||
| `features/controlLayers/hooks/useCanvasProjectSave.ts` | Save hook — collects Redux state, fetches images, builds ZIP |
|
||||
| `features/controlLayers/hooks/useCanvasProjectLoad.ts` | Load hook — parses ZIP, deduplicates images, dispatches state |
|
||||
| `features/controlLayers/components/SaveCanvasProjectDialog.tsx` | Save name dialog + `useSaveCanvasProjectWithDialog` hook |
|
||||
| `features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog.tsx` | Load confirmation dialog + `useLoadCanvasProjectWithDialog` hook |
|
||||
| `features/controlLayers/components/Toolbar/CanvasToolbarProjectMenuButton.tsx` | Toolbar dropdown UI |
|
||||
| `features/controlLayers/store/canvasSlice.ts` | `canvasProjectRecalled` Redux action |
|
||||
|
||||
## Save Flow
|
||||
|
||||
1. User clicks "Save Canvas Project" → `SaveCanvasProjectDialog` opens asking for a project name
|
||||
2. On confirm, `saveCanvasProject(name)` is called
|
||||
3. Read Redux state via selectors: `selectCanvasSlice()`, `selectParamsSlice()`, `selectRefImagesSlice()`, `selectLoRAsSlice()`
|
||||
4. Build `CanvasProjectState` from the canvas slice; use `paramsState` directly for params
|
||||
5. Walk all entities to collect every `image_name` reference via `collectImageNames()`:
|
||||
- `CanvasImageState.image.image_name` in layer/mask objects
|
||||
- `CroppableImageWithDims.original.image.image_name` in global ref images
|
||||
- `CroppableImageWithDims.crop.image.image_name` in cropped ref images
|
||||
- `ImageWithDims.image_name` in regional guidance ref images
|
||||
6. Fetch each image from the backend API
|
||||
7. Build ZIP with JSZip: add `manifest.json` (including `name`), `canvas_state.json`, `params.json`, `ref_images.json`, and all images into `images/`
|
||||
8. Sanitize the name for filesystem use and generate blob, trigger download as `{name}.invk`
|
||||
|
||||
## Load Flow
|
||||
|
||||
1. User selects `.invk` file → confirmation dialog opens
|
||||
2. On confirm, parse ZIP with JSZip
|
||||
3. Validate manifest version via Zod schema
|
||||
4. Read `canvas_state.json`, `params.json` (optional), `ref_images.json` (optional)
|
||||
5. Collect all `image_name` references from the loaded state
|
||||
6. **Deduplicate images**: for each referenced image, check if it exists on the server via `getImageDTOSafe(image_name)`
|
||||
- Already exists → skip (no upload)
|
||||
- Missing → upload from ZIP via `uploadImage()`, record `oldName → newName` mapping
|
||||
7. Remap all `image_name` values in the loaded state using the mapping (only for re-uploaded images whose names changed)
|
||||
8. Dispatch Redux actions:
|
||||
- `canvasProjectRecalled()` — restores all canvas entities, bbox, selected/bookmarked entity
|
||||
- `refImagesRecalled()` — restores global reference images
|
||||
- `paramsRecalled()` — replaces the entire params state in one action
|
||||
- `loraAllDeleted()` + `loraRecalled()` — restores LoRAs
|
||||
9. Show success/error toast
|
||||
|
||||
## Image Name Collection & Remapping
|
||||
|
||||
The `canvasProjectFile.ts` utility provides two parallel sets of functions:
|
||||
|
||||
**Collection** (`collectImageNames`): Walks the entire state tree and returns a `Set<string>` of all referenced `image_name` values. This is used by both save (to know which images to fetch) and load (to know which images to check/upload).
|
||||
|
||||
**Remapping** (`remapCanvasState`, `remapRefImages`): Deep-clones state objects and replaces `image_name` values using a `Map<string, string>` mapping. Only images that were re-uploaded with a different name are remapped. Images that already existed on the server are left unchanged.
|
||||
|
||||
Both walk the same paths through the state tree:
|
||||
- Layer/mask objects → `CanvasImageState.image.image_name`
|
||||
- Regional guidance ref images → `ImageWithDims.image_name`
|
||||
- Global ref images → `CroppableImageWithDims.original.image.image_name` and `.crop.image.image_name`
|
||||
|
||||
## Extending the Format
|
||||
|
||||
### Adding new optional data (non-breaking)
|
||||
|
||||
Add a new JSON file to the ZIP. No version bump needed.
|
||||
|
||||
1. **Save**: Add `zip.file('new_data.json', JSON.stringify(data))` in `useCanvasProjectSave.ts`
|
||||
2. **Load**: Read with `zip.file('new_data.json')` in `useCanvasProjectLoad.ts` — check for `null` so older project files without it still load
|
||||
3. **Dispatch**: Add the appropriate Redux action to restore the data
|
||||
|
||||
### Adding new entity types with images
|
||||
|
||||
1. Extend `CanvasProjectState` type in `canvasProjectFile.ts`
|
||||
2. Add collection logic in `collectImageNames()` to walk the new entity's objects
|
||||
3. Add remapping logic in `remapCanvasState()` to update image names
|
||||
4. Include the new entity array in both save and load hooks
|
||||
5. Handle it in the `canvasProjectRecalled` reducer in `canvasSlice.ts`
|
||||
|
||||
### Breaking schema changes
|
||||
|
||||
1. Bump `CANVAS_PROJECT_VERSION` in `canvasProjectFile.ts`
|
||||
2. Update the Zod manifest schema: `version: z.union([z.literal(1), z.literal(2)])`
|
||||
3. Add migration logic in the load hook: check version, transform v1 → v2 before dispatching
|
||||
|
||||
## UI Architecture
|
||||
|
||||
### Save dialog
|
||||
|
||||
The save flow uses a **nanostore atom** (`$isOpen`) to control the `SaveCanvasProjectDialog`:
|
||||
|
||||
1. `useSaveCanvasProjectWithDialog()` — returns a callback that sets `$isOpen` to `true`
|
||||
2. `SaveCanvasProjectDialog` (singleton in `GlobalModalIsolator`) — renders an `AlertDialog` with a name input
|
||||
3. On save → calls `saveCanvasProject(name)` and closes the dialog
|
||||
4. On cancel → closes the dialog
|
||||
|
||||
### Load dialog
|
||||
|
||||
The load flow uses a **nanostore atom** (`$pendingFile`) to decouple the file dialog from the confirmation dialog:
|
||||
|
||||
1. `useLoadCanvasProjectWithDialog()` — opens a programmatic file input (`document.createElement('input')`)
|
||||
2. On file selection → sets `$pendingFile` atom
|
||||
3. `LoadCanvasProjectConfirmationAlertDialog` (singleton in `GlobalModalIsolator`) — subscribes to `$pendingFile` via `useStore()`
|
||||
4. On accept → calls `loadCanvasProject(file)` and clears the atom
|
||||
5. On cancel → clears the atom
|
||||
|
||||
The programmatic file input approach was chosen because the context menu component uses `isLazy: true`, which unmounts the DOM tree when the menu closes — a hidden `<input>` element inside the menu would be destroyed before the file dialog returns.
|
||||
129
docs/contributing/EXTERNAL_PROVIDERS.md
Normal file
129
docs/contributing/EXTERNAL_PROVIDERS.md
Normal file
@@ -0,0 +1,129 @@
|
||||
# External Provider Integration
|
||||
|
||||
This guide covers:
|
||||
|
||||
1. Adding a new **external model** (most common; existing provider).
|
||||
2. Adding a brand-new **external provider** (adapter + config + UI wiring).
|
||||
|
||||
## 1) Add a New External Model (Existing Provider)
|
||||
|
||||
For provider-backed models (for example, OpenAI or Gemini), the source of truth is
|
||||
`invokeai/backend/model_manager/starter_models.py`.
|
||||
|
||||
### Required model fields
|
||||
|
||||
Define a `StarterModel` with:
|
||||
|
||||
- `base=BaseModelType.External`
|
||||
- `type=ModelType.ExternalImageGenerator`
|
||||
- `format=ModelFormat.ExternalApi`
|
||||
- `source="external://<provider_id>/<provider_model_id>"`
|
||||
- `name`, `description`
|
||||
- `capabilities=ExternalModelCapabilities(...)`
|
||||
- optional `default_settings=ExternalApiModelDefaultSettings(...)`
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
new_external_model = StarterModel(
|
||||
name="Provider Model Name",
|
||||
base=BaseModelType.External,
|
||||
source="external://openai/my-model-id",
|
||||
description=(
|
||||
"Provider model (external API). "
|
||||
"Requires a configured OpenAI API key and may incur provider usage costs."
|
||||
),
|
||||
type=ModelType.ExternalImageGenerator,
|
||||
format=ModelFormat.ExternalApi,
|
||||
capabilities=ExternalModelCapabilities(
|
||||
modes=["txt2img", "img2img", "inpaint"],
|
||||
supports_negative_prompt=False,
|
||||
supports_seed=False,
|
||||
supports_guidance=False,
|
||||
supports_steps=False,
|
||||
supports_reference_images=True,
|
||||
max_images_per_request=4,
|
||||
),
|
||||
default_settings=ExternalApiModelDefaultSettings(
|
||||
width=1024,
|
||||
height=1024,
|
||||
num_images=1,
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
Then append it to `STARTER_MODELS`.
|
||||
|
||||
### Required description text
|
||||
|
||||
External starter model descriptions must clearly state:
|
||||
|
||||
- an API key is required
|
||||
- usage may incur provider-side costs
|
||||
|
||||
### Capabilities must be accurate
|
||||
|
||||
These flags directly control UI visibility and request payload fields:
|
||||
|
||||
- `supports_negative_prompt`
|
||||
- `supports_seed`
|
||||
- `supports_guidance`
|
||||
- `supports_steps`
|
||||
- `supports_reference_images`
|
||||
|
||||
`supports_steps` is especially important: if `False`, steps are hidden for that model and `steps` is sent as `null`.
|
||||
|
||||
### Source string stability
|
||||
|
||||
Starter overrides are matched by `source` (`external://provider/model-id`). Keep this stable:
|
||||
|
||||
- runtime capability/default overrides depend on it
|
||||
- installation detection in starter-model APIs depends on it
|
||||
|
||||
`STARTER_MODELS` enforces unique `source` values with an assertion.
|
||||
|
||||
### Install behavior notes
|
||||
|
||||
- External starter models are managed in **External Providers** setup (not the regular Starter Models tab).
|
||||
- External starter models auto-install when a provider is configured.
|
||||
- Removing a provider API key removes installed external models for that provider.
|
||||
|
||||
## 2) Credentials and Config
|
||||
|
||||
External provider API keys are stored separately from `invokeai.yaml`:
|
||||
|
||||
- default file: `~/invokeai/api_keys.yaml`
|
||||
- resolved path: `<INVOKEAI_ROOT>/api_keys.yaml`
|
||||
|
||||
Non-secret provider settings (for example base URL overrides) stay in `invokeai.yaml`.
|
||||
|
||||
Environment variables are still supported, e.g.:
|
||||
|
||||
- `INVOKEAI_EXTERNAL_GEMINI_API_KEY`
|
||||
- `INVOKEAI_EXTERNAL_OPENAI_API_KEY`
|
||||
|
||||
## 3) Add a New Provider (Only If Needed)
|
||||
|
||||
If your model uses a provider that is not already integrated:
|
||||
|
||||
1. Add config fields in `invokeai/app/services/config/config_default.py`
|
||||
`external_<provider>_api_key` and optional `external_<provider>_base_url`.
|
||||
2. Add provider field mapping in `invokeai/app/api/routers/app_info.py`
|
||||
(`EXTERNAL_PROVIDER_FIELDS`).
|
||||
3. Implement provider adapter in `invokeai/app/services/external_generation/providers/`
|
||||
by subclassing `ExternalProvider`.
|
||||
4. Register the provider in `invokeai/app/api/dependencies.py` when building
|
||||
`ExternalGenerationService`.
|
||||
5. Add starter model entries using `source="external://<provider>/<model-id>"`.
|
||||
6. Optional UI ordering tweak:
|
||||
`invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ExternalProviders/ExternalProvidersForm.tsx`
|
||||
(`PROVIDER_SORT_ORDER`).
|
||||
|
||||
## 4) Optional Manual Installation
|
||||
|
||||
You can also install external models directly via:
|
||||
|
||||
`POST /api/v2/models/install?source=external://<provider_id>/<provider_model_id>`
|
||||
|
||||
If omitted, `path`, `source`, and `hash` are auto-populated for external model configs.
|
||||
Set capabilities conservatively; the external generation service enforces capability checks at runtime.
|
||||
1254
docs/contributing/NEW_MODEL_INTEGRATION.md
Normal file
1254
docs/contributing/NEW_MODEL_INTEGRATION.md
Normal file
File diff suppressed because it is too large
Load Diff
64
docs/contributing/PR-MERGE-POLICY.md
Normal file
64
docs/contributing/PR-MERGE-POLICY.md
Normal file
@@ -0,0 +1,64 @@
|
||||
# Pull Request Merge Policy
|
||||
|
||||
This document outlines the process for reviewing and merging pull requests (PRs) into the InvokeAI repository.
|
||||
|
||||
## Review Process
|
||||
|
||||
### 1. Assignment
|
||||
|
||||
One of the repository maintainers will assign collaborators to review a pull request. The assigned reviewer(s) will be responsible for conducting the code review.
|
||||
|
||||
### 2. Review and Iteration
|
||||
|
||||
The assignee is responsible for:
|
||||
- Reviewing the PR thoroughly
|
||||
- Providing constructive feedback
|
||||
- Iterating with the PR author until the assignee is satisfied that the PR is fit to merge
|
||||
- Ensuring the PR meets code quality standards, follows project conventions, and doesn't introduce bugs or regressions
|
||||
|
||||
### 3. Approval and Notification
|
||||
|
||||
Once the assignee is satisfied with the PR:
|
||||
- The assignee approves the PR
|
||||
- The assignee alerts one of the maintainers that the PR is ready for merge using the **#request-reviews Discord channel**
|
||||
|
||||
### 4. Final Merge
|
||||
|
||||
One of the maintainers is responsible for:
|
||||
- Performing a final check of the PR
|
||||
- Merging the PR into the appropriate branch
|
||||
|
||||
**Important:** Collaborators are strongly discouraged from merging PRs on their own, except in case of emergency (e.g., critical bug fix and no maintainer is available).
|
||||
|
||||
### 5. Release Policy
|
||||
|
||||
Once a feature release candidate is published, no feature PRs are to
|
||||
be merged into main. Only bugfixes are allowed until the final
|
||||
release.
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Clean Commit History
|
||||
|
||||
To encourage a clean development log, PR authors are encouraged to use `git rebase -i` to suppress trivial commit messages (e.g., `ruff` and `prettier` formatting fixes) after the PR is accepted but before it is merged.
|
||||
|
||||
### Merge Strategy
|
||||
|
||||
The maintainer will perform either a **3-way merge** or **squash merge** when merging a PR into the `main` branch. This approach helps avoid rebase conflict hell and maintains a cleaner project history.
|
||||
|
||||
### Attribution
|
||||
|
||||
The PR author should reference any papers, source code or
|
||||
documentation that they used while creating the code both in the PR
|
||||
and as comments in the code itself. If there are any licensing
|
||||
restrictions, these should be linked to and/or reproduced in the repo
|
||||
root.
|
||||
|
||||
|
||||
## Summary
|
||||
|
||||
This policy ensures that:
|
||||
- All PRs receive proper review from assigned collaborators
|
||||
- Maintainers have final oversight before code enters the main branch
|
||||
- The commit history remains clean and meaningful
|
||||
- Merge conflicts are minimized through appropriate merge strategies
|
||||
@@ -0,0 +1,375 @@
|
||||
# Recall Parameters API - LoRAs, ControlNets, and IP Adapters with Images
|
||||
|
||||
## Overview
|
||||
|
||||
The Recall Parameters API supports recalling LoRAs, ControlNets (including T2I Adapters and Control LoRAs), and IP Adapters along with their associated weights and settings. Control Layers and IP Adapters can now include image references from the `INVOKEAI_ROOT/outputs/images` directory for fully functional control and image prompt functionality.
|
||||
|
||||
## Key Features
|
||||
|
||||
✅ **LoRAs**: Fully functional - adds to UI, queries model configs, applies weights
|
||||
✅ **Control Layers**: Full support with optional images from outputs/images
|
||||
✅ **IP Adapters**: Full support with optional reference images from outputs/images
|
||||
✅ **Model Name Resolution**: Automatic lookup from human-readable names to internal keys
|
||||
✅ **Image Validation**: Backend validates that image files exist before sending
|
||||
|
||||
## Endpoints
|
||||
|
||||
### POST `/api/v1/recall/{queue_id}`
|
||||
|
||||
Updates recallable parameters for the frontend, including LoRAs, control adapters, and IP adapters with optional images.
|
||||
|
||||
**Path Parameters:**
|
||||
- `queue_id` (string): The queue ID to associate parameters with (typically "default")
|
||||
|
||||
**Request Body:**
|
||||
|
||||
All fields are optional. Include only the parameters you want to update.
|
||||
|
||||
```typescript
|
||||
{
|
||||
// Standard parameters
|
||||
positive_prompt?: string;
|
||||
negative_prompt?: string;
|
||||
model?: string; // Model name or key
|
||||
steps?: number;
|
||||
cfg_scale?: number;
|
||||
width?: number;
|
||||
height?: number;
|
||||
seed?: number;
|
||||
// ... other standard parameters
|
||||
|
||||
// LoRAs
|
||||
loras?: Array<{
|
||||
model_name: string; // LoRA model name
|
||||
weight?: number; // Default: 0.75, Range: -10 to 10
|
||||
is_enabled?: boolean; // Default: true
|
||||
}>;
|
||||
|
||||
// Control Layers (ControlNet, T2I Adapter, Control LoRA)
|
||||
control_layers?: Array<{
|
||||
model_name: string; // Control adapter model name
|
||||
image_name?: string; // Optional image filename from outputs/images
|
||||
weight?: number; // Default: 1.0, Range: -1 to 2
|
||||
begin_step_percent?: number; // Default: 0.0, Range: 0 to 1
|
||||
end_step_percent?: number; // Default: 1.0, Range: 0 to 1
|
||||
control_mode?: "balanced" | "more_prompt" | "more_control"; // ControlNet only
|
||||
}>;
|
||||
|
||||
// IP Adapters
|
||||
ip_adapters?: Array<{
|
||||
model_name: string; // IP Adapter model name
|
||||
image_name?: string; // Optional reference image filename from outputs/images
|
||||
weight?: number; // Default: 1.0, Range: -1 to 2
|
||||
begin_step_percent?: number; // Default: 0.0, Range: 0 to 1
|
||||
end_step_percent?: number; // Default: 1.0, Range: 0 to 1
|
||||
method?: "full" | "style" | "composition"; // Default: "full"
|
||||
influence?: "Lowest" | "Low" | "Medium" | "High" | "Highest"; // Flux Redux only; default: "highest"
|
||||
}>;
|
||||
}
|
||||
```
|
||||
|
||||
## Model Name Resolution
|
||||
|
||||
The backend automatically resolves model names to their internal keys:
|
||||
|
||||
1. **Main Models**: Resolved from the name to the model key
|
||||
2. **LoRAs**: Searched in the LoRA model database
|
||||
3. **Control Adapters**: Tried in order - ControlNet → T2I Adapter → Control LoRA
|
||||
4. **IP Adapters**: Searched in the IP Adapter model database
|
||||
|
||||
Models that cannot be resolved are skipped with a warning in the logs.
|
||||
|
||||
## Image File Handling
|
||||
|
||||
### Image Path Resolution
|
||||
|
||||
When you specify an `image_name`, the backend:
|
||||
1. Constructs the full path: `{INVOKEAI_ROOT}/outputs/images/{image_name}`
|
||||
2. Validates that the file exists
|
||||
3. Includes the image reference in the event sent to the frontend
|
||||
4. Logs whether the image was found or not
|
||||
|
||||
### Image Naming
|
||||
|
||||
Images should be referenced by their filename as it appears in the outputs/images directory:
|
||||
- ✅ Correct: `"image_name": "example.png"`
|
||||
- ✅ Correct: `"image_name": "my_control_image_20240110.jpg"`
|
||||
- ❌ Incorrect: `"image_name": "outputs/images/example.png"` (use relative filename only)
|
||||
- ❌ Incorrect: `"image_name": "/full/path/to/example.png"` (use relative filename only)
|
||||
|
||||
## Frontend Behavior
|
||||
|
||||
### LoRAs
|
||||
- **Fully Supported**: LoRAs are immediately added to the LoRA list in the UI
|
||||
- Existing LoRAs are cleared before adding new ones
|
||||
- Each LoRA's model config is fetched and applied with the specified weight
|
||||
- LoRAs appear in the LoRA selector panel
|
||||
|
||||
### Control Layers with Images
|
||||
- **Fully Supported**: Control layers now support images from outputs/images
|
||||
- Configuration includes model, weights, step percentages, and image reference
|
||||
- Image availability is logged in frontend console
|
||||
- Images can be used to create actual control layers through the UI
|
||||
|
||||
### IP Adapters with Images
|
||||
- **Fully Supported**: IP Adapters now support reference images from outputs/images
|
||||
- Configuration includes model, weights, step percentages, method, and image reference
|
||||
- Image availability is logged in frontend console
|
||||
- Images can be used to create actual reference image layers through the UI
|
||||
|
||||
## Examples
|
||||
|
||||
### 1. Add LoRAs Only
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:9090/api/v1/recall/default \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"loras": [
|
||||
{
|
||||
"model_name": "add-detail-xl",
|
||||
"weight": 0.8,
|
||||
"is_enabled": true
|
||||
},
|
||||
{
|
||||
"model_name": "sd_xl_offset_example-lora_1.0",
|
||||
"weight": 0.5,
|
||||
"is_enabled": true
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
### 2. Configure Control Layers with Image
|
||||
|
||||
Replace `my_control_image.png` with an actual image filename from your outputs/images directory.
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:9090/api/v1/recall/default \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"control_layers": [
|
||||
{
|
||||
"model_name": "controlnet-canny-sdxl-1.0",
|
||||
"image_name": "my_control_image.png",
|
||||
"weight": 0.75,
|
||||
"begin_step_percent": 0.0,
|
||||
"end_step_percent": 0.8,
|
||||
"control_mode": "balanced"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
### 3. Configure IP Adapters with Reference Image
|
||||
|
||||
Replace `reference_face.png` with an actual image filename from your outputs/images directory.
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:9090/api/v1/recall/default \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"ip_adapters": [
|
||||
{
|
||||
"model_name": "ip-adapter-plus-face_sd15",
|
||||
"image_name": "reference_face.png",
|
||||
"weight": 0.7,
|
||||
"begin_step_percent": 0.0,
|
||||
"end_step_percent": 1.0,
|
||||
"method": "composition"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
### 4. Complete Configuration with All Features
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:9090/api/v1/recall/default \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"positive_prompt": "masterpiece, detailed photo with specific style",
|
||||
"negative_prompt": "blurry, low quality",
|
||||
"model": "FLUX Schnell",
|
||||
"steps": 25,
|
||||
"cfg_scale": 8.0,
|
||||
"width": 1024,
|
||||
"height": 768,
|
||||
"seed": 42,
|
||||
"loras": [
|
||||
{
|
||||
"model_name": "add-detail-xl",
|
||||
"weight": 0.6,
|
||||
"is_enabled": true
|
||||
}
|
||||
],
|
||||
"control_layers": [
|
||||
{
|
||||
"model_name": "controlnet-depth-sdxl-1.0",
|
||||
"image_name": "depth_map.png",
|
||||
"weight": 1.0,
|
||||
"begin_step_percent": 0.0,
|
||||
"end_step_percent": 0.7
|
||||
}
|
||||
],
|
||||
"ip_adapters": [
|
||||
{
|
||||
"model_name": "ip-adapter-plus-face_sd15",
|
||||
"image_name": "style_reference.png",
|
||||
"weight": 0.5,
|
||||
"begin_step_percent": 0.0,
|
||||
"end_step_percent": 1.0,
|
||||
"method": "style"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
## Response Format
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"queue_id": "default",
|
||||
"updated_count": 15,
|
||||
"parameters": {
|
||||
"positive_prompt": "...",
|
||||
"steps": 25,
|
||||
"loras": [
|
||||
{
|
||||
"model_key": "abc123...",
|
||||
"weight": 0.6,
|
||||
"is_enabled": true
|
||||
}
|
||||
],
|
||||
"control_layers": [
|
||||
{
|
||||
"model_key": "controlnet-xyz...",
|
||||
"weight": 1.0,
|
||||
"image": {
|
||||
"image_name": "depth_map.png"
|
||||
}
|
||||
}
|
||||
],
|
||||
"ip_adapters": [
|
||||
{
|
||||
"model_key": "ip-adapter-xyz...",
|
||||
"weight": 0.5,
|
||||
"image": {
|
||||
"image_name": "style_reference.png"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## WebSocket Events
|
||||
|
||||
When parameters are updated, a `recall_parameters_updated` event is emitted via WebSocket to the queue room. The frontend automatically:
|
||||
|
||||
1. Applies standard parameters (prompts, steps, dimensions, etc.)
|
||||
2. Loads and adds LoRAs to the LoRA list
|
||||
3. Logs control layer and IP adapter configurations with image information
|
||||
4. Makes image references available for manual canvas/reference image creation
|
||||
|
||||
## Logging
|
||||
|
||||
### Backend Logs
|
||||
|
||||
Backend logs show:
|
||||
- Model name → key resolution (success/failure)
|
||||
- Image file validation (found/not found)
|
||||
- Parameter storage confirmation
|
||||
- Event emission status
|
||||
|
||||
Example log messages:
|
||||
```
|
||||
INFO: Resolved ControlNet model name 'controlnet-canny-sdxl-1.0' to key 'controlnet-xyz...'
|
||||
INFO: Found image file: depth_map.png
|
||||
INFO: Updated 12 recall parameters for queue default
|
||||
INFO: Resolved 1 LoRA(s)
|
||||
INFO: Resolved 1 control layer(s)
|
||||
INFO: Resolved 1 IP adapter(s)
|
||||
```
|
||||
|
||||
### Frontend Logs
|
||||
|
||||
Frontend logs (check browser console):
|
||||
- Set `localStorage.ROARR_FILTER = 'debug'` to see all debug messages
|
||||
- Look for messages from the `events` namespace
|
||||
- LoRA loading, model resolution, and parameter application are logged
|
||||
|
||||
Example log messages:
|
||||
```
|
||||
INFO: Applied 5 recall parameters to store
|
||||
INFO: Received 1 control layer(s) with image support
|
||||
INFO: Control layer 1: controlnet-xyz... (weight: 0.75, image: depth_map.png)
|
||||
DEBUG: Control layer 1 image available at: outputs/images/depth_map.png
|
||||
INFO: Received 1 IP adapter(s) with image support
|
||||
INFO: IP adapter 1: ip-adapter-xyz... (weight: 0.7, image: style_reference.png)
|
||||
DEBUG: IP adapter 1 image available at: outputs/images/style_reference.png
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
1. **Canvas Integration**: Control layers and IP adapters with images are currently logged but not automatically added to canvas layers
|
||||
- Users can view the configuration and manually create canvas layers with the provided images
|
||||
- Future enhancement: Auto-create canvas layers with stored images
|
||||
|
||||
2. **Model Availability**: Models must be installed in InvokeAI before they can be recalled
|
||||
|
||||
3. **Image Availability**: Images must exist in the outputs/images directory
|
||||
- Missing images are logged as warnings but don't fail the request
|
||||
- Other parameters are still applied even if images are missing
|
||||
|
||||
4. **Image URLs**: Only local filenames from outputs/images are supported
|
||||
- Remote image URLs are not currently supported
|
||||
|
||||
## Testing
|
||||
|
||||
Use the provided test script:
|
||||
|
||||
```bash
|
||||
./test_recall_loras_controlnets.sh
|
||||
```
|
||||
|
||||
This will test:
|
||||
- LoRA addition with multiple models
|
||||
- Control layer configuration with image references
|
||||
- IP adapter configuration with image references
|
||||
- Combined parameter updates with all features
|
||||
|
||||
Note: Update the image names in the test script to match actual images in your outputs/images directory.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Images Not Found
|
||||
|
||||
If you see "Image file not found" in the logs:
|
||||
1. Verify the image filename matches exactly (case-sensitive)
|
||||
2. Ensure the image is in `{INVOKEAI_ROOT}/outputs/images/`
|
||||
3. Check that the filename doesn't include the `outputs/images/` prefix
|
||||
|
||||
### Models Not Found
|
||||
|
||||
If you see "Could not find model" messages:
|
||||
1. Verify the model name matches exactly (case-sensitive)
|
||||
2. Ensure the model is installed in InvokeAI
|
||||
3. Check the model name using the models browser in the UI
|
||||
|
||||
### Event Not Received
|
||||
|
||||
If the frontend doesn't receive the event:
|
||||
1. Check browser console for connection errors
|
||||
2. Verify the queue_id matches the frontend's queue (usually "default")
|
||||
3. Check backend logs for event emission errors
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
Potential improvements:
|
||||
1. Auto-create canvas layers with provided control layer images
|
||||
2. Auto-create reference image layers with provided IP adapter images
|
||||
3. Support for image URLs
|
||||
4. Batch operations for multiple queue IDs
|
||||
5. Image upload capability (accept base64 or file upload)
|
||||
208
docs/contributing/RECALL_PARAMETERS/RECALL_PARAMETERS_API.md
Normal file
208
docs/contributing/RECALL_PARAMETERS/RECALL_PARAMETERS_API.md
Normal file
@@ -0,0 +1,208 @@
|
||||
# Recall Parameters API
|
||||
|
||||
## Overview
|
||||
|
||||
A new REST API endpoint has been added to the InvokeAI backend that allows programmatic updates to recallable parameters from another process. This enables external applications or scripts to modify frontend parameters like prompts, models, and step counts via HTTP requests.
|
||||
|
||||
When parameters are updated via the API, the backend automatically broadcasts a WebSocket event to all connected frontend clients subscribed to that queue, causing them to update immediately.
|
||||
|
||||
## How It Works
|
||||
|
||||
1. **API Request**: External application sends a POST request with parameters to update
|
||||
2. **Storage**: Parameters are stored in client state persistence, associated with a queue ID
|
||||
3. **Broadcast**: A WebSocket event (`recall_parameters_updated`) is emitted to all frontend clients listening to that queue
|
||||
4. **Frontend Update**: Connected frontend clients receive the event and can process the updated parameters
|
||||
5. **Immediate Display**: The frontend UI updates automatically with the new values
|
||||
|
||||
This means if you have the InvokeAI frontend open in a browser, updating parameters via the API will instantly reflect on the screen without any manual action needed.
|
||||
|
||||
## Endpoint
|
||||
|
||||
**Base URL**: `http://localhost:9090/api/v1/recall/{queue_id}`
|
||||
|
||||
## POST - Update Recall Parameters
|
||||
|
||||
Updates recallable parameters for a given queue ID.
|
||||
|
||||
### Request
|
||||
|
||||
```http
|
||||
POST /api/v1/recall/{queue_id}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"positive_prompt": "a beautiful landscape",
|
||||
"negative_prompt": "blurry, low quality",
|
||||
"model": "sd-1.5",
|
||||
"steps": 20,
|
||||
"cfg_scale": 7.5,
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"seed": 12345
|
||||
}
|
||||
```
|
||||
|
||||
The queue id is usually "default".
|
||||
|
||||
### Parameters
|
||||
|
||||
All parameters are optional. Only provide the parameters you want to update:
|
||||
|
||||
| Parameter | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `positive_prompt` | string | Positive prompt text |
|
||||
| `negative_prompt` | string | Negative prompt text |
|
||||
| `model` | string | Main model name/identifier |
|
||||
| `refiner_model` | string | Refiner model name/identifier |
|
||||
| `vae_model` | string | VAE model name/identifier |
|
||||
| `scheduler` | string | Scheduler name |
|
||||
| `steps` | integer | Number of generation steps (≥1) |
|
||||
| `refiner_steps` | integer | Number of refiner steps (≥0) |
|
||||
| `cfg_scale` | number | CFG scale for guidance |
|
||||
| `cfg_rescale_multiplier` | number | CFG rescale multiplier |
|
||||
| `refiner_cfg_scale` | number | Refiner CFG scale |
|
||||
| `guidance` | number | Guidance scale |
|
||||
| `width` | integer | Image width in pixels (≥64) |
|
||||
| `height` | integer | Image height in pixels (≥64) |
|
||||
| `seed` | integer | Random seed (≥0) |
|
||||
| `denoise_strength` | number | Denoising strength (0-1) |
|
||||
| `refiner_denoise_start` | number | Refiner denoising start (0-1) |
|
||||
| `clip_skip` | integer | CLIP skip layers (≥0) |
|
||||
| `seamless_x` | boolean | Enable seamless X tiling |
|
||||
| `seamless_y` | boolean | Enable seamless Y tiling |
|
||||
| `refiner_positive_aesthetic_score` | number | Refiner positive aesthetic score |
|
||||
| `refiner_negative_aesthetic_score` | number | Refiner negative aesthetic score |
|
||||
|
||||
### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"queue_id": "queue_123",
|
||||
"updated_count": 7,
|
||||
"parameters": {
|
||||
"positive_prompt": "a beautiful landscape",
|
||||
"negative_prompt": "blurry, low quality",
|
||||
"model": "sd-1.5",
|
||||
"steps": 20,
|
||||
"cfg_scale": 7.5,
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"seed": 12345
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## GET - Retrieve Recall Parameters
|
||||
|
||||
Retrieves metadata about stored recall parameters.
|
||||
|
||||
### Request
|
||||
|
||||
```http
|
||||
GET /api/v1/recall/{queue_id}
|
||||
```
|
||||
|
||||
### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"queue_id": "queue_123",
|
||||
"note": "Use the frontend to access stored recall parameters, or set specific parameters using POST"
|
||||
}
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Using cURL
|
||||
|
||||
```bash
|
||||
# Update prompts and model
|
||||
curl -X POST http://localhost:9090/api/v1/recall/default \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"positive_prompt": "a cyberpunk city at night",
|
||||
"negative_prompt": "dark, unclear",
|
||||
"model": "sd-1.5",
|
||||
"steps": 30
|
||||
}'
|
||||
|
||||
# Update just the seed
|
||||
curl -X POST http://localhost:9090/api/v1/recall/default \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"seed": 99999}'
|
||||
```
|
||||
|
||||
### Using Python
|
||||
|
||||
```python
|
||||
import requests
|
||||
import json
|
||||
|
||||
# Configuration
|
||||
API_URL = "http://localhost:9090/api/v1/recall/default"
|
||||
|
||||
# Update multiple parameters
|
||||
params = {
|
||||
"positive_prompt": "a serene forest",
|
||||
"negative_prompt": "people, buildings",
|
||||
"steps": 25,
|
||||
"cfg_scale": 7.0,
|
||||
"seed": 42
|
||||
}
|
||||
|
||||
response = requests.post(API_URL, json=params)
|
||||
result = response.json()
|
||||
|
||||
print(f"Status: {result['status']}")
|
||||
print(f"Updated {result['updated_count']} parameters")
|
||||
print(json.dumps(result['parameters'], indent=2))
|
||||
```
|
||||
|
||||
### Using Node.js/JavaScript
|
||||
|
||||
```javascript
|
||||
const API_URL = 'http://localhost:9090/api/v1/recall/default';
|
||||
|
||||
const params = {
|
||||
positive_prompt: 'a beautiful sunset',
|
||||
negative_prompt: 'blurry',
|
||||
steps: 20,
|
||||
width: 768,
|
||||
height: 768,
|
||||
seed: 12345
|
||||
};
|
||||
|
||||
fetch(API_URL, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(params)
|
||||
})
|
||||
.then(res => res.json())
|
||||
.then(data => console.log(data));
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
- Parameters are stored in the client state persistence service, using keys prefixed with `recall_`
|
||||
- The parameters are associated with a `queue_id`, allowing multiple concurrent sessions to maintain separate parameter sets
|
||||
- Only non-null parameters are processed and stored
|
||||
- The endpoint provides validation for numeric ranges (e.g., steps ≥ 1, dimensions ≥ 64)
|
||||
- All parameter values are JSON-serialized for storage
|
||||
- When parameter values are changed, the backend generates a web sockets event that the frontend listens to.
|
||||
|
||||
## Integration with Frontend
|
||||
|
||||
The stored parameters can be accessed by the frontend through the
|
||||
existing client state API or by implementing hooks that read from the
|
||||
recall parameter storage. This allows external applications to
|
||||
pre-populate generation parameters before the user initiates image
|
||||
generation.
|
||||
|
||||
## Error Handling
|
||||
|
||||
- **400 Bad Request**: Invalid parameters or parameter values
|
||||
- **500 Internal Server Error**: Server-side error storing or retrieving parameters
|
||||
|
||||
Errors include detailed messages explaining what went wrong.
|
||||
35
docs/contributing/frontend/canvas-text-tool.md
Normal file
35
docs/contributing/frontend/canvas-text-tool.md
Normal file
@@ -0,0 +1,35 @@
|
||||
# Canvas Text Tool
|
||||
|
||||
## Overview
|
||||
|
||||
The canvas text workflow is split between a Konva module that owns tool state and a React overlay that handles text entry.
|
||||
|
||||
- `invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasTextToolModule.ts`
|
||||
- Owns the tool, cursor preview, and text session state (including the cursor "T" marker).
|
||||
- Manages dynamic cursor contrast, starts sessions on pointer down, and commits sessions by rasterizing the active text block into a new raster layer.
|
||||
- `invokeai/frontend/web/src/features/controlLayers/components/Text/CanvasTextOverlay.tsx`
|
||||
- Renders the on-canvas editor as a `contentEditable` overlay positioned in canvas space.
|
||||
- Syncs keyboard input, suppresses app hotkeys, and forwards commits/cancels to the Konva module.
|
||||
- `invokeai/frontend/web/src/features/controlLayers/components/Text/TextToolOptions.tsx`
|
||||
- Provides the font dropdown, size slider/input, formatting toggles, and alignment buttons that appear when the Text tool is active.
|
||||
|
||||
## Rasterization pipeline
|
||||
|
||||
`renderTextToCanvas()` (`invokeai/frontend/web/src/features/controlLayers/text/textRenderer.ts`) converts the editor contents into a transparent canvas. The Text tool module configures the renderer with the active font stack, weight, styling flags, alignment, and the active canvas color. The resulting canvas is encoded to a PNG data URL and stored in a new raster layer (`image` object) with a transparent background.
|
||||
|
||||
Layer placement preserves the original click location:
|
||||
|
||||
- The session stores the anchor coordinate (where the user clicked) and current alignment.
|
||||
- `calculateLayerPosition()` calculates the top-left position for the raster layer after applying the configured padding and alignment offsets.
|
||||
- New layers are inserted directly above the currently-selected raster layer (when present) and selected automatically.
|
||||
|
||||
## Font stacks
|
||||
|
||||
Font definitions live in `invokeai/frontend/web/src/features/controlLayers/text/textConstants.ts` as ten deterministic stacks (sans, serif, mono, rounded, script, humanist, slab serif, display, narrow, UI serif). Each stack lists system-safe fallbacks so the editor can choose the first available font per platform.
|
||||
|
||||
To add or adjust fonts:
|
||||
|
||||
1. Update `TEXT_FONT_STACKS` with the new `id`, `label`, and CSS `font-family` stack.
|
||||
2. If you add a new stack, extend the `TEXT_FONT_IDS` tuple and update the `canvasTextSlice` schema default (`TEXT_DEFAULT_FONT_ID`).
|
||||
3. Provide translation strings for any new labels in `public/locales/*`.
|
||||
4. The editor and renderer will automatically pick up the new stack via `getFontStackById()`.
|
||||
@@ -8,6 +8,10 @@ We welcome contributions, whether features, bug fixes, code cleanup, testing, co
|
||||
|
||||
If you’d like to help with development, please see our [development guide](contribution_guides/development.md).
|
||||
|
||||
## External Providers
|
||||
|
||||
If you are adding external image generation providers or configs, see our [external provider integration guide](EXTERNAL_PROVIDERS.md).
|
||||
|
||||
**New Contributors:** If you’re unfamiliar with contributing to open source projects, take a look at our [new contributor guide](contribution_guides/newContributorChecklist.md).
|
||||
|
||||
## Nodes
|
||||
@@ -18,7 +22,7 @@ If you’d like to add a Node, please see our [nodes contribution guide](../node
|
||||
|
||||
Helping support other users in [Discord](https://discord.gg/ZmtBAhwWhy) and on Github are valuable forms of contribution that we greatly appreciate.
|
||||
|
||||
We receive many issues and requests for help from users. We're limited in bandwidth relative to our the user base, so providing answers to questions or helping identify causes of issues is very helpful. By doing this, you enable us to spend time on the highest priority work.
|
||||
We receive many issues and requests for help from users. We're limited in bandwidth relative to our user base, so providing answers to questions or helping identify causes of issues is very helpful. By doing this, you enable us to spend time on the highest priority work.
|
||||
|
||||
## Documentation
|
||||
|
||||
|
||||
32
docs/features/Lasso_tool.md
Normal file
32
docs/features/Lasso_tool.md
Normal file
@@ -0,0 +1,32 @@
|
||||
Lasso Tool
|
||||
===========
|
||||
|
||||
- The Lasso tool creates selections and inpaint masks by drawing freehand or polygonal regions on the canvas.
|
||||
|
||||
How to open the Lasso tool
|
||||
--------------------------
|
||||
- Click the Lasso icon in the toolbar.
|
||||
- Hotkey: press `L` (default). The hotkey is shown in the tool's tooltip and can be customized in Hotkeys settings.
|
||||
|
||||
Modes
|
||||
-----
|
||||
- Freehand (default)
|
||||
- Hold the pointer and drag to draw a continuous contour.
|
||||
- Long segments are broken into intermediate points to keep the line continuous.
|
||||
- Very long strokes may be simplified after drawing to reduce point count for performance.
|
||||
|
||||
- Polygon
|
||||
- Click to place points; click the first point (or a point near it) to close the polygon.
|
||||
- The tool snaps the closing point to the start for precise closures.
|
||||
|
||||
Basic interactions
|
||||
------------------
|
||||
- Switch modes with the mode toggle in the toolbar.
|
||||
- To close a polygon: click the starting point again or click near it — the tool aligns the final point to the start to complete the shape.
|
||||
- The selection will be added to the current Inpaint Mask layer. If no Inpaint Mask layer exists, a new one will be created automatically.
|
||||
|
||||
Tips & behavior
|
||||
---------------
|
||||
- Hold `Space` to temporarily switch to the View tool for panning and zooming; release `Space` to return to the Lasso tool and continue drawing.
|
||||
- When using the Polygon mode, you can hold `Shift` to snap points to horizontal, vertical, or 45-degree angles for more precise shapes.
|
||||
- Hold `Ctrl` (Windows/Linux) or `Command` (macOS) while drawing to subtract from the current selection instead of adding to it.
|
||||
19
docs/features/Text_tool.md
Normal file
19
docs/features/Text_tool.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# Text Tool
|
||||
|
||||
## Font selection
|
||||
|
||||
The Text tool uses a set of predefined font stacks. When you choose a font, the app resolves the first available font on your system from that stack and uses it for both the editor overlay and the rasterized result. This provides consistent styling across platforms while still falling back to safe system fonts if a preferred font is missing.
|
||||
|
||||
## Size and spacing
|
||||
|
||||
- **Size** controls the font size in pixels.
|
||||
- **Spacing** controls the line height multiplier (Dense, Normal, Spacious). This affects the distance between lines while editing the text.
|
||||
|
||||
## Uncommitted state
|
||||
|
||||
While text is uncommitted, it remains editable on-canvas. Access to other tools is blocked. Switching to other tabs (Generate, Upascaling, Workflows etc.) discards the text. The uncommitted box can be moved and rotated:
|
||||
|
||||
- **Move:** Hold Ctrl (Windows/Linux) or Command (macOS) and drag to move the text box.
|
||||
- **Rotate:** Drag the rotation handle above the box. Hold **Shift** while rotating to snap to 15 degree increments.
|
||||
|
||||
The text is committed to a raster layer when you press **Enter**. Press **Esc** to discard the current text session.
|
||||
56
docs/features/canvas_projects.md
Normal file
56
docs/features/canvas_projects.md
Normal file
@@ -0,0 +1,56 @@
|
||||
---
|
||||
title: Canvas Projects
|
||||
---
|
||||
|
||||
# :material-folder-zip: Canvas Projects
|
||||
|
||||
## Save and Restore Your Canvas Work
|
||||
|
||||
Canvas Projects let you save your entire canvas setup to a file and load it back later. This is useful when you want to:
|
||||
|
||||
- **Switch between tasks** without losing your current canvas arrangement
|
||||
- **Back up complex setups** with multiple layers, masks, and reference images
|
||||
- **Share canvas layouts** with others or transfer them between machines
|
||||
- **Recover from deleted images** — all images are embedded in the project file
|
||||
|
||||
## What Gets Saved
|
||||
|
||||
A canvas project file (`.invk`) captures everything about your current canvas session:
|
||||
|
||||
- **All layers** — raster layers, control layers, inpaint masks, regional guidance
|
||||
- **All drawn content** — brush strokes, pasted images, eraser marks
|
||||
- **Reference images** — global IP-Adapter / FLUX Redux images with crop settings
|
||||
- **Regional guidance** — per-region prompts and reference images
|
||||
- **Bounding box** — position, size, aspect ratio, and scale settings
|
||||
- **All generation parameters** — prompts, seed, steps, CFG scale, guidance, scheduler, model, VAE, dimensions, img2img strength, infill settings, canvas coherence, refiner settings, FLUX/Z-Image specific parameters, and more
|
||||
- **LoRAs** — all added LoRA models with their weights and enabled/disabled state
|
||||
|
||||
## How to Save a Project
|
||||
|
||||
You can save from two places:
|
||||
|
||||
1. **Toolbar** — Click the **Archive icon** in the canvas toolbar, then select **Save Canvas Project**
|
||||
2. **Context menu** — Right-click the canvas, open the **Project** submenu, then select **Save Canvas Project**
|
||||
|
||||
A dialog will ask you to enter a **project name**. This name is used as the filename (e.g., entering "My Portrait" saves as `My Portrait.invk`) and is stored inside the project file.
|
||||
|
||||
## How to Load a Project
|
||||
|
||||
1. **Toolbar** — Click the **Archive icon**, then select **Load Canvas Project**
|
||||
2. **Context menu** — Right-click the canvas, open the **Project** submenu, then select **Load Canvas Project**
|
||||
|
||||
A file dialog will open. Select your `.invk` file. You will see a confirmation dialog warning that loading will replace your current canvas. Click **Load** to proceed.
|
||||
|
||||
### What Happens on Load
|
||||
|
||||
- Your current canvas is **completely replaced** — all existing layers, masks, reference images, and parameters are overwritten
|
||||
- Images that are already present on your InvokeAI server are reused automatically (no duplicate uploads)
|
||||
- Images that were deleted from the server are re-uploaded from the project file
|
||||
- If the saved model is not installed on your system, the model identifier is still restored — you will need to select an available model manually
|
||||
|
||||
## Good to Know
|
||||
|
||||
- **No undo** — Loading a project replaces your canvas entirely. There is no way to undo this action, so save your current project first if you want to keep it.
|
||||
- **Image deduplication** — When loading, images already on your server are not re-uploaded. Only missing images are uploaded from the project file.
|
||||
- **File size** — The `.invk` file size depends on the number and resolution of images in your canvas. A project with many high-resolution layers can be large.
|
||||
- **Model availability** — The project saves which model was selected, but does not include the model itself. If the model is not installed when you load the project, you will need to select a different one.
|
||||
152
docs/features/orphaned_model_removal.md
Normal file
152
docs/features/orphaned_model_removal.md
Normal file
@@ -0,0 +1,152 @@
|
||||
# Orphaned Models Synchronization Feature
|
||||
|
||||
## Overview
|
||||
This feature adds a UI for synchronizing the models directory by finding and removing orphaned model files. Orphaned models are directories that contain model files but are not referenced in the InvokeAI database.
|
||||
|
||||
## Implementation Summary
|
||||
|
||||
### Backend (Python)
|
||||
|
||||
#### New Service: `OrphanedModelsService`
|
||||
- Location: `invokeai/app/services/orphaned_models/`
|
||||
- Implements the core logic from the CLI script
|
||||
- Methods:
|
||||
- `find_orphaned_models()`: Scans the models directory and database to find orphaned models
|
||||
- `delete_orphaned_models(paths)`: Safely deletes specified orphaned model directories
|
||||
|
||||
#### API Routes
|
||||
Added to `invokeai/app/api/routers/model_manager.py`:
|
||||
- `GET /api/v2/models/sync/orphaned`: Returns list of orphaned models with metadata
|
||||
- `DELETE /api/v2/models/sync/orphaned`: Deletes selected orphaned models
|
||||
|
||||
#### Data Models
|
||||
- `OrphanedModelInfo`: Contains path, absolute_path, files list, and size_bytes
|
||||
- `DeleteOrphanedModelsRequest`: Contains list of paths to delete
|
||||
- `DeleteOrphanedModelsResponse`: Contains deleted paths and errors
|
||||
|
||||
### Frontend (TypeScript/React)
|
||||
|
||||
#### New Components
|
||||
|
||||
1. **SyncModelsButton.tsx**
|
||||
- Red button styled with `colorScheme="error"` for visual prominence
|
||||
- Labeled "Sync Models"
|
||||
- Opens the SyncModelsDialog when clicked
|
||||
- Located next to the "+ Add Models" button
|
||||
|
||||
2. **SyncModelsDialog.tsx**
|
||||
- Modal dialog that displays orphaned models
|
||||
- Features:
|
||||
- List of orphaned models with checkboxes (default: all checked)
|
||||
- "Select All" / "Deselect All" toggle
|
||||
- Shows file count and total size for each model
|
||||
- "Delete" and "Cancel" buttons
|
||||
- Loading spinner while fetching data
|
||||
- Error handling with user-friendly messages
|
||||
- Automatically shows toast if no orphaned models found
|
||||
- Shows success/error toasts after deletion
|
||||
|
||||
#### API Integration
|
||||
- Added `useGetOrphanedModelsQuery` and `useDeleteOrphanedModelsMutation` hooks to `services/api/endpoints/models.ts`
|
||||
- Integrated with RTK Query for efficient data fetching and caching
|
||||
|
||||
#### Translation Strings
|
||||
Added to `public/locales/en.json`:
|
||||
- syncModels, noOrphanedModels, orphanedModelsFound
|
||||
- orphanedModelsDescription, foundOrphanedModels (with pluralization)
|
||||
- filesCount, deleteSelected, deselectAll
|
||||
- Success/error messages for deletion operations
|
||||
|
||||
## User Experience Flow
|
||||
|
||||
1. User clicks the red "Sync Models" button in the Model Manager
|
||||
2. System queries the backend for orphaned models
|
||||
3. If no orphaned models:
|
||||
- Toast message: "The models directory is synchronized. No orphaned files found."
|
||||
- Dialog closes automatically
|
||||
4. If orphaned models found:
|
||||
- Dialog shows list with checkboxes (all selected by default)
|
||||
- User can toggle individual models or use "Select All" / "Deselect All"
|
||||
- Each model shows:
|
||||
- Directory path
|
||||
- File count
|
||||
- Total size (formatted: B, KB, MB, GB)
|
||||
5. User clicks "Delete {{count}} selected"
|
||||
6. System deletes selected models
|
||||
7. Success/error toasts appear
|
||||
8. Dialog closes
|
||||
|
||||
## Safety Features
|
||||
|
||||
1. **Database Backup**: The service creates a backup before any deletion
|
||||
2. **Selective Deletion**: Users choose which models to delete
|
||||
3. **Path Validation**: Ensures paths are within the models directory
|
||||
4. **Error Handling**: Reports which models failed to delete and why
|
||||
5. **Default Selected**: All models are selected by default for convenience
|
||||
6. **Confirmation Required**: User must explicitly click Delete
|
||||
|
||||
## Technical Details
|
||||
|
||||
### Directory-Based Detection
|
||||
The system treats model paths as directories:
|
||||
- If database has `model-id/file.safetensors`, the entire `model-id/` directory belongs to that model
|
||||
- All files and subdirectories within a registered model directory are protected
|
||||
- Only directories with NO registered models are flagged as orphaned
|
||||
|
||||
### Supported File Extensions
|
||||
- .safetensors
|
||||
- .ckpt
|
||||
- .pt
|
||||
- .pth
|
||||
- .bin
|
||||
- .onnx
|
||||
|
||||
### Skipped Directories
|
||||
- .download_cache
|
||||
- .convert_cache
|
||||
- \_\_pycache\_\_
|
||||
- .git
|
||||
|
||||
## Testing Recommendations
|
||||
|
||||
1. **Test with orphaned models**:
|
||||
- Manually copy a model directory to models folder
|
||||
- Verify it appears in the dialog
|
||||
- Delete it and verify removal
|
||||
|
||||
2. **Test with no orphaned models**:
|
||||
- Clean install
|
||||
- Verify toast message appears
|
||||
|
||||
3. **Test partial selection**:
|
||||
- Select only some models
|
||||
- Verify only selected ones are deleted
|
||||
|
||||
4. **Test error scenarios**:
|
||||
- Invalid paths
|
||||
- Permission issues
|
||||
- Verify error messages are clear
|
||||
|
||||
## Files Changed
|
||||
|
||||
### Backend
|
||||
- `invokeai/app/services/orphaned_models/__init__.py` (new)
|
||||
- `invokeai/app/services/orphaned_models/orphaned_models_service.py` (new)
|
||||
- `invokeai/app/api/routers/model_manager.py` (modified)
|
||||
|
||||
### Frontend
|
||||
- `invokeai/frontend/web/src/services/api/endpoints/models.ts` (modified)
|
||||
- `invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx` (modified)
|
||||
- `invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/SyncModelsButton.tsx` (new)
|
||||
- `invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/SyncModelsDialog.tsx` (new)
|
||||
- `invokeai/frontend/web/public/locales/en.json` (modified)
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
Potential improvements for future versions:
|
||||
1. Show preview of what will be deleted before deletion
|
||||
2. Add option to move orphaned models to archive instead of deleting
|
||||
3. Show disk space that will be freed
|
||||
4. Add filter/search in orphaned models list
|
||||
5. Support for undo operation
|
||||
6. Scheduled automatic cleanup
|
||||
@@ -25,12 +25,24 @@ Hardware requirements vary significantly depending on model and image output siz
|
||||
- Memory: At least 16GB RAM.
|
||||
- Disk: 10GB for base installation plus 100GB for models.
|
||||
|
||||
=== "FLUX - 1024×1024"
|
||||
=== "FLUX.1 - 1024×1024"
|
||||
|
||||
- GPU: Nvidia 20xx series or later, 10GB+ VRAM.
|
||||
- Memory: At least 32GB RAM.
|
||||
- Disk: 10GB for base installation plus 200GB for models.
|
||||
|
||||
=== "FLUX.2 Klein - 1024×1024"
|
||||
|
||||
- GPU: Nvidia 20xx series or later, 6GB+ VRAM for GGUF Q4 quantized models, 12GB+ for full precision.
|
||||
- Memory: At least 16GB RAM.
|
||||
- Disk: 10GB for base installation plus 20GB for models.
|
||||
|
||||
=== "Z-Image Turbo - 1024x1024"
|
||||
- GPU: Nvidia 20xx series or later, 8GB+ VRAM for the Q4_K quantized model. 16GB+ needed for the Q8 or BF16 models.
|
||||
- Memory: At least 16GB RAM.
|
||||
- Disk: 10GB for base installation plus 35GB for models.
|
||||
|
||||
|
||||
More detail on system requirements can be found [here](./requirements.md).
|
||||
|
||||
## Step 2: Download and Set Up the Launcher
|
||||
|
||||
@@ -6,7 +6,9 @@ Invoke runs on Windows 10+, macOS 14+ and Linux (Ubuntu 20.04+ is well-tested).
|
||||
|
||||
Hardware requirements vary significantly depending on model and image output size.
|
||||
|
||||
The requirements below are rough guidelines for best performance. GPUs with less VRAM typically still work, if a bit slower. Follow the [Low-VRAM mode guide](./features/low-vram.md) to optimize performance.
|
||||
The requirements below are rough guidelines for best performance. GPUs
|
||||
with less VRAM typically still work, if a bit slower. Follow the
|
||||
[Low-VRAM mode guide](../features/low-vram.md) to optimize performance.
|
||||
|
||||
- All Apple Silicon (M1, M2, etc) Macs work, but 16GB+ memory is recommended.
|
||||
- AMD GPUs are supported on Linux only. The VRAM requirements are the same as Nvidia GPUs.
|
||||
@@ -25,12 +27,29 @@ The requirements below are rough guidelines for best performance. GPUs with less
|
||||
- Memory: At least 16GB RAM.
|
||||
- Disk: 10GB for base installation plus 100GB for models.
|
||||
|
||||
=== "FLUX - 1024×1024"
|
||||
=== "FLUX.1 - 1024×1024"
|
||||
|
||||
- GPU: Nvidia 20xx series or later, 10GB+ VRAM.
|
||||
- Memory: At least 32GB RAM.
|
||||
- Disk: 10GB for base installation plus 200GB for models.
|
||||
|
||||
=== "FLUX.2 Klein 4B - 1024×1024"
|
||||
|
||||
- GPU: Nvidia 30xx series or later, 12GB+ VRAM (e.g. RTX 3090, RTX 4070). FP8 version works with 8GB+ VRAM.
|
||||
- Memory: At least 16GB RAM.
|
||||
- Disk: 10GB for base installation plus 20GB for models (Diffusers format with encoder).
|
||||
|
||||
=== "FLUX.2 Klein 9B - 1024×1024"
|
||||
|
||||
- GPU: Nvidia 40xx series, 24GB+ VRAM (e.g. RTX 4090). FP8 version works with 12GB+ VRAM.
|
||||
- Memory: At least 32GB RAM.
|
||||
- Disk: 10GB for base installation plus 40GB for models (Diffusers format with encoder).
|
||||
|
||||
=== "Z-Image Turbo - 1024x1024"
|
||||
- GPU: Nvidia 20xx series or later, 8GB+ VRAM for the Q4_K quantized model. 16GB+ needed for the Q8 or BF16 models.
|
||||
- Memory: At least 16GB RAM.
|
||||
- Disk: 10GB for base installation plus 35GB for models.
|
||||
|
||||
!!! info "`tmpfs` on Linux"
|
||||
|
||||
If your temporary directory is mounted as a `tmpfs`, ensure it has sufficient space.
|
||||
|
||||
876
docs/multiuser/admin_guide.md
Normal file
876
docs/multiuser/admin_guide.md
Normal file
@@ -0,0 +1,876 @@
|
||||
# InvokeAI Multi-User Administrator Guide
|
||||
|
||||
## Overview
|
||||
|
||||
This guide is for administrators managing a multi-user InvokeAI installation. It covers initial setup, user management, security best practices, and troubleshooting.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before enabling multi-user support, ensure you have:
|
||||
|
||||
- InvokeAI installed and running
|
||||
- Access to the server filesystem (for initial setup)
|
||||
- Understanding of your deployment environment
|
||||
- Backup of your existing data (recommended)
|
||||
|
||||
## Initial Setup
|
||||
|
||||
### Activating Multiuser Mode
|
||||
|
||||
To put InvokeAI into multiuser mode, you will need to add the option
|
||||
`multiuser: true` to its configuration file. This file is located at
|
||||
`INVOKEAI_ROOT/invokeai.yaml` With the InvokeAI backend halted, add
|
||||
the new configuration option to the end of the file with a text editor
|
||||
so that it looks like this:
|
||||
|
||||
```yaml
|
||||
# Internal metadata - do not edit:
|
||||
schema_version: 4.0.2
|
||||
|
||||
# Enable/disable multi-user mode
|
||||
multiuser: true
|
||||
```
|
||||
|
||||
Then restart the InvokeAI server backend from the command line or
|
||||
using the launcher.
|
||||
|
||||
!!! note "Reverting to single-user mode"
|
||||
If at any time you wish to revert to single-user mode, simply comment
|
||||
out the `multiuser` line, or change "true" to "false". Then
|
||||
restart the server. Because of the way that browsers cache pages,
|
||||
users with open InvokeAI sessions may need to force-refresh their
|
||||
browsers.
|
||||
|
||||
|
||||
### First Administrator Account
|
||||
|
||||
When InvokeAI starts for the first time in multi-user mode, you'll see the **Administrator Setup** dialog.
|
||||
|
||||
**Setup Steps:**
|
||||
|
||||
1. **Email Address**: Enter a valid email address (this becomes your username)
|
||||
|
||||
* Example: `admin@example.com` or `admin@localhost` for testing
|
||||
* Must be a valid email format
|
||||
* Cannot be changed later without database access
|
||||
|
||||
2. **Display Name**: Enter a friendly name
|
||||
|
||||
* Example: "System Administrator" or your real name
|
||||
* Can be changed later in your profile
|
||||
* Visible to other users in shared contexts
|
||||
|
||||
3. **Password**: Create a strong administrator password
|
||||
|
||||
* **Minimum requirements:**
|
||||
|
||||
* At least 8 characters long
|
||||
* Contains uppercase letters (A-Z)
|
||||
* Contains lowercase letters (a-z)
|
||||
* Contains numbers (0-9)
|
||||
|
||||
* **Recommended:**
|
||||
|
||||
* Use 12+ characters
|
||||
* Include special characters (!@#$%^&*)
|
||||
* Use a password manager to generate and store
|
||||
* Don't reuse passwords from other services
|
||||
|
||||
4. **Confirm Password**: Re-enter the password
|
||||
|
||||
5. Click **Create Administrator Account**
|
||||
|
||||
!!! warning "Important"
|
||||
Store these credentials securely! The
|
||||
first administrator account can reset
|
||||
the password to something new, but cannot
|
||||
retrieve a lost one.
|
||||
|
||||
### Configuration
|
||||
|
||||
InvokeAI can run in single-user or multi-user mode, controlled by the `multiuser` configuration option in `invokeai.yaml`:
|
||||
|
||||
```yaml
|
||||
# Enable/disable multi-user mode
|
||||
multiuser: true # Enable multi-user mode (requires authentication)
|
||||
# multiuser: false # Single-user mode (no authentication required)
|
||||
# If the multiuser option is absent, single-user mode is used
|
||||
|
||||
# Database configuration
|
||||
use_memory_db: false # Use persistent database
|
||||
db_path: databases/invokeai.db # Database location
|
||||
|
||||
# Session configuration (multi-user mode only)
|
||||
jwt_secret_key: "your-secret-key-here" # Auto-generated if not specified
|
||||
jwt_token_expiry_hours: 24 # Default session timeout
|
||||
jwt_remember_me_days: 7 # "Remember me" duration
|
||||
```
|
||||
|
||||
**Single-User Mode** (`multiuser: false` or option absent):
|
||||
- No authentication required
|
||||
- All functionality enabled by default
|
||||
- All boards and images visible in unified view
|
||||
- Ideal for personal use or trusted environments
|
||||
|
||||
**Multi-User Mode** (`multiuser: true`):
|
||||
- Authentication required for access
|
||||
- User isolation for boards, images, and workflows
|
||||
- Role-based permissions enforced
|
||||
- Ideal for shared servers or team environments
|
||||
|
||||
!!! warning "Mode Switching Behavior"
|
||||
**Switching to Single-User Mode:** If boards or images were created in multi-user mode, they will all be combined into a single unified view when switching to single-user mode.
|
||||
|
||||
**Switching to Multi-User Mode:** Legacy boards and images created under single-user mode will be owned by an internal user named "system." Only the Administrator will have access to these legacy assets. A utility to migrate these legacy assets to another user will be part of a future release.
|
||||
|
||||
### Migration from Single-User
|
||||
|
||||
When upgrading from a single-user installation or switching modes:
|
||||
|
||||
1. **Automatic Migration**: The database will automatically migrate to multi-user schema when multi-user mode is first enabled
|
||||
2. **Legacy Data Ownership**: Existing data (boards, images, workflows) created in single-user mode is assigned to an internal user named "system"
|
||||
3. **Administrator Access**: Only administrators will have access to legacy "system"-owned assets when in multi-user mode
|
||||
4. **No Data Loss**: All existing content is preserved
|
||||
|
||||
**Migration Process:**
|
||||
|
||||
```bash
|
||||
# Backup your database first
|
||||
cp databases/invokeai.db databases/invokeai.db.backup
|
||||
|
||||
# Enable multi-user mode in invokeai.yaml
|
||||
# multiuser: true
|
||||
|
||||
# Start InvokeAI (migration happens automatically)
|
||||
invokeai-web
|
||||
|
||||
# Complete the administrator setup dialog
|
||||
# Legacy data will be owned by "system" user
|
||||
```
|
||||
|
||||
!!! note "Legacy Asset Migration"
|
||||
A utility to migrate legacy "system"-owned assets to specific user accounts will be available in a future release. Until then, administrators can access and manage all legacy content.
|
||||
|
||||
## User Management
|
||||
|
||||
### Creating Users
|
||||
|
||||
**Via Web Interface (Coming Soon):**
|
||||
|
||||
!!! info "Web UI for User Management"
|
||||
A web-based user interface that allows administrators to manage users is coming in a future release. Until then, use the command-line scripts described below.
|
||||
|
||||
**Via Command Line Scripts:**
|
||||
|
||||
InvokeAI provides several command-line scripts in the `scripts/` directory for user management:
|
||||
|
||||
**useradd.py** - Add a new user:
|
||||
|
||||
```bash
|
||||
# Interactive mode (prompts for details)
|
||||
python scripts/useradd.py
|
||||
|
||||
# Create a regular user
|
||||
python scripts/useradd.py \
|
||||
--email user@example.com \
|
||||
--password TempPass123 \
|
||||
--name "User Name"
|
||||
|
||||
# Create an administrator
|
||||
python scripts/useradd.py \
|
||||
--email admin@example.com \
|
||||
--password AdminPass123 \
|
||||
--name "Admin Name" \
|
||||
--admin
|
||||
```
|
||||
|
||||
**userlist.py** - List all users:
|
||||
|
||||
```bash
|
||||
# List all users
|
||||
python scripts/userlist.py
|
||||
|
||||
# Show detailed information
|
||||
python scripts/userlist.py --verbose
|
||||
```
|
||||
|
||||
**usermod.py** - Modify an existing user:
|
||||
|
||||
```bash
|
||||
# Change display name
|
||||
python scripts/usermod.py --email user@example.com --name "New Name"
|
||||
|
||||
# Promote to administrator
|
||||
python scripts/usermod.py --email user@example.com --admin
|
||||
|
||||
# Demote from administrator
|
||||
python scripts/usermod.py --email user@example.com --no-admin
|
||||
|
||||
# Deactivate account
|
||||
python scripts/usermod.py --email user@example.com --deactivate
|
||||
|
||||
# Reactivate account
|
||||
python scripts/usermod.py --email user@example.com --activate
|
||||
|
||||
# Change password
|
||||
python scripts/usermod.py --email user@example.com --password NewPassword123
|
||||
```
|
||||
|
||||
**userdel.py** - Delete a user:
|
||||
|
||||
```bash
|
||||
# Delete a user (prompts for confirmation)
|
||||
python scripts/userdel.py --email user@example.com
|
||||
|
||||
# Delete without confirmation
|
||||
python scripts/userdel.py --email user@example.com --force
|
||||
```
|
||||
|
||||
!!! tip "Script Usage"
|
||||
Run any script with `--help` to see all available options:
|
||||
```bash
|
||||
python scripts/useradd.py --help
|
||||
```
|
||||
|
||||
!!! warning "Command Line Management"
|
||||
- These scripts directly modify the database
|
||||
- Always backup your database before making changes
|
||||
- Changes take effect immediately (users may need to log in again)
|
||||
- Deleting a user permanently removes all their content
|
||||
|
||||
### Editing Users
|
||||
|
||||
**Via Command Line:**
|
||||
|
||||
Use `usermod.py` as described above to modify user properties.
|
||||
|
||||
!!! warning "Last Administrator"
|
||||
You cannot remove admin privileges from the last remaining administrator account.
|
||||
|
||||
### Resetting User Passwords
|
||||
|
||||
**Via Web Interface (Coming Soon):**
|
||||
|
||||
Web-based password reset functionality for administrators is coming in a future release.
|
||||
|
||||
**Via Command Line:**
|
||||
|
||||
```bash
|
||||
# Reset a user's password
|
||||
python scripts/usermod.py --email user@example.com --password NewTempPassword123
|
||||
```
|
||||
|
||||
**Security Note:** Never send passwords via email or unsecured channels. Use secure communication methods.
|
||||
|
||||
### Deactivating Users
|
||||
|
||||
**Via Command Line:**
|
||||
|
||||
```bash
|
||||
# Deactivate a user account
|
||||
python scripts/usermod.py --email user@example.com --deactivate
|
||||
|
||||
# Reactivate a user account
|
||||
python scripts/usermod.py --email user@example.com --activate
|
||||
```
|
||||
|
||||
**Effects:**
|
||||
|
||||
- User cannot log in when deactivated
|
||||
- Existing sessions are immediately invalidated
|
||||
- User's data is preserved
|
||||
- Can be reactivated at any time
|
||||
|
||||
### Deleting Users
|
||||
|
||||
**Via Command Line:**
|
||||
|
||||
```bash
|
||||
# Delete a user (prompts for confirmation)
|
||||
python scripts/userdel.py --email user@example.com
|
||||
|
||||
# Delete without confirmation prompt
|
||||
python scripts/userdel.py --email user@example.com --force
|
||||
```
|
||||
|
||||
**Important:**
|
||||
|
||||
- ⚠️ This action is **permanent**
|
||||
- User's boards, images, and workflows are deleted
|
||||
- Cannot be undone
|
||||
- Consider deactivating instead of deleting
|
||||
|
||||
!!! warning "Data Loss"
|
||||
Deleting a user permanently removes all their content. Back up the database first if recovery might be needed.
|
||||
|
||||
### Viewing User Activity
|
||||
|
||||
**Queue Management:**
|
||||
|
||||
1. Navigate to **Admin** → **Queue Overview**
|
||||
2. View all users' active and pending generations
|
||||
3. Filter by user
|
||||
4. Cancel stuck or problematic tasks
|
||||
|
||||
**User Statistics:**
|
||||
|
||||
- Number of boards created
|
||||
- Number of images generated
|
||||
- Storage usage (if enabled)
|
||||
- Last login time
|
||||
|
||||
## Model Management
|
||||
|
||||
As an administrator, you have full access to model management.
|
||||
|
||||
### Adding Models
|
||||
|
||||
**Via Model Manager UI:**
|
||||
|
||||
1. Go to **Models** tab
|
||||
2. Click **Add Model**
|
||||
3. Choose installation method:
|
||||
- **From URL**: Provide HuggingFace repo or download URL
|
||||
- **From Local Path**: Scan local directories
|
||||
- **Import**: Import model from filesystem
|
||||
|
||||
**Supported Model Types:**
|
||||
|
||||
- Main models (Stable Diffusion, SDXL, FLUX)
|
||||
- LoRA models
|
||||
- ControlNet models
|
||||
- VAE models
|
||||
- Textual Inversions
|
||||
- IP-Adapters
|
||||
|
||||
### Configuring Models
|
||||
|
||||
**Model Settings:**
|
||||
|
||||
- Display name
|
||||
- Description
|
||||
- Default generation settings (CFG, steps, scheduler)
|
||||
- Variant selection (fp16/fp32)
|
||||
- Model thumbnail image
|
||||
|
||||
**Default Settings:**
|
||||
|
||||
Set default parameters that users will start with:
|
||||
|
||||
1. Select a model
|
||||
2. Go to **Default Settings** tab
|
||||
3. Configure:
|
||||
- CFG Scale
|
||||
- Steps
|
||||
- Scheduler
|
||||
- VAE selection
|
||||
4. Save settings
|
||||
|
||||
### Removing Models
|
||||
|
||||
1. Go to **Models** tab
|
||||
2. Select model(s) to remove
|
||||
3. Click **Delete**
|
||||
4. Confirm deletion
|
||||
|
||||
!!! warning "Impact"
|
||||
Removing a model affects all users who may be using it in workflows or saved settings.
|
||||
|
||||
## Shared Boards
|
||||
|
||||
Shared boards enable collaboration between users while maintaining control.
|
||||
|
||||
!!! note "Future Feature"
|
||||
Board sharing will be implemented in a future release.
|
||||
|
||||
### Creating Shared Boards
|
||||
|
||||
1. Log in as administrator
|
||||
2. Create a new board (or use existing board)
|
||||
3. Right-click the board → **Share Board**
|
||||
4. Add users and set permissions
|
||||
5. Click **Save Sharing Settings**
|
||||
|
||||
### Permission Levels
|
||||
|
||||
| Level | View | Add Images | Edit/Delete | Manage Sharing |
|
||||
|-------|------|------------|-------------|----------------|
|
||||
| **Read** | ✅ | ❌ | ❌ | ❌ |
|
||||
| **Write** | ✅ | ✅ | ✅ | ❌ |
|
||||
| **Admin** | ✅ | ✅ | ✅ | ✅ |
|
||||
|
||||
**Permission Recommendations:**
|
||||
|
||||
- **Read**: For viewers who should see but not modify content
|
||||
- **Write**: For active collaborators who add and organize images
|
||||
- **Admin**: For trusted users who help manage the shared board
|
||||
|
||||
### Managing Shared Boards
|
||||
|
||||
**Add Users to Shared Board:**
|
||||
|
||||
1. Right-click shared board → **Manage Sharing**
|
||||
2. Click **Add User**
|
||||
3. Select user from dropdown
|
||||
4. Choose permission level
|
||||
5. Save changes
|
||||
|
||||
**Remove Users from Shared Board:**
|
||||
|
||||
1. Right-click shared board → **Manage Sharing**
|
||||
2. Find user in list
|
||||
3. Click **Remove**
|
||||
4. Confirm removal
|
||||
|
||||
**Change User Permissions:**
|
||||
|
||||
1. Right-click shared board → **Manage Sharing**
|
||||
2. Find user in list
|
||||
3. Change permission dropdown
|
||||
4. Save changes
|
||||
|
||||
### Shared Board Best Practices
|
||||
|
||||
- Give meaningful names to shared boards
|
||||
- Document the board's purpose in the description
|
||||
- Assign minimum necessary permissions
|
||||
- Regularly audit access lists
|
||||
- Remove users who no longer need access
|
||||
|
||||
## Security
|
||||
|
||||
### Password Policies
|
||||
|
||||
**Enforced Requirements:**
|
||||
|
||||
- Minimum 8 characters
|
||||
- Must contain uppercase letters
|
||||
- Must contain lowercase letters
|
||||
- Must contain numbers
|
||||
|
||||
**Recommended Policies:**
|
||||
|
||||
- Require 12+ character passwords
|
||||
- Include special characters
|
||||
- Implement password rotation every 90 days
|
||||
- Prevent password reuse
|
||||
- Use multi-factor authentication (when available)
|
||||
|
||||
### Session Management
|
||||
|
||||
**Session Security and Token Management:**
|
||||
|
||||
This system uses stateless JWT tokens with HMAC signatures to
|
||||
identify users after they provide their initial credentials. The
|
||||
tokens will persist for 24 hours by default, or for 7 days if the user
|
||||
clicks the "Remember me" checkbox at login. Expired tokens are
|
||||
automatically rejected and the user will have to log in again.
|
||||
|
||||
At the client side, tokens are stored in browser localStorage. Logging
|
||||
out clears them. No server-side session storage is required.
|
||||
|
||||
The tokens include the user's ID, email, and admin status, along with
|
||||
an HMAC signature.
|
||||
|
||||
### Secret Key Management
|
||||
|
||||
**Important:** The JWT secret key must be kept confidential.
|
||||
|
||||
To generate tokens, each InvokeAI instance has a distinct secret JWT key that must be
|
||||
kept confidential. The key is stored in the `app_settings` table of
|
||||
the InvokeAI database with in a field value named `jwt_secret`.
|
||||
|
||||
The secret key is automatically generated during database creation or
|
||||
migration. If you wish to change the key, you may generate a
|
||||
replacement using either of these commands:
|
||||
|
||||
|
||||
```bash
|
||||
# Python
|
||||
python -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
|
||||
# OpenSSL
|
||||
openssl rand -base64 32
|
||||
```
|
||||
|
||||
Then cut and paste the printed secret into this Sqlite3 command:
|
||||
|
||||
```bash
|
||||
sqlite3 INVOKE_ROOT/databases/invokeai.db 'update app_settings set value="THE_SECRET" where key="jwt_secret"'
|
||||
```
|
||||
|
||||
(replace INVOKE_ROOT with your InvokeAI root directory and THE_SECRET
|
||||
with the new secret).
|
||||
|
||||
After this, restart the server. All logged in users will be logged out
|
||||
and will need to provide their usernames and passwords again.
|
||||
|
||||
### Hosting a Shared InvokeAI Instance
|
||||
|
||||
The multiuser feature allows you to run an InvokeAI backend that can
|
||||
be accessed by your friends and family across your home network. It is
|
||||
also possible to host a backend that is accessible over the Internet.
|
||||
|
||||
By default, InvokeAI runs on `localhost`, IP address `127.0.0.1`,
|
||||
which is only accessible to browsers running on the same machine as
|
||||
the backend. To make the backend accessible to any machine on your
|
||||
home or work LAN, add the line `host: 0.0.0.0` to the InvokeAI
|
||||
configuration file, usually stored at `INVOKE_ROOT/invokeai.yaml`.
|
||||
|
||||
Here is a minimal example.
|
||||
|
||||
```yaml
|
||||
# Internal metadata - do not edit:
|
||||
schema_version: 4.0.2
|
||||
|
||||
# Put user settings here - see https://invoke-ai.github.io/InvokeAI/configuration/:
|
||||
multiuser: true
|
||||
host: 0.0.0.0
|
||||
```
|
||||
|
||||
After relaunching the backend you will be able to reach the server
|
||||
from other machines on the LAN using the server machine's IP address
|
||||
or hostname and port 9090.
|
||||
|
||||
#### Connecting to the Internet
|
||||
|
||||
!!! warning "Use at your own risk"
|
||||
The InvokeAI team has done its best to make the software free of
|
||||
exploitable bugs, but the software has not undergone a rigorous security
|
||||
audit or intrusion testing. Use at your own risk
|
||||
|
||||
It is also possible to create a (semi) public server accessible from
|
||||
the Internet. The details of how to do this depend very much on your
|
||||
home or corporate router/firewall system and are beyond the scope of
|
||||
this document.
|
||||
|
||||
If you expose InvokeAI to the Internet, there are a number of
|
||||
precautions to take. Here is a brief list of recommended network
|
||||
security practices.
|
||||
|
||||
**HTTPS Configuration:**
|
||||
|
||||
For internet deployments, always use HTTPS:
|
||||
|
||||
```yaml
|
||||
# Use a reverse proxy like nginx or Traefik
|
||||
# Example nginx configuration:
|
||||
|
||||
server {
|
||||
listen 443 ssl http2;
|
||||
server_name invoke.example.com;
|
||||
|
||||
ssl_certificate /path/to/cert.pem;
|
||||
ssl_certificate_key /path/to/key.pem;
|
||||
|
||||
location / {
|
||||
proxy_pass http://localhost:9090;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
# WebSocket support
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Firewall Rules:**
|
||||
|
||||
It is best to restrict access to trusted networks and remote IP
|
||||
addresses, or use a VPN to connect to your home network. Rate limit
|
||||
connections to InvokeAI's authentication endpoint
|
||||
`http://your.host:9090/login`.
|
||||
|
||||
**Backup and Recovery:**
|
||||
|
||||
It is a good idea to periodically backup your InvokeAI database,
|
||||
images, and possibly models in the event of unauthorized use of a
|
||||
publicly-accessible server.
|
||||
|
||||
**Manual Backup:**
|
||||
|
||||
```bash
|
||||
# Stop InvokeAI
|
||||
# Copy database file
|
||||
cd INVOKE_ROOT
|
||||
cp databases/invokeai.db databases/invokeai.db.$(date +%Y%m%d)
|
||||
|
||||
# Or create compressed backup
|
||||
tar -czf invokeai_backup_$(date +%Y%m%d).tar.gz databases/
|
||||
```
|
||||
|
||||
**Automated Backup Script:**
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# backup_invokeai.sh
|
||||
|
||||
INVOKE_ROOT="/path/to/invoke_root"
|
||||
BACKUP_DIR="/path/to/backups"
|
||||
DB_PATH="$INVOKE_ROOT/databases/invokeai.db"
|
||||
DATE=$(date +%Y%m%d_%H%M%S)
|
||||
|
||||
# Create backup directory
|
||||
mkdir -p "$BACKUP_DIR"
|
||||
|
||||
# Copy database
|
||||
cp "$DB_PATH" "$BACKUP_DIR/invokeai_$DATE.db"
|
||||
|
||||
# Keep only last 30 days
|
||||
find "$BACKUP_DIR" -name "invokeai_*.db" -mtime +30 -delete
|
||||
|
||||
echo "Backup completed: invokeai_$DATE.db"
|
||||
```
|
||||
|
||||
**Schedule with cron:**
|
||||
|
||||
```bash
|
||||
# Edit crontab
|
||||
crontab -e
|
||||
|
||||
# Add daily backup at 2 AM
|
||||
0 2 * * * /path/to/backup_invokeai.sh
|
||||
```
|
||||
|
||||
|
||||
|
||||
```bash
|
||||
# Stop InvokeAI
|
||||
# Replace current database with backup
|
||||
cd INVOKE_ROOT
|
||||
cp databases/invokeai.db databases/invokeai.db.old # Save current
|
||||
cp databases/invokeai_backup.db databases/invokeai.db
|
||||
|
||||
# Restart InvokeAI
|
||||
invokeai-web
|
||||
```
|
||||
|
||||
**Disaster Recover - Complete System Backup:**
|
||||
|
||||
Include these directories/files:
|
||||
|
||||
- `databases/` - All database files
|
||||
- `models/` - Installed models (if locally stored)
|
||||
- `outputs/` - Generated images
|
||||
- `invokeai.yaml` - Configuration file
|
||||
- Any custom scripts or modifications
|
||||
|
||||
**Recovery Process:**
|
||||
|
||||
1. Install InvokeAI on new system
|
||||
2. Restore configuration file
|
||||
3. Restore database directory
|
||||
4. Restore models and outputs
|
||||
5. Verify file permissions
|
||||
6. Start InvokeAI and test
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### User Cannot Login
|
||||
|
||||
**Symptom:** User reports unable to log in
|
||||
|
||||
**Diagnosis:**
|
||||
|
||||
1. Verify account exists and is active
|
||||
```bash
|
||||
sqlite3 databases/invokeai.db "SELECT * FROM users WHERE email = 'user@example.com';"
|
||||
```
|
||||
|
||||
2. Check password (have user try resetting)
|
||||
3. Verify account is active (`is_active = 1`)
|
||||
4. Check for account lockout (if implemented)
|
||||
|
||||
**Solutions:**
|
||||
|
||||
- Reset user password
|
||||
- Reactivate disabled account
|
||||
- Verify email address is correct
|
||||
- Check system logs for auth errors
|
||||
|
||||
### Database Locked Errors
|
||||
|
||||
**Symptom:** "Database is locked" errors
|
||||
|
||||
**Causes:**
|
||||
|
||||
- Concurrent write operations
|
||||
- Long-running transactions
|
||||
- Backup process accessing database
|
||||
- File system issues
|
||||
|
||||
**Solutions:**
|
||||
|
||||
```bash
|
||||
# Check for locks
|
||||
fuser databases/invokeai.db
|
||||
|
||||
# Increase timeout (in config)
|
||||
# Or switch to WAL mode:
|
||||
sqlite3 databases/invokeai.db "PRAGMA journal_mode=WAL;"
|
||||
```
|
||||
|
||||
### Forgotten Admin Password
|
||||
|
||||
**Recovery Process:**
|
||||
|
||||
1. Stop InvokeAI
|
||||
2. Direct database access:
|
||||
```bash
|
||||
sqlite3 databases/invokeai.db
|
||||
```
|
||||
|
||||
3. Reset admin password (requires password hash):
|
||||
```sql
|
||||
-- Generate hash first using Python:
|
||||
-- from passlib.context import CryptContext
|
||||
-- pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
-- print(pwd_context.hash("NewPassword123"))
|
||||
|
||||
UPDATE users
|
||||
SET password_hash = '$2b$12$...'
|
||||
WHERE email = 'admin@example.com';
|
||||
```
|
||||
|
||||
4. Restart InvokeAI
|
||||
|
||||
**Alternative:** Remove `jwt_secret_key` from config to trigger setup wizard (will create new admin).
|
||||
|
||||
### Performance Issues
|
||||
|
||||
**Symptom:** Slow generation or UI
|
||||
|
||||
**Diagnosis:**
|
||||
|
||||
1. Check active generation count
|
||||
2. Review resource usage (CPU/GPU/RAM)
|
||||
3. Check database size and performance
|
||||
4. Review network latency
|
||||
|
||||
**Solutions:**
|
||||
|
||||
- Limit concurrent generations
|
||||
- Increase hardware resources
|
||||
- Optimize database (`VACUUM`, `ANALYZE`)
|
||||
- Add indexes for slow queries
|
||||
- Consider load balancing
|
||||
|
||||
### Migration Failures
|
||||
|
||||
**Symptom:** Database migration fails on upgrade
|
||||
|
||||
**Prevention:**
|
||||
|
||||
- Always backup before upgrading
|
||||
- Test migration on copy of database
|
||||
- Review migration logs
|
||||
|
||||
**Recovery:**
|
||||
|
||||
```bash
|
||||
# Restore backup
|
||||
cp databases/invokeai.db.backup databases/invokeai.db
|
||||
|
||||
# Try migration again with verbose logging
|
||||
invokeai-web --log-level DEBUG
|
||||
```
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
### Complete Configuration Example for a Public Site
|
||||
|
||||
```yaml
|
||||
# invokeai.yaml - Multi-user configuration
|
||||
|
||||
# Internal metadata - do not edit:
|
||||
schema_version: 4.0.2
|
||||
|
||||
# Put user settings here
|
||||
multiuser: true
|
||||
|
||||
# Server
|
||||
host: "0.0.0.0"
|
||||
port: 9090
|
||||
|
||||
# Performance
|
||||
enable_partial_loading: true
|
||||
precision: float16
|
||||
pytorch_cuda_alloc_conf: "backend:cudaMallocAsync"
|
||||
hashing_algorithm: blake3_multi
|
||||
```
|
||||
## Frequently Asked Questions
|
||||
|
||||
### How many users can InvokeAI support?
|
||||
|
||||
The backend will support dozens of concurrent users. However, because
|
||||
the image generation queue is single-threaded, image generation tasks
|
||||
are processed on a first-come, first-serve basis. This means that a
|
||||
user may have to wait for all the other users' image generation jobs
|
||||
to complete before their generation job starts to execute.
|
||||
|
||||
A future version of InvokeAI may support concurrent execution on
|
||||
systems with multiple GPUs/graphics cards.
|
||||
|
||||
### Can I integrate with existing authentication systems?
|
||||
|
||||
OAuth2/OpenID Connect support is planned for a future release. Currently, InvokeAI uses its own authentication system.
|
||||
|
||||
### How do I audit user actions?
|
||||
|
||||
Full audit logging is planned for a future release. Currently, you can:
|
||||
|
||||
- Monitor the generation queue
|
||||
- Review database changes
|
||||
- Check application logs
|
||||
|
||||
### Can users have different model access?
|
||||
|
||||
Not in the current release. All users can view and use all installed models. Per-user model access is a possible enhancement.
|
||||
|
||||
### How do I handle user data when they leave?
|
||||
|
||||
Best practice:
|
||||
|
||||
1. Deactivate the account first
|
||||
2. Transfer ownership of shared boards
|
||||
3. After transition period, delete the account
|
||||
4. Or keep the account deactivated for audit purposes
|
||||
|
||||
### What's the licensing impact of multi-user mode?
|
||||
|
||||
InvokeAI remains under its existing license. Multi-user mode does not change licensing terms.
|
||||
|
||||
## Getting Help
|
||||
|
||||
### Support Resources
|
||||
|
||||
- **Documentation**: [InvokeAI Docs](https://invoke-ai.github.io/InvokeAI/)
|
||||
- **Discord**: [Join Community](https://discord.gg/ZmtBAhwWhy)
|
||||
- **GitHub Issues**: [Report Problems](https://github.com/invoke-ai/InvokeAI/issues)
|
||||
- **User Guide**: [For Users](user_guide.md)
|
||||
- **API Guide**: [For Developers](api_guide.md)
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
When reporting administrator issues, include:
|
||||
|
||||
- InvokeAI version
|
||||
- Operating system and version
|
||||
- Database size and user count
|
||||
- Relevant log excerpts
|
||||
- Steps to reproduce
|
||||
- Expected vs actual behavior
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [User Guide](user_guide.md) - For end users
|
||||
- [API Guide](api_guide.md) - For API consumers
|
||||
- [Multiuser Specification](specification.md) - Technical details
|
||||
|
||||
---
|
||||
|
||||
**Need additional assistance?** Visit the [InvokeAI Discord](https://discord.gg/ZmtBAhwWhy) or file an issue on [GitHub](https://github.com/invoke-ai/InvokeAI/issues).
|
||||
1224
docs/multiuser/api_guide.md
Normal file
1224
docs/multiuser/api_guide.md
Normal file
File diff suppressed because it is too large
Load Diff
870
docs/multiuser/specification.md
Normal file
870
docs/multiuser/specification.md
Normal file
@@ -0,0 +1,870 @@
|
||||
# InvokeAI Multi-User Support - Detailed Specification
|
||||
|
||||
## 1. Executive Summary
|
||||
|
||||
This document provides a comprehensive specification for adding multi-user support to InvokeAI. The feature will enable a single InvokeAI instance to support multiple isolated users, each with their own generation settings, image boards, and workflows, while maintaining administrative controls for model management and system configuration.
|
||||
|
||||
## 2. Overview
|
||||
|
||||
### 2.1 Goals
|
||||
- Enable multiple users to share a single InvokeAI instance
|
||||
- Provide user isolation for personal content (boards, images, workflows, settings)
|
||||
- Maintain centralized model management by administrators
|
||||
- Support shared boards for collaboration
|
||||
- Provide secure authentication and authorization
|
||||
- Minimize impact on existing single-user installations
|
||||
|
||||
### 2.2 Non-Goals
|
||||
- Real-time collaboration features (multiple users editing same workflow simultaneously)
|
||||
- Advanced team management features (in initial release)
|
||||
- Migration of existing multi-user enterprise edition data
|
||||
- Support for external identity providers (in initial release, can be added later)
|
||||
|
||||
## 3. User Roles and Permissions
|
||||
|
||||
### 3.1 Administrator Role
|
||||
**Capabilities:**
|
||||
|
||||
- Full access to all InvokeAI features
|
||||
- Model management (add, delete, configure models)
|
||||
- User management (create, edit, delete users)
|
||||
- View and manage all users' queue sessions
|
||||
- Access system configuration
|
||||
- Create and manage shared boards
|
||||
- Grant/revoke administrative privileges to other users
|
||||
|
||||
**Restrictions:**
|
||||
|
||||
- Cannot delete their own account if they are the last administrator
|
||||
- Cannot revoke their own admin privileges if they are the last administrator
|
||||
|
||||
### 3.2 Regular User Role
|
||||
**Capabilities:**
|
||||
|
||||
- Create, edit, and delete their own image boards
|
||||
- Upload and manage their own assets
|
||||
- Use all image generation tools (linear, canvas, upscale, workflow tabs)
|
||||
- Create, edit, save, and load workflows
|
||||
- Access public/shared workflows
|
||||
- View and manage their own queue sessions
|
||||
- Adjust personal UI preferences (theme, hotkeys, etc.)
|
||||
- Access shared boards (read/write based on permissions)
|
||||
- **View model configurations** (read-only access to model manager)
|
||||
- **View model details, default settings, and metadata**
|
||||
|
||||
**Restrictions:**
|
||||
|
||||
- Cannot add, delete, or edit models
|
||||
- **Can view but cannot modify model manager settings** (read-only access)
|
||||
- Cannot reidentify, convert, or update model paths
|
||||
- Cannot upload or change model thumbnail images
|
||||
- Cannot save changes to model default settings
|
||||
- Cannot perform bulk delete operations on models
|
||||
- Cannot view or modify other users' boards, images, or workflows
|
||||
- Cannot cancel or modify other users' queue sessions
|
||||
- Cannot access system configuration
|
||||
- Cannot manage users or permissions
|
||||
|
||||
### 3.3 Future Role Considerations
|
||||
- **Viewer Role**: Read-only access (future enhancement)
|
||||
- **Team/Group-based Permissions**: Organizational hierarchy (future enhancement)
|
||||
|
||||
## 4. Authentication System
|
||||
|
||||
### 4.1 Authentication Method
|
||||
- **Primary Method**: Username and password authentication with secure password hashing
|
||||
- **Password Hashing**: Use bcrypt or Argon2 for password storage
|
||||
- **Session Management**: JWT tokens or secure session cookies
|
||||
- **Token Expiration**: Configurable session timeout (default: 7 days for "remember me", 24 hours otherwise)
|
||||
|
||||
### 4.2 Initial Administrator Setup
|
||||
**First-time Launch Flow:**
|
||||
|
||||
1. Application detects no administrator account exists
|
||||
2. Displays mandatory setup dialog (cannot be skipped)
|
||||
3. Prompts for:
|
||||
- Administrator username (email format recommended)
|
||||
- Administrator display name
|
||||
- Strong password (minimum requirements enforced)
|
||||
- Password confirmation
|
||||
4. Stores hashed credentials in configuration
|
||||
5. Creates administrator account in database
|
||||
6. Proceeds to normal login screen
|
||||
|
||||
**Reset Capability:**
|
||||
|
||||
- Administrators can be reset by manually editing the config file
|
||||
- Requires access to server filesystem (intentional security measure)
|
||||
- Database maintains user records; config file contains root admin credentials
|
||||
|
||||
### 4.3 Password Requirements
|
||||
- Minimum 8 characters
|
||||
- At least one uppercase letter
|
||||
- At least one lowercase letter
|
||||
- At least one number
|
||||
- At least one special character (optional but recommended)
|
||||
- Not in common password list
|
||||
|
||||
### 4.4 Login Flow
|
||||
|
||||
1. User navigates to InvokeAI URL
|
||||
2. If not authenticated, redirect to login page
|
||||
3. User enters username/email and password
|
||||
4. Optional "Remember me" checkbox for extended session
|
||||
5. Backend validates credentials
|
||||
6. On success: Generate session token, redirect to application
|
||||
7. On failure: Display error, allow retry with rate limiting (prevent brute force)
|
||||
|
||||
### 4.5 Logout Flow
|
||||
- User clicks logout button
|
||||
- Frontend clears session token
|
||||
- Backend invalidates session (if using server-side sessions)
|
||||
- Redirect to login page
|
||||
|
||||
### 4.6 Future Authentication Enhancements
|
||||
- OAuth2/OpenID Connect support
|
||||
- Two-factor authentication (2FA)
|
||||
- SSO integration
|
||||
- API key authentication for programmatic access
|
||||
|
||||
## 5. User Management
|
||||
|
||||
### 5.1 User Creation (Administrator)
|
||||
**Flow:**
|
||||
|
||||
1. Administrator navigates to user management interface
|
||||
2. Clicks "Add User" button
|
||||
3. Enters user information:
|
||||
- Email address (required, used as username)
|
||||
- Display name (optional, defaults to email)
|
||||
- Role (User or Administrator)
|
||||
- Initial password or "Send invitation email"
|
||||
4. System validates email uniqueness
|
||||
5. System creates user account
|
||||
6. If invitation mode:
|
||||
- Generate one-time secure token
|
||||
- Send email with setup link
|
||||
- Link expires after 7 days
|
||||
7. If direct password mode:
|
||||
- Administrator provides initial password
|
||||
- User must change on first login
|
||||
|
||||
**Invitation Email Flow:**
|
||||
|
||||
1. User receives email with unique link
|
||||
2. Link contains secure token
|
||||
3. User clicks link, redirected to setup page
|
||||
4. User enters desired password
|
||||
5. Token validated and consumed (single-use)
|
||||
6. Account activated
|
||||
7. User redirected to login page
|
||||
|
||||
### 5.2 User Profile Management
|
||||
**User Self-Service:**
|
||||
|
||||
- Update display name
|
||||
- Change password (requires current password)
|
||||
- Update email address (requires verification)
|
||||
- Manage UI preferences
|
||||
- View account creation date and last login
|
||||
|
||||
**Administrator Actions:**
|
||||
|
||||
- Edit user information (name, email)
|
||||
- Reset user password (generates reset link)
|
||||
- Toggle administrator privileges
|
||||
- Assign to groups (future feature)
|
||||
- Suspend/unsuspend account
|
||||
- Delete account (with data retention options)
|
||||
|
||||
### 5.3 Password Reset Flow
|
||||
**User-Initiated (Future Enhancement):**
|
||||
|
||||
1. User clicks "Forgot Password" on login page
|
||||
2. Enters email address
|
||||
3. System sends password reset link (if email exists)
|
||||
4. User clicks link, enters new password
|
||||
5. Password updated, user can login
|
||||
|
||||
**Administrator-Initiated:**
|
||||
|
||||
1. Administrator selects user
|
||||
2. Clicks "Send Password Reset"
|
||||
3. System generates reset token and link
|
||||
4. Email sent to user
|
||||
5. User follows same flow as user-initiated reset
|
||||
|
||||
## 6. Data Model and Database Schema
|
||||
|
||||
### 6.1 New Tables
|
||||
|
||||
#### 6.1.1 users
|
||||
```sql
|
||||
CREATE TABLE users (
|
||||
user_id TEXT NOT NULL PRIMARY KEY,
|
||||
email TEXT NOT NULL UNIQUE,
|
||||
display_name TEXT,
|
||||
password_hash TEXT NOT NULL,
|
||||
is_admin BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
is_active BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
last_login_at DATETIME
|
||||
);
|
||||
CREATE INDEX idx_users_email ON users(email);
|
||||
CREATE INDEX idx_users_is_admin ON users(is_admin);
|
||||
CREATE INDEX idx_users_is_active ON users(is_active);
|
||||
```
|
||||
|
||||
#### 6.1.2 user_sessions
|
||||
```sql
|
||||
CREATE TABLE user_sessions (
|
||||
session_id TEXT NOT NULL PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
token_hash TEXT NOT NULL,
|
||||
expires_at DATETIME NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
last_activity_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
user_agent TEXT,
|
||||
ip_address TEXT,
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX idx_user_sessions_user_id ON user_sessions(user_id);
|
||||
CREATE INDEX idx_user_sessions_expires_at ON user_sessions(expires_at);
|
||||
CREATE INDEX idx_user_sessions_token_hash ON user_sessions(token_hash);
|
||||
```
|
||||
|
||||
#### 6.1.3 user_invitations
|
||||
```sql
|
||||
CREATE TABLE user_invitations (
|
||||
invitation_id TEXT NOT NULL PRIMARY KEY,
|
||||
email TEXT NOT NULL,
|
||||
token_hash TEXT NOT NULL,
|
||||
invited_by_user_id TEXT NOT NULL,
|
||||
expires_at DATETIME NOT NULL,
|
||||
used_at DATETIME,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
FOREIGN KEY (invited_by_user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX idx_user_invitations_email ON user_invitations(email);
|
||||
CREATE INDEX idx_user_invitations_token_hash ON user_invitations(token_hash);
|
||||
CREATE INDEX idx_user_invitations_expires_at ON user_invitations(expires_at);
|
||||
```
|
||||
|
||||
#### 6.1.4 shared_boards
|
||||
```sql
|
||||
CREATE TABLE shared_boards (
|
||||
board_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
permission TEXT NOT NULL CHECK(permission IN ('read', 'write', 'admin')),
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
PRIMARY KEY (board_id, user_id),
|
||||
FOREIGN KEY (board_id) REFERENCES boards(board_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX idx_shared_boards_user_id ON shared_boards(user_id);
|
||||
CREATE INDEX idx_shared_boards_board_id ON shared_boards(board_id);
|
||||
```
|
||||
|
||||
### 6.2 Modified Tables
|
||||
|
||||
#### 6.2.1 boards
|
||||
```sql
|
||||
-- Add columns:
|
||||
ALTER TABLE boards ADD COLUMN user_id TEXT NOT NULL DEFAULT 'system';
|
||||
ALTER TABLE boards ADD COLUMN is_shared BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
ALTER TABLE boards ADD COLUMN created_by_user_id TEXT;
|
||||
|
||||
-- Add foreign key (requires recreation in SQLite):
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
FOREIGN KEY (created_by_user_id) REFERENCES users(user_id) ON DELETE SET NULL
|
||||
|
||||
-- Add indices:
|
||||
CREATE INDEX idx_boards_user_id ON boards(user_id);
|
||||
CREATE INDEX idx_boards_is_shared ON boards(is_shared);
|
||||
```
|
||||
|
||||
#### 6.2.2 images
|
||||
```sql
|
||||
-- Add column:
|
||||
ALTER TABLE images ADD COLUMN user_id TEXT NOT NULL DEFAULT 'system';
|
||||
|
||||
-- Add foreign key:
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
|
||||
-- Add index:
|
||||
CREATE INDEX idx_images_user_id ON images(user_id);
|
||||
```
|
||||
|
||||
#### 6.2.3 workflows
|
||||
```sql
|
||||
-- Add columns:
|
||||
ALTER TABLE workflows ADD COLUMN user_id TEXT NOT NULL DEFAULT 'system';
|
||||
ALTER TABLE workflows ADD COLUMN is_public BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
-- Add foreign key:
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
|
||||
-- Add indices:
|
||||
CREATE INDEX idx_workflows_user_id ON workflows(user_id);
|
||||
CREATE INDEX idx_workflows_is_public ON workflows(is_public);
|
||||
```
|
||||
|
||||
#### 6.2.4 session_queue
|
||||
```sql
|
||||
-- Add column:
|
||||
ALTER TABLE session_queue ADD COLUMN user_id TEXT NOT NULL DEFAULT 'system';
|
||||
|
||||
-- Add foreign key:
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
|
||||
-- Add index:
|
||||
CREATE INDEX idx_session_queue_user_id ON session_queue(user_id);
|
||||
```
|
||||
|
||||
#### 6.2.5 style_presets
|
||||
```sql
|
||||
-- Add columns:
|
||||
ALTER TABLE style_presets ADD COLUMN user_id TEXT NOT NULL DEFAULT 'system';
|
||||
ALTER TABLE style_presets ADD COLUMN is_public BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
-- Add foreign key:
|
||||
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
|
||||
|
||||
-- Add indices:
|
||||
CREATE INDEX idx_style_presets_user_id ON style_presets(user_id);
|
||||
CREATE INDEX idx_style_presets_is_public ON style_presets(is_public);
|
||||
```
|
||||
|
||||
### 6.3 Migration Strategy
|
||||
|
||||
1. Create new user tables (users, user_sessions, user_invitations, shared_boards)
|
||||
2. Create default 'system' user for backward compatibility
|
||||
3. Update existing data to reference 'system' user
|
||||
4. Add foreign key constraints
|
||||
5. Version as database migration (e.g., migration_25.py)
|
||||
|
||||
### 6.4 Migration for Existing Installations
|
||||
- Single-user installations: Prompt to create admin account on first launch after update
|
||||
- Existing data migration: Administrator can specify an arbitrary user account to hold legacy data (can be the admin account or a separate user)
|
||||
- System provides UI during migration to choose destination user for existing data
|
||||
|
||||
## 7. API Endpoints
|
||||
|
||||
### 7.1 Authentication Endpoints
|
||||
|
||||
#### POST /api/v1/auth/setup
|
||||
- Initialize first administrator account
|
||||
- Only works if no admin exists
|
||||
- Body: `{ email, display_name, password }`
|
||||
- Response: `{ success, user }`
|
||||
|
||||
#### POST /api/v1/auth/login
|
||||
- Authenticate user
|
||||
- Body: `{ email, password, remember_me? }`
|
||||
- Response: `{ token, user, expires_at }`
|
||||
|
||||
#### POST /api/v1/auth/logout
|
||||
- Invalidate current session
|
||||
- Headers: `Authorization: Bearer <token>`
|
||||
- Response: `{ success }`
|
||||
|
||||
#### GET /api/v1/auth/me
|
||||
- Get current user information
|
||||
- Headers: `Authorization: Bearer <token>`
|
||||
- Response: `{ user }`
|
||||
|
||||
#### POST /api/v1/auth/change-password
|
||||
- Change current user's password
|
||||
- Body: `{ current_password, new_password }`
|
||||
- Headers: `Authorization: Bearer <token>`
|
||||
- Response: `{ success }`
|
||||
|
||||
### 7.2 User Management Endpoints (Admin Only)
|
||||
|
||||
#### GET /api/v1/users
|
||||
- List all users (paginated)
|
||||
- Query params: `offset`, `limit`, `search`, `role_filter`
|
||||
- Response: `{ users[], total, offset, limit }`
|
||||
|
||||
#### POST /api/v1/users
|
||||
- Create new user
|
||||
- Body: `{ email, display_name, is_admin, send_invitation?, initial_password? }`
|
||||
- Response: `{ user, invitation_link? }`
|
||||
|
||||
#### GET /api/v1/users/{user_id}
|
||||
- Get user details
|
||||
- Response: `{ user }`
|
||||
|
||||
#### PATCH /api/v1/users/{user_id}
|
||||
- Update user
|
||||
- Body: `{ display_name?, is_admin?, is_active? }`
|
||||
- Response: `{ user }`
|
||||
|
||||
#### DELETE /api/v1/users/{user_id}
|
||||
- Delete user
|
||||
- Query params: `delete_data` (true/false)
|
||||
- Response: `{ success }`
|
||||
|
||||
#### POST /api/v1/users/{user_id}/reset-password
|
||||
- Send password reset email
|
||||
- Response: `{ success, reset_link }`
|
||||
|
||||
### 7.3 Shared Boards Endpoints
|
||||
|
||||
#### POST /api/v1/boards/{board_id}/share
|
||||
- Share board with users
|
||||
- Body: `{ user_ids[], permission: 'read' | 'write' | 'admin' }`
|
||||
- Response: `{ success, shared_with[] }`
|
||||
|
||||
#### GET /api/v1/boards/{board_id}/shares
|
||||
- Get board sharing information
|
||||
- Response: `{ shares[] }`
|
||||
|
||||
#### DELETE /api/v1/boards/{board_id}/share/{user_id}
|
||||
- Remove board sharing
|
||||
- Response: `{ success }`
|
||||
|
||||
### 7.4 Modified Endpoints
|
||||
|
||||
All existing endpoints will be modified to:
|
||||
|
||||
1. Require authentication (except setup/login)
|
||||
2. Filter data by current user (unless admin viewing all)
|
||||
3. Enforce permissions (e.g., model management requires admin)
|
||||
4. Include user context in operations
|
||||
|
||||
Example modifications:
|
||||
- `GET /api/v1/boards` → Returns only user's boards + shared boards
|
||||
- `POST /api/v1/session/queue` → Associates queue item with current user
|
||||
- `GET /api/v1/queue` → Returns all items for admin, only user's items for regular users
|
||||
|
||||
## 8. Frontend Changes
|
||||
|
||||
### 8.1 New Components
|
||||
|
||||
#### LoginPage
|
||||
- Email/password form
|
||||
- "Remember me" checkbox
|
||||
- Login button
|
||||
- Forgot password link (future)
|
||||
- Branding and welcome message
|
||||
|
||||
#### AdministratorSetup
|
||||
- Modal dialog (cannot be dismissed)
|
||||
- Administrator account creation form
|
||||
- Password strength indicator
|
||||
- Terms/welcome message
|
||||
|
||||
#### UserManagementPage (Admin only)
|
||||
- User list table
|
||||
- Add user button
|
||||
- User actions (edit, delete, reset password)
|
||||
- Search and filter
|
||||
- Role toggle
|
||||
|
||||
#### UserProfilePage
|
||||
- Display user information
|
||||
- Change password form
|
||||
- UI preferences
|
||||
- Account details
|
||||
|
||||
#### BoardSharingDialog
|
||||
- User picker/search
|
||||
- Permission selector
|
||||
- Share button
|
||||
- Current shares list
|
||||
|
||||
### 8.2 Modified Components
|
||||
|
||||
#### App Root
|
||||
- Add authentication check
|
||||
- Redirect to login if not authenticated
|
||||
- Handle session expiration
|
||||
- Add global error boundary for auth errors
|
||||
|
||||
#### Navigation/Header
|
||||
- Add user menu with logout
|
||||
- Display current user name
|
||||
- Admin indicator badge
|
||||
|
||||
#### ModelManagerTab
|
||||
- Hide/disable for non-admin users
|
||||
- Show "Admin only" message
|
||||
|
||||
#### QueuePanel
|
||||
- Filter by current user (for non-admin)
|
||||
- Show all with user indicators (for admin)
|
||||
- Disable actions on other users' items (for non-admin)
|
||||
|
||||
#### BoardsPanel
|
||||
- Show personal boards section
|
||||
- Show shared boards section
|
||||
- Add sharing controls to board actions
|
||||
|
||||
### 8.3 State Management
|
||||
|
||||
New Redux slices/zustand stores:
|
||||
- `authSlice`: Current user, authentication status, token
|
||||
- `usersSlice`: User list for admin interface
|
||||
- `sharingSlice`: Board sharing state
|
||||
|
||||
Updated slices:
|
||||
- `boardsSlice`: Include shared boards, ownership info
|
||||
- `queueSlice`: Include user filtering
|
||||
- `workflowsSlice`: Include public/private status
|
||||
|
||||
## 9. Configuration
|
||||
|
||||
### 9.1 New Config Options
|
||||
|
||||
Add to `InvokeAIAppConfig`:
|
||||
|
||||
```python
|
||||
# Authentication
|
||||
auth_enabled: bool = True # Enable/disable multi-user auth
|
||||
session_expiry_hours: int = 24 # Default session expiration
|
||||
session_expiry_hours_remember: int = 168 # "Remember me" expiration (7 days)
|
||||
password_min_length: int = 8 # Minimum password length
|
||||
require_strong_passwords: bool = True # Enforce password complexity
|
||||
|
||||
# Session tracking
|
||||
enable_server_side_sessions: bool = False # Optional server-side session tracking
|
||||
|
||||
# Audit logging
|
||||
audit_log_auth_events: bool = True # Log authentication events
|
||||
audit_log_admin_actions: bool = True # Log administrative actions
|
||||
|
||||
# Email (optional - for invitations and password reset)
|
||||
email_enabled: bool = False
|
||||
smtp_host: str = ""
|
||||
smtp_port: int = 587
|
||||
smtp_username: str = ""
|
||||
smtp_password: str = ""
|
||||
smtp_from_address: str = ""
|
||||
smtp_from_name: str = "InvokeAI"
|
||||
|
||||
# Initial admin (stored as hash)
|
||||
admin_email: Optional[str] = None
|
||||
admin_password_hash: Optional[str] = None
|
||||
```
|
||||
|
||||
### 9.2 Backward Compatibility
|
||||
|
||||
- If `auth_enabled = False`, system runs in legacy single-user mode
|
||||
- All data belongs to implicit "system" user
|
||||
- No authentication required
|
||||
- Smooth upgrade path for existing installations
|
||||
|
||||
## 10. Security Considerations
|
||||
|
||||
### 10.1 Password Security
|
||||
- Never store passwords in plain text
|
||||
- Use bcrypt or Argon2id for password hashing
|
||||
- Implement proper salt generation
|
||||
- Enforce password complexity requirements
|
||||
- Implement rate limiting on login attempts
|
||||
- Consider password breach checking (Have I Been Pwned API)
|
||||
|
||||
### 10.2 Session Security
|
||||
- Use cryptographically secure random tokens
|
||||
- Implement token rotation
|
||||
- Set appropriate cookie flags (HttpOnly, Secure, SameSite)
|
||||
- Implement session timeout and renewal
|
||||
- Invalidate sessions on logout
|
||||
- Clean up expired sessions periodically
|
||||
|
||||
### 10.3 Authorization
|
||||
- Always verify user identity from session token (never trust client)
|
||||
- Check permissions on every API call
|
||||
- Implement principle of least privilege
|
||||
- Validate user ownership of resources before operations
|
||||
- Implement proper error messages (avoid information leakage)
|
||||
|
||||
### 10.4 Data Isolation
|
||||
- Strict separation of user data in database queries
|
||||
- Prevent SQL injection via parameterized queries
|
||||
- Validate all user inputs
|
||||
- Implement proper access control checks
|
||||
- Audit trail for sensitive operations
|
||||
|
||||
### 10.5 API Security
|
||||
- Implement rate limiting on sensitive endpoints
|
||||
- Use HTTPS in production (enforce via config)
|
||||
- Implement CSRF protection
|
||||
- Validate and sanitize all inputs
|
||||
- Implement proper CORS configuration
|
||||
- Add security headers (CSP, X-Frame-Options, etc.)
|
||||
|
||||
### 10.6 Deployment Security
|
||||
- Document secure deployment practices
|
||||
- Recommend reverse proxy configuration (nginx, Apache)
|
||||
- Provide example configurations for HTTPS
|
||||
- Document firewall requirements
|
||||
- Recommend network isolation strategies
|
||||
|
||||
## 11. Email Integration (Optional)
|
||||
|
||||
**Note**: Email/SMTP configuration is optional. Many administrators will not have ready access to an outgoing SMTP server. When email is not configured, the system provides fallback mechanisms by displaying setup links directly in the admin UI.
|
||||
|
||||
### 11.1 Email Templates
|
||||
|
||||
#### User Invitation
|
||||
```
|
||||
Subject: You've been invited to InvokeAI
|
||||
|
||||
Hello,
|
||||
|
||||
You've been invited to join InvokeAI by [Administrator Name].
|
||||
|
||||
Click the link below to set up your account:
|
||||
[Setup Link]
|
||||
|
||||
This link expires in 7 days.
|
||||
|
||||
---
|
||||
InvokeAI
|
||||
```
|
||||
|
||||
#### Password Reset
|
||||
```
|
||||
Subject: Reset your InvokeAI password
|
||||
|
||||
Hello [User Name],
|
||||
|
||||
A password reset was requested for your account.
|
||||
|
||||
Click the link below to reset your password:
|
||||
[Reset Link]
|
||||
|
||||
This link expires in 24 hours.
|
||||
|
||||
If you didn't request this, please ignore this email.
|
||||
|
||||
---
|
||||
InvokeAI
|
||||
```
|
||||
|
||||
### 11.2 Email Service
|
||||
- Support SMTP configuration
|
||||
- Use secure connection (TLS)
|
||||
- Handle email failures gracefully
|
||||
- Implement email queue for reliability
|
||||
- Log email activities (without sensitive data)
|
||||
- Provide fallback for no-email deployments (show links in admin UI)
|
||||
|
||||
## 12. Testing Requirements
|
||||
|
||||
### 12.1 Unit Tests
|
||||
- Authentication service (password hashing, validation)
|
||||
- Authorization checks
|
||||
- Token generation and validation
|
||||
- User management operations
|
||||
- Shared board permissions
|
||||
- Data isolation queries
|
||||
|
||||
### 12.2 Integration Tests
|
||||
- Complete authentication flows
|
||||
- User creation and invitation
|
||||
- Password reset flow
|
||||
- Multi-user data isolation
|
||||
- Shared board access
|
||||
- Session management
|
||||
- Admin operations
|
||||
|
||||
### 12.3 Security Tests
|
||||
- SQL injection prevention
|
||||
- XSS prevention
|
||||
- CSRF protection
|
||||
- Session hijacking prevention
|
||||
- Brute force protection
|
||||
- Authorization bypass attempts
|
||||
|
||||
### 12.4 Performance Tests
|
||||
- Authentication overhead
|
||||
- Query performance with user filters
|
||||
- Concurrent user sessions
|
||||
- Database scalability with many users
|
||||
|
||||
## 13. Documentation Requirements
|
||||
|
||||
### 13.1 User Documentation
|
||||
- Getting started with multi-user InvokeAI
|
||||
- Login and account management
|
||||
- Using shared boards
|
||||
- Understanding permissions
|
||||
- Troubleshooting authentication issues
|
||||
|
||||
### 13.2 Administrator Documentation
|
||||
- Setting up multi-user InvokeAI
|
||||
- User management guide
|
||||
- Creating and managing shared boards
|
||||
- Email configuration
|
||||
- Security best practices
|
||||
- Backup and restore with user data
|
||||
|
||||
### 13.3 Developer Documentation
|
||||
- Authentication architecture
|
||||
- API authentication requirements
|
||||
- Adding new multi-user features
|
||||
- Database schema changes
|
||||
- Testing multi-user features
|
||||
|
||||
### 13.4 Migration Documentation
|
||||
- Upgrading from single-user to multi-user
|
||||
- Data migration strategies
|
||||
- Rollback procedures
|
||||
- Common issues and solutions
|
||||
|
||||
## 14. Future Enhancements
|
||||
|
||||
### 14.1 Phase 2 Features
|
||||
- **OAuth2/OpenID Connect integration** (deferred from initial release to keep scope manageable)
|
||||
- Two-factor authentication
|
||||
- API keys for programmatic access
|
||||
- Enhanced team/group management
|
||||
- Advanced permission system (roles and capabilities)
|
||||
|
||||
### 14.2 Phase 3 Features
|
||||
- SSO integration (SAML, LDAP)
|
||||
- User quotas and limits
|
||||
- Resource usage tracking
|
||||
- Advanced collaboration features
|
||||
- Workflow template library with permissions
|
||||
- Model access controls per user/group
|
||||
|
||||
## 15. Success Metrics
|
||||
|
||||
### 15.1 Functionality Metrics
|
||||
- Successful user authentication rate
|
||||
- Zero unauthorized data access incidents
|
||||
- All tests passing (unit, integration, security)
|
||||
- API response time within acceptable limits
|
||||
|
||||
### 15.2 Usability Metrics
|
||||
- User setup completion time < 2 minutes
|
||||
- Login time < 2 seconds
|
||||
- Clear error messages for all auth failures
|
||||
- Positive user feedback on multi-user features
|
||||
|
||||
### 15.3 Security Metrics
|
||||
- No critical security vulnerabilities identified
|
||||
- CodeQL scan passes
|
||||
- Penetration testing completed
|
||||
- Security best practices followed
|
||||
|
||||
## 16. Risks and Mitigations
|
||||
|
||||
### 16.1 Technical Risks
|
||||
| Risk | Impact | Probability | Mitigation |
|
||||
|------|--------|-------------|------------|
|
||||
| Performance degradation with user filtering | Medium | Low | Index optimization, query caching |
|
||||
| Database migration failures | High | Low | Thorough testing, rollback procedures |
|
||||
| Session management complexity | Medium | Medium | Use proven libraries (PyJWT), extensive testing |
|
||||
| Auth bypass vulnerabilities | High | Low | Security review, penetration testing |
|
||||
|
||||
### 16.2 UX Risks
|
||||
| Risk | Impact | Probability | Mitigation |
|
||||
|------|--------|-------------|------------|
|
||||
| Confusion in migration for existing users | Medium | High | Clear documentation, migration wizard |
|
||||
| Friction from additional login step | Low | High | Remember me option, long session timeout |
|
||||
| Complexity of admin interface | Medium | Medium | Intuitive UI design, user testing |
|
||||
|
||||
### 16.3 Operational Risks
|
||||
| Risk | Impact | Probability | Mitigation |
|
||||
|------|--------|-------------|------------|
|
||||
| Email delivery failures | Low | Medium | Show links in UI, document manual methods |
|
||||
| Lost admin password | High | Low | Document recovery procedure, config reset |
|
||||
| User data conflicts in migration | Medium | Low | Data validation, backup requirements |
|
||||
|
||||
## 17. Implementation Phases
|
||||
|
||||
### Phase 1: Foundation (Weeks 1-2)
|
||||
- Database schema design and migration
|
||||
- Basic authentication service
|
||||
- Password hashing and validation
|
||||
- Session management
|
||||
|
||||
### Phase 2: Backend API (Weeks 3-4)
|
||||
- Authentication endpoints
|
||||
- User management endpoints
|
||||
- Authorization middleware
|
||||
- Update existing endpoints with auth
|
||||
|
||||
### Phase 3: Frontend Auth (Weeks 5-6)
|
||||
- Login page and flow
|
||||
- Administrator setup
|
||||
- Session management
|
||||
- Auth state management
|
||||
|
||||
### Phase 4: Multi-tenancy (Weeks 7-9)
|
||||
- User isolation in all services
|
||||
- Shared boards implementation
|
||||
- Queue permission filtering
|
||||
- Workflow public/private
|
||||
|
||||
### Phase 5: Admin Interface (Weeks 10-11)
|
||||
- User management UI
|
||||
- Board sharing UI
|
||||
- Admin-specific features
|
||||
- User profile page
|
||||
|
||||
### Phase 6: Testing & Polish (Weeks 12-13)
|
||||
- Comprehensive testing
|
||||
- Security audit
|
||||
- Performance optimization
|
||||
- Documentation
|
||||
- Bug fixes
|
||||
|
||||
### Phase 7: Beta & Release (Week 14+)
|
||||
- Beta testing with selected users
|
||||
- Feedback incorporation
|
||||
- Final testing
|
||||
- Release preparation
|
||||
- Documentation finalization
|
||||
|
||||
## 18. Acceptance Criteria
|
||||
|
||||
- [ ] Administrator can set up initial account on first launch
|
||||
- [ ] Users can log in with email and password
|
||||
- [ ] Users can change their password
|
||||
- [ ] Administrators can create, edit, and delete users
|
||||
- [ ] User data is properly isolated (boards, images, workflows)
|
||||
- [ ] Shared boards work correctly with permissions
|
||||
- [ ] Non-admin users cannot access model management
|
||||
- [ ] Queue filtering works correctly for users and admins
|
||||
- [ ] Session management works correctly (expiry, renewal, logout)
|
||||
- [ ] All security tests pass
|
||||
- [ ] API documentation is updated
|
||||
- [ ] User and admin documentation is complete
|
||||
- [ ] Migration from single-user works smoothly
|
||||
- [ ] Performance is acceptable with multiple concurrent users
|
||||
- [ ] Backward compatibility mode works (auth disabled)
|
||||
|
||||
## 19. Design Decisions
|
||||
|
||||
The following design decisions have been approved for implementation:
|
||||
|
||||
1. **OAuth2 Priority**: OAuth2/OpenID Connect integration will be a **future enhancement**. The initial release will focus on username/password authentication to keep scope manageable.
|
||||
|
||||
2. **Email Requirement**: Email/SMTP configuration is **optional**. Many administrators will not have ready access to an outgoing SMTP server. The system will provide fallback mechanisms (showing setup links directly in the admin UI) when email is not configured.
|
||||
|
||||
3. **Data Migration**: During migration from single-user to multi-user mode, the administrator will be given the **option to specify an arbitrary user account** to hold legacy data. The admin account can be used for this purpose if the administrator wishes.
|
||||
|
||||
4. **API Compatibility**: Authentication will be **required on all APIs**, but authentication will not be required if multi-user support is disabled (backward compatibility mode with `auth_enabled: false`).
|
||||
|
||||
5. **Session Storage**: The system will use **JWT tokens with optional server-side session tracking**. This provides scalability while allowing administrators to enable server-side tracking if needed.
|
||||
|
||||
6. **Audit Logging**: The system will **log authentication events and admin actions**. This provides accountability and security monitoring for critical operations.
|
||||
|
||||
## 20. Conclusion
|
||||
|
||||
This specification provides a comprehensive blueprint for implementing multi-user support in InvokeAI. The design prioritizes:
|
||||
|
||||
- **Security**: Proper authentication, authorization, and data isolation
|
||||
- **Usability**: Intuitive UI, smooth migration, minimal friction
|
||||
- **Scalability**: Efficient database design, performant queries
|
||||
- **Maintainability**: Clean architecture, comprehensive testing
|
||||
- **Flexibility**: Future enhancement paths, optional features
|
||||
|
||||
The phased implementation approach allows for iterative development and testing, while the detailed specifications ensure all stakeholders have clear expectations of the final system.
|
||||
406
docs/multiuser/user_guide.md
Normal file
406
docs/multiuser/user_guide.md
Normal file
@@ -0,0 +1,406 @@
|
||||
# InvokeAI Multi-User Guide
|
||||
|
||||
## Overview
|
||||
|
||||
Multi-User mode is a recent feature (introduced in version 6.12), which allows multiple individuals to share a single InvokeAI server while keeping their work separate and organized. Each user has their own username and login password, images, assets, image boards, customization settings and workflows.
|
||||
|
||||
Two types of users are recognized:
|
||||
|
||||
* A user with **Administrator** status can add, remove and modify other users, and can install models. They also have the ability to view the full session queue and pause or kill other users' jobs.
|
||||
* **Non-administrator** users can modify their own profile but not others. They also do not have the ability to install or configure models, but must ask an Administrator to do this task.
|
||||
|
||||
Multiple users can be granted Administrator status.
|
||||
|
||||
***
|
||||
|
||||
## Getting Started
|
||||
|
||||
To activate Multi-User mode, open the `INVOKEAI_ROOT/invokeai.yaml` configuration file in a text editor. Add this line anywhere in the file:
|
||||
```yaml
|
||||
multiuser: true
|
||||
```
|
||||
|
||||
You may also wish to make InvokeAI available to other machines on your local LAN. Add an additional line to `invokeai.yaml`:
|
||||
|
||||
```yaml
|
||||
host: 0.0.0.0
|
||||
```
|
||||
|
||||
Restart the server. It will now be in multi-user mode. If you enabled
|
||||
the `host` option, other users on your home or office LAN will be able
|
||||
to reach it by browsing to the IP address of the machine the backend
|
||||
is running on (`http://host-ip-address:9090`).
|
||||
|
||||
!!! tip "Do not expose InvokeAI to the internet"
|
||||
It is not recommended to expose the InvokeAI host to the internet
|
||||
due to security concerns.
|
||||
|
||||
### Initial Setup (First Time in Multi-User Mode)
|
||||
|
||||
If you're the first person to access a fresh InvokeAI installation in multi-user mode, you'll see the **Administrator Setup** dialog:
|
||||
|
||||

|
||||
|
||||
Now
|
||||
|
||||
1. Enter your email address (this will be your login name)
|
||||
2. Create a display name (this will be the name other users see)
|
||||
3. Choose a strong password that meets the requirements:
|
||||
- At least 8 characters long
|
||||
- Contains uppercase letters
|
||||
- Contains lowercase letters
|
||||
- Contains numbers
|
||||
4. Confirm your password
|
||||
5. Click **Create Administrator Account**
|
||||
|
||||
You'll now be taken to a login screen and can enter the credentials
|
||||
you just created.
|
||||
|
||||
### Adding and Modifying Users
|
||||
|
||||
If you are logged in as Administrator, you can add additional users. Click on the small "person silhouette" icon at the bottom left of the main Invoke screen and select "User Management:"
|
||||
|
||||

|
||||
|
||||
This will take you to the User Management screen...
|
||||
|
||||

|
||||
|
||||
...where you can click "Create User" to add a new user.
|
||||
|
||||

|
||||
|
||||
The User Management screen also allows you to:
|
||||
|
||||
1. Temporarily change a user's status to Inactive, preventing them from logging in to Invoke.
|
||||
2. Edit a user (by clicking on the pencil icon) to change the user's display name or password.
|
||||
3. Permanently delete a user.
|
||||
4. Grant a user Administrator privileges.
|
||||
|
||||
### Command-line User Management Scripts
|
||||
|
||||
Administrators can also use a series of command-line scripts to add, modify, or delete users. If you use the launcher, click the ">" icon to enter the command-line interface. Otherwise, if you are a native command-line user, activate the InvokeAI environment from your terminal.
|
||||
|
||||
The commands are named:
|
||||
|
||||
* **invoke-useradd** -- add a user
|
||||
* **invoke-usermod** -- modify a user
|
||||
* **invoke-userdel** -- delete a user
|
||||
* **invoke-userlist** -- list all users
|
||||
|
||||
Pass the `--help` argument to get the usage of each script. For example:
|
||||
|
||||
```bash
|
||||
> invoke-useradd --help
|
||||
usage: invoke-useradd [-h] [--root ROOT] [--email EMAIL] [--password PASSWORD] [--name NAME] [--admin]
|
||||
|
||||
Add a user to the InvokeAI database
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--root ROOT, -r ROOT Path to the InvokeAI root directory. If omitted, the root is resolved in this order: the $INVOKEAI_ROOT environment
|
||||
variable, the active virtual environment's parent directory, or $HOME/invokeai.
|
||||
--email EMAIL, -e EMAIL
|
||||
User email address
|
||||
--password PASSWORD, -p PASSWORD
|
||||
User password
|
||||
--name NAME, -n NAME User display name (optional)
|
||||
--admin, -a Make user an administrator
|
||||
|
||||
If no arguments are provided, the script will run in interactive mode.
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
## Logging in as a Non-Administrative User
|
||||
|
||||
If you are a registered user on the system, enter your email address and password to log in. The Administrator will be able to provide you with the values to use:
|
||||
|
||||

|
||||
|
||||
As an unprivileged user you can do pretty much anything that's allowed under single-user mode -- generating images, using LoRAs, creating and running workflows, creating image boards -- but you are restricted against installing new models, changing low-level server settings, or interfering with other users. More information on user roles is given below.
|
||||
|
||||
### Changing your Profile
|
||||
|
||||
To change your display name or profile, click on the person silhouette icon at the bottom left of the screen and choose "My Profile". This will take you to a screen that lets you change these values. At this time you can change your display name but not your login ID (ordinarily your contact email address).
|
||||
|
||||
***
|
||||
|
||||
## Understanding User Roles
|
||||
|
||||
In single-user mode, you have access to all features without restrictions. In multi-user mode, InvokeAI has two user roles:
|
||||
|
||||
### Regular User
|
||||
|
||||
As a regular user, you can:
|
||||
|
||||
- ✅ Create and manage your own image boards
|
||||
- ✅ Generate images using all AI tools (Linear, Canvas, Upscale, Workflows)
|
||||
- ✅ Create, save, and load your own workflows
|
||||
- ✅ View your own generation queue
|
||||
- ✅ Customize your UI preferences (theme, hotkeys, etc.)
|
||||
- ✅ View available models (read-only access to Model Manager)
|
||||
- ✅ View shared and public boards created by other users
|
||||
- ✅ View and use workflows marked as shared by other users
|
||||
|
||||
You cannot:
|
||||
|
||||
- ❌ Add, delete, or modify models
|
||||
- ❌ View or modify other users' private boards, images, or workflows
|
||||
- ❌ Manage user accounts
|
||||
- ❌ Access system configuration
|
||||
- ❌ View or cancel other users' generation tasks
|
||||
|
||||
!!! tip "The generation queue"
|
||||
When two or more users are accessing InvokeAI at the same time,
|
||||
their image generation jobs will be placed on the session queue on
|
||||
a first-come, first-serve basis. This means that you will have to
|
||||
wait for other users' image rendering jobs to complete before
|
||||
yours will start.
|
||||
|
||||
When another user's job is running, you will see the image
|
||||
generation progress bar and a queue badge that reads `X/Y`, where
|
||||
"X" is the number of jobs you have queued and "Y" is the total
|
||||
number of jobs queued, including your own and others.
|
||||
|
||||
You can also pull up the Queue tab in order to see where your job
|
||||
is in relationship to other queued tasks.
|
||||
|
||||
### Administrator
|
||||
|
||||
Administrators have all regular user capabilities, plus:
|
||||
|
||||
- ✅ Full model management (add, delete, configure models)
|
||||
- ✅ Create and manage user accounts
|
||||
- ✅ View and manage all users' generation queues
|
||||
- ✅ View and manage all users' boards, images, and workflows (including system-owned legacy content)
|
||||
- ✅ Access system configuration
|
||||
- ✅ Grant or revoke admin privileges
|
||||
|
||||
***
|
||||
|
||||
## Working with Your Content in Multi-User Mode
|
||||
|
||||
### Image Boards
|
||||
|
||||
In multi-user mode, each user can create an unlimited number of boards and organize their images and assets as they see fit. Boards have three visibility levels:
|
||||
|
||||
- **Private** (default): Only you (and administrators) can see and modify the board.
|
||||
- **Shared**: All users can view the board and its contents, but only you (and administrators) can modify it (rename, archive, delete, or add/remove images).
|
||||
- **Public**: All users can view the board. Only you (and administrators) can modify the board's structure (rename, archive, delete).
|
||||
|
||||
To change a board's visibility, right-click on the board and select the desired visibility option.
|
||||
|
||||
Administrators can see and manage all users' image boards and their contents regardless of visibility settings.
|
||||
|
||||
### Going From Multi-User to Single-User Mode
|
||||
|
||||
If an InvokeAI instance was in multiuser mode and then restarted in single user mode (by setting `multiuser: false` in the configuration file), all users' boards will be consolidated in one place. Any images that were in "Uncategorized" will be merged together into a single Uncategorized board. If, at a later date, the server is restarted in multi-user mode, the boards and images will be separated and restored to their owners.
|
||||
|
||||
### Workflows
|
||||
|
||||
Each user has their own private workflow library. Workflows you create are visible only to you by default.
|
||||
|
||||
You can share a workflow with other users by marking it as **shared** (public). Shared workflows appear in all users' workflow libraries and can be opened by anyone, but only the owner (or an administrator) can modify or delete them.
|
||||
|
||||
To share a workflow, open it and use the sharing controls to toggle its public/shared status.
|
||||
|
||||
!!! warning "Preexisting workflows after enabling multi-user mode"
|
||||
When you enable multi-user mode for the first time on an existing InvokeAI installation, all workflows that were created before multi-user mode was activated will appear in the **shared workflows** section. These preexisting workflows are owned by the internal "system" account and are visible to all users. Administrators can edit or delete these shared legacy workflows. Regular users can view and use them but cannot modify them.
|
||||
|
||||
|
||||
### The Generation Queue
|
||||
|
||||
The queue shows your pending and running generation tasks.
|
||||
|
||||
**Queue Features:**
|
||||
|
||||
- View your current and completed generations
|
||||
- Cancel pending tasks
|
||||
- Re-run previous generations
|
||||
- Monitor progress in real-time
|
||||
|
||||
**Queue Isolation:**
|
||||
|
||||
- You will see your own queue items, as well as the items generated by
|
||||
either users, but the generation parameters (e.g. prompts) for other
|
||||
users' are hidden for privacy reasons.
|
||||
- Administrators can view all queues for troubleshooting
|
||||
- Your generations won't interfere with other users' tasks
|
||||
|
||||
***
|
||||
|
||||
## Customizing Your Experience
|
||||
|
||||
### Personal Preferences
|
||||
|
||||
Your UI preferences are saved to your account and are restored when you log in:
|
||||
|
||||
- **Theme**: Choose between light and dark modes
|
||||
- **Hotkeys**: Customize keyboard shortcuts
|
||||
- **Canvas Settings**: Default zoom, grid visibility, etc.
|
||||
- **Generation Defaults**: Default values for width, height, steps, etc.
|
||||
|
||||
These settings are stored per-user and won't affect other users.
|
||||
|
||||
***
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Cannot Log In
|
||||
|
||||
**Issue:** Login fails with "Incorrect email or password"
|
||||
|
||||
**Solutions:**
|
||||
|
||||
- Verify you're entering the correct email address
|
||||
- Check that Caps Lock is off
|
||||
- Try typing the password slowly to avoid mistakes
|
||||
- Contact your administrator if you've forgotten your password
|
||||
|
||||
**Issue:** Login fails with "Account is disabled"
|
||||
|
||||
**Solution:** Contact your administrator to reactivate your account
|
||||
|
||||
### Session Expired
|
||||
|
||||
**Issue:** You're suddenly logged out and see "Session expired"
|
||||
|
||||
**Explanation:** Sessions expire after 24 hours (or 7 days with "remember me")
|
||||
|
||||
**Solution:** Simply log in again with your credentials
|
||||
|
||||
### Cannot Access Features
|
||||
|
||||
**Issue:** Features like Model Manager show "Admin privileges required"
|
||||
|
||||
**Explanation:** Some features are restricted to administrators
|
||||
|
||||
**Solution:**
|
||||
|
||||
- For model viewing: You can view but not modify models
|
||||
- For user management: Contact an administrator
|
||||
- For system configuration: Contact an administrator
|
||||
|
||||
### Missing Boards or Images
|
||||
|
||||
**Issue:** Boards or images you created are not visible
|
||||
|
||||
**Possible Causes:**
|
||||
|
||||
1. **Filter Applied:** Check if a filter is hiding content
|
||||
2. **Wrong User:** Ensure you're logged in with the correct account
|
||||
3. **Archived Board:** Check the "Show Archived" option
|
||||
|
||||
**Solution:**
|
||||
|
||||
- Clear any active filters
|
||||
- Verify you're logged in as the right user
|
||||
- Check archived items
|
||||
|
||||
### Slow Performance
|
||||
|
||||
**Issue:** Generation or UI feels slower than expected
|
||||
|
||||
**Possible Causes:**
|
||||
|
||||
- Other users generating images simultaneously
|
||||
- Server resource limits
|
||||
- Network latency
|
||||
|
||||
**Solutions:**
|
||||
|
||||
- Check the queue to see if others are generating
|
||||
- Wait for current generations to complete
|
||||
- Contact administrator if persistent
|
||||
|
||||
### Generation Stuck in Queue
|
||||
|
||||
**Issue:** Your generation is queued but not starting
|
||||
|
||||
**Possible Causes:**
|
||||
|
||||
- Server is processing other users' generations
|
||||
- Server resources are fully utilized
|
||||
- Technical issue with the server
|
||||
|
||||
**Solutions:**
|
||||
|
||||
- Wait for your turn in the queue
|
||||
- Check if your generation is paused
|
||||
- Contact administrator if stuck for extended period
|
||||
|
||||
|
||||
***
|
||||
|
||||
## Frequently Asked Questions
|
||||
|
||||
### Can other users see my images?
|
||||
|
||||
Not unless you change your board's visibility to "shared" or "public". All personal boards and images are private by default.
|
||||
|
||||
### Can I share my workflows with others?
|
||||
|
||||
Yes. You can mark any workflow as shared (public), which makes it visible to all users. Other users can view and use shared workflows, but only you or an administrator can modify or delete them.
|
||||
|
||||
### How long do sessions last?
|
||||
|
||||
- 24 hours by default
|
||||
- 7 days if you check "Remember me" during login
|
||||
|
||||
### Can I use the API with multi-user mode?
|
||||
|
||||
Yes, but you'll need to authenticate with a JWT token. See the [API Guide](api_guide.md) for details.
|
||||
|
||||
### What happens if I forget my password?
|
||||
|
||||
Contact your administrator. They can reset your password for you.
|
||||
|
||||
### Can I have multiple sessions?
|
||||
|
||||
Yes, you can log in from multiple devices or browsers simultaneously. All sessions will use the same account and see the same content.
|
||||
|
||||
### Why can't I see the Model Manager "Add Models" tab?
|
||||
|
||||
Regular users can see the Models tab but with read-only access. Check that you're logged in and try refreshing the page.
|
||||
|
||||
### How do I know if I'm an administrator?
|
||||
|
||||
Administrators see an "Admin" badge next to their name in the top-right corner and have access to additional features like User Management.
|
||||
|
||||
### Can I request admin privileges?
|
||||
|
||||
Yes, ask your current administrator to grant you admin
|
||||
privileges. Admin privileges will give you the ability to see all
|
||||
other user's boards and images, as well as to add models and change
|
||||
various server-wide settings.
|
||||
|
||||
## Getting Help
|
||||
|
||||
### Support Channels
|
||||
|
||||
- **Administrator:** Contact your system administrator for account issues
|
||||
- **Documentation:** Check the [FAQ](../faq.md) for common issues
|
||||
- **Community:** Join the [Discord](https://discord.gg/ZmtBAhwWhy) for help
|
||||
- **Bug Reports:** File issues on [GitHub](https://github.com/invoke-ai/InvokeAI/issues)
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
When reporting an issue, include:
|
||||
|
||||
- Your role (regular user or administrator)
|
||||
- What you were trying to do
|
||||
- What happened instead
|
||||
- Any error messages you saw
|
||||
- Your browser and operating system
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [Administrator Guide](admin_guide.md) - For administrators managing users and the system
|
||||
- [API Guide](api_guide.md) - For developers using the InvokeAI API
|
||||
- [Multiuser Specification](specification.md) - Technical details about the feature
|
||||
- [InvokeAI Documentation](../index.md) - Main documentation hub
|
||||
|
||||
---
|
||||
|
||||
**Need more help?** Contact your administrator or visit the [InvokeAI Discord](https://discord.gg/ZmtBAhwWhy).
|
||||
166
invokeai/app/api/auth_dependencies.py
Normal file
166
invokeai/app/api/auth_dependencies.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""FastAPI dependencies for authentication."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.auth.token_service import TokenData, verify_token
|
||||
from invokeai.backend.util.logging import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# HTTP Bearer token security scheme
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)],
|
||||
) -> TokenData:
|
||||
"""Get current authenticated user from Bearer token.
|
||||
|
||||
Note: This function accesses ApiDependencies.invoker.services.users directly,
|
||||
which is the established pattern in this codebase. The ApiDependencies.invoker
|
||||
is initialized in the FastAPI lifespan context before any requests are handled.
|
||||
|
||||
Args:
|
||||
credentials: The HTTP authorization credentials containing the Bearer token
|
||||
|
||||
Returns:
|
||||
TokenData containing user information from the token
|
||||
|
||||
Raises:
|
||||
HTTPException: If token is missing, invalid, or expired (401 Unauthorized)
|
||||
"""
|
||||
if credentials is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing authentication credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
token = credentials.credentials
|
||||
token_data = verify_token(token)
|
||||
|
||||
if token_data is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired authentication token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Verify user still exists and is active
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
user = user_service.get(token_data.user_id)
|
||||
|
||||
if user is None or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User account is inactive or does not exist",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
return token_data
|
||||
|
||||
|
||||
async def get_current_user_or_default(
|
||||
credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)],
|
||||
) -> TokenData:
|
||||
"""Get current authenticated user from Bearer token, or return a default system user if not authenticated.
|
||||
|
||||
This dependency is useful for endpoints that should work in both single-user and multiuser modes.
|
||||
|
||||
When multiuser mode is disabled (default), this always returns a system user with admin privileges,
|
||||
allowing unrestricted access to all operations.
|
||||
|
||||
When multiuser mode is enabled, authentication is required and this function validates the token,
|
||||
returning authenticated user data or raising 401 Unauthorized if no valid credentials are provided.
|
||||
|
||||
Args:
|
||||
credentials: The HTTP authorization credentials containing the Bearer token
|
||||
|
||||
Returns:
|
||||
TokenData containing user information from the token, or system user in single-user mode
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 Unauthorized if in multiuser mode and credentials are missing, invalid, or user is inactive
|
||||
"""
|
||||
# Get configuration to check if multiuser is enabled
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
|
||||
# In single-user mode (multiuser=False), always return system user with admin privileges
|
||||
if not config.multiuser:
|
||||
return TokenData(user_id="system", email="system@system.invokeai", is_admin=True)
|
||||
|
||||
# Multiuser mode is enabled - validate credentials
|
||||
if credentials is None:
|
||||
# In multiuser mode, authentication is required
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required")
|
||||
|
||||
token = credentials.credentials
|
||||
token_data = verify_token(token)
|
||||
|
||||
if token_data is None:
|
||||
# Invalid token in multiuser mode - reject
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token")
|
||||
|
||||
# Verify user still exists and is active
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
user = user_service.get(token_data.user_id)
|
||||
|
||||
if user is None or not user.is_active:
|
||||
# User doesn't exist or is inactive in multiuser mode - reject
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
|
||||
|
||||
return token_data
|
||||
|
||||
|
||||
async def require_admin(
|
||||
current_user: Annotated[TokenData, Depends(get_current_user)],
|
||||
) -> TokenData:
|
||||
"""Require admin role for the current user.
|
||||
|
||||
Args:
|
||||
current_user: The current authenticated user's token data
|
||||
|
||||
Returns:
|
||||
The token data if user is an admin
|
||||
|
||||
Raises:
|
||||
HTTPException: If user does not have admin privileges (403 Forbidden)
|
||||
"""
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin privileges required")
|
||||
return current_user
|
||||
|
||||
|
||||
async def require_admin_or_default(
|
||||
current_user: Annotated[TokenData, Depends(get_current_user_or_default)],
|
||||
) -> TokenData:
|
||||
"""Require admin role for the current user, or return default system admin in single-user mode.
|
||||
|
||||
This dependency is useful for admin-only endpoints that should work in both single-user and multiuser modes.
|
||||
|
||||
When multiuser mode is disabled (default), this always returns a system user with admin privileges.
|
||||
When multiuser mode is enabled, this validates that the authenticated user has admin privileges.
|
||||
|
||||
Args:
|
||||
current_user: The current authenticated user's token data (or default system user)
|
||||
|
||||
Returns:
|
||||
The token data if user is an admin (or system user in single-user mode)
|
||||
|
||||
Raises:
|
||||
HTTPException: If user does not have admin privileges (403 Forbidden) in multiuser mode
|
||||
"""
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin privileges required")
|
||||
return current_user
|
||||
|
||||
|
||||
# Type aliases for convenient use in route dependencies
|
||||
CurrentUser = Annotated[TokenData, Depends(get_current_user)]
|
||||
CurrentUserOrDefault = Annotated[TokenData, Depends(get_current_user_or_default)]
|
||||
AdminUser = Annotated[TokenData, Depends(require_admin)]
|
||||
AdminUserOrDefault = Annotated[TokenData, Depends(require_admin_or_default)]
|
||||
@@ -5,6 +5,8 @@ from logging import Logger
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.app_settings import AppSettingsService
|
||||
from invokeai.app.services.auth.token_service import set_jwt_secret
|
||||
from invokeai.app.services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
|
||||
from invokeai.app.services.board_images.board_images_default import BoardImagesService
|
||||
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
|
||||
@@ -14,6 +16,9 @@ from invokeai.app.services.client_state_persistence.client_state_persistence_sql
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.download.download_default import DownloadQueueService
|
||||
from invokeai.app.services.events.events_fastapievents import FastAPIEventService
|
||||
from invokeai.app.services.external_generation.external_generation_default import ExternalGenerationService
|
||||
from invokeai.app.services.external_generation.providers import AlibabaCloudProvider, GeminiProvider, OpenAIProvider
|
||||
from invokeai.app.services.external_generation.startup import sync_configured_external_starter_models
|
||||
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from invokeai.app.services.image_records.image_records_sqlite import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images.images_default import ImageService
|
||||
@@ -40,13 +45,16 @@ from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
|
||||
from invokeai.app.services.urls.urls_default import LocalUrlService
|
||||
from invokeai.app.services.users.users_default import UserService
|
||||
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
AnimaConditioningInfo,
|
||||
BasicConditioningInfo,
|
||||
CogView4ConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
FLUXConditioningInfo,
|
||||
QwenImageConditioningInfo,
|
||||
SD3ConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
ZImageConditioningInfo,
|
||||
@@ -101,6 +109,12 @@ class ApiDependencies:
|
||||
|
||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||
|
||||
# Initialize JWT secret from database
|
||||
app_settings = AppSettingsService(db=db)
|
||||
jwt_secret = app_settings.get_jwt_secret()
|
||||
set_jwt_secret(jwt_secret)
|
||||
logger.info("JWT secret loaded from database")
|
||||
|
||||
configuration = config
|
||||
logger = logger
|
||||
|
||||
@@ -131,18 +145,30 @@ class ApiDependencies:
|
||||
SD3ConditioningInfo,
|
||||
CogView4ConditioningInfo,
|
||||
ZImageConditioningInfo,
|
||||
QwenImageConditioningInfo,
|
||||
AnimaConditioningInfo,
|
||||
],
|
||||
ephemeral=True,
|
||||
),
|
||||
)
|
||||
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
|
||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||
model_record_service = ModelRecordServiceSQL(db=db, logger=logger)
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration,
|
||||
model_record_service=ModelRecordServiceSQL(db=db, logger=logger),
|
||||
model_record_service=model_record_service,
|
||||
download_queue=download_queue_service,
|
||||
events=events,
|
||||
)
|
||||
external_generation = ExternalGenerationService(
|
||||
providers={
|
||||
AlibabaCloudProvider.provider_id: AlibabaCloudProvider(app_config=configuration, logger=logger),
|
||||
GeminiProvider.provider_id: GeminiProvider(app_config=configuration, logger=logger),
|
||||
OpenAIProvider.provider_id: OpenAIProvider(app_config=configuration, logger=logger),
|
||||
},
|
||||
logger=logger,
|
||||
record_store=model_record_service,
|
||||
)
|
||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||
model_relationships = ModelRelationshipsService()
|
||||
model_relationship_records = SqliteModelRelationshipRecordStorage(db=db)
|
||||
names = SimpleNameService()
|
||||
@@ -155,6 +181,7 @@ class ApiDependencies:
|
||||
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
|
||||
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
|
||||
client_state_persistence = ClientStatePersistenceSqlite(db=db)
|
||||
users = UserService(db=db)
|
||||
|
||||
services = InvocationServices(
|
||||
board_image_records=board_image_records,
|
||||
@@ -174,6 +201,7 @@ class ApiDependencies:
|
||||
model_relationships=model_relationships,
|
||||
model_relationship_records=model_relationship_records,
|
||||
download_queue=download_queue_service,
|
||||
external_generation=external_generation,
|
||||
names=names,
|
||||
performance_statistics=performance_statistics,
|
||||
session_processor=session_processor,
|
||||
@@ -186,9 +214,20 @@ class ApiDependencies:
|
||||
style_preset_image_files=style_preset_image_files,
|
||||
workflow_thumbnails=workflow_thumbnails,
|
||||
client_state_persistence=client_state_persistence,
|
||||
users=users,
|
||||
)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
configured_external_providers = {
|
||||
provider_id
|
||||
for provider_id, status in external_generation.get_provider_statuses().items()
|
||||
if status.configured
|
||||
}
|
||||
sync_configured_external_starter_models(
|
||||
configured_provider_ids=configured_external_providers,
|
||||
model_manager=model_manager,
|
||||
logger=logger,
|
||||
)
|
||||
db.clean()
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.responses import Response
|
||||
from starlette.staticfiles import StaticFiles
|
||||
from starlette.types import Scope
|
||||
|
||||
|
||||
class NoCacheStaticFiles(StaticFiles):
|
||||
@@ -12,6 +14,10 @@ class NoCacheStaticFiles(StaticFiles):
|
||||
|
||||
Static files include the javascript bundles, fonts, locales, and some images. Generated
|
||||
images are not included, as they are served by a router.
|
||||
|
||||
This class also implements proper SPA (Single Page Application) routing by serving index.html
|
||||
for any routes that don't match static files, enabling client-side routing to work correctly
|
||||
in production builds.
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
@@ -26,3 +32,19 @@ class NoCacheStaticFiles(StaticFiles):
|
||||
resp.headers.setdefault("Pragma", self.pragma)
|
||||
resp.headers.setdefault("Expires", self.expires)
|
||||
return resp
|
||||
|
||||
async def get_response(self, path: str, scope: Scope) -> Response:
|
||||
"""
|
||||
Override get_response to implement SPA routing.
|
||||
|
||||
When a file is not found and html mode is enabled, serve index.html instead of raising a 404.
|
||||
This allows client-side routing to work correctly in SPAs.
|
||||
"""
|
||||
try:
|
||||
return await super().get_response(path, scope)
|
||||
except HTTPException as exc:
|
||||
# If the file is not found (404) and html mode is enabled, serve index.html
|
||||
# This allows client-side routing to handle the path
|
||||
if exc.status_code == 404 and self.html:
|
||||
return await super().get_response("index.html", scope)
|
||||
raise
|
||||
|
||||
@@ -1,15 +1,30 @@
|
||||
import locale
|
||||
from enum import Enum
|
||||
from importlib.metadata import distributions
|
||||
from pathlib import Path as FilePath
|
||||
from threading import Lock
|
||||
|
||||
import torch
|
||||
from fastapi import Body
|
||||
import yaml
|
||||
from fastapi import Body, HTTPException, Path
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.auth_dependencies import AdminUserOrDefault
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config
|
||||
from invokeai.app.services.config.config_default import (
|
||||
EXTERNAL_PROVIDER_CONFIG_FIELDS,
|
||||
DefaultInvokeAIAppConfig,
|
||||
InvokeAIAppConfig,
|
||||
get_config,
|
||||
load_and_migrate_config,
|
||||
load_external_api_keys,
|
||||
)
|
||||
from invokeai.app.services.external_generation.external_generation_common import ExternalProviderStatus
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
from invokeai.backend.util.logging import logging
|
||||
from invokeai.version import __version__
|
||||
|
||||
@@ -41,7 +56,7 @@ async def get_version() -> AppVersion:
|
||||
async def get_app_deps() -> dict[str, str]:
|
||||
deps: dict[str, str] = {dist.metadata["Name"]: dist.version for dist in distributions()}
|
||||
try:
|
||||
cuda = torch.version.cuda or "N/A"
|
||||
cuda = getattr(getattr(torch, "version", None), "cuda", None) or "N/A" # pyright: ignore[reportAttributeAccessIssue]
|
||||
except Exception:
|
||||
cuda = "N/A"
|
||||
|
||||
@@ -64,6 +79,41 @@ class InvokeAIAppConfigWithSetFields(BaseModel):
|
||||
config: InvokeAIAppConfig = Field(description="The InvokeAI App Config")
|
||||
|
||||
|
||||
class ExternalProviderStatusModel(BaseModel):
|
||||
provider_id: str = Field(description="The external provider identifier")
|
||||
configured: bool = Field(description="Whether credentials are configured for the provider")
|
||||
message: str | None = Field(default=None, description="Optional provider status detail")
|
||||
|
||||
|
||||
class ExternalProviderConfigUpdate(BaseModel):
|
||||
api_key: str | None = Field(default=None, description="API key for the external provider")
|
||||
base_url: str | None = Field(default=None, description="Optional base URL override for the provider")
|
||||
|
||||
|
||||
class ExternalProviderConfigModel(BaseModel):
|
||||
provider_id: str = Field(description="The external provider identifier")
|
||||
api_key_configured: bool = Field(description="Whether an API key is configured")
|
||||
base_url: str | None = Field(default=None, description="Optional base URL override")
|
||||
|
||||
|
||||
EXTERNAL_PROVIDER_FIELDS: dict[str, tuple[str, str]] = {
|
||||
"alibabacloud": ("external_alibabacloud_api_key", "external_alibabacloud_base_url"),
|
||||
"gemini": ("external_gemini_api_key", "external_gemini_base_url"),
|
||||
"openai": ("external_openai_api_key", "external_openai_base_url"),
|
||||
}
|
||||
_EXTERNAL_PROVIDER_CONFIG_LOCK = Lock()
|
||||
|
||||
|
||||
class UpdateAppGenerationSettingsRequest(BaseModel):
|
||||
"""Writable generation-related app settings."""
|
||||
|
||||
max_queue_history: int | None = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
description="Keep the last N completed, failed, and canceled queue items on startup. Set to 0 to prune all terminal items.",
|
||||
)
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/runtime_config", operation_id="get_runtime_config", status_code=200, response_model=InvokeAIAppConfigWithSetFields
|
||||
)
|
||||
@@ -72,6 +122,190 @@ async def get_runtime_config() -> InvokeAIAppConfigWithSetFields:
|
||||
return InvokeAIAppConfigWithSetFields(set_fields=config.model_fields_set, config=config)
|
||||
|
||||
|
||||
@app_router.patch(
|
||||
"/runtime_config",
|
||||
operation_id="update_runtime_config",
|
||||
status_code=200,
|
||||
response_model=InvokeAIAppConfigWithSetFields,
|
||||
)
|
||||
async def update_runtime_config(
|
||||
_: AdminUserOrDefault,
|
||||
changes: UpdateAppGenerationSettingsRequest = Body(description="Writable runtime configuration changes"),
|
||||
) -> InvokeAIAppConfigWithSetFields:
|
||||
config = get_config()
|
||||
update_dict = changes.model_dump(exclude_unset=True)
|
||||
config.update_config(update_dict)
|
||||
|
||||
if config.config_file_path.exists():
|
||||
persisted_config = load_and_migrate_config(config.config_file_path)
|
||||
else:
|
||||
persisted_config = DefaultInvokeAIAppConfig()
|
||||
|
||||
persisted_config.update_config(update_dict)
|
||||
persisted_config.write_file(config.config_file_path)
|
||||
return InvokeAIAppConfigWithSetFields(set_fields=config.model_fields_set, config=config)
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/external_providers/status",
|
||||
operation_id="get_external_provider_statuses",
|
||||
status_code=200,
|
||||
response_model=list[ExternalProviderStatusModel],
|
||||
)
|
||||
async def get_external_provider_statuses() -> list[ExternalProviderStatusModel]:
|
||||
statuses = ApiDependencies.invoker.services.external_generation.get_provider_statuses()
|
||||
return [status_to_model(status) for status in statuses.values()]
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/external_providers/config",
|
||||
operation_id="get_external_provider_configs",
|
||||
status_code=200,
|
||||
response_model=list[ExternalProviderConfigModel],
|
||||
)
|
||||
async def get_external_provider_configs() -> list[ExternalProviderConfigModel]:
|
||||
config = get_config()
|
||||
return [_build_external_provider_config(provider_id, config) for provider_id in EXTERNAL_PROVIDER_FIELDS]
|
||||
|
||||
|
||||
@app_router.post(
|
||||
"/external_providers/config/{provider_id}",
|
||||
operation_id="set_external_provider_config",
|
||||
status_code=200,
|
||||
response_model=ExternalProviderConfigModel,
|
||||
)
|
||||
async def set_external_provider_config(
|
||||
provider_id: str = Path(description="The external provider identifier"),
|
||||
update: ExternalProviderConfigUpdate = Body(description="External provider configuration settings"),
|
||||
) -> ExternalProviderConfigModel:
|
||||
api_key_field, base_url_field = _get_external_provider_fields(provider_id)
|
||||
updates: dict[str, str | None] = {}
|
||||
|
||||
if update.api_key is not None:
|
||||
api_key = update.api_key.strip()
|
||||
updates[api_key_field] = api_key or None
|
||||
if update.base_url is not None:
|
||||
base_url = update.base_url.strip()
|
||||
updates[base_url_field] = base_url or None
|
||||
|
||||
if not updates:
|
||||
raise HTTPException(status_code=400, detail="No external provider config fields provided")
|
||||
|
||||
api_key_removed = update.api_key is not None and updates.get(api_key_field) is None
|
||||
_apply_external_provider_update(updates)
|
||||
if api_key_removed:
|
||||
_remove_external_models_for_provider(provider_id)
|
||||
return _build_external_provider_config(provider_id, get_config())
|
||||
|
||||
|
||||
@app_router.delete(
|
||||
"/external_providers/config/{provider_id}",
|
||||
operation_id="reset_external_provider_config",
|
||||
status_code=200,
|
||||
response_model=ExternalProviderConfigModel,
|
||||
)
|
||||
async def reset_external_provider_config(
|
||||
provider_id: str = Path(description="The external provider identifier"),
|
||||
) -> ExternalProviderConfigModel:
|
||||
api_key_field, base_url_field = _get_external_provider_fields(provider_id)
|
||||
_apply_external_provider_update({api_key_field: None, base_url_field: None})
|
||||
_remove_external_models_for_provider(provider_id)
|
||||
return _build_external_provider_config(provider_id, get_config())
|
||||
|
||||
|
||||
def status_to_model(status: ExternalProviderStatus) -> ExternalProviderStatusModel:
|
||||
return ExternalProviderStatusModel(
|
||||
provider_id=status.provider_id,
|
||||
configured=status.configured,
|
||||
message=status.message,
|
||||
)
|
||||
|
||||
|
||||
def _get_external_provider_fields(provider_id: str) -> tuple[str, str]:
|
||||
if provider_id not in EXTERNAL_PROVIDER_FIELDS:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown external provider '{provider_id}'")
|
||||
return EXTERNAL_PROVIDER_FIELDS[provider_id]
|
||||
|
||||
|
||||
def _write_external_api_keys_file(api_keys_file_path: FilePath, api_keys: dict[str, str]) -> None:
|
||||
if not api_keys:
|
||||
if api_keys_file_path.exists():
|
||||
api_keys_file_path.unlink()
|
||||
return
|
||||
|
||||
api_keys_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(api_keys_file_path, "w", encoding=locale.getpreferredencoding()) as api_keys_file:
|
||||
yaml.safe_dump(api_keys, api_keys_file, sort_keys=False)
|
||||
|
||||
|
||||
def _apply_external_provider_update(updates: dict[str, str | None]) -> None:
|
||||
with _EXTERNAL_PROVIDER_CONFIG_LOCK:
|
||||
runtime_config = get_config()
|
||||
config_path = runtime_config.config_file_path
|
||||
api_keys_file_path = runtime_config.api_keys_file_path
|
||||
if config_path.exists():
|
||||
file_config = load_and_migrate_config(config_path)
|
||||
else:
|
||||
file_config = DefaultInvokeAIAppConfig()
|
||||
|
||||
runtime_config.update_config(updates)
|
||||
provider_config_fields = set(EXTERNAL_PROVIDER_CONFIG_FIELDS)
|
||||
provider_updates = {field: value for field, value in updates.items() if field in provider_config_fields}
|
||||
non_provider_updates = {field: value for field, value in updates.items() if field not in provider_config_fields}
|
||||
|
||||
if non_provider_updates:
|
||||
file_config.update_config(non_provider_updates)
|
||||
|
||||
persisted_api_keys = load_external_api_keys(api_keys_file_path)
|
||||
for field_name in EXTERNAL_PROVIDER_CONFIG_FIELDS:
|
||||
file_value = getattr(file_config, field_name, None)
|
||||
if field_name not in persisted_api_keys and isinstance(file_value, str) and file_value.strip():
|
||||
persisted_api_keys[field_name] = file_value
|
||||
|
||||
for field_name, value in provider_updates.items():
|
||||
if value is None:
|
||||
persisted_api_keys.pop(field_name, None)
|
||||
else:
|
||||
persisted_api_keys[field_name] = value
|
||||
|
||||
_write_external_api_keys_file(api_keys_file_path, persisted_api_keys)
|
||||
|
||||
for field_name in EXTERNAL_PROVIDER_CONFIG_FIELDS:
|
||||
setattr(file_config, field_name, None)
|
||||
|
||||
file_config_to_write = type(file_config).model_validate(
|
||||
file_config.model_dump(exclude_unset=True, exclude_none=True)
|
||||
)
|
||||
file_config_to_write.write_file(config_path, as_example=False)
|
||||
|
||||
|
||||
def _build_external_provider_config(provider_id: str, config: InvokeAIAppConfig) -> ExternalProviderConfigModel:
|
||||
api_key_field, base_url_field = _get_external_provider_fields(provider_id)
|
||||
return ExternalProviderConfigModel(
|
||||
provider_id=provider_id,
|
||||
api_key_configured=bool(getattr(config, api_key_field)),
|
||||
base_url=getattr(config, base_url_field),
|
||||
)
|
||||
|
||||
|
||||
def _remove_external_models_for_provider(provider_id: str) -> None:
|
||||
model_manager = ApiDependencies.invoker.services.model_manager
|
||||
external_models = model_manager.store.search_by_attr(
|
||||
base_model=BaseModelType.External,
|
||||
model_type=ModelType.ExternalImageGenerator,
|
||||
)
|
||||
|
||||
for model in external_models:
|
||||
if getattr(model, "provider_id", None) != provider_id:
|
||||
continue
|
||||
try:
|
||||
model_manager.install.delete(model.key)
|
||||
except UnknownModelException:
|
||||
logging.warning(f"External model key '{model.key}' was already removed while resetting '{provider_id}'")
|
||||
except Exception as error:
|
||||
logging.warning(f"Failed removing external model key '{model.key}' for '{provider_id}': {error}")
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/logging",
|
||||
operation_id="get_log_level",
|
||||
|
||||
536
invokeai/app/api/routers/auth.py
Normal file
536
invokeai/app/api/routers/auth.py
Normal file
@@ -0,0 +1,536 @@
|
||||
"""Authentication endpoints."""
|
||||
|
||||
import secrets
|
||||
import string
|
||||
from datetime import timedelta
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Path, status
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from invokeai.app.api.auth_dependencies import AdminUser, CurrentUser
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.auth.token_service import TokenData, create_access_token
|
||||
from invokeai.app.services.users.users_common import (
|
||||
UserCreateRequest,
|
||||
UserDTO,
|
||||
UserUpdateRequest,
|
||||
validate_email_with_special_domains,
|
||||
)
|
||||
|
||||
auth_router = APIRouter(prefix="/v1/auth", tags=["authentication"])
|
||||
|
||||
# Token expiration constants (in days)
|
||||
TOKEN_EXPIRATION_NORMAL = 1 # 1 day for normal login
|
||||
TOKEN_EXPIRATION_REMEMBER_ME = 7 # 7 days for "remember me" login
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
"""Request body for user login."""
|
||||
|
||||
email: str = Field(description="User email address")
|
||||
password: str = Field(description="User password")
|
||||
remember_me: bool = Field(default=False, description="Whether to extend session duration")
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, v: str) -> str:
|
||||
"""Validate email address, allowing special-use domains."""
|
||||
return validate_email_with_special_domains(v)
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
"""Response from successful login."""
|
||||
|
||||
token: str = Field(description="JWT access token")
|
||||
user: UserDTO = Field(description="User information")
|
||||
expires_in: int = Field(description="Token expiration time in seconds")
|
||||
|
||||
|
||||
class SetupRequest(BaseModel):
|
||||
"""Request body for initial admin setup."""
|
||||
|
||||
email: str = Field(description="Admin email address")
|
||||
display_name: str | None = Field(default=None, description="Admin display name")
|
||||
password: str = Field(description="Admin password")
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, v: str) -> str:
|
||||
"""Validate email address, allowing special-use domains."""
|
||||
return validate_email_with_special_domains(v)
|
||||
|
||||
|
||||
class SetupResponse(BaseModel):
|
||||
"""Response from successful admin setup."""
|
||||
|
||||
success: bool = Field(description="Whether setup was successful")
|
||||
user: UserDTO = Field(description="Created admin user information")
|
||||
|
||||
|
||||
class LogoutResponse(BaseModel):
|
||||
"""Response from logout."""
|
||||
|
||||
success: bool = Field(description="Whether logout was successful")
|
||||
|
||||
|
||||
class SetupStatusResponse(BaseModel):
|
||||
"""Response for setup status check."""
|
||||
|
||||
setup_required: bool = Field(description="Whether initial setup is required")
|
||||
multiuser_enabled: bool = Field(description="Whether multiuser mode is enabled")
|
||||
strict_password_checking: bool = Field(description="Whether strict password requirements are enforced")
|
||||
admin_email: str | None = Field(default=None, description="Email of the first active admin user, if any")
|
||||
|
||||
|
||||
@auth_router.get("/status", response_model=SetupStatusResponse)
|
||||
async def get_setup_status() -> SetupStatusResponse:
|
||||
"""Check if initial administrator setup is required.
|
||||
|
||||
Returns:
|
||||
SetupStatusResponse indicating whether setup is needed and multiuser mode status
|
||||
"""
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
|
||||
# If multiuser is disabled, setup is never required
|
||||
if not config.multiuser:
|
||||
return SetupStatusResponse(
|
||||
setup_required=False,
|
||||
multiuser_enabled=False,
|
||||
strict_password_checking=config.strict_password_checking,
|
||||
admin_email=None,
|
||||
)
|
||||
|
||||
# In multiuser mode, check if an admin exists
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
setup_required = not user_service.has_admin()
|
||||
|
||||
# Only expose admin_email during initial setup to avoid leaking
|
||||
# administrator identity on public deployments.
|
||||
admin_email = user_service.get_admin_email() if setup_required else None
|
||||
|
||||
return SetupStatusResponse(
|
||||
setup_required=setup_required,
|
||||
multiuser_enabled=True,
|
||||
strict_password_checking=config.strict_password_checking,
|
||||
admin_email=admin_email,
|
||||
)
|
||||
|
||||
|
||||
@auth_router.post("/login", response_model=LoginResponse)
|
||||
async def login(
|
||||
request: Annotated[LoginRequest, Body(description="Login credentials")],
|
||||
) -> LoginResponse:
|
||||
"""Authenticate user and return access token.
|
||||
|
||||
Args:
|
||||
request: Login credentials (email and password)
|
||||
|
||||
Returns:
|
||||
LoginResponse containing JWT token and user information
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if credentials are invalid or user is inactive
|
||||
HTTPException: 403 if multiuser mode is disabled
|
||||
"""
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
|
||||
# Check if multiuser is enabled
|
||||
if not config.multiuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Multiuser mode is disabled. Authentication is not required in single-user mode.",
|
||||
)
|
||||
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
user = user_service.authenticate(request.email, request.password)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect email or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User account is disabled")
|
||||
|
||||
# Create token with appropriate expiration
|
||||
expires_delta = timedelta(days=TOKEN_EXPIRATION_REMEMBER_ME if request.remember_me else TOKEN_EXPIRATION_NORMAL)
|
||||
token_data = TokenData(
|
||||
user_id=user.user_id,
|
||||
email=user.email,
|
||||
is_admin=user.is_admin,
|
||||
remember_me=request.remember_me,
|
||||
)
|
||||
token = create_access_token(token_data, expires_delta)
|
||||
|
||||
return LoginResponse(
|
||||
token=token,
|
||||
user=user,
|
||||
expires_in=int(expires_delta.total_seconds()),
|
||||
)
|
||||
|
||||
|
||||
@auth_router.post("/logout", response_model=LogoutResponse)
|
||||
async def logout(
|
||||
current_user: CurrentUser,
|
||||
) -> LogoutResponse:
|
||||
"""Logout current user.
|
||||
|
||||
Currently a no-op since we use stateless JWT tokens. For token invalidation in
|
||||
future implementations, consider:
|
||||
- Token blacklist: Store invalidated tokens in Redis/database with expiration
|
||||
- Token versioning: Add version field to user record, increment on logout
|
||||
- Short-lived tokens: Use refresh token pattern with token rotation
|
||||
- Session storage: Track active sessions server-side for revocation
|
||||
|
||||
Args:
|
||||
current_user: The authenticated user (validates token)
|
||||
|
||||
Returns:
|
||||
LogoutResponse indicating success
|
||||
"""
|
||||
# TODO: Implement token invalidation when server-side session management is added
|
||||
# For now, this is a no-op since we use stateless JWT tokens
|
||||
return LogoutResponse(success=True)
|
||||
|
||||
|
||||
@auth_router.get("/me", response_model=UserDTO)
|
||||
async def get_current_user_info(
|
||||
current_user: CurrentUser,
|
||||
) -> UserDTO:
|
||||
"""Get current authenticated user's information.
|
||||
|
||||
Args:
|
||||
current_user: The authenticated user's token data
|
||||
|
||||
Returns:
|
||||
UserDTO containing user information
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if user is not found (should not happen normally)
|
||||
"""
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
user = user_service.get(current_user.user_id)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@auth_router.post("/setup", response_model=SetupResponse)
|
||||
async def setup_admin(
|
||||
request: Annotated[SetupRequest, Body(description="Admin account details")],
|
||||
) -> SetupResponse:
|
||||
"""Set up initial administrator account.
|
||||
|
||||
This endpoint can only be called once, when no admin user exists. It creates
|
||||
the first admin user for the system.
|
||||
|
||||
Args:
|
||||
request: Admin account details (email, display_name, password)
|
||||
|
||||
Returns:
|
||||
SetupResponse containing the created admin user
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if admin already exists or password is weak
|
||||
HTTPException: 403 if multiuser mode is disabled
|
||||
"""
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
|
||||
# Check if multiuser is enabled
|
||||
if not config.multiuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Multiuser mode is disabled. Admin setup is not required in single-user mode.",
|
||||
)
|
||||
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
|
||||
# Check if any admin exists
|
||||
if user_service.has_admin():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Administrator account already configured",
|
||||
)
|
||||
|
||||
# Create admin user - this will validate password strength
|
||||
try:
|
||||
user_data = UserCreateRequest(
|
||||
email=request.email,
|
||||
display_name=request.display_name,
|
||||
password=request.password,
|
||||
is_admin=True,
|
||||
)
|
||||
user = user_service.create_admin(user_data, strict_password_checking=config.strict_password_checking)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
|
||||
return SetupResponse(success=True, user=user)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User management models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PASSWORD_ALPHABET = string.ascii_letters + string.digits + string.punctuation
|
||||
|
||||
|
||||
class AdminUserCreateRequest(BaseModel):
|
||||
"""Request body for admin to create a new user."""
|
||||
|
||||
email: str = Field(description="User email address")
|
||||
display_name: str | None = Field(default=None, description="Display name")
|
||||
password: str = Field(description="User password")
|
||||
is_admin: bool = Field(default=False, description="Whether user should have admin privileges")
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def validate_email(cls, v: str) -> str:
|
||||
"""Validate email address, allowing special-use domains."""
|
||||
return validate_email_with_special_domains(v)
|
||||
|
||||
|
||||
class AdminUserUpdateRequest(BaseModel):
|
||||
"""Request body for admin to update any user."""
|
||||
|
||||
display_name: str | None = Field(default=None, description="Display name")
|
||||
password: str | None = Field(default=None, description="New password")
|
||||
is_admin: bool | None = Field(default=None, description="Whether user should have admin privileges")
|
||||
is_active: bool | None = Field(default=None, description="Whether user account should be active")
|
||||
|
||||
|
||||
class UserProfileUpdateRequest(BaseModel):
|
||||
"""Request body for a user to update their own profile."""
|
||||
|
||||
display_name: str | None = Field(default=None, description="New display name")
|
||||
current_password: str | None = Field(default=None, description="Current password (required when changing password)")
|
||||
new_password: str | None = Field(default=None, description="New password")
|
||||
|
||||
|
||||
class GeneratePasswordResponse(BaseModel):
|
||||
"""Response containing a generated password."""
|
||||
|
||||
password: str = Field(description="Generated strong password")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User management endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@auth_router.get("/generate-password", response_model=GeneratePasswordResponse)
|
||||
async def generate_password(
|
||||
current_user: CurrentUser,
|
||||
) -> GeneratePasswordResponse:
|
||||
"""Generate a strong random password.
|
||||
|
||||
Returns a cryptographically secure random password of 16 characters
|
||||
containing uppercase, lowercase, digits, and punctuation.
|
||||
"""
|
||||
# Ensure the generated password always meets strength requirements:
|
||||
# at least one uppercase, one lowercase, one digit, one special char.
|
||||
while True:
|
||||
password = "".join(secrets.choice(_PASSWORD_ALPHABET) for _ in range(16))
|
||||
if (
|
||||
any(c.isupper() for c in password)
|
||||
and any(c.islower() for c in password)
|
||||
and any(c.isdigit() for c in password)
|
||||
):
|
||||
return GeneratePasswordResponse(password=password)
|
||||
|
||||
|
||||
@auth_router.get("/users", response_model=list[UserDTO])
|
||||
async def list_users(
|
||||
current_user: AdminUser,
|
||||
) -> list[UserDTO]:
|
||||
"""List all users. Requires admin privileges.
|
||||
|
||||
The internal 'system' user (created for backward compatibility) is excluded
|
||||
from the results since it cannot be managed through this interface.
|
||||
|
||||
Returns:
|
||||
List of all real users (system user excluded)
|
||||
"""
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
return [u for u in user_service.list_users() if u.user_id != "system"]
|
||||
|
||||
|
||||
@auth_router.post("/users", response_model=UserDTO, status_code=status.HTTP_201_CREATED)
|
||||
async def create_user(
|
||||
request: Annotated[AdminUserCreateRequest, Body(description="New user details")],
|
||||
current_user: AdminUser,
|
||||
) -> UserDTO:
|
||||
"""Create a new user. Requires admin privileges.
|
||||
|
||||
Args:
|
||||
request: New user details
|
||||
|
||||
Returns:
|
||||
The created user
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if email already exists or password is weak
|
||||
"""
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
try:
|
||||
user_data = UserCreateRequest(
|
||||
email=request.email,
|
||||
display_name=request.display_name,
|
||||
password=request.password,
|
||||
is_admin=request.is_admin,
|
||||
)
|
||||
return user_service.create(user_data, strict_password_checking=config.strict_password_checking)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
|
||||
|
||||
@auth_router.get("/users/{user_id}", response_model=UserDTO)
|
||||
async def get_user(
|
||||
user_id: Annotated[str, Path(description="User ID")],
|
||||
current_user: AdminUser,
|
||||
) -> UserDTO:
|
||||
"""Get a user by ID. Requires admin privileges.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
|
||||
Returns:
|
||||
The user
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if user not found
|
||||
"""
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
user = user_service.get(user_id)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
return user
|
||||
|
||||
|
||||
@auth_router.patch("/users/{user_id}", response_model=UserDTO)
|
||||
async def update_user(
|
||||
user_id: Annotated[str, Path(description="User ID")],
|
||||
request: Annotated[AdminUserUpdateRequest, Body(description="User fields to update")],
|
||||
current_user: AdminUser,
|
||||
) -> UserDTO:
|
||||
"""Update a user. Requires admin privileges.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
request: Fields to update
|
||||
|
||||
Returns:
|
||||
The updated user
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if password is weak
|
||||
HTTPException: 404 if user not found
|
||||
"""
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
try:
|
||||
changes = UserUpdateRequest(
|
||||
display_name=request.display_name,
|
||||
password=request.password,
|
||||
is_admin=request.is_admin,
|
||||
is_active=request.is_active,
|
||||
)
|
||||
return user_service.update(user_id, changes, strict_password_checking=config.strict_password_checking)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
|
||||
|
||||
@auth_router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_user(
|
||||
user_id: Annotated[str, Path(description="User ID")],
|
||||
current_user: AdminUser,
|
||||
) -> None:
|
||||
"""Delete a user. Requires admin privileges.
|
||||
|
||||
Admins can delete any user including other admins, but cannot delete the last
|
||||
remaining admin.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if attempting to delete the last admin
|
||||
HTTPException: 404 if user not found
|
||||
"""
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
user = user_service.get(user_id)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
|
||||
# Prevent deleting the last active admin
|
||||
if user.is_admin and user.is_active and user_service.count_admins() <= 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot delete the last administrator",
|
||||
)
|
||||
|
||||
try:
|
||||
user_service.delete(user_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
|
||||
|
||||
@auth_router.patch("/me", response_model=UserDTO)
|
||||
async def update_current_user(
|
||||
request: Annotated[UserProfileUpdateRequest, Body(description="Profile fields to update")],
|
||||
current_user: CurrentUser,
|
||||
) -> UserDTO:
|
||||
"""Update the current user's own profile.
|
||||
|
||||
To change the password, both ``current_password`` and ``new_password`` must
|
||||
be provided. The current password is verified before the change is applied.
|
||||
|
||||
Args:
|
||||
request: Profile fields to update
|
||||
current_user: The authenticated user
|
||||
|
||||
Returns:
|
||||
The updated user
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if current password is incorrect or new password is weak
|
||||
HTTPException: 404 if user not found
|
||||
"""
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
|
||||
# Verify current password when attempting a password change
|
||||
if request.new_password is not None:
|
||||
if not request.current_password:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Current password is required to set a new password",
|
||||
)
|
||||
|
||||
# Re-authenticate to verify the current password
|
||||
user = user_service.get(current_user.user_id)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
|
||||
authenticated = user_service.authenticate(user.email, request.current_password)
|
||||
if authenticated is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Current password is incorrect",
|
||||
)
|
||||
|
||||
try:
|
||||
changes = UserUpdateRequest(
|
||||
display_name=request.display_name,
|
||||
password=request.new_password,
|
||||
)
|
||||
return user_service.update(
|
||||
current_user.user_id, changes, strict_password_checking=config.strict_password_checking
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
@@ -1,12 +1,53 @@
|
||||
from fastapi import Body, HTTPException
|
||||
from fastapi.routing import APIRouter
|
||||
|
||||
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.images.images_common import AddImagesToBoardResult, RemoveImagesFromBoardResult
|
||||
|
||||
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
|
||||
|
||||
|
||||
def _assert_board_write_access(board_id: str, current_user: CurrentUserOrDefault) -> None:
|
||||
"""Raise 403 if the current user may not mutate the given board.
|
||||
|
||||
Write access is granted when ANY of these hold:
|
||||
- The user is an admin.
|
||||
- The user owns the board.
|
||||
- The board visibility is Public (public boards accept contributions from any user).
|
||||
"""
|
||||
from invokeai.app.services.board_records.board_records_common import BoardVisibility
|
||||
|
||||
try:
|
||||
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
if current_user.is_admin:
|
||||
return
|
||||
if board.user_id == current_user.user_id:
|
||||
return
|
||||
if board.board_visibility == BoardVisibility.Public:
|
||||
return
|
||||
raise HTTPException(status_code=403, detail="Not authorized to modify this board")
|
||||
|
||||
|
||||
def _assert_image_direct_owner(image_name: str, current_user: CurrentUserOrDefault) -> None:
|
||||
"""Raise 403 if the current user is not the direct owner of the image.
|
||||
|
||||
This is intentionally stricter than _assert_image_owner in images.py:
|
||||
board ownership is NOT sufficient here. Allowing a user to add someone
|
||||
else's image to their own board would grant them mutation rights via the
|
||||
board-ownership fallback in _assert_image_owner, escalating read access
|
||||
into write access.
|
||||
"""
|
||||
if current_user.is_admin:
|
||||
return
|
||||
owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name)
|
||||
if owner is not None and owner == current_user.user_id:
|
||||
return
|
||||
raise HTTPException(status_code=403, detail="Not authorized to move this image")
|
||||
|
||||
|
||||
@board_images_router.post(
|
||||
"/",
|
||||
operation_id="add_image_to_board",
|
||||
@@ -17,14 +58,17 @@ board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
|
||||
response_model=AddImagesToBoardResult,
|
||||
)
|
||||
async def add_image_to_board(
|
||||
current_user: CurrentUserOrDefault,
|
||||
board_id: str = Body(description="The id of the board to add to"),
|
||||
image_name: str = Body(description="The name of the image to add"),
|
||||
) -> AddImagesToBoardResult:
|
||||
"""Creates a board_image"""
|
||||
_assert_board_write_access(board_id, current_user)
|
||||
_assert_image_direct_owner(image_name, current_user)
|
||||
try:
|
||||
added_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
old_board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) or "none"
|
||||
ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||
added_images.add(image_name)
|
||||
affected_boards.add(board_id)
|
||||
@@ -48,13 +92,16 @@ async def add_image_to_board(
|
||||
response_model=RemoveImagesFromBoardResult,
|
||||
)
|
||||
async def remove_image_from_board(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_name: str = Body(description="The name of the image to remove", embed=True),
|
||||
) -> RemoveImagesFromBoardResult:
|
||||
"""Removes an image from its board, if it had one"""
|
||||
try:
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
if old_board_id != "none":
|
||||
_assert_board_write_access(old_board_id, current_user)
|
||||
removed_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
removed_images.add(image_name)
|
||||
affected_boards.add("none")
|
||||
@@ -64,6 +111,8 @@ async def remove_image_from_board(
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to remove image from board")
|
||||
|
||||
@@ -78,16 +127,21 @@ async def remove_image_from_board(
|
||||
response_model=AddImagesToBoardResult,
|
||||
)
|
||||
async def add_images_to_board(
|
||||
current_user: CurrentUserOrDefault,
|
||||
board_id: str = Body(description="The id of the board to add to"),
|
||||
image_names: list[str] = Body(description="The names of the images to add", embed=True),
|
||||
) -> AddImagesToBoardResult:
|
||||
"""Adds a list of images to a board"""
|
||||
_assert_board_write_access(board_id, current_user)
|
||||
try:
|
||||
added_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
_assert_image_direct_owner(image_name, current_user)
|
||||
old_board_id = (
|
||||
ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) or "none"
|
||||
)
|
||||
ApiDependencies.invoker.services.board_images.add_image_to_board(
|
||||
board_id=board_id,
|
||||
image_name=image_name,
|
||||
@@ -96,12 +150,16 @@ async def add_images_to_board(
|
||||
affected_boards.add(board_id)
|
||||
affected_boards.add(old_board_id)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
return AddImagesToBoardResult(
|
||||
added_images=list(added_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to add images to board")
|
||||
|
||||
@@ -116,6 +174,7 @@ async def add_images_to_board(
|
||||
response_model=RemoveImagesFromBoardResult,
|
||||
)
|
||||
async def remove_images_from_board(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_names: list[str] = Body(description="The names of the images to remove", embed=True),
|
||||
) -> RemoveImagesFromBoardResult:
|
||||
"""Removes a list of images from their board, if they had one"""
|
||||
@@ -125,15 +184,21 @@ async def remove_images_from_board(
|
||||
for image_name in image_names:
|
||||
try:
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
if old_board_id != "none":
|
||||
_assert_board_write_access(old_board_id, current_user)
|
||||
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
removed_images.add(image_name)
|
||||
affected_boards.add("none")
|
||||
affected_boards.add(old_board_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
return RemoveImagesFromBoardResult(
|
||||
removed_images=list(removed_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to remove images from board")
|
||||
|
||||
@@ -4,8 +4,9 @@ from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy
|
||||
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy, BoardVisibility
|
||||
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
@@ -32,11 +33,12 @@ class DeleteBoardResult(BaseModel):
|
||||
response_model=BoardDTO,
|
||||
)
|
||||
async def create_board(
|
||||
current_user: CurrentUserOrDefault,
|
||||
board_name: str = Query(description="The name of the board to create", max_length=300),
|
||||
) -> BoardDTO:
|
||||
"""Creates a board"""
|
||||
"""Creates a board for the current user"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.create(board_name=board_name)
|
||||
result = ApiDependencies.invoker.services.boards.create(board_name=board_name, user_id=current_user.user_id)
|
||||
return result
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to create board")
|
||||
@@ -44,16 +46,28 @@ async def create_board(
|
||||
|
||||
@boards_router.get("/{board_id}", operation_id="get_board", response_model=BoardDTO)
|
||||
async def get_board(
|
||||
current_user: CurrentUserOrDefault,
|
||||
board_id: str = Path(description="The id of board to get"),
|
||||
) -> BoardDTO:
|
||||
"""Gets a board"""
|
||||
"""Gets a board (user must have access to it)"""
|
||||
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
return result
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
|
||||
# Admins can access any board.
|
||||
# Owners can access their own boards.
|
||||
# Shared and public boards are visible to all authenticated users.
|
||||
if (
|
||||
not current_user.is_admin
|
||||
and result.user_id != current_user.user_id
|
||||
and result.board_visibility == BoardVisibility.Private
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this board")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@boards_router.patch(
|
||||
"/{board_id}",
|
||||
@@ -67,10 +81,19 @@ async def get_board(
|
||||
response_model=BoardDTO,
|
||||
)
|
||||
async def update_board(
|
||||
current_user: CurrentUserOrDefault,
|
||||
board_id: str = Path(description="The id of board to update"),
|
||||
changes: BoardChanges = Body(description="The changes to apply to the board"),
|
||||
) -> BoardDTO:
|
||||
"""Updates a board"""
|
||||
"""Updates a board (user must have access to it)"""
|
||||
try:
|
||||
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
|
||||
if not current_user.is_admin and board.user_id != current_user.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to update this board")
|
||||
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.update(board_id=board_id, changes=changes)
|
||||
return result
|
||||
@@ -80,10 +103,19 @@ async def update_board(
|
||||
|
||||
@boards_router.delete("/{board_id}", operation_id="delete_board", response_model=DeleteBoardResult)
|
||||
async def delete_board(
|
||||
current_user: CurrentUserOrDefault,
|
||||
board_id: str = Path(description="The id of board to delete"),
|
||||
include_images: Optional[bool] = Query(description="Permanently delete all images on the board", default=False),
|
||||
) -> DeleteBoardResult:
|
||||
"""Deletes a board"""
|
||||
"""Deletes a board (user must have access to it)"""
|
||||
try:
|
||||
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
|
||||
if not current_user.is_admin and board.user_id != current_user.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to delete this board")
|
||||
|
||||
try:
|
||||
if include_images is True:
|
||||
deleted_images = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
|
||||
@@ -120,6 +152,7 @@ async def delete_board(
|
||||
response_model=Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]],
|
||||
)
|
||||
async def list_boards(
|
||||
current_user: CurrentUserOrDefault,
|
||||
order_by: BoardRecordOrderBy = Query(default=BoardRecordOrderBy.CreatedAt, description="The attribute to order by"),
|
||||
direction: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The direction to order by"),
|
||||
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
|
||||
@@ -127,11 +160,15 @@ async def list_boards(
|
||||
limit: Optional[int] = Query(default=None, description="The number of boards per page"),
|
||||
include_archived: bool = Query(default=False, description="Whether or not to include archived boards in list"),
|
||||
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
|
||||
"""Gets a list of boards"""
|
||||
"""Gets a list of boards for the current user, including shared boards. Admin users see all boards."""
|
||||
if all:
|
||||
return ApiDependencies.invoker.services.boards.get_all(order_by, direction, include_archived)
|
||||
return ApiDependencies.invoker.services.boards.get_all(
|
||||
current_user.user_id, current_user.is_admin, order_by, direction, include_archived
|
||||
)
|
||||
elif offset is not None and limit is not None:
|
||||
return ApiDependencies.invoker.services.boards.get_many(order_by, direction, offset, limit, include_archived)
|
||||
return ApiDependencies.invoker.services.boards.get_many(
|
||||
current_user.user_id, current_user.is_admin, order_by, direction, offset, limit, include_archived
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@@ -145,15 +182,40 @@ async def list_boards(
|
||||
response_model=list[str],
|
||||
)
|
||||
async def list_all_board_image_names(
|
||||
current_user: CurrentUserOrDefault,
|
||||
board_id: str = Path(description="The id of the board or 'none' for uncategorized images"),
|
||||
categories: list[ImageCategory] | None = Query(default=None, description="The categories of image to include."),
|
||||
is_intermediate: bool | None = Query(default=None, description="Whether to list intermediate images."),
|
||||
) -> list[str]:
|
||||
"""Gets a list of images for a board"""
|
||||
|
||||
if board_id != "none":
|
||||
try:
|
||||
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
|
||||
if (
|
||||
not current_user.is_admin
|
||||
and board.user_id != current_user.user_id
|
||||
and board.board_visibility == BoardVisibility.Private
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this board")
|
||||
|
||||
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
|
||||
board_id,
|
||||
categories,
|
||||
is_intermediate,
|
||||
)
|
||||
|
||||
# For uncategorized images (board_id="none"), filter to only the caller's
|
||||
# images so that one user cannot enumerate another's uncategorized images.
|
||||
# Admin users can see all uncategorized images.
|
||||
if board_id == "none" and not current_user.is_admin:
|
||||
image_names = [
|
||||
name
|
||||
for name in image_names
|
||||
if ApiDependencies.invoker.services.image_records.get_user_id(name) == current_user.user_id
|
||||
]
|
||||
|
||||
return image_names
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
|
||||
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.backend.util.logging import logging
|
||||
|
||||
@@ -13,15 +14,16 @@ client_state_router = APIRouter(prefix="/v1/client_state", tags=["client_state"]
|
||||
response_model=str | None,
|
||||
)
|
||||
async def get_client_state_by_key(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"),
|
||||
key: str = Query(..., description="Key to get"),
|
||||
) -> str | None:
|
||||
"""Gets the client state"""
|
||||
"""Gets the client state for the current user (or system user if not authenticated)"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(queue_id, key)
|
||||
return ApiDependencies.invoker.services.client_state_persistence.get_by_key(current_user.user_id, key)
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting client state: {e}")
|
||||
raise HTTPException(status_code=500, detail="Error setting client state")
|
||||
raise HTTPException(status_code=500, detail="Error getting client state")
|
||||
|
||||
|
||||
@client_state_router.post(
|
||||
@@ -30,13 +32,14 @@ async def get_client_state_by_key(
|
||||
response_model=str,
|
||||
)
|
||||
async def set_client_state(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"),
|
||||
key: str = Query(..., description="Key to set"),
|
||||
value: str = Body(..., description="Stringified value to set"),
|
||||
) -> str:
|
||||
"""Sets the client state"""
|
||||
"""Sets the client state for the current user (or system user if not authenticated)"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.client_state_persistence.set_by_key(queue_id, key, value)
|
||||
return ApiDependencies.invoker.services.client_state_persistence.set_by_key(current_user.user_id, key, value)
|
||||
except Exception as e:
|
||||
logging.error(f"Error setting client state: {e}")
|
||||
raise HTTPException(status_code=500, detail="Error setting client state")
|
||||
@@ -48,11 +51,12 @@ async def set_client_state(
|
||||
responses={204: {"description": "Client state deleted"}},
|
||||
)
|
||||
async def delete_client_state(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id (ignored, kept for backwards compatibility)"),
|
||||
) -> None:
|
||||
"""Deletes the client state"""
|
||||
"""Deletes the client state for the current user (or system user if not authenticated)"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.client_state_persistence.delete(queue_id)
|
||||
ApiDependencies.invoker.services.client_state_persistence.delete(current_user.user_id)
|
||||
except Exception as e:
|
||||
logging.error(f"Error deleting client state: {e}")
|
||||
raise HTTPException(status_code=500, detail="Error deleting client state")
|
||||
|
||||
@@ -9,6 +9,7 @@ from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.api.extract_metadata_from_image import extract_metadata_from_image
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
@@ -37,6 +38,96 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
IMAGE_MAX_AGE = 31536000
|
||||
|
||||
|
||||
def _assert_image_owner(image_name: str, current_user: CurrentUserOrDefault) -> None:
|
||||
"""Raise 403 if the current user does not own the image and is not an admin.
|
||||
|
||||
Ownership is satisfied when ANY of these hold:
|
||||
- The user is an admin.
|
||||
- The user is the image's direct owner (image_records.user_id).
|
||||
- The user owns the board the image sits on.
|
||||
- The image sits on a Public board (public boards grant mutation rights).
|
||||
"""
|
||||
from invokeai.app.services.board_records.board_records_common import BoardVisibility
|
||||
|
||||
if current_user.is_admin:
|
||||
return
|
||||
owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name)
|
||||
if owner is not None and owner == current_user.user_id:
|
||||
return
|
||||
|
||||
# Check whether the user owns the board the image belongs to,
|
||||
# or the board is Public (public boards grant mutation rights).
|
||||
board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name)
|
||||
if board_id is not None:
|
||||
try:
|
||||
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
if board.user_id == current_user.user_id:
|
||||
return
|
||||
if board.board_visibility == BoardVisibility.Public:
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise HTTPException(status_code=403, detail="Not authorized to modify this image")
|
||||
|
||||
|
||||
def _assert_image_read_access(image_name: str, current_user: CurrentUserOrDefault) -> None:
|
||||
"""Raise 403 if the current user may not view the image.
|
||||
|
||||
Access is granted when ANY of these hold:
|
||||
- The user is an admin.
|
||||
- The user owns the image.
|
||||
- The image sits on a shared or public board.
|
||||
"""
|
||||
from invokeai.app.services.board_records.board_records_common import BoardVisibility
|
||||
|
||||
if current_user.is_admin:
|
||||
return
|
||||
|
||||
owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name)
|
||||
if owner is not None and owner == current_user.user_id:
|
||||
return
|
||||
|
||||
# Check whether the image's board makes it visible to other users.
|
||||
board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name)
|
||||
if board_id is not None:
|
||||
try:
|
||||
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public):
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this image")
|
||||
|
||||
|
||||
def _assert_board_read_access(board_id: str, current_user: CurrentUserOrDefault) -> None:
|
||||
"""Raise 403 if the current user may not read images from this board.
|
||||
|
||||
Access is granted when ANY of these hold:
|
||||
- The user is an admin.
|
||||
- The user owns the board.
|
||||
- The board visibility is Shared or Public.
|
||||
"""
|
||||
from invokeai.app.services.board_records.board_records_common import BoardVisibility
|
||||
|
||||
if current_user.is_admin:
|
||||
return
|
||||
|
||||
try:
|
||||
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
|
||||
if board.user_id == current_user.user_id:
|
||||
return
|
||||
|
||||
if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public):
|
||||
return
|
||||
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this board")
|
||||
|
||||
|
||||
class ResizeToDimensions(BaseModel):
|
||||
width: int = Field(..., gt=0)
|
||||
height: int = Field(..., gt=0)
|
||||
@@ -61,6 +152,7 @@ class ResizeToDimensions(BaseModel):
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def upload_image(
|
||||
current_user: CurrentUserOrDefault,
|
||||
file: UploadFile,
|
||||
request: Request,
|
||||
response: Response,
|
||||
@@ -80,7 +172,23 @@ async def upload_image(
|
||||
embed=True,
|
||||
),
|
||||
) -> ImageDTO:
|
||||
"""Uploads an image"""
|
||||
"""Uploads an image for the current user"""
|
||||
# If uploading into a board, verify the user has write access.
|
||||
# Public boards allow uploads from any authenticated user.
|
||||
if board_id is not None:
|
||||
from invokeai.app.services.board_records.board_records_common import BoardVisibility
|
||||
|
||||
try:
|
||||
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
if (
|
||||
not current_user.is_admin
|
||||
and board.user_id != current_user.user_id
|
||||
and board.board_visibility != BoardVisibility.Public
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Not authorized to upload to this board")
|
||||
|
||||
if not file.content_type or not file.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
@@ -133,6 +241,7 @@ async def upload_image(
|
||||
workflow=extracted_metadata.invokeai_workflow,
|
||||
graph=extracted_metadata.invokeai_graph,
|
||||
is_intermediate=is_intermediate,
|
||||
user_id=current_user.user_id,
|
||||
)
|
||||
|
||||
response.status_code = 201
|
||||
@@ -162,9 +271,11 @@ async def create_image_upload_entry(
|
||||
|
||||
@images_router.delete("/i/{image_name}", operation_id="delete_image", response_model=DeleteImagesResult)
|
||||
async def delete_image(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
) -> DeleteImagesResult:
|
||||
"""Deletes an image"""
|
||||
_assert_image_owner(image_name, current_user)
|
||||
|
||||
deleted_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
@@ -186,26 +297,31 @@ async def delete_image(
|
||||
|
||||
|
||||
@images_router.delete("/intermediates", operation_id="clear_intermediates")
|
||||
async def clear_intermediates() -> int:
|
||||
"""Clears all intermediates"""
|
||||
async def clear_intermediates(
|
||||
current_user: CurrentUserOrDefault,
|
||||
) -> int:
|
||||
"""Clears all intermediates. Requires admin."""
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Only admins can clear all intermediates")
|
||||
|
||||
try:
|
||||
count_deleted = ApiDependencies.invoker.services.images.delete_intermediates()
|
||||
return count_deleted
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to clear intermediates")
|
||||
pass
|
||||
|
||||
|
||||
@images_router.get("/intermediates", operation_id="get_intermediates_count")
|
||||
async def get_intermediates_count() -> int:
|
||||
"""Gets the count of intermediate images"""
|
||||
async def get_intermediates_count(
|
||||
current_user: CurrentUserOrDefault,
|
||||
) -> int:
|
||||
"""Gets the count of intermediate images. Non-admin users only see their own intermediates."""
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_intermediates_count()
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.images.get_intermediates_count(user_id=user_id)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to get intermediates")
|
||||
pass
|
||||
|
||||
|
||||
@images_router.patch(
|
||||
@@ -214,10 +330,12 @@ async def get_intermediates_count() -> int:
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def update_image(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_name: str = Path(description="The name of the image to update"),
|
||||
image_changes: ImageRecordChanges = Body(description="The changes to apply to the image"),
|
||||
) -> ImageDTO:
|
||||
"""Updates an image"""
|
||||
_assert_image_owner(image_name, current_user)
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.update(image_name, image_changes)
|
||||
@@ -231,9 +349,11 @@ async def update_image(
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def get_image_dto(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_name: str = Path(description="The name of image to get"),
|
||||
) -> ImageDTO:
|
||||
"""Gets an image's DTO"""
|
||||
_assert_image_read_access(image_name, current_user)
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_dto(image_name)
|
||||
@@ -247,9 +367,11 @@ async def get_image_dto(
|
||||
response_model=Optional[MetadataField],
|
||||
)
|
||||
async def get_image_metadata(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_name: str = Path(description="The name of image to get"),
|
||||
) -> Optional[MetadataField]:
|
||||
"""Gets an image's metadata"""
|
||||
_assert_image_read_access(image_name, current_user)
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_metadata(image_name)
|
||||
@@ -266,8 +388,11 @@ class WorkflowAndGraphResponse(BaseModel):
|
||||
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=WorkflowAndGraphResponse
|
||||
)
|
||||
async def get_image_workflow(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_name: str = Path(description="The name of image whose workflow to get"),
|
||||
) -> WorkflowAndGraphResponse:
|
||||
_assert_image_read_access(image_name, current_user)
|
||||
|
||||
try:
|
||||
workflow = ApiDependencies.invoker.services.images.get_workflow(image_name)
|
||||
graph = ApiDependencies.invoker.services.images.get_graph(image_name)
|
||||
@@ -303,8 +428,12 @@ async def get_image_workflow(
|
||||
async def get_image_full(
|
||||
image_name: str = Path(description="The name of full-resolution image file to get"),
|
||||
) -> Response:
|
||||
"""Gets a full-resolution image file"""
|
||||
"""Gets a full-resolution image file.
|
||||
|
||||
This endpoint is intentionally unauthenticated because browsers load images
|
||||
via <img src> tags which cannot send Bearer tokens. Image names are UUIDs,
|
||||
providing security through unguessability.
|
||||
"""
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images.get_path(image_name)
|
||||
with open(path, "rb") as f:
|
||||
@@ -332,8 +461,12 @@ async def get_image_full(
|
||||
async def get_image_thumbnail(
|
||||
image_name: str = Path(description="The name of thumbnail image file to get"),
|
||||
) -> Response:
|
||||
"""Gets a thumbnail image file"""
|
||||
"""Gets a thumbnail image file.
|
||||
|
||||
This endpoint is intentionally unauthenticated because browsers load images
|
||||
via <img src> tags which cannot send Bearer tokens. Image names are UUIDs,
|
||||
providing security through unguessability.
|
||||
"""
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images.get_path(image_name, thumbnail=True)
|
||||
with open(path, "rb") as f:
|
||||
@@ -351,9 +484,11 @@ async def get_image_thumbnail(
|
||||
response_model=ImageUrlsDTO,
|
||||
)
|
||||
async def get_image_urls(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_name: str = Path(description="The name of the image whose URL to get"),
|
||||
) -> ImageUrlsDTO:
|
||||
"""Gets an image and thumbnail URL"""
|
||||
_assert_image_read_access(image_name, current_user)
|
||||
|
||||
try:
|
||||
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
|
||||
@@ -373,6 +508,7 @@ async def get_image_urls(
|
||||
response_model=OffsetPaginatedResults[ImageDTO],
|
||||
)
|
||||
async def list_image_dtos(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
|
||||
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
|
||||
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
|
||||
@@ -386,10 +522,24 @@ async def list_image_dtos(
|
||||
starred_first: bool = Query(default=True, description="Whether to sort by starred images first"),
|
||||
search_term: Optional[str] = Query(default=None, description="The term to search for"),
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a list of image DTOs"""
|
||||
"""Gets a list of image DTOs for the current user"""
|
||||
|
||||
# Validate that the caller can read from this board before listing its images.
|
||||
# "none" is a sentinel for uncategorized images and is handled by the SQL layer.
|
||||
if board_id is not None and board_id != "none":
|
||||
_assert_board_read_access(board_id, current_user)
|
||||
|
||||
image_dtos = ApiDependencies.invoker.services.images.get_many(
|
||||
offset, limit, starred_first, order_dir, image_origin, categories, is_intermediate, board_id, search_term
|
||||
offset,
|
||||
limit,
|
||||
starred_first,
|
||||
order_dir,
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
search_term,
|
||||
current_user.user_id,
|
||||
)
|
||||
|
||||
return image_dtos
|
||||
@@ -397,6 +547,7 @@ async def list_image_dtos(
|
||||
|
||||
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesResult)
|
||||
async def delete_images_from_list(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_names: list[str] = Body(description="The list of names of images to delete", embed=True),
|
||||
) -> DeleteImagesResult:
|
||||
try:
|
||||
@@ -404,24 +555,31 @@ async def delete_images_from_list(
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
_assert_image_owner(image_name, current_user)
|
||||
image_dto = ApiDependencies.invoker.services.images.get_dto(image_name)
|
||||
board_id = image_dto.board_id or "none"
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
deleted_images.add(image_name)
|
||||
affected_boards.add(board_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
return DeleteImagesResult(
|
||||
deleted_images=list(deleted_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete images")
|
||||
|
||||
|
||||
@images_router.delete("/uncategorized", operation_id="delete_uncategorized_images", response_model=DeleteImagesResult)
|
||||
async def delete_uncategorized_images() -> DeleteImagesResult:
|
||||
"""Deletes all images that are uncategorized"""
|
||||
async def delete_uncategorized_images(
|
||||
current_user: CurrentUserOrDefault,
|
||||
) -> DeleteImagesResult:
|
||||
"""Deletes all uncategorized images owned by the current user (or all if admin)"""
|
||||
|
||||
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
|
||||
board_id="none", categories=None, is_intermediate=None
|
||||
@@ -432,9 +590,13 @@ async def delete_uncategorized_images() -> DeleteImagesResult:
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
_assert_image_owner(image_name, current_user)
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
deleted_images.add(image_name)
|
||||
affected_boards.add("none")
|
||||
except HTTPException:
|
||||
# Skip images not owned by the current user
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
return DeleteImagesResult(
|
||||
@@ -451,6 +613,7 @@ class ImagesUpdatedFromListResult(BaseModel):
|
||||
|
||||
@images_router.post("/star", operation_id="star_images_in_list", response_model=StarredImagesResult)
|
||||
async def star_images_in_list(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_names: list[str] = Body(description="The list of names of images to star", embed=True),
|
||||
) -> StarredImagesResult:
|
||||
try:
|
||||
@@ -458,23 +621,29 @@ async def star_images_in_list(
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
_assert_image_owner(image_name, current_user)
|
||||
updated_image_dto = ApiDependencies.invoker.services.images.update(
|
||||
image_name, changes=ImageRecordChanges(starred=True)
|
||||
)
|
||||
starred_images.add(image_name)
|
||||
affected_boards.add(updated_image_dto.board_id or "none")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
return StarredImagesResult(
|
||||
starred_images=list(starred_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to star images")
|
||||
|
||||
|
||||
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=UnstarredImagesResult)
|
||||
async def unstar_images_in_list(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_names: list[str] = Body(description="The list of names of images to unstar", embed=True),
|
||||
) -> UnstarredImagesResult:
|
||||
try:
|
||||
@@ -482,17 +651,22 @@ async def unstar_images_in_list(
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
_assert_image_owner(image_name, current_user)
|
||||
updated_image_dto = ApiDependencies.invoker.services.images.update(
|
||||
image_name, changes=ImageRecordChanges(starred=False)
|
||||
)
|
||||
unstarred_images.add(image_name)
|
||||
affected_boards.add(updated_image_dto.board_id or "none")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
return UnstarredImagesResult(
|
||||
unstarred_images=list(unstarred_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to unstar images")
|
||||
|
||||
@@ -510,6 +684,7 @@ class ImagesDownloaded(BaseModel):
|
||||
"/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202
|
||||
)
|
||||
async def download_images_from_list(
|
||||
current_user: CurrentUserOrDefault,
|
||||
background_tasks: BackgroundTasks,
|
||||
image_names: Optional[list[str]] = Body(
|
||||
default=None, description="The list of names of images to download", embed=True
|
||||
@@ -520,6 +695,16 @@ async def download_images_from_list(
|
||||
) -> ImagesDownloaded:
|
||||
if (image_names is None or len(image_names) == 0) and board_id is None:
|
||||
raise HTTPException(status_code=400, detail="No images or board id specified.")
|
||||
|
||||
# Validate that the caller can read every image they are requesting.
|
||||
# For a board_id request, check board visibility; for explicit image names,
|
||||
# check each image individually.
|
||||
if board_id:
|
||||
_assert_board_read_access(board_id, current_user)
|
||||
if image_names:
|
||||
for name in image_names:
|
||||
_assert_image_read_access(name, current_user)
|
||||
|
||||
bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id)
|
||||
|
||||
background_tasks.add_task(
|
||||
@@ -527,6 +712,7 @@ async def download_images_from_list(
|
||||
image_names,
|
||||
board_id,
|
||||
bulk_download_item_id,
|
||||
current_user.user_id,
|
||||
)
|
||||
return ImagesDownloaded(bulk_download_item_name=bulk_download_item_id + ".zip")
|
||||
|
||||
@@ -545,11 +731,21 @@ async def download_images_from_list(
|
||||
},
|
||||
)
|
||||
async def get_bulk_download_item(
|
||||
current_user: CurrentUserOrDefault,
|
||||
background_tasks: BackgroundTasks,
|
||||
bulk_download_item_name: str = Path(description="The bulk_download_item_name of the bulk download item to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a bulk download zip file"""
|
||||
"""Gets a bulk download zip file.
|
||||
|
||||
Requires authentication. The caller must be the user who initiated the
|
||||
download (tracked by the bulk download service) or an admin.
|
||||
"""
|
||||
try:
|
||||
# Verify the caller owns this download (or is an admin)
|
||||
owner = ApiDependencies.invoker.services.bulk_download.get_owner(bulk_download_item_name)
|
||||
if owner is not None and owner != current_user.user_id and not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this download")
|
||||
|
||||
path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name)
|
||||
|
||||
response = FileResponse(
|
||||
@@ -561,12 +757,15 @@ async def get_bulk_download_item(
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.delete, bulk_download_item_name)
|
||||
return response
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get("/names", operation_id="get_image_names")
|
||||
async def get_image_names(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
|
||||
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
|
||||
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
|
||||
@@ -580,6 +779,10 @@ async def get_image_names(
|
||||
) -> ImageNamesResult:
|
||||
"""Gets ordered list of image names with metadata for optimistic updates"""
|
||||
|
||||
# Validate that the caller can read from this board before listing its images.
|
||||
if board_id is not None and board_id != "none":
|
||||
_assert_board_read_access(board_id, current_user)
|
||||
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.images.get_image_names(
|
||||
starred_first=starred_first,
|
||||
@@ -589,6 +792,8 @@ async def get_image_names(
|
||||
is_intermediate=is_intermediate,
|
||||
board_id=board_id,
|
||||
search_term=search_term,
|
||||
user_id=current_user.user_id,
|
||||
is_admin=current_user.is_admin,
|
||||
)
|
||||
return result
|
||||
except Exception:
|
||||
@@ -601,6 +806,7 @@ async def get_image_names(
|
||||
responses={200: {"model": list[ImageDTO]}},
|
||||
)
|
||||
async def get_images_by_names(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_names: list[str] = Body(embed=True, description="Object containing list of image names to fetch DTOs for"),
|
||||
) -> list[ImageDTO]:
|
||||
"""Gets image DTOs for the specified image names. Maintains order of input names."""
|
||||
@@ -612,8 +818,12 @@ async def get_images_by_names(
|
||||
image_dtos: list[ImageDTO] = []
|
||||
for name in image_names:
|
||||
try:
|
||||
_assert_image_read_access(name, current_user)
|
||||
dto = image_service.get_dto(name)
|
||||
image_dtos.append(dto)
|
||||
except HTTPException:
|
||||
# Skip images the user is not authorized to view
|
||||
continue
|
||||
except Exception:
|
||||
# Skip missing images - they may have been deleted between name fetch and DTO fetch
|
||||
continue
|
||||
|
||||
@@ -19,6 +19,7 @@ from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.api.auth_dependencies import AdminUserOrDefault
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
@@ -27,7 +28,9 @@ from invokeai.app.services.model_records import (
|
||||
ModelRecordChanges,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.orphaned_models import OrphanedModelInfo
|
||||
from invokeai.app.util.suppress_output import SuppressOutput
|
||||
from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig
|
||||
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
|
||||
from invokeai.backend.model_manager.configs.main import (
|
||||
Main_Checkpoint_SD1_Config,
|
||||
@@ -73,8 +76,36 @@ class CacheType(str, Enum):
|
||||
def add_cover_image_to_model_config(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
|
||||
"""Add a cover image URL to a model configuration."""
|
||||
cover_image = dependencies.invoker.services.model_images.get_url(config.key)
|
||||
config.cover_image = cover_image
|
||||
return config
|
||||
return config.model_copy(update={"cover_image": cover_image})
|
||||
|
||||
|
||||
def apply_external_starter_model_overrides(config: AnyModelConfig) -> AnyModelConfig:
|
||||
"""Overlay starter-model metadata onto installed external model configs."""
|
||||
if not isinstance(config, ExternalApiModelConfig):
|
||||
return config
|
||||
|
||||
starter_match = next((starter for starter in STARTER_MODELS if starter.source == config.source), None)
|
||||
if starter_match is None:
|
||||
return config
|
||||
|
||||
model_updates: dict[str, object] = {}
|
||||
if starter_match.capabilities is not None:
|
||||
model_updates["capabilities"] = starter_match.capabilities
|
||||
if starter_match.default_settings is not None:
|
||||
model_updates["default_settings"] = starter_match.default_settings
|
||||
if starter_match.panel_schema is not None:
|
||||
model_updates["panel_schema"] = starter_match.panel_schema
|
||||
|
||||
if not model_updates:
|
||||
return config
|
||||
|
||||
return config.model_copy(update=model_updates)
|
||||
|
||||
|
||||
def prepare_model_config_for_response(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
|
||||
"""Apply API-only model config overlays before returning a response."""
|
||||
config = apply_external_starter_model_overrides(config)
|
||||
return add_cover_image_to_model_config(config, dependencies)
|
||||
|
||||
|
||||
##############################################################################
|
||||
@@ -143,11 +174,35 @@ async def list_model_records(
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
|
||||
)
|
||||
for model in found_models:
|
||||
model = add_cover_image_to_model_config(model, ApiDependencies)
|
||||
for index, model in enumerate(found_models):
|
||||
found_models[index] = prepare_model_config_for_response(model, ApiDependencies)
|
||||
return ModelsList(models=found_models)
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/missing",
|
||||
operation_id="list_missing_models",
|
||||
responses={200: {"description": "List of models with missing files"}},
|
||||
)
|
||||
async def list_missing_models() -> ModelsList:
|
||||
"""Get models whose files are missing from disk.
|
||||
|
||||
These are models that have database entries but their corresponding
|
||||
weight files have been deleted externally (not via Model Manager).
|
||||
"""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
models_path = ApiDependencies.invoker.services.configuration.models_path
|
||||
|
||||
missing_models: list[AnyModelConfig] = []
|
||||
for model_config in record_store.all_models():
|
||||
if model_config.base == BaseModelType.External or model_config.format == ModelFormat.ExternalApi:
|
||||
continue
|
||||
if not (models_path / model_config.path).resolve().exists():
|
||||
missing_models.append(model_config)
|
||||
|
||||
return ModelsList(models=missing_models)
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/get_by_attrs",
|
||||
operation_id="get_model_records_by_attrs",
|
||||
@@ -166,7 +221,24 @@ async def get_model_records_by_attrs(
|
||||
if not configs:
|
||||
raise HTTPException(status_code=404, detail="No model found with these attributes")
|
||||
|
||||
return configs[0]
|
||||
return prepare_model_config_for_response(configs[0], ApiDependencies)
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/get_by_hash",
|
||||
operation_id="get_model_records_by_hash",
|
||||
response_model=AnyModelConfig,
|
||||
)
|
||||
async def get_model_records_by_hash(
|
||||
hash: str = Query(description="The hash of the model"),
|
||||
) -> AnyModelConfig:
|
||||
"""Gets a model by its hash. This is useful for recalling models that were deleted and reinstalled,
|
||||
as the hash remains stable across reinstallations while the key (UUID) changes."""
|
||||
configs = ApiDependencies.invoker.services.model_manager.store.search_by_hash(hash)
|
||||
if not configs:
|
||||
raise HTTPException(status_code=404, detail="No model found with this hash")
|
||||
|
||||
return prepare_model_config_for_response(configs[0], ApiDependencies)
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
@@ -187,7 +259,7 @@ async def get_model_record(
|
||||
"""Get a model record"""
|
||||
try:
|
||||
config = ApiDependencies.invoker.services.model_manager.store.get_model(key)
|
||||
return add_cover_image_to_model_config(config, ApiDependencies)
|
||||
return prepare_model_config_for_response(config, ApiDependencies)
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -206,6 +278,7 @@ async def get_model_record(
|
||||
)
|
||||
async def reidentify_model(
|
||||
key: Annotated[str, Path(description="Key of the model to reidentify.")],
|
||||
current_admin: AdminUserOrDefault,
|
||||
) -> AnyModelConfig:
|
||||
"""Attempt to reidentify a model by re-probing its weights file."""
|
||||
try:
|
||||
@@ -219,7 +292,18 @@ async def reidentify_model(
|
||||
result = ModelConfigFactory.from_model_on_disk(mod)
|
||||
if result.config is None:
|
||||
raise InvalidModelException("Unable to identify model format")
|
||||
result.config.key = config.key # retain the same key
|
||||
|
||||
# Retain user-editable fields from the original config
|
||||
result.config.path = config.path
|
||||
result.config.key = config.key
|
||||
result.config.name = config.name
|
||||
result.config.description = config.description
|
||||
result.config.cover_image = config.cover_image
|
||||
if hasattr(result.config, "trigger_phrases") and hasattr(config, "trigger_phrases"):
|
||||
result.config.trigger_phrases = config.trigger_phrases
|
||||
result.config.source = config.source
|
||||
result.config.source_type = config.source_type
|
||||
|
||||
new_config = ApiDependencies.invoker.services.model_manager.store.replace_model(config.key, result.config)
|
||||
return new_config
|
||||
except UnknownModelException as e:
|
||||
@@ -332,13 +416,14 @@ async def get_hugging_face_models(
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
changes: Annotated[ModelRecordChanges, Body(description="Model config", examples=[example_model_input])],
|
||||
current_admin: AdminUserOrDefault,
|
||||
) -> AnyModelConfig:
|
||||
"""Update a model's config."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
config = record_store.update_model(key, changes=changes, allow_class_change=True)
|
||||
config = add_cover_image_to_model_config(config, ApiDependencies)
|
||||
config = prepare_model_config_for_response(config, ApiDependencies)
|
||||
logger.info(f"Updated model: {key}")
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@@ -394,6 +479,7 @@ async def get_model_image(
|
||||
async def update_model_image(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
image: UploadFile,
|
||||
current_admin: AdminUserOrDefault,
|
||||
) -> None:
|
||||
if not image.content_type or not image.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
@@ -427,6 +513,7 @@ async def update_model_image(
|
||||
status_code=204,
|
||||
)
|
||||
async def delete_model(
|
||||
current_admin: AdminUserOrDefault,
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
) -> Response:
|
||||
"""
|
||||
@@ -460,6 +547,19 @@ class BulkDeleteModelsResponse(BaseModel):
|
||||
failed: List[dict] = Field(description="List of failed deletions with error messages")
|
||||
|
||||
|
||||
class BulkReidentifyModelsRequest(BaseModel):
|
||||
"""Request body for bulk model reidentification."""
|
||||
|
||||
keys: List[str] = Field(description="List of model keys to reidentify")
|
||||
|
||||
|
||||
class BulkReidentifyModelsResponse(BaseModel):
|
||||
"""Response body for bulk model reidentification."""
|
||||
|
||||
succeeded: List[str] = Field(description="List of successfully reidentified model keys")
|
||||
failed: List[dict] = Field(description="List of failed reidentifications with error messages")
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/i/bulk_delete",
|
||||
operation_id="bulk_delete_models",
|
||||
@@ -469,6 +569,7 @@ class BulkDeleteModelsResponse(BaseModel):
|
||||
status_code=200,
|
||||
)
|
||||
async def bulk_delete_models(
|
||||
current_admin: AdminUserOrDefault,
|
||||
request: BulkDeleteModelsRequest = Body(description="List of model keys to delete"),
|
||||
) -> BulkDeleteModelsResponse:
|
||||
"""
|
||||
@@ -500,6 +601,67 @@ async def bulk_delete_models(
|
||||
return BulkDeleteModelsResponse(deleted=deleted, failed=failed)
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/i/bulk_reidentify",
|
||||
operation_id="bulk_reidentify_models",
|
||||
responses={
|
||||
200: {"description": "Models reidentified (possibly with some failures)"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def bulk_reidentify_models(
|
||||
current_admin: AdminUserOrDefault,
|
||||
request: BulkReidentifyModelsRequest = Body(description="List of model keys to reidentify"),
|
||||
) -> BulkReidentifyModelsResponse:
|
||||
"""
|
||||
Reidentify multiple models by re-probing their weights files.
|
||||
|
||||
Returns a list of successfully reidentified keys and failed reidentifications with error messages.
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
store = ApiDependencies.invoker.services.model_manager.store
|
||||
models_path = ApiDependencies.invoker.services.configuration.models_path
|
||||
|
||||
succeeded = []
|
||||
failed = []
|
||||
|
||||
for key in request.keys:
|
||||
try:
|
||||
config = store.get_model(key)
|
||||
if pathlib.Path(config.path).is_relative_to(models_path):
|
||||
model_path = pathlib.Path(config.path)
|
||||
else:
|
||||
model_path = models_path / config.path
|
||||
mod = ModelOnDisk(model_path)
|
||||
result = ModelConfigFactory.from_model_on_disk(mod)
|
||||
if result.config is None:
|
||||
raise InvalidModelException("Unable to identify model format")
|
||||
|
||||
# Retain user-editable fields from the original config
|
||||
result.config.path = config.path
|
||||
result.config.key = config.key
|
||||
result.config.name = config.name
|
||||
result.config.description = config.description
|
||||
result.config.cover_image = config.cover_image
|
||||
if hasattr(config, "trigger_phrases") and hasattr(result.config, "trigger_phrases"):
|
||||
result.config.trigger_phrases = config.trigger_phrases
|
||||
result.config.source = config.source
|
||||
result.config.source_type = config.source_type
|
||||
|
||||
store.replace_model(config.key, result.config)
|
||||
succeeded.append(key)
|
||||
logger.info(f"Reidentified model: {key}")
|
||||
except UnknownModelException as e:
|
||||
logger.error(f"Failed to reidentify model {key}: {str(e)}")
|
||||
failed.append({"key": key, "error": str(e)})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reidentify model {key}: {str(e)}")
|
||||
failed.append({"key": key, "error": str(e)})
|
||||
|
||||
logger.info(f"Bulk reidentify completed: {len(succeeded)} succeeded, {len(failed)} failed")
|
||||
return BulkReidentifyModelsResponse(succeeded=succeeded, failed=failed)
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/i/{key}/image",
|
||||
operation_id="delete_model_image",
|
||||
@@ -510,6 +672,7 @@ async def bulk_delete_models(
|
||||
status_code=204,
|
||||
)
|
||||
async def delete_model_image(
|
||||
current_admin: AdminUserOrDefault,
|
||||
key: str = Path(description="Unique key of model image to remove from model_images directory."),
|
||||
) -> None:
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
@@ -535,6 +698,7 @@ async def delete_model_image(
|
||||
status_code=201,
|
||||
)
|
||||
async def install_model(
|
||||
current_admin: AdminUserOrDefault,
|
||||
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
||||
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
|
||||
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
|
||||
@@ -605,6 +769,7 @@ async def install_model(
|
||||
response_class=HTMLResponse,
|
||||
)
|
||||
async def install_hugging_face_model(
|
||||
current_admin: AdminUserOrDefault,
|
||||
source: str = Query(description="HuggingFace repo_id to install"),
|
||||
) -> HTMLResponse:
|
||||
"""Install a Hugging Face model using a string identifier."""
|
||||
@@ -724,7 +889,7 @@ async def install_hugging_face_model(
|
||||
"/install",
|
||||
operation_id="list_model_installs",
|
||||
)
|
||||
async def list_model_installs() -> List[ModelInstallJob]:
|
||||
async def list_model_installs(current_admin: AdminUserOrDefault) -> List[ModelInstallJob]:
|
||||
"""Return the list of model install jobs.
|
||||
|
||||
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
|
||||
@@ -733,6 +898,7 @@ async def list_model_installs() -> List[ModelInstallJob]:
|
||||
* "waiting" -- Job is waiting in the queue to run
|
||||
* "downloading" -- Model file(s) are downloading
|
||||
* "running" -- Model has downloaded and the model probing and registration process is running
|
||||
* "paused" -- Job is paused and can be resumed
|
||||
* "completed" -- Installation completed successfully
|
||||
* "error" -- An error occurred. Details will be in the "error_type" and "error" fields.
|
||||
* "cancelled" -- Job was cancelled before completion.
|
||||
@@ -755,7 +921,9 @@ async def list_model_installs() -> List[ModelInstallJob]:
|
||||
404: {"description": "No such job"},
|
||||
},
|
||||
)
|
||||
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
|
||||
async def get_model_install_job(
|
||||
current_admin: AdminUserOrDefault, id: int = Path(description="Model install id")
|
||||
) -> ModelInstallJob:
|
||||
"""
|
||||
Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs'
|
||||
for information on the format of the return value.
|
||||
@@ -776,7 +944,10 @@ async def get_model_install_job(id: int = Path(description="Model install id"))
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
|
||||
async def cancel_model_install_job(
|
||||
current_admin: AdminUserOrDefault,
|
||||
id: int = Path(description="Model install job ID"),
|
||||
) -> None:
|
||||
"""Cancel the model install job(s) corresponding to the given job ID."""
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
try:
|
||||
@@ -786,6 +957,96 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
|
||||
installer.cancel_job(job)
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/install/{id}/pause",
|
||||
operation_id="pause_model_install_job",
|
||||
responses={
|
||||
201: {"description": "The job was paused successfully"},
|
||||
415: {"description": "No such job"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def pause_model_install_job(
|
||||
current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID")
|
||||
) -> ModelInstallJob:
|
||||
"""Pause the model install job corresponding to the given job ID."""
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
try:
|
||||
job = installer.get_job_by_id(id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=415, detail=str(e))
|
||||
installer.pause_job(job)
|
||||
return job
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/install/{id}/resume",
|
||||
operation_id="resume_model_install_job",
|
||||
responses={
|
||||
201: {"description": "The job was resumed successfully"},
|
||||
415: {"description": "No such job"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def resume_model_install_job(
|
||||
current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID")
|
||||
) -> ModelInstallJob:
|
||||
"""Resume a paused model install job corresponding to the given job ID."""
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
try:
|
||||
job = installer.get_job_by_id(id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=415, detail=str(e))
|
||||
installer.resume_job(job)
|
||||
return job
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/install/{id}/restart_failed",
|
||||
operation_id="restart_failed_model_install_job",
|
||||
responses={
|
||||
201: {"description": "Failed files restarted successfully"},
|
||||
415: {"description": "No such job"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def restart_failed_model_install_job(
|
||||
current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID")
|
||||
) -> ModelInstallJob:
|
||||
"""Restart failed or non-resumable file downloads for the given job."""
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
try:
|
||||
job = installer.get_job_by_id(id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=415, detail=str(e))
|
||||
installer.restart_failed(job)
|
||||
return job
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/install/{id}/restart_file",
|
||||
operation_id="restart_model_install_file",
|
||||
responses={
|
||||
201: {"description": "File restarted successfully"},
|
||||
415: {"description": "No such job"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def restart_model_install_file(
|
||||
current_admin: AdminUserOrDefault,
|
||||
id: int = Path(description="Model install job ID"),
|
||||
file_source: AnyHttpUrl = Body(description="File download URL to restart"),
|
||||
) -> ModelInstallJob:
|
||||
"""Restart a specific file download for the given job."""
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
try:
|
||||
job = installer.get_job_by_id(id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=415, detail=str(e))
|
||||
installer.restart_file(job, str(file_source))
|
||||
return job
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/install",
|
||||
operation_id="prune_model_install_jobs",
|
||||
@@ -794,7 +1055,7 @@ async def cancel_model_install_job(id: int = Path(description="Model install job
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def prune_model_install_jobs() -> Response:
|
||||
async def prune_model_install_jobs(current_admin: AdminUserOrDefault) -> Response:
|
||||
"""Prune all completed and errored jobs from the install job list."""
|
||||
ApiDependencies.invoker.services.model_manager.install.prune_jobs()
|
||||
return Response(status_code=204)
|
||||
@@ -814,6 +1075,7 @@ async def prune_model_install_jobs() -> Response:
|
||||
},
|
||||
)
|
||||
async def convert_model(
|
||||
current_admin: AdminUserOrDefault,
|
||||
key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."),
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
@@ -893,7 +1155,7 @@ async def convert_model(
|
||||
|
||||
# return the config record for the new diffusers directory
|
||||
new_config = store.get_model(new_key)
|
||||
new_config = add_cover_image_to_model_config(new_config, ApiDependencies)
|
||||
new_config = prepare_model_config_for_response(new_config, ApiDependencies)
|
||||
return new_config
|
||||
|
||||
|
||||
@@ -905,15 +1167,48 @@ class StarterModelResponse(BaseModel):
|
||||
def get_is_installed(
|
||||
starter_model: StarterModel | StarterModelWithoutDependencies, installed_models: list[AnyModelConfig]
|
||||
) -> bool:
|
||||
from invokeai.backend.model_manager.taxonomy import ModelType
|
||||
|
||||
for model in installed_models:
|
||||
# Check if source matches exactly
|
||||
if model.source == starter_model.source:
|
||||
return True
|
||||
# Check if name (or previous names), base and type match
|
||||
if (
|
||||
(model.name == starter_model.name or model.name in starter_model.previous_names)
|
||||
and model.base == starter_model.base
|
||||
and model.type == starter_model.type
|
||||
):
|
||||
return True
|
||||
|
||||
# Special handling for Qwen3Encoder models - check by type and variant
|
||||
# This allows renamed models to still be detected as installed
|
||||
if starter_model.type == ModelType.Qwen3Encoder:
|
||||
from invokeai.backend.model_manager.taxonomy import Qwen3VariantType
|
||||
|
||||
# Determine expected variant from source pattern
|
||||
expected_variant: Qwen3VariantType | None = None
|
||||
if "klein-9B" in starter_model.source or "qwen3_8b" in starter_model.source.lower():
|
||||
expected_variant = Qwen3VariantType.Qwen3_8B
|
||||
elif (
|
||||
"klein-4B" in starter_model.source
|
||||
or "qwen3_4b" in starter_model.source.lower()
|
||||
or "Z-Image" in starter_model.source
|
||||
):
|
||||
expected_variant = Qwen3VariantType.Qwen3_4B
|
||||
|
||||
if expected_variant is not None:
|
||||
for model in installed_models:
|
||||
if model.type == ModelType.Qwen3Encoder and hasattr(model, "variant"):
|
||||
model_variant = model.variant
|
||||
# Handle both enum and string values
|
||||
if isinstance(model_variant, Qwen3VariantType):
|
||||
if model_variant == expected_variant:
|
||||
return True
|
||||
elif isinstance(model_variant, str):
|
||||
if model_variant == expected_variant.value:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@@ -962,7 +1257,7 @@ async def get_stats() -> Optional[CacheStats]:
|
||||
operation_id="empty_model_cache",
|
||||
status_code=200,
|
||||
)
|
||||
async def empty_model_cache() -> None:
|
||||
async def empty_model_cache(current_admin: AdminUserOrDefault) -> None:
|
||||
"""Drop all models from the model cache to free RAM/VRAM. 'Locked' models that are in active use will not be dropped."""
|
||||
# Request 1000GB of room in order to force the cache to drop all models.
|
||||
ApiDependencies.invoker.services.logger.info("Emptying model cache.")
|
||||
@@ -979,11 +1274,11 @@ class HFTokenHelper:
|
||||
@classmethod
|
||||
def get_status(cls) -> HFTokenStatus:
|
||||
try:
|
||||
if huggingface_hub.get_token_permission(huggingface_hub.get_token()):
|
||||
# Valid token!
|
||||
return HFTokenStatus.VALID
|
||||
# No token set
|
||||
return HFTokenStatus.INVALID
|
||||
token = huggingface_hub.get_token()
|
||||
if not token:
|
||||
return HFTokenStatus.INVALID
|
||||
huggingface_hub.whoami(token=token)
|
||||
return HFTokenStatus.VALID
|
||||
except Exception:
|
||||
return HFTokenStatus.UNKNOWN
|
||||
|
||||
@@ -1012,6 +1307,7 @@ async def get_hf_login_status() -> HFTokenStatus:
|
||||
|
||||
@model_manager_router.post("/hf_login", operation_id="do_hf_login", response_model=HFTokenStatus)
|
||||
async def do_hf_login(
|
||||
current_admin: AdminUserOrDefault,
|
||||
token: str = Body(description="Hugging Face token to use for login", embed=True),
|
||||
) -> HFTokenStatus:
|
||||
HFTokenHelper.set_token(token)
|
||||
@@ -1024,5 +1320,83 @@ async def do_hf_login(
|
||||
|
||||
|
||||
@model_manager_router.delete("/hf_login", operation_id="reset_hf_token", response_model=HFTokenStatus)
|
||||
async def reset_hf_token() -> HFTokenStatus:
|
||||
async def reset_hf_token(current_admin: AdminUserOrDefault) -> HFTokenStatus:
|
||||
return HFTokenHelper.reset_token()
|
||||
|
||||
|
||||
# Orphaned Models Management Routes
|
||||
|
||||
|
||||
class DeleteOrphanedModelsRequest(BaseModel):
|
||||
"""Request to delete specific orphaned model directories."""
|
||||
|
||||
paths: list[str] = Field(description="List of relative paths to delete")
|
||||
|
||||
|
||||
class DeleteOrphanedModelsResponse(BaseModel):
|
||||
"""Response from deleting orphaned models."""
|
||||
|
||||
deleted: list[str] = Field(description="Paths that were successfully deleted")
|
||||
errors: dict[str, str] = Field(description="Paths that had errors, with error messages")
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/sync/orphaned",
|
||||
operation_id="get_orphaned_models",
|
||||
response_model=list[OrphanedModelInfo],
|
||||
)
|
||||
async def get_orphaned_models(_: AdminUserOrDefault) -> list[OrphanedModelInfo]:
|
||||
"""Find orphaned model directories.
|
||||
|
||||
Orphaned models are directories in the models folder that contain model files
|
||||
but are not referenced in the database. This can happen when models are deleted
|
||||
from the database but the files remain on disk.
|
||||
|
||||
Returns:
|
||||
List of orphaned model directory information
|
||||
"""
|
||||
from invokeai.app.services.orphaned_models import OrphanedModelsService
|
||||
|
||||
# Access the database through the model records service
|
||||
model_records_service = ApiDependencies.invoker.services.model_manager.store
|
||||
|
||||
service = OrphanedModelsService(
|
||||
config=ApiDependencies.invoker.services.configuration,
|
||||
db=model_records_service._db, # Access the database from model records service
|
||||
)
|
||||
return service.find_orphaned_models()
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/sync/orphaned",
|
||||
operation_id="delete_orphaned_models",
|
||||
response_model=DeleteOrphanedModelsResponse,
|
||||
)
|
||||
async def delete_orphaned_models(
|
||||
request: DeleteOrphanedModelsRequest, _: AdminUserOrDefault
|
||||
) -> DeleteOrphanedModelsResponse:
|
||||
"""Delete specified orphaned model directories.
|
||||
|
||||
Args:
|
||||
request: Request containing list of relative paths to delete
|
||||
|
||||
Returns:
|
||||
Response indicating which paths were deleted and which had errors
|
||||
"""
|
||||
from invokeai.app.services.orphaned_models import OrphanedModelsService
|
||||
|
||||
# Access the database through the model records service
|
||||
model_records_service = ApiDependencies.invoker.services.model_manager.store
|
||||
|
||||
service = OrphanedModelsService(
|
||||
config=ApiDependencies.invoker.services.configuration,
|
||||
db=model_records_service._db, # Access the database from model records service
|
||||
)
|
||||
|
||||
results = service.delete_orphaned_models(request.paths)
|
||||
|
||||
# Separate successful deletions from errors
|
||||
deleted = [path for path, status in results.items() if status == "deleted"]
|
||||
errors = {path: status for path, status in results.items() if status != "deleted"}
|
||||
|
||||
return DeleteOrphanedModelsResponse(deleted=deleted, errors=errors)
|
||||
|
||||
512
invokeai/app/api/routers/recall_parameters.py
Normal file
512
invokeai/app/api/routers/recall_parameters.py
Normal file
@@ -0,0 +1,512 @@
|
||||
"""Router for updating recallable parameters on the frontend."""
|
||||
|
||||
import json
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from fastapi import Body, HTTPException, Path
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.backend.image_util.controlnet_processor import process_controlnet_image
|
||||
from invokeai.backend.model_manager.taxonomy import ModelType
|
||||
|
||||
recall_parameters_router = APIRouter(prefix="/v1/recall", tags=["recall"])
|
||||
|
||||
|
||||
class LoRARecallParameter(BaseModel):
|
||||
"""LoRA configuration for recall"""
|
||||
|
||||
model_name: str = Field(description="The name of the LoRA model")
|
||||
weight: float = Field(default=0.75, ge=-10, le=10, description="The weight for the LoRA")
|
||||
is_enabled: bool = Field(default=True, description="Whether the LoRA is enabled")
|
||||
|
||||
|
||||
class ControlNetRecallParameter(BaseModel):
|
||||
"""ControlNet configuration for recall"""
|
||||
|
||||
model_name: str = Field(description="The name of the ControlNet/T2I Adapter/Control LoRA model")
|
||||
image_name: Optional[str] = Field(default=None, description="The filename of the control image in outputs/images")
|
||||
weight: float = Field(default=1.0, ge=-1, le=2, description="The weight for the control adapter")
|
||||
begin_step_percent: Optional[float] = Field(
|
||||
default=None, ge=0, le=1, description="When the control adapter is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: Optional[float] = Field(
|
||||
default=None, ge=0, le=1, description="When the control adapter is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: Optional[Literal["balanced", "more_prompt", "more_control"]] = Field(
|
||||
default=None, description="The control mode (ControlNet only)"
|
||||
)
|
||||
|
||||
|
||||
class IPAdapterRecallParameter(BaseModel):
|
||||
"""IP Adapter configuration for recall"""
|
||||
|
||||
model_name: str = Field(description="The name of the IP Adapter model")
|
||||
image_name: Optional[str] = Field(default=None, description="The filename of the reference image in outputs/images")
|
||||
weight: float = Field(default=1.0, ge=-1, le=2, description="The weight for the IP Adapter")
|
||||
begin_step_percent: Optional[float] = Field(
|
||||
default=None, ge=0, le=1, description="When the IP Adapter is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: Optional[float] = Field(
|
||||
default=None, ge=0, le=1, description="When the IP Adapter is last applied (% of total steps)"
|
||||
)
|
||||
method: Optional[Literal["full", "style", "composition"]] = Field(default=None, description="The IP Adapter method")
|
||||
image_influence: Optional[Literal["lowest", "low", "medium", "high", "highest"]] = Field(
|
||||
default=None, description="FLUX Redux image influence (if model is flux_redux)"
|
||||
)
|
||||
|
||||
|
||||
class RecallParameter(BaseModel):
|
||||
"""Request model for updating recallable parameters."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
# Prompts
|
||||
positive_prompt: Optional[str] = Field(None, description="Positive prompt text")
|
||||
negative_prompt: Optional[str] = Field(None, description="Negative prompt text")
|
||||
|
||||
# Model configuration
|
||||
model: Optional[str] = Field(None, description="Main model name/identifier")
|
||||
refiner_model: Optional[str] = Field(None, description="Refiner model name/identifier")
|
||||
vae_model: Optional[str] = Field(None, description="VAE model name/identifier")
|
||||
scheduler: Optional[str] = Field(None, description="Scheduler name")
|
||||
|
||||
# Generation parameters
|
||||
steps: Optional[int] = Field(None, ge=1, description="Number of generation steps")
|
||||
refiner_steps: Optional[int] = Field(None, ge=0, description="Number of refiner steps")
|
||||
cfg_scale: Optional[float] = Field(None, description="CFG scale for guidance")
|
||||
cfg_rescale_multiplier: Optional[float] = Field(None, description="CFG rescale multiplier")
|
||||
refiner_cfg_scale: Optional[float] = Field(None, description="Refiner CFG scale")
|
||||
guidance: Optional[float] = Field(None, description="Guidance scale")
|
||||
|
||||
# Image parameters
|
||||
width: Optional[int] = Field(None, ge=64, description="Image width in pixels")
|
||||
height: Optional[int] = Field(None, ge=64, description="Image height in pixels")
|
||||
seed: Optional[int] = Field(None, ge=0, description="Random seed")
|
||||
|
||||
# Advanced parameters
|
||||
denoise_strength: Optional[float] = Field(None, ge=0, le=1, description="Denoising strength")
|
||||
refiner_denoise_start: Optional[float] = Field(None, ge=0, le=1, description="Refiner denoising start")
|
||||
clip_skip: Optional[int] = Field(None, ge=0, description="CLIP skip layers")
|
||||
seamless_x: Optional[bool] = Field(None, description="Enable seamless X tiling")
|
||||
seamless_y: Optional[bool] = Field(None, description="Enable seamless Y tiling")
|
||||
|
||||
# Refiner aesthetics
|
||||
refiner_positive_aesthetic_score: Optional[float] = Field(None, description="Refiner positive aesthetic score")
|
||||
refiner_negative_aesthetic_score: Optional[float] = Field(None, description="Refiner negative aesthetic score")
|
||||
|
||||
# LoRAs, ControlNets, and IP Adapters
|
||||
loras: Optional[list[LoRARecallParameter]] = Field(None, description="List of LoRAs with their weights")
|
||||
control_layers: Optional[list[ControlNetRecallParameter]] = Field(
|
||||
None, description="List of control adapters (ControlNet, T2I Adapter, Control LoRA) with their settings"
|
||||
)
|
||||
ip_adapters: Optional[list[IPAdapterRecallParameter]] = Field(
|
||||
None, description="List of IP Adapters with their settings"
|
||||
)
|
||||
|
||||
|
||||
def resolve_model_name_to_key(model_name: str, model_type: ModelType = ModelType.Main) -> Optional[str]:
|
||||
"""
|
||||
Look up a model by name and return its key.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model to look up
|
||||
model_type: The type of model to search for (default: Main)
|
||||
|
||||
Returns:
|
||||
The key of the first matching model, or None if not found.
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
models = ApiDependencies.invoker.services.model_manager.store.search_by_attr(
|
||||
model_name=model_name, model_type=model_type
|
||||
)
|
||||
|
||||
if models:
|
||||
logger.info(f"Resolved {model_type.value} model name '{model_name}' to key '{models[0].key}'")
|
||||
return models[0].key
|
||||
|
||||
logger.warning(f"Could not find {model_type.value} model with name '{model_name}'")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during {model_type.value} model lookup: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def load_image_file(image_name: str) -> Optional[dict[str, Any]]:
|
||||
"""
|
||||
Load an image from the outputs/images directory.
|
||||
|
||||
Args:
|
||||
image_name: The filename of the image in outputs/images
|
||||
|
||||
Returns:
|
||||
A dictionary with image_name, width, and height, or None if the image cannot be found
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
# Prefer using the image_files service to validate & open images
|
||||
image_files = ApiDependencies.invoker.services.image_files
|
||||
# Resolve a safe path inside outputs
|
||||
image_path = image_files.get_path(image_name)
|
||||
|
||||
if not image_files.validate_path(str(image_path)):
|
||||
logger.warning(f"Image file not found: {image_name} (searched in {image_path.parent})")
|
||||
return None
|
||||
|
||||
# Open the image via service to leverage caching
|
||||
pil_image = image_files.get(image_name)
|
||||
width, height = pil_image.size
|
||||
logger.info(f"Found image file: {image_name} ({width}x{height})")
|
||||
return {"image_name": image_name, "width": width, "height": height}
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading image file {image_name}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def resolve_lora_models(loras: list[LoRARecallParameter]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Resolve LoRA model names to keys and build configuration list.
|
||||
|
||||
Args:
|
||||
loras: List of LoRA recall parameters
|
||||
|
||||
Returns:
|
||||
List of resolved LoRA configurations with model keys
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
resolved_loras = []
|
||||
|
||||
for lora in loras:
|
||||
model_key = resolve_model_name_to_key(lora.model_name, ModelType.LoRA)
|
||||
if model_key:
|
||||
resolved_loras.append({"model_key": model_key, "weight": lora.weight, "is_enabled": lora.is_enabled})
|
||||
else:
|
||||
logger.warning(f"Skipping LoRA '{lora.model_name}' - model not found")
|
||||
|
||||
return resolved_loras
|
||||
|
||||
|
||||
def resolve_control_models(control_layers: list[ControlNetRecallParameter]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Resolve control adapter model names to keys and build configuration list.
|
||||
|
||||
Tries to resolve as ControlNet, T2I Adapter, or Control LoRA in that order.
|
||||
|
||||
Args:
|
||||
control_layers: List of control adapter recall parameters
|
||||
|
||||
Returns:
|
||||
List of resolved control adapter configurations with model keys
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
services = ApiDependencies.invoker.services
|
||||
resolved_controls = []
|
||||
|
||||
for control in control_layers:
|
||||
model_key = None
|
||||
|
||||
# Try ControlNet first
|
||||
model_key = resolve_model_name_to_key(control.model_name, ModelType.ControlNet)
|
||||
if not model_key:
|
||||
# Try T2I Adapter
|
||||
model_key = resolve_model_name_to_key(control.model_name, ModelType.T2IAdapter)
|
||||
if not model_key:
|
||||
# Try Control LoRA (also uses LoRA type)
|
||||
model_key = resolve_model_name_to_key(control.model_name, ModelType.LoRA)
|
||||
|
||||
if model_key:
|
||||
config: dict[str, Any] = {"model_key": model_key, "weight": control.weight}
|
||||
if control.image_name is not None:
|
||||
image_data = load_image_file(control.image_name)
|
||||
if image_data:
|
||||
config["image"] = image_data
|
||||
|
||||
# Try to process the image using the model's default processor
|
||||
processed_image_data = process_controlnet_image(control.image_name, model_key, services)
|
||||
if processed_image_data:
|
||||
config["processed_image"] = processed_image_data
|
||||
logger.info(f"Added processed image for control adapter {control.model_name}")
|
||||
else:
|
||||
logger.warning(f"Could not load image for control adapter: {control.image_name}")
|
||||
if control.begin_step_percent is not None:
|
||||
config["begin_step_percent"] = control.begin_step_percent
|
||||
if control.end_step_percent is not None:
|
||||
config["end_step_percent"] = control.end_step_percent
|
||||
if control.control_mode is not None:
|
||||
config["control_mode"] = control.control_mode
|
||||
|
||||
resolved_controls.append(config)
|
||||
else:
|
||||
logger.warning(f"Skipping control adapter '{control.model_name}' - model not found")
|
||||
|
||||
return resolved_controls
|
||||
|
||||
|
||||
def resolve_ip_adapter_models(ip_adapters: list[IPAdapterRecallParameter]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Resolve IP Adapter model names to keys and build configuration list.
|
||||
|
||||
Args:
|
||||
ip_adapters: List of IP Adapter recall parameters
|
||||
|
||||
Returns:
|
||||
List of resolved IP Adapter configurations with model keys
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
resolved_adapters = []
|
||||
|
||||
for adapter in ip_adapters:
|
||||
# Try resolving as IP Adapter; if not found, try FLUX Redux
|
||||
model_key = resolve_model_name_to_key(adapter.model_name, ModelType.IPAdapter)
|
||||
if not model_key:
|
||||
model_key = resolve_model_name_to_key(adapter.model_name, ModelType.FluxRedux)
|
||||
if model_key:
|
||||
config: dict[str, Any] = {
|
||||
"model_key": model_key,
|
||||
# Always include weight; ignored by FLUX Redux on the frontend
|
||||
"weight": adapter.weight,
|
||||
}
|
||||
if adapter.image_name is not None:
|
||||
image_data = load_image_file(adapter.image_name)
|
||||
if image_data:
|
||||
config["image"] = image_data
|
||||
else:
|
||||
logger.warning(f"Could not load image for IP Adapter: {adapter.image_name}")
|
||||
if adapter.begin_step_percent is not None:
|
||||
config["begin_step_percent"] = adapter.begin_step_percent
|
||||
if adapter.end_step_percent is not None:
|
||||
config["end_step_percent"] = adapter.end_step_percent
|
||||
if adapter.method is not None:
|
||||
config["method"] = adapter.method
|
||||
# Include FLUX Redux image influence when provided
|
||||
if adapter.image_influence is not None:
|
||||
config["image_influence"] = adapter.image_influence
|
||||
|
||||
resolved_adapters.append(config)
|
||||
else:
|
||||
logger.warning(f"Skipping IP Adapter '{adapter.model_name}' - model not found")
|
||||
|
||||
return resolved_adapters
|
||||
|
||||
|
||||
def _assert_recall_image_access(parameters: "RecallParameter", current_user: CurrentUserOrDefault) -> None:
|
||||
"""Validate that the caller can read every image referenced in the recall parameters.
|
||||
|
||||
Control layers and IP adapters may reference image_name fields. Without this
|
||||
check an attacker who knows another user's image UUID could use the recall
|
||||
endpoint to extract image dimensions and — for ControlNet preprocessors — mint
|
||||
a derived processed image they can then fetch.
|
||||
"""
|
||||
from invokeai.app.services.board_records.board_records_common import BoardVisibility
|
||||
|
||||
image_names: list[str] = []
|
||||
if parameters.control_layers:
|
||||
for layer in parameters.control_layers:
|
||||
if layer.image_name is not None:
|
||||
image_names.append(layer.image_name)
|
||||
if parameters.ip_adapters:
|
||||
for adapter in parameters.ip_adapters:
|
||||
if adapter.image_name is not None:
|
||||
image_names.append(adapter.image_name)
|
||||
|
||||
if not image_names:
|
||||
return
|
||||
|
||||
# Admin can access all images
|
||||
if current_user.is_admin:
|
||||
return
|
||||
|
||||
for image_name in image_names:
|
||||
owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name)
|
||||
if owner is not None and owner == current_user.user_id:
|
||||
continue
|
||||
|
||||
# Check board visibility
|
||||
board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name)
|
||||
if board_id is not None:
|
||||
try:
|
||||
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public):
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise HTTPException(status_code=403, detail=f"Not authorized to access image {image_name}")
|
||||
|
||||
|
||||
@recall_parameters_router.post(
|
||||
"/{queue_id}",
|
||||
operation_id="update_recall_parameters",
|
||||
response_model=dict[str, Any],
|
||||
)
|
||||
async def update_recall_parameters(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(..., description="The queue id to perform this operation on"),
|
||||
parameters: RecallParameter = Body(..., description="Recall parameters to update"),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Update recallable parameters that can be recalled on the frontend.
|
||||
|
||||
This endpoint allows updating parameters such as prompt, model, steps, and other
|
||||
generation settings. These parameters are stored in client state and can be
|
||||
accessed by the frontend to populate UI elements.
|
||||
|
||||
Args:
|
||||
queue_id: The queue ID to associate these parameters with
|
||||
parameters: The RecallParameter object containing the parameters to update
|
||||
|
||||
Returns:
|
||||
A dictionary containing the updated parameters and status
|
||||
|
||||
Example:
|
||||
POST /api/v1/recall/{queue_id}
|
||||
{
|
||||
"positive_prompt": "a beautiful landscape",
|
||||
"model": "sd-1.5",
|
||||
"steps": 20,
|
||||
"cfg_scale": 7.5,
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"seed": 12345
|
||||
}
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
# Validate image access before processing — prevents information leakage
|
||||
# (dimensions) and derived-image minting via ControlNet preprocessors.
|
||||
_assert_recall_image_access(parameters, current_user)
|
||||
|
||||
try:
|
||||
# Get only the parameters that were actually provided (non-None values)
|
||||
provided_params = {k: v for k, v in parameters.model_dump().items() if v is not None}
|
||||
|
||||
if not provided_params:
|
||||
return {"status": "no_parameters_provided", "updated_count": 0}
|
||||
|
||||
# Store each parameter in client state scoped to the current user
|
||||
updated_count = 0
|
||||
for param_key, param_value in provided_params.items():
|
||||
# Convert parameter values to JSON strings for storage
|
||||
value_str = json.dumps(param_value)
|
||||
try:
|
||||
ApiDependencies.invoker.services.client_state_persistence.set_by_key(
|
||||
current_user.user_id, f"recall_{param_key}", value_str
|
||||
)
|
||||
updated_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting recall parameter {param_key}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error setting recall parameter {param_key}",
|
||||
)
|
||||
|
||||
logger.info(f"Updated {updated_count} recall parameters for queue {queue_id}")
|
||||
|
||||
# Resolve model name to key if a model was provided
|
||||
if "model" in provided_params and isinstance(provided_params["model"], str):
|
||||
model_name = provided_params["model"]
|
||||
model_key = resolve_model_name_to_key(model_name, ModelType.Main)
|
||||
|
||||
if model_key:
|
||||
logger.info(f"Resolved model name '{model_name}' to key '{model_key}'")
|
||||
provided_params["model"] = model_key
|
||||
else:
|
||||
logger.warning(f"Could not resolve model name '{model_name}' to a model key")
|
||||
# Remove model from parameters if we couldn't resolve it
|
||||
del provided_params["model"]
|
||||
|
||||
# Process LoRAs if provided
|
||||
if "loras" in provided_params:
|
||||
loras_param = parameters.loras
|
||||
if loras_param is not None:
|
||||
resolved_loras = resolve_lora_models(loras_param)
|
||||
provided_params["loras"] = resolved_loras
|
||||
logger.info(f"Resolved {len(resolved_loras)} LoRA(s)")
|
||||
|
||||
# Process control layers if provided
|
||||
if "control_layers" in provided_params:
|
||||
control_layers_param = parameters.control_layers
|
||||
if control_layers_param is not None:
|
||||
resolved_controls = resolve_control_models(control_layers_param)
|
||||
provided_params["control_layers"] = resolved_controls
|
||||
logger.info(f"Resolved {len(resolved_controls)} control layer(s)")
|
||||
|
||||
# Process IP adapters if provided
|
||||
if "ip_adapters" in provided_params:
|
||||
ip_adapters_param = parameters.ip_adapters
|
||||
if ip_adapters_param is not None:
|
||||
resolved_adapters = resolve_ip_adapter_models(ip_adapters_param)
|
||||
provided_params["ip_adapters"] = resolved_adapters
|
||||
logger.info(f"Resolved {len(resolved_adapters)} IP adapter(s)")
|
||||
|
||||
# Emit event to notify frontend of parameter updates
|
||||
try:
|
||||
logger.info(
|
||||
f"Emitting recall_parameters_updated event for queue {queue_id} with {len(provided_params)} parameters"
|
||||
)
|
||||
ApiDependencies.invoker.services.events.emit_recall_parameters_updated(
|
||||
queue_id, current_user.user_id, provided_params
|
||||
)
|
||||
logger.info("Successfully emitted recall_parameters_updated event")
|
||||
except Exception as e:
|
||||
logger.error(f"Error emitting recall parameters event: {e}", exc_info=True)
|
||||
# Don't fail the request if event emission fails, just log it
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"queue_id": queue_id,
|
||||
"updated_count": updated_count,
|
||||
"parameters": provided_params,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating recall parameters: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Error updating recall parameters",
|
||||
)
|
||||
|
||||
|
||||
@recall_parameters_router.get(
|
||||
"/{queue_id}",
|
||||
operation_id="get_recall_parameters",
|
||||
response_model=dict[str, Any],
|
||||
)
|
||||
async def get_recall_parameters(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(..., description="The queue id to retrieve parameters for"),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Retrieve all stored recall parameters for a given queue.
|
||||
|
||||
Returns a dictionary of all recall parameters that have been set for the queue.
|
||||
|
||||
Args:
|
||||
queue_id: The queue ID to retrieve parameters for
|
||||
|
||||
Returns:
|
||||
A dictionary containing all stored recall parameters
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
# Retrieve all recall parameters by iterating through expected keys
|
||||
# Since client_state_persistence doesn't have a "get_all" method, we'll
|
||||
# return an informative response
|
||||
return {
|
||||
"status": "success",
|
||||
"queue_id": queue_id,
|
||||
"note": "Use the frontend to access stored recall parameters, or set specific parameters using POST",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving recall parameters: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Error retrieving recall parameters",
|
||||
)
|
||||
@@ -4,6 +4,7 @@ from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.app.api.auth_dependencies import AdminUserOrDefault, CurrentUserOrDefault
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
@@ -24,6 +25,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueItemNotFoundError,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
|
||||
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
|
||||
@@ -36,6 +38,51 @@ class SessionQueueAndProcessorStatus(BaseModel):
|
||||
processor: SessionProcessorStatus
|
||||
|
||||
|
||||
def sanitize_queue_item_for_user(
|
||||
queue_item: SessionQueueItem, current_user_id: str, is_admin: bool
|
||||
) -> SessionQueueItem:
|
||||
"""Sanitize queue item for non-admin users viewing other users' items.
|
||||
|
||||
For non-admin users viewing queue items belonging to other users,
|
||||
only timestamps, status, and error information are exposed. All other
|
||||
fields (user identity, generation parameters, graphs, workflows) are stripped.
|
||||
|
||||
Args:
|
||||
queue_item: The queue item to sanitize
|
||||
current_user_id: The ID of the current user viewing the item
|
||||
is_admin: Whether the current user is an admin
|
||||
|
||||
Returns:
|
||||
The sanitized queue item (sensitive fields cleared if necessary)
|
||||
"""
|
||||
# Admins and item owners can see everything
|
||||
if is_admin or queue_item.user_id == current_user_id:
|
||||
return queue_item
|
||||
|
||||
# For non-admins viewing other users' items, strip everything except
|
||||
# item_id, queue_id, status, and timestamps
|
||||
sanitized_item = queue_item.model_copy(deep=False)
|
||||
sanitized_item.user_id = "redacted"
|
||||
sanitized_item.user_display_name = None
|
||||
sanitized_item.user_email = None
|
||||
sanitized_item.batch_id = "redacted"
|
||||
sanitized_item.session_id = "redacted"
|
||||
sanitized_item.origin = None
|
||||
sanitized_item.destination = None
|
||||
sanitized_item.priority = 0
|
||||
sanitized_item.field_values = None
|
||||
sanitized_item.retried_from_item_id = None
|
||||
sanitized_item.workflow = None
|
||||
sanitized_item.error_type = None
|
||||
sanitized_item.error_message = None
|
||||
sanitized_item.error_traceback = None
|
||||
sanitized_item.session = GraphExecutionState(
|
||||
id="redacted",
|
||||
graph=Graph(),
|
||||
)
|
||||
return sanitized_item
|
||||
|
||||
|
||||
@session_queue_router.post(
|
||||
"/{queue_id}/enqueue_batch",
|
||||
operation_id="enqueue_batch",
|
||||
@@ -44,14 +91,15 @@ class SessionQueueAndProcessorStatus(BaseModel):
|
||||
},
|
||||
)
|
||||
async def enqueue_batch(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
batch: Batch = Body(description="Batch to process"),
|
||||
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
|
||||
) -> EnqueueBatchResult:
|
||||
"""Processes a batch and enqueues the output graphs for execution."""
|
||||
"""Processes a batch and enqueues the output graphs for execution for the current user."""
|
||||
try:
|
||||
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
|
||||
queue_id=queue_id, batch=batch, prepend=prepend
|
||||
queue_id=queue_id, batch=batch, prepend=prepend, user_id=current_user.user_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while enqueuing batch: {e}")
|
||||
@@ -65,15 +113,18 @@ async def enqueue_batch(
|
||||
},
|
||||
)
|
||||
async def list_all_queue_items(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
|
||||
) -> list[SessionQueueItem]:
|
||||
"""Gets all queue items"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.list_all_queue_items(
|
||||
items = ApiDependencies.invoker.services.session_queue.list_all_queue_items(
|
||||
queue_id=queue_id,
|
||||
destination=destination,
|
||||
)
|
||||
# Sanitize items for non-admin users
|
||||
return [sanitize_queue_item_for_user(item, current_user.user_id, current_user.is_admin) for item in items]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue items: {e}")
|
||||
|
||||
@@ -86,12 +137,16 @@ async def list_all_queue_items(
|
||||
},
|
||||
)
|
||||
async def get_queue_item_ids(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
|
||||
) -> ItemIdsResult:
|
||||
"""Gets all queue item ids that match the given parameters"""
|
||||
"""Gets all queue item ids that match the given parameters. Non-admin users only see their own items."""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(queue_id=queue_id, order_dir=order_dir)
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(
|
||||
queue_id=queue_id, order_dir=order_dir, user_id=user_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue item ids: {e}")
|
||||
|
||||
@@ -102,6 +157,7 @@ async def get_queue_item_ids(
|
||||
responses={200: {"model": list[SessionQueueItem]}},
|
||||
)
|
||||
async def get_queue_items_by_item_ids(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
item_ids: list[int] = Body(
|
||||
embed=True, description="Object containing list of queue item ids to fetch queue items for"
|
||||
@@ -118,7 +174,9 @@ async def get_queue_items_by_item_ids(
|
||||
queue_item = session_queue_service.get_queue_item(item_id=item_id)
|
||||
if queue_item.queue_id != queue_id: # Auth protection for items from other queues
|
||||
continue
|
||||
queue_items.append(queue_item)
|
||||
# Sanitize item for non-admin users
|
||||
sanitized_item = sanitize_queue_item_for_user(queue_item, current_user.user_id, current_user.is_admin)
|
||||
queue_items.append(sanitized_item)
|
||||
except Exception:
|
||||
# Skip missing queue items - they may have been deleted between item id fetch and queue item fetch
|
||||
continue
|
||||
@@ -134,9 +192,10 @@ async def get_queue_items_by_item_ids(
|
||||
responses={200: {"model": SessionProcessorStatus}},
|
||||
)
|
||||
async def resume(
|
||||
current_user: AdminUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionProcessorStatus:
|
||||
"""Resumes session processor"""
|
||||
"""Resumes session processor. Admin only."""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_processor.resume()
|
||||
except Exception as e:
|
||||
@@ -148,10 +207,11 @@ async def resume(
|
||||
operation_id="pause",
|
||||
responses={200: {"model": SessionProcessorStatus}},
|
||||
)
|
||||
async def Pause(
|
||||
async def pause(
|
||||
current_user: AdminUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionProcessorStatus:
|
||||
"""Pauses session processor"""
|
||||
"""Pauses session processor. Admin only."""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_processor.pause()
|
||||
except Exception as e:
|
||||
@@ -164,11 +224,16 @@ async def Pause(
|
||||
responses={200: {"model": CancelAllExceptCurrentResult}},
|
||||
)
|
||||
async def cancel_all_except_current(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> CancelAllExceptCurrentResult:
|
||||
"""Immediately cancels all queue items except in-processing items"""
|
||||
"""Immediately cancels all queue items except in-processing items. Non-admin users can only cancel their own items."""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
|
||||
# Admin users can cancel all items, non-admin users can only cancel their own
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(
|
||||
queue_id=queue_id, user_id=user_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling all except current: {e}")
|
||||
|
||||
@@ -179,11 +244,16 @@ async def cancel_all_except_current(
|
||||
responses={200: {"model": DeleteAllExceptCurrentResult}},
|
||||
)
|
||||
async def delete_all_except_current(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> DeleteAllExceptCurrentResult:
|
||||
"""Immediately deletes all queue items except in-processing items"""
|
||||
"""Immediately deletes all queue items except in-processing items. Non-admin users can only delete their own items."""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.delete_all_except_current(queue_id=queue_id)
|
||||
# Admin users can delete all items, non-admin users can only delete their own
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.session_queue.delete_all_except_current(
|
||||
queue_id=queue_id, user_id=user_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting all except current: {e}")
|
||||
|
||||
@@ -194,13 +264,16 @@ async def delete_all_except_current(
|
||||
responses={200: {"model": CancelByBatchIDsResult}},
|
||||
)
|
||||
async def cancel_by_batch_ids(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
batch_ids: list[str] = Body(description="The list of batch_ids to cancel all queue items for", embed=True),
|
||||
) -> CancelByBatchIDsResult:
|
||||
"""Immediately cancels all queue items from the given batch ids"""
|
||||
"""Immediately cancels all queue items from the given batch ids. Non-admin users can only cancel their own items."""
|
||||
try:
|
||||
# Admin users can cancel all items, non-admin users can only cancel their own
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(
|
||||
queue_id=queue_id, batch_ids=batch_ids
|
||||
queue_id=queue_id, batch_ids=batch_ids, user_id=user_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by batch id: {e}")
|
||||
@@ -212,13 +285,16 @@ async def cancel_by_batch_ids(
|
||||
responses={200: {"model": CancelByDestinationResult}},
|
||||
)
|
||||
async def cancel_by_destination(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
destination: str = Query(description="The destination to cancel all queue items for"),
|
||||
) -> CancelByDestinationResult:
|
||||
"""Immediately cancels all queue items with the given origin"""
|
||||
"""Immediately cancels all queue items with the given destination. Non-admin users can only cancel their own items."""
|
||||
try:
|
||||
# Admin users can cancel all items, non-admin users can only cancel their own
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
queue_id=queue_id, destination=destination, user_id=user_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by destination: {e}")
|
||||
@@ -230,12 +306,28 @@ async def cancel_by_destination(
|
||||
responses={200: {"model": RetryItemsResult}},
|
||||
)
|
||||
async def retry_items_by_id(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
item_ids: list[int] = Body(description="The queue item ids to retry"),
|
||||
) -> RetryItemsResult:
|
||||
"""Immediately cancels all queue items with the given origin"""
|
||||
"""Retries the given queue items. Users can only retry their own items unless they are an admin."""
|
||||
try:
|
||||
# Check authorization: user must own all items or be an admin
|
||||
if not current_user.is_admin:
|
||||
for item_id in item_ids:
|
||||
try:
|
||||
queue_item = ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
|
||||
if queue_item.user_id != current_user.user_id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail=f"You do not have permission to retry queue item {item_id}"
|
||||
)
|
||||
except SessionQueueItemNotFoundError:
|
||||
# Skip items that don't exist - they will be handled by retry_items_by_id
|
||||
continue
|
||||
|
||||
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while retrying queue items: {e}")
|
||||
|
||||
@@ -248,15 +340,25 @@ async def retry_items_by_id(
|
||||
},
|
||||
)
|
||||
async def clear(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> ClearResult:
|
||||
"""Clears the queue entirely, immediately canceling the currently-executing session"""
|
||||
"""Clears the queue entirely. Admin users clear all items; non-admin users only clear their own items. If there's a currently-executing item, users can only cancel it if they own it or are an admin."""
|
||||
try:
|
||||
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
if queue_item is not None:
|
||||
# Check authorization for canceling the current item
|
||||
if queue_item.user_id != current_user.user_id and not current_user.is_admin:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You do not have permission to cancel the currently executing queue item"
|
||||
)
|
||||
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
|
||||
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
|
||||
# Admin users can clear all items, non-admin users can only clear their own
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id, user_id=user_id)
|
||||
return clear_result
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while clearing queue: {e}")
|
||||
|
||||
@@ -269,11 +371,14 @@ async def clear(
|
||||
},
|
||||
)
|
||||
async def prune(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> PruneResult:
|
||||
"""Prunes all completed or errored queue items"""
|
||||
"""Prunes all completed or errored queue items. Non-admin users can only prune their own items."""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
|
||||
# Admin users can prune all items, non-admin users can only prune their own
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.session_queue.prune(queue_id, user_id=user_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while pruning queue: {e}")
|
||||
|
||||
@@ -286,11 +391,15 @@ async def prune(
|
||||
},
|
||||
)
|
||||
async def get_current_queue_item(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> Optional[SessionQueueItem]:
|
||||
"""Gets the currently execution queue item"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
if item is not None:
|
||||
item = sanitize_queue_item_for_user(item, current_user.user_id, current_user.is_admin)
|
||||
return item
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while getting current queue item: {e}")
|
||||
|
||||
@@ -303,11 +412,15 @@ async def get_current_queue_item(
|
||||
},
|
||||
)
|
||||
async def get_next_queue_item(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> Optional[SessionQueueItem]:
|
||||
"""Gets the next queue item, without executing it"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
|
||||
item = ApiDependencies.invoker.services.session_queue.get_next(queue_id)
|
||||
if item is not None:
|
||||
item = sanitize_queue_item_for_user(item, current_user.user_id, current_user.is_admin)
|
||||
return item
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while getting next queue item: {e}")
|
||||
|
||||
@@ -320,11 +433,13 @@ async def get_next_queue_item(
|
||||
},
|
||||
)
|
||||
async def get_queue_status(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionQueueAndProcessorStatus:
|
||||
"""Gets the status of the session queue"""
|
||||
"""Gets the status of the session queue. Non-admin users see only their own counts and cannot see current item details unless they own it."""
|
||||
try:
|
||||
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=user_id)
|
||||
processor = ApiDependencies.invoker.services.session_processor.get_status()
|
||||
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
|
||||
except Exception as e:
|
||||
@@ -339,12 +454,16 @@ async def get_queue_status(
|
||||
},
|
||||
)
|
||||
async def get_batch_status(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
batch_id: str = Path(description="The batch to get the status of"),
|
||||
) -> BatchStatus:
|
||||
"""Gets the status of the session queue"""
|
||||
"""Gets the status of a batch. Non-admin users only see their own batches."""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.session_queue.get_batch_status(
|
||||
queue_id=queue_id, batch_id=batch_id, user_id=user_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while getting batch status: {e}")
|
||||
|
||||
@@ -358,6 +477,7 @@ async def get_batch_status(
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
async def get_queue_item(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
item_id: int = Path(description="The queue item to get"),
|
||||
) -> SessionQueueItem:
|
||||
@@ -366,7 +486,8 @@ async def get_queue_item(
|
||||
queue_item = ApiDependencies.invoker.services.session_queue.get_queue_item(item_id=item_id)
|
||||
if queue_item.queue_id != queue_id:
|
||||
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
|
||||
return queue_item
|
||||
# Sanitize item for non-admin users
|
||||
return sanitize_queue_item_for_user(queue_item, current_user.user_id, current_user.is_admin)
|
||||
except SessionQueueItemNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
|
||||
except Exception as e:
|
||||
@@ -378,12 +499,24 @@ async def get_queue_item(
|
||||
operation_id="delete_queue_item",
|
||||
)
|
||||
async def delete_queue_item(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
item_id: int = Path(description="The queue item to delete"),
|
||||
) -> None:
|
||||
"""Deletes a queue item"""
|
||||
"""Deletes a queue item. Users can only delete their own items unless they are an admin."""
|
||||
try:
|
||||
# Get the queue item to check ownership
|
||||
queue_item = ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
|
||||
|
||||
# Check authorization: user must own the item or be an admin
|
||||
if queue_item.user_id != current_user.user_id and not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="You do not have permission to delete this queue item")
|
||||
|
||||
ApiDependencies.invoker.services.session_queue.delete_queue_item(item_id)
|
||||
except SessionQueueItemNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting queue item: {e}")
|
||||
|
||||
@@ -396,14 +529,24 @@ async def delete_queue_item(
|
||||
},
|
||||
)
|
||||
async def cancel_queue_item(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
item_id: int = Path(description="The queue item to cancel"),
|
||||
) -> SessionQueueItem:
|
||||
"""Deletes a queue item"""
|
||||
"""Cancels a queue item. Users can only cancel their own items unless they are an admin."""
|
||||
try:
|
||||
# Get the queue item to check ownership
|
||||
queue_item = ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
|
||||
|
||||
# Check authorization: user must own the item or be an admin
|
||||
if queue_item.user_id != current_user.user_id and not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="You do not have permission to cancel this queue item")
|
||||
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
|
||||
except SessionQueueItemNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling queue item: {e}")
|
||||
|
||||
@@ -414,13 +557,15 @@ async def cancel_queue_item(
|
||||
responses={200: {"model": SessionQueueCountsByDestination}},
|
||||
)
|
||||
async def counts_by_destination(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to query"),
|
||||
destination: str = Query(description="The destination to query"),
|
||||
) -> SessionQueueCountsByDestination:
|
||||
"""Gets the counts of queue items by destination"""
|
||||
"""Gets the counts of queue items by destination. Non-admin users only see their own items."""
|
||||
try:
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
queue_id=queue_id, destination=destination, user_id=user_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching counts by destination: {e}")
|
||||
@@ -432,13 +577,16 @@ async def counts_by_destination(
|
||||
responses={200: {"model": DeleteByDestinationResult}},
|
||||
)
|
||||
async def delete_by_destination(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to query"),
|
||||
destination: str = Path(description="The destination to query"),
|
||||
) -> DeleteByDestinationResult:
|
||||
"""Deletes all items with the given destination"""
|
||||
"""Deletes all items with the given destination. Non-admin users can only delete their own items."""
|
||||
try:
|
||||
# Admin users can delete all items, non-admin users can only delete their own
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.session_queue.delete_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
queue_id=queue_id, destination=destination, user_id=user_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting by destination: {e}")
|
||||
|
||||
@@ -6,6 +6,7 @@ from fastapi import APIRouter, Body, File, HTTPException, Path, Query, UploadFil
|
||||
from fastapi.responses import FileResponse
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
@@ -33,16 +34,25 @@ workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
|
||||
},
|
||||
)
|
||||
async def get_workflow(
|
||||
current_user: CurrentUserOrDefault,
|
||||
workflow_id: str = Path(description="The workflow to get"),
|
||||
) -> WorkflowRecordWithThumbnailDTO:
|
||||
"""Gets a workflow"""
|
||||
try:
|
||||
thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id)
|
||||
workflow = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump())
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
if config.multiuser:
|
||||
is_default = workflow.workflow.meta.category is WorkflowCategory.Default
|
||||
is_owner = workflow.user_id == current_user.user_id
|
||||
if not (is_default or is_owner or workflow.is_public or current_user.is_admin):
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this workflow")
|
||||
|
||||
thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id)
|
||||
return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump())
|
||||
|
||||
|
||||
@workflows_router.patch(
|
||||
"/i/{workflow_id}",
|
||||
@@ -52,10 +62,21 @@ async def get_workflow(
|
||||
},
|
||||
)
|
||||
async def update_workflow(
|
||||
current_user: CurrentUserOrDefault,
|
||||
workflow: Workflow = Body(description="The updated workflow", embed=True),
|
||||
) -> WorkflowRecordDTO:
|
||||
"""Updates a workflow"""
|
||||
return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow)
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
if config.multiuser:
|
||||
try:
|
||||
existing = ApiDependencies.invoker.services.workflow_records.get(workflow.id)
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
if not current_user.is_admin and existing.user_id != current_user.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
|
||||
# Pass user_id for defense-in-depth SQL scoping; admins pass None to allow any.
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow, user_id=user_id)
|
||||
|
||||
|
||||
@workflows_router.delete(
|
||||
@@ -63,15 +84,25 @@ async def update_workflow(
|
||||
operation_id="delete_workflow",
|
||||
)
|
||||
async def delete_workflow(
|
||||
current_user: CurrentUserOrDefault,
|
||||
workflow_id: str = Path(description="The workflow to delete"),
|
||||
) -> None:
|
||||
"""Deletes a workflow"""
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
if config.multiuser:
|
||||
try:
|
||||
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
if not current_user.is_admin and existing.user_id != current_user.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to delete this workflow")
|
||||
try:
|
||||
ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id)
|
||||
except WorkflowThumbnailFileNotFoundException:
|
||||
# It's OK if the workflow has no thumbnail file. We can still delete the workflow.
|
||||
pass
|
||||
ApiDependencies.invoker.services.workflow_records.delete(workflow_id)
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
ApiDependencies.invoker.services.workflow_records.delete(workflow_id, user_id=user_id)
|
||||
|
||||
|
||||
@workflows_router.post(
|
||||
@@ -82,10 +113,17 @@ async def delete_workflow(
|
||||
},
|
||||
)
|
||||
async def create_workflow(
|
||||
current_user: CurrentUserOrDefault,
|
||||
workflow: WorkflowWithoutID = Body(description="The workflow to create", embed=True),
|
||||
) -> WorkflowRecordDTO:
|
||||
"""Creates a workflow"""
|
||||
return ApiDependencies.invoker.services.workflow_records.create(workflow=workflow)
|
||||
# In single-user mode, workflows are owned by 'system' and shared by default so all legacy/single-user
|
||||
# workflows remain visible. In multiuser mode, workflows are private to the creator by default.
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
is_public = not config.multiuser
|
||||
return ApiDependencies.invoker.services.workflow_records.create(
|
||||
workflow=workflow, user_id=current_user.user_id, is_public=is_public
|
||||
)
|
||||
|
||||
|
||||
@workflows_router.get(
|
||||
@@ -96,6 +134,7 @@ async def create_workflow(
|
||||
},
|
||||
)
|
||||
async def list_workflows(
|
||||
current_user: CurrentUserOrDefault,
|
||||
page: int = Query(default=0, description="The page to get"),
|
||||
per_page: Optional[int] = Query(default=None, description="The number of workflows per page"),
|
||||
order_by: WorkflowRecordOrderBy = Query(
|
||||
@@ -106,8 +145,19 @@ async def list_workflows(
|
||||
tags: Optional[list[str]] = Query(default=None, description="The tags of workflow to get"),
|
||||
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
|
||||
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
|
||||
is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
|
||||
) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]:
|
||||
"""Gets a page of workflows"""
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
|
||||
# In multiuser mode, scope user-category workflows to the current user unless fetching shared workflows.
|
||||
# Admins skip the user_id filter so they can see and manage all workflows including system-owned ones.
|
||||
user_id_filter: Optional[str] = None
|
||||
if config.multiuser and not current_user.is_admin:
|
||||
has_user_category = not categories or WorkflowCategory.User in categories
|
||||
if has_user_category and is_public is not True:
|
||||
user_id_filter = current_user.user_id
|
||||
|
||||
workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = []
|
||||
workflows = ApiDependencies.invoker.services.workflow_records.get_many(
|
||||
order_by=order_by,
|
||||
@@ -118,6 +168,8 @@ async def list_workflows(
|
||||
categories=categories,
|
||||
tags=tags,
|
||||
has_been_opened=has_been_opened,
|
||||
user_id=user_id_filter,
|
||||
is_public=is_public,
|
||||
)
|
||||
for workflow in workflows.items:
|
||||
workflows_with_thumbnails.append(
|
||||
@@ -143,15 +195,20 @@ async def list_workflows(
|
||||
},
|
||||
)
|
||||
async def set_workflow_thumbnail(
|
||||
current_user: CurrentUserOrDefault,
|
||||
workflow_id: str = Path(description="The workflow to update"),
|
||||
image: UploadFile = File(description="The image file to upload"),
|
||||
):
|
||||
"""Sets a workflow's thumbnail image"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
|
||||
|
||||
if not image.content_type or not image.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
@@ -177,14 +234,19 @@ async def set_workflow_thumbnail(
|
||||
},
|
||||
)
|
||||
async def delete_workflow_thumbnail(
|
||||
current_user: CurrentUserOrDefault,
|
||||
workflow_id: str = Path(description="The workflow to update"),
|
||||
):
|
||||
"""Removes a workflow's thumbnail image"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id)
|
||||
except ValueError as e:
|
||||
@@ -206,8 +268,12 @@ async def delete_workflow_thumbnail(
|
||||
async def get_workflow_thumbnail(
|
||||
workflow_id: str = Path(description="The id of the workflow thumbnail to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a workflow's thumbnail image"""
|
||||
"""Gets a workflow's thumbnail image.
|
||||
|
||||
This endpoint is intentionally unauthenticated because browsers load images
|
||||
via <img src> tags which cannot send Bearer tokens. Workflow IDs are UUIDs,
|
||||
providing security through unguessability.
|
||||
"""
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.workflow_thumbnails.get_path(workflow_id)
|
||||
|
||||
@@ -223,37 +289,91 @@ async def get_workflow_thumbnail(
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@workflows_router.patch(
|
||||
"/i/{workflow_id}/is_public",
|
||||
operation_id="update_workflow_is_public",
|
||||
responses={
|
||||
200: {"model": WorkflowRecordDTO},
|
||||
},
|
||||
)
|
||||
async def update_workflow_is_public(
|
||||
current_user: CurrentUserOrDefault,
|
||||
workflow_id: str = Path(description="The workflow to update"),
|
||||
is_public: bool = Body(description="Whether the workflow should be shared publicly", embed=True),
|
||||
) -> WorkflowRecordDTO:
|
||||
"""Updates whether a workflow is shared publicly"""
|
||||
try:
|
||||
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
|
||||
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.workflow_records.update_is_public(
|
||||
workflow_id=workflow_id, is_public=is_public, user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
@workflows_router.get("/tags", operation_id="get_all_tags")
|
||||
async def get_all_tags(
|
||||
current_user: CurrentUserOrDefault,
|
||||
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
|
||||
is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
|
||||
) -> list[str]:
|
||||
"""Gets all unique tags from workflows"""
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
user_id_filter: Optional[str] = None
|
||||
if config.multiuser and not current_user.is_admin:
|
||||
has_user_category = not categories or WorkflowCategory.User in categories
|
||||
if has_user_category and is_public is not True:
|
||||
user_id_filter = current_user.user_id
|
||||
|
||||
return ApiDependencies.invoker.services.workflow_records.get_all_tags(categories=categories)
|
||||
return ApiDependencies.invoker.services.workflow_records.get_all_tags(
|
||||
categories=categories, user_id=user_id_filter, is_public=is_public
|
||||
)
|
||||
|
||||
|
||||
@workflows_router.get("/counts_by_tag", operation_id="get_counts_by_tag")
|
||||
async def get_counts_by_tag(
|
||||
current_user: CurrentUserOrDefault,
|
||||
tags: list[str] = Query(description="The tags to get counts for"),
|
||||
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
|
||||
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
|
||||
is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
|
||||
) -> dict[str, int]:
|
||||
"""Counts workflows by tag"""
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
user_id_filter: Optional[str] = None
|
||||
if config.multiuser and not current_user.is_admin:
|
||||
has_user_category = not categories or WorkflowCategory.User in categories
|
||||
if has_user_category and is_public is not True:
|
||||
user_id_filter = current_user.user_id
|
||||
|
||||
return ApiDependencies.invoker.services.workflow_records.counts_by_tag(
|
||||
tags=tags, categories=categories, has_been_opened=has_been_opened
|
||||
tags=tags, categories=categories, has_been_opened=has_been_opened, user_id=user_id_filter, is_public=is_public
|
||||
)
|
||||
|
||||
|
||||
@workflows_router.get("/counts_by_category", operation_id="counts_by_category")
|
||||
async def counts_by_category(
|
||||
current_user: CurrentUserOrDefault,
|
||||
categories: list[WorkflowCategory] = Query(description="The categories to include"),
|
||||
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
|
||||
is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
|
||||
) -> dict[str, int]:
|
||||
"""Counts workflows by category"""
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
user_id_filter: Optional[str] = None
|
||||
if config.multiuser and not current_user.is_admin:
|
||||
has_user_category = WorkflowCategory.User in categories
|
||||
if has_user_category and is_public is not True:
|
||||
user_id_filter = current_user.user_id
|
||||
|
||||
return ApiDependencies.invoker.services.workflow_records.counts_by_category(
|
||||
categories=categories, has_been_opened=has_been_opened
|
||||
categories=categories, has_been_opened=has_been_opened, user_id=user_id_filter, is_public=is_public
|
||||
)
|
||||
|
||||
|
||||
@@ -262,7 +382,18 @@ async def counts_by_category(
|
||||
operation_id="update_opened_at",
|
||||
)
|
||||
async def update_opened_at(
|
||||
current_user: CurrentUserOrDefault,
|
||||
workflow_id: str = Path(description="The workflow to update"),
|
||||
) -> None:
|
||||
"""Updates the opened_at field of a workflow"""
|
||||
ApiDependencies.invoker.services.workflow_records.update_opened_at(workflow_id)
|
||||
try:
|
||||
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
|
||||
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
ApiDependencies.invoker.services.workflow_records.update_opened_at(workflow_id, user_id=user_id)
|
||||
|
||||
@@ -6,6 +6,7 @@ from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from socketio import ASGIApp, AsyncServer
|
||||
|
||||
from invokeai.app.services.auth.token_service import verify_token
|
||||
from invokeai.app.services.events.events_common import (
|
||||
BatchEnqueuedEvent,
|
||||
BulkDownloadCompleteEvent,
|
||||
@@ -35,8 +36,12 @@ from invokeai.app.services.events.events_common import (
|
||||
QueueClearedEvent,
|
||||
QueueEventBase,
|
||||
QueueItemStatusChangedEvent,
|
||||
RecallParametersUpdatedEvent,
|
||||
register_events,
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
class QueueSubscriptionEvent(BaseModel):
|
||||
@@ -61,6 +66,7 @@ QUEUE_EVENTS = {
|
||||
QueueItemStatusChangedEvent,
|
||||
BatchEnqueuedEvent,
|
||||
QueueClearedEvent,
|
||||
RecallParametersUpdatedEvent,
|
||||
}
|
||||
|
||||
MODEL_EVENTS = {
|
||||
@@ -94,6 +100,13 @@ class SocketIO:
|
||||
self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io")
|
||||
app.mount("/ws", self._app)
|
||||
|
||||
# Track user information for each socket connection
|
||||
self._socket_users: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# Set up authentication middleware
|
||||
self._sio.on("connect", handler=self._handle_connect)
|
||||
self._sio.on("disconnect", handler=self._handle_disconnect)
|
||||
|
||||
self._sio.on(self._sub_queue, handler=self._handle_sub_queue)
|
||||
self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue)
|
||||
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
|
||||
@@ -103,23 +116,247 @@ class SocketIO:
|
||||
register_events(MODEL_EVENTS, self._handle_model_event)
|
||||
register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event)
|
||||
|
||||
async def _handle_connect(self, sid: str, environ: dict, auth: dict | None) -> bool:
|
||||
"""Handle socket connection and authenticate the user.
|
||||
|
||||
Returns True to accept the connection, False to reject it.
|
||||
Stores user_id in the internal socket users dict for later use.
|
||||
|
||||
In multiuser mode, connections without a valid token are rejected outright
|
||||
so that anonymous clients cannot subscribe to queue rooms and observe
|
||||
queue activity belonging to other users. In single-user mode, unauthenticated
|
||||
connections are accepted as the system admin user.
|
||||
"""
|
||||
# Extract token from auth data or headers
|
||||
token = None
|
||||
if auth and isinstance(auth, dict):
|
||||
token = auth.get("token")
|
||||
|
||||
if not token and environ:
|
||||
# Try to get token from headers
|
||||
headers = environ.get("HTTP_AUTHORIZATION", "")
|
||||
if headers.startswith("Bearer "):
|
||||
token = headers[7:]
|
||||
|
||||
# Verify the token
|
||||
if token:
|
||||
token_data = verify_token(token)
|
||||
if token_data:
|
||||
# In multiuser mode, also verify the backing user record still
|
||||
# exists and is active — mirrors the REST auth check in
|
||||
# auth_dependencies.py. A deleted or deactivated user whose
|
||||
# JWT has not yet expired must not be allowed to open a socket.
|
||||
if self._is_multiuser_enabled():
|
||||
try:
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
|
||||
user = ApiDependencies.invoker.services.users.get(token_data.user_id)
|
||||
if user is None or not user.is_active:
|
||||
logger.warning(f"Rejecting socket {sid}: user {token_data.user_id} not found or inactive")
|
||||
return False
|
||||
except Exception:
|
||||
# If user service is unavailable, fail closed
|
||||
logger.warning(f"Rejecting socket {sid}: unable to verify user record")
|
||||
return False
|
||||
|
||||
# Store user_id and is_admin in socket users dict
|
||||
self._socket_users[sid] = {
|
||||
"user_id": token_data.user_id,
|
||||
"is_admin": token_data.is_admin,
|
||||
}
|
||||
logger.info(
|
||||
f"Socket {sid} connected with user_id: {token_data.user_id}, is_admin: {token_data.is_admin}"
|
||||
)
|
||||
return True
|
||||
|
||||
# No valid token provided. In multiuser mode this is not allowed — reject
|
||||
# the connection so anonymous clients cannot subscribe to queue rooms.
|
||||
# In single-user mode, fall through and accept the socket as system admin.
|
||||
if self._is_multiuser_enabled():
|
||||
logger.warning(
|
||||
f"Rejecting socket {sid} connection: multiuser mode is enabled and no valid auth token was provided"
|
||||
)
|
||||
return False
|
||||
|
||||
self._socket_users[sid] = {
|
||||
"user_id": "system",
|
||||
"is_admin": True,
|
||||
}
|
||||
logger.debug(f"Socket {sid} connected as system admin (single-user mode)")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _is_multiuser_enabled() -> bool:
|
||||
"""Check whether multiuser mode is enabled. Fails closed if configuration
|
||||
is not yet initialized, which should not happen in practice but prevents
|
||||
accidentally opening the socket during startup races."""
|
||||
try:
|
||||
# Imported here to avoid a circular import at module load time.
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
|
||||
return bool(ApiDependencies.invoker.services.configuration.multiuser)
|
||||
except Exception:
|
||||
# If dependencies are not initialized, fail closed (treat as multiuser)
|
||||
# so we never accidentally admit an anonymous socket.
|
||||
return True
|
||||
|
||||
async def _handle_disconnect(self, sid: str) -> None:
|
||||
"""Handle socket disconnection and cleanup user info."""
|
||||
if sid in self._socket_users:
|
||||
del self._socket_users[sid]
|
||||
logger.debug(f"Socket {sid} disconnected and cleaned up")
|
||||
|
||||
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
|
||||
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
|
||||
"""Handle queue subscription and add socket to both queue and user-specific rooms."""
|
||||
queue_id = QueueSubscriptionEvent(**data).queue_id
|
||||
|
||||
# Check if we have user info for this socket. In multiuser mode _handle_connect
|
||||
# will have already rejected any socket without a valid token, so missing user
|
||||
# info here is a bug — refuse the subscription rather than silently falling back
|
||||
# to an anonymous system user who could then receive queue item events.
|
||||
if sid not in self._socket_users:
|
||||
if self._is_multiuser_enabled():
|
||||
logger.warning(
|
||||
f"Refusing queue subscription for socket {sid}: no user info (socket not authenticated via connect event)"
|
||||
)
|
||||
return
|
||||
# Single-user mode: safe to fall back to the system admin user.
|
||||
self._socket_users[sid] = {
|
||||
"user_id": "system",
|
||||
"is_admin": True,
|
||||
}
|
||||
|
||||
user_id = self._socket_users[sid]["user_id"]
|
||||
is_admin = self._socket_users[sid]["is_admin"]
|
||||
|
||||
# Add socket to the queue room
|
||||
await self._sio.enter_room(sid, queue_id)
|
||||
|
||||
# Also add socket to a user-specific room for event filtering
|
||||
user_room = f"user:{user_id}"
|
||||
await self._sio.enter_room(sid, user_room)
|
||||
|
||||
# If admin, also add to admin room to receive all events
|
||||
if is_admin:
|
||||
await self._sio.enter_room(sid, "admin")
|
||||
|
||||
logger.debug(
|
||||
f"Socket {sid} (user_id: {user_id}, is_admin: {is_admin}) subscribed to queue {queue_id} and user room {user_room}"
|
||||
)
|
||||
|
||||
async def _handle_unsub_queue(self, sid: str, data: Any) -> None:
|
||||
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
|
||||
|
||||
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
|
||||
# In multiuser mode, only allow authenticated sockets to subscribe.
|
||||
# Bulk download events are routed to user-specific rooms, so the
|
||||
# bulk_download_id room subscription is only kept for single-user
|
||||
# backward compatibility.
|
||||
if self._is_multiuser_enabled() and sid not in self._socket_users:
|
||||
logger.warning(f"Refusing bulk download subscription for unknown socket {sid} in multiuser mode")
|
||||
return
|
||||
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
|
||||
|
||||
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
|
||||
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
|
||||
|
||||
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
|
||||
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
|
||||
"""Handle queue events with user isolation.
|
||||
|
||||
All queue item events (invocation events AND QueueItemStatusChangedEvent) are
|
||||
private to the owning user and admins. They carry unsanitized user_id, batch_id,
|
||||
session_id, origin, destination and error metadata, and must never be broadcast
|
||||
to the whole queue room — otherwise any other authenticated subscriber could
|
||||
observe cross-user queue activity.
|
||||
|
||||
RecallParametersUpdatedEvent is also private to the owner + admins.
|
||||
|
||||
BatchEnqueuedEvent carries the enqueuing user's batch_id/origin/counts and
|
||||
is also routed privately. QueueClearedEvent is the only queue event that
|
||||
is still broadcast to the whole queue room.
|
||||
|
||||
IMPORTANT: Check InvocationEventBase BEFORE QueueItemEventBase since InvocationEventBase
|
||||
inherits from QueueItemEventBase. The order of isinstance checks matters!
|
||||
"""
|
||||
try:
|
||||
event_name, event_data = event
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from invokeai.app.services.events.events_common import InvocationEventBase, QueueItemEventBase
|
||||
|
||||
# Check InvocationEventBase FIRST (before QueueItemEventBase) since it's a subclass
|
||||
# Invocation events (progress, started, complete, error) are private to owner + admins
|
||||
if isinstance(event_data, InvocationEventBase) and hasattr(event_data, "user_id"):
|
||||
user_room = f"user:{event_data.user_id}"
|
||||
|
||||
# Emit to the user's room
|
||||
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
|
||||
|
||||
# Also emit to admin room so admins can see all events, but strip image preview data
|
||||
# from InvocationProgressEvent to prevent admins from seeing other users' image content
|
||||
if isinstance(event_data, InvocationProgressEvent):
|
||||
admin_event_data = event_data.model_copy(update={"image": None})
|
||||
await self._sio.emit(event=event_name, data=admin_event_data.model_dump(mode="json"), room="admin")
|
||||
else:
|
||||
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
|
||||
|
||||
logger.debug(f"Emitted private invocation event {event_name} to user room {user_room} and admin room")
|
||||
|
||||
# Other queue item events (QueueItemStatusChangedEvent) carry unsanitized
|
||||
# user_id, batch_id, session_id, origin, destination and error metadata.
|
||||
# They are private to the owning user + admins — never broadcast to the
|
||||
# full queue room.
|
||||
elif isinstance(event_data, QueueItemEventBase) and hasattr(event_data, "user_id"):
|
||||
user_room = f"user:{event_data.user_id}"
|
||||
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
|
||||
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
|
||||
|
||||
logger.debug(f"Emitted private queue item event {event_name} to user room {user_room} and admin room")
|
||||
|
||||
# RecallParametersUpdatedEvent is private - only emit to owner + admins
|
||||
elif isinstance(event_data, RecallParametersUpdatedEvent):
|
||||
user_room = f"user:{event_data.user_id}"
|
||||
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
|
||||
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
|
||||
logger.debug(f"Emitted private recall_parameters_updated event to user room {user_room} and admin room")
|
||||
|
||||
# BatchEnqueuedEvent carries the enqueuing user's batch_id, origin, and
|
||||
# enqueued counts. Route it privately to the owner + admins so other
|
||||
# users do not observe cross-user batch activity.
|
||||
elif isinstance(event_data, BatchEnqueuedEvent):
|
||||
user_room = f"user:{event_data.user_id}"
|
||||
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
|
||||
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
|
||||
logger.debug(f"Emitted private batch_enqueued event to user room {user_room} and admin room")
|
||||
|
||||
else:
|
||||
# For remaining queue events (e.g. QueueClearedEvent) that do not
|
||||
# carry user identity, emit to all subscribers in the queue room.
|
||||
await self._sio.emit(
|
||||
event=event_name, data=event_data.model_dump(mode="json"), room=event_data.queue_id
|
||||
)
|
||||
logger.debug(
|
||||
f"Emitted general queue event {event_name} to all subscribers in queue {event_data.queue_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Log any unhandled exceptions in event handling to prevent silent failures
|
||||
logger.error(f"Error handling queue event {event[0]}: {e}", exc_info=True)
|
||||
|
||||
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | DownloadEventBase]) -> None:
|
||||
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))
|
||||
|
||||
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
|
||||
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id)
|
||||
event_name, event_data = event
|
||||
# Route to user-specific + admin rooms so that other authenticated
|
||||
# users cannot learn the bulk_download_item_name (the capability token
|
||||
# needed to fetch the zip from the unauthenticated GET endpoint).
|
||||
# In single-user mode (user_id="system"), fall back to the shared
|
||||
# bulk_download_id room for backward compatibility.
|
||||
if hasattr(event_data, "user_id") and event_data.user_id != "system":
|
||||
user_room = f"user:{event_data.user_id}"
|
||||
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
|
||||
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
|
||||
else:
|
||||
await self._sio.emit(
|
||||
event=event_name, data=event_data.model_dump(mode="json"), room=event_data.bulk_download_id
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@ from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.app.api.routers import (
|
||||
app_info,
|
||||
auth,
|
||||
board_images,
|
||||
boards,
|
||||
client_state,
|
||||
@@ -24,6 +25,7 @@ from invokeai.app.api.routers import (
|
||||
images,
|
||||
model_manager,
|
||||
model_relationships,
|
||||
recall_parameters,
|
||||
session_queue,
|
||||
style_presets,
|
||||
utilities,
|
||||
@@ -77,6 +79,50 @@ app = FastAPI(
|
||||
)
|
||||
|
||||
|
||||
class SlidingWindowTokenMiddleware(BaseHTTPMiddleware):
|
||||
"""Refresh the JWT token on each authenticated response.
|
||||
|
||||
When a request includes a valid Bearer token, the response includes a
|
||||
X-Refreshed-Token header with a new token that has a fresh expiry.
|
||||
This implements sliding-window session expiry: the session only expires
|
||||
after a period of *inactivity*, not a fixed time after login.
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
|
||||
response = await call_next(request)
|
||||
|
||||
# Only refresh on mutating requests (POST/PUT/PATCH/DELETE) — these indicate
|
||||
# genuine user activity. GET requests are often background fetches (RTK Query
|
||||
# cache revalidation, refetch-on-focus, etc.) and should not reset the
|
||||
# inactivity timer.
|
||||
if response.status_code < 400 and request.method in ("POST", "PUT", "PATCH", "DELETE"):
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[7:]
|
||||
try:
|
||||
from datetime import timedelta
|
||||
|
||||
from invokeai.app.api.routers.auth import TOKEN_EXPIRATION_NORMAL, TOKEN_EXPIRATION_REMEMBER_ME
|
||||
from invokeai.app.services.auth.token_service import create_access_token, verify_token
|
||||
|
||||
token_data = verify_token(token)
|
||||
if token_data is not None:
|
||||
# Use the remember_me claim from the token to determine the
|
||||
# correct refresh duration. This avoids the bug where a 7-day
|
||||
# token with <24h remaining would be silently downgraded to 1 day.
|
||||
if token_data.remember_me:
|
||||
expires_delta = timedelta(days=TOKEN_EXPIRATION_REMEMBER_ME)
|
||||
else:
|
||||
expires_delta = timedelta(days=TOKEN_EXPIRATION_NORMAL)
|
||||
|
||||
new_token = create_access_token(token_data, expires_delta)
|
||||
response.headers["X-Refreshed-Token"] = new_token
|
||||
except Exception:
|
||||
pass # Don't fail the request if token refresh fails
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class RedirectRootWithQueryStringMiddleware(BaseHTTPMiddleware):
|
||||
"""When a request is made to the root path with a query string, redirect to the root path without the query string.
|
||||
|
||||
@@ -97,6 +143,7 @@ class RedirectRootWithQueryStringMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
# Add the middleware
|
||||
app.add_middleware(RedirectRootWithQueryStringMiddleware)
|
||||
app.add_middleware(SlidingWindowTokenMiddleware)
|
||||
|
||||
|
||||
# Add event handler
|
||||
@@ -115,12 +162,15 @@ app.add_middleware(
|
||||
allow_credentials=app_config.allow_credentials,
|
||||
allow_methods=app_config.allow_methods,
|
||||
allow_headers=app_config.allow_headers,
|
||||
expose_headers=["X-Refreshed-Token"],
|
||||
)
|
||||
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
|
||||
# Include all routers
|
||||
# Authentication router should be first so it's registered before protected routes
|
||||
app.include_router(auth.auth_router, prefix="/api")
|
||||
app.include_router(utilities.utilities_router, prefix="/api")
|
||||
app.include_router(model_manager.model_manager_router, prefix="/api")
|
||||
app.include_router(download_queue.download_queue_router, prefix="/api")
|
||||
@@ -133,6 +183,7 @@ app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||
app.include_router(workflows.workflows_router, prefix="/api")
|
||||
app.include_router(style_presets.style_presets_router, prefix="/api")
|
||||
app.include_router(client_state.client_state_router, prefix="/api")
|
||||
app.include_router(recall_parameters.recall_parameters_router, prefix="/api")
|
||||
|
||||
app.openapi = get_openapi_func(app)
|
||||
|
||||
|
||||
715
invokeai/app/invocations/anima_denoise.py
Normal file
715
invokeai/app/invocations/anima_denoise.py
Normal file
@@ -0,0 +1,715 @@
|
||||
"""Anima denoising invocation.
|
||||
|
||||
Implements the rectified flow denoising loop for Anima models:
|
||||
- Direct prediction: denoised = input - output * sigma
|
||||
- Fixed shift=3.0 via loglinear_timestep_shift (Flux paper by Black Forest Labs)
|
||||
- Timestep convention: timestep = sigma * 1.0 (raw sigma, NOT 1-sigma like Z-Image)
|
||||
- NO v-prediction negation (unlike Z-Image)
|
||||
- 3D latent space: [B, C, T, H, W] with T=1 for images
|
||||
- 16 latent channels, 8x spatial compression
|
||||
|
||||
Key differences from Z-Image denoise:
|
||||
- Anima uses fixed shift=3.0, Z-Image uses dynamic shift based on resolution
|
||||
- Anima: timestep = sigma (raw), Z-Image: model_t = 1.0 - sigma
|
||||
- Anima: noise_pred = model_output (direct), Z-Image: noise_pred = -model_output (v-pred)
|
||||
- Anima transformer takes (x, timesteps, context, t5xxl_ids, t5xxl_weights)
|
||||
- Anima uses 3D latents directly, Z-Image converts 4D -> list of 5D
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from contextlib import ExitStack
|
||||
from typing import Callable, Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
AnimaConditioningField,
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.anima.anima_transformer_patch import patch_anima_for_regional_prompting
|
||||
from invokeai.backend.anima.conditioning_data import AnimaRegionalTextConditioning, AnimaTextConditioning
|
||||
from invokeai.backend.anima.regional_prompting import AnimaRegionalPromptingExtension
|
||||
from invokeai.backend.flux.schedulers import ANIMA_SCHEDULER_LABELS, ANIMA_SCHEDULER_MAP, ANIMA_SCHEDULER_NAME_VALUES
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.anima_lora_constants import ANIMA_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import (
|
||||
RectifiedFlowInpaintExtension,
|
||||
assert_broadcastable,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import AnimaConditioningInfo, Range
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
# Anima uses 8x spatial compression (VAE downsamples by 2^3)
|
||||
ANIMA_LATENT_SCALE_FACTOR = 8
|
||||
# Anima uses 16 latent channels
|
||||
ANIMA_LATENT_CHANNELS = 16
|
||||
# Anima uses fixed shift=3.0 for the rectified flow schedule
|
||||
ANIMA_SHIFT = 3.0
|
||||
# Anima uses raw sigma values as timesteps (no rescaling)
|
||||
ANIMA_MULTIPLIER = 1.0
|
||||
|
||||
|
||||
def loglinear_timestep_shift(alpha: float, t: float) -> float:
|
||||
"""Apply log-linear timestep shift to a noise schedule value.
|
||||
|
||||
This shift biases the noise schedule toward higher noise levels, as described
|
||||
in the Flux model (Black Forest Labs, 2024). With alpha > 1, the model spends
|
||||
proportionally more denoising steps at higher noise levels.
|
||||
|
||||
Formula: sigma = alpha * t / (1 + (alpha - 1) * t)
|
||||
|
||||
Args:
|
||||
alpha: Shift factor (3.0 for Anima, resolution-dependent for Flux).
|
||||
t: Timestep value in [0, 1].
|
||||
|
||||
Returns:
|
||||
Shifted timestep value.
|
||||
"""
|
||||
if alpha == 1.0:
|
||||
return t
|
||||
return alpha * t / (1 + (alpha - 1) * t)
|
||||
|
||||
|
||||
def inverse_loglinear_timestep_shift(alpha: float, sigma: float) -> float:
|
||||
"""Recover linear t from a shifted sigma value.
|
||||
|
||||
Inverse of loglinear_timestep_shift: given sigma = alpha * t / (1 + (alpha-1) * t),
|
||||
solve for t = sigma / (alpha - (alpha-1) * sigma).
|
||||
|
||||
This is needed for the inpainting extension, which expects linear t values
|
||||
for gradient mask thresholding. With Anima's shift=3.0, the difference
|
||||
between shifted sigma and linear t is large (e.g. at t=0.5, sigma=0.75),
|
||||
causing overly aggressive mask thresholding if sigma is used directly.
|
||||
|
||||
Args:
|
||||
alpha: Shift factor (3.0 for Anima).
|
||||
sigma: Shifted sigma value in [0, 1].
|
||||
|
||||
Returns:
|
||||
Linear t value in [0, 1].
|
||||
"""
|
||||
if alpha == 1.0:
|
||||
return sigma
|
||||
denominator = alpha - (alpha - 1) * sigma
|
||||
if abs(denominator) < 1e-8:
|
||||
return 1.0
|
||||
return sigma / denominator
|
||||
|
||||
|
||||
class AnimaInpaintExtension(RectifiedFlowInpaintExtension):
|
||||
"""Inpaint extension for Anima that accounts for the time-SNR shift.
|
||||
|
||||
Anima uses a fixed shift=3.0 which makes sigma values significantly larger
|
||||
than the corresponding linear t values. The base RectifiedFlowInpaintExtension
|
||||
uses t_prev for both gradient mask thresholding and noise mixing, which assumes
|
||||
linear t values.
|
||||
|
||||
This subclass:
|
||||
- Uses the LINEAR t for gradient mask thresholding (correct progressive reveal)
|
||||
- Uses the SHIFTED sigma for noise mixing (matches the denoiser's noise level)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
init_latents: torch.Tensor,
|
||||
inpaint_mask: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
shift: float = ANIMA_SHIFT,
|
||||
):
|
||||
assert_broadcastable(init_latents.shape, inpaint_mask.shape, noise.shape)
|
||||
self._init_latents = init_latents
|
||||
self._inpaint_mask = inpaint_mask
|
||||
self._noise = noise
|
||||
self._shift = shift
|
||||
|
||||
def merge_intermediate_latents_with_init_latents(
|
||||
self, intermediate_latents: torch.Tensor, sigma_prev: float
|
||||
) -> torch.Tensor:
|
||||
"""Merge intermediate latents with init latents, correcting for Anima's shift.
|
||||
|
||||
Args:
|
||||
intermediate_latents: The denoised latents at the current step.
|
||||
sigma_prev: The SHIFTED sigma value for the next step.
|
||||
"""
|
||||
# Recover linear t from shifted sigma for gradient mask thresholding.
|
||||
# This ensures the gradient mask is revealed at the correct pace.
|
||||
t_prev = inverse_loglinear_timestep_shift(self._shift, sigma_prev)
|
||||
mask = self._apply_mask_gradient_adjustment(t_prev)
|
||||
|
||||
# Use shifted sigma for noise mixing to match the denoiser's noise level.
|
||||
# The Euler step produces latents at noise level sigma_prev, so the
|
||||
# preserved regions must also be at sigma_prev noise level.
|
||||
noised_init_latents = self._noise * sigma_prev + (1.0 - sigma_prev) * self._init_latents
|
||||
|
||||
return intermediate_latents * mask + noised_init_latents * (1.0 - mask)
|
||||
|
||||
|
||||
@invocation(
|
||||
"anima_denoise",
|
||||
title="Denoise - Anima",
|
||||
tags=["image", "anima"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class AnimaDenoiseInvocation(BaseInvocation):
|
||||
"""Run the denoising process with an Anima model.
|
||||
|
||||
Uses rectified flow sampling with shift=3.0 and the Cosmos Predict2 DiT
|
||||
backbone with integrated LLM Adapter for text conditioning.
|
||||
|
||||
Supports txt2img, img2img (via latents input), and inpainting (via denoise_mask).
|
||||
"""
|
||||
|
||||
# If latents is provided, this means we are doing image-to-image.
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None, description=FieldDescriptions.latents, input=Input.Connection
|
||||
)
|
||||
# denoise_mask is used for inpainting. Only the masked region is modified.
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
|
||||
)
|
||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
add_noise: bool = InputField(default=True, description="Add noise based on denoising start.")
|
||||
transformer: TransformerField = InputField(
|
||||
description="Anima transformer model.", input=Input.Connection, title="Transformer"
|
||||
)
|
||||
positive_conditioning: AnimaConditioningField | list[AnimaConditioningField] = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_conditioning: AnimaConditioningField | list[AnimaConditioningField] | None = InputField(
|
||||
default=None, description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
)
|
||||
guidance_scale: float = InputField(
|
||||
default=4.5,
|
||||
ge=1.0,
|
||||
description="Guidance scale for classifier-free guidance. Recommended: 4.0-5.0 for Anima.",
|
||||
title="Guidance Scale",
|
||||
)
|
||||
width: int = InputField(default=1024, multiple_of=8, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=8, description="Height of the generated image.")
|
||||
steps: int = InputField(default=30, gt=0, description="Number of denoising steps. 30 recommended for Anima.")
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
scheduler: ANIMA_SCHEDULER_NAME_VALUES = InputField(
|
||||
default="euler",
|
||||
description="Scheduler (sampler) for the denoising process.",
|
||||
ui_choice_labels=ANIMA_SCHEDULER_LABELS,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
latents = latents.detach().to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
||||
"""Prepare the inpaint mask for Anima.
|
||||
|
||||
Anima uses 3D latents [B, C, T, H, W] internally but the mask operates
|
||||
on the spatial dimensions [B, C, H, W] which match the squeezed output.
|
||||
"""
|
||||
if self.denoise_mask is None:
|
||||
return None
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
|
||||
# Invert mask: 0.0 = regions to denoise, 1.0 = regions to preserve
|
||||
mask = 1.0 - mask
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
mask = tv_resize(
|
||||
img=mask,
|
||||
size=[latent_height, latent_width],
|
||||
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
||||
antialias=False,
|
||||
)
|
||||
|
||||
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
||||
return mask
|
||||
|
||||
def _get_noise(
|
||||
self,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
seed: int,
|
||||
) -> torch.Tensor:
|
||||
"""Generate initial noise tensor in 3D latent space [B, C, T, H, W]."""
|
||||
rand_device = "cpu"
|
||||
return torch.randn(
|
||||
1,
|
||||
ANIMA_LATENT_CHANNELS,
|
||||
1, # T=1 for single image
|
||||
height // ANIMA_LATENT_SCALE_FACTOR,
|
||||
width // ANIMA_LATENT_SCALE_FACTOR,
|
||||
device=rand_device,
|
||||
dtype=torch.float32,
|
||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
def _get_sigmas(self, num_steps: int) -> list[float]:
|
||||
"""Generate sigma schedule with fixed shift=3.0.
|
||||
|
||||
Uses the log-linear timestep shift from the Flux model (Black Forest Labs)
|
||||
with a fixed shift factor of 3.0 (no dynamic resolution-based shift).
|
||||
|
||||
Returns:
|
||||
List of num_steps + 1 sigma values from ~1.0 (noise) to 0.0 (clean).
|
||||
"""
|
||||
sigmas = []
|
||||
for i in range(num_steps + 1):
|
||||
t = 1.0 - i / num_steps
|
||||
sigma = loglinear_timestep_shift(ANIMA_SHIFT, t)
|
||||
sigmas.append(sigma)
|
||||
return sigmas
|
||||
|
||||
def _load_conditioning(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
cond_field: AnimaConditioningField,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> AnimaConditioningInfo:
|
||||
"""Load Anima conditioning data from storage."""
|
||||
cond_data = context.conditioning.load(cond_field.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
cond_info = cond_data.conditionings[0]
|
||||
assert isinstance(cond_info, AnimaConditioningInfo)
|
||||
return cond_info.to(dtype=dtype, device=device)
|
||||
|
||||
def _load_text_conditionings(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
cond_field: AnimaConditioningField | list[AnimaConditioningField],
|
||||
img_token_height: int,
|
||||
img_token_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> list[AnimaTextConditioning]:
|
||||
"""Load Anima text conditioning with optional regional masks.
|
||||
|
||||
Args:
|
||||
context: The invocation context.
|
||||
cond_field: Single conditioning field or list of fields.
|
||||
img_token_height: Height of the image token grid (H // patch_size).
|
||||
img_token_width: Width of the image token grid (W // patch_size).
|
||||
dtype: Target dtype.
|
||||
device: Target device.
|
||||
|
||||
Returns:
|
||||
List of AnimaTextConditioning objects with optional masks.
|
||||
"""
|
||||
cond_list = cond_field if isinstance(cond_field, list) else [cond_field]
|
||||
|
||||
text_conditionings: list[AnimaTextConditioning] = []
|
||||
for cond in cond_list:
|
||||
cond_info = self._load_conditioning(context, cond, dtype, device)
|
||||
|
||||
# Load the mask, if provided
|
||||
mask: torch.Tensor | None = None
|
||||
if cond.mask is not None:
|
||||
mask = context.tensors.load(cond.mask.tensor_name)
|
||||
mask = mask.to(device=device)
|
||||
mask = AnimaRegionalPromptingExtension.preprocess_regional_prompt_mask(
|
||||
mask, img_token_height, img_token_width, dtype, device
|
||||
)
|
||||
|
||||
text_conditionings.append(
|
||||
AnimaTextConditioning(
|
||||
qwen3_embeds=cond_info.qwen3_embeds,
|
||||
t5xxl_ids=cond_info.t5xxl_ids,
|
||||
t5xxl_weights=cond_info.t5xxl_weights,
|
||||
mask=mask,
|
||||
)
|
||||
)
|
||||
|
||||
return text_conditionings
|
||||
|
||||
def _run_llm_adapter_for_regions(
|
||||
self,
|
||||
transformer,
|
||||
text_conditionings: list[AnimaTextConditioning],
|
||||
dtype: torch.dtype,
|
||||
) -> AnimaRegionalTextConditioning:
|
||||
"""Run the LLM Adapter separately for each regional conditioning and concatenate.
|
||||
|
||||
Args:
|
||||
transformer: The AnimaTransformer instance (must be on device).
|
||||
text_conditionings: List of per-region conditioning data.
|
||||
dtype: Inference dtype.
|
||||
|
||||
Returns:
|
||||
AnimaRegionalTextConditioning with concatenated context and masks.
|
||||
"""
|
||||
context_embeds_list: list[torch.Tensor] = []
|
||||
context_ranges: list[Range] = []
|
||||
image_masks: list[torch.Tensor | None] = []
|
||||
cur_len = 0
|
||||
|
||||
for tc in text_conditionings:
|
||||
qwen3_embeds = tc.qwen3_embeds.unsqueeze(0) # (1, seq_len, 1024)
|
||||
t5xxl_ids = tc.t5xxl_ids.unsqueeze(0) # (1, seq_len)
|
||||
t5xxl_weights = None
|
||||
if tc.t5xxl_weights is not None:
|
||||
t5xxl_weights = tc.t5xxl_weights.unsqueeze(0).unsqueeze(-1) # (1, seq_len, 1)
|
||||
|
||||
# Run the LLM Adapter to produce context for this region
|
||||
context = transformer.preprocess_text_embeds(
|
||||
qwen3_embeds.to(dtype=dtype),
|
||||
t5xxl_ids,
|
||||
t5xxl_weights=t5xxl_weights.to(dtype=dtype) if t5xxl_weights is not None else None,
|
||||
)
|
||||
# context shape: (1, 512, 1024) — squeeze batch dim
|
||||
context_2d = context.squeeze(0) # (512, 1024)
|
||||
|
||||
context_embeds_list.append(context_2d)
|
||||
context_ranges.append(Range(start=cur_len, end=cur_len + context_2d.shape[0]))
|
||||
image_masks.append(tc.mask)
|
||||
cur_len += context_2d.shape[0]
|
||||
|
||||
concatenated_context = torch.cat(context_embeds_list, dim=0)
|
||||
|
||||
return AnimaRegionalTextConditioning(
|
||||
context_embeds=concatenated_context,
|
||||
image_masks=image_masks,
|
||||
context_ranges=context_ranges,
|
||||
)
|
||||
|
||||
def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
|
||||
device = TorchDevice.choose_torch_device()
|
||||
inference_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
|
||||
|
||||
if self.denoising_start >= self.denoising_end:
|
||||
raise ValueError(
|
||||
f"denoising_start ({self.denoising_start}) must be less than denoising_end ({self.denoising_end})."
|
||||
)
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
|
||||
# Compute image token grid dimensions for regional prompting
|
||||
# Anima: 8x VAE compression, 2x patch size → 16x total
|
||||
patch_size = 2
|
||||
latent_height = self.height // ANIMA_LATENT_SCALE_FACTOR
|
||||
latent_width = self.width // ANIMA_LATENT_SCALE_FACTOR
|
||||
img_token_height = latent_height // patch_size
|
||||
img_token_width = latent_width // patch_size
|
||||
img_seq_len = img_token_height * img_token_width
|
||||
|
||||
# Load positive conditioning with optional regional masks
|
||||
pos_text_conditionings = self._load_text_conditionings(
|
||||
context=context,
|
||||
cond_field=self.positive_conditioning,
|
||||
img_token_height=img_token_height,
|
||||
img_token_width=img_token_width,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
has_regional = len(pos_text_conditionings) > 1 or any(tc.mask is not None for tc in pos_text_conditionings)
|
||||
|
||||
# Load negative conditioning if CFG is enabled
|
||||
do_cfg = not math.isclose(self.guidance_scale, 1.0) and self.negative_conditioning is not None
|
||||
neg_text_conditionings: list[AnimaTextConditioning] | None = None
|
||||
if do_cfg:
|
||||
assert self.negative_conditioning is not None
|
||||
neg_text_conditionings = self._load_text_conditionings(
|
||||
context=context,
|
||||
cond_field=self.negative_conditioning,
|
||||
img_token_height=img_token_height,
|
||||
img_token_width=img_token_width,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Generate sigma schedule
|
||||
sigmas = self._get_sigmas(self.steps)
|
||||
|
||||
# Apply denoising_start and denoising_end clipping (for img2img/inpaint)
|
||||
if self.denoising_start > 0 or self.denoising_end < 1:
|
||||
total_sigmas = len(sigmas)
|
||||
start_idx = int(self.denoising_start * (total_sigmas - 1))
|
||||
end_idx = int(self.denoising_end * (total_sigmas - 1)) + 1
|
||||
sigmas = sigmas[start_idx:end_idx]
|
||||
|
||||
total_steps = len(sigmas) - 1
|
||||
|
||||
# Load input latents if provided (image-to-image)
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
init_latents = init_latents.to(device=device, dtype=inference_dtype)
|
||||
# Anima denoiser works in 3D: add temporal dim if needed
|
||||
if init_latents.ndim == 4:
|
||||
init_latents = init_latents.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]
|
||||
|
||||
# Generate initial noise (3D latent: [B, C, T, H, W])
|
||||
noise = self._get_noise(self.height, self.width, inference_dtype, device, self.seed)
|
||||
|
||||
# Prepare input latents
|
||||
if init_latents is not None:
|
||||
if self.add_noise:
|
||||
s_0 = sigmas[0]
|
||||
latents = s_0 * noise + (1.0 - s_0) * init_latents
|
||||
else:
|
||||
latents = init_latents
|
||||
else:
|
||||
if self.denoising_start > 1e-5:
|
||||
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
||||
latents = noise
|
||||
|
||||
if total_steps <= 0:
|
||||
return latents.squeeze(2)
|
||||
|
||||
# Prepare inpaint extension
|
||||
inpaint_mask = self._prep_inpaint_mask(context, latents.squeeze(2))
|
||||
inpaint_extension: AnimaInpaintExtension | None = None
|
||||
if inpaint_mask is not None:
|
||||
if init_latents is None:
|
||||
raise ValueError("Initial latents are required when using an inpaint mask (image-to-image inpainting)")
|
||||
inpaint_extension = AnimaInpaintExtension(
|
||||
init_latents=init_latents.squeeze(2),
|
||||
inpaint_mask=inpaint_mask,
|
||||
noise=noise.squeeze(2),
|
||||
shift=ANIMA_SHIFT,
|
||||
)
|
||||
|
||||
step_callback = self._build_step_callback(context)
|
||||
|
||||
# Initialize diffusers scheduler if not using built-in Euler
|
||||
scheduler: SchedulerMixin | None = None
|
||||
use_scheduler = self.scheduler != "euler"
|
||||
|
||||
if use_scheduler:
|
||||
scheduler_class = ANIMA_SCHEDULER_MAP[self.scheduler]
|
||||
scheduler = scheduler_class(num_train_timesteps=1000, shift=1.0)
|
||||
is_lcm = self.scheduler == "lcm"
|
||||
set_timesteps_sig = inspect.signature(scheduler.set_timesteps)
|
||||
if not is_lcm and "sigmas" in set_timesteps_sig.parameters:
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps=total_steps, device=device)
|
||||
num_scheduler_steps = len(scheduler.timesteps)
|
||||
else:
|
||||
num_scheduler_steps = total_steps
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
|
||||
|
||||
# Apply LoRA models to the transformer.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=ANIMA_LORA_TRANSFORMER_PREFIX,
|
||||
dtype=inference_dtype,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
|
||||
# Run LLM Adapter for each regional conditioning to produce context vectors.
|
||||
# This must happen with the transformer on device since it uses the adapter weights.
|
||||
if has_regional:
|
||||
pos_regional = self._run_llm_adapter_for_regions(transformer, pos_text_conditionings, inference_dtype)
|
||||
pos_context = pos_regional.context_embeds.unsqueeze(0) # (1, total_ctx_len, 1024)
|
||||
|
||||
# Build regional prompting extension with cross-attention mask
|
||||
regional_extension = AnimaRegionalPromptingExtension.from_regional_conditioning(
|
||||
pos_regional, img_seq_len
|
||||
)
|
||||
|
||||
# For negative, concatenate all regions without masking (matches Z-Image behavior)
|
||||
neg_context = None
|
||||
if do_cfg and neg_text_conditionings is not None:
|
||||
neg_regional = self._run_llm_adapter_for_regions(
|
||||
transformer, neg_text_conditionings, inference_dtype
|
||||
)
|
||||
neg_context = neg_regional.context_embeds.unsqueeze(0)
|
||||
else:
|
||||
# Single conditioning — run LLM Adapter via normal forward path
|
||||
tc = pos_text_conditionings[0]
|
||||
pos_qwen3_embeds = tc.qwen3_embeds.unsqueeze(0)
|
||||
pos_t5xxl_ids = tc.t5xxl_ids.unsqueeze(0)
|
||||
pos_t5xxl_weights = None
|
||||
if tc.t5xxl_weights is not None:
|
||||
pos_t5xxl_weights = tc.t5xxl_weights.unsqueeze(0).unsqueeze(-1)
|
||||
|
||||
# Pre-compute context via LLM Adapter
|
||||
pos_context = transformer.preprocess_text_embeds(
|
||||
pos_qwen3_embeds.to(dtype=inference_dtype),
|
||||
pos_t5xxl_ids,
|
||||
t5xxl_weights=pos_t5xxl_weights.to(dtype=inference_dtype)
|
||||
if pos_t5xxl_weights is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
neg_context = None
|
||||
if do_cfg and neg_text_conditionings is not None:
|
||||
ntc = neg_text_conditionings[0]
|
||||
neg_qwen3 = ntc.qwen3_embeds.unsqueeze(0)
|
||||
neg_ids = ntc.t5xxl_ids.unsqueeze(0)
|
||||
neg_weights = None
|
||||
if ntc.t5xxl_weights is not None:
|
||||
neg_weights = ntc.t5xxl_weights.unsqueeze(0).unsqueeze(-1)
|
||||
neg_context = transformer.preprocess_text_embeds(
|
||||
neg_qwen3.to(dtype=inference_dtype),
|
||||
neg_ids,
|
||||
t5xxl_weights=neg_weights.to(dtype=inference_dtype) if neg_weights is not None else None,
|
||||
)
|
||||
|
||||
regional_extension = None
|
||||
|
||||
# Apply regional prompting patch if we have regional masks
|
||||
exit_stack.enter_context(patch_anima_for_regional_prompting(transformer, regional_extension))
|
||||
|
||||
# Helper to run transformer with pre-computed context (bypasses LLM Adapter)
|
||||
def _run_transformer(ctx: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||
return transformer(
|
||||
x=x.to(transformer.dtype if hasattr(transformer, "dtype") else inference_dtype),
|
||||
timesteps=t,
|
||||
context=ctx,
|
||||
# t5xxl_ids=None skips the LLM Adapter — context is already pre-computed
|
||||
)
|
||||
|
||||
if use_scheduler and scheduler is not None:
|
||||
# Scheduler-based denoising
|
||||
user_step = 0
|
||||
pbar = tqdm(total=total_steps, desc="Denoising (Anima)")
|
||||
for step_index in range(num_scheduler_steps):
|
||||
sched_timestep = scheduler.timesteps[step_index]
|
||||
sigma_curr = sched_timestep.item() / scheduler.config.num_train_timesteps
|
||||
|
||||
is_heun = hasattr(scheduler, "state_in_first_order")
|
||||
in_first_order = scheduler.state_in_first_order if is_heun else True
|
||||
|
||||
timestep = torch.tensor(
|
||||
[sigma_curr * ANIMA_MULTIPLIER], device=device, dtype=inference_dtype
|
||||
).expand(latents.shape[0])
|
||||
|
||||
noise_pred_cond = _run_transformer(pos_context, latents, timestep).float()
|
||||
|
||||
if do_cfg and neg_context is not None:
|
||||
noise_pred_uncond = _run_transformer(neg_context, latents, timestep).float()
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
else:
|
||||
noise_pred = noise_pred_cond
|
||||
|
||||
step_output = scheduler.step(model_output=noise_pred, timestep=sched_timestep, sample=latents)
|
||||
latents = step_output.prev_sample
|
||||
|
||||
if step_index + 1 < len(scheduler.sigmas):
|
||||
sigma_prev = scheduler.sigmas[step_index + 1].item()
|
||||
else:
|
||||
sigma_prev = 0.0
|
||||
|
||||
if inpaint_extension is not None:
|
||||
latents_4d = latents.squeeze(2)
|
||||
latents_4d = inpaint_extension.merge_intermediate_latents_with_init_latents(
|
||||
latents_4d, sigma_prev
|
||||
)
|
||||
latents = latents_4d.unsqueeze(2)
|
||||
|
||||
if is_heun:
|
||||
if not in_first_order:
|
||||
user_step += 1
|
||||
if user_step <= total_steps:
|
||||
pbar.update(1)
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=user_step,
|
||||
order=2,
|
||||
total_steps=total_steps,
|
||||
timestep=int(sigma_curr * 1000),
|
||||
latents=latents.squeeze(2),
|
||||
)
|
||||
)
|
||||
else:
|
||||
user_step += 1
|
||||
if user_step <= total_steps:
|
||||
pbar.update(1)
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=user_step,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(sigma_curr * 1000),
|
||||
latents=latents.squeeze(2),
|
||||
)
|
||||
)
|
||||
pbar.close()
|
||||
else:
|
||||
# Built-in Euler implementation (default for Anima)
|
||||
for step_idx in tqdm(range(total_steps), desc="Denoising (Anima)"):
|
||||
sigma_curr = sigmas[step_idx]
|
||||
sigma_prev = sigmas[step_idx + 1]
|
||||
|
||||
timestep = torch.tensor(
|
||||
[sigma_curr * ANIMA_MULTIPLIER], device=device, dtype=inference_dtype
|
||||
).expand(latents.shape[0])
|
||||
|
||||
noise_pred_cond = _run_transformer(pos_context, latents, timestep).float()
|
||||
|
||||
if do_cfg and neg_context is not None:
|
||||
noise_pred_uncond = _run_transformer(neg_context, latents, timestep).float()
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
else:
|
||||
noise_pred = noise_pred_cond
|
||||
|
||||
latents_dtype = latents.dtype
|
||||
latents = latents.to(dtype=torch.float32)
|
||||
latents = latents + (sigma_prev - sigma_curr) * noise_pred
|
||||
latents = latents.to(dtype=latents_dtype)
|
||||
|
||||
if inpaint_extension is not None:
|
||||
latents_4d = latents.squeeze(2)
|
||||
latents_4d = inpaint_extension.merge_intermediate_latents_with_init_latents(
|
||||
latents_4d, sigma_prev
|
||||
)
|
||||
latents = latents_4d.unsqueeze(2)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=step_idx + 1,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(sigma_curr * 1000),
|
||||
latents=latents.squeeze(2),
|
||||
),
|
||||
)
|
||||
|
||||
# Remove temporal dimension for output: [B, C, 1, H, W] -> [B, C, H, W]
|
||||
return latents.squeeze(2)
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, BaseModelType.Anima)
|
||||
|
||||
return step_callback
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
"""Iterate over LoRA models to apply to the transformer."""
|
||||
for lora in self.transformer.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
if not isinstance(lora_info.model, ModelPatchRaw):
|
||||
raise TypeError(
|
||||
f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. "
|
||||
"The LoRA model may be corrupted or incompatible."
|
||||
)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
119
invokeai/app/invocations/anima_image_to_latents.py
Normal file
119
invokeai/app/invocations/anima_image_to_latents.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Anima image-to-latents invocation.
|
||||
|
||||
Encodes an image to latent space using the Anima VAE (AutoencoderKLWan or FLUX VAE).
|
||||
|
||||
For Wan VAE (AutoencoderKLWan):
|
||||
- Input image is converted to 5D tensor [B, C, T, H, W] with T=1
|
||||
- After encoding, latents are normalized: (latents - mean) / std
|
||||
(inverse of the denormalization in anima_latents_to_image.py)
|
||||
|
||||
For FLUX VAE (AutoEncoder):
|
||||
- Encoding is handled internally by the FLUX VAE
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from diffusers.models.autoencoders import AutoencoderKLWan
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder as FluxAutoEncoder
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
|
||||
|
||||
AnimaVAE = Union[AutoencoderKLWan, FluxAutoEncoder]
|
||||
|
||||
|
||||
@invocation(
|
||||
"anima_i2l",
|
||||
title="Image to Latents - Anima",
|
||||
tags=["image", "latents", "vae", "i2l", "anima"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class AnimaImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates latents from an image using the Anima VAE (supports Wan 2.1 and FLUX VAE)."""
|
||||
|
||||
image: ImageField = InputField(description="The image to encode.")
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
if not isinstance(vae_info.model, (AutoencoderKLWan, FluxAutoEncoder)):
|
||||
raise TypeError(
|
||||
f"Expected AutoencoderKLWan or FluxAutoEncoder for Anima VAE, got {type(vae_info.model).__name__}."
|
||||
)
|
||||
|
||||
estimated_working_memory = estimate_vae_working_memory_flux(
|
||||
operation="encode",
|
||||
image_tensor=image_tensor,
|
||||
vae=vae_info.model,
|
||||
)
|
||||
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
if not isinstance(vae, (AutoencoderKLWan, FluxAutoEncoder)):
|
||||
raise TypeError(f"Expected AutoencoderKLWan or FluxAutoEncoder, got {type(vae).__name__}.")
|
||||
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
||||
|
||||
with torch.inference_mode():
|
||||
if isinstance(vae, FluxAutoEncoder):
|
||||
# FLUX VAE handles scaling internally
|
||||
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
|
||||
latents = vae.encode(image_tensor, sample=True, generator=generator)
|
||||
else:
|
||||
# AutoencoderKLWan expects 5D input [B, C, T, H, W]
|
||||
if image_tensor.ndim == 4:
|
||||
image_tensor = image_tensor.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]
|
||||
|
||||
encoded = vae.encode(image_tensor, return_dict=False)[0]
|
||||
latents = encoded.sample().to(dtype=vae_dtype)
|
||||
|
||||
# Normalize to denoiser space: (latents - mean) / std
|
||||
# This is the inverse of the denormalization in anima_latents_to_image.py
|
||||
latents_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latents)
|
||||
latents_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(latents)
|
||||
latents = (latents - latents_mean) / latents_std
|
||||
|
||||
# Remove temporal dimension: [B, C, 1, H, W] -> [B, C, H, W]
|
||||
if latents.ndim == 5:
|
||||
latents = latents.squeeze(2)
|
||||
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
if not isinstance(vae_info.model, (AutoencoderKLWan, FluxAutoEncoder)):
|
||||
raise TypeError(
|
||||
f"Expected AutoencoderKLWan or FluxAutoEncoder for Anima VAE, got {type(vae_info.model).__name__}."
|
||||
)
|
||||
|
||||
context.util.signal_progress("Running Anima VAE encode")
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
108
invokeai/app/invocations/anima_latents_to_image.py
Normal file
108
invokeai/app/invocations/anima_latents_to_image.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Anima latents-to-image invocation.
|
||||
|
||||
Decodes Anima latents using the QwenImage VAE (AutoencoderKLWan) or
|
||||
compatible FLUX VAE as fallback.
|
||||
|
||||
Latents from the denoiser are in normalized space (zero-centered). Before
|
||||
VAE decode, they must be denormalized using the Wan 2.1 per-channel
|
||||
mean/std: latents = latents * std + mean (matching diffusers WanPipeline).
|
||||
|
||||
The VAE expects 5D latents [B, C, T, H, W] — for single images, T=1.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from diffusers.models.autoencoders import AutoencoderKLWan
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder as FluxAutoEncoder
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
|
||||
|
||||
|
||||
@invocation(
|
||||
"anima_l2i",
|
||||
title="Latents to Image - Anima",
|
||||
tags=["latents", "image", "vae", "l2i", "anima"],
|
||||
category="latents",
|
||||
version="1.0.2",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class AnimaLatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents using the Anima VAE.
|
||||
|
||||
Supports the Wan 2.1 QwenImage VAE (AutoencoderKLWan) with explicit
|
||||
latent denormalization, and FLUX VAE as fallback.
|
||||
"""
|
||||
|
||||
latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
if not isinstance(vae_info.model, (AutoencoderKLWan, FluxAutoEncoder)):
|
||||
raise TypeError(
|
||||
f"Expected AutoencoderKLWan or FluxAutoEncoder for Anima VAE, got {type(vae_info.model).__name__}."
|
||||
)
|
||||
|
||||
estimated_working_memory = estimate_vae_working_memory_flux(
|
||||
operation="decode",
|
||||
image_tensor=latents,
|
||||
vae=vae_info.model,
|
||||
)
|
||||
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
context.util.signal_progress("Running Anima VAE decode")
|
||||
if not isinstance(vae, (AutoencoderKLWan, FluxAutoEncoder)):
|
||||
raise TypeError(f"Expected AutoencoderKLWan or FluxAutoEncoder, got {type(vae).__name__}.")
|
||||
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
if isinstance(vae, FluxAutoEncoder):
|
||||
# FLUX VAE handles scaling internally, expects 4D [B, C, H, W]
|
||||
img = vae.decode(latents)
|
||||
else:
|
||||
# Expects 5D latents [B, C, T, H, W]
|
||||
if latents.ndim == 4:
|
||||
latents = latents.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]
|
||||
|
||||
# Denormalize from denoiser space to raw VAE space
|
||||
# (same as diffusers WanPipeline and ComfyUI Wan21.process_out)
|
||||
latents_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latents)
|
||||
latents_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(latents)
|
||||
latents = latents * latents_std + latents_mean
|
||||
|
||||
decoded = vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
# Output is 5D [B, C, T, H, W] — squeeze temporal dim
|
||||
if decoded.ndim == 5:
|
||||
decoded = decoded.squeeze(2)
|
||||
img = decoded
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
image_dto = context.images.save(image=img_pil)
|
||||
return ImageOutput.build(image_dto)
|
||||
162
invokeai/app/invocations/anima_lora_loader.py
Normal file
162
invokeai/app/invocations/anima_lora_loader.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, Qwen3EncoderField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
@invocation_output("anima_lora_loader_output")
|
||||
class AnimaLoRALoaderOutput(BaseInvocationOutput):
|
||||
"""Anima LoRA Loader Output"""
|
||||
|
||||
transformer: Optional[TransformerField] = OutputField(
|
||||
default=None, description=FieldDescriptions.transformer, title="Anima Transformer"
|
||||
)
|
||||
qwen3_encoder: Optional[Qwen3EncoderField] = OutputField(
|
||||
default=None, description=FieldDescriptions.qwen3_encoder, title="Qwen3 Encoder"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"anima_lora_loader",
|
||||
title="Apply LoRA - Anima",
|
||||
tags=["lora", "model", "anima"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class AnimaLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply a LoRA model to an Anima transformer and/or Qwen3 text encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model,
|
||||
title="LoRA",
|
||||
ui_model_base=BaseModelType.Anima,
|
||||
ui_model_type=ModelType.LoRA,
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
transformer: TransformerField | None = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="Anima Transformer",
|
||||
)
|
||||
qwen3_encoder: Qwen3EncoderField | None = InputField(
|
||||
default=None,
|
||||
title="Qwen3 Encoder",
|
||||
description=FieldDescriptions.qwen3_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> AnimaLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras):
|
||||
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
|
||||
if self.qwen3_encoder and any(lora.lora.key == lora_key for lora in self.qwen3_encoder.loras):
|
||||
raise ValueError(f'LoRA "{lora_key}" already applied to Qwen3 encoder.')
|
||||
|
||||
output = AnimaLoRALoaderOutput()
|
||||
|
||||
if self.transformer is not None:
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
output.transformer.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
if self.qwen3_encoder is not None:
|
||||
output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True)
|
||||
output.qwen3_encoder.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@invocation(
|
||||
"anima_lora_collection_loader",
|
||||
title="Apply LoRA Collection - Anima",
|
||||
tags=["lora", "model", "anima"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class AnimaLoRACollectionLoader(BaseInvocation):
|
||||
"""Applies a collection of LoRAs to an Anima transformer."""
|
||||
|
||||
loras: Optional[LoRAField | list[LoRAField]] = InputField(
|
||||
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
|
||||
)
|
||||
|
||||
transformer: Optional[TransformerField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
qwen3_encoder: Qwen3EncoderField | None = InputField(
|
||||
default=None,
|
||||
title="Qwen3 Encoder",
|
||||
description=FieldDescriptions.qwen3_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> AnimaLoRALoaderOutput:
|
||||
output = AnimaLoRALoaderOutput()
|
||||
|
||||
if self.loras is None:
|
||||
if self.transformer is not None:
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
if self.qwen3_encoder is not None:
|
||||
output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True)
|
||||
return output
|
||||
|
||||
loras = self.loras if isinstance(self.loras, list) else [self.loras]
|
||||
added_loras: list[str] = []
|
||||
|
||||
if self.transformer is not None:
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
|
||||
if self.qwen3_encoder is not None:
|
||||
output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True)
|
||||
|
||||
for lora in loras:
|
||||
if lora is None:
|
||||
continue
|
||||
if lora.lora.key in added_loras:
|
||||
continue
|
||||
|
||||
if not context.models.exists(lora.lora.key):
|
||||
raise ValueError(f"Unknown lora: {lora.lora.key}!")
|
||||
|
||||
if lora.lora.base is not BaseModelType.Anima:
|
||||
raise ValueError(
|
||||
f"LoRA '{lora.lora.key}' is for {lora.lora.base.value if lora.lora.base else 'unknown'} models, "
|
||||
"not Anima models. Ensure you are using an Anima compatible LoRA."
|
||||
)
|
||||
|
||||
added_loras.append(lora.lora.key)
|
||||
|
||||
if self.transformer is not None and output.transformer is not None:
|
||||
output.transformer.loras.append(lora)
|
||||
|
||||
if self.qwen3_encoder is not None and output.qwen3_encoder is not None:
|
||||
output.qwen3_encoder.loras.append(lora)
|
||||
|
||||
return output
|
||||
102
invokeai/app/invocations/anima_model_loader.py
Normal file
102
invokeai/app/invocations/anima_model_loader.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import (
|
||||
ModelIdentifierField,
|
||||
Qwen3EncoderField,
|
||||
T5EncoderField,
|
||||
TransformerField,
|
||||
VAEField,
|
||||
)
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.t5_model_identifier import (
|
||||
preprocess_t5_encoder_model_identifier,
|
||||
preprocess_t5_tokenizer_model_identifier,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
|
||||
|
||||
|
||||
@invocation_output("anima_model_loader_output")
|
||||
class AnimaModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Anima model loader output."""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
qwen3_encoder: Qwen3EncoderField = OutputField(description=FieldDescriptions.qwen3_encoder, title="Qwen3 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
|
||||
|
||||
@invocation(
|
||||
"anima_model_loader",
|
||||
title="Main Model - Anima",
|
||||
tags=["model", "anima"],
|
||||
category="model",
|
||||
version="1.3.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class AnimaModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an Anima model, outputting its submodels.
|
||||
|
||||
Anima uses:
|
||||
- Transformer: Cosmos Predict2 DiT + LLM Adapter (from single-file checkpoint)
|
||||
- Qwen3 Encoder: Qwen3 0.6B (standalone single-file)
|
||||
- VAE: AutoencoderKLQwenImage / Wan 2.1 VAE (standalone single-file or FLUX VAE)
|
||||
- T5 Encoder: T5-XXL model (only the tokenizer submodel is used, for LLM Adapter token IDs)
|
||||
"""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description="Anima main model (transformer + LLM adapter).",
|
||||
input=Input.Direct,
|
||||
ui_model_base=BaseModelType.Anima,
|
||||
ui_model_type=ModelType.Main,
|
||||
title="Transformer",
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description="Standalone VAE model. Anima uses a Wan 2.1 / QwenImage VAE (16-channel). "
|
||||
"A FLUX VAE can also be used as a compatible fallback.",
|
||||
input=Input.Direct,
|
||||
ui_model_type=ModelType.VAE,
|
||||
title="VAE",
|
||||
)
|
||||
|
||||
qwen3_encoder_model: ModelIdentifierField = InputField(
|
||||
description="Standalone Qwen3 0.6B Encoder model.",
|
||||
input=Input.Direct,
|
||||
ui_model_type=ModelType.Qwen3Encoder,
|
||||
title="Qwen3 Encoder",
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description="T5-XXL encoder model. The tokenizer submodel is used for Anima text encoding.",
|
||||
input=Input.Direct,
|
||||
ui_model_type=ModelType.T5Encoder,
|
||||
title="T5 Encoder",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> AnimaModelLoaderOutput:
|
||||
# Transformer always comes from the main model
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
|
||||
# VAE
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
# Qwen3 Encoder
|
||||
qwen3_tokenizer = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
qwen3_encoder = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
# T5 Encoder (only tokenizer submodel is used by Anima)
|
||||
t5_tokenizer = preprocess_t5_tokenizer_model_identifier(self.t5_encoder_model)
|
||||
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model)
|
||||
|
||||
return AnimaModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
qwen3_encoder=Qwen3EncoderField(tokenizer=qwen3_tokenizer, text_encoder=qwen3_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
t5_encoder=T5EncoderField(tokenizer=t5_tokenizer, text_encoder=t5_encoder, loras=[]),
|
||||
)
|
||||
221
invokeai/app/invocations/anima_text_encoder.py
Normal file
221
invokeai/app/invocations/anima_text_encoder.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""Anima text encoder invocation.
|
||||
|
||||
Encodes text using the dual-conditioning pipeline:
|
||||
1. Qwen3 0.6B: Produces hidden states (last layer)
|
||||
2. T5-XXL Tokenizer: Produces token IDs only (no T5 model needed)
|
||||
|
||||
Both outputs are stored together in AnimaConditioningInfo and used by
|
||||
the LLM Adapter inside the transformer during denoising.
|
||||
|
||||
Key differences from Z-Image text encoder:
|
||||
- Anima uses Qwen3 0.6B (base model, NOT instruct) — no chat template
|
||||
- Anima additionally tokenizes with T5-XXL tokenizer to get token IDs
|
||||
- Qwen3 output uses all positions (including padding) for full context
|
||||
"""
|
||||
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
AnimaConditioningField,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.invocations.model import Qwen3EncoderField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import AnimaConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.anima_lora_constants import ANIMA_LORA_QWEN3_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
AnimaConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
)
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.get_logger(__name__)
|
||||
|
||||
# T5-XXL max sequence length for token IDs
|
||||
T5_MAX_SEQ_LEN = 512
|
||||
|
||||
# Safety cap for Qwen3 sequence length to prevent GPU OOM on extremely long prompts.
|
||||
# Qwen3 0.6B supports 32K context but the LLM Adapter doesn't need that much.
|
||||
QWEN3_MAX_SEQ_LEN = 8192
|
||||
|
||||
|
||||
@invocation(
|
||||
"anima_text_encoder",
|
||||
title="Prompt - Anima",
|
||||
tags=["prompt", "conditioning", "anima"],
|
||||
category="conditioning",
|
||||
version="1.3.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class AnimaTextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes and preps a prompt for an Anima image.
|
||||
|
||||
Uses Qwen3 0.6B for hidden state extraction and T5-XXL tokenizer for
|
||||
token IDs (no T5 model weights needed). Both are combined by the
|
||||
LLM Adapter inside the Anima transformer during denoising.
|
||||
"""
|
||||
|
||||
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
|
||||
qwen3_encoder: Qwen3EncoderField = InputField(
|
||||
title="Qwen3 Encoder",
|
||||
description=FieldDescriptions.qwen3_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
t5_encoder: T5EncoderField = InputField(
|
||||
title="T5 Encoder",
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
mask: TensorField | None = InputField(
|
||||
default=None,
|
||||
description="A mask defining the region that this conditioning prompt applies to.",
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> AnimaConditioningOutput:
|
||||
qwen3_embeds, t5xxl_ids, t5xxl_weights = self._encode_prompt(context)
|
||||
|
||||
# Move to CPU for storage
|
||||
qwen3_embeds = qwen3_embeds.detach().to("cpu")
|
||||
t5xxl_ids = t5xxl_ids.detach().to("cpu")
|
||||
t5xxl_weights = t5xxl_weights.detach().to("cpu") if t5xxl_weights is not None else None
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
AnimaConditioningInfo(
|
||||
qwen3_embeds=qwen3_embeds,
|
||||
t5xxl_ids=t5xxl_ids,
|
||||
t5xxl_weights=t5xxl_weights,
|
||||
)
|
||||
]
|
||||
)
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return AnimaConditioningOutput(
|
||||
conditioning=AnimaConditioningField(conditioning_name=conditioning_name, mask=self.mask)
|
||||
)
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
"""Encode prompt using Qwen3 0.6B and T5-XXL tokenizer.
|
||||
|
||||
Returns:
|
||||
Tuple of (qwen3_embeds, t5xxl_ids, t5xxl_weights).
|
||||
- qwen3_embeds: Shape (max_seq_len, 1024) — includes all positions (including padding)
|
||||
to preserve full sequence context for the LLM Adapter.
|
||||
- t5xxl_ids: Shape (seq_len,) — T5-XXL token IDs (unpadded).
|
||||
- t5xxl_weights: None (uniform weights for now).
|
||||
"""
|
||||
prompt = self.prompt
|
||||
|
||||
# --- Step 1: Encode with Qwen3 0.6B ---
|
||||
text_encoder_info = context.models.load(self.qwen3_encoder.text_encoder)
|
||||
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
(_, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
|
||||
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
|
||||
|
||||
device = text_encoder.device
|
||||
|
||||
# Apply LoRA models to the text encoder
|
||||
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=text_encoder,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=ANIMA_LORA_QWEN3_PREFIX,
|
||||
dtype=lora_dtype,
|
||||
)
|
||||
)
|
||||
|
||||
if not isinstance(text_encoder, PreTrainedModel):
|
||||
raise TypeError(f"Expected PreTrainedModel for text encoder, got {type(text_encoder).__name__}.")
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
||||
raise TypeError(f"Expected PreTrainedTokenizerBase for tokenizer, got {type(tokenizer).__name__}.")
|
||||
|
||||
context.util.signal_progress("Running Qwen3 0.6B text encoder")
|
||||
|
||||
# Anima uses base Qwen3 (not instruct) — tokenize directly, no chat template.
|
||||
# A safety cap is applied to prevent GPU OOM on extremely long prompts.
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding=False,
|
||||
truncation=True,
|
||||
max_length=QWEN3_MAX_SEQ_LEN,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask
|
||||
if not isinstance(text_input_ids, torch.Tensor) or not isinstance(attention_mask, torch.Tensor):
|
||||
raise TypeError("Tokenizer returned unexpected types.")
|
||||
|
||||
if text_input_ids.shape[-1] == QWEN3_MAX_SEQ_LEN:
|
||||
logger.warning(
|
||||
f"Prompt was truncated to {QWEN3_MAX_SEQ_LEN} tokens. "
|
||||
"Consider shortening the prompt for best results."
|
||||
)
|
||||
|
||||
# Ensure at least 1 token (empty prompts produce 0 tokens with padding=False)
|
||||
if text_input_ids.shape[-1] == 0:
|
||||
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
|
||||
text_input_ids = torch.tensor([[pad_id]])
|
||||
attention_mask = torch.tensor([[1]])
|
||||
|
||||
# Get last hidden state from Qwen3 (final layer output)
|
||||
prompt_mask = attention_mask.to(device).bool()
|
||||
outputs = text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=prompt_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
|
||||
raise RuntimeError("Text encoder did not return hidden_states.")
|
||||
if len(outputs.hidden_states) < 1:
|
||||
raise RuntimeError(f"Expected at least 1 hidden state, got {len(outputs.hidden_states)}.")
|
||||
|
||||
# Use last hidden state — only real tokens, no padding
|
||||
qwen3_embeds = outputs.hidden_states[-1][0] # Shape: (seq_len, 1024)
|
||||
|
||||
# --- Step 2: Tokenize with T5-XXL tokenizer (IDs only, no model) ---
|
||||
context.util.signal_progress("Tokenizing with T5-XXL")
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
with t5_tokenizer_info.model_on_device() as (_, t5_tokenizer):
|
||||
t5_tokens = t5_tokenizer(
|
||||
prompt,
|
||||
padding=False,
|
||||
truncation=True,
|
||||
max_length=T5_MAX_SEQ_LEN,
|
||||
return_tensors="pt",
|
||||
)
|
||||
t5xxl_ids = t5_tokens.input_ids[0] # Shape: (seq_len,)
|
||||
|
||||
return qwen3_embeds, t5xxl_ids, None
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
"""Iterate over LoRA models to apply to the Qwen3 text encoder."""
|
||||
for lora in self.qwen3_encoder.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
if not isinstance(lora_info.model, ModelPatchRaw):
|
||||
raise TypeError(
|
||||
f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. "
|
||||
"The LoRA model may be corrupted or incompatible."
|
||||
)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
@@ -56,7 +56,7 @@ class BaseBatchInvocation(BaseInvocation):
|
||||
"image_batch",
|
||||
title="Image Batch",
|
||||
tags=["primitives", "image", "batch", "special"],
|
||||
category="primitives",
|
||||
category="batch",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
@@ -87,7 +87,7 @@ class ImageGeneratorField(BaseModel):
|
||||
"image_generator",
|
||||
title="Image Generator",
|
||||
tags=["primitives", "board", "image", "batch", "special"],
|
||||
category="primitives",
|
||||
category="batch",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
@@ -111,7 +111,7 @@ class ImageGenerator(BaseInvocation):
|
||||
"string_batch",
|
||||
title="String Batch",
|
||||
tags=["primitives", "string", "batch", "special"],
|
||||
category="primitives",
|
||||
category="batch",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
@@ -142,7 +142,7 @@ class StringGeneratorField(BaseModel):
|
||||
"string_generator",
|
||||
title="String Generator",
|
||||
tags=["primitives", "string", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
category="batch",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
@@ -166,7 +166,7 @@ class StringGenerator(BaseInvocation):
|
||||
"integer_batch",
|
||||
title="Integer Batch",
|
||||
tags=["primitives", "integer", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
category="batch",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
@@ -195,7 +195,7 @@ class IntegerGeneratorField(BaseModel):
|
||||
"integer_generator",
|
||||
title="Integer Generator",
|
||||
tags=["primitives", "int", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
category="batch",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
@@ -219,7 +219,7 @@ class IntegerGenerator(BaseInvocation):
|
||||
"float_batch",
|
||||
title="Float Batch",
|
||||
tags=["primitives", "float", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
category="batch",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
@@ -250,7 +250,7 @@ class FloatGeneratorField(BaseModel):
|
||||
"float_generator",
|
||||
title="Float Generator",
|
||||
tags=["primitives", "float", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
category="batch",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ from invokeai.backend.image_util.util import cv2_to_pil, pil_to_cv2
|
||||
"canny_edge_detection",
|
||||
title="Canny Edge Detection",
|
||||
tags=["controlnet", "canny"],
|
||||
category="controlnet",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.0.0",
|
||||
)
|
||||
class CannyEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
27
invokeai/app/invocations/canvas.py
Normal file
27
invokeai/app/invocations/canvas.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation(
|
||||
"canvas_output",
|
||||
title="Canvas Output",
|
||||
tags=["canvas", "output", "image"],
|
||||
category="canvas",
|
||||
version="1.0.0",
|
||||
use_cache=False,
|
||||
)
|
||||
class CanvasOutputInvocation(BaseInvocation):
|
||||
"""Outputs an image to the canvas staging area.
|
||||
|
||||
Use this node in workflows intended for canvas workflow integration.
|
||||
Connect the final image of your workflow to this node to send it
|
||||
to the canvas staging area when run via 'Run Workflow on Canvas'."""
|
||||
|
||||
image: ImageField = InputField(description=FieldDescriptions.image)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -33,7 +33,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
"cogview4_denoise",
|
||||
title="Denoise - CogView4",
|
||||
tags=["image", "cogview4"],
|
||||
category="image",
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
|
||||
@@ -27,7 +27,7 @@ from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory
|
||||
"cogview4_i2l",
|
||||
title="Image to Latents - CogView4",
|
||||
tags=["image", "latents", "vae", "i2l", "cogview4"],
|
||||
category="image",
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
|
||||
@@ -6,11 +6,11 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.model import GlmEncoderField
|
||||
from invokeai.app.invocations.primitives import CogView4ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
CogView4ConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
)
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
# The CogView4 GLM Text Encoder max sequence length set based on the default in diffusers.
|
||||
COGVIEW4_GLM_MAX_SEQ_LEN = 1024
|
||||
@@ -20,7 +20,7 @@ COGVIEW4_GLM_MAX_SEQ_LEN = 1024
|
||||
"cogview4_text_encoder",
|
||||
title="Prompt - CogView4",
|
||||
tags=["prompt", "conditioning", "cogview4"],
|
||||
category="conditioning",
|
||||
category="prompt",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
@@ -37,6 +37,8 @@ class CogView4TextEncoderInvocation(BaseInvocation):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> CogView4ConditioningOutput:
|
||||
glm_embeds = self._glm_encode(context, max_seq_len=COGVIEW4_GLM_MAX_SEQ_LEN)
|
||||
# Move embeddings to CPU for storage to save VRAM
|
||||
glm_embeds = glm_embeds.detach().to("cpu")
|
||||
conditioning_data = ConditioningFieldData(conditionings=[CogView4ConditioningInfo(glm_embeds=glm_embeds)])
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return CogView4ConditioningOutput.build(conditioning_name)
|
||||
@@ -45,10 +47,18 @@ class CogView4TextEncoderInvocation(BaseInvocation):
|
||||
prompt = [self.prompt]
|
||||
|
||||
# TODO(ryand): Add model inputs to the invocation rather than hard-coding.
|
||||
glm_text_encoder_info = context.models.load(self.glm_encoder.text_encoder)
|
||||
with (
|
||||
context.models.load(self.glm_encoder.text_encoder).model_on_device() as (_, glm_text_encoder),
|
||||
glm_text_encoder_info.model_on_device() as (_, glm_text_encoder),
|
||||
context.models.load(self.glm_encoder.tokenizer).model_on_device() as (_, glm_tokenizer),
|
||||
):
|
||||
repaired_tensors = glm_text_encoder_info.repair_required_tensors_on_device()
|
||||
device = get_effective_device(glm_text_encoder)
|
||||
if repaired_tensors > 0:
|
||||
context.logger.warning(
|
||||
f"Recovered {repaired_tensors} required GLM tensor(s) onto {device} after a partial device mismatch."
|
||||
)
|
||||
|
||||
context.util.signal_progress("Running GLM text encoder")
|
||||
assert isinstance(glm_text_encoder, GlmModel)
|
||||
assert isinstance(glm_tokenizer, PreTrainedTokenizerFast)
|
||||
@@ -84,9 +94,7 @@ class CogView4TextEncoderInvocation(BaseInvocation):
|
||||
device=text_input_ids.device,
|
||||
)
|
||||
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
|
||||
prompt_embeds = glm_text_encoder(
|
||||
text_input_ids.to(TorchDevice.choose_torch_device()), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
prompt_embeds = glm_text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2]
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
|
||||
@@ -11,9 +11,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
|
||||
|
||||
@invocation(
|
||||
"range", title="Integer Range", tags=["collection", "integer", "range"], category="collections", version="1.0.0"
|
||||
)
|
||||
@invocation("range", title="Integer Range", tags=["collection", "integer", "range"], category="batch", version="1.0.0")
|
||||
class RangeInvocation(BaseInvocation):
|
||||
"""Creates a range of numbers from start to stop with step"""
|
||||
|
||||
@@ -35,7 +33,7 @@ class RangeInvocation(BaseInvocation):
|
||||
"range_of_size",
|
||||
title="Integer Range of Size",
|
||||
tags=["collection", "integer", "size", "range"],
|
||||
category="collections",
|
||||
category="batch",
|
||||
version="1.0.0",
|
||||
)
|
||||
class RangeOfSizeInvocation(BaseInvocation):
|
||||
@@ -55,7 +53,7 @@ class RangeOfSizeInvocation(BaseInvocation):
|
||||
"random_range",
|
||||
title="Random Range",
|
||||
tags=["range", "integer", "random", "collection"],
|
||||
category="collections",
|
||||
category="batch",
|
||||
version="1.0.1",
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
"color_map",
|
||||
title="Color Map",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ColorMapInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@@ -19,6 +19,7 @@ from invokeai.app.invocations.model import CLIPField
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
@@ -42,7 +43,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
"compel",
|
||||
title="Prompt - SD1.5",
|
||||
tags=["prompt", "compel"],
|
||||
category="conditioning",
|
||||
category="prompt",
|
||||
version="1.2.1",
|
||||
)
|
||||
class CompelInvocation(BaseInvocation):
|
||||
@@ -103,7 +104,7 @@ class CompelInvocation(BaseInvocation):
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
device=get_effective_device(text_encoder),
|
||||
split_long_text_mode=SplitLongTextMode.SENTENCES,
|
||||
)
|
||||
|
||||
@@ -212,7 +213,7 @@ class SDXLPromptInvocationBase:
|
||||
truncate_long_prompts=False, # TODO:
|
||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||
requires_pooled=get_pooled,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
device=get_effective_device(text_encoder),
|
||||
split_long_text_mode=SplitLongTextMode.SENTENCES,
|
||||
)
|
||||
|
||||
@@ -247,7 +248,7 @@ class SDXLPromptInvocationBase:
|
||||
"sdxl_compel_prompt",
|
||||
title="Prompt - SDXL",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
category="prompt",
|
||||
version="1.2.1",
|
||||
)
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
@@ -341,7 +342,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"sdxl_refiner_compel_prompt",
|
||||
title="Prompt - SDXL Refiner",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
category="prompt",
|
||||
version="1.1.2",
|
||||
)
|
||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
@@ -390,7 +391,7 @@ class CLIPSkipInvocationOutput(BaseInvocationOutput):
|
||||
"clip_skip",
|
||||
title="Apply CLIP Skip - SD1.5, SDXL",
|
||||
tags=["clipskip", "clip", "skip"],
|
||||
category="conditioning",
|
||||
category="prompt",
|
||||
version="1.1.1",
|
||||
)
|
||||
class CLIPSkipInvocation(BaseInvocation):
|
||||
|
||||
@@ -9,7 +9,7 @@ from invokeai.backend.image_util.content_shuffle import content_shuffle
|
||||
"content_shuffle",
|
||||
title="Content Shuffle",
|
||||
tags=["controlnet", "normal"],
|
||||
category="controlnet",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ContentShuffleInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@@ -64,7 +64,7 @@ class ControlOutput(BaseInvocationOutput):
|
||||
|
||||
|
||||
@invocation(
|
||||
"controlnet", title="ControlNet - SD1.5, SD2, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3"
|
||||
"controlnet", title="ControlNet - SD1.5, SD2, SDXL", tags=["controlnet"], category="conditioning", version="1.1.3"
|
||||
)
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
@@ -116,7 +116,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
"heuristic_resize",
|
||||
title="Heuristic Resize",
|
||||
tags=["image, controlnet"],
|
||||
category="image",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.1.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_t
|
||||
"create_denoise_mask",
|
||||
title="Create Denoise Mask",
|
||||
tags=["mask", "denoise"],
|
||||
category="latents",
|
||||
category="mask",
|
||||
version="1.0.2",
|
||||
)
|
||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
|
||||
@@ -41,7 +41,7 @@ class GradientMaskOutput(BaseInvocationOutput):
|
||||
"create_gradient_mask",
|
||||
title="Create Gradient Mask",
|
||||
tags=["mask", "denoise"],
|
||||
category="latents",
|
||||
category="mask",
|
||||
version="1.3.0",
|
||||
)
|
||||
class CreateGradientMaskInvocation(BaseInvocation):
|
||||
|
||||
@@ -20,7 +20,7 @@ DEPTH_ANYTHING_MODELS = {
|
||||
"depth_anything_depth_estimation",
|
||||
title="Depth Anything Depth Estimation",
|
||||
tags=["controlnet", "depth", "depth anything"],
|
||||
category="controlnet",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.0.0",
|
||||
)
|
||||
class DepthAnythingDepthEstimationInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@@ -11,7 +11,7 @@ from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
|
||||
"dw_openpose_detection",
|
||||
title="DW Openpose Detection",
|
||||
tags=["controlnet", "dwpose", "openpose"],
|
||||
category="controlnet",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.1.1",
|
||||
)
|
||||
class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
203
invokeai/app/invocations/external_image_generation.py
Normal file
203
invokeai/app/invocations/external_image_generation.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from typing import Any, ClassVar, Literal
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
MetadataField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageCollectionOutput
|
||||
from invokeai.app.services.external_generation.external_generation_common import (
|
||||
ExternalGenerationRequest,
|
||||
ExternalGenerationResult,
|
||||
ExternalReferenceImage,
|
||||
)
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig, ExternalGenerationMode
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
|
||||
|
||||
|
||||
class BaseExternalImageGenerationInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generate images using an external provider."""
|
||||
|
||||
provider_id: ClassVar[str | None] = None
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.main_model,
|
||||
ui_model_base=[BaseModelType.External],
|
||||
ui_model_type=[ModelType.ExternalImageGenerator],
|
||||
ui_model_format=[ModelFormat.ExternalApi],
|
||||
)
|
||||
mode: ExternalGenerationMode = InputField(default="txt2img", description="Generation mode")
|
||||
prompt: str = InputField(description="Prompt")
|
||||
seed: int | None = InputField(default=None, description=FieldDescriptions.seed)
|
||||
num_images: int = InputField(default=1, gt=0, description="Number of images to generate")
|
||||
width: int = InputField(default=1024, gt=0, description=FieldDescriptions.width)
|
||||
height: int = InputField(default=1024, gt=0, description=FieldDescriptions.height)
|
||||
image_size: str | None = InputField(default=None, description="Image size preset (e.g. 1K, 2K, 4K)")
|
||||
init_image: ImageField | None = InputField(default=None, description="Init image for img2img/inpaint")
|
||||
mask_image: ImageField | None = InputField(default=None, description="Mask image for inpaint")
|
||||
reference_images: list[ImageField] = InputField(default=[], description="Reference images")
|
||||
|
||||
def _build_provider_options(self) -> dict[str, Any] | None:
|
||||
"""Override in provider-specific subclasses to pass extra options."""
|
||||
return None
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
||||
model_config = context.models.get_config(self.model)
|
||||
if not isinstance(model_config, ExternalApiModelConfig):
|
||||
raise ValueError("Selected model is not an external API model")
|
||||
|
||||
if self.provider_id is not None and model_config.provider_id != self.provider_id:
|
||||
raise ValueError(
|
||||
f"Selected model provider '{model_config.provider_id}' does not match node provider '{self.provider_id}'"
|
||||
)
|
||||
|
||||
init_image = None
|
||||
if self.init_image is not None:
|
||||
init_image = context.images.get_pil(self.init_image.image_name, mode="RGB")
|
||||
|
||||
mask_image = None
|
||||
if self.mask_image is not None:
|
||||
mask_image = context.images.get_pil(self.mask_image.image_name, mode="L")
|
||||
|
||||
reference_images: list[ExternalReferenceImage] = []
|
||||
for image_field in self.reference_images:
|
||||
reference_image = context.images.get_pil(image_field.image_name, mode="RGB")
|
||||
reference_images.append(ExternalReferenceImage(image=reference_image))
|
||||
|
||||
request = ExternalGenerationRequest(
|
||||
model=model_config,
|
||||
mode=self.mode,
|
||||
prompt=self.prompt,
|
||||
seed=self.seed,
|
||||
num_images=self.num_images,
|
||||
width=self.width,
|
||||
height=self.height,
|
||||
image_size=self.image_size,
|
||||
init_image=init_image,
|
||||
mask_image=mask_image,
|
||||
reference_images=reference_images,
|
||||
metadata=self._build_request_metadata(),
|
||||
provider_options=self._build_provider_options(),
|
||||
)
|
||||
|
||||
result = context._services.external_generation.generate(request)
|
||||
|
||||
outputs: list[ImageField] = []
|
||||
for generated in result.images:
|
||||
metadata = self._build_output_metadata(model_config, result, generated.seed)
|
||||
image_dto = context.images.save(image=generated.image, metadata=metadata)
|
||||
outputs.append(ImageField(image_name=image_dto.image_name))
|
||||
|
||||
return ImageCollectionOutput(collection=outputs)
|
||||
|
||||
def _build_request_metadata(self) -> dict[str, Any] | None:
|
||||
if self.metadata is None:
|
||||
return None
|
||||
return self.metadata.root
|
||||
|
||||
def _build_output_metadata(
|
||||
self,
|
||||
model_config: ExternalApiModelConfig,
|
||||
result: ExternalGenerationResult,
|
||||
image_seed: int | None,
|
||||
) -> MetadataField | None:
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
if self.metadata is not None:
|
||||
metadata.update(self.metadata.root)
|
||||
|
||||
metadata.update(
|
||||
{
|
||||
"external_provider": model_config.provider_id,
|
||||
"external_model_id": model_config.provider_model_id,
|
||||
}
|
||||
)
|
||||
|
||||
provider_request_id = getattr(result, "provider_request_id", None)
|
||||
if provider_request_id:
|
||||
metadata["external_request_id"] = provider_request_id
|
||||
|
||||
provider_metadata = getattr(result, "provider_metadata", None)
|
||||
if provider_metadata:
|
||||
metadata["external_provider_metadata"] = provider_metadata
|
||||
|
||||
if image_seed is not None:
|
||||
metadata["external_seed"] = image_seed
|
||||
|
||||
if not metadata:
|
||||
return None
|
||||
return MetadataField(root=metadata)
|
||||
|
||||
|
||||
@invocation(
|
||||
"external_image_generation",
|
||||
title="External Image Generation (Legacy)",
|
||||
tags=["external", "generation"],
|
||||
category="image",
|
||||
version="1.1.0",
|
||||
classification=Classification.Internal,
|
||||
)
|
||||
class ExternalImageGenerationInvocation(BaseExternalImageGenerationInvocation):
|
||||
"""Legacy external image generation node kept for backward compatibility."""
|
||||
|
||||
|
||||
@invocation(
|
||||
"openai_image_generation",
|
||||
title="OpenAI Image Generation",
|
||||
tags=["external", "generation", "openai"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class OpenAIImageGenerationInvocation(BaseExternalImageGenerationInvocation):
|
||||
"""Generate images using an OpenAI-hosted external model."""
|
||||
|
||||
provider_id = "openai"
|
||||
|
||||
quality: Literal["auto", "high", "medium", "low"] = InputField(default="auto", description="Output image quality")
|
||||
background: Literal["auto", "transparent", "opaque"] = InputField(
|
||||
default="auto", description="Background transparency handling"
|
||||
)
|
||||
input_fidelity: Literal["low", "high"] | None = InputField(
|
||||
default=None, description="Fidelity to source images (edits only)"
|
||||
)
|
||||
|
||||
def _build_provider_options(self) -> dict[str, Any]:
|
||||
options: dict[str, Any] = {
|
||||
"quality": self.quality,
|
||||
"background": self.background,
|
||||
}
|
||||
if self.input_fidelity is not None:
|
||||
options["input_fidelity"] = self.input_fidelity
|
||||
return options
|
||||
|
||||
|
||||
@invocation(
|
||||
"gemini_image_generation",
|
||||
title="Gemini Image Generation",
|
||||
tags=["external", "generation", "gemini"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class GeminiImageGenerationInvocation(BaseExternalImageGenerationInvocation):
|
||||
"""Generate images using a Gemini-hosted external model."""
|
||||
|
||||
provider_id = "gemini"
|
||||
|
||||
temperature: float | None = InputField(default=None, ge=0.0, le=2.0, description="Sampling temperature")
|
||||
thinking_level: Literal["minimal", "high"] | None = InputField(
|
||||
default=None, description="Thinking level for image generation"
|
||||
)
|
||||
|
||||
def _build_provider_options(self) -> dict[str, Any] | None:
|
||||
options: dict[str, Any] = {}
|
||||
if self.temperature is not None:
|
||||
options["temperature"] = self.temperature
|
||||
if self.thinking_level is not None:
|
||||
options["thinking_level"] = self.thinking_level
|
||||
return options or None
|
||||
@@ -435,7 +435,9 @@ def get_faces_list(
|
||||
return all_faces
|
||||
|
||||
|
||||
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.2")
|
||||
@invocation(
|
||||
"face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="segmentation", version="1.2.2"
|
||||
)
|
||||
class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
|
||||
|
||||
@@ -514,7 +516,9 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||
return output
|
||||
|
||||
|
||||
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.2")
|
||||
@invocation(
|
||||
"face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="segmentation", version="1.2.2"
|
||||
)
|
||||
class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
||||
"""Face mask creation using mediapipe face detection"""
|
||||
|
||||
@@ -617,7 +621,11 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
||||
|
||||
|
||||
@invocation(
|
||||
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.2"
|
||||
"face_identifier",
|
||||
title="FaceIdentifier",
|
||||
tags=["image", "face", "identifier"],
|
||||
category="segmentation",
|
||||
version="1.2.2",
|
||||
)
|
||||
class FaceIdentifierInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
|
||||
|
||||
@@ -171,6 +171,8 @@ class FieldDescriptions:
|
||||
sd3_model = "SD3 model (MMDiTX) to load"
|
||||
cogview4_model = "CogView4 model (Transformer) to load"
|
||||
z_image_model = "Z-Image model (Transformer) to load"
|
||||
qwen_image_model = "Qwen Image Edit model (Transformer) to load"
|
||||
qwen_vl_encoder = "Qwen2.5-VL tokenizer, processor and text/vision encoder"
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
@@ -340,6 +342,27 @@ class ZImageConditioningField(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class QwenImageConditioningField(BaseModel):
|
||||
"""A Qwen Image Edit conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
class AnimaConditioningField(BaseModel):
|
||||
"""An Anima conditioning tensor primitive value.
|
||||
|
||||
Anima conditioning contains Qwen3 0.6B hidden states and T5-XXL token IDs,
|
||||
which are combined by the LLM Adapter inside the transformer.
|
||||
"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
mask: Optional[TensorField] = Field(
|
||||
default=None,
|
||||
description="The mask associated with this conditioning tensor for regional prompting. "
|
||||
"Excluded regions should be set to False, included regions should be set to True.",
|
||||
)
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
@@ -532,7 +555,7 @@ def migrate_model_ui_type(ui_type: UIType | str, json_schema_extra: dict[str, An
|
||||
case UIType.VAEModel:
|
||||
ui_model_type = [ModelType.VAE]
|
||||
case UIType.FluxVAEModel:
|
||||
ui_model_base = [BaseModelType.Flux]
|
||||
ui_model_base = [BaseModelType.Flux, BaseModelType.Flux2]
|
||||
ui_model_type = [ModelType.VAE]
|
||||
case UIType.LoRAModel:
|
||||
ui_model_type = [ModelType.LoRA]
|
||||
|
||||
530
invokeai/app/invocations/flux2_denoise.py
Normal file
530
invokeai/app/invocations/flux2_denoise.py
Normal file
@@ -0,0 +1,530 @@
|
||||
"""Flux2 Klein Denoise Invocation.
|
||||
|
||||
Run denoising process with a FLUX.2 Klein transformer model.
|
||||
Uses Qwen3 conditioning instead of CLIP+T5.
|
||||
"""
|
||||
|
||||
from contextlib import ExitStack
|
||||
from typing import Callable, Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
FluxKontextConditioningField,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
|
||||
from invokeai.backend.flux.schedulers import FLUX_SCHEDULER_LABELS, FLUX_SCHEDULER_MAP, FLUX_SCHEDULER_NAME_VALUES
|
||||
from invokeai.backend.flux2.denoise import denoise
|
||||
from invokeai.backend.flux2.ref_image_extension import Flux2RefImageExtension
|
||||
from invokeai.backend.flux2.sampling_utils import (
|
||||
compute_empirical_mu,
|
||||
generate_img_ids_flux2,
|
||||
get_noise_flux2,
|
||||
get_schedule_flux2,
|
||||
pack_flux2,
|
||||
unpack_flux2,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_bfl_peft_lora_conversion_utils import (
|
||||
convert_bfl_lora_patch_to_diffusers,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux2_denoise",
|
||||
title="FLUX2 Denoise",
|
||||
tags=["image", "flux", "flux2", "klein", "denoise"],
|
||||
category="latents",
|
||||
version="1.4.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Flux2DenoiseInvocation(BaseInvocation):
|
||||
"""Run denoising process with a FLUX.2 Klein transformer model.
|
||||
|
||||
This node is designed for FLUX.2 Klein models which use Qwen3 as the text encoder.
|
||||
It does not support ControlNet, IP-Adapters, or regional prompting.
|
||||
"""
|
||||
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.denoise_mask,
|
||||
input=Input.Connection,
|
||||
)
|
||||
denoising_start: float = InputField(
|
||||
default=0.0,
|
||||
ge=0,
|
||||
le=1,
|
||||
description=FieldDescriptions.denoising_start,
|
||||
)
|
||||
denoising_end: float = InputField(
|
||||
default=1.0,
|
||||
ge=0,
|
||||
le=1,
|
||||
description=FieldDescriptions.denoising_end,
|
||||
)
|
||||
add_noise: bool = InputField(default=True, description="Add noise based on denoising start.")
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
positive_text_conditioning: FluxConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond,
|
||||
input=Input.Connection,
|
||||
)
|
||||
negative_text_conditioning: Optional[FluxConditioningField] = InputField(
|
||||
default=None,
|
||||
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
|
||||
input=Input.Connection,
|
||||
)
|
||||
cfg_scale: float = InputField(
|
||||
default=1.0,
|
||||
description=FieldDescriptions.cfg_scale,
|
||||
title="CFG Scale",
|
||||
)
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
num_steps: int = InputField(
|
||||
default=4,
|
||||
description="Number of diffusion steps. Use 4 for distilled models, 28+ for base models.",
|
||||
)
|
||||
scheduler: FLUX_SCHEDULER_NAME_VALUES = InputField(
|
||||
default="euler",
|
||||
description="Scheduler (sampler) for the denoising process. 'euler' is fast and standard. "
|
||||
"'heun' is 2nd-order (better quality, 2x slower). 'lcm' is optimized for few steps.",
|
||||
ui_choice_labels=FLUX_SCHEDULER_LABELS,
|
||||
)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
vae: VAEField = InputField(
|
||||
description="FLUX.2 VAE model (required for BN statistics).",
|
||||
input=Input.Connection,
|
||||
)
|
||||
kontext_conditioning: FluxKontextConditioningField | list[FluxKontextConditioningField] | None = InputField(
|
||||
default=None,
|
||||
description="FLUX Kontext conditioning (reference images for multi-reference image editing).",
|
||||
input=Input.Connection,
|
||||
title="Reference Images",
|
||||
)
|
||||
|
||||
def _get_bn_stats(self, context: InvocationContext) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Extract BN statistics from the FLUX.2 VAE.
|
||||
|
||||
The FLUX.2 VAE uses batch normalization on the patchified 128-channel representation.
|
||||
IMPORTANT: BFL FLUX.2 VAE uses affine=False, so there are NO learnable weight/bias.
|
||||
|
||||
BN formula (affine=False): y = (x - mean) / std
|
||||
Inverse: x = y * std + mean
|
||||
|
||||
Returns:
|
||||
Tuple of (bn_mean, bn_std) tensors of shape (128,), or None if BN layer not found.
|
||||
"""
|
||||
with context.models.load(self.vae.vae).model_on_device() as (_, vae):
|
||||
# Ensure VAE is in eval mode to prevent BN stats from being updated
|
||||
vae.eval()
|
||||
|
||||
# Try to find the BN layer - it may be at different locations depending on model format
|
||||
bn_layer = None
|
||||
if hasattr(vae, "bn"):
|
||||
bn_layer = vae.bn
|
||||
elif hasattr(vae, "batch_norm"):
|
||||
bn_layer = vae.batch_norm
|
||||
elif hasattr(vae, "encoder") and hasattr(vae.encoder, "bn"):
|
||||
bn_layer = vae.encoder.bn
|
||||
|
||||
if bn_layer is None:
|
||||
return None
|
||||
|
||||
# Verify running statistics are initialized
|
||||
if bn_layer.running_mean is None or bn_layer.running_var is None:
|
||||
return None
|
||||
|
||||
# Get BN running statistics from VAE
|
||||
bn_mean = bn_layer.running_mean.clone() # Shape: (128,)
|
||||
bn_var = bn_layer.running_var.clone() # Shape: (128,)
|
||||
bn_eps = bn_layer.eps if hasattr(bn_layer, "eps") else 1e-4 # BFL uses 1e-4
|
||||
bn_std = torch.sqrt(bn_var + bn_eps)
|
||||
|
||||
return bn_mean, bn_std
|
||||
|
||||
def _bn_normalize(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
bn_mean: torch.Tensor,
|
||||
bn_std: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Apply BN normalization to packed latents.
|
||||
|
||||
BN formula (affine=False): y = (x - mean) / std
|
||||
|
||||
Args:
|
||||
x: Packed latents of shape (B, seq, 128).
|
||||
bn_mean: BN running mean of shape (128,).
|
||||
bn_std: BN running std of shape (128,).
|
||||
|
||||
Returns:
|
||||
Normalized latents of same shape.
|
||||
"""
|
||||
# x: (B, seq, 128), params: (128,) -> broadcast over batch and sequence dims
|
||||
bn_mean = bn_mean.to(x.device, x.dtype)
|
||||
bn_std = bn_std.to(x.device, x.dtype)
|
||||
return (x - bn_mean) / bn_std
|
||||
|
||||
def _bn_denormalize(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
bn_mean: torch.Tensor,
|
||||
bn_std: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Apply BN denormalization to packed latents (inverse of normalization).
|
||||
|
||||
Inverse BN (affine=False): x = y * std + mean
|
||||
|
||||
Args:
|
||||
x: Packed latents of shape (B, seq, 128).
|
||||
bn_mean: BN running mean of shape (128,).
|
||||
bn_std: BN running std of shape (128,).
|
||||
|
||||
Returns:
|
||||
Denormalized latents of same shape.
|
||||
"""
|
||||
# x: (B, seq, 128), params: (128,) -> broadcast over batch and sequence dims
|
||||
bn_mean = bn_mean.to(x.device, x.dtype)
|
||||
bn_std = bn_std.to(x.device, x.dtype)
|
||||
return x * bn_std + bn_mean
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
latents = latents.detach().to("cpu")
|
||||
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
|
||||
inference_dtype = torch.bfloat16
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
# Get BN statistics from VAE for latent denormalization (optional)
|
||||
# BFL FLUX.2 VAE uses affine=False, so only mean/std are needed
|
||||
# Some VAE formats (e.g. diffusers) may not expose BN stats directly
|
||||
bn_stats = self._get_bn_stats(context)
|
||||
bn_mean, bn_std = bn_stats if bn_stats is not None else (None, None)
|
||||
|
||||
# Load the input latents, if provided
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
init_latents = init_latents.to(device=device, dtype=inference_dtype)
|
||||
|
||||
# Prepare input noise (FLUX.2 uses 32 channels)
|
||||
noise = get_noise_flux2(
|
||||
num_samples=1,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
device=device,
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
b, _c, latent_h, latent_w = noise.shape
|
||||
packed_h = latent_h // 2
|
||||
packed_w = latent_w // 2
|
||||
|
||||
# Load the conditioning data
|
||||
pos_cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||
assert len(pos_cond_data.conditionings) == 1
|
||||
pos_flux_conditioning = pos_cond_data.conditionings[0]
|
||||
assert isinstance(pos_flux_conditioning, FLUXConditioningInfo)
|
||||
pos_flux_conditioning = pos_flux_conditioning.to(dtype=inference_dtype, device=device)
|
||||
|
||||
# Qwen3 stacked embeddings (stored in t5_embeds field for compatibility)
|
||||
txt = pos_flux_conditioning.t5_embeds
|
||||
|
||||
# Generate text position IDs (4D format for FLUX.2: T, H, W, L)
|
||||
# FLUX.2 uses 4D position coordinates for its rotary position embeddings
|
||||
# IMPORTANT: Position IDs must be int64 (long) dtype
|
||||
# Diffusers uses: T=0, H=0, W=0, L=0..seq_len-1
|
||||
seq_len = txt.shape[1]
|
||||
txt_ids = torch.zeros(1, seq_len, 4, device=device, dtype=torch.long)
|
||||
txt_ids[..., 3] = torch.arange(seq_len, device=device, dtype=torch.long) # L coordinate varies
|
||||
|
||||
# Load negative conditioning if provided
|
||||
neg_txt = None
|
||||
neg_txt_ids = None
|
||||
if self.negative_text_conditioning is not None:
|
||||
neg_cond_data = context.conditioning.load(self.negative_text_conditioning.conditioning_name)
|
||||
assert len(neg_cond_data.conditionings) == 1
|
||||
neg_flux_conditioning = neg_cond_data.conditionings[0]
|
||||
assert isinstance(neg_flux_conditioning, FLUXConditioningInfo)
|
||||
neg_flux_conditioning = neg_flux_conditioning.to(dtype=inference_dtype, device=device)
|
||||
neg_txt = neg_flux_conditioning.t5_embeds
|
||||
# For text tokens: T=0, H=0, W=0, L=0..seq_len-1 (only L varies per token)
|
||||
neg_seq_len = neg_txt.shape[1]
|
||||
neg_txt_ids = torch.zeros(1, neg_seq_len, 4, device=device, dtype=torch.long)
|
||||
neg_txt_ids[..., 3] = torch.arange(neg_seq_len, device=device, dtype=torch.long)
|
||||
|
||||
# Validate transformer config
|
||||
transformer_config = context.models.get_config(self.transformer.transformer)
|
||||
assert transformer_config.base == BaseModelType.Flux2 and transformer_config.type == ModelType.Main
|
||||
|
||||
# Calculate the timestep schedule using FLUX.2 specific schedule
|
||||
# This matches diffusers' Flux2Pipeline implementation
|
||||
# Note: Schedule shifting is handled by the scheduler via mu parameter
|
||||
image_seq_len = packed_h * packed_w
|
||||
timesteps = get_schedule_flux2(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=image_seq_len,
|
||||
)
|
||||
# Compute mu for dynamic schedule shifting (used by FlowMatchEulerDiscreteScheduler)
|
||||
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=self.num_steps)
|
||||
|
||||
# Clip the timesteps schedule based on denoising_start and denoising_end
|
||||
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
|
||||
|
||||
# Prepare input latent image
|
||||
if init_latents is not None:
|
||||
if self.add_noise:
|
||||
t_0 = timesteps[0]
|
||||
x = t_0 * noise + (1.0 - t_0) * init_latents
|
||||
else:
|
||||
x = init_latents
|
||||
else:
|
||||
if self.denoising_start > 1e-5:
|
||||
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
||||
x = noise
|
||||
|
||||
# If len(timesteps) == 1, then short-circuit
|
||||
if len(timesteps) <= 1:
|
||||
return x
|
||||
|
||||
# Generate image position IDs (FLUX.2 uses 4D coordinates)
|
||||
# Position IDs use int64 dtype like diffusers
|
||||
img_ids = generate_img_ids_flux2(h=latent_h, w=latent_w, batch_size=b, device=device)
|
||||
|
||||
# Prepare inpaint mask
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
|
||||
# Pack all latent tensors
|
||||
init_latents_packed = pack_flux2(init_latents) if init_latents is not None else None
|
||||
inpaint_mask_packed = pack_flux2(inpaint_mask) if inpaint_mask is not None else None
|
||||
noise_packed = pack_flux2(noise)
|
||||
x = pack_flux2(x)
|
||||
|
||||
# BN normalization for img2img/inpainting:
|
||||
# - The init_latents from VAE encode are NOT BN-normalized
|
||||
# - The transformer operates in BN-normalized space
|
||||
# - We must normalize x, init_latents, AND noise for InpaintExtension
|
||||
# - Output MUST be denormalized after denoising before VAE decode
|
||||
#
|
||||
# This ensures that:
|
||||
# 1. x starts in the correct normalized space for the transformer
|
||||
# 2. When InpaintExtension merges intermediate_latents with noised_init_latents,
|
||||
# both are in the same scale/space (noise and init_latents must be in same space
|
||||
# for the linear interpolation: noised = noise * t + init * (1-t))
|
||||
if bn_mean is not None and bn_std is not None:
|
||||
if init_latents_packed is not None:
|
||||
init_latents_packed = self._bn_normalize(init_latents_packed, bn_mean, bn_std)
|
||||
# Also normalize noise for InpaintExtension - it's used to compute
|
||||
# noised_init_latents = noise * t + init_latents * (1-t)
|
||||
# Both operands must be in the same normalized space
|
||||
noise_packed = self._bn_normalize(noise_packed, bn_mean, bn_std)
|
||||
# For img2img/inpainting, x is computed from init_latents and must also be normalized
|
||||
# For txt2img, x is pure noise (already N(0,1)) - normalizing it would be incorrect
|
||||
# We detect img2img by checking if init_latents was provided
|
||||
if init_latents is not None:
|
||||
x = self._bn_normalize(x, bn_mean, bn_std)
|
||||
|
||||
# Verify packed dimensions
|
||||
assert packed_h * packed_w == x.shape[1]
|
||||
|
||||
# Prepare inpaint extension
|
||||
inpaint_extension: Optional[RectifiedFlowInpaintExtension] = None
|
||||
if inpaint_mask_packed is not None:
|
||||
assert init_latents_packed is not None
|
||||
inpaint_extension = RectifiedFlowInpaintExtension(
|
||||
init_latents=init_latents_packed,
|
||||
inpaint_mask=inpaint_mask_packed,
|
||||
noise=noise_packed,
|
||||
)
|
||||
|
||||
# Prepare CFG scale list
|
||||
num_steps = len(timesteps) - 1
|
||||
cfg_scale_list = [self.cfg_scale] * num_steps
|
||||
|
||||
# Check if we're doing inpainting (have a mask or a clipped schedule)
|
||||
is_inpainting = self.denoise_mask is not None or self.denoising_start > 1e-5
|
||||
|
||||
# Create scheduler with FLUX.2 Klein configuration
|
||||
# For inpainting/img2img, use manual Euler stepping to preserve the exact timestep schedule
|
||||
# For txt2img, use the scheduler with dynamic shifting for optimal results
|
||||
scheduler = None
|
||||
if self.scheduler in FLUX_SCHEDULER_MAP and not is_inpainting:
|
||||
# Only use scheduler for txt2img - use manual Euler for inpainting to preserve exact timesteps
|
||||
scheduler_class = FLUX_SCHEDULER_MAP[self.scheduler]
|
||||
# FlowMatchHeunDiscreteScheduler only supports num_train_timesteps and shift parameters
|
||||
# FlowMatchEulerDiscreteScheduler and FlowMatchLCMScheduler support dynamic shifting
|
||||
if self.scheduler == "heun":
|
||||
scheduler = scheduler_class(
|
||||
num_train_timesteps=1000,
|
||||
shift=3.0,
|
||||
)
|
||||
else:
|
||||
scheduler = scheduler_class(
|
||||
num_train_timesteps=1000,
|
||||
shift=3.0,
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.5,
|
||||
max_shift=1.15,
|
||||
base_image_seq_len=256,
|
||||
max_image_seq_len=4096,
|
||||
time_shift_type="exponential",
|
||||
)
|
||||
|
||||
# Prepare reference image extension for FLUX.2 Klein built-in editing
|
||||
ref_image_extension = None
|
||||
if self.kontext_conditioning:
|
||||
ref_image_extension = Flux2RefImageExtension(
|
||||
context=context,
|
||||
ref_image_conditioning=self.kontext_conditioning
|
||||
if isinstance(self.kontext_conditioning, list)
|
||||
else [self.kontext_conditioning],
|
||||
vae_field=self.vae,
|
||||
device=device,
|
||||
dtype=inference_dtype,
|
||||
bn_mean=bn_mean,
|
||||
bn_std=bn_std,
|
||||
)
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
# Load the transformer model
|
||||
(cached_weights, transformer) = exit_stack.enter_context(
|
||||
context.models.load(self.transformer.transformer).model_on_device()
|
||||
)
|
||||
config = transformer_config
|
||||
|
||||
# Determine if the model is quantized
|
||||
if config.format in [ModelFormat.Diffusers]:
|
||||
model_is_quantized = False
|
||||
elif config.format in [
|
||||
ModelFormat.BnbQuantizedLlmInt8b,
|
||||
ModelFormat.BnbQuantizednf4b,
|
||||
ModelFormat.GGUFQuantized,
|
||||
]:
|
||||
model_is_quantized = True
|
||||
else:
|
||||
model_is_quantized = False
|
||||
|
||||
# Apply LoRA models to the transformer
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
||||
dtype=inference_dtype,
|
||||
cached_weights=cached_weights,
|
||||
force_sidecar_patching=model_is_quantized,
|
||||
)
|
||||
)
|
||||
|
||||
# Prepare reference image conditioning if provided
|
||||
img_cond_seq = None
|
||||
img_cond_seq_ids = None
|
||||
if ref_image_extension is not None:
|
||||
# Ensure batch sizes match
|
||||
ref_image_extension.ensure_batch_size(x.shape[0])
|
||||
img_cond_seq, img_cond_seq_ids = (
|
||||
ref_image_extension.ref_image_latents,
|
||||
ref_image_extension.ref_image_ids,
|
||||
)
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=x,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
timesteps=timesteps,
|
||||
step_callback=self._build_step_callback(context),
|
||||
cfg_scale=cfg_scale_list,
|
||||
neg_txt=neg_txt,
|
||||
neg_txt_ids=neg_txt_ids,
|
||||
scheduler=scheduler,
|
||||
mu=mu,
|
||||
inpaint_extension=inpaint_extension,
|
||||
img_cond_seq=img_cond_seq,
|
||||
img_cond_seq_ids=img_cond_seq_ids,
|
||||
)
|
||||
|
||||
# Apply BN denormalization if BN stats are available
|
||||
# The diffusers Flux2KleinPipeline applies: latents = latents * bn_std + bn_mean
|
||||
# This transforms latents from normalized space to VAE's expected input space
|
||||
if bn_mean is not None and bn_std is not None:
|
||||
x = self._bn_denormalize(x, bn_mean, bn_std)
|
||||
|
||||
x = unpack_flux2(x.float(), self.height, self.width)
|
||||
return x
|
||||
|
||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
"""Prepare the inpaint mask."""
|
||||
if self.denoise_mask is None:
|
||||
return None
|
||||
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
mask = 1.0 - mask
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
mask = tv_resize(
|
||||
img=mask,
|
||||
size=[latent_height, latent_width],
|
||||
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
||||
antialias=False,
|
||||
)
|
||||
|
||||
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
||||
return mask.expand_as(latents)
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
"""Iterate over LoRA models to apply.
|
||||
|
||||
Converts BFL-format LoRA keys to diffusers format if needed, since FLUX.2 Klein
|
||||
uses Flux2Transformer2DModel (diffusers naming) but LoRAs may have been loaded
|
||||
with BFL naming (e.g. when a Klein 4B LoRA is misidentified as FLUX.1).
|
||||
"""
|
||||
for lora in self.transformer.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, ModelPatchRaw)
|
||||
converted = convert_bfl_lora_patch_to_diffusers(lora_info.model)
|
||||
yield (converted, lora.weight)
|
||||
del lora_info
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
"""Build a callback for step progress updates."""
|
||||
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
latents = state.latents.float()
|
||||
state.latents = unpack_flux2(latents, self.height, self.width).squeeze()
|
||||
context.util.flux2_step_callback(state)
|
||||
|
||||
return step_callback
|
||||
182
invokeai/app/invocations/flux2_klein_lora_loader.py
Normal file
182
invokeai/app/invocations/flux2_klein_lora_loader.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""FLUX.2 Klein LoRA Loader Invocation.
|
||||
|
||||
Applies LoRA models to a FLUX.2 Klein transformer and/or Qwen3 text encoder.
|
||||
Unlike standard FLUX which uses CLIP+T5, Klein uses only Qwen3 for text encoding.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, Qwen3EncoderField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
@invocation_output("flux2_klein_lora_loader_output")
|
||||
class Flux2KleinLoRALoaderOutput(BaseInvocationOutput):
|
||||
"""FLUX.2 Klein LoRA Loader Output"""
|
||||
|
||||
transformer: Optional[TransformerField] = OutputField(
|
||||
default=None, description=FieldDescriptions.transformer, title="Transformer"
|
||||
)
|
||||
qwen3_encoder: Optional[Qwen3EncoderField] = OutputField(
|
||||
default=None, description=FieldDescriptions.qwen3_encoder, title="Qwen3 Encoder"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux2_klein_lora_loader",
|
||||
title="Apply LoRA - Flux2 Klein",
|
||||
tags=["lora", "model", "flux", "klein", "flux2"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Flux2KleinLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply a LoRA model to a FLUX.2 Klein transformer and/or Qwen3 text encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model,
|
||||
title="LoRA",
|
||||
ui_model_base=BaseModelType.Flux2,
|
||||
ui_model_type=ModelType.LoRA,
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
transformer: TransformerField | None = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
qwen3_encoder: Qwen3EncoderField | None = InputField(
|
||||
default=None,
|
||||
title="Qwen3 Encoder",
|
||||
description=FieldDescriptions.qwen3_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> Flux2KleinLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||
|
||||
# Warn if LoRA variant doesn't match transformer variant
|
||||
lora_config = context.models.get_config(lora_key)
|
||||
lora_variant = getattr(lora_config, "variant", None)
|
||||
if lora_variant and self.transformer is not None:
|
||||
transformer_config = context.models.get_config(self.transformer.transformer.key)
|
||||
transformer_variant = getattr(transformer_config, "variant", None)
|
||||
if transformer_variant and lora_variant != transformer_variant:
|
||||
context.logger.warning(
|
||||
f"LoRA variant mismatch: LoRA '{lora_config.name}' is for {lora_variant.value} "
|
||||
f"but transformer is {transformer_variant.value}. This may cause shape errors."
|
||||
)
|
||||
|
||||
# Check for existing LoRAs with the same key.
|
||||
if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras):
|
||||
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
|
||||
if self.qwen3_encoder and any(lora.lora.key == lora_key for lora in self.qwen3_encoder.loras):
|
||||
raise ValueError(f'LoRA "{lora_key}" already applied to Qwen3 encoder.')
|
||||
|
||||
output = Flux2KleinLoRALoaderOutput()
|
||||
|
||||
# Attach LoRA layers to the models.
|
||||
if self.transformer is not None:
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
output.transformer.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
if self.qwen3_encoder is not None:
|
||||
output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True)
|
||||
output.qwen3_encoder.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux2_klein_lora_collection_loader",
|
||||
title="Apply LoRA Collection - Flux2 Klein",
|
||||
tags=["lora", "model", "flux", "klein", "flux2"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Flux2KleinLoRACollectionLoader(BaseInvocation):
|
||||
"""Applies a collection of LoRAs to a FLUX.2 Klein transformer and/or Qwen3 text encoder."""
|
||||
|
||||
loras: Optional[LoRAField | list[LoRAField]] = InputField(
|
||||
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
|
||||
)
|
||||
|
||||
transformer: Optional[TransformerField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
qwen3_encoder: Qwen3EncoderField | None = InputField(
|
||||
default=None,
|
||||
title="Qwen3 Encoder",
|
||||
description=FieldDescriptions.qwen3_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> Flux2KleinLoRALoaderOutput:
|
||||
output = Flux2KleinLoRALoaderOutput()
|
||||
loras = self.loras if isinstance(self.loras, list) else [self.loras]
|
||||
added_loras: list[str] = []
|
||||
|
||||
if self.transformer is not None:
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
|
||||
if self.qwen3_encoder is not None:
|
||||
output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True)
|
||||
|
||||
for lora in loras:
|
||||
if lora is None:
|
||||
continue
|
||||
if lora.lora.key in added_loras:
|
||||
continue
|
||||
|
||||
if not context.models.exists(lora.lora.key):
|
||||
raise Exception(f"Unknown lora: {lora.lora.key}!")
|
||||
|
||||
assert lora.lora.base in (BaseModelType.Flux, BaseModelType.Flux2)
|
||||
|
||||
# Warn if LoRA variant doesn't match transformer variant
|
||||
lora_config = context.models.get_config(lora.lora.key)
|
||||
lora_variant = getattr(lora_config, "variant", None)
|
||||
if lora_variant and self.transformer is not None:
|
||||
transformer_config = context.models.get_config(self.transformer.transformer.key)
|
||||
transformer_variant = getattr(transformer_config, "variant", None)
|
||||
if transformer_variant and lora_variant != transformer_variant:
|
||||
context.logger.warning(
|
||||
f"LoRA variant mismatch: LoRA '{lora_config.name}' is for {lora_variant.value} "
|
||||
f"but transformer is {transformer_variant.value}. This may cause shape errors."
|
||||
)
|
||||
|
||||
added_loras.append(lora.lora.key)
|
||||
|
||||
if self.transformer is not None and output.transformer is not None:
|
||||
output.transformer.loras.append(lora)
|
||||
|
||||
if self.qwen3_encoder is not None and output.qwen3_encoder is not None:
|
||||
output.qwen3_encoder.loras.append(lora)
|
||||
|
||||
return output
|
||||
222
invokeai/app/invocations/flux2_klein_model_loader.py
Normal file
222
invokeai/app/invocations/flux2_klein_model_loader.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Flux2 Klein Model Loader Invocation.
|
||||
|
||||
Loads a Flux2 Klein model with its Qwen3 text encoder and VAE.
|
||||
Unlike standard FLUX which uses CLIP+T5, Klein uses only Qwen3.
|
||||
"""
|
||||
|
||||
from typing import Literal, Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import (
|
||||
ModelIdentifierField,
|
||||
Qwen3EncoderField,
|
||||
TransformerField,
|
||||
VAEField,
|
||||
)
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
BaseModelType,
|
||||
Flux2VariantType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
Qwen3VariantType,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("flux2_klein_model_loader_output")
|
||||
class Flux2KleinModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Flux2 Klein model loader output."""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
qwen3_encoder: Qwen3EncoderField = OutputField(description=FieldDescriptions.qwen3_encoder, title="Qwen3 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
max_seq_len: Literal[256, 512] = OutputField(
|
||||
description="The max sequence length for the Qwen3 encoder.",
|
||||
title="Max Seq Length",
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux2_klein_model_loader",
|
||||
title="Main Model - Flux2 Klein",
|
||||
tags=["model", "flux", "klein", "qwen3"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Flux2KleinModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a Flux2 Klein model, outputting its submodels.
|
||||
|
||||
Flux2 Klein uses Qwen3 as the text encoder instead of CLIP+T5.
|
||||
It uses a 32-channel VAE (AutoencoderKLFlux2) instead of the 16-channel FLUX.1 VAE.
|
||||
|
||||
When using a Diffusers format model, both VAE and Qwen3 encoder are extracted
|
||||
automatically from the main model. You can override with standalone models:
|
||||
- Transformer: Always from Flux2 Klein main model
|
||||
- VAE: From main model (Diffusers) or standalone VAE
|
||||
- Qwen3 Encoder: From main model (Diffusers) or standalone Qwen3 model
|
||||
"""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
input=Input.Direct,
|
||||
ui_model_base=BaseModelType.Flux2,
|
||||
ui_model_type=ModelType.Main,
|
||||
title="Transformer",
|
||||
)
|
||||
|
||||
vae_model: Optional[ModelIdentifierField] = InputField(
|
||||
default=None,
|
||||
description="Standalone VAE model. Flux2 Klein uses the same VAE as FLUX (16-channel). "
|
||||
"If not provided, VAE will be loaded from the Qwen3 Source model.",
|
||||
input=Input.Direct,
|
||||
ui_model_base=[BaseModelType.Flux, BaseModelType.Flux2],
|
||||
ui_model_type=ModelType.VAE,
|
||||
title="VAE",
|
||||
)
|
||||
|
||||
qwen3_encoder_model: Optional[ModelIdentifierField] = InputField(
|
||||
default=None,
|
||||
description="Standalone Qwen3 Encoder model. "
|
||||
"If not provided, encoder will be loaded from the Qwen3 Source model.",
|
||||
input=Input.Direct,
|
||||
ui_model_type=ModelType.Qwen3Encoder,
|
||||
title="Qwen3 Encoder",
|
||||
)
|
||||
|
||||
qwen3_source_model: Optional[ModelIdentifierField] = InputField(
|
||||
default=None,
|
||||
description="Diffusers Flux2 Klein model to extract VAE and/or Qwen3 encoder from. "
|
||||
"Use this if you don't have separate VAE/Qwen3 models. "
|
||||
"Ignored if both VAE and Qwen3 Encoder are provided separately.",
|
||||
input=Input.Direct,
|
||||
ui_model_base=BaseModelType.Flux2,
|
||||
ui_model_type=ModelType.Main,
|
||||
ui_model_format=ModelFormat.Diffusers,
|
||||
title="Qwen3 Source (Diffusers)",
|
||||
)
|
||||
|
||||
max_seq_len: Literal[256, 512] = InputField(
|
||||
default=512,
|
||||
description="Max sequence length for the Qwen3 encoder.",
|
||||
title="Max Seq Length",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> Flux2KleinModelLoaderOutput:
|
||||
# Transformer always comes from the main model
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
|
||||
# Check if main model is Diffusers format (can extract VAE directly)
|
||||
main_config = context.models.get_config(self.model)
|
||||
main_is_diffusers = main_config.format == ModelFormat.Diffusers
|
||||
|
||||
# Determine VAE source
|
||||
# IMPORTANT: FLUX.2 Klein uses a 32-channel VAE (AutoencoderKLFlux2), not the 16-channel FLUX.1 VAE.
|
||||
# The VAE should come from the FLUX.2 Klein Diffusers model, not a separate FLUX VAE.
|
||||
if self.vae_model is not None:
|
||||
# Use standalone VAE (user explicitly selected one)
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
elif main_is_diffusers:
|
||||
# Extract VAE from main model (recommended for FLUX.2)
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
elif self.qwen3_source_model is not None:
|
||||
# Extract from Qwen3 source Diffusers model
|
||||
self._validate_diffusers_format(context, self.qwen3_source_model, "Qwen3 Source")
|
||||
vae = self.qwen3_source_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
else:
|
||||
raise ValueError(
|
||||
"No VAE source provided. Standalone safetensors/GGUF models require a separate VAE. "
|
||||
"Options:\n"
|
||||
" 1. Set 'VAE' to a standalone FLUX VAE model\n"
|
||||
" 2. Set 'Qwen3 Source' to a Diffusers Flux2 Klein model to extract the VAE from"
|
||||
)
|
||||
|
||||
# Determine Qwen3 Encoder source
|
||||
if self.qwen3_encoder_model is not None:
|
||||
# Use standalone Qwen3 Encoder - validate it matches the FLUX.2 Klein variant
|
||||
self._validate_qwen3_encoder_variant(context, main_config)
|
||||
qwen3_tokenizer = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
qwen3_encoder = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
elif main_is_diffusers:
|
||||
# Extract from main model (recommended for FLUX.2 Klein)
|
||||
qwen3_tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
qwen3_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
elif self.qwen3_source_model is not None:
|
||||
# Extract from separate Diffusers model
|
||||
self._validate_diffusers_format(context, self.qwen3_source_model, "Qwen3 Source")
|
||||
qwen3_tokenizer = self.qwen3_source_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
qwen3_encoder = self.qwen3_source_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
else:
|
||||
raise ValueError(
|
||||
"No Qwen3 Encoder source provided. Standalone safetensors/GGUF models require a separate text encoder. "
|
||||
"Options:\n"
|
||||
" 1. Set 'Qwen3 Encoder' to a standalone Qwen3 text encoder model "
|
||||
"(Klein 4B needs Qwen3 4B, Klein 9B needs Qwen3 8B)\n"
|
||||
" 2. Set 'Qwen3 Source' to a Diffusers Flux2 Klein model to extract the encoder from"
|
||||
)
|
||||
|
||||
return Flux2KleinModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
qwen3_encoder=Qwen3EncoderField(tokenizer=qwen3_tokenizer, text_encoder=qwen3_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
max_seq_len=self.max_seq_len,
|
||||
)
|
||||
|
||||
def _validate_diffusers_format(
|
||||
self, context: InvocationContext, model: ModelIdentifierField, model_name: str
|
||||
) -> None:
|
||||
"""Validate that a model is in Diffusers format."""
|
||||
config = context.models.get_config(model)
|
||||
if config.format != ModelFormat.Diffusers:
|
||||
raise ValueError(
|
||||
f"The {model_name} model must be a Diffusers format model. "
|
||||
f"The selected model '{config.name}' is in {config.format.value} format."
|
||||
)
|
||||
|
||||
def _validate_qwen3_encoder_variant(self, context: InvocationContext, main_config) -> None:
|
||||
"""Validate that the standalone Qwen3 encoder variant matches the FLUX.2 Klein variant.
|
||||
|
||||
- FLUX.2 Klein 4B requires Qwen3 4B encoder
|
||||
- FLUX.2 Klein 9B requires Qwen3 8B encoder
|
||||
"""
|
||||
if self.qwen3_encoder_model is None:
|
||||
return
|
||||
|
||||
# Get the Qwen3 encoder config
|
||||
qwen3_config = context.models.get_config(self.qwen3_encoder_model)
|
||||
|
||||
# Check if the config has a variant field
|
||||
if not hasattr(qwen3_config, "variant"):
|
||||
# Can't validate, skip
|
||||
return
|
||||
|
||||
qwen3_variant = qwen3_config.variant
|
||||
|
||||
# Get the FLUX.2 Klein variant from the main model config
|
||||
if not hasattr(main_config, "variant"):
|
||||
return
|
||||
|
||||
flux2_variant = main_config.variant
|
||||
|
||||
# Validate the variants match
|
||||
# Klein4B requires Qwen3_4B, Klein9B/Klein9BBase requires Qwen3_8B
|
||||
expected_qwen3_variant = None
|
||||
if flux2_variant == Flux2VariantType.Klein4B:
|
||||
expected_qwen3_variant = Qwen3VariantType.Qwen3_4B
|
||||
elif flux2_variant in (Flux2VariantType.Klein9B, Flux2VariantType.Klein9BBase):
|
||||
expected_qwen3_variant = Qwen3VariantType.Qwen3_8B
|
||||
|
||||
if expected_qwen3_variant is not None and qwen3_variant != expected_qwen3_variant:
|
||||
raise ValueError(
|
||||
f"Qwen3 encoder variant mismatch: FLUX.2 Klein {flux2_variant.value} requires "
|
||||
f"{expected_qwen3_variant.value} encoder, but {qwen3_variant.value} was selected. "
|
||||
f"Please select a matching Qwen3 encoder or use a Diffusers format model which includes the correct encoder."
|
||||
)
|
||||
200
invokeai/app/invocations/flux2_klein_text_encoder.py
Normal file
200
invokeai/app/invocations/flux2_klein_text_encoder.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Flux2 Klein Text Encoder Invocation.
|
||||
|
||||
Flux2 Klein uses Qwen3 as the text encoder instead of CLIP+T5.
|
||||
The key difference is that it extracts hidden states from layers (9, 18, 27)
|
||||
and stacks them together for richer text representations.
|
||||
|
||||
This implementation matches the diffusers Flux2KleinPipeline exactly.
|
||||
"""
|
||||
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Literal, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
Input,
|
||||
InputField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.invocations.model import Qwen3EncoderField
|
||||
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_T5_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
# FLUX.2 Klein extracts hidden states from these specific layers
|
||||
# Matching diffusers Flux2KleinPipeline: (9, 18, 27)
|
||||
# hidden_states[0] is embedding layer, so layer N is at index N
|
||||
KLEIN_EXTRACTION_LAYERS = (9, 18, 27)
|
||||
|
||||
# Default max sequence length for Klein models
|
||||
KLEIN_MAX_SEQ_LEN = 512
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux2_klein_text_encoder",
|
||||
title="Prompt - Flux2 Klein",
|
||||
tags=["prompt", "conditioning", "flux", "klein", "qwen3"],
|
||||
category="prompt",
|
||||
version="1.1.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Flux2KleinTextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes and preps a prompt for Flux2 Klein image generation.
|
||||
|
||||
Flux2 Klein uses Qwen3 as the text encoder, extracting hidden states from
|
||||
layers (9, 18, 27) and stacking them for richer text representations.
|
||||
This matches the diffusers Flux2KleinPipeline implementation exactly.
|
||||
"""
|
||||
|
||||
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
|
||||
qwen3_encoder: Qwen3EncoderField = InputField(
|
||||
title="Qwen3 Encoder",
|
||||
description=FieldDescriptions.qwen3_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
max_seq_len: Literal[256, 512] = InputField(
|
||||
default=512,
|
||||
description="Max sequence length for the Qwen3 encoder.",
|
||||
)
|
||||
mask: Optional[TensorField] = InputField(
|
||||
default=None,
|
||||
description="A mask defining the region that this conditioning prompt applies to.",
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
||||
# Open the exitstack here to lock models for the duration of the node
|
||||
with ExitStack() as exit_stack:
|
||||
# Pass the locked stack down to the helper function
|
||||
qwen3_embeds, pooled_embeds = self._encode_prompt(context, exit_stack)
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[FLUXConditioningInfo(clip_embeds=pooled_embeds, t5_embeds=qwen3_embeds)]
|
||||
)
|
||||
|
||||
# The models are still locked while we save the data
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return FluxConditioningOutput(
|
||||
conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask)
|
||||
)
|
||||
|
||||
def _encode_prompt(self, context: InvocationContext, exit_stack: ExitStack) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
prompt = self.prompt
|
||||
|
||||
# Reordered loading to prevent the annoying cache drop issue
|
||||
# This prevents it from being evicted while we look up the tokenizer
|
||||
text_encoder_info = context.models.load(self.qwen3_encoder.text_encoder)
|
||||
(cached_weights, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
|
||||
|
||||
# Now it is safe to load and lock the tokenizer
|
||||
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
|
||||
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
|
||||
|
||||
repaired_tensors = text_encoder_info.repair_required_tensors_on_device()
|
||||
device = get_effective_device(text_encoder)
|
||||
if repaired_tensors > 0:
|
||||
context.logger.warning(
|
||||
f"Recovered {repaired_tensors} required Qwen3 tensor(s) onto {device} after a partial device mismatch."
|
||||
)
|
||||
|
||||
# Apply LoRA models
|
||||
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=text_encoder,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=FLUX_LORA_T5_PREFIX,
|
||||
dtype=lora_dtype,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
|
||||
context.util.signal_progress("Running Qwen3 text encoder (Klein)")
|
||||
|
||||
if not isinstance(text_encoder, PreTrainedModel):
|
||||
raise TypeError(
|
||||
f"Expected PreTrainedModel for text encoder, got {type(text_encoder).__name__}. "
|
||||
"The Qwen3 encoder model may be corrupted or incompatible."
|
||||
)
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
||||
raise TypeError(
|
||||
f"Expected PreTrainedTokenizerBase for tokenizer, got {type(tokenizer).__name__}. "
|
||||
"The Qwen3 tokenizer may be corrupted or incompatible."
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
text: str = tokenizer.apply_chat_template( # type: ignore[assignment]
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.max_seq_len,
|
||||
)
|
||||
|
||||
input_ids = inputs["input_ids"].to(device)
|
||||
attention_mask = inputs["attention_mask"].to(device)
|
||||
|
||||
# Forward pass through the model
|
||||
outputs = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
|
||||
raise RuntimeError(
|
||||
"Text encoder did not return hidden_states. "
|
||||
"Ensure output_hidden_states=True is supported by this model."
|
||||
)
|
||||
num_hidden_layers = len(outputs.hidden_states)
|
||||
|
||||
hidden_states_list = []
|
||||
for layer_idx in KLEIN_EXTRACTION_LAYERS:
|
||||
if layer_idx >= num_hidden_layers:
|
||||
layer_idx = num_hidden_layers - 1
|
||||
hidden_states_list.append(outputs.hidden_states[layer_idx])
|
||||
|
||||
out = torch.stack(hidden_states_list, dim=1)
|
||||
out = out.to(dtype=text_encoder.dtype, device=device)
|
||||
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
|
||||
last_hidden_state = outputs.hidden_states[-1]
|
||||
expanded_mask = attention_mask.unsqueeze(-1).expand_as(last_hidden_state).float()
|
||||
sum_embeds = (last_hidden_state * expanded_mask).sum(dim=1)
|
||||
num_tokens = expanded_mask.sum(dim=1).clamp(min=1)
|
||||
pooled_embeds = sum_embeds / num_tokens
|
||||
|
||||
return prompt_embeds, pooled_embeds
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
"""Iterate over LoRA models to apply to the Qwen3 text encoder."""
|
||||
for lora in self.qwen3_encoder.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
if not isinstance(lora_info.model, ModelPatchRaw):
|
||||
raise TypeError(
|
||||
f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. "
|
||||
"The LoRA model may be corrupted or incompatible."
|
||||
)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
92
invokeai/app/invocations/flux2_vae_decode.py
Normal file
92
invokeai/app/invocations/flux2_vae_decode.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Flux2 Klein VAE Decode Invocation.
|
||||
|
||||
Decodes latents to images using the FLUX.2 32-channel VAE (AutoencoderKLFlux2).
|
||||
"""
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux2_vae_decode",
|
||||
title="Latents to Image - FLUX2",
|
||||
tags=["latents", "image", "vae", "l2i", "flux2", "klein"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Flux2VaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents using FLUX.2 Klein's 32-channel VAE."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
||||
"""Decode latents to image using FLUX.2 VAE.
|
||||
|
||||
Input latents should already be in the correct space after BN denormalization
|
||||
was applied in the denoiser. The VAE expects (B, 32, H, W) format.
|
||||
"""
|
||||
with vae_info.model_on_device() as (_, vae):
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
device = TorchDevice.choose_torch_device()
|
||||
latents = latents.to(device=device, dtype=vae_dtype)
|
||||
|
||||
# Decode using diffusers API
|
||||
decoded = vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
# Convert from [-1, 1] to [0, 1] then to [0, 255] PIL image
|
||||
img = (decoded / 2 + 0.5).clamp(0, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
img_np = (img * 255).byte().cpu().numpy()
|
||||
# Explicitly create RGB image (not grayscale)
|
||||
img_pil = Image.fromarray(img_np, mode="RGB")
|
||||
return img_pil
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
# Log latent statistics for debugging black image issues
|
||||
context.logger.debug(
|
||||
f"FLUX.2 VAE decode input: shape={latents.shape}, "
|
||||
f"min={latents.min().item():.4f}, max={latents.max().item():.4f}, "
|
||||
f"mean={latents.mean().item():.4f}"
|
||||
)
|
||||
|
||||
# Warn if input latents are all zeros or very small (would cause black images)
|
||||
if latents.abs().max() < 1e-6:
|
||||
context.logger.warning(
|
||||
"FLUX.2 VAE decode received near-zero latents! This will cause black images. "
|
||||
"The latent cache may be corrupted - try clearing the cache."
|
||||
)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
context.util.signal_progress("Running VAE")
|
||||
image = self._vae_decode(vae_info=vae_info, latents=latents)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
88
invokeai/app/invocations/flux2_vae_encode.py
Normal file
88
invokeai/app/invocations/flux2_vae_encode.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Flux2 Klein VAE Encode Invocation.
|
||||
|
||||
Encodes images to latents using the FLUX.2 32-channel VAE (AutoencoderKLFlux2).
|
||||
"""
|
||||
|
||||
import einops
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux2_vae_encode",
|
||||
title="Image to Latents - FLUX2",
|
||||
tags=["latents", "image", "vae", "i2l", "flux2", "klein"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Flux2VaeEncodeInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents using FLUX.2 Klein's 32-channel VAE."""
|
||||
|
||||
image: ImageField = InputField(
|
||||
description="The image to encode.",
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def _vae_encode(self, vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Encode image to latents using FLUX.2 VAE.
|
||||
|
||||
The VAE encodes to 32-channel latent space.
|
||||
Output latents shape: (B, 32, H/8, W/8).
|
||||
"""
|
||||
with vae_info.model_on_device() as (_, vae):
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
device = TorchDevice.choose_torch_device()
|
||||
image_tensor = image_tensor.to(device=device, dtype=vae_dtype)
|
||||
|
||||
# Encode using diffusers API
|
||||
# The VAE.encode() returns a DiagonalGaussianDistribution-like object
|
||||
latent_dist = vae.encode(image_tensor, return_dict=False)[0]
|
||||
|
||||
# Sample from the distribution (or use mode for deterministic output)
|
||||
# Using mode() for deterministic encoding
|
||||
if hasattr(latent_dist, "mode"):
|
||||
latents = latent_dist.mode()
|
||||
elif hasattr(latent_dist, "sample"):
|
||||
# Fall back to sampling if mode is not available
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
latents = latent_dist.sample(generator=generator)
|
||||
else:
|
||||
# Direct tensor output (some VAE implementations)
|
||||
latents = latent_dist
|
||||
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
|
||||
# Convert image to tensor (HWC -> CHW, normalize to [-1, 1])
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
context.util.signal_progress("Running VAE Encode")
|
||||
latents = self._vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
@@ -50,7 +50,7 @@ class FluxControlNetOutput(BaseInvocationOutput):
|
||||
"flux_controlnet",
|
||||
title="FLUX ControlNet",
|
||||
tags=["controlnet", "flux"],
|
||||
category="controlnet",
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxControlNetInvocation(BaseInvocation):
|
||||
|
||||
@@ -32,6 +32,13 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.dype.presets import (
|
||||
DYPE_PRESET_LABELS,
|
||||
DYPE_PRESET_OFF,
|
||||
DyPEPreset,
|
||||
get_dype_config_from_preset,
|
||||
)
|
||||
from invokeai.backend.flux.extensions.dype_extension import DyPEExtension
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.kontext_extension import KontextExtension
|
||||
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
|
||||
@@ -47,6 +54,7 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
pack,
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.flux.schedulers import FLUX_SCHEDULER_LABELS, FLUX_SCHEDULER_MAP, FLUX_SCHEDULER_NAME_VALUES
|
||||
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, FluxVariantType, ModelFormat, ModelType
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
@@ -62,8 +70,8 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
"flux_denoise",
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="4.1.0",
|
||||
category="latents",
|
||||
version="4.5.1",
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation):
|
||||
"""Run denoising process with a FLUX transformer model."""
|
||||
@@ -132,6 +140,12 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
num_steps: int = InputField(
|
||||
default=4, description="Number of diffusion steps. Recommended values are schnell: 4, dev: 50."
|
||||
)
|
||||
scheduler: FLUX_SCHEDULER_NAME_VALUES = InputField(
|
||||
default="euler",
|
||||
description="Scheduler (sampler) for the denoising process. 'euler' is fast and standard. "
|
||||
"'heun' is 2nd-order (better quality, 2x slower). 'lcm' is optimized for few steps.",
|
||||
ui_choice_labels=FLUX_SCHEDULER_LABELS,
|
||||
)
|
||||
guidance: float = InputField(
|
||||
default=4.0,
|
||||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
|
||||
@@ -159,6 +173,31 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
# DyPE (Dynamic Position Extrapolation) for high-resolution generation
|
||||
dype_preset: DyPEPreset = InputField(
|
||||
default=DYPE_PRESET_OFF,
|
||||
description=(
|
||||
"DyPE preset for high-resolution generation. 'auto' enables automatically for resolutions > 1536px. "
|
||||
"'area' enables automatically based on image area. '4k' uses optimized settings for 4K output."
|
||||
),
|
||||
ui_order=100,
|
||||
ui_choice_labels=DYPE_PRESET_LABELS,
|
||||
)
|
||||
dype_scale: Optional[float] = InputField(
|
||||
default=None,
|
||||
ge=0.0,
|
||||
le=8.0,
|
||||
description="DyPE magnitude (λs). Higher values = stronger extrapolation. Only used when dype_preset is not 'off'.",
|
||||
ui_order=101,
|
||||
)
|
||||
dype_exponent: Optional[float] = InputField(
|
||||
default=None,
|
||||
ge=0.0,
|
||||
le=1000.0,
|
||||
description="DyPE decay speed (λt). Controls transition from low to high frequency detail. Only used when dype_preset is not 'off'.",
|
||||
ui_order=102,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
@@ -232,8 +271,14 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
transformer_config = context.models.get_config(self.transformer.transformer)
|
||||
assert transformer_config.base is BaseModelType.Flux and transformer_config.type is ModelType.Main
|
||||
is_schnell = transformer_config.variant is FluxVariantType.Schnell
|
||||
assert (
|
||||
transformer_config.base in (BaseModelType.Flux, BaseModelType.Flux2)
|
||||
and transformer_config.type is ModelType.Main
|
||||
)
|
||||
# Schnell is only for FLUX.1, FLUX.2 Klein behaves like Dev (with guidance)
|
||||
is_schnell = (
|
||||
transformer_config.base is BaseModelType.Flux and transformer_config.variant is FluxVariantType.Schnell
|
||||
)
|
||||
|
||||
# Calculate the timestep schedule.
|
||||
timesteps = get_schedule(
|
||||
@@ -242,6 +287,12 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
# Create scheduler if not using default euler
|
||||
scheduler = None
|
||||
if self.scheduler in FLUX_SCHEDULER_MAP:
|
||||
scheduler_class = FLUX_SCHEDULER_MAP[self.scheduler]
|
||||
scheduler = scheduler_class(num_train_timesteps=1000)
|
||||
|
||||
# Clip the timesteps schedule based on denoising_start and denoising_end.
|
||||
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
|
||||
|
||||
@@ -409,6 +460,30 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
kontext_extension.ensure_batch_size(x.shape[0])
|
||||
img_cond_seq, img_cond_seq_ids = kontext_extension.kontext_latents, kontext_extension.kontext_ids
|
||||
|
||||
# Prepare DyPE extension for high-resolution generation
|
||||
dype_extension: DyPEExtension | None = None
|
||||
dype_config = get_dype_config_from_preset(
|
||||
preset=self.dype_preset,
|
||||
width=self.width,
|
||||
height=self.height,
|
||||
custom_scale=self.dype_scale,
|
||||
custom_exponent=self.dype_exponent,
|
||||
)
|
||||
if dype_config is not None:
|
||||
dype_extension = DyPEExtension(
|
||||
config=dype_config,
|
||||
target_height=self.height,
|
||||
target_width=self.width,
|
||||
)
|
||||
context.logger.info(
|
||||
f"DyPE enabled: resolution={self.width}x{self.height}, preset={self.dype_preset}, "
|
||||
f"method={dype_config.method}, scale={dype_config.dype_scale:.2f}, "
|
||||
f"exponent={dype_config.dype_exponent:.2f}, start_sigma={dype_config.dype_start_sigma:.2f}, "
|
||||
f"base_resolution={dype_config.base_resolution}"
|
||||
)
|
||||
else:
|
||||
context.logger.debug(f"DyPE disabled: resolution={self.width}x{self.height}, preset={self.dype_preset}")
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
img=x,
|
||||
@@ -426,6 +501,8 @@ class FluxDenoiseInvocation(BaseInvocation):
|
||||
img_cond=img_cond,
|
||||
img_cond_seq=img_cond_seq,
|
||||
img_cond_seq_ids=img_cond_seq_ids,
|
||||
dype_extension=dype_extension,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
|
||||
@@ -29,7 +29,7 @@ class FluxFillOutput(BaseInvocationOutput):
|
||||
"flux_fill",
|
||||
title="FLUX Fill Conditioning",
|
||||
tags=["inpaint"],
|
||||
category="inpaint",
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
"flux_ip_adapter",
|
||||
title="FLUX IP-Adapter",
|
||||
tags=["ip_adapter", "control"],
|
||||
category="ip_adapter",
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxIPAdapterInvocation(BaseInvocation):
|
||||
|
||||
@@ -162,7 +162,7 @@ class FLUXLoRACollectionLoader(BaseInvocation):
|
||||
if not context.models.exists(lora.lora.key):
|
||||
raise Exception(f"Unknown lora: {lora.lora.key}!")
|
||||
|
||||
assert lora.lora.base is BaseModelType.Flux
|
||||
assert lora.lora.base in (BaseModelType.Flux, BaseModelType.Flux2)
|
||||
|
||||
added_loras.append(lora.lora.key)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.t5_model_identifier import (
|
||||
@@ -37,28 +37,25 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
title="Main Model - FLUX",
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.6",
|
||||
version="1.0.7",
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a flux base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.flux_model,
|
||||
input=Input.Direct,
|
||||
ui_model_base=BaseModelType.Flux,
|
||||
ui_model_type=ModelType.Main,
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Direct,
|
||||
title="T5 Encoder",
|
||||
ui_model_type=ModelType.T5Encoder,
|
||||
)
|
||||
|
||||
clip_embed_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.clip_embed_model,
|
||||
input=Input.Direct,
|
||||
title="CLIP Embed",
|
||||
ui_model_type=ModelType.CLIPEmbed,
|
||||
)
|
||||
|
||||
@@ -47,7 +47,7 @@ DOWNSAMPLING_FUNCTIONS = Literal["nearest", "bilinear", "bicubic", "area", "near
|
||||
"flux_redux",
|
||||
title="FLUX Redux",
|
||||
tags=["ip_adapter", "control"],
|
||||
category="ip_adapter",
|
||||
category="conditioning",
|
||||
version="2.1.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
|
||||
@@ -28,7 +28,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
|
||||
"flux_text_encoder",
|
||||
title="Prompt - FLUX",
|
||||
tags=["prompt", "conditioning", "flux"],
|
||||
category="conditioning",
|
||||
category="prompt",
|
||||
version="1.1.2",
|
||||
)
|
||||
class FluxTextEncoderInvocation(BaseInvocation):
|
||||
@@ -58,6 +58,12 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
# scoped. This ensures that the T5 model can be freed and gc'd before loading the CLIP model (if necessary).
|
||||
t5_embeddings = self._t5_encode(context)
|
||||
clip_embeddings = self._clip_encode(context)
|
||||
|
||||
# Move embeddings to CPU for storage to save VRAM
|
||||
# They will be moved to the appropriate device when used by the denoiser
|
||||
t5_embeddings = t5_embeddings.detach().to("cpu")
|
||||
clip_embeddings = clip_embeddings.detach().to("cpu")
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = {
|
||||
"grounding_dino",
|
||||
title="Grounding DINO (Text Prompt Object Detection)",
|
||||
tags=["prompt", "object detection"],
|
||||
category="image",
|
||||
category="segmentation",
|
||||
version="1.0.0",
|
||||
)
|
||||
class GroundingDinoInvocation(BaseInvocation):
|
||||
|
||||
@@ -11,7 +11,7 @@ from invokeai.backend.image_util.hed import ControlNetHED_Apache2, HEDEdgeDetect
|
||||
"hed_edge_detection",
|
||||
title="HED Edge Detection",
|
||||
tags=["controlnet", "hed", "softedge"],
|
||||
category="controlnet",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.0.0",
|
||||
)
|
||||
class HEDEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@@ -21,6 +21,7 @@ class IdealSizeOutput(BaseInvocationOutput):
|
||||
"ideal_size",
|
||||
title="Ideal Size - SD1.5, SDXL",
|
||||
tags=["latents", "math", "ideal_size"],
|
||||
category="latents",
|
||||
version="1.0.6",
|
||||
)
|
||||
class IdealSizeInvocation(BaseInvocation):
|
||||
@@ -46,7 +47,12 @@ class IdealSizeInvocation(BaseInvocation):
|
||||
dimension = 512
|
||||
elif unet_config.base == BaseModelType.StableDiffusion2:
|
||||
dimension = 768
|
||||
elif unet_config.base in (BaseModelType.StableDiffusionXL, BaseModelType.Flux, BaseModelType.StableDiffusion3):
|
||||
elif unet_config.base in (
|
||||
BaseModelType.StableDiffusionXL,
|
||||
BaseModelType.Flux,
|
||||
BaseModelType.Flux2,
|
||||
BaseModelType.StableDiffusion3,
|
||||
):
|
||||
dimension = 1024
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {unet_config.base}")
|
||||
|
||||
@@ -21,7 +21,7 @@ from invokeai.app.invocations.fields import (
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.primitives import ImageOutput, StringOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
@@ -197,7 +197,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"tomask",
|
||||
title="Mask from Alpha",
|
||||
tags=["image", "mask"],
|
||||
category="image",
|
||||
category="mask",
|
||||
version="1.2.2",
|
||||
)
|
||||
class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -581,11 +581,30 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"decode_watermark",
|
||||
title="Decode Invisible Watermark",
|
||||
tags=["image", "watermark"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class DecodeInvisibleWatermarkInvocation(BaseInvocation):
|
||||
"""Decode an invisible watermark from an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to decode the watermark from")
|
||||
length: int = InputField(default=8, description="The expected watermark length in bytes")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
watermark = InvisibleWatermark.decode_watermark(image, self.length)
|
||||
return StringOutput(value=watermark)
|
||||
|
||||
|
||||
@invocation(
|
||||
"mask_edge",
|
||||
title="Mask Edge",
|
||||
tags=["image", "mask", "inpaint"],
|
||||
category="image",
|
||||
category="mask",
|
||||
version="1.2.2",
|
||||
)
|
||||
class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -624,7 +643,7 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"mask_combine",
|
||||
title="Combine Masks",
|
||||
tags=["image", "mask", "multiply"],
|
||||
category="image",
|
||||
category="mask",
|
||||
version="1.2.2",
|
||||
)
|
||||
class MaskCombineInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -955,7 +974,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"save_image",
|
||||
title="Save Image",
|
||||
tags=["primitives", "image"],
|
||||
category="primitives",
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
use_cache=False,
|
||||
)
|
||||
@@ -976,7 +995,7 @@ class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"canvas_paste_back",
|
||||
title="Canvas Paste Back",
|
||||
tags=["image", "combine"],
|
||||
category="image",
|
||||
category="canvas",
|
||||
version="1.0.1",
|
||||
)
|
||||
class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -1013,7 +1032,7 @@ class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"mask_from_id",
|
||||
title="Mask from Segmented Image",
|
||||
tags=["image", "mask", "id"],
|
||||
category="image",
|
||||
category="mask",
|
||||
version="1.0.1",
|
||||
)
|
||||
class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -1050,7 +1069,7 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"canvas_v2_mask_and_crop",
|
||||
title="Canvas V2 Mask and Crop",
|
||||
tags=["image", "mask", "id"],
|
||||
category="image",
|
||||
category="canvas",
|
||||
version="1.0.0",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
@@ -1091,7 +1110,7 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
|
||||
@invocation(
|
||||
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.1"
|
||||
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="mask", version="1.0.1"
|
||||
)
|
||||
class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
|
||||
@@ -1180,7 +1199,7 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"apply_mask_to_image",
|
||||
title="Apply Mask to Image",
|
||||
tags=["image", "mask", "blend"],
|
||||
category="image",
|
||||
category="mask",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ApplyMaskToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -1355,7 +1374,7 @@ class PasteImageIntoBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoar
|
||||
"flux_kontext_image_prep",
|
||||
title="FLUX Kontext Image Prep",
|
||||
tags=["image", "concatenate", "flux", "kontext"],
|
||||
category="image",
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxKontextConcatenateImagesInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@@ -23,7 +23,7 @@ class ImagePanelCoordinateOutput(BaseInvocationOutput):
|
||||
"image_panel_layout",
|
||||
title="Image Panel Layout",
|
||||
tags=["image", "panel", "layout"],
|
||||
category="image",
|
||||
category="canvas",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
|
||||
@@ -73,7 +73,7 @@ CLIP_VISION_MODEL_MAP: dict[Literal["ViT-L", "ViT-H", "ViT-G"], StarterModel] =
|
||||
"ip_adapter",
|
||||
title="IP-Adapter - SD1.5, SDXL",
|
||||
tags=["ip_adapter", "control"],
|
||||
category="ip_adapter",
|
||||
category="conditioning",
|
||||
version="1.5.1",
|
||||
)
|
||||
class IPAdapterInvocation(BaseInvocation):
|
||||
|
||||
@@ -11,7 +11,7 @@ from invokeai.backend.image_util.lineart import Generator, LineartEdgeDetector
|
||||
"lineart_edge_detection",
|
||||
title="Lineart Edge Detection",
|
||||
tags=["controlnet", "lineart"],
|
||||
category="controlnet",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.0.0",
|
||||
)
|
||||
class LineartEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@@ -9,7 +9,7 @@ from invokeai.backend.image_util.lineart_anime import LineartAnimeEdgeDetector,
|
||||
"lineart_anime_edge_detection",
|
||||
title="Lineart Anime Edge Detection",
|
||||
tags=["controlnet", "lineart"],
|
||||
category="controlnet",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.0.0",
|
||||
)
|
||||
class LineartAnimeEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@@ -19,7 +19,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
"llava_onevision_vllm",
|
||||
title="LLaVA OneVision VLLM",
|
||||
tags=["vllm"],
|
||||
category="vllm",
|
||||
category="multimodal",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
|
||||
34
invokeai/app/invocations/logic.py
Normal file
34
invokeai/app/invocations/logic.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("if_output")
|
||||
class IfInvocationOutput(BaseInvocationOutput):
|
||||
value: Optional[Any] = OutputField(
|
||||
default=None, description="The selected value", title="Output", ui_type=UIType.Any
|
||||
)
|
||||
|
||||
|
||||
@invocation("if", title="If", tags=["logic", "conditional"], category="math", version="1.0.0")
|
||||
class IfInvocation(BaseInvocation):
|
||||
"""Selects between two optional inputs based on a boolean condition."""
|
||||
|
||||
condition: bool = InputField(default=False, description="The condition used to select an input", title="Condition")
|
||||
true_input: Optional[Any] = InputField(
|
||||
default=None,
|
||||
description="Selected when the condition is true",
|
||||
title="True Input",
|
||||
ui_type=UIType.Any,
|
||||
)
|
||||
false_input: Optional[Any] = InputField(
|
||||
default=None,
|
||||
description="Selected when the condition is false",
|
||||
title="False Input",
|
||||
ui_type=UIType.Any,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IfInvocationOutput:
|
||||
return IfInvocationOutput(value=self.true_input if self.condition else self.false_input)
|
||||
@@ -24,7 +24,7 @@ from invokeai.backend.image_util.util import pil_to_np
|
||||
"rectangle_mask",
|
||||
title="Create Rectangle Mask",
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
category="mask",
|
||||
version="1.0.1",
|
||||
)
|
||||
class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
||||
@@ -55,7 +55,7 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
||||
"alpha_mask_to_tensor",
|
||||
title="Alpha Mask to Tensor",
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
category="mask",
|
||||
version="1.0.0",
|
||||
)
|
||||
class AlphaMaskToTensorInvocation(BaseInvocation):
|
||||
@@ -83,7 +83,7 @@ class AlphaMaskToTensorInvocation(BaseInvocation):
|
||||
"invert_tensor_mask",
|
||||
title="Invert Tensor Mask",
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
category="mask",
|
||||
version="1.1.0",
|
||||
)
|
||||
class InvertTensorMaskInvocation(BaseInvocation):
|
||||
@@ -115,7 +115,7 @@ class InvertTensorMaskInvocation(BaseInvocation):
|
||||
"image_mask_to_tensor",
|
||||
title="Image Mask to Tensor",
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
category="mask",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
|
||||
|
||||
@@ -9,7 +9,7 @@ from invokeai.backend.image_util.mediapipe_face import detect_faces
|
||||
"mediapipe_face_detection",
|
||||
title="MediaPipe Face Detection",
|
||||
tags=["controlnet", "face"],
|
||||
category="controlnet",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MediaPipeFaceDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@@ -150,6 +150,10 @@ GENERATION_MODES = Literal[
|
||||
"flux_img2img",
|
||||
"flux_inpaint",
|
||||
"flux_outpaint",
|
||||
"flux2_txt2img",
|
||||
"flux2_img2img",
|
||||
"flux2_inpaint",
|
||||
"flux2_outpaint",
|
||||
"sd3_txt2img",
|
||||
"sd3_img2img",
|
||||
"sd3_inpaint",
|
||||
@@ -162,6 +166,14 @@ GENERATION_MODES = Literal[
|
||||
"z_image_img2img",
|
||||
"z_image_inpaint",
|
||||
"z_image_outpaint",
|
||||
"qwen_image_txt2img",
|
||||
"qwen_image_img2img",
|
||||
"qwen_image_inpaint",
|
||||
"qwen_image_outpaint",
|
||||
"anima_txt2img",
|
||||
"anima_img2img",
|
||||
"anima_inpaint",
|
||||
"anima_outpaint",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ from invokeai.app.invocations.primitives import (
|
||||
)
|
||||
from invokeai.app.invocations.scheduler import SchedulerOutput
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField, T2IAdapterInvocation
|
||||
from invokeai.app.invocations.z_image_denoise import ZImageDenoiseInvocation
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
@@ -620,7 +621,7 @@ class LatentsMetaOutput(LatentsOutput, MetadataOutput):
|
||||
"denoise_latents_meta",
|
||||
title=f"{DenoiseLatentsInvocation.UIConfig.title} + Metadata",
|
||||
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||
category="latents",
|
||||
category="metadata",
|
||||
version="1.1.1",
|
||||
)
|
||||
class DenoiseLatentsMetaInvocation(DenoiseLatentsInvocation, WithMetadata):
|
||||
@@ -685,7 +686,7 @@ class DenoiseLatentsMetaInvocation(DenoiseLatentsInvocation, WithMetadata):
|
||||
"flux_denoise_meta",
|
||||
title=f"{FluxDenoiseInvocation.UIConfig.title} + Metadata",
|
||||
tags=["flux", "latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||
category="latents",
|
||||
category="metadata",
|
||||
version="1.0.1",
|
||||
)
|
||||
class FluxDenoiseLatentsMetaInvocation(FluxDenoiseInvocation, WithMetadata):
|
||||
@@ -729,6 +730,52 @@ class FluxDenoiseLatentsMetaInvocation(FluxDenoiseInvocation, WithMetadata):
|
||||
return LatentsMetaOutput(**params, metadata=MetadataField.model_validate(md))
|
||||
|
||||
|
||||
@invocation(
|
||||
"z_image_denoise_meta",
|
||||
title=f"{ZImageDenoiseInvocation.UIConfig.title} + Metadata",
|
||||
tags=["z-image", "latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||
category="metadata",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ZImageDenoiseMetaInvocation(ZImageDenoiseInvocation, WithMetadata):
|
||||
"""Run denoising process with a Z-Image transformer model + metadata."""
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsMetaOutput:
|
||||
def _loras_to_json(obj: Union[Any, list[Any]]):
|
||||
if not isinstance(obj, list):
|
||||
obj = [obj]
|
||||
|
||||
output: list[dict[str, Any]] = []
|
||||
for item in obj:
|
||||
output.append(
|
||||
LoRAMetadataField(
|
||||
model=item.lora,
|
||||
weight=item.weight,
|
||||
).model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
|
||||
)
|
||||
return output
|
||||
|
||||
obj = super().invoke(context)
|
||||
|
||||
md: Dict[str, Any] = {} if self.metadata is None else self.metadata.root
|
||||
md.update({"width": obj.width})
|
||||
md.update({"height": obj.height})
|
||||
md.update({"steps": self.steps})
|
||||
md.update({"guidance": self.guidance_scale})
|
||||
md.update({"denoising_start": self.denoising_start})
|
||||
md.update({"denoising_end": self.denoising_end})
|
||||
md.update({"scheduler": self.scheduler})
|
||||
md.update({"model": self.transformer.transformer})
|
||||
md.update({"seed": self.seed})
|
||||
if len(self.transformer.loras) > 0:
|
||||
md.update({"loras": _loras_to_json(self.transformer.loras)})
|
||||
|
||||
params = obj.__dict__.copy()
|
||||
del params["type"]
|
||||
|
||||
return LatentsMetaOutput(**params, metadata=MetadataField.model_validate(md))
|
||||
|
||||
|
||||
@invocation(
|
||||
"metadata_to_vae",
|
||||
title="Metadata To VAE",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user