Merge branch 'external-models' into alibabacloud/dashscope

# Conflicts:
#	invokeai/backend/model_manager/starter_models.py
This commit is contained in:
Alexander Eichhorn
2026-04-14 23:08:06 +02:00
491 changed files with 31802 additions and 2771 deletions

View File

@@ -12,12 +12,13 @@ help:
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
@echo "test Run the unit tests."
@echo "update-config-docstring Update the app's config docstring so mkdocs can autogenerate it correctly."
@echo "frontend-install Install the pnpm modules needed for the front end"
@echo "frontend-build Build the frontend in order to run on localhost:9090"
@echo "frontend-install Install the pnpm modules needed for the frontend"
@echo "frontend-build Build the frontend for localhost:9090"
@echo "frontend-test Run the frontend test suite once"
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
@echo "frontend-prettier Format the frontend using lint:prettier"
@echo "wheel Build the wheel for the current version"
@echo "frontend-lint Run frontend checks and fixable lint/format steps"
@echo "wheel Build the wheel for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
@echo "docs Serve the mkdocs site with live reload"
@@ -57,6 +58,10 @@ frontend-install:
frontend-build:
cd invokeai/frontend/web && pnpm build
# Run the frontend test suite once
frontend-test:
cd invokeai/frontend/web && pnpm run test:run
# Run the frontend in dev mode
frontend-dev:
cd invokeai/frontend/web && pnpm dev

View File

@@ -58,15 +58,39 @@ Invoke offers a fully featured workflow management solution, enabling users to c
Invoke features an organized gallery system for easily storing, accessing, and remixing your content in the Invoke workspace. Images can be dragged/dropped onto any Image-base UI element in the application, and rich metadata within the Image allows for easy recall of key prompts or settings used in your workflow.
### Model Support
- SD 1.5
- SD 2.0
- SDXL
- SD 3.5 Medium
- SD 3.5 Large
- CogView 4
- Flux.1 Dev
- Flux.1 Schnell
- Flux.1 Kontext
- Flux.1 Krea
- Flux Redux
- Flux Fill
- Flux.2 Klein 4B
- Flux.2 Klein 9B
- Z-Image Turbo
- Z-Image Base
- Anima
- Qwen Image
- Qwen Image Edit
- Nano Banana (API Only)
- GPT Image (API Only)
- Wan (API Only)
### Other features
- Support for both ckpt and diffusers models
- SD1.5, SD2.0, SDXL, and FLUX support
- Support for ckpt, diffusers, and some gguf models
- Upscaling Tools
- Embedding Manager & Support
- Model Manager & Support
- Workflow creation & management
- Node-Based Architecture
- Object Segmentation & Selection Models (SAM / SAM2)
## Contributing

View File

@@ -0,0 +1,205 @@
# Canvas Projects — Technical Documentation
## Overview
Canvas Projects provide a save/load mechanism for the entire canvas state. The feature serializes all canvas entities, generation parameters, reference images, and their associated image files into a ZIP-based `.invk` file. On load, it restores the full state, handling image deduplication and re-uploading as needed.
## File Format
The `.invk` file is a standard ZIP archive with the following structure:
```
project.invk
├── manifest.json
├── canvas_state.json
├── params.json
├── ref_images.json
├── loras.json
└── images/
├── {image_name_1}.png
├── {image_name_2}.png
└── ...
```
### manifest.json
Schema version and metadata. Validated on load with Zod.
```json
{
"version": 1,
"appVersion": "5.12.0",
"createdAt": "2026-02-26T12:00:00.000Z",
"name": "My Canvas Project"
}
```
| Field | Type | Description |
|---|---|---|
| `version` | `number` | Schema version, currently `1`. Used for migration logic on load. |
| `appVersion` | `string` | InvokeAI version that created the file. Informational only. |
| `createdAt` | `string` | ISO 8601 timestamp. |
| `name` | `string` | User-provided project name. Also used as the download filename. |
### canvas_state.json
The serialized canvas entity tree. Type: `CanvasProjectState`.
```typescript
type CanvasProjectState = {
rasterLayers: CanvasRasterLayerState[];
controlLayers: CanvasControlLayerState[];
inpaintMasks: CanvasInpaintMaskState[];
regionalGuidance: CanvasRegionalGuidanceState[];
bbox: CanvasState['bbox'];
selectedEntityIdentifier: CanvasState['selectedEntityIdentifier'];
bookmarkedEntityIdentifier: CanvasState['bookmarkedEntityIdentifier'];
};
```
Each entity contains its full state including all canvas objects (brush lines, eraser lines, rect shapes, images). Image objects reference files by `image_name` which correspond to files in the `images/` folder.
### params.json
The complete generation parameters state (`ParamsState`). Optional on load (older files may not have it). This includes all fields from the params Redux slice:
- Prompts (positive, negative, prompt history)
- Core generation settings (seed, steps, CFG scale, guidance, scheduler, iterations)
- Model selections (main model, VAE, FLUX VAE, T5 encoder, CLIP embed models, refiner, Z-Image models, Klein models)
- Dimensions (width, height, aspect ratio)
- Img2img strength
- Infill settings (method, tile size, patchmatch downscale, color)
- Canvas coherence settings (mode, edge size, min denoise)
- Refiner parameters (steps, CFG scale, scheduler, aesthetic scores, start)
- FLUX-specific settings (scheduler, DyPE preset/scale/exponent)
- Z-Image-specific settings (scheduler, seed variance)
- Upscale settings (scheduler, CFG scale)
- Seamless tiling, mask blur, CLIP skip, VAE precision, CPU noise, color compensation
### ref_images.json
Global reference image entities (`RefImageState[]`). These are IP-Adapter / FLUX Redux configs with `CroppableImageWithDims` containing both original and cropped image references. Optional on load.
### loras.json
Array of LoRA configurations (`LoRA[]`). Each entry contains:
```typescript
type LoRA = {
id: string;
isEnabled: boolean;
model: ModelIdentifierField;
weight: number;
};
```
Optional on load. Like models, LoRA identifiers are stored as-is — if a LoRA is not installed when loading, the entry is restored but may not be usable.
### images/
All image files referenced anywhere in the state. Keyed by their original `image_name`. On save, each image is fetched from the backend via `GET /api/v1/images/i/{name}/full` and stored as-is.
## Key Source Files
| File | Purpose |
|---|---|
| `features/controlLayers/util/canvasProjectFile.ts` | Types, constants, image name collection, remapping, existence checking |
| `features/controlLayers/hooks/useCanvasProjectSave.ts` | Save hook — collects Redux state, fetches images, builds ZIP |
| `features/controlLayers/hooks/useCanvasProjectLoad.ts` | Load hook — parses ZIP, deduplicates images, dispatches state |
| `features/controlLayers/components/SaveCanvasProjectDialog.tsx` | Save name dialog + `useSaveCanvasProjectWithDialog` hook |
| `features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog.tsx` | Load confirmation dialog + `useLoadCanvasProjectWithDialog` hook |
| `features/controlLayers/components/Toolbar/CanvasToolbarProjectMenuButton.tsx` | Toolbar dropdown UI |
| `features/controlLayers/store/canvasSlice.ts` | `canvasProjectRecalled` Redux action |
## Save Flow
1. User clicks "Save Canvas Project" → `SaveCanvasProjectDialog` opens asking for a project name
2. On confirm, `saveCanvasProject(name)` is called
3. Read Redux state via selectors: `selectCanvasSlice()`, `selectParamsSlice()`, `selectRefImagesSlice()`, `selectLoRAsSlice()`
4. Build `CanvasProjectState` from the canvas slice; use `paramsState` directly for params
5. Walk all entities to collect every `image_name` reference via `collectImageNames()`:
- `CanvasImageState.image.image_name` in layer/mask objects
- `CroppableImageWithDims.original.image.image_name` in global ref images
- `CroppableImageWithDims.crop.image.image_name` in cropped ref images
- `ImageWithDims.image_name` in regional guidance ref images
6. Fetch each image from the backend API
7. Build ZIP with JSZip: add `manifest.json` (including `name`), `canvas_state.json`, `params.json`, `ref_images.json`, and all images into `images/`
8. Sanitize the name for filesystem use and generate blob, trigger download as `{name}.invk`
## Load Flow
1. User selects `.invk` file → confirmation dialog opens
2. On confirm, parse ZIP with JSZip
3. Validate manifest version via Zod schema
4. Read `canvas_state.json`, `params.json` (optional), `ref_images.json` (optional)
5. Collect all `image_name` references from the loaded state
6. **Deduplicate images**: for each referenced image, check if it exists on the server via `getImageDTOSafe(image_name)`
- Already exists → skip (no upload)
- Missing → upload from ZIP via `uploadImage()`, record `oldName → newName` mapping
7. Remap all `image_name` values in the loaded state using the mapping (only for re-uploaded images whose names changed)
8. Dispatch Redux actions:
- `canvasProjectRecalled()` — restores all canvas entities, bbox, selected/bookmarked entity
- `refImagesRecalled()` — restores global reference images
- `paramsRecalled()` — replaces the entire params state in one action
- `loraAllDeleted()` + `loraRecalled()` — restores LoRAs
9. Show success/error toast
## Image Name Collection & Remapping
The `canvasProjectFile.ts` utility provides two parallel sets of functions:
**Collection** (`collectImageNames`): Walks the entire state tree and returns a `Set<string>` of all referenced `image_name` values. This is used by both save (to know which images to fetch) and load (to know which images to check/upload).
**Remapping** (`remapCanvasState`, `remapRefImages`): Deep-clones state objects and replaces `image_name` values using a `Map<string, string>` mapping. Only images that were re-uploaded with a different name are remapped. Images that already existed on the server are left unchanged.
Both walk the same paths through the state tree:
- Layer/mask objects → `CanvasImageState.image.image_name`
- Regional guidance ref images → `ImageWithDims.image_name`
- Global ref images → `CroppableImageWithDims.original.image.image_name` and `.crop.image.image_name`
## Extending the Format
### Adding new optional data (non-breaking)
Add a new JSON file to the ZIP. No version bump needed.
1. **Save**: Add `zip.file('new_data.json', JSON.stringify(data))` in `useCanvasProjectSave.ts`
2. **Load**: Read with `zip.file('new_data.json')` in `useCanvasProjectLoad.ts` — check for `null` so older project files without it still load
3. **Dispatch**: Add the appropriate Redux action to restore the data
### Adding new entity types with images
1. Extend `CanvasProjectState` type in `canvasProjectFile.ts`
2. Add collection logic in `collectImageNames()` to walk the new entity's objects
3. Add remapping logic in `remapCanvasState()` to update image names
4. Include the new entity array in both save and load hooks
5. Handle it in the `canvasProjectRecalled` reducer in `canvasSlice.ts`
### Breaking schema changes
1. Bump `CANVAS_PROJECT_VERSION` in `canvasProjectFile.ts`
2. Update the Zod manifest schema: `version: z.union([z.literal(1), z.literal(2)])`
3. Add migration logic in the load hook: check version, transform v1 → v2 before dispatching
## UI Architecture
### Save dialog
The save flow uses a **nanostore atom** (`$isOpen`) to control the `SaveCanvasProjectDialog`:
1. `useSaveCanvasProjectWithDialog()` — returns a callback that sets `$isOpen` to `true`
2. `SaveCanvasProjectDialog` (singleton in `GlobalModalIsolator`) — renders an `AlertDialog` with a name input
3. On save → calls `saveCanvasProject(name)` and closes the dialog
4. On cancel → closes the dialog
### Load dialog
The load flow uses a **nanostore atom** (`$pendingFile`) to decouple the file dialog from the confirmation dialog:
1. `useLoadCanvasProjectWithDialog()` — opens a programmatic file input (`document.createElement('input')`)
2. On file selection → sets `$pendingFile` atom
3. `LoadCanvasProjectConfirmationAlertDialog` (singleton in `GlobalModalIsolator`) — subscribes to `$pendingFile` via `useStore()`
4. On accept → calls `loadCanvasProject(file)` and clears the atom
5. On cancel → clears the atom
The programmatic file input approach was chosen because the context menu component uses `isLazy: true`, which unmounts the DOM tree when the menu closes — a hidden `<input>` element inside the menu would be destroyed before the file dialog returns.

View File

@@ -0,0 +1,32 @@
Lasso Tool
===========
- The Lasso tool creates selections and inpaint masks by drawing freehand or polygonal regions on the canvas.
How to open the Lasso tool
--------------------------
- Click the Lasso icon in the toolbar.
- Hotkey: press `L` (default). The hotkey is shown in the tool's tooltip and can be customized in Hotkeys settings.
Modes
-----
- Freehand (default)
- Hold the pointer and drag to draw a continuous contour.
- Long segments are broken into intermediate points to keep the line continuous.
- Very long strokes may be simplified after drawing to reduce point count for performance.
- Polygon
- Click to place points; click the first point (or a point near it) to close the polygon.
- The tool snaps the closing point to the start for precise closures.
Basic interactions
------------------
- Switch modes with the mode toggle in the toolbar.
- To close a polygon: click the starting point again or click near it — the tool aligns the final point to the start to complete the shape.
- The selection will be added to the current Inpaint Mask layer. If no Inpaint Mask layer exists, a new one will be created automatically.
Tips & behavior
---------------
- Hold `Space` to temporarily switch to the View tool for panning and zooming; release `Space` to return to the Lasso tool and continue drawing.
- When using the Polygon mode, you can hold `Shift` to snap points to horizontal, vertical, or 45-degree angles for more precise shapes.
- Hold `Ctrl` (Windows/Linux) or `Command` (macOS) while drawing to subtract from the current selection instead of adding to it.

View File

@@ -0,0 +1,56 @@
---
title: Canvas Projects
---
# :material-folder-zip: Canvas Projects
## Save and Restore Your Canvas Work
Canvas Projects let you save your entire canvas setup to a file and load it back later. This is useful when you want to:
- **Switch between tasks** without losing your current canvas arrangement
- **Back up complex setups** with multiple layers, masks, and reference images
- **Share canvas layouts** with others or transfer them between machines
- **Recover from deleted images** — all images are embedded in the project file
## What Gets Saved
A canvas project file (`.invk`) captures everything about your current canvas session:
- **All layers** — raster layers, control layers, inpaint masks, regional guidance
- **All drawn content** — brush strokes, pasted images, eraser marks
- **Reference images** — global IP-Adapter / FLUX Redux images with crop settings
- **Regional guidance** — per-region prompts and reference images
- **Bounding box** — position, size, aspect ratio, and scale settings
- **All generation parameters** — prompts, seed, steps, CFG scale, guidance, scheduler, model, VAE, dimensions, img2img strength, infill settings, canvas coherence, refiner settings, FLUX/Z-Image specific parameters, and more
- **LoRAs** — all added LoRA models with their weights and enabled/disabled state
## How to Save a Project
You can save from two places:
1. **Toolbar** — Click the **Archive icon** in the canvas toolbar, then select **Save Canvas Project**
2. **Context menu** — Right-click the canvas, open the **Project** submenu, then select **Save Canvas Project**
A dialog will ask you to enter a **project name**. This name is used as the filename (e.g., entering "My Portrait" saves as `My Portrait.invk`) and is stored inside the project file.
## How to Load a Project
1. **Toolbar** — Click the **Archive icon**, then select **Load Canvas Project**
2. **Context menu** — Right-click the canvas, open the **Project** submenu, then select **Load Canvas Project**
A file dialog will open. Select your `.invk` file. You will see a confirmation dialog warning that loading will replace your current canvas. Click **Load** to proceed.
### What Happens on Load
- Your current canvas is **completely replaced** — all existing layers, masks, reference images, and parameters are overwritten
- Images that are already present on your InvokeAI server are reused automatically (no duplicate uploads)
- Images that were deleted from the server are re-uploaded from the project file
- If the saved model is not installed on your system, the model identifier is still restored — you will need to select an available model manually
## Good to Know
- **No undo** — Loading a project replaces your canvas entirely. There is no way to undo this action, so save your current project first if you want to keep it.
- **Image deduplication** — When loading, images already on your server are not re-uploaded. Only missing images are uploaded from the project file.
- **File size** — The `.invk` file size depends on the number and resolution of images in your canvas. A project with many high-resolution layers can be large.
- **Model availability** — The project saves which model was selected, but does not include the model itself. If the model is not installed when you load the project, you will need to select a different one.

View File

@@ -140,13 +140,13 @@ As a regular user, you can:
- ✅ View your own generation queue
- ✅ Customize your UI preferences (theme, hotkeys, etc.)
- ✅ View available models (read-only access to Model Manager)
-Access shared boards (based on permissions granted to you) (FUTURE FEATURE)
-Access workflows marked as public (FUTURE FEATURE)
-View shared and public boards created by other users
-View and use workflows marked as shared by other users
You cannot:
- ❌ Add, delete, or modify models
- ❌ View or modify other users' boards, images, or workflows
- ❌ View or modify other users' private boards, images, or workflows
- ❌ Manage user accounts
- ❌ Access system configuration
- ❌ View or cancel other users' generation tasks
@@ -173,7 +173,7 @@ Administrators have all regular user capabilities, plus:
- ✅ Full model management (add, delete, configure models)
- ✅ Create and manage user accounts
- ✅ View and manage all users' generation queues
-Create and manage shared boards (FUTURE FEATURE)
-View and manage all users' boards, images, and workflows (including system-owned legacy content)
- ✅ Access system configuration
- ✅ Grant or revoke admin privileges
@@ -183,23 +183,30 @@ Administrators have all regular user capabilities, plus:
### Image Boards
In multi-user model, Image Boards work as before. Each user can create an unlimited number of boards and organize their images and assets as they see fit. Boards are private: you cannot see a board owned by a different user.
In multi-user mode, each user can create an unlimited number of boards and organize their images and assets as they see fit. Boards have three visibility levels:
!!! tip "Shared Boards"
InvokeAI 6.13 will add support for creating public boards that are accessible to all users.
- **Private** (default): Only you (and administrators) can see and modify the board.
- **Shared**: All users can view the board and its contents, but only you (and administrators) can modify it (rename, archive, delete, or add/remove images).
- **Public**: All users can view the board. Only you (and administrators) can modify the board's structure (rename, archive, delete).
The Administrator can see all users Image Boards and their contents.
To change a board's visibility, right-click on the board and select the desired visibility option.
### Going From Multi-User to Single-User mode
Administrators can see and manage all users' image boards and their contents regardless of visibility settings.
### Going From Multi-User to Single-User Mode
If an InvokeAI instance was in multiuser mode and then restarted in single user mode (by setting `multiuser: false` in the configuration file), all users' boards will be consolidated in one place. Any images that were in "Uncategorized" will be merged together into a single Uncategorized board. If, at a later date, the server is restarted in multi-user mode, the boards and images will be separated and restored to their owners.
### Workflows
In the current released version (6.12) workflows are always shared among users. Any workflow that you create will be visible to other users and vice-versa, and there is no protection against one user modifying another user's workflow.
Each user has their own private workflow library. Workflows you create are visible only to you by default.
!!! tip "Private and Shared Workflows"
InvokeAI 6.13 will provide the ability to create private and shared workflows. A private workflow can only be viewed by the user who created it. At any time, however, the user can designate the workflow *shared*, in which case it can be opened on a read-only basis by all logged-in users.
You can share a workflow with other users by marking it as **shared** (public). Shared workflows appear in all users' workflow libraries and can be opened by anyone, but only the owner (or an administrator) can modify or delete them.
To share a workflow, open it and use the sharing controls to toggle its public/shared status.
!!! warning "Preexisting workflows after enabling multi-user mode"
When you enable multi-user mode for the first time on an existing InvokeAI installation, all workflows that were created before multi-user mode was activated will appear in the **shared workflows** section. These preexisting workflows are owned by the internal "system" account and are visible to all users. Administrators can edit or delete these shared legacy workflows. Regular users can view and use them but cannot modify them.
### The Generation Queue
@@ -330,11 +337,11 @@ These settings are stored per-user and won't affect other users.
### Can other users see my images?
No, unless you add them to a shared board (FUTURE FEATURE). All your personal boards and images are private.
Not unless you change your board's visibility to "shared" or "public". All personal boards and images are private by default.
### Can I share my workflows with others?
Not directly. Ask your administrator to mark workflows as public if you want to share them.
Yes. You can mark any workflow as shared (public), which makes it visible to all users. Other users can view and use shared workflows, but only you or an administrator can modify or delete them.
### How long do sessions last?

View File

@@ -49,10 +49,12 @@ from invokeai.app.services.users.users_default import UserService
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
AnimaConditioningInfo,
BasicConditioningInfo,
CogView4ConditioningInfo,
ConditioningFieldData,
FLUXConditioningInfo,
QwenImageConditioningInfo,
SD3ConditioningInfo,
SDXLConditioningInfo,
ZImageConditioningInfo,
@@ -143,6 +145,8 @@ class ApiDependencies:
SD3ConditioningInfo,
CogView4ConditioningInfo,
ZImageConditioningInfo,
QwenImageConditioningInfo,
AnimaConditioningInfo,
],
ephemeral=True,
),

View File

@@ -10,14 +10,15 @@ from fastapi import Body, HTTPException, Path
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from invokeai.app.api.auth_dependencies import AdminUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.config.config_default import (
DefaultInvokeAIAppConfig,
EXTERNAL_PROVIDER_CONFIG_FIELDS,
DefaultInvokeAIAppConfig,
InvokeAIAppConfig,
get_config,
load_external_api_keys,
load_and_migrate_config,
load_external_api_keys,
)
from invokeai.app.services.external_generation.external_generation_common import ExternalProviderStatus
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
@@ -103,6 +104,16 @@ EXTERNAL_PROVIDER_FIELDS: dict[str, tuple[str, str]] = {
_EXTERNAL_PROVIDER_CONFIG_LOCK = Lock()
class UpdateAppGenerationSettingsRequest(BaseModel):
"""Writable generation-related app settings."""
max_queue_history: int | None = Field(
default=None,
ge=0,
description="Keep the last N completed, failed, and canceled queue items on startup. Set to 0 to prune all terminal items.",
)
@app_router.get(
"/runtime_config", operation_id="get_runtime_config", status_code=200, response_model=InvokeAIAppConfigWithSetFields
)
@@ -111,6 +122,30 @@ async def get_runtime_config() -> InvokeAIAppConfigWithSetFields:
return InvokeAIAppConfigWithSetFields(set_fields=config.model_fields_set, config=config)
@app_router.patch(
"/runtime_config",
operation_id="update_runtime_config",
status_code=200,
response_model=InvokeAIAppConfigWithSetFields,
)
async def update_runtime_config(
_: AdminUserOrDefault,
changes: UpdateAppGenerationSettingsRequest = Body(description="Writable runtime configuration changes"),
) -> InvokeAIAppConfigWithSetFields:
config = get_config()
update_dict = changes.model_dump(exclude_unset=True)
config.update_config(update_dict)
if config.config_file_path.exists():
persisted_config = load_and_migrate_config(config.config_file_path)
else:
persisted_config = DefaultInvokeAIAppConfig()
persisted_config.update_config(update_dict)
persisted_config.write_file(config.config_file_path)
return InvokeAIAppConfigWithSetFields(set_fields=config.model_fields_set, config=config)
@app_router.get(
"/external_providers/status",
operation_id="get_external_provider_statuses",

View File

@@ -80,6 +80,7 @@ class SetupStatusResponse(BaseModel):
setup_required: bool = Field(description="Whether initial setup is required")
multiuser_enabled: bool = Field(description="Whether multiuser mode is enabled")
strict_password_checking: bool = Field(description="Whether strict password requirements are enforced")
admin_email: str | None = Field(default=None, description="Email of the first active admin user, if any")
@auth_router.get("/status", response_model=SetupStatusResponse)
@@ -94,15 +95,25 @@ async def get_setup_status() -> SetupStatusResponse:
# If multiuser is disabled, setup is never required
if not config.multiuser:
return SetupStatusResponse(
setup_required=False, multiuser_enabled=False, strict_password_checking=config.strict_password_checking
setup_required=False,
multiuser_enabled=False,
strict_password_checking=config.strict_password_checking,
admin_email=None,
)
# In multiuser mode, check if an admin exists
user_service = ApiDependencies.invoker.services.users
setup_required = not user_service.has_admin()
# Only expose admin_email during initial setup to avoid leaking
# administrator identity on public deployments.
admin_email = user_service.get_admin_email() if setup_required else None
return SetupStatusResponse(
setup_required=setup_required, multiuser_enabled=True, strict_password_checking=config.strict_password_checking
setup_required=setup_required,
multiuser_enabled=True,
strict_password_checking=config.strict_password_checking,
admin_email=admin_email,
)
@@ -150,6 +161,7 @@ async def login(
user_id=user.user_id,
email=user.email,
is_admin=user.is_admin,
remember_me=request.remember_me,
)
token = create_access_token(token_data, expires_delta)

View File

@@ -1,12 +1,53 @@
from fastapi import Body, HTTPException
from fastapi.routing import APIRouter
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.images.images_common import AddImagesToBoardResult, RemoveImagesFromBoardResult
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
def _assert_board_write_access(board_id: str, current_user: CurrentUserOrDefault) -> None:
"""Raise 403 if the current user may not mutate the given board.
Write access is granted when ANY of these hold:
- The user is an admin.
- The user owns the board.
- The board visibility is Public (public boards accept contributions from any user).
"""
from invokeai.app.services.board_records.board_records_common import BoardVisibility
try:
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
except Exception:
raise HTTPException(status_code=404, detail="Board not found")
if current_user.is_admin:
return
if board.user_id == current_user.user_id:
return
if board.board_visibility == BoardVisibility.Public:
return
raise HTTPException(status_code=403, detail="Not authorized to modify this board")
def _assert_image_direct_owner(image_name: str, current_user: CurrentUserOrDefault) -> None:
"""Raise 403 if the current user is not the direct owner of the image.
This is intentionally stricter than _assert_image_owner in images.py:
board ownership is NOT sufficient here. Allowing a user to add someone
else's image to their own board would grant them mutation rights via the
board-ownership fallback in _assert_image_owner, escalating read access
into write access.
"""
if current_user.is_admin:
return
owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name)
if owner is not None and owner == current_user.user_id:
return
raise HTTPException(status_code=403, detail="Not authorized to move this image")
@board_images_router.post(
"/",
operation_id="add_image_to_board",
@@ -17,14 +58,17 @@ board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
response_model=AddImagesToBoardResult,
)
async def add_image_to_board(
current_user: CurrentUserOrDefault,
board_id: str = Body(description="The id of the board to add to"),
image_name: str = Body(description="The name of the image to add"),
) -> AddImagesToBoardResult:
"""Creates a board_image"""
_assert_board_write_access(board_id, current_user)
_assert_image_direct_owner(image_name, current_user)
try:
added_images: set[str] = set()
affected_boards: set[str] = set()
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
old_board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) or "none"
ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
added_images.add(image_name)
affected_boards.add(board_id)
@@ -48,13 +92,16 @@ async def add_image_to_board(
response_model=RemoveImagesFromBoardResult,
)
async def remove_image_from_board(
current_user: CurrentUserOrDefault,
image_name: str = Body(description="The name of the image to remove", embed=True),
) -> RemoveImagesFromBoardResult:
"""Removes an image from its board, if it had one"""
try:
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
if old_board_id != "none":
_assert_board_write_access(old_board_id, current_user)
removed_images: set[str] = set()
affected_boards: set[str] = set()
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
removed_images.add(image_name)
affected_boards.add("none")
@@ -64,6 +111,8 @@ async def remove_image_from_board(
affected_boards=list(affected_boards),
)
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=500, detail="Failed to remove image from board")
@@ -78,16 +127,21 @@ async def remove_image_from_board(
response_model=AddImagesToBoardResult,
)
async def add_images_to_board(
current_user: CurrentUserOrDefault,
board_id: str = Body(description="The id of the board to add to"),
image_names: list[str] = Body(description="The names of the images to add", embed=True),
) -> AddImagesToBoardResult:
"""Adds a list of images to a board"""
_assert_board_write_access(board_id, current_user)
try:
added_images: set[str] = set()
affected_boards: set[str] = set()
for image_name in image_names:
try:
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
_assert_image_direct_owner(image_name, current_user)
old_board_id = (
ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) or "none"
)
ApiDependencies.invoker.services.board_images.add_image_to_board(
board_id=board_id,
image_name=image_name,
@@ -96,12 +150,16 @@ async def add_images_to_board(
affected_boards.add(board_id)
affected_boards.add(old_board_id)
except HTTPException:
raise
except Exception:
pass
return AddImagesToBoardResult(
added_images=list(added_images),
affected_boards=list(affected_boards),
)
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=500, detail="Failed to add images to board")
@@ -116,6 +174,7 @@ async def add_images_to_board(
response_model=RemoveImagesFromBoardResult,
)
async def remove_images_from_board(
current_user: CurrentUserOrDefault,
image_names: list[str] = Body(description="The names of the images to remove", embed=True),
) -> RemoveImagesFromBoardResult:
"""Removes a list of images from their board, if they had one"""
@@ -125,15 +184,21 @@ async def remove_images_from_board(
for image_name in image_names:
try:
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
if old_board_id != "none":
_assert_board_write_access(old_board_id, current_user)
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
removed_images.add(image_name)
affected_boards.add("none")
affected_boards.add(old_board_id)
except HTTPException:
raise
except Exception:
pass
return RemoveImagesFromBoardResult(
removed_images=list(removed_images),
affected_boards=list(affected_boards),
)
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=500, detail="Failed to remove images from board")

View File

@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy, BoardVisibility
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
@@ -56,7 +56,14 @@ async def get_board(
except Exception:
raise HTTPException(status_code=404, detail="Board not found")
if not current_user.is_admin and result.user_id != current_user.user_id:
# Admins can access any board.
# Owners can access their own boards.
# Shared and public boards are visible to all authenticated users.
if (
not current_user.is_admin
and result.user_id != current_user.user_id
and result.board_visibility == BoardVisibility.Private
):
raise HTTPException(status_code=403, detail="Not authorized to access this board")
return result
@@ -188,7 +195,11 @@ async def list_all_board_image_names(
except Exception:
raise HTTPException(status_code=404, detail="Board not found")
if not current_user.is_admin and board.user_id != current_user.user_id:
if (
not current_user.is_admin
and board.user_id != current_user.user_id
and board.board_visibility == BoardVisibility.Private
):
raise HTTPException(status_code=403, detail="Not authorized to access this board")
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
@@ -196,4 +207,15 @@ async def list_all_board_image_names(
categories,
is_intermediate,
)
# For uncategorized images (board_id="none"), filter to only the caller's
# images so that one user cannot enumerate another's uncategorized images.
# Admin users can see all uncategorized images.
if board_id == "none" and not current_user.is_admin:
image_names = [
name
for name in image_names
if ApiDependencies.invoker.services.image_records.get_user_id(name) == current_user.user_id
]
return image_names

View File

@@ -38,6 +38,96 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"])
IMAGE_MAX_AGE = 31536000
def _assert_image_owner(image_name: str, current_user: CurrentUserOrDefault) -> None:
"""Raise 403 if the current user does not own the image and is not an admin.
Ownership is satisfied when ANY of these hold:
- The user is an admin.
- The user is the image's direct owner (image_records.user_id).
- The user owns the board the image sits on.
- The image sits on a Public board (public boards grant mutation rights).
"""
from invokeai.app.services.board_records.board_records_common import BoardVisibility
if current_user.is_admin:
return
owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name)
if owner is not None and owner == current_user.user_id:
return
# Check whether the user owns the board the image belongs to,
# or the board is Public (public boards grant mutation rights).
board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name)
if board_id is not None:
try:
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
if board.user_id == current_user.user_id:
return
if board.board_visibility == BoardVisibility.Public:
return
except Exception:
pass
raise HTTPException(status_code=403, detail="Not authorized to modify this image")
def _assert_image_read_access(image_name: str, current_user: CurrentUserOrDefault) -> None:
"""Raise 403 if the current user may not view the image.
Access is granted when ANY of these hold:
- The user is an admin.
- The user owns the image.
- The image sits on a shared or public board.
"""
from invokeai.app.services.board_records.board_records_common import BoardVisibility
if current_user.is_admin:
return
owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name)
if owner is not None and owner == current_user.user_id:
return
# Check whether the image's board makes it visible to other users.
board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name)
if board_id is not None:
try:
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public):
return
except Exception:
pass
raise HTTPException(status_code=403, detail="Not authorized to access this image")
def _assert_board_read_access(board_id: str, current_user: CurrentUserOrDefault) -> None:
"""Raise 403 if the current user may not read images from this board.
Access is granted when ANY of these hold:
- The user is an admin.
- The user owns the board.
- The board visibility is Shared or Public.
"""
from invokeai.app.services.board_records.board_records_common import BoardVisibility
if current_user.is_admin:
return
try:
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
except Exception:
raise HTTPException(status_code=404, detail="Board not found")
if board.user_id == current_user.user_id:
return
if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public):
return
raise HTTPException(status_code=403, detail="Not authorized to access this board")
class ResizeToDimensions(BaseModel):
width: int = Field(..., gt=0)
height: int = Field(..., gt=0)
@@ -83,6 +173,22 @@ async def upload_image(
),
) -> ImageDTO:
"""Uploads an image for the current user"""
# If uploading into a board, verify the user has write access.
# Public boards allow uploads from any authenticated user.
if board_id is not None:
from invokeai.app.services.board_records.board_records_common import BoardVisibility
try:
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
except Exception:
raise HTTPException(status_code=404, detail="Board not found")
if (
not current_user.is_admin
and board.user_id != current_user.user_id
and board.board_visibility != BoardVisibility.Public
):
raise HTTPException(status_code=403, detail="Not authorized to upload to this board")
if not file.content_type or not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
@@ -165,9 +271,11 @@ async def create_image_upload_entry(
@images_router.delete("/i/{image_name}", operation_id="delete_image", response_model=DeleteImagesResult)
async def delete_image(
current_user: CurrentUserOrDefault,
image_name: str = Path(description="The name of the image to delete"),
) -> DeleteImagesResult:
"""Deletes an image"""
_assert_image_owner(image_name, current_user)
deleted_images: set[str] = set()
affected_boards: set[str] = set()
@@ -189,26 +297,31 @@ async def delete_image(
@images_router.delete("/intermediates", operation_id="clear_intermediates")
async def clear_intermediates() -> int:
"""Clears all intermediates"""
async def clear_intermediates(
current_user: CurrentUserOrDefault,
) -> int:
"""Clears all intermediates. Requires admin."""
if not current_user.is_admin:
raise HTTPException(status_code=403, detail="Only admins can clear all intermediates")
try:
count_deleted = ApiDependencies.invoker.services.images.delete_intermediates()
return count_deleted
except Exception:
raise HTTPException(status_code=500, detail="Failed to clear intermediates")
pass
@images_router.get("/intermediates", operation_id="get_intermediates_count")
async def get_intermediates_count() -> int:
"""Gets the count of intermediate images"""
async def get_intermediates_count(
current_user: CurrentUserOrDefault,
) -> int:
"""Gets the count of intermediate images. Non-admin users only see their own intermediates."""
try:
return ApiDependencies.invoker.services.images.get_intermediates_count()
user_id = None if current_user.is_admin else current_user.user_id
return ApiDependencies.invoker.services.images.get_intermediates_count(user_id=user_id)
except Exception:
raise HTTPException(status_code=500, detail="Failed to get intermediates")
pass
@images_router.patch(
@@ -217,10 +330,12 @@ async def get_intermediates_count() -> int:
response_model=ImageDTO,
)
async def update_image(
current_user: CurrentUserOrDefault,
image_name: str = Path(description="The name of the image to update"),
image_changes: ImageRecordChanges = Body(description="The changes to apply to the image"),
) -> ImageDTO:
"""Updates an image"""
_assert_image_owner(image_name, current_user)
try:
return ApiDependencies.invoker.services.images.update(image_name, image_changes)
@@ -234,9 +349,11 @@ async def update_image(
response_model=ImageDTO,
)
async def get_image_dto(
current_user: CurrentUserOrDefault,
image_name: str = Path(description="The name of image to get"),
) -> ImageDTO:
"""Gets an image's DTO"""
_assert_image_read_access(image_name, current_user)
try:
return ApiDependencies.invoker.services.images.get_dto(image_name)
@@ -250,9 +367,11 @@ async def get_image_dto(
response_model=Optional[MetadataField],
)
async def get_image_metadata(
current_user: CurrentUserOrDefault,
image_name: str = Path(description="The name of image to get"),
) -> Optional[MetadataField]:
"""Gets an image's metadata"""
_assert_image_read_access(image_name, current_user)
try:
return ApiDependencies.invoker.services.images.get_metadata(image_name)
@@ -269,8 +388,11 @@ class WorkflowAndGraphResponse(BaseModel):
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=WorkflowAndGraphResponse
)
async def get_image_workflow(
current_user: CurrentUserOrDefault,
image_name: str = Path(description="The name of image whose workflow to get"),
) -> WorkflowAndGraphResponse:
_assert_image_read_access(image_name, current_user)
try:
workflow = ApiDependencies.invoker.services.images.get_workflow(image_name)
graph = ApiDependencies.invoker.services.images.get_graph(image_name)
@@ -306,8 +428,12 @@ async def get_image_workflow(
async def get_image_full(
image_name: str = Path(description="The name of full-resolution image file to get"),
) -> Response:
"""Gets a full-resolution image file"""
"""Gets a full-resolution image file.
This endpoint is intentionally unauthenticated because browsers load images
via <img src> tags which cannot send Bearer tokens. Image names are UUIDs,
providing security through unguessability.
"""
try:
path = ApiDependencies.invoker.services.images.get_path(image_name)
with open(path, "rb") as f:
@@ -335,8 +461,12 @@ async def get_image_full(
async def get_image_thumbnail(
image_name: str = Path(description="The name of thumbnail image file to get"),
) -> Response:
"""Gets a thumbnail image file"""
"""Gets a thumbnail image file.
This endpoint is intentionally unauthenticated because browsers load images
via <img src> tags which cannot send Bearer tokens. Image names are UUIDs,
providing security through unguessability.
"""
try:
path = ApiDependencies.invoker.services.images.get_path(image_name, thumbnail=True)
with open(path, "rb") as f:
@@ -354,9 +484,11 @@ async def get_image_thumbnail(
response_model=ImageUrlsDTO,
)
async def get_image_urls(
current_user: CurrentUserOrDefault,
image_name: str = Path(description="The name of the image whose URL to get"),
) -> ImageUrlsDTO:
"""Gets an image and thumbnail URL"""
_assert_image_read_access(image_name, current_user)
try:
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
@@ -392,6 +524,11 @@ async def list_image_dtos(
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of image DTOs for the current user"""
# Validate that the caller can read from this board before listing its images.
# "none" is a sentinel for uncategorized images and is handled by the SQL layer.
if board_id is not None and board_id != "none":
_assert_board_read_access(board_id, current_user)
image_dtos = ApiDependencies.invoker.services.images.get_many(
offset,
limit,
@@ -410,6 +547,7 @@ async def list_image_dtos(
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesResult)
async def delete_images_from_list(
current_user: CurrentUserOrDefault,
image_names: list[str] = Body(description="The list of names of images to delete", embed=True),
) -> DeleteImagesResult:
try:
@@ -417,24 +555,31 @@ async def delete_images_from_list(
affected_boards: set[str] = set()
for image_name in image_names:
try:
_assert_image_owner(image_name, current_user)
image_dto = ApiDependencies.invoker.services.images.get_dto(image_name)
board_id = image_dto.board_id or "none"
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.add(image_name)
affected_boards.add(board_id)
except HTTPException:
raise
except Exception:
pass
return DeleteImagesResult(
deleted_images=list(deleted_images),
affected_boards=list(affected_boards),
)
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=500, detail="Failed to delete images")
@images_router.delete("/uncategorized", operation_id="delete_uncategorized_images", response_model=DeleteImagesResult)
async def delete_uncategorized_images() -> DeleteImagesResult:
"""Deletes all images that are uncategorized"""
async def delete_uncategorized_images(
current_user: CurrentUserOrDefault,
) -> DeleteImagesResult:
"""Deletes all uncategorized images owned by the current user (or all if admin)"""
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
board_id="none", categories=None, is_intermediate=None
@@ -445,9 +590,13 @@ async def delete_uncategorized_images() -> DeleteImagesResult:
affected_boards: set[str] = set()
for image_name in image_names:
try:
_assert_image_owner(image_name, current_user)
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.add(image_name)
affected_boards.add("none")
except HTTPException:
# Skip images not owned by the current user
pass
except Exception:
pass
return DeleteImagesResult(
@@ -464,6 +613,7 @@ class ImagesUpdatedFromListResult(BaseModel):
@images_router.post("/star", operation_id="star_images_in_list", response_model=StarredImagesResult)
async def star_images_in_list(
current_user: CurrentUserOrDefault,
image_names: list[str] = Body(description="The list of names of images to star", embed=True),
) -> StarredImagesResult:
try:
@@ -471,23 +621,29 @@ async def star_images_in_list(
affected_boards: set[str] = set()
for image_name in image_names:
try:
_assert_image_owner(image_name, current_user)
updated_image_dto = ApiDependencies.invoker.services.images.update(
image_name, changes=ImageRecordChanges(starred=True)
)
starred_images.add(image_name)
affected_boards.add(updated_image_dto.board_id or "none")
except HTTPException:
raise
except Exception:
pass
return StarredImagesResult(
starred_images=list(starred_images),
affected_boards=list(affected_boards),
)
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=500, detail="Failed to star images")
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=UnstarredImagesResult)
async def unstar_images_in_list(
current_user: CurrentUserOrDefault,
image_names: list[str] = Body(description="The list of names of images to unstar", embed=True),
) -> UnstarredImagesResult:
try:
@@ -495,17 +651,22 @@ async def unstar_images_in_list(
affected_boards: set[str] = set()
for image_name in image_names:
try:
_assert_image_owner(image_name, current_user)
updated_image_dto = ApiDependencies.invoker.services.images.update(
image_name, changes=ImageRecordChanges(starred=False)
)
unstarred_images.add(image_name)
affected_boards.add(updated_image_dto.board_id or "none")
except HTTPException:
raise
except Exception:
pass
return UnstarredImagesResult(
unstarred_images=list(unstarred_images),
affected_boards=list(affected_boards),
)
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=500, detail="Failed to unstar images")
@@ -523,6 +684,7 @@ class ImagesDownloaded(BaseModel):
"/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202
)
async def download_images_from_list(
current_user: CurrentUserOrDefault,
background_tasks: BackgroundTasks,
image_names: Optional[list[str]] = Body(
default=None, description="The list of names of images to download", embed=True
@@ -533,6 +695,16 @@ async def download_images_from_list(
) -> ImagesDownloaded:
if (image_names is None or len(image_names) == 0) and board_id is None:
raise HTTPException(status_code=400, detail="No images or board id specified.")
# Validate that the caller can read every image they are requesting.
# For a board_id request, check board visibility; for explicit image names,
# check each image individually.
if board_id:
_assert_board_read_access(board_id, current_user)
if image_names:
for name in image_names:
_assert_image_read_access(name, current_user)
bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id)
background_tasks.add_task(
@@ -540,6 +712,7 @@ async def download_images_from_list(
image_names,
board_id,
bulk_download_item_id,
current_user.user_id,
)
return ImagesDownloaded(bulk_download_item_name=bulk_download_item_id + ".zip")
@@ -558,11 +731,21 @@ async def download_images_from_list(
},
)
async def get_bulk_download_item(
current_user: CurrentUserOrDefault,
background_tasks: BackgroundTasks,
bulk_download_item_name: str = Path(description="The bulk_download_item_name of the bulk download item to get"),
) -> FileResponse:
"""Gets a bulk download zip file"""
"""Gets a bulk download zip file.
Requires authentication. The caller must be the user who initiated the
download (tracked by the bulk download service) or an admin.
"""
try:
# Verify the caller owns this download (or is an admin)
owner = ApiDependencies.invoker.services.bulk_download.get_owner(bulk_download_item_name)
if owner is not None and owner != current_user.user_id and not current_user.is_admin:
raise HTTPException(status_code=403, detail="Not authorized to access this download")
path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name)
response = FileResponse(
@@ -574,6 +757,8 @@ async def get_bulk_download_item(
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.delete, bulk_download_item_name)
return response
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=404)
@@ -594,6 +779,10 @@ async def get_image_names(
) -> ImageNamesResult:
"""Gets ordered list of image names with metadata for optimistic updates"""
# Validate that the caller can read from this board before listing its images.
if board_id is not None and board_id != "none":
_assert_board_read_access(board_id, current_user)
try:
result = ApiDependencies.invoker.services.images.get_image_names(
starred_first=starred_first,
@@ -617,6 +806,7 @@ async def get_image_names(
responses={200: {"model": list[ImageDTO]}},
)
async def get_images_by_names(
current_user: CurrentUserOrDefault,
image_names: list[str] = Body(embed=True, description="Object containing list of image names to fetch DTOs for"),
) -> list[ImageDTO]:
"""Gets image DTOs for the specified image names. Maintains order of input names."""
@@ -628,8 +818,12 @@ async def get_images_by_names(
image_dtos: list[ImageDTO] = []
for name in image_names:
try:
_assert_image_read_access(name, current_user)
dto = image_service.get_dto(name)
image_dtos.append(dto)
except HTTPException:
# Skip images the user is not authorized to view
continue
except Exception:
# Skip missing images - they may have been deleted between name fetch and DTO fetch
continue

View File

@@ -547,6 +547,19 @@ class BulkDeleteModelsResponse(BaseModel):
failed: List[dict] = Field(description="List of failed deletions with error messages")
class BulkReidentifyModelsRequest(BaseModel):
"""Request body for bulk model reidentification."""
keys: List[str] = Field(description="List of model keys to reidentify")
class BulkReidentifyModelsResponse(BaseModel):
"""Response body for bulk model reidentification."""
succeeded: List[str] = Field(description="List of successfully reidentified model keys")
failed: List[dict] = Field(description="List of failed reidentifications with error messages")
@model_manager_router.post(
"/i/bulk_delete",
operation_id="bulk_delete_models",
@@ -588,6 +601,67 @@ async def bulk_delete_models(
return BulkDeleteModelsResponse(deleted=deleted, failed=failed)
@model_manager_router.post(
"/i/bulk_reidentify",
operation_id="bulk_reidentify_models",
responses={
200: {"description": "Models reidentified (possibly with some failures)"},
},
status_code=200,
)
async def bulk_reidentify_models(
current_admin: AdminUserOrDefault,
request: BulkReidentifyModelsRequest = Body(description="List of model keys to reidentify"),
) -> BulkReidentifyModelsResponse:
"""
Reidentify multiple models by re-probing their weights files.
Returns a list of successfully reidentified keys and failed reidentifications with error messages.
"""
logger = ApiDependencies.invoker.services.logger
store = ApiDependencies.invoker.services.model_manager.store
models_path = ApiDependencies.invoker.services.configuration.models_path
succeeded = []
failed = []
for key in request.keys:
try:
config = store.get_model(key)
if pathlib.Path(config.path).is_relative_to(models_path):
model_path = pathlib.Path(config.path)
else:
model_path = models_path / config.path
mod = ModelOnDisk(model_path)
result = ModelConfigFactory.from_model_on_disk(mod)
if result.config is None:
raise InvalidModelException("Unable to identify model format")
# Retain user-editable fields from the original config
result.config.path = config.path
result.config.key = config.key
result.config.name = config.name
result.config.description = config.description
result.config.cover_image = config.cover_image
if hasattr(config, "trigger_phrases") and hasattr(result.config, "trigger_phrases"):
result.config.trigger_phrases = config.trigger_phrases
result.config.source = config.source
result.config.source_type = config.source_type
store.replace_model(config.key, result.config)
succeeded.append(key)
logger.info(f"Reidentified model: {key}")
except UnknownModelException as e:
logger.error(f"Failed to reidentify model {key}: {str(e)}")
failed.append({"key": key, "error": str(e)})
except Exception as e:
logger.error(f"Failed to reidentify model {key}: {str(e)}")
failed.append({"key": key, "error": str(e)})
logger.info(f"Bulk reidentify completed: {len(succeeded)} succeeded, {len(failed)} failed")
return BulkReidentifyModelsResponse(succeeded=succeeded, failed=failed)
@model_manager_router.delete(
"/i/{key}/image",
operation_id="delete_model_image",
@@ -815,7 +889,7 @@ async def install_hugging_face_model(
"/install",
operation_id="list_model_installs",
)
async def list_model_installs() -> List[ModelInstallJob]:
async def list_model_installs(current_admin: AdminUserOrDefault) -> List[ModelInstallJob]:
"""Return the list of model install jobs.
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
@@ -847,7 +921,9 @@ async def list_model_installs() -> List[ModelInstallJob]:
404: {"description": "No such job"},
},
)
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
async def get_model_install_job(
current_admin: AdminUserOrDefault, id: int = Path(description="Model install id")
) -> ModelInstallJob:
"""
Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs'
for information on the format of the return value.
@@ -890,7 +966,9 @@ async def cancel_model_install_job(
},
status_code=201,
)
async def pause_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob:
async def pause_model_install_job(
current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID")
) -> ModelInstallJob:
"""Pause the model install job corresponding to the given job ID."""
installer = ApiDependencies.invoker.services.model_manager.install
try:
@@ -910,7 +988,9 @@ async def pause_model_install_job(id: int = Path(description="Model install job
},
status_code=201,
)
async def resume_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob:
async def resume_model_install_job(
current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID")
) -> ModelInstallJob:
"""Resume a paused model install job corresponding to the given job ID."""
installer = ApiDependencies.invoker.services.model_manager.install
try:
@@ -930,7 +1010,9 @@ async def resume_model_install_job(id: int = Path(description="Model install job
},
status_code=201,
)
async def restart_failed_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob:
async def restart_failed_model_install_job(
current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID")
) -> ModelInstallJob:
"""Restart failed or non-resumable file downloads for the given job."""
installer = ApiDependencies.invoker.services.model_manager.install
try:
@@ -951,6 +1033,7 @@ async def restart_failed_model_install_job(id: int = Path(description="Model ins
status_code=201,
)
async def restart_model_install_file(
current_admin: AdminUserOrDefault,
id: int = Path(description="Model install job ID"),
file_source: AnyHttpUrl = Body(description="File download URL to restart"),
) -> ModelInstallJob:
@@ -1262,7 +1345,7 @@ class DeleteOrphanedModelsResponse(BaseModel):
operation_id="get_orphaned_models",
response_model=list[OrphanedModelInfo],
)
async def get_orphaned_models() -> list[OrphanedModelInfo]:
async def get_orphaned_models(_: AdminUserOrDefault) -> list[OrphanedModelInfo]:
"""Find orphaned model directories.
Orphaned models are directories in the models folder that contain model files
@@ -1289,7 +1372,9 @@ async def get_orphaned_models() -> list[OrphanedModelInfo]:
operation_id="delete_orphaned_models",
response_model=DeleteOrphanedModelsResponse,
)
async def delete_orphaned_models(request: DeleteOrphanedModelsRequest) -> DeleteOrphanedModelsResponse:
async def delete_orphaned_models(
request: DeleteOrphanedModelsRequest, _: AdminUserOrDefault
) -> DeleteOrphanedModelsResponse:
"""Delete specified orphaned model directories.
Args:

View File

@@ -7,6 +7,7 @@ from fastapi import Body, HTTPException, Path
from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.backend.image_util.controlnet_processor import process_controlnet_image
from invokeai.backend.model_manager.taxonomy import ModelType
@@ -291,12 +292,58 @@ def resolve_ip_adapter_models(ip_adapters: list[IPAdapterRecallParameter]) -> li
return resolved_adapters
def _assert_recall_image_access(parameters: "RecallParameter", current_user: CurrentUserOrDefault) -> None:
"""Validate that the caller can read every image referenced in the recall parameters.
Control layers and IP adapters may reference image_name fields. Without this
check an attacker who knows another user's image UUID could use the recall
endpoint to extract image dimensions and — for ControlNet preprocessors — mint
a derived processed image they can then fetch.
"""
from invokeai.app.services.board_records.board_records_common import BoardVisibility
image_names: list[str] = []
if parameters.control_layers:
for layer in parameters.control_layers:
if layer.image_name is not None:
image_names.append(layer.image_name)
if parameters.ip_adapters:
for adapter in parameters.ip_adapters:
if adapter.image_name is not None:
image_names.append(adapter.image_name)
if not image_names:
return
# Admin can access all images
if current_user.is_admin:
return
for image_name in image_names:
owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name)
if owner is not None and owner == current_user.user_id:
continue
# Check board visibility
board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name)
if board_id is not None:
try:
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public):
continue
except Exception:
pass
raise HTTPException(status_code=403, detail=f"Not authorized to access image {image_name}")
@recall_parameters_router.post(
"/{queue_id}",
operation_id="update_recall_parameters",
response_model=dict[str, Any],
)
async def update_recall_parameters(
current_user: CurrentUserOrDefault,
queue_id: str = Path(..., description="The queue id to perform this operation on"),
parameters: RecallParameter = Body(..., description="Recall parameters to update"),
) -> dict[str, Any]:
@@ -328,6 +375,10 @@ async def update_recall_parameters(
"""
logger = ApiDependencies.invoker.services.logger
# Validate image access before processing — prevents information leakage
# (dimensions) and derived-image minting via ControlNet preprocessors.
_assert_recall_image_access(parameters, current_user)
try:
# Get only the parameters that were actually provided (non-None values)
provided_params = {k: v for k, v in parameters.model_dump().items() if v is not None}
@@ -335,14 +386,14 @@ async def update_recall_parameters(
if not provided_params:
return {"status": "no_parameters_provided", "updated_count": 0}
# Store each parameter in client state using a consistent key format
# Store each parameter in client state scoped to the current user
updated_count = 0
for param_key, param_value in provided_params.items():
# Convert parameter values to JSON strings for storage
value_str = json.dumps(param_value)
try:
ApiDependencies.invoker.services.client_state_persistence.set_by_key(
queue_id, f"recall_{param_key}", value_str
current_user.user_id, f"recall_{param_key}", value_str
)
updated_count += 1
except Exception as e:
@@ -396,7 +447,9 @@ async def update_recall_parameters(
logger.info(
f"Emitting recall_parameters_updated event for queue {queue_id} with {len(provided_params)} parameters"
)
ApiDependencies.invoker.services.events.emit_recall_parameters_updated(queue_id, provided_params)
ApiDependencies.invoker.services.events.emit_recall_parameters_updated(
queue_id, current_user.user_id, provided_params
)
logger.info("Successfully emitted recall_parameters_updated event")
except Exception as e:
logger.error(f"Error emitting recall parameters event: {e}", exc_info=True)
@@ -425,6 +478,7 @@ async def update_recall_parameters(
response_model=dict[str, Any],
)
async def get_recall_parameters(
current_user: CurrentUserOrDefault,
queue_id: str = Path(..., description="The queue id to retrieve parameters for"),
) -> dict[str, Any]:
"""

View File

@@ -44,7 +44,8 @@ def sanitize_queue_item_for_user(
"""Sanitize queue item for non-admin users viewing other users' items.
For non-admin users viewing queue items belonging to other users,
the field_values, session graph, and workflow should be hidden/cleared to protect privacy.
only timestamps, status, and error information are exposed. All other
fields (user identity, generation parameters, graphs, workflows) are stripped.
Args:
queue_item: The queue item to sanitize
@@ -58,15 +59,25 @@ def sanitize_queue_item_for_user(
if is_admin or queue_item.user_id == current_user_id:
return queue_item
# For non-admins viewing other users' items, clear sensitive fields
# Create a shallow copy to avoid mutating the original
# For non-admins viewing other users' items, strip everything except
# item_id, queue_id, status, and timestamps
sanitized_item = queue_item.model_copy(deep=False)
sanitized_item.user_id = "redacted"
sanitized_item.user_display_name = None
sanitized_item.user_email = None
sanitized_item.batch_id = "redacted"
sanitized_item.session_id = "redacted"
sanitized_item.origin = None
sanitized_item.destination = None
sanitized_item.priority = 0
sanitized_item.field_values = None
sanitized_item.retried_from_item_id = None
sanitized_item.workflow = None
# Clear the session graph by replacing it with an empty graph execution state
# This prevents information leakage through the generation graph
sanitized_item.error_type = None
sanitized_item.error_message = None
sanitized_item.error_traceback = None
sanitized_item.session = GraphExecutionState(
id=queue_item.session.id,
id="redacted",
graph=Graph(),
)
return sanitized_item
@@ -126,12 +137,16 @@ async def list_all_queue_items(
},
)
async def get_queue_item_ids(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
) -> ItemIdsResult:
"""Gets all queue item ids that match the given parameters"""
"""Gets all queue item ids that match the given parameters. Non-admin users only see their own items."""
try:
return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(queue_id=queue_id, order_dir=order_dir)
user_id = None if current_user.is_admin else current_user.user_id
return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(
queue_id=queue_id, order_dir=order_dir, user_id=user_id
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue item ids: {e}")
@@ -376,11 +391,15 @@ async def prune(
},
)
async def get_current_queue_item(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the currently execution queue item"""
try:
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
if item is not None:
item = sanitize_queue_item_for_user(item, current_user.user_id, current_user.is_admin)
return item
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting current queue item: {e}")
@@ -393,11 +412,15 @@ async def get_current_queue_item(
},
)
async def get_next_queue_item(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the next queue item, without executing it"""
try:
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
item = ApiDependencies.invoker.services.session_queue.get_next(queue_id)
if item is not None:
item = sanitize_queue_item_for_user(item, current_user.user_id, current_user.is_admin)
return item
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting next queue item: {e}")
@@ -413,9 +436,10 @@ async def get_queue_status(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionQueueAndProcessorStatus:
"""Gets the status of the session queue"""
"""Gets the status of the session queue. Non-admin users see only their own counts and cannot see current item details unless they own it."""
try:
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=current_user.user_id)
user_id = None if current_user.is_admin else current_user.user_id
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=user_id)
processor = ApiDependencies.invoker.services.session_processor.get_status()
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
except Exception as e:
@@ -430,12 +454,16 @@ async def get_queue_status(
},
)
async def get_batch_status(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
batch_id: str = Path(description="The batch to get the status of"),
) -> BatchStatus:
"""Gets the status of the session queue"""
"""Gets the status of a batch. Non-admin users only see their own batches."""
try:
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
user_id = None if current_user.is_admin else current_user.user_id
return ApiDependencies.invoker.services.session_queue.get_batch_status(
queue_id=queue_id, batch_id=batch_id, user_id=user_id
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting batch status: {e}")
@@ -529,13 +557,15 @@ async def cancel_queue_item(
responses={200: {"model": SessionQueueCountsByDestination}},
)
async def counts_by_destination(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to query"),
destination: str = Query(description="The destination to query"),
) -> SessionQueueCountsByDestination:
"""Gets the counts of queue items by destination"""
"""Gets the counts of queue items by destination. Non-admin users only see their own items."""
try:
user_id = None if current_user.is_admin else current_user.user_id
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
queue_id=queue_id, destination=destination
queue_id=queue_id, destination=destination, user_id=user_id
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching counts by destination: {e}")

View File

@@ -6,6 +6,7 @@ from fastapi import APIRouter, Body, File, HTTPException, Path, Query, UploadFil
from fastapi.responses import FileResponse
from PIL import Image
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
@@ -33,16 +34,25 @@ workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
},
)
async def get_workflow(
current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to get"),
) -> WorkflowRecordWithThumbnailDTO:
"""Gets a workflow"""
try:
thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id)
workflow = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump())
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
config = ApiDependencies.invoker.services.configuration
if config.multiuser:
is_default = workflow.workflow.meta.category is WorkflowCategory.Default
is_owner = workflow.user_id == current_user.user_id
if not (is_default or is_owner or workflow.is_public or current_user.is_admin):
raise HTTPException(status_code=403, detail="Not authorized to access this workflow")
thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id)
return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump())
@workflows_router.patch(
"/i/{workflow_id}",
@@ -52,10 +62,21 @@ async def get_workflow(
},
)
async def update_workflow(
current_user: CurrentUserOrDefault,
workflow: Workflow = Body(description="The updated workflow", embed=True),
) -> WorkflowRecordDTO:
"""Updates a workflow"""
return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow)
config = ApiDependencies.invoker.services.configuration
if config.multiuser:
try:
existing = ApiDependencies.invoker.services.workflow_records.get(workflow.id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
if not current_user.is_admin and existing.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
# Pass user_id for defense-in-depth SQL scoping; admins pass None to allow any.
user_id = None if current_user.is_admin else current_user.user_id
return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow, user_id=user_id)
@workflows_router.delete(
@@ -63,15 +84,25 @@ async def update_workflow(
operation_id="delete_workflow",
)
async def delete_workflow(
current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to delete"),
) -> None:
"""Deletes a workflow"""
config = ApiDependencies.invoker.services.configuration
if config.multiuser:
try:
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
if not current_user.is_admin and existing.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to delete this workflow")
try:
ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id)
except WorkflowThumbnailFileNotFoundException:
# It's OK if the workflow has no thumbnail file. We can still delete the workflow.
pass
ApiDependencies.invoker.services.workflow_records.delete(workflow_id)
user_id = None if current_user.is_admin else current_user.user_id
ApiDependencies.invoker.services.workflow_records.delete(workflow_id, user_id=user_id)
@workflows_router.post(
@@ -82,10 +113,17 @@ async def delete_workflow(
},
)
async def create_workflow(
current_user: CurrentUserOrDefault,
workflow: WorkflowWithoutID = Body(description="The workflow to create", embed=True),
) -> WorkflowRecordDTO:
"""Creates a workflow"""
return ApiDependencies.invoker.services.workflow_records.create(workflow=workflow)
# In single-user mode, workflows are owned by 'system' and shared by default so all legacy/single-user
# workflows remain visible. In multiuser mode, workflows are private to the creator by default.
config = ApiDependencies.invoker.services.configuration
is_public = not config.multiuser
return ApiDependencies.invoker.services.workflow_records.create(
workflow=workflow, user_id=current_user.user_id, is_public=is_public
)
@workflows_router.get(
@@ -96,6 +134,7 @@ async def create_workflow(
},
)
async def list_workflows(
current_user: CurrentUserOrDefault,
page: int = Query(default=0, description="The page to get"),
per_page: Optional[int] = Query(default=None, description="The number of workflows per page"),
order_by: WorkflowRecordOrderBy = Query(
@@ -106,8 +145,19 @@ async def list_workflows(
tags: Optional[list[str]] = Query(default=None, description="The tags of workflow to get"),
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]:
"""Gets a page of workflows"""
config = ApiDependencies.invoker.services.configuration
# In multiuser mode, scope user-category workflows to the current user unless fetching shared workflows.
# Admins skip the user_id filter so they can see and manage all workflows including system-owned ones.
user_id_filter: Optional[str] = None
if config.multiuser and not current_user.is_admin:
has_user_category = not categories or WorkflowCategory.User in categories
if has_user_category and is_public is not True:
user_id_filter = current_user.user_id
workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = []
workflows = ApiDependencies.invoker.services.workflow_records.get_many(
order_by=order_by,
@@ -118,6 +168,8 @@ async def list_workflows(
categories=categories,
tags=tags,
has_been_opened=has_been_opened,
user_id=user_id_filter,
is_public=is_public,
)
for workflow in workflows.items:
workflows_with_thumbnails.append(
@@ -143,15 +195,20 @@ async def list_workflows(
},
)
async def set_workflow_thumbnail(
current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to update"),
image: UploadFile = File(description="The image file to upload"),
):
"""Sets a workflow's thumbnail image"""
try:
ApiDependencies.invoker.services.workflow_records.get(workflow_id)
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
config = ApiDependencies.invoker.services.configuration
if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
@@ -177,14 +234,19 @@ async def set_workflow_thumbnail(
},
)
async def delete_workflow_thumbnail(
current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to update"),
):
"""Removes a workflow's thumbnail image"""
try:
ApiDependencies.invoker.services.workflow_records.get(workflow_id)
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
config = ApiDependencies.invoker.services.configuration
if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
try:
ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id)
except ValueError as e:
@@ -206,8 +268,12 @@ async def delete_workflow_thumbnail(
async def get_workflow_thumbnail(
workflow_id: str = Path(description="The id of the workflow thumbnail to get"),
) -> FileResponse:
"""Gets a workflow's thumbnail image"""
"""Gets a workflow's thumbnail image.
This endpoint is intentionally unauthenticated because browsers load images
via <img src> tags which cannot send Bearer tokens. Workflow IDs are UUIDs,
providing security through unguessability.
"""
try:
path = ApiDependencies.invoker.services.workflow_thumbnails.get_path(workflow_id)
@@ -223,37 +289,91 @@ async def get_workflow_thumbnail(
raise HTTPException(status_code=404)
@workflows_router.patch(
"/i/{workflow_id}/is_public",
operation_id="update_workflow_is_public",
responses={
200: {"model": WorkflowRecordDTO},
},
)
async def update_workflow_is_public(
current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to update"),
is_public: bool = Body(description="Whether the workflow should be shared publicly", embed=True),
) -> WorkflowRecordDTO:
"""Updates whether a workflow is shared publicly"""
try:
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
config = ApiDependencies.invoker.services.configuration
if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
user_id = None if current_user.is_admin else current_user.user_id
return ApiDependencies.invoker.services.workflow_records.update_is_public(
workflow_id=workflow_id, is_public=is_public, user_id=user_id
)
@workflows_router.get("/tags", operation_id="get_all_tags")
async def get_all_tags(
current_user: CurrentUserOrDefault,
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
) -> list[str]:
"""Gets all unique tags from workflows"""
config = ApiDependencies.invoker.services.configuration
user_id_filter: Optional[str] = None
if config.multiuser and not current_user.is_admin:
has_user_category = not categories or WorkflowCategory.User in categories
if has_user_category and is_public is not True:
user_id_filter = current_user.user_id
return ApiDependencies.invoker.services.workflow_records.get_all_tags(categories=categories)
return ApiDependencies.invoker.services.workflow_records.get_all_tags(
categories=categories, user_id=user_id_filter, is_public=is_public
)
@workflows_router.get("/counts_by_tag", operation_id="get_counts_by_tag")
async def get_counts_by_tag(
current_user: CurrentUserOrDefault,
tags: list[str] = Query(description="The tags to get counts for"),
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
) -> dict[str, int]:
"""Counts workflows by tag"""
config = ApiDependencies.invoker.services.configuration
user_id_filter: Optional[str] = None
if config.multiuser and not current_user.is_admin:
has_user_category = not categories or WorkflowCategory.User in categories
if has_user_category and is_public is not True:
user_id_filter = current_user.user_id
return ApiDependencies.invoker.services.workflow_records.counts_by_tag(
tags=tags, categories=categories, has_been_opened=has_been_opened
tags=tags, categories=categories, has_been_opened=has_been_opened, user_id=user_id_filter, is_public=is_public
)
@workflows_router.get("/counts_by_category", operation_id="counts_by_category")
async def counts_by_category(
current_user: CurrentUserOrDefault,
categories: list[WorkflowCategory] = Query(description="The categories to include"),
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
) -> dict[str, int]:
"""Counts workflows by category"""
config = ApiDependencies.invoker.services.configuration
user_id_filter: Optional[str] = None
if config.multiuser and not current_user.is_admin:
has_user_category = WorkflowCategory.User in categories
if has_user_category and is_public is not True:
user_id_filter = current_user.user_id
return ApiDependencies.invoker.services.workflow_records.counts_by_category(
categories=categories, has_been_opened=has_been_opened
categories=categories, has_been_opened=has_been_opened, user_id=user_id_filter, is_public=is_public
)
@@ -262,7 +382,18 @@ async def counts_by_category(
operation_id="update_opened_at",
)
async def update_opened_at(
current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to update"),
) -> None:
"""Updates the opened_at field of a workflow"""
ApiDependencies.invoker.services.workflow_records.update_opened_at(workflow_id)
try:
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
config = ApiDependencies.invoker.services.configuration
if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
user_id = None if current_user.is_admin else current_user.user_id
ApiDependencies.invoker.services.workflow_records.update_opened_at(workflow_id, user_id=user_id)

View File

@@ -121,6 +121,11 @@ class SocketIO:
Returns True to accept the connection, False to reject it.
Stores user_id in the internal socket users dict for later use.
In multiuser mode, connections without a valid token are rejected outright
so that anonymous clients cannot subscribe to queue rooms and observe
queue activity belonging to other users. In single-user mode, unauthenticated
connections are accepted as the system admin user.
"""
# Extract token from auth data or headers
token = None
@@ -137,6 +142,23 @@ class SocketIO:
if token:
token_data = verify_token(token)
if token_data:
# In multiuser mode, also verify the backing user record still
# exists and is active — mirrors the REST auth check in
# auth_dependencies.py. A deleted or deactivated user whose
# JWT has not yet expired must not be allowed to open a socket.
if self._is_multiuser_enabled():
try:
from invokeai.app.api.dependencies import ApiDependencies
user = ApiDependencies.invoker.services.users.get(token_data.user_id)
if user is None or not user.is_active:
logger.warning(f"Rejecting socket {sid}: user {token_data.user_id} not found or inactive")
return False
except Exception:
# If user service is unavailable, fail closed
logger.warning(f"Rejecting socket {sid}: unable to verify user record")
return False
# Store user_id and is_admin in socket users dict
self._socket_users[sid] = {
"user_id": token_data.user_id,
@@ -147,14 +169,37 @@ class SocketIO:
)
return True
# If no valid token, store system user for backward compatibility
# No valid token provided. In multiuser mode this is not allowed — reject
# the connection so anonymous clients cannot subscribe to queue rooms.
# In single-user mode, fall through and accept the socket as system admin.
if self._is_multiuser_enabled():
logger.warning(
f"Rejecting socket {sid} connection: multiuser mode is enabled and no valid auth token was provided"
)
return False
self._socket_users[sid] = {
"user_id": "system",
"is_admin": False,
"is_admin": True,
}
logger.debug(f"Socket {sid} connected as system user (no valid token)")
logger.debug(f"Socket {sid} connected as system admin (single-user mode)")
return True
@staticmethod
def _is_multiuser_enabled() -> bool:
"""Check whether multiuser mode is enabled. Fails closed if configuration
is not yet initialized, which should not happen in practice but prevents
accidentally opening the socket during startup races."""
try:
# Imported here to avoid a circular import at module load time.
from invokeai.app.api.dependencies import ApiDependencies
return bool(ApiDependencies.invoker.services.configuration.multiuser)
except Exception:
# If dependencies are not initialized, fail closed (treat as multiuser)
# so we never accidentally admit an anonymous socket.
return True
async def _handle_disconnect(self, sid: str) -> None:
"""Handle socket disconnection and cleanup user info."""
if sid in self._socket_users:
@@ -165,15 +210,20 @@ class SocketIO:
"""Handle queue subscription and add socket to both queue and user-specific rooms."""
queue_id = QueueSubscriptionEvent(**data).queue_id
# Check if we have user info for this socket
# Check if we have user info for this socket. In multiuser mode _handle_connect
# will have already rejected any socket without a valid token, so missing user
# info here is a bug — refuse the subscription rather than silently falling back
# to an anonymous system user who could then receive queue item events.
if sid not in self._socket_users:
logger.warning(
f"Socket {sid} subscribing to queue {queue_id} but has no user info - need to authenticate via connect event"
)
# Store as system user temporarily - real auth should happen in connect
if self._is_multiuser_enabled():
logger.warning(
f"Refusing queue subscription for socket {sid}: no user info (socket not authenticated via connect event)"
)
return
# Single-user mode: safe to fall back to the system admin user.
self._socket_users[sid] = {
"user_id": "system",
"is_admin": False,
"is_admin": True,
}
user_id = self._socket_users[sid]["user_id"]
@@ -198,6 +248,13 @@ class SocketIO:
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
# In multiuser mode, only allow authenticated sockets to subscribe.
# Bulk download events are routed to user-specific rooms, so the
# bulk_download_id room subscription is only kept for single-user
# backward compatibility.
if self._is_multiuser_enabled() and sid not in self._socket_users:
logger.warning(f"Refusing bulk download subscription for unknown socket {sid} in multiuser mode")
return
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
@@ -206,9 +263,17 @@ class SocketIO:
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
"""Handle queue events with user isolation.
Invocation events (progress, started, complete) are private - only emit to owner and admins.
Queue item status events are public - emit to all users (field values hidden via API).
Other queue events emit to all subscribers.
All queue item events (invocation events AND QueueItemStatusChangedEvent) are
private to the owning user and admins. They carry unsanitized user_id, batch_id,
session_id, origin, destination and error metadata, and must never be broadcast
to the whole queue room — otherwise any other authenticated subscriber could
observe cross-user queue activity.
RecallParametersUpdatedEvent is also private to the owner + admins.
BatchEnqueuedEvent carries the enqueuing user's batch_id/origin/counts and
is also routed privately. QueueClearedEvent is the only queue event that
is still broadcast to the whole queue room.
IMPORTANT: Check InvocationEventBase BEFORE QueueItemEventBase since InvocationEventBase
inherits from QueueItemEventBase. The order of isinstance checks matters!
@@ -237,24 +302,40 @@ class SocketIO:
logger.debug(f"Emitted private invocation event {event_name} to user room {user_room} and admin room")
# Queue item status events are visible to all users (field values masked via API)
# This catches QueueItemStatusChangedEvent but NOT InvocationEvents (already handled above)
# Other queue item events (QueueItemStatusChangedEvent) carry unsanitized
# user_id, batch_id, session_id, origin, destination and error metadata.
# They are private to the owning user + admins — never broadcast to the
# full queue room.
elif isinstance(event_data, QueueItemEventBase) and hasattr(event_data, "user_id"):
# Emit to all subscribers in the queue
await self._sio.emit(
event=event_name, data=event_data.model_dump(mode="json"), room=event_data.queue_id
)
user_room = f"user:{event_data.user_id}"
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
logger.info(
f"Emitted public queue item event {event_name} to all subscribers in queue {event_data.queue_id}"
)
logger.debug(f"Emitted private queue item event {event_name} to user room {user_room} and admin room")
# RecallParametersUpdatedEvent is private - only emit to owner + admins
elif isinstance(event_data, RecallParametersUpdatedEvent):
user_room = f"user:{event_data.user_id}"
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
logger.debug(f"Emitted private recall_parameters_updated event to user room {user_room} and admin room")
# BatchEnqueuedEvent carries the enqueuing user's batch_id, origin, and
# enqueued counts. Route it privately to the owner + admins so other
# users do not observe cross-user batch activity.
elif isinstance(event_data, BatchEnqueuedEvent):
user_room = f"user:{event_data.user_id}"
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
logger.debug(f"Emitted private batch_enqueued event to user room {user_room} and admin room")
else:
# For other queue events (like QueueClearedEvent, BatchEnqueuedEvent), emit to all subscribers
# For remaining queue events (e.g. QueueClearedEvent) that do not
# carry user identity, emit to all subscribers in the queue room.
await self._sio.emit(
event=event_name, data=event_data.model_dump(mode="json"), room=event_data.queue_id
)
logger.info(
logger.debug(
f"Emitted general queue event {event_name} to all subscribers in queue {event_data.queue_id}"
)
except Exception as e:
@@ -265,4 +346,17 @@ class SocketIO:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id)
event_name, event_data = event
# Route to user-specific + admin rooms so that other authenticated
# users cannot learn the bulk_download_item_name (the capability token
# needed to fetch the zip from the unauthenticated GET endpoint).
# In single-user mode (user_id="system"), fall back to the shared
# bulk_download_id room for backward compatibility.
if hasattr(event_data, "user_id") and event_data.user_id != "system":
user_room = f"user:{event_data.user_id}"
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
else:
await self._sio.emit(
event=event_name, data=event_data.model_dump(mode="json"), room=event_data.bulk_download_id
)

View File

@@ -79,6 +79,50 @@ app = FastAPI(
)
class SlidingWindowTokenMiddleware(BaseHTTPMiddleware):
"""Refresh the JWT token on each authenticated response.
When a request includes a valid Bearer token, the response includes a
X-Refreshed-Token header with a new token that has a fresh expiry.
This implements sliding-window session expiry: the session only expires
after a period of *inactivity*, not a fixed time after login.
"""
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
response = await call_next(request)
# Only refresh on mutating requests (POST/PUT/PATCH/DELETE) — these indicate
# genuine user activity. GET requests are often background fetches (RTK Query
# cache revalidation, refetch-on-focus, etc.) and should not reset the
# inactivity timer.
if response.status_code < 400 and request.method in ("POST", "PUT", "PATCH", "DELETE"):
auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[7:]
try:
from datetime import timedelta
from invokeai.app.api.routers.auth import TOKEN_EXPIRATION_NORMAL, TOKEN_EXPIRATION_REMEMBER_ME
from invokeai.app.services.auth.token_service import create_access_token, verify_token
token_data = verify_token(token)
if token_data is not None:
# Use the remember_me claim from the token to determine the
# correct refresh duration. This avoids the bug where a 7-day
# token with <24h remaining would be silently downgraded to 1 day.
if token_data.remember_me:
expires_delta = timedelta(days=TOKEN_EXPIRATION_REMEMBER_ME)
else:
expires_delta = timedelta(days=TOKEN_EXPIRATION_NORMAL)
new_token = create_access_token(token_data, expires_delta)
response.headers["X-Refreshed-Token"] = new_token
except Exception:
pass # Don't fail the request if token refresh fails
return response
class RedirectRootWithQueryStringMiddleware(BaseHTTPMiddleware):
"""When a request is made to the root path with a query string, redirect to the root path without the query string.
@@ -99,6 +143,7 @@ class RedirectRootWithQueryStringMiddleware(BaseHTTPMiddleware):
# Add the middleware
app.add_middleware(RedirectRootWithQueryStringMiddleware)
app.add_middleware(SlidingWindowTokenMiddleware)
# Add event handler
@@ -117,6 +162,7 @@ app.add_middleware(
allow_credentials=app_config.allow_credentials,
allow_methods=app_config.allow_methods,
allow_headers=app_config.allow_headers,
expose_headers=["X-Refreshed-Token"],
)
app.add_middleware(GZipMiddleware, minimum_size=1000)

View File

@@ -0,0 +1,715 @@
"""Anima denoising invocation.
Implements the rectified flow denoising loop for Anima models:
- Direct prediction: denoised = input - output * sigma
- Fixed shift=3.0 via loglinear_timestep_shift (Flux paper by Black Forest Labs)
- Timestep convention: timestep = sigma * 1.0 (raw sigma, NOT 1-sigma like Z-Image)
- NO v-prediction negation (unlike Z-Image)
- 3D latent space: [B, C, T, H, W] with T=1 for images
- 16 latent channels, 8x spatial compression
Key differences from Z-Image denoise:
- Anima uses fixed shift=3.0, Z-Image uses dynamic shift based on resolution
- Anima: timestep = sigma (raw), Z-Image: model_t = 1.0 - sigma
- Anima: noise_pred = model_output (direct), Z-Image: noise_pred = -model_output (v-pred)
- Anima transformer takes (x, timesteps, context, t5xxl_ids, t5xxl_weights)
- Anima uses 3D latents directly, Z-Image converts 4D -> list of 5D
"""
import inspect
import math
from contextlib import ExitStack
from typing import Callable, Iterator, Optional, Tuple
import torch
import torchvision.transforms as tv_transforms
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from torchvision.transforms.functional import resize as tv_resize
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
AnimaConditioningField,
DenoiseMaskField,
FieldDescriptions,
Input,
InputField,
LatentsField,
)
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.anima.anima_transformer_patch import patch_anima_for_regional_prompting
from invokeai.backend.anima.conditioning_data import AnimaRegionalTextConditioning, AnimaTextConditioning
from invokeai.backend.anima.regional_prompting import AnimaRegionalPromptingExtension
from invokeai.backend.flux.schedulers import ANIMA_SCHEDULER_LABELS, ANIMA_SCHEDULER_MAP, ANIMA_SCHEDULER_NAME_VALUES
from invokeai.backend.model_manager.taxonomy import BaseModelType
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.anima_lora_constants import ANIMA_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import (
RectifiedFlowInpaintExtension,
assert_broadcastable,
)
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import AnimaConditioningInfo, Range
from invokeai.backend.util.devices import TorchDevice
# Anima uses 8x spatial compression (VAE downsamples by 2^3)
ANIMA_LATENT_SCALE_FACTOR = 8
# Anima uses 16 latent channels
ANIMA_LATENT_CHANNELS = 16
# Anima uses fixed shift=3.0 for the rectified flow schedule
ANIMA_SHIFT = 3.0
# Anima uses raw sigma values as timesteps (no rescaling)
ANIMA_MULTIPLIER = 1.0
def loglinear_timestep_shift(alpha: float, t: float) -> float:
"""Apply log-linear timestep shift to a noise schedule value.
This shift biases the noise schedule toward higher noise levels, as described
in the Flux model (Black Forest Labs, 2024). With alpha > 1, the model spends
proportionally more denoising steps at higher noise levels.
Formula: sigma = alpha * t / (1 + (alpha - 1) * t)
Args:
alpha: Shift factor (3.0 for Anima, resolution-dependent for Flux).
t: Timestep value in [0, 1].
Returns:
Shifted timestep value.
"""
if alpha == 1.0:
return t
return alpha * t / (1 + (alpha - 1) * t)
def inverse_loglinear_timestep_shift(alpha: float, sigma: float) -> float:
"""Recover linear t from a shifted sigma value.
Inverse of loglinear_timestep_shift: given sigma = alpha * t / (1 + (alpha-1) * t),
solve for t = sigma / (alpha - (alpha-1) * sigma).
This is needed for the inpainting extension, which expects linear t values
for gradient mask thresholding. With Anima's shift=3.0, the difference
between shifted sigma and linear t is large (e.g. at t=0.5, sigma=0.75),
causing overly aggressive mask thresholding if sigma is used directly.
Args:
alpha: Shift factor (3.0 for Anima).
sigma: Shifted sigma value in [0, 1].
Returns:
Linear t value in [0, 1].
"""
if alpha == 1.0:
return sigma
denominator = alpha - (alpha - 1) * sigma
if abs(denominator) < 1e-8:
return 1.0
return sigma / denominator
class AnimaInpaintExtension(RectifiedFlowInpaintExtension):
"""Inpaint extension for Anima that accounts for the time-SNR shift.
Anima uses a fixed shift=3.0 which makes sigma values significantly larger
than the corresponding linear t values. The base RectifiedFlowInpaintExtension
uses t_prev for both gradient mask thresholding and noise mixing, which assumes
linear t values.
This subclass:
- Uses the LINEAR t for gradient mask thresholding (correct progressive reveal)
- Uses the SHIFTED sigma for noise mixing (matches the denoiser's noise level)
"""
def __init__(
self,
init_latents: torch.Tensor,
inpaint_mask: torch.Tensor,
noise: torch.Tensor,
shift: float = ANIMA_SHIFT,
):
assert_broadcastable(init_latents.shape, inpaint_mask.shape, noise.shape)
self._init_latents = init_latents
self._inpaint_mask = inpaint_mask
self._noise = noise
self._shift = shift
def merge_intermediate_latents_with_init_latents(
self, intermediate_latents: torch.Tensor, sigma_prev: float
) -> torch.Tensor:
"""Merge intermediate latents with init latents, correcting for Anima's shift.
Args:
intermediate_latents: The denoised latents at the current step.
sigma_prev: The SHIFTED sigma value for the next step.
"""
# Recover linear t from shifted sigma for gradient mask thresholding.
# This ensures the gradient mask is revealed at the correct pace.
t_prev = inverse_loglinear_timestep_shift(self._shift, sigma_prev)
mask = self._apply_mask_gradient_adjustment(t_prev)
# Use shifted sigma for noise mixing to match the denoiser's noise level.
# The Euler step produces latents at noise level sigma_prev, so the
# preserved regions must also be at sigma_prev noise level.
noised_init_latents = self._noise * sigma_prev + (1.0 - sigma_prev) * self._init_latents
return intermediate_latents * mask + noised_init_latents * (1.0 - mask)
@invocation(
"anima_denoise",
title="Denoise - Anima",
tags=["image", "anima"],
category="image",
version="1.2.0",
classification=Classification.Prototype,
)
class AnimaDenoiseInvocation(BaseInvocation):
"""Run the denoising process with an Anima model.
Uses rectified flow sampling with shift=3.0 and the Cosmos Predict2 DiT
backbone with integrated LLM Adapter for text conditioning.
Supports txt2img, img2img (via latents input), and inpainting (via denoise_mask).
"""
# If latents is provided, this means we are doing image-to-image.
latents: Optional[LatentsField] = InputField(
default=None, description=FieldDescriptions.latents, input=Input.Connection
)
# denoise_mask is used for inpainting. Only the masked region is modified.
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
)
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
add_noise: bool = InputField(default=True, description="Add noise based on denoising start.")
transformer: TransformerField = InputField(
description="Anima transformer model.", input=Input.Connection, title="Transformer"
)
positive_conditioning: AnimaConditioningField | list[AnimaConditioningField] = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: AnimaConditioningField | list[AnimaConditioningField] | None = InputField(
default=None, description=FieldDescriptions.negative_cond, input=Input.Connection
)
guidance_scale: float = InputField(
default=4.5,
ge=1.0,
description="Guidance scale for classifier-free guidance. Recommended: 4.0-5.0 for Anima.",
title="Guidance Scale",
)
width: int = InputField(default=1024, multiple_of=8, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=8, description="Height of the generated image.")
steps: int = InputField(default=30, gt=0, description="Number of denoising steps. 30 recommended for Anima.")
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
scheduler: ANIMA_SCHEDULER_NAME_VALUES = InputField(
default="euler",
description="Scheduler (sampler) for the denoising process.",
ui_choice_labels=ANIMA_SCHEDULER_LABELS,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
latents = latents.detach().to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
"""Prepare the inpaint mask for Anima.
Anima uses 3D latents [B, C, T, H, W] internally but the mask operates
on the spatial dimensions [B, C, H, W] which match the squeezed output.
"""
if self.denoise_mask is None:
return None
mask = context.tensors.load(self.denoise_mask.mask_name)
# Invert mask: 0.0 = regions to denoise, 1.0 = regions to preserve
mask = 1.0 - mask
_, _, latent_height, latent_width = latents.shape
mask = tv_resize(
img=mask,
size=[latent_height, latent_width],
interpolation=tv_transforms.InterpolationMode.BILINEAR,
antialias=False,
)
mask = mask.to(device=latents.device, dtype=latents.dtype)
return mask
def _get_noise(
self,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
seed: int,
) -> torch.Tensor:
"""Generate initial noise tensor in 3D latent space [B, C, T, H, W]."""
rand_device = "cpu"
return torch.randn(
1,
ANIMA_LATENT_CHANNELS,
1, # T=1 for single image
height // ANIMA_LATENT_SCALE_FACTOR,
width // ANIMA_LATENT_SCALE_FACTOR,
device=rand_device,
dtype=torch.float32,
generator=torch.Generator(device=rand_device).manual_seed(seed),
).to(device=device, dtype=dtype)
def _get_sigmas(self, num_steps: int) -> list[float]:
"""Generate sigma schedule with fixed shift=3.0.
Uses the log-linear timestep shift from the Flux model (Black Forest Labs)
with a fixed shift factor of 3.0 (no dynamic resolution-based shift).
Returns:
List of num_steps + 1 sigma values from ~1.0 (noise) to 0.0 (clean).
"""
sigmas = []
for i in range(num_steps + 1):
t = 1.0 - i / num_steps
sigma = loglinear_timestep_shift(ANIMA_SHIFT, t)
sigmas.append(sigma)
return sigmas
def _load_conditioning(
self,
context: InvocationContext,
cond_field: AnimaConditioningField,
dtype: torch.dtype,
device: torch.device,
) -> AnimaConditioningInfo:
"""Load Anima conditioning data from storage."""
cond_data = context.conditioning.load(cond_field.conditioning_name)
assert len(cond_data.conditionings) == 1
cond_info = cond_data.conditionings[0]
assert isinstance(cond_info, AnimaConditioningInfo)
return cond_info.to(dtype=dtype, device=device)
def _load_text_conditionings(
self,
context: InvocationContext,
cond_field: AnimaConditioningField | list[AnimaConditioningField],
img_token_height: int,
img_token_width: int,
dtype: torch.dtype,
device: torch.device,
) -> list[AnimaTextConditioning]:
"""Load Anima text conditioning with optional regional masks.
Args:
context: The invocation context.
cond_field: Single conditioning field or list of fields.
img_token_height: Height of the image token grid (H // patch_size).
img_token_width: Width of the image token grid (W // patch_size).
dtype: Target dtype.
device: Target device.
Returns:
List of AnimaTextConditioning objects with optional masks.
"""
cond_list = cond_field if isinstance(cond_field, list) else [cond_field]
text_conditionings: list[AnimaTextConditioning] = []
for cond in cond_list:
cond_info = self._load_conditioning(context, cond, dtype, device)
# Load the mask, if provided
mask: torch.Tensor | None = None
if cond.mask is not None:
mask = context.tensors.load(cond.mask.tensor_name)
mask = mask.to(device=device)
mask = AnimaRegionalPromptingExtension.preprocess_regional_prompt_mask(
mask, img_token_height, img_token_width, dtype, device
)
text_conditionings.append(
AnimaTextConditioning(
qwen3_embeds=cond_info.qwen3_embeds,
t5xxl_ids=cond_info.t5xxl_ids,
t5xxl_weights=cond_info.t5xxl_weights,
mask=mask,
)
)
return text_conditionings
def _run_llm_adapter_for_regions(
self,
transformer,
text_conditionings: list[AnimaTextConditioning],
dtype: torch.dtype,
) -> AnimaRegionalTextConditioning:
"""Run the LLM Adapter separately for each regional conditioning and concatenate.
Args:
transformer: The AnimaTransformer instance (must be on device).
text_conditionings: List of per-region conditioning data.
dtype: Inference dtype.
Returns:
AnimaRegionalTextConditioning with concatenated context and masks.
"""
context_embeds_list: list[torch.Tensor] = []
context_ranges: list[Range] = []
image_masks: list[torch.Tensor | None] = []
cur_len = 0
for tc in text_conditionings:
qwen3_embeds = tc.qwen3_embeds.unsqueeze(0) # (1, seq_len, 1024)
t5xxl_ids = tc.t5xxl_ids.unsqueeze(0) # (1, seq_len)
t5xxl_weights = None
if tc.t5xxl_weights is not None:
t5xxl_weights = tc.t5xxl_weights.unsqueeze(0).unsqueeze(-1) # (1, seq_len, 1)
# Run the LLM Adapter to produce context for this region
context = transformer.preprocess_text_embeds(
qwen3_embeds.to(dtype=dtype),
t5xxl_ids,
t5xxl_weights=t5xxl_weights.to(dtype=dtype) if t5xxl_weights is not None else None,
)
# context shape: (1, 512, 1024) — squeeze batch dim
context_2d = context.squeeze(0) # (512, 1024)
context_embeds_list.append(context_2d)
context_ranges.append(Range(start=cur_len, end=cur_len + context_2d.shape[0]))
image_masks.append(tc.mask)
cur_len += context_2d.shape[0]
concatenated_context = torch.cat(context_embeds_list, dim=0)
return AnimaRegionalTextConditioning(
context_embeds=concatenated_context,
image_masks=image_masks,
context_ranges=context_ranges,
)
def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
device = TorchDevice.choose_torch_device()
inference_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
if self.denoising_start >= self.denoising_end:
raise ValueError(
f"denoising_start ({self.denoising_start}) must be less than denoising_end ({self.denoising_end})."
)
transformer_info = context.models.load(self.transformer.transformer)
# Compute image token grid dimensions for regional prompting
# Anima: 8x VAE compression, 2x patch size → 16x total
patch_size = 2
latent_height = self.height // ANIMA_LATENT_SCALE_FACTOR
latent_width = self.width // ANIMA_LATENT_SCALE_FACTOR
img_token_height = latent_height // patch_size
img_token_width = latent_width // patch_size
img_seq_len = img_token_height * img_token_width
# Load positive conditioning with optional regional masks
pos_text_conditionings = self._load_text_conditionings(
context=context,
cond_field=self.positive_conditioning,
img_token_height=img_token_height,
img_token_width=img_token_width,
dtype=inference_dtype,
device=device,
)
has_regional = len(pos_text_conditionings) > 1 or any(tc.mask is not None for tc in pos_text_conditionings)
# Load negative conditioning if CFG is enabled
do_cfg = not math.isclose(self.guidance_scale, 1.0) and self.negative_conditioning is not None
neg_text_conditionings: list[AnimaTextConditioning] | None = None
if do_cfg:
assert self.negative_conditioning is not None
neg_text_conditionings = self._load_text_conditionings(
context=context,
cond_field=self.negative_conditioning,
img_token_height=img_token_height,
img_token_width=img_token_width,
dtype=inference_dtype,
device=device,
)
# Generate sigma schedule
sigmas = self._get_sigmas(self.steps)
# Apply denoising_start and denoising_end clipping (for img2img/inpaint)
if self.denoising_start > 0 or self.denoising_end < 1:
total_sigmas = len(sigmas)
start_idx = int(self.denoising_start * (total_sigmas - 1))
end_idx = int(self.denoising_end * (total_sigmas - 1)) + 1
sigmas = sigmas[start_idx:end_idx]
total_steps = len(sigmas) - 1
# Load input latents if provided (image-to-image)
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
if init_latents is not None:
init_latents = init_latents.to(device=device, dtype=inference_dtype)
# Anima denoiser works in 3D: add temporal dim if needed
if init_latents.ndim == 4:
init_latents = init_latents.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]
# Generate initial noise (3D latent: [B, C, T, H, W])
noise = self._get_noise(self.height, self.width, inference_dtype, device, self.seed)
# Prepare input latents
if init_latents is not None:
if self.add_noise:
s_0 = sigmas[0]
latents = s_0 * noise + (1.0 - s_0) * init_latents
else:
latents = init_latents
else:
if self.denoising_start > 1e-5:
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
latents = noise
if total_steps <= 0:
return latents.squeeze(2)
# Prepare inpaint extension
inpaint_mask = self._prep_inpaint_mask(context, latents.squeeze(2))
inpaint_extension: AnimaInpaintExtension | None = None
if inpaint_mask is not None:
if init_latents is None:
raise ValueError("Initial latents are required when using an inpaint mask (image-to-image inpainting)")
inpaint_extension = AnimaInpaintExtension(
init_latents=init_latents.squeeze(2),
inpaint_mask=inpaint_mask,
noise=noise.squeeze(2),
shift=ANIMA_SHIFT,
)
step_callback = self._build_step_callback(context)
# Initialize diffusers scheduler if not using built-in Euler
scheduler: SchedulerMixin | None = None
use_scheduler = self.scheduler != "euler"
if use_scheduler:
scheduler_class = ANIMA_SCHEDULER_MAP[self.scheduler]
scheduler = scheduler_class(num_train_timesteps=1000, shift=1.0)
is_lcm = self.scheduler == "lcm"
set_timesteps_sig = inspect.signature(scheduler.set_timesteps)
if not is_lcm and "sigmas" in set_timesteps_sig.parameters:
scheduler.set_timesteps(sigmas=sigmas, device=device)
else:
scheduler.set_timesteps(num_inference_steps=total_steps, device=device)
num_scheduler_steps = len(scheduler.timesteps)
else:
num_scheduler_steps = total_steps
with ExitStack() as exit_stack:
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
# Apply LoRA models to the transformer.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=ANIMA_LORA_TRANSFORMER_PREFIX,
dtype=inference_dtype,
cached_weights=cached_weights,
)
)
# Run LLM Adapter for each regional conditioning to produce context vectors.
# This must happen with the transformer on device since it uses the adapter weights.
if has_regional:
pos_regional = self._run_llm_adapter_for_regions(transformer, pos_text_conditionings, inference_dtype)
pos_context = pos_regional.context_embeds.unsqueeze(0) # (1, total_ctx_len, 1024)
# Build regional prompting extension with cross-attention mask
regional_extension = AnimaRegionalPromptingExtension.from_regional_conditioning(
pos_regional, img_seq_len
)
# For negative, concatenate all regions without masking (matches Z-Image behavior)
neg_context = None
if do_cfg and neg_text_conditionings is not None:
neg_regional = self._run_llm_adapter_for_regions(
transformer, neg_text_conditionings, inference_dtype
)
neg_context = neg_regional.context_embeds.unsqueeze(0)
else:
# Single conditioning — run LLM Adapter via normal forward path
tc = pos_text_conditionings[0]
pos_qwen3_embeds = tc.qwen3_embeds.unsqueeze(0)
pos_t5xxl_ids = tc.t5xxl_ids.unsqueeze(0)
pos_t5xxl_weights = None
if tc.t5xxl_weights is not None:
pos_t5xxl_weights = tc.t5xxl_weights.unsqueeze(0).unsqueeze(-1)
# Pre-compute context via LLM Adapter
pos_context = transformer.preprocess_text_embeds(
pos_qwen3_embeds.to(dtype=inference_dtype),
pos_t5xxl_ids,
t5xxl_weights=pos_t5xxl_weights.to(dtype=inference_dtype)
if pos_t5xxl_weights is not None
else None,
)
neg_context = None
if do_cfg and neg_text_conditionings is not None:
ntc = neg_text_conditionings[0]
neg_qwen3 = ntc.qwen3_embeds.unsqueeze(0)
neg_ids = ntc.t5xxl_ids.unsqueeze(0)
neg_weights = None
if ntc.t5xxl_weights is not None:
neg_weights = ntc.t5xxl_weights.unsqueeze(0).unsqueeze(-1)
neg_context = transformer.preprocess_text_embeds(
neg_qwen3.to(dtype=inference_dtype),
neg_ids,
t5xxl_weights=neg_weights.to(dtype=inference_dtype) if neg_weights is not None else None,
)
regional_extension = None
# Apply regional prompting patch if we have regional masks
exit_stack.enter_context(patch_anima_for_regional_prompting(transformer, regional_extension))
# Helper to run transformer with pre-computed context (bypasses LLM Adapter)
def _run_transformer(ctx: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return transformer(
x=x.to(transformer.dtype if hasattr(transformer, "dtype") else inference_dtype),
timesteps=t,
context=ctx,
# t5xxl_ids=None skips the LLM Adapter — context is already pre-computed
)
if use_scheduler and scheduler is not None:
# Scheduler-based denoising
user_step = 0
pbar = tqdm(total=total_steps, desc="Denoising (Anima)")
for step_index in range(num_scheduler_steps):
sched_timestep = scheduler.timesteps[step_index]
sigma_curr = sched_timestep.item() / scheduler.config.num_train_timesteps
is_heun = hasattr(scheduler, "state_in_first_order")
in_first_order = scheduler.state_in_first_order if is_heun else True
timestep = torch.tensor(
[sigma_curr * ANIMA_MULTIPLIER], device=device, dtype=inference_dtype
).expand(latents.shape[0])
noise_pred_cond = _run_transformer(pos_context, latents, timestep).float()
if do_cfg and neg_context is not None:
noise_pred_uncond = _run_transformer(neg_context, latents, timestep).float()
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond
step_output = scheduler.step(model_output=noise_pred, timestep=sched_timestep, sample=latents)
latents = step_output.prev_sample
if step_index + 1 < len(scheduler.sigmas):
sigma_prev = scheduler.sigmas[step_index + 1].item()
else:
sigma_prev = 0.0
if inpaint_extension is not None:
latents_4d = latents.squeeze(2)
latents_4d = inpaint_extension.merge_intermediate_latents_with_init_latents(
latents_4d, sigma_prev
)
latents = latents_4d.unsqueeze(2)
if is_heun:
if not in_first_order:
user_step += 1
if user_step <= total_steps:
pbar.update(1)
step_callback(
PipelineIntermediateState(
step=user_step,
order=2,
total_steps=total_steps,
timestep=int(sigma_curr * 1000),
latents=latents.squeeze(2),
)
)
else:
user_step += 1
if user_step <= total_steps:
pbar.update(1)
step_callback(
PipelineIntermediateState(
step=user_step,
order=1,
total_steps=total_steps,
timestep=int(sigma_curr * 1000),
latents=latents.squeeze(2),
)
)
pbar.close()
else:
# Built-in Euler implementation (default for Anima)
for step_idx in tqdm(range(total_steps), desc="Denoising (Anima)"):
sigma_curr = sigmas[step_idx]
sigma_prev = sigmas[step_idx + 1]
timestep = torch.tensor(
[sigma_curr * ANIMA_MULTIPLIER], device=device, dtype=inference_dtype
).expand(latents.shape[0])
noise_pred_cond = _run_transformer(pos_context, latents, timestep).float()
if do_cfg and neg_context is not None:
noise_pred_uncond = _run_transformer(neg_context, latents, timestep).float()
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond
latents_dtype = latents.dtype
latents = latents.to(dtype=torch.float32)
latents = latents + (sigma_prev - sigma_curr) * noise_pred
latents = latents.to(dtype=latents_dtype)
if inpaint_extension is not None:
latents_4d = latents.squeeze(2)
latents_4d = inpaint_extension.merge_intermediate_latents_with_init_latents(
latents_4d, sigma_prev
)
latents = latents_4d.unsqueeze(2)
step_callback(
PipelineIntermediateState(
step=step_idx + 1,
order=1,
total_steps=total_steps,
timestep=int(sigma_curr * 1000),
latents=latents.squeeze(2),
),
)
# Remove temporal dimension for output: [B, C, 1, H, W] -> [B, C, H, W]
return latents.squeeze(2)
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, BaseModelType.Anima)
return step_callback
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
"""Iterate over LoRA models to apply to the transformer."""
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)
if not isinstance(lora_info.model, ModelPatchRaw):
raise TypeError(
f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. "
"The LoRA model may be corrupted or incompatible."
)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -0,0 +1,119 @@
"""Anima image-to-latents invocation.
Encodes an image to latent space using the Anima VAE (AutoencoderKLWan or FLUX VAE).
For Wan VAE (AutoencoderKLWan):
- Input image is converted to 5D tensor [B, C, T, H, W] with T=1
- After encoding, latents are normalized: (latents - mean) / std
(inverse of the denormalization in anima_latents_to_image.py)
For FLUX VAE (AutoEncoder):
- Encoding is handled internally by the FLUX VAE
"""
from typing import Union
import einops
import torch
from diffusers.models.autoencoders import AutoencoderKLWan
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.autoencoder import AutoEncoder as FluxAutoEncoder
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
AnimaVAE = Union[AutoencoderKLWan, FluxAutoEncoder]
@invocation(
"anima_i2l",
title="Image to Latents - Anima",
tags=["image", "latents", "vae", "i2l", "anima"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class AnimaImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates latents from an image using the Anima VAE (supports Wan 2.1 and FLUX VAE)."""
image: ImageField = InputField(description="The image to encode.")
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
if not isinstance(vae_info.model, (AutoencoderKLWan, FluxAutoEncoder)):
raise TypeError(
f"Expected AutoencoderKLWan or FluxAutoEncoder for Anima VAE, got {type(vae_info.model).__name__}."
)
estimated_working_memory = estimate_vae_working_memory_flux(
operation="encode",
image_tensor=image_tensor,
vae=vae_info.model,
)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
if not isinstance(vae, (AutoencoderKLWan, FluxAutoEncoder)):
raise TypeError(f"Expected AutoencoderKLWan or FluxAutoEncoder, got {type(vae).__name__}.")
vae_dtype = next(iter(vae.parameters())).dtype
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
with torch.inference_mode():
if isinstance(vae, FluxAutoEncoder):
# FLUX VAE handles scaling internally
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
latents = vae.encode(image_tensor, sample=True, generator=generator)
else:
# AutoencoderKLWan expects 5D input [B, C, T, H, W]
if image_tensor.ndim == 4:
image_tensor = image_tensor.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]
encoded = vae.encode(image_tensor, return_dict=False)[0]
latents = encoded.sample().to(dtype=vae_dtype)
# Normalize to denoiser space: (latents - mean) / std
# This is the inverse of the denormalization in anima_latents_to_image.py
latents_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latents)
latents_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(latents)
latents = (latents - latents_mean) / latents_std
# Remove temporal dimension: [B, C, 1, H, W] -> [B, C, H, W]
if latents.ndim == 5:
latents = latents.squeeze(2)
return latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
vae_info = context.models.load(self.vae.vae)
if not isinstance(vae_info.model, (AutoencoderKLWan, FluxAutoEncoder)):
raise TypeError(
f"Expected AutoencoderKLWan or FluxAutoEncoder for Anima VAE, got {type(vae_info.model).__name__}."
)
context.util.signal_progress("Running Anima VAE encode")
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)

View File

@@ -0,0 +1,108 @@
"""Anima latents-to-image invocation.
Decodes Anima latents using the QwenImage VAE (AutoencoderKLWan) or
compatible FLUX VAE as fallback.
Latents from the denoiser are in normalized space (zero-centered). Before
VAE decode, they must be denormalized using the Wan 2.1 per-channel
mean/std: latents = latents * std + mean (matching diffusers WanPipeline).
The VAE expects 5D latents [B, C, T, H, W] — for single images, T=1.
"""
import torch
from diffusers.models.autoencoders import AutoencoderKLWan
from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.autoencoder import AutoEncoder as FluxAutoEncoder
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
@invocation(
"anima_l2i",
title="Latents to Image - Anima",
tags=["latents", "image", "vae", "l2i", "anima"],
category="latents",
version="1.0.2",
classification=Classification.Prototype,
)
class AnimaLatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents using the Anima VAE.
Supports the Wan 2.1 QwenImage VAE (AutoencoderKLWan) with explicit
latent denormalization, and FLUX VAE as fallback.
"""
latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
if not isinstance(vae_info.model, (AutoencoderKLWan, FluxAutoEncoder)):
raise TypeError(
f"Expected AutoencoderKLWan or FluxAutoEncoder for Anima VAE, got {type(vae_info.model).__name__}."
)
estimated_working_memory = estimate_vae_working_memory_flux(
operation="decode",
image_tensor=latents,
vae=vae_info.model,
)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
context.util.signal_progress("Running Anima VAE decode")
if not isinstance(vae, (AutoencoderKLWan, FluxAutoEncoder)):
raise TypeError(f"Expected AutoencoderKLWan or FluxAutoEncoder, got {type(vae).__name__}.")
vae_dtype = next(iter(vae.parameters())).dtype
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
TorchDevice.empty_cache()
with torch.inference_mode():
if isinstance(vae, FluxAutoEncoder):
# FLUX VAE handles scaling internally, expects 4D [B, C, H, W]
img = vae.decode(latents)
else:
# Expects 5D latents [B, C, T, H, W]
if latents.ndim == 4:
latents = latents.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]
# Denormalize from denoiser space to raw VAE space
# (same as diffusers WanPipeline and ComfyUI Wan21.process_out)
latents_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latents)
latents_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(latents)
latents = latents * latents_std + latents_mean
decoded = vae.decode(latents, return_dict=False)[0]
# Output is 5D [B, C, T, H, W] — squeeze temporal dim
if decoded.ndim == 5:
decoded = decoded.squeeze(2)
img = decoded
img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c")
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
TorchDevice.empty_cache()
image_dto = context.images.save(image=img_pil)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,162 @@
from typing import Optional
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, Qwen3EncoderField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
@invocation_output("anima_lora_loader_output")
class AnimaLoRALoaderOutput(BaseInvocationOutput):
"""Anima LoRA Loader Output"""
transformer: Optional[TransformerField] = OutputField(
default=None, description=FieldDescriptions.transformer, title="Anima Transformer"
)
qwen3_encoder: Optional[Qwen3EncoderField] = OutputField(
default=None, description=FieldDescriptions.qwen3_encoder, title="Qwen3 Encoder"
)
@invocation(
"anima_lora_loader",
title="Apply LoRA - Anima",
tags=["lora", "model", "anima"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class AnimaLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to an Anima transformer and/or Qwen3 text encoder."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model,
title="LoRA",
ui_model_base=BaseModelType.Anima,
ui_model_type=ModelType.LoRA,
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
transformer: TransformerField | None = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="Anima Transformer",
)
qwen3_encoder: Qwen3EncoderField | None = InputField(
default=None,
title="Qwen3 Encoder",
description=FieldDescriptions.qwen3_encoder,
input=Input.Connection,
)
def invoke(self, context: InvocationContext) -> AnimaLoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")
if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
if self.qwen3_encoder and any(lora.lora.key == lora_key for lora in self.qwen3_encoder.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to Qwen3 encoder.')
output = AnimaLoRALoaderOutput()
if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
output.transformer.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
if self.qwen3_encoder is not None:
output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True)
output.qwen3_encoder.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
return output
@invocation(
"anima_lora_collection_loader",
title="Apply LoRA Collection - Anima",
tags=["lora", "model", "anima"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class AnimaLoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to an Anima transformer."""
loras: Optional[LoRAField | list[LoRAField]] = InputField(
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)
transformer: Optional[TransformerField] = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="Transformer",
)
qwen3_encoder: Qwen3EncoderField | None = InputField(
default=None,
title="Qwen3 Encoder",
description=FieldDescriptions.qwen3_encoder,
input=Input.Connection,
)
def invoke(self, context: InvocationContext) -> AnimaLoRALoaderOutput:
output = AnimaLoRALoaderOutput()
if self.loras is None:
if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
if self.qwen3_encoder is not None:
output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True)
return output
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []
if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
if self.qwen3_encoder is not None:
output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True)
for lora in loras:
if lora is None:
continue
if lora.lora.key in added_loras:
continue
if not context.models.exists(lora.lora.key):
raise ValueError(f"Unknown lora: {lora.lora.key}!")
if lora.lora.base is not BaseModelType.Anima:
raise ValueError(
f"LoRA '{lora.lora.key}' is for {lora.lora.base.value if lora.lora.base else 'unknown'} models, "
"not Anima models. Ensure you are using an Anima compatible LoRA."
)
added_loras.append(lora.lora.key)
if self.transformer is not None and output.transformer is not None:
output.transformer.loras.append(lora)
if self.qwen3_encoder is not None and output.qwen3_encoder is not None:
output.qwen3_encoder.loras.append(lora)
return output

View File

@@ -0,0 +1,102 @@
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import (
ModelIdentifierField,
Qwen3EncoderField,
T5EncoderField,
TransformerField,
VAEField,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.t5_model_identifier import (
preprocess_t5_encoder_model_identifier,
preprocess_t5_tokenizer_model_identifier,
)
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
@invocation_output("anima_model_loader_output")
class AnimaModelLoaderOutput(BaseInvocationOutput):
"""Anima model loader output."""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
qwen3_encoder: Qwen3EncoderField = OutputField(description=FieldDescriptions.qwen3_encoder, title="Qwen3 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
@invocation(
"anima_model_loader",
title="Main Model - Anima",
tags=["model", "anima"],
category="model",
version="1.3.0",
classification=Classification.Prototype,
)
class AnimaModelLoaderInvocation(BaseInvocation):
"""Loads an Anima model, outputting its submodels.
Anima uses:
- Transformer: Cosmos Predict2 DiT + LLM Adapter (from single-file checkpoint)
- Qwen3 Encoder: Qwen3 0.6B (standalone single-file)
- VAE: AutoencoderKLQwenImage / Wan 2.1 VAE (standalone single-file or FLUX VAE)
- T5 Encoder: T5-XXL model (only the tokenizer submodel is used, for LLM Adapter token IDs)
"""
model: ModelIdentifierField = InputField(
description="Anima main model (transformer + LLM adapter).",
input=Input.Direct,
ui_model_base=BaseModelType.Anima,
ui_model_type=ModelType.Main,
title="Transformer",
)
vae_model: ModelIdentifierField = InputField(
description="Standalone VAE model. Anima uses a Wan 2.1 / QwenImage VAE (16-channel). "
"A FLUX VAE can also be used as a compatible fallback.",
input=Input.Direct,
ui_model_type=ModelType.VAE,
title="VAE",
)
qwen3_encoder_model: ModelIdentifierField = InputField(
description="Standalone Qwen3 0.6B Encoder model.",
input=Input.Direct,
ui_model_type=ModelType.Qwen3Encoder,
title="Qwen3 Encoder",
)
t5_encoder_model: ModelIdentifierField = InputField(
description="T5-XXL encoder model. The tokenizer submodel is used for Anima text encoding.",
input=Input.Direct,
ui_model_type=ModelType.T5Encoder,
title="T5 Encoder",
)
def invoke(self, context: InvocationContext) -> AnimaModelLoaderOutput:
# Transformer always comes from the main model
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
# VAE
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
# Qwen3 Encoder
qwen3_tokenizer = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
qwen3_encoder = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
# T5 Encoder (only tokenizer submodel is used by Anima)
t5_tokenizer = preprocess_t5_tokenizer_model_identifier(self.t5_encoder_model)
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model)
return AnimaModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
qwen3_encoder=Qwen3EncoderField(tokenizer=qwen3_tokenizer, text_encoder=qwen3_encoder),
vae=VAEField(vae=vae),
t5_encoder=T5EncoderField(tokenizer=t5_tokenizer, text_encoder=t5_encoder, loras=[]),
)

View File

@@ -0,0 +1,221 @@
"""Anima text encoder invocation.
Encodes text using the dual-conditioning pipeline:
1. Qwen3 0.6B: Produces hidden states (last layer)
2. T5-XXL Tokenizer: Produces token IDs only (no T5 model needed)
Both outputs are stored together in AnimaConditioningInfo and used by
the LLM Adapter inside the transformer during denoising.
Key differences from Z-Image text encoder:
- Anima uses Qwen3 0.6B (base model, NOT instruct) — no chat template
- Anima additionally tokenizes with T5-XXL tokenizer to get token IDs
- Qwen3 output uses all positions (including padding) for full context
"""
from contextlib import ExitStack
from typing import Iterator, Tuple
import torch
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
AnimaConditioningField,
FieldDescriptions,
Input,
InputField,
TensorField,
UIComponent,
)
from invokeai.app.invocations.model import Qwen3EncoderField, T5EncoderField
from invokeai.app.invocations.primitives import AnimaConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.anima_lora_constants import ANIMA_LORA_QWEN3_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
AnimaConditioningInfo,
ConditioningFieldData,
)
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger(__name__)
# T5-XXL max sequence length for token IDs
T5_MAX_SEQ_LEN = 512
# Safety cap for Qwen3 sequence length to prevent GPU OOM on extremely long prompts.
# Qwen3 0.6B supports 32K context but the LLM Adapter doesn't need that much.
QWEN3_MAX_SEQ_LEN = 8192
@invocation(
"anima_text_encoder",
title="Prompt - Anima",
tags=["prompt", "conditioning", "anima"],
category="conditioning",
version="1.3.0",
classification=Classification.Prototype,
)
class AnimaTextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for an Anima image.
Uses Qwen3 0.6B for hidden state extraction and T5-XXL tokenizer for
token IDs (no T5 model weights needed). Both are combined by the
LLM Adapter inside the Anima transformer during denoising.
"""
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
qwen3_encoder: Qwen3EncoderField = InputField(
title="Qwen3 Encoder",
description=FieldDescriptions.qwen3_encoder,
input=Input.Connection,
)
t5_encoder: T5EncoderField = InputField(
title="T5 Encoder",
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
mask: TensorField | None = InputField(
default=None,
description="A mask defining the region that this conditioning prompt applies to.",
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> AnimaConditioningOutput:
qwen3_embeds, t5xxl_ids, t5xxl_weights = self._encode_prompt(context)
# Move to CPU for storage
qwen3_embeds = qwen3_embeds.detach().to("cpu")
t5xxl_ids = t5xxl_ids.detach().to("cpu")
t5xxl_weights = t5xxl_weights.detach().to("cpu") if t5xxl_weights is not None else None
conditioning_data = ConditioningFieldData(
conditionings=[
AnimaConditioningInfo(
qwen3_embeds=qwen3_embeds,
t5xxl_ids=t5xxl_ids,
t5xxl_weights=t5xxl_weights,
)
]
)
conditioning_name = context.conditioning.save(conditioning_data)
return AnimaConditioningOutput(
conditioning=AnimaConditioningField(conditioning_name=conditioning_name, mask=self.mask)
)
def _encode_prompt(
self,
context: InvocationContext,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""Encode prompt using Qwen3 0.6B and T5-XXL tokenizer.
Returns:
Tuple of (qwen3_embeds, t5xxl_ids, t5xxl_weights).
- qwen3_embeds: Shape (max_seq_len, 1024) — includes all positions (including padding)
to preserve full sequence context for the LLM Adapter.
- t5xxl_ids: Shape (seq_len,) — T5-XXL token IDs (unpadded).
- t5xxl_weights: None (uniform weights for now).
"""
prompt = self.prompt
# --- Step 1: Encode with Qwen3 0.6B ---
text_encoder_info = context.models.load(self.qwen3_encoder.text_encoder)
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
with ExitStack() as exit_stack:
(_, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
device = text_encoder.device
# Apply LoRA models to the text encoder
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
model=text_encoder,
patches=self._lora_iterator(context),
prefix=ANIMA_LORA_QWEN3_PREFIX,
dtype=lora_dtype,
)
)
if not isinstance(text_encoder, PreTrainedModel):
raise TypeError(f"Expected PreTrainedModel for text encoder, got {type(text_encoder).__name__}.")
if not isinstance(tokenizer, PreTrainedTokenizerBase):
raise TypeError(f"Expected PreTrainedTokenizerBase for tokenizer, got {type(tokenizer).__name__}.")
context.util.signal_progress("Running Qwen3 0.6B text encoder")
# Anima uses base Qwen3 (not instruct) — tokenize directly, no chat template.
# A safety cap is applied to prevent GPU OOM on extremely long prompts.
text_inputs = tokenizer(
prompt,
padding=False,
truncation=True,
max_length=QWEN3_MAX_SEQ_LEN,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
if not isinstance(text_input_ids, torch.Tensor) or not isinstance(attention_mask, torch.Tensor):
raise TypeError("Tokenizer returned unexpected types.")
if text_input_ids.shape[-1] == QWEN3_MAX_SEQ_LEN:
logger.warning(
f"Prompt was truncated to {QWEN3_MAX_SEQ_LEN} tokens. "
"Consider shortening the prompt for best results."
)
# Ensure at least 1 token (empty prompts produce 0 tokens with padding=False)
if text_input_ids.shape[-1] == 0:
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
text_input_ids = torch.tensor([[pad_id]])
attention_mask = torch.tensor([[1]])
# Get last hidden state from Qwen3 (final layer output)
prompt_mask = attention_mask.to(device).bool()
outputs = text_encoder(
text_input_ids.to(device),
attention_mask=prompt_mask,
output_hidden_states=True,
)
if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
raise RuntimeError("Text encoder did not return hidden_states.")
if len(outputs.hidden_states) < 1:
raise RuntimeError(f"Expected at least 1 hidden state, got {len(outputs.hidden_states)}.")
# Use last hidden state — only real tokens, no padding
qwen3_embeds = outputs.hidden_states[-1][0] # Shape: (seq_len, 1024)
# --- Step 2: Tokenize with T5-XXL tokenizer (IDs only, no model) ---
context.util.signal_progress("Tokenizing with T5-XXL")
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
with t5_tokenizer_info.model_on_device() as (_, t5_tokenizer):
t5_tokens = t5_tokenizer(
prompt,
padding=False,
truncation=True,
max_length=T5_MAX_SEQ_LEN,
return_tensors="pt",
)
t5xxl_ids = t5_tokens.input_ids[0] # Shape: (seq_len,)
return qwen3_embeds, t5xxl_ids, None
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
"""Iterate over LoRA models to apply to the Qwen3 text encoder."""
for lora in self.qwen3_encoder.loras:
lora_info = context.models.load(lora.lora)
if not isinstance(lora_info.model, ModelPatchRaw):
raise TypeError(
f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. "
"The LoRA model may be corrupted or incompatible."
)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -56,7 +56,7 @@ class BaseBatchInvocation(BaseInvocation):
"image_batch",
title="Image Batch",
tags=["primitives", "image", "batch", "special"],
category="primitives",
category="batch",
version="1.0.0",
classification=Classification.Special,
)
@@ -87,7 +87,7 @@ class ImageGeneratorField(BaseModel):
"image_generator",
title="Image Generator",
tags=["primitives", "board", "image", "batch", "special"],
category="primitives",
category="batch",
version="1.0.0",
classification=Classification.Special,
)
@@ -111,7 +111,7 @@ class ImageGenerator(BaseInvocation):
"string_batch",
title="String Batch",
tags=["primitives", "string", "batch", "special"],
category="primitives",
category="batch",
version="1.0.0",
classification=Classification.Special,
)
@@ -142,7 +142,7 @@ class StringGeneratorField(BaseModel):
"string_generator",
title="String Generator",
tags=["primitives", "string", "number", "batch", "special"],
category="primitives",
category="batch",
version="1.0.0",
classification=Classification.Special,
)
@@ -166,7 +166,7 @@ class StringGenerator(BaseInvocation):
"integer_batch",
title="Integer Batch",
tags=["primitives", "integer", "number", "batch", "special"],
category="primitives",
category="batch",
version="1.0.0",
classification=Classification.Special,
)
@@ -195,7 +195,7 @@ class IntegerGeneratorField(BaseModel):
"integer_generator",
title="Integer Generator",
tags=["primitives", "int", "number", "batch", "special"],
category="primitives",
category="batch",
version="1.0.0",
classification=Classification.Special,
)
@@ -219,7 +219,7 @@ class IntegerGenerator(BaseInvocation):
"float_batch",
title="Float Batch",
tags=["primitives", "float", "number", "batch", "special"],
category="primitives",
category="batch",
version="1.0.0",
classification=Classification.Special,
)
@@ -250,7 +250,7 @@ class FloatGeneratorField(BaseModel):
"float_generator",
title="Float Generator",
tags=["primitives", "float", "number", "batch", "special"],
category="primitives",
category="batch",
version="1.0.0",
classification=Classification.Special,
)

View File

@@ -11,7 +11,7 @@ from invokeai.backend.image_util.util import cv2_to_pil, pil_to_cv2
"canny_edge_detection",
title="Canny Edge Detection",
tags=["controlnet", "canny"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class CannyEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -0,0 +1,27 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation(
"canvas_output",
title="Canvas Output",
tags=["canvas", "output", "image"],
category="canvas",
version="1.0.0",
use_cache=False,
)
class CanvasOutputInvocation(BaseInvocation):
"""Outputs an image to the canvas staging area.
Use this node in workflows intended for canvas workflow integration.
Connect the final image of your workflow to this node to send it
to the canvas staging area when run via 'Run Workflow on Canvas'."""
image: ImageField = InputField(description=FieldDescriptions.image)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)

View File

@@ -33,7 +33,7 @@ from invokeai.backend.util.devices import TorchDevice
"cogview4_denoise",
title="Denoise - CogView4",
tags=["image", "cogview4"],
category="image",
category="latents",
version="1.0.0",
classification=Classification.Prototype,
)

View File

@@ -27,7 +27,7 @@ from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory
"cogview4_i2l",
title="Image to Latents - CogView4",
tags=["image", "latents", "vae", "i2l", "cogview4"],
category="image",
category="latents",
version="1.0.0",
classification=Classification.Prototype,
)

View File

@@ -6,6 +6,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import GlmEncoderField
from invokeai.app.invocations.primitives import CogView4ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
CogView4ConditioningInfo,
ConditioningFieldData,
@@ -19,7 +20,7 @@ COGVIEW4_GLM_MAX_SEQ_LEN = 1024
"cogview4_text_encoder",
title="Prompt - CogView4",
tags=["prompt", "conditioning", "cogview4"],
category="conditioning",
category="prompt",
version="1.0.0",
classification=Classification.Prototype,
)
@@ -46,10 +47,18 @@ class CogView4TextEncoderInvocation(BaseInvocation):
prompt = [self.prompt]
# TODO(ryand): Add model inputs to the invocation rather than hard-coding.
glm_text_encoder_info = context.models.load(self.glm_encoder.text_encoder)
with (
context.models.load(self.glm_encoder.text_encoder).model_on_device() as (_, glm_text_encoder),
glm_text_encoder_info.model_on_device() as (_, glm_text_encoder),
context.models.load(self.glm_encoder.tokenizer).model_on_device() as (_, glm_tokenizer),
):
repaired_tensors = glm_text_encoder_info.repair_required_tensors_on_device()
device = get_effective_device(glm_text_encoder)
if repaired_tensors > 0:
context.logger.warning(
f"Recovered {repaired_tensors} required GLM tensor(s) onto {device} after a partial device mismatch."
)
context.util.signal_progress("Running GLM text encoder")
assert isinstance(glm_text_encoder, GlmModel)
assert isinstance(glm_tokenizer, PreTrainedTokenizerFast)
@@ -85,9 +94,7 @@ class CogView4TextEncoderInvocation(BaseInvocation):
device=text_input_ids.device,
)
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
prompt_embeds = glm_text_encoder(
text_input_ids.to(glm_text_encoder.device), output_hidden_states=True
).hidden_states[-2]
prompt_embeds = glm_text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2]
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds

View File

@@ -11,9 +11,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
@invocation(
"range", title="Integer Range", tags=["collection", "integer", "range"], category="collections", version="1.0.0"
)
@invocation("range", title="Integer Range", tags=["collection", "integer", "range"], category="batch", version="1.0.0")
class RangeInvocation(BaseInvocation):
"""Creates a range of numbers from start to stop with step"""
@@ -35,7 +33,7 @@ class RangeInvocation(BaseInvocation):
"range_of_size",
title="Integer Range of Size",
tags=["collection", "integer", "size", "range"],
category="collections",
category="batch",
version="1.0.0",
)
class RangeOfSizeInvocation(BaseInvocation):
@@ -55,7 +53,7 @@ class RangeOfSizeInvocation(BaseInvocation):
"random_range",
title="Random Range",
tags=["range", "integer", "random", "collection"],
category="collections",
category="batch",
version="1.0.1",
use_cache=False,
)

View File

@@ -11,7 +11,7 @@ from invokeai.backend.image_util.util import np_to_pil, pil_to_np
"color_map",
title="Color Map",
tags=["controlnet"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class ColorMapInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -19,6 +19,7 @@ from invokeai.app.invocations.model import CLIPField
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -42,7 +43,7 @@ from invokeai.backend.util.devices import TorchDevice
"compel",
title="Prompt - SD1.5",
tags=["prompt", "compel"],
category="conditioning",
category="prompt",
version="1.2.1",
)
class CompelInvocation(BaseInvocation):
@@ -103,7 +104,7 @@ class CompelInvocation(BaseInvocation):
textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False,
device=text_encoder.device, # Use the device the model is actually on
device=get_effective_device(text_encoder),
split_long_text_mode=SplitLongTextMode.SENTENCES,
)
@@ -212,7 +213,7 @@ class SDXLPromptInvocationBase:
truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled,
device=text_encoder.device, # Use the device the model is actually on
device=get_effective_device(text_encoder),
split_long_text_mode=SplitLongTextMode.SENTENCES,
)
@@ -247,7 +248,7 @@ class SDXLPromptInvocationBase:
"sdxl_compel_prompt",
title="Prompt - SDXL",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
category="prompt",
version="1.2.1",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@@ -341,7 +342,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"sdxl_refiner_compel_prompt",
title="Prompt - SDXL Refiner",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
category="prompt",
version="1.1.2",
)
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@@ -390,7 +391,7 @@ class CLIPSkipInvocationOutput(BaseInvocationOutput):
"clip_skip",
title="Apply CLIP Skip - SD1.5, SDXL",
tags=["clipskip", "clip", "skip"],
category="conditioning",
category="prompt",
version="1.1.1",
)
class CLIPSkipInvocation(BaseInvocation):

View File

@@ -9,7 +9,7 @@ from invokeai.backend.image_util.content_shuffle import content_shuffle
"content_shuffle",
title="Content Shuffle",
tags=["controlnet", "normal"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class ContentShuffleInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -64,7 +64,7 @@ class ControlOutput(BaseInvocationOutput):
@invocation(
"controlnet", title="ControlNet - SD1.5, SD2, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3"
"controlnet", title="ControlNet - SD1.5, SD2, SDXL", tags=["controlnet"], category="conditioning", version="1.1.3"
)
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
@@ -116,7 +116,7 @@ class ControlNetInvocation(BaseInvocation):
"heuristic_resize",
title="Heuristic Resize",
tags=["image, controlnet"],
category="image",
category="controlnet_preprocessors",
version="1.1.1",
classification=Classification.Prototype,
)

View File

@@ -18,7 +18,7 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_t
"create_denoise_mask",
title="Create Denoise Mask",
tags=["mask", "denoise"],
category="latents",
category="mask",
version="1.0.2",
)
class CreateDenoiseMaskInvocation(BaseInvocation):

View File

@@ -41,7 +41,7 @@ class GradientMaskOutput(BaseInvocationOutput):
"create_gradient_mask",
title="Create Gradient Mask",
tags=["mask", "denoise"],
category="latents",
category="mask",
version="1.3.0",
)
class CreateGradientMaskInvocation(BaseInvocation):

View File

@@ -20,7 +20,7 @@ DEPTH_ANYTHING_MODELS = {
"depth_anything_depth_estimation",
title="Depth Anything Depth Estimation",
tags=["controlnet", "depth", "depth anything"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class DepthAnythingDepthEstimationInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -11,7 +11,7 @@ from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
"dw_openpose_detection",
title="DW Openpose Detection",
tags=["controlnet", "dwpose", "openpose"],
category="controlnet",
category="controlnet_preprocessors",
version="1.1.1",
)
class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -1,4 +1,4 @@
from typing import Any, ClassVar
from typing import Any, ClassVar, Literal
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
@@ -34,18 +34,18 @@ class BaseExternalImageGenerationInvocation(BaseInvocation, WithMetadata, WithBo
)
mode: ExternalGenerationMode = InputField(default="txt2img", description="Generation mode")
prompt: str = InputField(description="Prompt")
negative_prompt: str | None = InputField(default=None, description="Negative prompt")
seed: int | None = InputField(default=None, description=FieldDescriptions.seed)
num_images: int = InputField(default=1, gt=0, description="Number of images to generate")
width: int = InputField(default=1024, gt=0, description=FieldDescriptions.width)
height: int = InputField(default=1024, gt=0, description=FieldDescriptions.height)
steps: int | None = InputField(default=None, gt=0, description=FieldDescriptions.steps)
guidance: float | None = InputField(default=None, ge=0, description="Guidance strength")
image_size: str | None = InputField(default=None, description="Image size preset (e.g. 1K, 2K, 4K)")
init_image: ImageField | None = InputField(default=None, description="Init image for img2img/inpaint")
mask_image: ImageField | None = InputField(default=None, description="Mask image for inpaint")
reference_images: list[ImageField] = InputField(default=[], description="Reference images")
reference_image_weights: list[float] | None = InputField(default=None, description="Reference image weights")
reference_image_modes: list[str] | None = InputField(default=None, description="Reference image modes")
def _build_provider_options(self) -> dict[str, Any] | None:
"""Override in provider-specific subclasses to pass extra options."""
return None
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
model_config = context.models.get_config(self.model)
@@ -65,38 +65,25 @@ class BaseExternalImageGenerationInvocation(BaseInvocation, WithMetadata, WithBo
if self.mask_image is not None:
mask_image = context.images.get_pil(self.mask_image.image_name, mode="L")
if self.reference_image_weights is not None and len(self.reference_image_weights) != len(self.reference_images):
raise ValueError("reference_image_weights must match reference_images length")
if self.reference_image_modes is not None and len(self.reference_image_modes) != len(self.reference_images):
raise ValueError("reference_image_modes must match reference_images length")
reference_images: list[ExternalReferenceImage] = []
for index, image_field in enumerate(self.reference_images):
for image_field in self.reference_images:
reference_image = context.images.get_pil(image_field.image_name, mode="RGB")
weight = None
mode = None
if self.reference_image_weights is not None:
weight = self.reference_image_weights[index]
if self.reference_image_modes is not None:
mode = self.reference_image_modes[index]
reference_images.append(ExternalReferenceImage(image=reference_image, weight=weight, mode=mode))
reference_images.append(ExternalReferenceImage(image=reference_image))
request = ExternalGenerationRequest(
model=model_config,
mode=self.mode,
prompt=self.prompt,
negative_prompt=self.negative_prompt,
seed=self.seed,
num_images=self.num_images,
width=self.width,
height=self.height,
steps=self.steps,
guidance=self.guidance,
image_size=self.image_size,
init_image=init_image,
mask_image=mask_image,
reference_images=reference_images,
metadata=self._build_request_metadata(),
provider_options=self._build_provider_options(),
)
result = context._services.external_generation.generate(request)
@@ -172,6 +159,23 @@ class OpenAIImageGenerationInvocation(BaseExternalImageGenerationInvocation):
provider_id = "openai"
quality: Literal["auto", "high", "medium", "low"] = InputField(default="auto", description="Output image quality")
background: Literal["auto", "transparent", "opaque"] = InputField(
default="auto", description="Background transparency handling"
)
input_fidelity: Literal["low", "high"] | None = InputField(
default=None, description="Fidelity to source images (edits only)"
)
def _build_provider_options(self) -> dict[str, Any]:
options: dict[str, Any] = {
"quality": self.quality,
"background": self.background,
}
if self.input_fidelity is not None:
options["input_fidelity"] = self.input_fidelity
return options
@invocation(
"gemini_image_generation",
@@ -184,3 +188,16 @@ class GeminiImageGenerationInvocation(BaseExternalImageGenerationInvocation):
"""Generate images using a Gemini-hosted external model."""
provider_id = "gemini"
temperature: float | None = InputField(default=None, ge=0.0, le=2.0, description="Sampling temperature")
thinking_level: Literal["minimal", "high"] | None = InputField(
default=None, description="Thinking level for image generation"
)
def _build_provider_options(self) -> dict[str, Any] | None:
options: dict[str, Any] = {}
if self.temperature is not None:
options["temperature"] = self.temperature
if self.thinking_level is not None:
options["thinking_level"] = self.thinking_level
return options or None

View File

@@ -435,7 +435,9 @@ def get_faces_list(
return all_faces
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.2")
@invocation(
"face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="segmentation", version="1.2.2"
)
class FaceOffInvocation(BaseInvocation, WithMetadata):
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
@@ -514,7 +516,9 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
return output
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.2")
@invocation(
"face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="segmentation", version="1.2.2"
)
class FaceMaskInvocation(BaseInvocation, WithMetadata):
"""Face mask creation using mediapipe face detection"""
@@ -617,7 +621,11 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
@invocation(
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.2"
"face_identifier",
title="FaceIdentifier",
tags=["image", "face", "identifier"],
category="segmentation",
version="1.2.2",
)
class FaceIdentifierInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""

View File

@@ -171,6 +171,8 @@ class FieldDescriptions:
sd3_model = "SD3 model (MMDiTX) to load"
cogview4_model = "CogView4 model (Transformer) to load"
z_image_model = "Z-Image model (Transformer) to load"
qwen_image_model = "Qwen Image Edit model (Transformer) to load"
qwen_vl_encoder = "Qwen2.5-VL tokenizer, processor and text/vision encoder"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
@@ -340,6 +342,27 @@ class ZImageConditioningField(BaseModel):
)
class QwenImageConditioningField(BaseModel):
"""A Qwen Image Edit conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
class AnimaConditioningField(BaseModel):
"""An Anima conditioning tensor primitive value.
Anima conditioning contains Qwen3 0.6B hidden states and T5-XXL token IDs,
which are combined by the LLM Adapter inside the transformer.
"""
conditioning_name: str = Field(description="The name of conditioning tensor")
mask: Optional[TensorField] = Field(
default=None,
description="The mask associated with this conditioning tensor for regional prompting. "
"Excluded regions should be set to False, included regions should be set to True.",
)
class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

View File

@@ -53,7 +53,7 @@ from invokeai.backend.util.devices import TorchDevice
"flux2_denoise",
title="FLUX2 Denoise",
tags=["image", "flux", "flux2", "klein", "denoise"],
category="image",
category="latents",
version="1.4.0",
classification=Classification.Prototype,
)

View File

@@ -25,6 +25,7 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import Qwen3EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_T5_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -44,7 +45,7 @@ KLEIN_MAX_SEQ_LEN = 512
"flux2_klein_text_encoder",
title="Prompt - Flux2 Klein",
tags=["prompt", "conditioning", "flux", "klein", "qwen3"],
category="conditioning",
category="prompt",
version="1.1.1",
classification=Classification.Prototype,
)
@@ -100,7 +101,12 @@ class Flux2KleinTextEncoderInvocation(BaseInvocation):
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
device = text_encoder.device
repaired_tensors = text_encoder_info.repair_required_tensors_on_device()
device = get_effective_device(text_encoder)
if repaired_tensors > 0:
context.logger.warning(
f"Recovered {repaired_tensors} required Qwen3 tensor(s) onto {device} after a partial device mismatch."
)
# Apply LoRA models
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)

View File

@@ -50,7 +50,7 @@ class FluxControlNetOutput(BaseInvocationOutput):
"flux_controlnet",
title="FLUX ControlNet",
tags=["controlnet", "flux"],
category="controlnet",
category="conditioning",
version="1.0.0",
)
class FluxControlNetInvocation(BaseInvocation):

View File

@@ -70,7 +70,7 @@ from invokeai.backend.util.devices import TorchDevice
"flux_denoise",
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
category="latents",
version="4.5.1",
)
class FluxDenoiseInvocation(BaseInvocation):

View File

@@ -29,7 +29,7 @@ class FluxFillOutput(BaseInvocationOutput):
"flux_fill",
title="FLUX Fill Conditioning",
tags=["inpaint"],
category="inpaint",
category="conditioning",
version="1.0.0",
classification=Classification.Beta,
)

View File

@@ -24,7 +24,7 @@ from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
"flux_ip_adapter",
title="FLUX IP-Adapter",
tags=["ip_adapter", "control"],
category="ip_adapter",
category="conditioning",
version="1.0.0",
)
class FluxIPAdapterInvocation(BaseInvocation):

View File

@@ -47,7 +47,7 @@ DOWNSAMPLING_FUNCTIONS = Literal["nearest", "bilinear", "bicubic", "area", "near
"flux_redux",
title="FLUX Redux",
tags=["ip_adapter", "control"],
category="ip_adapter",
category="conditioning",
version="2.1.0",
classification=Classification.Beta,
)

View File

@@ -28,7 +28,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
"flux_text_encoder",
title="Prompt - FLUX",
tags=["prompt", "conditioning", "flux"],
category="conditioning",
category="prompt",
version="1.1.2",
)
class FluxTextEncoderInvocation(BaseInvocation):

View File

@@ -24,7 +24,7 @@ GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = {
"grounding_dino",
title="Grounding DINO (Text Prompt Object Detection)",
tags=["prompt", "object detection"],
category="image",
category="segmentation",
version="1.0.0",
)
class GroundingDinoInvocation(BaseInvocation):

View File

@@ -11,7 +11,7 @@ from invokeai.backend.image_util.hed import ControlNetHED_Apache2, HEDEdgeDetect
"hed_edge_detection",
title="HED Edge Detection",
tags=["controlnet", "hed", "softedge"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class HEDEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -21,6 +21,7 @@ class IdealSizeOutput(BaseInvocationOutput):
"ideal_size",
title="Ideal Size - SD1.5, SDXL",
tags=["latents", "math", "ideal_size"],
category="latents",
version="1.0.6",
)
class IdealSizeInvocation(BaseInvocation):

View File

@@ -21,7 +21,7 @@ from invokeai.app.invocations.fields import (
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.primitives import ImageOutput, StringOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
@@ -197,7 +197,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
"tomask",
title="Mask from Alpha",
tags=["image", "mask"],
category="image",
category="mask",
version="1.2.2",
)
class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -581,11 +581,30 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithBoard):
return ImageOutput.build(image_dto)
@invocation(
"decode_watermark",
title="Decode Invisible Watermark",
tags=["image", "watermark"],
category="image",
version="1.0.0",
)
class DecodeInvisibleWatermarkInvocation(BaseInvocation):
"""Decode an invisible watermark from an image."""
image: ImageField = InputField(description="The image to decode the watermark from")
length: int = InputField(default=8, description="The expected watermark length in bytes")
def invoke(self, context: InvocationContext) -> StringOutput:
image = context.images.get_pil(self.image.image_name)
watermark = InvisibleWatermark.decode_watermark(image, self.length)
return StringOutput(value=watermark)
@invocation(
"mask_edge",
title="Mask Edge",
tags=["image", "mask", "inpaint"],
category="image",
category="mask",
version="1.2.2",
)
class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -624,7 +643,7 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard):
"mask_combine",
title="Combine Masks",
tags=["image", "mask", "multiply"],
category="image",
category="mask",
version="1.2.2",
)
class MaskCombineInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -955,7 +974,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
"save_image",
title="Save Image",
tags=["primitives", "image"],
category="primitives",
category="image",
version="1.2.2",
use_cache=False,
)
@@ -976,7 +995,7 @@ class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"canvas_paste_back",
title="Canvas Paste Back",
tags=["image", "combine"],
category="image",
category="canvas",
version="1.0.1",
)
class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -1013,7 +1032,7 @@ class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
"mask_from_id",
title="Mask from Segmented Image",
tags=["image", "mask", "id"],
category="image",
category="mask",
version="1.0.1",
)
class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -1050,7 +1069,7 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
"canvas_v2_mask_and_crop",
title="Canvas V2 Mask and Crop",
tags=["image", "mask", "id"],
category="image",
category="canvas",
version="1.0.0",
classification=Classification.Deprecated,
)
@@ -1091,7 +1110,7 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
@invocation(
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.1"
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="mask", version="1.0.1"
)
class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
@@ -1180,7 +1199,7 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
"apply_mask_to_image",
title="Apply Mask to Image",
tags=["image", "mask", "blend"],
category="image",
category="mask",
version="1.0.0",
)
class ApplyMaskToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -1355,7 +1374,7 @@ class PasteImageIntoBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoar
"flux_kontext_image_prep",
title="FLUX Kontext Image Prep",
tags=["image", "concatenate", "flux", "kontext"],
category="image",
category="conditioning",
version="1.0.0",
)
class FluxKontextConcatenateImagesInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -23,7 +23,7 @@ class ImagePanelCoordinateOutput(BaseInvocationOutput):
"image_panel_layout",
title="Image Panel Layout",
tags=["image", "panel", "layout"],
category="image",
category="canvas",
version="1.0.0",
classification=Classification.Prototype,
)

View File

@@ -73,7 +73,7 @@ CLIP_VISION_MODEL_MAP: dict[Literal["ViT-L", "ViT-H", "ViT-G"], StarterModel] =
"ip_adapter",
title="IP-Adapter - SD1.5, SDXL",
tags=["ip_adapter", "control"],
category="ip_adapter",
category="conditioning",
version="1.5.1",
)
class IPAdapterInvocation(BaseInvocation):

View File

@@ -11,7 +11,7 @@ from invokeai.backend.image_util.lineart import Generator, LineartEdgeDetector
"lineart_edge_detection",
title="Lineart Edge Detection",
tags=["controlnet", "lineart"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class LineartEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -9,7 +9,7 @@ from invokeai.backend.image_util.lineart_anime import LineartAnimeEdgeDetector,
"lineart_anime_edge_detection",
title="Lineart Anime Edge Detection",
tags=["controlnet", "lineart"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class LineartAnimeEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -19,7 +19,7 @@ from invokeai.backend.util.devices import TorchDevice
"llava_onevision_vllm",
title="LLaVA OneVision VLLM",
tags=["vllm"],
category="vllm",
category="multimodal",
version="1.0.0",
classification=Classification.Beta,
)

View File

@@ -0,0 +1,34 @@
from typing import Any, Optional
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("if_output")
class IfInvocationOutput(BaseInvocationOutput):
value: Optional[Any] = OutputField(
default=None, description="The selected value", title="Output", ui_type=UIType.Any
)
@invocation("if", title="If", tags=["logic", "conditional"], category="math", version="1.0.0")
class IfInvocation(BaseInvocation):
"""Selects between two optional inputs based on a boolean condition."""
condition: bool = InputField(default=False, description="The condition used to select an input", title="Condition")
true_input: Optional[Any] = InputField(
default=None,
description="Selected when the condition is true",
title="True Input",
ui_type=UIType.Any,
)
false_input: Optional[Any] = InputField(
default=None,
description="Selected when the condition is false",
title="False Input",
ui_type=UIType.Any,
)
def invoke(self, context: InvocationContext) -> IfInvocationOutput:
return IfInvocationOutput(value=self.true_input if self.condition else self.false_input)

View File

@@ -24,7 +24,7 @@ from invokeai.backend.image_util.util import pil_to_np
"rectangle_mask",
title="Create Rectangle Mask",
tags=["conditioning"],
category="conditioning",
category="mask",
version="1.0.1",
)
class RectangleMaskInvocation(BaseInvocation, WithMetadata):
@@ -55,7 +55,7 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata):
"alpha_mask_to_tensor",
title="Alpha Mask to Tensor",
tags=["conditioning"],
category="conditioning",
category="mask",
version="1.0.0",
)
class AlphaMaskToTensorInvocation(BaseInvocation):
@@ -83,7 +83,7 @@ class AlphaMaskToTensorInvocation(BaseInvocation):
"invert_tensor_mask",
title="Invert Tensor Mask",
tags=["conditioning"],
category="conditioning",
category="mask",
version="1.1.0",
)
class InvertTensorMaskInvocation(BaseInvocation):
@@ -115,7 +115,7 @@ class InvertTensorMaskInvocation(BaseInvocation):
"image_mask_to_tensor",
title="Image Mask to Tensor",
tags=["conditioning"],
category="conditioning",
category="mask",
version="1.0.0",
)
class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):

View File

@@ -9,7 +9,7 @@ from invokeai.backend.image_util.mediapipe_face import detect_faces
"mediapipe_face_detection",
title="MediaPipe Face Detection",
tags=["controlnet", "face"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class MediaPipeFaceDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -166,6 +166,14 @@ GENERATION_MODES = Literal[
"z_image_img2img",
"z_image_inpaint",
"z_image_outpaint",
"qwen_image_txt2img",
"qwen_image_img2img",
"qwen_image_inpaint",
"qwen_image_outpaint",
"anima_txt2img",
"anima_img2img",
"anima_inpaint",
"anima_outpaint",
]

View File

@@ -621,7 +621,7 @@ class LatentsMetaOutput(LatentsOutput, MetadataOutput):
"denoise_latents_meta",
title=f"{DenoiseLatentsInvocation.UIConfig.title} + Metadata",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
category="metadata",
version="1.1.1",
)
class DenoiseLatentsMetaInvocation(DenoiseLatentsInvocation, WithMetadata):
@@ -686,7 +686,7 @@ class DenoiseLatentsMetaInvocation(DenoiseLatentsInvocation, WithMetadata):
"flux_denoise_meta",
title=f"{FluxDenoiseInvocation.UIConfig.title} + Metadata",
tags=["flux", "latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
category="metadata",
version="1.0.1",
)
class FluxDenoiseLatentsMetaInvocation(FluxDenoiseInvocation, WithMetadata):
@@ -734,7 +734,7 @@ class FluxDenoiseLatentsMetaInvocation(FluxDenoiseInvocation, WithMetadata):
"z_image_denoise_meta",
title=f"{ZImageDenoiseInvocation.UIConfig.title} + Metadata",
tags=["z-image", "latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
category="metadata",
version="1.0.0",
)
class ZImageDenoiseMetaInvocation(ZImageDenoiseInvocation, WithMetadata):

View File

@@ -10,7 +10,7 @@ from invokeai.backend.image_util.mlsd.models.mbv2_mlsd_large import MobileV2_MLS
"mlsd_detection",
title="MLSD Detection",
tags=["controlnet", "mlsd", "edge"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class MLSDDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -72,6 +72,13 @@ class GlmEncoderField(BaseModel):
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
class QwenVLEncoderField(BaseModel):
"""Field for Qwen2.5-VL encoder used by Qwen Image Edit models."""
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
class Qwen3EncoderField(BaseModel):
"""Field for Qwen3 text encoder used by Z-Image models."""
@@ -577,7 +584,7 @@ class SeamlessModeInvocation(BaseInvocation):
return SeamlessModeOutput(unet=unet, vae=vae)
@invocation("freeu", title="Apply FreeU - SD1.5, SDXL", tags=["freeu"], category="unet", version="1.0.2")
@invocation("freeu", title="Apply FreeU - SD1.5, SDXL", tags=["freeu"], category="model", version="1.0.2")
class FreeUInvocation(BaseInvocation):
"""
Applies FreeU to the UNet. Suggested values (b1/b2/s1/s2):

View File

@@ -10,7 +10,7 @@ from invokeai.backend.image_util.normal_bae.nets.NNET import NNET
"normal_map",
title="Normal Map",
tags=["controlnet", "normal"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class NormalMapInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -16,7 +16,9 @@ class PBRMapsOutput(BaseInvocationOutput):
displacement_map: ImageField = OutputField(default=None, description="The generated displacement map")
@invocation("pbr_maps", title="PBR Maps", tags=["image", "material"], category="image", version="1.0.0")
@invocation(
"pbr_maps", title="PBR Maps", tags=["image", "material"], category="controlnet_preprocessors", version="1.0.0"
)
class PBRMapsInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generate Normal, Displacement and Roughness Map from a given image"""

View File

@@ -10,7 +10,7 @@ from invokeai.backend.image_util.pidi.model import PiDiNet
"pidi_edge_detection",
title="PiDiNet Edge Detection",
tags=["controlnet", "edge"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class PiDiNetEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -12,6 +12,7 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
AnimaConditioningField,
BoundingBoxField,
CogView4ConditioningField,
ColorField,
@@ -24,6 +25,7 @@ from invokeai.app.invocations.fields import (
InputField,
LatentsField,
OutputField,
QwenImageConditioningField,
SD3ConditioningField,
TensorField,
UIComponent,
@@ -473,6 +475,28 @@ class ZImageConditioningOutput(BaseInvocationOutput):
return cls(conditioning=ZImageConditioningField(conditioning_name=conditioning_name))
@invocation_output("qwen_image_conditioning_output")
class QwenImageConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a Qwen Image Edit conditioning tensor."""
conditioning: QwenImageConditioningField = OutputField(description=FieldDescriptions.cond)
@classmethod
def build(cls, conditioning_name: str) -> "QwenImageConditioningOutput":
return cls(conditioning=QwenImageConditioningField(conditioning_name=conditioning_name))
@invocation_output("anima_conditioning_output")
class AnimaConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output an Anima text conditioning tensor."""
conditioning: AnimaConditioningField = OutputField(description=FieldDescriptions.cond)
@classmethod
def build(cls, conditioning_name: str) -> "AnimaConditioningOutput":
return cls(conditioning=AnimaConditioningField(conditioning_name=conditioning_name))
@invocation_output("conditioning_output")
class ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""

View File

@@ -0,0 +1,490 @@
from contextlib import ExitStack
from typing import Callable, Iterator, Optional, Tuple
import torch
import torchvision.transforms as tv_transforms
from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformer2DModel
from torchvision.transforms.functional import resize as tv_resize
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
Input,
InputField,
LatentsField,
QwenImageConditioningField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.qwen_image_lora_constants import (
QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX,
)
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import QwenImageConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
"qwen_image_denoise",
title="Denoise - Qwen Image",
tags=["image", "qwen_image"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class QwenImageDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run the denoising process with a Qwen Image model."""
# If latents is provided, this means we are doing image-to-image.
latents: Optional[LatentsField] = InputField(
default=None, description=FieldDescriptions.latents, input=Input.Connection
)
# Reference image latents (encoded through VAE) to concatenate with noisy latents.
reference_latents: Optional[LatentsField] = InputField(
default=None,
description="Reference image latents to guide generation. Encoded through the VAE.",
input=Input.Connection,
)
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
)
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
transformer: TransformerField = InputField(
description=FieldDescriptions.qwen_image_model, input=Input.Connection, title="Transformer"
)
positive_conditioning: QwenImageConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: Optional[QwenImageConditioningField] = InputField(
default=None, description=FieldDescriptions.negative_cond, input=Input.Connection
)
cfg_scale: float | list[float] = InputField(default=4.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
steps: int = InputField(default=40, gt=0, description=FieldDescriptions.steps)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
shift: Optional[float] = InputField(
default=None,
description="Override the sigma schedule shift. "
"When set, uses a fixed shift (e.g. 3.0 for Lightning LoRAs) instead of the default dynamic shifting. "
"Leave unset for the base model's default schedule.",
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
latents = latents.detach().to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
if self.denoise_mask is None:
return None
mask = context.tensors.load(self.denoise_mask.mask_name)
mask = 1.0 - mask
_, _, latent_height, latent_width = latents.shape
mask = tv_resize(
img=mask,
size=[latent_height, latent_width],
interpolation=tv_transforms.InterpolationMode.BILINEAR,
antialias=False,
)
mask = mask.to(device=latents.device, dtype=latents.dtype)
return mask
def _load_text_conditioning(
self,
context: InvocationContext,
conditioning_name: str,
dtype: torch.dtype,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor | None]:
cond_data = context.conditioning.load(conditioning_name)
assert len(cond_data.conditionings) == 1
conditioning = cond_data.conditionings[0]
assert isinstance(conditioning, QwenImageConditioningInfo)
conditioning = conditioning.to(dtype=dtype, device=device)
return conditioning.prompt_embeds, conditioning.prompt_embeds_mask
def _get_noise(
self,
batch_size: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
seed: int,
) -> torch.Tensor:
rand_device = "cpu"
rand_dtype = torch.float32
return torch.randn(
batch_size,
num_channels_latents,
int(height) // LATENT_SCALE_FACTOR,
int(width) // LATENT_SCALE_FACTOR,
device=rand_device,
dtype=rand_dtype,
generator=torch.Generator(device=rand_device).manual_seed(seed),
).to(device=device, dtype=dtype)
def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]:
if isinstance(self.cfg_scale, float):
cfg_scale = [self.cfg_scale] * num_timesteps
elif isinstance(self.cfg_scale, list):
assert len(self.cfg_scale) == num_timesteps
cfg_scale = self.cfg_scale
else:
raise ValueError(f"Invalid CFG scale type: {type(self.cfg_scale)}")
return cfg_scale
@staticmethod
def _pack_latents(
latents: torch.Tensor, batch_size: int, num_channels: int, height: int, width: int
) -> torch.Tensor:
"""Pack 4D latents (B, C, H, W) into 2x2-patched 3D (B, H/2*W/2, C*4)."""
latents = latents.view(batch_size, num_channels, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels * 4)
return latents
@staticmethod
def _unpack_latents(latents: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""Unpack 3D patched latents (B, seq, C*4) back to 4D (B, C, H, W)."""
batch_size, _num_patches, channels = latents.shape
# height/width are in latent space; they must be divisible by 2 for packing
h = 2 * (height // 2)
w = 2 * (width // 2)
latents = latents.view(batch_size, h // 2, w // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // 4, h, w)
return latents
def _run_diffusion(self, context: InvocationContext):
inference_dtype = torch.bfloat16
device = TorchDevice.choose_torch_device()
transformer_info = context.models.load(self.transformer.transformer)
assert isinstance(transformer_info.model, QwenImageTransformer2DModel)
# Load conditioning
pos_prompt_embeds, pos_prompt_mask = self._load_text_conditioning(
context=context,
conditioning_name=self.positive_conditioning.conditioning_name,
dtype=inference_dtype,
device=device,
)
neg_prompt_embeds = None
neg_prompt_mask = None
# Match the diffusers pipeline: only enable CFG when cfg_scale > 1 AND negative conditioning is provided.
# With cfg_scale <= 1, the negative prediction is unused, so skip it entirely.
# For per-step arrays, enable CFG if any step has scale > 1.
if isinstance(self.cfg_scale, list):
any_cfg_above_one = any(v > 1.0 for v in self.cfg_scale)
else:
any_cfg_above_one = self.cfg_scale > 1.0
do_classifier_free_guidance = self.negative_conditioning is not None and any_cfg_above_one
if do_classifier_free_guidance:
neg_prompt_embeds, neg_prompt_mask = self._load_text_conditioning(
context=context,
conditioning_name=self.negative_conditioning.conditioning_name,
dtype=inference_dtype,
device=device,
)
# Prepare the timestep / sigma schedule
patch_size = transformer_info.model.config.patch_size
assert isinstance(patch_size, int)
# Output channels is 16 (the actual latent channels)
out_channels = transformer_info.model.config.out_channels
assert isinstance(out_channels, int)
latent_height = self.height // LATENT_SCALE_FACTOR
latent_width = self.width // LATENT_SCALE_FACTOR
image_seq_len = (latent_height * latent_width) // (patch_size**2)
# Use the actual FlowMatchEulerDiscreteScheduler to compute sigmas/timesteps,
# exactly matching the diffusers pipeline.
import math
import numpy as np
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
# Try to load the scheduler config from the model's directory (Diffusers models
# have a scheduler/ subdir). For GGUF models this path doesn't exist, so fall
# back to instantiating the scheduler with the known Qwen Image defaults.
model_path = context.models.get_absolute_path(context.models.get_config(self.transformer.transformer))
scheduler_path = model_path / "scheduler"
if scheduler_path.is_dir() and (scheduler_path / "scheduler_config.json").exists():
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(str(scheduler_path), local_files_only=True)
else:
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.5,
max_shift=0.9,
base_image_seq_len=256,
max_image_seq_len=8192,
shift_terminal=0.02,
num_train_timesteps=1000,
time_shift_type="exponential",
)
if self.shift is not None:
# Lightning LoRA: fixed shift
mu = math.log(self.shift)
else:
# Default dynamic shifting
# Linear interpolation matching diffusers' calculate_shift
base_shift = scheduler.config.get("base_shift", 0.5)
max_shift = scheduler.config.get("max_shift", 0.9)
base_seq = scheduler.config.get("base_image_seq_len", 256)
max_seq = scheduler.config.get("max_image_seq_len", 4096)
m = (max_shift - base_shift) / (max_seq - base_seq)
b = base_shift - m * base_seq
mu = image_seq_len * m + b
init_sigmas = np.linspace(1.0, 1.0 / self.steps, self.steps).tolist()
scheduler.set_timesteps(sigmas=init_sigmas, mu=mu, device=device)
# Clip the schedule based on denoising_start/denoising_end to support img2img strength.
# The scheduler's sigmas go from high (noisy) to 0 (clean). We clip to the fractional range.
sigmas_sched = scheduler.sigmas # (N+1,) including terminal 0
if self.denoising_start > 0 or self.denoising_end < 1:
total_sigmas = len(sigmas_sched) - 1 # exclude terminal
start_idx = int(round(self.denoising_start * total_sigmas))
end_idx = int(round(self.denoising_end * total_sigmas))
sigmas_sched = sigmas_sched[start_idx : end_idx + 1] # +1 to include the next sigma for dt
# Rebuild timesteps from clipped sigmas (exclude terminal 0)
timesteps_sched = sigmas_sched[:-1] * scheduler.config.num_train_timesteps
else:
timesteps_sched = scheduler.timesteps
total_steps = len(timesteps_sched)
cfg_scale = self._prepare_cfg_scale(total_steps)
# Load initial latents if provided (for img2img)
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
if init_latents is not None:
init_latents = init_latents.to(device=device, dtype=inference_dtype)
if init_latents.dim() == 5:
init_latents = init_latents.squeeze(2)
# Load reference image latents if provided
ref_latents = None
if self.reference_latents is not None:
ref_latents = context.tensors.load(self.reference_latents.latents_name)
ref_latents = ref_latents.to(device=device, dtype=inference_dtype)
# The VAE encoder produces 5D latents (B, C, 1, H, W); squeeze the frame dim
# so we have 4D (B, C, H, W) for packing.
if ref_latents.dim() == 5:
ref_latents = ref_latents.squeeze(2)
# Generate noise (16 channels - the output latent channels)
noise = self._get_noise(
batch_size=1,
num_channels_latents=out_channels,
height=self.height,
width=self.width,
dtype=inference_dtype,
device=device,
seed=self.seed,
)
# Prepare input latent image
if init_latents is not None:
s_0 = sigmas_sched[0].item()
latents = s_0 * noise + (1.0 - s_0) * init_latents
else:
if self.denoising_start > 1e-5:
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
latents = noise
if total_steps <= 0:
return latents
# Pack latents into 2x2 patches: (B, C, H, W) -> (B, H/2*W/2, C*4)
latents = self._pack_latents(latents, 1, out_channels, latent_height, latent_width)
# Determine whether the model uses reference latent conditioning (zero_cond_t).
# Edit models (zero_cond_t=True) expect [noisy_patches ; ref_patches] in the sequence.
# Txt2img models (zero_cond_t=False) only take noisy patches.
has_zero_cond_t = getattr(transformer_info.model, "zero_cond_t", False) or getattr(
transformer_info.model.config, "zero_cond_t", False
)
use_ref_latents = has_zero_cond_t
ref_latents_packed = None
if use_ref_latents:
if ref_latents is not None:
_, ref_ch, rh, rw = ref_latents.shape
if rh != latent_height or rw != latent_width:
ref_latents = torch.nn.functional.interpolate(
ref_latents, size=(latent_height, latent_width), mode="bilinear"
)
else:
# No reference image provided — use zeros so the model still gets the
# expected sequence layout.
ref_latents = torch.zeros(
1, out_channels, latent_height, latent_width, device=device, dtype=inference_dtype
)
ref_latents_packed = self._pack_latents(ref_latents, 1, out_channels, latent_height, latent_width)
# img_shapes tells the transformer the spatial layout of patches.
if use_ref_latents:
img_shapes = [
[
(1, latent_height // 2, latent_width // 2),
(1, latent_height // 2, latent_width // 2),
]
]
else:
img_shapes = [
[
(1, latent_height // 2, latent_width // 2),
]
]
# Prepare inpaint extension (operates in 4D space, so unpack/repack around it)
inpaint_mask = self._prep_inpaint_mask(context, noise) # noise has the right 4D shape
inpaint_extension: RectifiedFlowInpaintExtension | None = None
if inpaint_mask is not None:
assert init_latents is not None
inpaint_extension = RectifiedFlowInpaintExtension(
init_latents=init_latents,
inpaint_mask=inpaint_mask,
noise=noise,
)
step_callback = self._build_step_callback(context)
step_callback(
PipelineIntermediateState(
step=0,
order=1,
total_steps=total_steps,
timestep=int(timesteps_sched[0].item()) if len(timesteps_sched) > 0 else 0,
latents=self._unpack_latents(latents, latent_height, latent_width),
),
)
noisy_seq_len = latents.shape[1]
# Determine if the model is quantized — GGUF models need sidecar patching for LoRAs
transformer_config = context.models.get_config(self.transformer.transformer)
model_is_quantized = transformer_config.format in (ModelFormat.GGUFQuantized,)
with ExitStack() as exit_stack:
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
assert isinstance(transformer, QwenImageTransformer2DModel)
# Apply LoRA patches to the transformer
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX,
dtype=inference_dtype,
cached_weights=cached_weights,
force_sidecar_patching=model_is_quantized,
)
)
for step_idx, t in enumerate(tqdm(timesteps_sched)):
# The pipeline passes timestep / 1000 to the transformer
timestep = t.expand(latents.shape[0]).to(inference_dtype)
# For edit models: concatenate noisy and reference patches along the sequence dim
# For txt2img models: just use noisy patches
if ref_latents_packed is not None:
model_input = torch.cat([latents, ref_latents_packed], dim=1)
else:
model_input = latents
noise_pred_cond = transformer(
hidden_states=model_input,
encoder_hidden_states=pos_prompt_embeds,
encoder_hidden_states_mask=pos_prompt_mask,
timestep=timestep / 1000,
img_shapes=img_shapes,
return_dict=False,
)[0]
# Only keep the noisy-latent portion of the output
noise_pred_cond = noise_pred_cond[:, :noisy_seq_len]
if do_classifier_free_guidance and neg_prompt_embeds is not None:
noise_pred_uncond = transformer(
hidden_states=model_input,
encoder_hidden_states=neg_prompt_embeds,
encoder_hidden_states_mask=neg_prompt_mask,
timestep=timestep / 1000,
img_shapes=img_shapes,
return_dict=False,
)[0]
noise_pred_uncond = noise_pred_uncond[:, :noisy_seq_len]
noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond
# Euler step using the (possibly clipped) sigma schedule
sigma_curr = sigmas_sched[step_idx]
sigma_next = sigmas_sched[step_idx + 1]
dt = sigma_next - sigma_curr
latents = latents.to(torch.float32) + dt * noise_pred.to(torch.float32)
latents = latents.to(inference_dtype)
if inpaint_extension is not None:
sigma_next = sigmas_sched[step_idx + 1].item()
latents_4d = self._unpack_latents(latents, latent_height, latent_width)
latents_4d = inpaint_extension.merge_intermediate_latents_with_init_latents(latents_4d, sigma_next)
latents = self._pack_latents(latents_4d, 1, out_channels, latent_height, latent_width)
step_callback(
PipelineIntermediateState(
step=step_idx + 1,
order=1,
total_steps=total_steps,
timestep=int(t.item()),
latents=self._unpack_latents(latents, latent_height, latent_width),
),
)
# Unpack back to 4D then add frame dim for the video-style VAE: (B, C, 1, H, W)
latents = self._unpack_latents(latents, latent_height, latent_width)
latents = latents.unsqueeze(2)
return latents
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, BaseModelType.QwenImage)
return step_callback
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
"""Iterate over LoRA models to apply to the transformer."""
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)
if not isinstance(lora_info.model, ModelPatchRaw):
raise TypeError(
f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}."
)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -0,0 +1,96 @@
import einops
import torch
from diffusers.models.autoencoders.autoencoder_kl_qwenimage import AutoencoderKLQwenImage
from PIL import Image as PILImage
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
@invocation(
"qwen_image_i2l",
title="Image to Latents - Qwen Image",
tags=["image", "latents", "vae", "i2l", "qwen_image"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class QwenImageImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates latents from an image using the Qwen Image VAE."""
image: ImageField = InputField(description="The image to encode.")
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
width: int | None = InputField(
default=None,
description="Resize the image to this width before encoding. If not set, encodes at the image's original size.",
)
height: int | None = InputField(
default=None,
description="Resize the image to this height before encoding. If not set, encodes at the image's original size.",
)
@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
with vae_info.model_on_device() as (_, vae):
assert isinstance(vae, AutoencoderKLQwenImage)
vae.disable_tiling()
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
with torch.inference_mode():
# The Qwen Image VAE expects 5D input: (B, C, num_frames, H, W)
if image_tensor.dim() == 4:
image_tensor = image_tensor.unsqueeze(2)
posterior = vae.encode(image_tensor).latent_dist
# Use mode (argmax) for deterministic encoding, matching diffusers
latents: torch.Tensor = posterior.mode().to(dtype=vae.dtype)
# Normalize with per-channel latents_mean / latents_std
latents_mean = (
torch.tensor(vae.config.latents_mean)
.view(1, vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = (
torch.tensor(vae.config.latents_std)
.view(1, vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents = (latents - latents_mean) / latents_std
return latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
# If target dimensions are specified, resize the image BEFORE encoding
# (matching the diffusers pipeline which resizes in pixel space, not latent space).
if self.width is not None and self.height is not None:
image = image.convert("RGB").resize((self.width, self.height), resample=PILImage.LANCZOS)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
vae_info = context.models.load(self.vae.vae)
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)

View File

@@ -0,0 +1,85 @@
from contextlib import nullcontext
import torch
from diffusers.models.autoencoders.autoencoder_kl_qwenimage import AutoencoderKLQwenImage
from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.util.devices import TorchDevice
@invocation(
"qwen_image_l2i",
title="Latents to Image - Qwen Image",
tags=["latents", "image", "vae", "l2i", "qwen_image"],
category="latents",
version="1.0.0",
classification=Classification.Prototype,
)
class QwenImageLatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents using the Qwen Image VAE."""
latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, AutoencoderKLQwenImage)
with (
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
vae_info.model_on_device() as (_, vae),
):
context.util.signal_progress("Running VAE")
assert isinstance(vae, AutoencoderKLQwenImage)
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
vae.disable_tiling()
tiling_context = nullcontext()
TorchDevice.empty_cache()
with torch.inference_mode(), tiling_context:
# The Qwen Image VAE uses per-channel latents_mean / latents_std
# instead of a single scaling_factor.
# Latents are 5D: (B, C, num_frames, H, W) — the unpack from the
# denoise step already produces this shape.
latents_mean = (
torch.tensor(vae.config.latents_mean)
.view(1, vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
img = vae.decode(latents, return_dict=False)[0]
# Drop the temporal frame dimension: (B, C, 1, H, W) -> (B, C, H, W)
img = img[:, :, 0]
img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c")
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
TorchDevice.empty_cache()
image_dto = context.images.save(image=img_pil)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,115 @@
from typing import Optional
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
@invocation_output("qwen_image_lora_loader_output")
class QwenImageLoRALoaderOutput(BaseInvocationOutput):
"""Qwen Image LoRA Loader Output"""
transformer: Optional[TransformerField] = OutputField(
default=None, description=FieldDescriptions.transformer, title="Transformer"
)
@invocation(
"qwen_image_lora_loader",
title="Apply LoRA - Qwen Image",
tags=["lora", "model", "qwen_image"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class QwenImageLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a Qwen Image transformer."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model,
title="LoRA",
ui_model_base=BaseModelType.QwenImage,
ui_model_type=ModelType.LoRA,
)
weight: float = InputField(default=1.0, description=FieldDescriptions.lora_weight)
transformer: TransformerField | None = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="Transformer",
)
def invoke(self, context: InvocationContext) -> QwenImageLoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")
if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
output = QwenImageLoRALoaderOutput()
if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
output.transformer.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
return output
@invocation(
"qwen_image_lora_collection_loader",
title="Apply LoRA Collection - Qwen Image",
tags=["lora", "model", "qwen_image"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class QwenImageLoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to a Qwen Image transformer."""
loras: Optional[LoRAField | list[LoRAField]] = InputField(
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)
transformer: Optional[TransformerField] = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="Transformer",
)
def invoke(self, context: InvocationContext) -> QwenImageLoRALoaderOutput:
output = QwenImageLoRALoaderOutput()
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []
if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
for lora in loras:
if lora is None:
continue
if lora.lora.key in added_loras:
continue
if not context.models.exists(lora.lora.key):
raise Exception(f"Unknown lora: {lora.lora.key}!")
added_loras.append(lora.lora.key)
if self.transformer is not None and output.transformer is not None:
output.transformer.loras.append(lora)
return output

View File

@@ -0,0 +1,107 @@
from typing import Optional
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import (
ModelIdentifierField,
QwenVLEncoderField,
TransformerField,
VAEField,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, SubModelType
@invocation_output("qwen_image_model_loader_output")
class QwenImageModelLoaderOutput(BaseInvocationOutput):
"""Qwen Image model loader output."""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
qwen_vl_encoder: QwenVLEncoderField = OutputField(
description=FieldDescriptions.qwen_vl_encoder, title="Qwen VL Encoder"
)
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation(
"qwen_image_model_loader",
title="Main Model - Qwen Image",
tags=["model", "qwen_image"],
category="model",
version="1.1.0",
classification=Classification.Prototype,
)
class QwenImageModelLoaderInvocation(BaseInvocation):
"""Loads a Qwen Image model, outputting its submodels.
The transformer is always loaded from the main model (Diffusers or GGUF).
For GGUF quantized models, the VAE and Qwen VL encoder must come from a
separate Diffusers model specified in the "Component Source" field.
For Diffusers models, all components are extracted from the main model
automatically. The "Component Source" field is ignored.
"""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.qwen_image_model,
input=Input.Direct,
ui_model_base=BaseModelType.QwenImage,
ui_model_type=ModelType.Main,
title="Transformer",
)
component_source: Optional[ModelIdentifierField] = InputField(
default=None,
description="Diffusers Qwen Image model to extract the VAE and Qwen VL encoder from. "
"Required when using a GGUF quantized transformer. "
"Ignored when the main model is already in Diffusers format.",
input=Input.Direct,
ui_model_base=BaseModelType.QwenImage,
ui_model_type=ModelType.Main,
ui_model_format=ModelFormat.Diffusers,
title="Component Source (Diffusers)",
)
def invoke(self, context: InvocationContext) -> QwenImageModelLoaderOutput:
main_config = context.models.get_config(self.model)
main_is_diffusers = main_config.format == ModelFormat.Diffusers
# Transformer always comes from the main model
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
if main_is_diffusers:
# Diffusers model: extract all components directly
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
elif self.component_source is not None:
# GGUF/checkpoint transformer: get VAE + encoder from the component source
source_config = context.models.get_config(self.component_source)
if source_config.format != ModelFormat.Diffusers:
raise ValueError(
f"The Component Source model must be in Diffusers format. "
f"The selected model '{source_config.name}' is in {source_config.format.value} format."
)
vae = self.component_source.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer = self.component_source.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.component_source.model_copy(update={"submodel_type": SubModelType.TextEncoder})
else:
raise ValueError(
"No source for VAE and Qwen VL encoder. "
"GGUF quantized models only contain the transformer — "
"please set 'Component Source' to a Diffusers Qwen Image model "
"to provide the VAE and text encoder."
)
return QwenImageModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
qwen_vl_encoder=QwenVLEncoderField(tokenizer=tokenizer, text_encoder=text_encoder),
vae=VAEField(vae=vae),
)

View File

@@ -0,0 +1,298 @@
from typing import Literal
import torch
from PIL import Image as PILImage
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
UIComponent,
)
from invokeai.app.invocations.model import QwenVLEncoderField
from invokeai.app.invocations.primitives import QwenImageConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningFieldData,
QwenImageConditioningInfo,
)
# Prompt templates and drop indices for the two Qwen Image model modes.
# These are taken directly from the diffusers pipelines.
# Image editing mode (QwenImagePipeline)
_EDIT_SYSTEM_PROMPT = (
"Describe the key features of the input image (color, shape, size, texture, objects, background), "
"then explain how the user's text instruction should alter or modify the image. "
"Generate a new image that meets the user's requirements while maintaining consistency "
"with the original input where appropriate."
)
_EDIT_DROP_IDX = 64
# Text-to-image mode (QwenImagePipeline)
_GENERATE_SYSTEM_PROMPT = (
"Describe the image by detailing the color, shape, size, texture, quantity, "
"text, spatial relationships of the objects and background:"
)
_GENERATE_DROP_IDX = 34
_IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>"
def _build_prompt(user_prompt: str, num_images: int) -> str:
"""Build the full prompt with the appropriate template based on whether reference images are provided."""
if num_images > 0:
# Edit mode: include vision placeholders for reference images
image_tokens = _IMAGE_PLACEHOLDER * num_images
return (
f"<|im_start|>system\n{_EDIT_SYSTEM_PROMPT}<|im_end|>\n"
f"<|im_start|>user\n{image_tokens}{user_prompt}<|im_end|>\n"
"<|im_start|>assistant\n"
)
else:
# Generate mode: text-only prompt
return (
f"<|im_start|>system\n{_GENERATE_SYSTEM_PROMPT}<|im_end|>\n"
f"<|im_start|>user\n{user_prompt}<|im_end|>\n"
"<|im_start|>assistant\n"
)
@invocation(
"qwen_image_text_encoder",
title="Prompt - Qwen Image",
tags=["prompt", "conditioning", "qwen_image"],
category="conditioning",
version="1.2.0",
classification=Classification.Prototype,
)
class QwenImageTextEncoderInvocation(BaseInvocation):
"""Encodes text and reference images for Qwen Image using Qwen2.5-VL."""
prompt: str = InputField(description="Text prompt describing the desired edit.", ui_component=UIComponent.Textarea)
reference_images: list[ImageField] = InputField(
default=[],
description="Reference images to guide the edit. The model can use multiple reference images.",
)
qwen_vl_encoder: QwenVLEncoderField = InputField(
title="Qwen VL Encoder",
description=FieldDescriptions.qwen_vl_encoder,
input=Input.Connection,
)
quantization: Literal["none", "int8", "nf4"] = InputField(
default="none",
description="Quantize the Qwen VL encoder to reduce VRAM usage. "
"'nf4' (4-bit) saves the most memory, 'int8' (8-bit) is a middle ground.",
)
@staticmethod
def _resize_for_vl_encoder(image: PILImage.Image, target_pixels: int = 512 * 512) -> PILImage.Image:
"""Resize image to fit within target_pixels while preserving aspect ratio.
Matches the diffusers pipeline's calculate_dimensions logic: the image is resized
so its total pixel count is approximately target_pixels, with dimensions rounded to
multiples of 32. This prevents large images from producing too many vision tokens
which can overwhelm the text prompt.
"""
w, h = image.size
aspect = w / h
# Compute dimensions that preserve aspect ratio at ~target_pixels total
new_w = int((target_pixels * aspect) ** 0.5)
new_h = int(target_pixels / new_w)
# Round to multiples of 32
new_w = max(32, (new_w // 32) * 32)
new_h = max(32, (new_h // 32) * 32)
if new_w != w or new_h != h:
image = image.resize((new_w, new_h), resample=PILImage.LANCZOS)
return image
@torch.no_grad()
def invoke(self, context: InvocationContext) -> QwenImageConditioningOutput:
# Load and resize reference images to ~1M pixels (matching diffusers pipeline)
pil_images: list[PILImage.Image] = []
for img_field in self.reference_images:
pil_img = context.images.get_pil(img_field.image_name)
pil_img = self._resize_for_vl_encoder(pil_img.convert("RGB"))
pil_images.append(pil_img)
prompt_embeds, prompt_mask = self._encode(context, pil_images)
prompt_embeds = prompt_embeds.detach().to("cpu")
prompt_mask = prompt_mask.detach().to("cpu") if prompt_mask is not None else None
conditioning_data = ConditioningFieldData(
conditionings=[QwenImageConditioningInfo(prompt_embeds=prompt_embeds, prompt_embeds_mask=prompt_mask)]
)
conditioning_name = context.conditioning.save(conditioning_data)
return QwenImageConditioningOutput.build(conditioning_name)
def _encode(
self, context: InvocationContext, images: list[PILImage.Image]
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Encode text prompt and reference images using Qwen2.5-VL.
Matches the diffusers QwenImagePipeline._get_qwen_prompt_embeds logic:
1. Format prompt with the edit-specific system template
2. Run through Qwen2.5-VL to get hidden states
3. Extract valid (non-padding) tokens and drop the system prefix
4. Return padded embeddings + attention mask
"""
from transformers import AutoTokenizer, Qwen2_5_VLProcessor
try:
from transformers import Qwen2_5_VLImageProcessor as _ImageProcessorCls
except ImportError:
from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( # type: ignore[no-redef]
Qwen2VLImageProcessor as _ImageProcessorCls,
)
try:
from transformers import Qwen2_5_VLVideoProcessor as _VideoProcessorCls
except ImportError:
from transformers.models.qwen2_vl.video_processing_qwen2_vl import ( # type: ignore[no-redef]
Qwen2VLVideoProcessor as _VideoProcessorCls,
)
# Format the prompt with one vision placeholder per reference image
text = _build_prompt(self.prompt, len(images))
# Build the processor
tokenizer_config = context.models.get_config(self.qwen_vl_encoder.tokenizer)
model_root = context.models.get_absolute_path(tokenizer_config)
tokenizer_dir = model_root / "tokenizer"
tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_dir), local_files_only=True)
image_processor = None
for search_dir in [model_root / "processor", tokenizer_dir, model_root, model_root / "image_processor"]:
if (search_dir / "preprocessor_config.json").exists():
image_processor = _ImageProcessorCls.from_pretrained(str(search_dir), local_files_only=True)
break
if image_processor is None:
image_processor = _ImageProcessorCls()
processor = Qwen2_5_VLProcessor(
tokenizer=tokenizer,
image_processor=image_processor,
video_processor=_VideoProcessorCls(),
)
context.util.signal_progress("Running Qwen2.5-VL text/vision encoder")
if self.quantization != "none":
text_encoder, device, cleanup = self._load_quantized_encoder(context)
else:
text_encoder, device, cleanup = self._load_cached_encoder(context)
try:
model_inputs = processor(
text=[text],
images=images if images else None,
padding=True,
return_tensors="pt",
).to(device=device)
outputs = text_encoder(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
pixel_values=getattr(model_inputs, "pixel_values", None),
image_grid_thw=getattr(model_inputs, "image_grid_thw", None),
output_hidden_states=True,
)
# Use last hidden state (matching diffusers pipeline)
hidden_states = outputs.hidden_states[-1]
# Extract valid (non-padding) tokens using the attention mask,
# then drop the system prompt prefix tokens.
# The drop index differs between edit mode (64) and generate mode (34).
drop_idx = _EDIT_DROP_IDX if images else _GENERATE_DROP_IDX
attn_mask = model_inputs.attention_mask
bool_mask = attn_mask.bool()
valid_lengths = bool_mask.sum(dim=1)
selected = hidden_states[bool_mask]
split_hidden = torch.split(selected, valid_lengths.tolist(), dim=0)
# Drop system prefix tokens and build padded output
trimmed = [h[drop_idx:] for h in split_hidden]
attn_mask_list = [torch.ones(h.size(0), dtype=torch.long, device=device) for h in trimmed]
max_seq_len = max(h.size(0) for h in trimmed)
prompt_embeds = torch.stack(
[torch.cat([h, h.new_zeros(max_seq_len - h.size(0), h.size(1))]) for h in trimmed]
)
encoder_attention_mask = torch.stack(
[torch.cat([m, m.new_zeros(max_seq_len - m.size(0))]) for m in attn_mask_list]
)
prompt_embeds = prompt_embeds.to(dtype=torch.bfloat16)
finally:
if cleanup is not None:
cleanup()
# If all tokens are valid (no padding), mask is not needed
if encoder_attention_mask.all():
encoder_attention_mask = None
return prompt_embeds, encoder_attention_mask
def _load_cached_encoder(self, context: InvocationContext):
"""Load the text encoder through the model cache (no quantization)."""
from transformers import Qwen2_5_VLForConditionalGeneration
text_encoder_info = context.models.load(self.qwen_vl_encoder.text_encoder)
ctx = text_encoder_info.model_on_device()
_, text_encoder = ctx.__enter__()
device = get_effective_device(text_encoder)
assert isinstance(text_encoder, Qwen2_5_VLForConditionalGeneration)
return text_encoder, device, lambda: ctx.__exit__(None, None, None)
def _load_quantized_encoder(self, context: InvocationContext):
"""Load the text encoder with BitsAndBytes quantization, bypassing the model cache.
BnB-quantized models are pinned to GPU and can't be moved between devices,
so they can't go through the standard model cache. The model is loaded fresh
each time and freed after use via the cleanup callback.
"""
import gc
import warnings
from transformers import BitsAndBytesConfig, Qwen2_5_VLForConditionalGeneration
encoder_config = context.models.get_config(self.qwen_vl_encoder.text_encoder)
model_root = context.models.get_absolute_path(encoder_config)
encoder_path = model_root / "text_encoder"
if self.quantization == "nf4":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
)
else: # int8
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
context.util.signal_progress("Loading Qwen2.5-VL encoder (quantized)")
with warnings.catch_warnings():
# BnB int8 internally casts bfloat16→float16; the warning is harmless
warnings.filterwarnings("ignore", message="MatMul8bitLt.*cast.*float16")
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
str(encoder_path),
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
local_files_only=True,
)
device = next(text_encoder.parameters()).device
def cleanup():
nonlocal text_encoder
del text_encoder
gc.collect()
torch.cuda.empty_cache()
return text_encoder, device, cleanup

View File

@@ -34,7 +34,7 @@ from invokeai.backend.util.devices import TorchDevice
"sd3_denoise",
title="Denoise - SD3",
tags=["image", "sd3"],
category="image",
category="latents",
version="1.1.1",
)
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -24,7 +24,7 @@ from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory
"sd3_i2l",
title="Image to Latents - SD3",
tags=["image", "latents", "vae", "i2l", "sd3"],
category="image",
category="latents",
version="1.0.1",
)
class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -16,6 +16,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import SD3ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
from invokeai.backend.model_manager.taxonomy import ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
@@ -30,7 +31,7 @@ SD3_T5_MAX_SEQ_LEN = 256
"sd3_text_encoder",
title="Prompt - SD3",
tags=["prompt", "conditioning", "sd3"],
category="conditioning",
category="prompt",
version="1.0.1",
)
class Sd3TextEncoderInvocation(BaseInvocation):
@@ -103,6 +104,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
context.util.signal_progress("Running T5 encoder")
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
t5_device = get_effective_device(t5_text_encoder)
text_inputs = t5_tokenizer(
prompt,
@@ -125,7 +127,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
f" {max_seq_len} tokens: {removed_text}"
)
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_device))[0]
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
@@ -144,6 +146,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
context.util.signal_progress("Running CLIP encoder")
assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(clip_tokenizer, CLIPTokenizer)
clip_device = get_effective_device(clip_text_encoder)
clip_text_encoder_config = clip_text_encoder_info.config
assert clip_text_encoder_config is not None
@@ -187,9 +190,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = clip_text_encoder(
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
)
prompt_embeds = clip_text_encoder(input_ids=text_input_ids.to(clip_device), output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]

View File

@@ -20,7 +20,7 @@ class StringPosNegOutput(BaseInvocationOutput):
"string_split_neg",
title="String Split Negative",
tags=["string", "split", "negative"],
category="string",
category="strings",
version="1.0.1",
)
class StringSplitNegInvocation(BaseInvocation):
@@ -63,7 +63,7 @@ class String2Output(BaseInvocationOutput):
string_2: str = OutputField(description="string 2")
@invocation("string_split", title="String Split", tags=["string", "split"], category="string", version="1.0.1")
@invocation("string_split", title="String Split", tags=["string", "split"], category="strings", version="1.0.1")
class StringSplitInvocation(BaseInvocation):
"""Splits string into two strings, based on the first occurance of the delimiter. The delimiter will be removed from the string"""
@@ -83,7 +83,7 @@ class StringSplitInvocation(BaseInvocation):
return String2Output(string_1=part1, string_2=part2)
@invocation("string_join", title="String Join", tags=["string", "join"], category="string", version="1.0.1")
@invocation("string_join", title="String Join", tags=["string", "join"], category="strings", version="1.0.1")
class StringJoinInvocation(BaseInvocation):
"""Joins string left to string right"""
@@ -94,7 +94,9 @@ class StringJoinInvocation(BaseInvocation):
return StringOutput(value=((self.string_left or "") + (self.string_right or "")))
@invocation("string_join_three", title="String Join Three", tags=["string", "join"], category="string", version="1.0.1")
@invocation(
"string_join_three", title="String Join Three", tags=["string", "join"], category="strings", version="1.0.1"
)
class StringJoinThreeInvocation(BaseInvocation):
"""Joins string left to string middle to string right"""
@@ -107,7 +109,7 @@ class StringJoinThreeInvocation(BaseInvocation):
@invocation(
"string_replace", title="String Replace", tags=["string", "replace", "regex"], category="string", version="1.0.1"
"string_replace", title="String Replace", tags=["string", "replace", "regex"], category="strings", version="1.0.1"
)
class StringReplaceInvocation(BaseInvocation):
"""Replaces the search string with the replace string"""

View File

@@ -49,7 +49,7 @@ class T2IAdapterOutput(BaseInvocationOutput):
"t2i_adapter",
title="T2I-Adapter - SD1.5, SDXL",
tags=["t2i_adapter", "control"],
category="t2i_adapter",
category="conditioning",
version="1.0.4",
)
class T2IAdapterInvocation(BaseInvocation):

View File

@@ -30,7 +30,7 @@ ESRGAN_MODEL_URLS: dict[str, str] = {
}
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2")
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="upscale", version="1.3.2")
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Upscales an image using RealESRGAN."""

View File

@@ -57,7 +57,7 @@ class ZImageControlOutput(BaseInvocationOutput):
"z_image_control",
title="Z-Image ControlNet",
tags=["image", "z-image", "control", "controlnet"],
category="control",
category="conditioning",
version="1.1.0",
classification=Classification.Prototype,
)

View File

@@ -49,8 +49,8 @@ from invokeai.backend.z_image.z_image_transformer_patch import patch_transformer
"z_image_denoise",
title="Denoise - Z-Image",
tags=["image", "z-image"],
category="image",
version="1.4.0",
category="latents",
version="1.5.0",
classification=Classification.Prototype,
)
class ZImageDenoiseInvocation(BaseInvocation):
@@ -104,6 +104,15 @@ class ZImageDenoiseInvocation(BaseInvocation):
description=FieldDescriptions.vae + " Required for control conditioning.",
input=Input.Connection,
)
# Shift override for the sigma schedule. If None, shift is auto-calculated from image dimensions.
shift: Optional[float] = InputField(
default=None,
ge=0.0,
description="Override the timestep shift (mu) for the sigma schedule. "
"Leave blank to auto-calculate based on image dimensions (recommended). "
"Lower values (~0.5) produce less noise shifting, higher values (~1.15) produce more.",
title="Shift",
)
# Scheduler selection for the denoising process
scheduler: ZIMAGE_SCHEDULER_NAME_VALUES = InputField(
default="euler",
@@ -225,34 +234,36 @@ class ZImageDenoiseInvocation(BaseInvocation):
"""Calculate timestep shift based on image sequence length.
Based on diffusers ZImagePipeline.calculate_shift method.
"""
m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len)
b = base_shift - m * base_image_seq_len
mu = image_seq_len * m + b
return mu
def _get_sigmas(self, mu: float, num_steps: int) -> list[float]:
"""Generate sigma schedule with time shift.
Based on FlowMatchEulerDiscreteScheduler with shift.
Generates num_steps + 1 sigma values (including terminal 0.0).
Returns a linear shift value (exp(mu) from the original formula).
"""
import math
def time_shift(mu: float, sigma: float, t: float) -> float:
"""Apply time shift to a single timestep value."""
m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len)
b = base_shift - m * base_image_seq_len
mu = image_seq_len * m + b
# Convert from exponential mu to linear shift value
return math.exp(mu)
def _get_sigmas(self, shift: float, num_steps: int) -> list[float]:
"""Generate sigma schedule with linear time shift.
Uses linear time shift: shift / (shift + (1/t - 1)).
The shift value is used directly as a multiplier.
Generates num_steps + 1 sigma values (including terminal 0.0).
"""
def time_shift(shift: float, t: float) -> float:
"""Apply linear time shift to a single timestep value."""
if t <= 0:
return 0.0
if t >= 1:
return 1.0
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
return shift / (shift + (1 / t - 1))
# Generate linearly spaced values from 1 to 0 (excluding endpoints for safety)
# then apply time shift
sigmas = []
for i in range(num_steps + 1):
t = 1.0 - i / num_steps # Goes from 1.0 to 0.0
sigma = time_shift(mu, 1.0, t)
sigma = time_shift(shift, t)
sigmas.append(sigma)
return sigmas
@@ -313,11 +324,14 @@ class ZImageDenoiseInvocation(BaseInvocation):
# Concatenate all negative embeddings
neg_prompt_embeds = torch.cat([tc.prompt_embeds for tc in neg_text_conditionings], dim=0)
# Calculate shift based on image sequence length
mu = self._calculate_shift(img_seq_len)
# Calculate shift based on image sequence length, or use override
if self.shift is not None:
shift = self.shift
else:
shift = self._calculate_shift(img_seq_len)
# Generate sigma schedule with time shift
sigmas = self._get_sigmas(mu, self.steps)
sigmas = self._get_sigmas(shift, self.steps)
# Apply denoising_start and denoising_end clipping
if self.denoising_start > 0 or self.denoising_end < 1:

View File

@@ -30,7 +30,7 @@ ZImageVAE = Union[AutoencoderKL, FluxAutoEncoder]
"z_image_i2l",
title="Image to Latents - Z-Image",
tags=["image", "latents", "vae", "i2l", "z-image"],
category="image",
category="latents",
version="1.1.0",
classification=Classification.Prototype,
)

View File

@@ -19,7 +19,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
"z_image_seed_variance_enhancer",
title="Seed Variance Enhancer - Z-Image",
tags=["conditioning", "z-image", "variance", "seed"],
category="conditioning",
category="prompt",
version="1.0.0",
classification=Classification.Prototype,
)

View File

@@ -16,6 +16,7 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import Qwen3EncoderField
from invokeai.app.invocations.primitives import ZImageConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.z_image_lora_constants import Z_IMAGE_LORA_QWEN3_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -33,7 +34,7 @@ Z_IMAGE_MAX_SEQ_LEN = 512
"z_image_text_encoder",
title="Prompt - Z-Image",
tags=["prompt", "conditioning", "z-image"],
category="conditioning",
category="prompt",
version="1.1.0",
classification=Classification.Prototype,
)
@@ -76,11 +77,17 @@ class ZImageTextEncoderInvocation(BaseInvocation):
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
with ExitStack() as exit_stack:
(_, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
(cached_weights, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
# Use the device that the text_encoder is actually on
device = text_encoder.device
# Use the device that the text encoder is effectively executing on, and repair any required tensors left on
# the CPU by a previous interrupted run.
repaired_tensors = text_encoder_info.repair_required_tensors_on_device()
device = get_effective_device(text_encoder)
if repaired_tensors > 0:
context.logger.warning(
f"Recovered {repaired_tensors} required Qwen3 tensor(s) onto {device} after a partial device mismatch."
)
# Apply LoRA models to the text encoder
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
@@ -90,6 +97,7 @@ class ZImageTextEncoderInvocation(BaseInvocation):
patches=self._lora_iterator(context),
prefix=Z_IMAGE_LORA_QWEN3_PREFIX,
dtype=lora_dtype,
cached_weights=cached_weights,
)
)

View File

@@ -21,6 +21,7 @@ class TokenData(BaseModel):
user_id: str
email: str
is_admin: bool
remember_me: bool = False
def set_jwt_secret(secret: str) -> None:

View File

@@ -9,6 +9,17 @@ from invokeai.app.util.misc import get_iso_timestamp
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
class BoardVisibility(str, Enum, metaclass=MetaEnum):
"""The visibility options for a board."""
Private = "private"
"""Only the board owner (and admins) can see and modify this board."""
Shared = "shared"
"""All users can view this board, but only the owner (and admins) can modify it."""
Public = "public"
"""All users can view this board; only the owner (and admins) can modify its structure."""
class BoardRecord(BaseModelExcludeNull):
"""Deserialized board record."""
@@ -28,6 +39,10 @@ class BoardRecord(BaseModelExcludeNull):
"""The name of the cover image of the board."""
archived: bool = Field(description="Whether or not the board is archived.")
"""Whether or not the board is archived."""
board_visibility: BoardVisibility = Field(
default=BoardVisibility.Private, description="The visibility of the board."
)
"""The visibility of the board (private, shared, or public)."""
def deserialize_board_record(board_dict: dict) -> BoardRecord:
@@ -44,6 +59,11 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
updated_at = board_dict.get("updated_at", get_iso_timestamp())
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
archived = board_dict.get("archived", False)
board_visibility_raw = board_dict.get("board_visibility", BoardVisibility.Private.value)
try:
board_visibility = BoardVisibility(board_visibility_raw)
except ValueError:
board_visibility = BoardVisibility.Private
return BoardRecord(
board_id=board_id,
@@ -54,6 +74,7 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
updated_at=updated_at,
deleted_at=deleted_at,
archived=archived,
board_visibility=board_visibility,
)
@@ -61,6 +82,7 @@ class BoardChanges(BaseModel, extra="forbid"):
board_name: Optional[str] = Field(default=None, description="The board's new name.", max_length=300)
cover_image_name: Optional[str] = Field(default=None, description="The name of the board's new cover image.")
archived: Optional[bool] = Field(default=None, description="Whether or not the board is archived")
board_visibility: Optional[BoardVisibility] = Field(default=None, description="The visibility of the board.")
class BoardRecordOrderBy(str, Enum, metaclass=MetaEnum):

View File

@@ -116,6 +116,17 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
(changes.archived, board_id),
)
# Change the visibility of a board
if changes.board_visibility is not None:
cursor.execute(
"""--sql
UPDATE boards
SET board_visibility = ?
WHERE board_id = ?;
""",
(changes.board_visibility.value, board_id),
)
except sqlite3.Error as e:
raise BoardRecordSaveException from e
return self.get(board_id)
@@ -155,7 +166,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
SELECT DISTINCT boards.*
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public'))
{archived_filter}
ORDER BY {order_by} {direction}
LIMIT ? OFFSET ?;
@@ -194,14 +205,14 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
SELECT COUNT(DISTINCT boards.board_id)
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1);
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public'));
"""
else:
count_query = """
SELECT COUNT(DISTINCT boards.board_id)
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public'))
AND boards.archived = 0;
"""
@@ -251,7 +262,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
SELECT DISTINCT boards.*
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public'))
{archived_filter}
ORDER BY LOWER(boards.board_name) {direction}
"""
@@ -260,7 +271,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
SELECT DISTINCT boards.*
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public'))
{archived_filter}
ORDER BY {order_by} {direction}
"""

View File

@@ -7,7 +7,11 @@ class BulkDownloadBase(ABC):
@abstractmethod
def handler(
self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
self,
image_names: Optional[list[str]],
board_id: Optional[str],
bulk_download_item_id: Optional[str],
user_id: str = "system",
) -> None:
"""
Create a zip file containing the images specified by the given image names or board id.
@@ -15,6 +19,7 @@ class BulkDownloadBase(ABC):
:param image_names: A list of image names to include in the zip file.
:param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
:param bulk_download_item_id: The bulk_download_item_id that will be used to retrieve the bulk download item when it is prepared, if none is provided a uuid will be generated.
:param user_id: The ID of the user who initiated the download.
"""
@abstractmethod
@@ -42,3 +47,12 @@ class BulkDownloadBase(ABC):
:param bulk_download_item_name: The name of the bulk download item.
"""
@abstractmethod
def get_owner(self, bulk_download_item_name: str) -> Optional[str]:
"""
Get the user_id of the user who initiated the download.
:param bulk_download_item_name: The name of the bulk download item.
:return: The user_id of the owner, or None if not tracked.
"""

View File

@@ -25,15 +25,24 @@ class BulkDownloadService(BulkDownloadBase):
self._temp_directory = TemporaryDirectory()
self._bulk_downloads_folder = Path(self._temp_directory.name) / "bulk_downloads"
self._bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
# Track which user owns each download so the fetch endpoint can enforce ownership
self._download_owners: dict[str, str] = {}
def handler(
self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
self,
image_names: Optional[list[str]],
board_id: Optional[str],
bulk_download_item_id: Optional[str],
user_id: str = "system",
) -> None:
bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID
bulk_download_item_id = bulk_download_item_id or uuid_string()
bulk_download_item_name = bulk_download_item_id + ".zip"
self._signal_job_started(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
# Record ownership so the fetch endpoint can verify the caller
self._download_owners[bulk_download_item_name] = user_id
self._signal_job_started(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id)
try:
image_dtos: list[ImageDTO] = []
@@ -46,16 +55,16 @@ class BulkDownloadService(BulkDownloadBase):
raise BulkDownloadParametersException()
bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id)
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id)
except (
ImageRecordNotFoundException,
BoardRecordNotFoundException,
BulkDownloadException,
BulkDownloadParametersException,
) as e:
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e)
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e, user_id)
except Exception as e:
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e)
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e, user_id)
self._invoker.services.logger.error("Problem bulk downloading images.")
raise e
@@ -103,43 +112,60 @@ class BulkDownloadService(BulkDownloadBase):
return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " " or c == "_" or c == "-"]).rstrip()
def _signal_job_started(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
user_id: str = "system",
) -> None:
"""Signal that a bulk download job has started."""
if self._invoker:
assert bulk_download_id is not None
self._invoker.services.events.emit_bulk_download_started(
bulk_download_id, bulk_download_item_id, bulk_download_item_name
bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id=user_id
)
def _signal_job_completed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
user_id: str = "system",
) -> None:
"""Signal that a bulk download job has completed."""
if self._invoker:
assert bulk_download_id is not None
assert bulk_download_item_name is not None
self._invoker.services.events.emit_bulk_download_complete(
bulk_download_id, bulk_download_item_id, bulk_download_item_name
bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id=user_id
)
def _signal_job_failed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, exception: Exception
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
exception: Exception,
user_id: str = "system",
) -> None:
"""Signal that a bulk download job has failed."""
if self._invoker:
assert bulk_download_id is not None
assert exception is not None
self._invoker.services.events.emit_bulk_download_error(
bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception)
bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception), user_id=user_id
)
def stop(self, *args, **kwargs):
self._temp_directory.cleanup()
def get_owner(self, bulk_download_item_name: str) -> Optional[str]:
return self._download_owners.get(bulk_download_item_name)
def delete(self, bulk_download_item_name: str) -> None:
path = self.get_path(bulk_download_item_name)
Path(path).unlink()
self._download_owners.pop(bulk_download_item_name, None)
def get_path(self, bulk_download_item_name: str) -> str:
path = str(self._bulk_downloads_folder / bulk_download_item_name)

View File

@@ -110,7 +110,8 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue.
clear_queue_on_startup: Empties session queue on startup.
clear_queue_on_startup: Empties session queue on startup. If true, disables `max_queue_history`.
max_queue_history: Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true.
allow_nodes: List of nodes to allow. Omit to allow all.
deny_nodes: List of nodes to deny. Omit to deny none.
node_cache_size: How many cached nodes to keep in memory.
@@ -204,7 +205,8 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup. If true, disables `max_queue_history`.")
max_queue_history: Optional[int] = Field(default=None, ge=0, description="Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true.")
# NODES
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")

View File

@@ -100,9 +100,9 @@ class EventServiceBase:
"""Emitted when a queue item's status changes"""
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status))
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None:
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult", user_id: str = "system") -> None:
"""Emitted when a batch is enqueued"""
self.dispatch(BatchEnqueuedEvent.build(enqueue_result))
self.dispatch(BatchEnqueuedEvent.build(enqueue_result, user_id))
def emit_queue_items_retried(self, retry_result: "RetryItemsResult") -> None:
"""Emitted when a list of queue items are retried"""
@@ -112,9 +112,9 @@ class EventServiceBase:
"""Emitted when a queue is cleared"""
self.dispatch(QueueClearedEvent.build(queue_id))
def emit_recall_parameters_updated(self, queue_id: str, parameters: dict) -> None:
def emit_recall_parameters_updated(self, queue_id: str, user_id: str, parameters: dict) -> None:
"""Emitted when recall parameters are updated"""
self.dispatch(RecallParametersUpdatedEvent.build(queue_id, parameters))
self.dispatch(RecallParametersUpdatedEvent.build(queue_id, user_id, parameters))
# endregion
@@ -194,23 +194,42 @@ class EventServiceBase:
# region Bulk image download
def emit_bulk_download_started(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
user_id: str = "system",
) -> None:
"""Emitted when a bulk image download is started"""
self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
self.dispatch(
BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id)
)
def emit_bulk_download_complete(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
user_id: str = "system",
) -> None:
"""Emitted when a bulk image download is complete"""
self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
self.dispatch(
BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id)
)
def emit_bulk_download_error(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
error: str,
user_id: str = "system",
) -> None:
"""Emitted when a bulk image download has an error"""
self.dispatch(
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error)
BulkDownloadErrorEvent.build(
bulk_download_id, bulk_download_item_id, bulk_download_item_name, error, user_id
)
)
# endregion

View File

@@ -281,9 +281,10 @@ class BatchEnqueuedEvent(QueueEventBase):
)
priority: int = Field(description="The priority of the batch")
origin: str | None = Field(default=None, description="The origin of the batch")
user_id: str = Field(default="system", description="The ID of the user who enqueued the batch")
@classmethod
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
def build(cls, enqueue_result: EnqueueBatchResult, user_id: str = "system") -> "BatchEnqueuedEvent":
return cls(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
@@ -291,6 +292,7 @@ class BatchEnqueuedEvent(QueueEventBase):
enqueued=enqueue_result.enqueued,
requested=enqueue_result.requested,
priority=enqueue_result.priority,
user_id=user_id,
)
@@ -609,6 +611,7 @@ class BulkDownloadEventBase(EventBase):
bulk_download_id: str = Field(description="The ID of the bulk image download")
bulk_download_item_id: str = Field(description="The ID of the bulk image download item")
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
user_id: str = Field(default="system", description="The ID of the user who initiated the download")
@payload_schema.register
@@ -619,12 +622,17 @@ class BulkDownloadStartedEvent(BulkDownloadEventBase):
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
cls,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
user_id: str = "system",
) -> "BulkDownloadStartedEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
user_id=user_id,
)
@@ -636,12 +644,17 @@ class BulkDownloadCompleteEvent(BulkDownloadEventBase):
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
cls,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
user_id: str = "system",
) -> "BulkDownloadCompleteEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
user_id=user_id,
)
@@ -655,13 +668,19 @@ class BulkDownloadErrorEvent(BulkDownloadEventBase):
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
cls,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
error: str,
user_id: str = "system",
) -> "BulkDownloadErrorEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=error,
user_id=user_id,
)
@@ -671,8 +690,9 @@ class RecallParametersUpdatedEvent(QueueEventBase):
__event_name__ = "recall_parameters_updated"
user_id: str = Field(description="The ID of the user whose recall parameters were updated")
parameters: dict[str, Any] = Field(description="The recall parameters that were updated")
@classmethod
def build(cls, queue_id: str, parameters: dict[str, Any]) -> "RecallParametersUpdatedEvent":
return cls(queue_id=queue_id, parameters=parameters)
def build(cls, queue_id: str, user_id: str, parameters: dict[str, Any]) -> "RecallParametersUpdatedEvent":
return cls(queue_id=queue_id, user_id=user_id, parameters=parameters)

View File

@@ -46,3 +46,9 @@ class FastAPIEventService(EventServiceBase):
except asyncio.CancelledError as e:
raise e # Raise a proper error
except Exception:
import logging
logging.getLogger("InvokeAI").error(
f"Error dispatching event {getattr(event, '__event_name__', event)}", exc_info=True
)

View File

@@ -3,9 +3,9 @@ from invokeai.app.services.external_generation.external_generation_base import (
ExternalProvider,
)
from invokeai.app.services.external_generation.external_generation_common import (
ExternalGeneratedImage,
ExternalGenerationRequest,
ExternalGenerationResult,
ExternalGeneratedImage,
ExternalProviderStatus,
ExternalReferenceImage,
)

View File

@@ -16,3 +16,13 @@ class ExternalProviderCapabilityError(ExternalGenerationError):
class ExternalProviderRequestError(ExternalGenerationError):
"""Raised when a provider rejects the request or returns an error."""
class ExternalProviderRateLimitError(ExternalProviderRequestError):
"""Raised when a provider returns HTTP 429 (rate limit exceeded)."""
retry_after: float | None
def __init__(self, message: str, retry_after: float | None = None) -> None:
super().__init__(message)
self.retry_after = retry_after

View File

@@ -11,8 +11,6 @@ from invokeai.backend.model_manager.configs.external_api import ExternalApiModel
@dataclass(frozen=True)
class ExternalReferenceImage:
image: PILImageType
weight: float | None = None
mode: str | None = None
@dataclass(frozen=True)
@@ -20,17 +18,16 @@ class ExternalGenerationRequest:
model: ExternalApiModelConfig
mode: ExternalGenerationMode
prompt: str
negative_prompt: str | None
seed: int | None
num_images: int
width: int
height: int
steps: int | None
guidance: float | None
image_size: str | None
init_image: PILImageType | None
mask_image: PILImageType | None
reference_images: list[ExternalReferenceImage]
metadata: dict[str, Any] | None
provider_options: dict[str, Any] | None = None
@dataclass(frozen=True)

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import time
from logging import Logger
from typing import TYPE_CHECKING
@@ -10,6 +11,7 @@ from invokeai.app.services.external_generation.errors import (
ExternalProviderCapabilityError,
ExternalProviderNotConfiguredError,
ExternalProviderNotFoundError,
ExternalProviderRateLimitError,
)
from invokeai.app.services.external_generation.external_generation_base import (
ExternalGenerationServiceBase,
@@ -52,7 +54,7 @@ class ExternalGenerationService(ExternalGenerationServiceBase):
request = self._bucket_request(request)
self._validate_request(request)
result = provider.generate(request)
result = self._generate_with_retry(provider, request)
if resize_to_original_inpaint_size is None:
return result
@@ -60,6 +62,30 @@ class ExternalGenerationService(ExternalGenerationServiceBase):
width, height = resize_to_original_inpaint_size
return _resize_result_images(result, width, height)
_MAX_RETRIES = 3
_DEFAULT_RETRY_DELAY = 10.0
_MAX_RETRY_DELAY = 60.0
def _generate_with_retry(
self, provider: ExternalProvider, request: ExternalGenerationRequest
) -> ExternalGenerationResult:
for attempt in range(self._MAX_RETRIES):
try:
return provider.generate(request)
except ExternalProviderRateLimitError as exc:
if attempt == self._MAX_RETRIES - 1:
raise
delay = min(exc.retry_after or self._DEFAULT_RETRY_DELAY, self._MAX_RETRY_DELAY)
self._logger.warning(
"Rate limited by %s (attempt %d/%d), retrying in %.0fs",
request.model.provider_id,
attempt + 1,
self._MAX_RETRIES,
delay,
)
time.sleep(delay)
raise ExternalProviderRateLimitError("Rate limit exceeded after all retries")
def get_provider_statuses(self) -> dict[str, ExternalProviderStatus]:
return {provider_id: provider.get_status() for provider_id, provider in self._providers.items()}
@@ -77,15 +103,9 @@ class ExternalGenerationService(ExternalGenerationServiceBase):
if request.mode not in capabilities.modes:
raise ExternalProviderCapabilityError(f"Mode '{request.mode}' is not supported by {request.model.name}")
if request.negative_prompt and not capabilities.supports_negative_prompt:
raise ExternalProviderCapabilityError(f"Negative prompts are not supported by {request.model.name}")
if request.seed is not None and not capabilities.supports_seed:
raise ExternalProviderCapabilityError(f"Seed control is not supported by {request.model.name}")
if request.guidance is not None and not capabilities.supports_guidance:
raise ExternalProviderCapabilityError(f"Guidance is not supported by {request.model.name}")
if request.reference_images and not capabilities.supports_reference_images:
raise ExternalProviderCapabilityError(f"Reference images are not supported by {request.model.name}")
@@ -159,17 +179,16 @@ class ExternalGenerationService(ExternalGenerationServiceBase):
model=record,
mode=request.mode,
prompt=request.prompt,
negative_prompt=request.negative_prompt,
seed=request.seed,
num_images=request.num_images,
width=request.width,
height=request.height,
steps=request.steps,
guidance=request.guidance,
image_size=request.image_size,
init_image=request.init_image,
mask_image=request.mask_image,
reference_images=request.reference_images,
metadata=request.metadata,
provider_options=request.provider_options,
)
def _bucket_request(self, request: ExternalGenerationRequest) -> ExternalGenerationRequest:
@@ -229,17 +248,16 @@ class ExternalGenerationService(ExternalGenerationServiceBase):
model=request.model,
mode=request.mode,
prompt=request.prompt,
negative_prompt=request.negative_prompt,
seed=request.seed,
num_images=request.num_images,
width=width,
height=height,
steps=request.steps,
guidance=request.guidance,
image_size=request.image_size,
init_image=_resize_image(request.init_image, width, height, "RGB"),
mask_image=_resize_image(request.mask_image, width, height, "L"),
reference_images=request.reference_images,
metadata=request.metadata,
provider_options=request.provider_options,
)

Some files were not shown because too many files have changed in this diff Show More