mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Compare commits
93 Commits
v6.12.0rc1
...
external-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7938d840b2 | ||
|
|
450ba7b7e1 | ||
|
|
3fc981f4b6 | ||
|
|
e252a5bb47 | ||
|
|
ce896678d7 | ||
|
|
37ff6c3743 | ||
|
|
c743106f66 | ||
|
|
1b50c1a79c | ||
|
|
acd4157bdf | ||
|
|
06a1881bbd | ||
|
|
441821ca03 | ||
|
|
9d62bfdf8e | ||
|
|
cd888654d5 | ||
|
|
dd056067a9 | ||
|
|
ec4b87b949 | ||
|
|
8f00759af0 | ||
|
|
5c09c823a9 | ||
|
|
33ec16deb4 | ||
|
|
b42274a57e | ||
|
|
ec90b2fbe9 | ||
|
|
17157d7c60 | ||
|
|
a3507121da | ||
|
|
3c9b282a90 | ||
|
|
a2e4fbb9b5 | ||
|
|
06eff38354 | ||
|
|
d4104be0b8 | ||
|
|
ee600973ed | ||
|
|
d4c0e631e2 | ||
|
|
5f35d0e432 | ||
|
|
f0d09c34a8 | ||
|
|
853c3ef915 | ||
|
|
60d0bcdbc1 | ||
|
|
80be1b7282 | ||
|
|
dbbf28925b | ||
|
|
f08b802968 | ||
|
|
ae42182246 | ||
|
|
3e9e052d5d | ||
|
|
089e2db402 | ||
|
|
4cbd60b4a5 | ||
|
|
c2016bcfb7 | ||
|
|
32002bd37e | ||
|
|
e6f2980d7c | ||
|
|
01c67c5468 | ||
|
|
be015a5434 | ||
|
|
82f3dc9032 | ||
|
|
471ab9d9c0 | ||
|
|
41a542552e | ||
|
|
5596fa0cc8 | ||
|
|
05f4deb68c | ||
|
|
474d85e5e0 | ||
|
|
ed268b1cfc | ||
|
|
6963cd97ba | ||
|
|
813a5e2c2e | ||
|
|
18315db7f0 | ||
|
|
edde0b4737 | ||
|
|
ab6f186f8c | ||
|
|
7f2878f691 | ||
|
|
d32f6b5a56 | ||
|
|
f7aa5fcbbf | ||
|
|
438515bf9a | ||
|
|
27fc650f4f | ||
|
|
a1eef791a1 | ||
|
|
d8d0ebc356 | ||
|
|
8375f95ea9 | ||
|
|
9e4d0bb191 | ||
|
|
20a400cee8 | ||
|
|
40f02aa6c4 | ||
|
|
c3a482e80a | ||
|
|
257994f552 | ||
|
|
bafce41856 | ||
|
|
757bd3d002 | ||
|
|
519575e871 | ||
|
|
17da6bb9c3 | ||
|
|
b120ef5183 | ||
|
|
dc5007fe95 | ||
|
|
f39456e6f0 | ||
|
|
bba207a856 | ||
|
|
a7b367fda2 | ||
|
|
cd47b3baf7 | ||
|
|
689725c6e4 | ||
|
|
10729f40f2 | ||
|
|
362054120e | ||
|
|
b91a156a3d | ||
|
|
c6b0d45c5f | ||
|
|
dc665e08ac | ||
|
|
0dd72837d3 | ||
|
|
d5a6283f23 | ||
|
|
6fe1a6f1ac | ||
|
|
5d34eab6f0 | ||
|
|
1b43769b95 | ||
|
|
a9d3b4e17c | ||
|
|
74ecc461b9 | ||
|
|
19650f6ada |
13
Makefile
13
Makefile
@@ -12,12 +12,13 @@ help:
|
||||
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
|
||||
@echo "test Run the unit tests."
|
||||
@echo "update-config-docstring Update the app's config docstring so mkdocs can autogenerate it correctly."
|
||||
@echo "frontend-install Install the pnpm modules needed for the front end"
|
||||
@echo "frontend-build Build the frontend in order to run on localhost:9090"
|
||||
@echo "frontend-install Install the pnpm modules needed for the frontend"
|
||||
@echo "frontend-build Build the frontend for localhost:9090"
|
||||
@echo "frontend-test Run the frontend test suite once"
|
||||
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
|
||||
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
||||
@echo "frontend-prettier Format the frontend using lint:prettier"
|
||||
@echo "wheel Build the wheel for the current version"
|
||||
@echo "frontend-lint Run frontend checks and fixable lint/format steps"
|
||||
@echo "wheel Build the wheel for the current version"
|
||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
|
||||
@echo "docs Serve the mkdocs site with live reload"
|
||||
@@ -57,6 +58,10 @@ frontend-install:
|
||||
frontend-build:
|
||||
cd invokeai/frontend/web && pnpm build
|
||||
|
||||
# Run the frontend test suite once
|
||||
frontend-test:
|
||||
cd invokeai/frontend/web && pnpm run test:run
|
||||
|
||||
# Run the frontend in dev mode
|
||||
frontend-dev:
|
||||
cd invokeai/frontend/web && pnpm dev
|
||||
|
||||
28
README.md
28
README.md
@@ -58,15 +58,39 @@ Invoke offers a fully featured workflow management solution, enabling users to c
|
||||
|
||||
Invoke features an organized gallery system for easily storing, accessing, and remixing your content in the Invoke workspace. Images can be dragged/dropped onto any Image-base UI element in the application, and rich metadata within the Image allows for easy recall of key prompts or settings used in your workflow.
|
||||
|
||||
### Model Support
|
||||
- SD 1.5
|
||||
- SD 2.0
|
||||
- SDXL
|
||||
- SD 3.5 Medium
|
||||
- SD 3.5 Large
|
||||
- CogView 4
|
||||
- Flux.1 Dev
|
||||
- Flux.1 Schnell
|
||||
- Flux.1 Kontext
|
||||
- Flux.1 Krea
|
||||
- Flux Redux
|
||||
- Flux Fill
|
||||
- Flux.2 Klein 4B
|
||||
- Flux.2 Klein 9B
|
||||
- Z-Image Turbo
|
||||
- Z-Image Base
|
||||
- Anima
|
||||
- Qwen Image
|
||||
- Qwen Image Edit
|
||||
- Nano Banana (API Only)
|
||||
- GPT Image (API Only)
|
||||
- Wan (API Only)
|
||||
|
||||
### Other features
|
||||
|
||||
- Support for both ckpt and diffusers models
|
||||
- SD1.5, SD2.0, SDXL, and FLUX support
|
||||
- Support for ckpt, diffusers, and some gguf models
|
||||
- Upscaling Tools
|
||||
- Embedding Manager & Support
|
||||
- Model Manager & Support
|
||||
- Workflow creation & management
|
||||
- Node-Based Architecture
|
||||
- Object Segmentation & Selection Models (SAM / SAM2)
|
||||
|
||||
## Contributing
|
||||
|
||||
|
||||
205
docs/contributing/CANVAS_PROJECTS/CANVAS_PROJECTS.md
Normal file
205
docs/contributing/CANVAS_PROJECTS/CANVAS_PROJECTS.md
Normal file
@@ -0,0 +1,205 @@
|
||||
# Canvas Projects — Technical Documentation
|
||||
|
||||
## Overview
|
||||
|
||||
Canvas Projects provide a save/load mechanism for the entire canvas state. The feature serializes all canvas entities, generation parameters, reference images, and their associated image files into a ZIP-based `.invk` file. On load, it restores the full state, handling image deduplication and re-uploading as needed.
|
||||
|
||||
## File Format
|
||||
|
||||
The `.invk` file is a standard ZIP archive with the following structure:
|
||||
|
||||
```
|
||||
project.invk
|
||||
├── manifest.json
|
||||
├── canvas_state.json
|
||||
├── params.json
|
||||
├── ref_images.json
|
||||
├── loras.json
|
||||
└── images/
|
||||
├── {image_name_1}.png
|
||||
├── {image_name_2}.png
|
||||
└── ...
|
||||
```
|
||||
|
||||
### manifest.json
|
||||
|
||||
Schema version and metadata. Validated on load with Zod.
|
||||
|
||||
```json
|
||||
{
|
||||
"version": 1,
|
||||
"appVersion": "5.12.0",
|
||||
"createdAt": "2026-02-26T12:00:00.000Z",
|
||||
"name": "My Canvas Project"
|
||||
}
|
||||
```
|
||||
|
||||
| Field | Type | Description |
|
||||
|---|---|---|
|
||||
| `version` | `number` | Schema version, currently `1`. Used for migration logic on load. |
|
||||
| `appVersion` | `string` | InvokeAI version that created the file. Informational only. |
|
||||
| `createdAt` | `string` | ISO 8601 timestamp. |
|
||||
| `name` | `string` | User-provided project name. Also used as the download filename. |
|
||||
|
||||
### canvas_state.json
|
||||
|
||||
The serialized canvas entity tree. Type: `CanvasProjectState`.
|
||||
|
||||
```typescript
|
||||
type CanvasProjectState = {
|
||||
rasterLayers: CanvasRasterLayerState[];
|
||||
controlLayers: CanvasControlLayerState[];
|
||||
inpaintMasks: CanvasInpaintMaskState[];
|
||||
regionalGuidance: CanvasRegionalGuidanceState[];
|
||||
bbox: CanvasState['bbox'];
|
||||
selectedEntityIdentifier: CanvasState['selectedEntityIdentifier'];
|
||||
bookmarkedEntityIdentifier: CanvasState['bookmarkedEntityIdentifier'];
|
||||
};
|
||||
```
|
||||
|
||||
Each entity contains its full state including all canvas objects (brush lines, eraser lines, rect shapes, images). Image objects reference files by `image_name` which correspond to files in the `images/` folder.
|
||||
|
||||
### params.json
|
||||
|
||||
The complete generation parameters state (`ParamsState`). Optional on load (older files may not have it). This includes all fields from the params Redux slice:
|
||||
|
||||
- Prompts (positive, negative, prompt history)
|
||||
- Core generation settings (seed, steps, CFG scale, guidance, scheduler, iterations)
|
||||
- Model selections (main model, VAE, FLUX VAE, T5 encoder, CLIP embed models, refiner, Z-Image models, Klein models)
|
||||
- Dimensions (width, height, aspect ratio)
|
||||
- Img2img strength
|
||||
- Infill settings (method, tile size, patchmatch downscale, color)
|
||||
- Canvas coherence settings (mode, edge size, min denoise)
|
||||
- Refiner parameters (steps, CFG scale, scheduler, aesthetic scores, start)
|
||||
- FLUX-specific settings (scheduler, DyPE preset/scale/exponent)
|
||||
- Z-Image-specific settings (scheduler, seed variance)
|
||||
- Upscale settings (scheduler, CFG scale)
|
||||
- Seamless tiling, mask blur, CLIP skip, VAE precision, CPU noise, color compensation
|
||||
|
||||
### ref_images.json
|
||||
|
||||
Global reference image entities (`RefImageState[]`). These are IP-Adapter / FLUX Redux configs with `CroppableImageWithDims` containing both original and cropped image references. Optional on load.
|
||||
|
||||
### loras.json
|
||||
|
||||
Array of LoRA configurations (`LoRA[]`). Each entry contains:
|
||||
|
||||
```typescript
|
||||
type LoRA = {
|
||||
id: string;
|
||||
isEnabled: boolean;
|
||||
model: ModelIdentifierField;
|
||||
weight: number;
|
||||
};
|
||||
```
|
||||
|
||||
Optional on load. Like models, LoRA identifiers are stored as-is — if a LoRA is not installed when loading, the entry is restored but may not be usable.
|
||||
|
||||
### images/
|
||||
|
||||
All image files referenced anywhere in the state. Keyed by their original `image_name`. On save, each image is fetched from the backend via `GET /api/v1/images/i/{name}/full` and stored as-is.
|
||||
|
||||
## Key Source Files
|
||||
|
||||
| File | Purpose |
|
||||
|---|---|
|
||||
| `features/controlLayers/util/canvasProjectFile.ts` | Types, constants, image name collection, remapping, existence checking |
|
||||
| `features/controlLayers/hooks/useCanvasProjectSave.ts` | Save hook — collects Redux state, fetches images, builds ZIP |
|
||||
| `features/controlLayers/hooks/useCanvasProjectLoad.ts` | Load hook — parses ZIP, deduplicates images, dispatches state |
|
||||
| `features/controlLayers/components/SaveCanvasProjectDialog.tsx` | Save name dialog + `useSaveCanvasProjectWithDialog` hook |
|
||||
| `features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog.tsx` | Load confirmation dialog + `useLoadCanvasProjectWithDialog` hook |
|
||||
| `features/controlLayers/components/Toolbar/CanvasToolbarProjectMenuButton.tsx` | Toolbar dropdown UI |
|
||||
| `features/controlLayers/store/canvasSlice.ts` | `canvasProjectRecalled` Redux action |
|
||||
|
||||
## Save Flow
|
||||
|
||||
1. User clicks "Save Canvas Project" → `SaveCanvasProjectDialog` opens asking for a project name
|
||||
2. On confirm, `saveCanvasProject(name)` is called
|
||||
3. Read Redux state via selectors: `selectCanvasSlice()`, `selectParamsSlice()`, `selectRefImagesSlice()`, `selectLoRAsSlice()`
|
||||
4. Build `CanvasProjectState` from the canvas slice; use `paramsState` directly for params
|
||||
5. Walk all entities to collect every `image_name` reference via `collectImageNames()`:
|
||||
- `CanvasImageState.image.image_name` in layer/mask objects
|
||||
- `CroppableImageWithDims.original.image.image_name` in global ref images
|
||||
- `CroppableImageWithDims.crop.image.image_name` in cropped ref images
|
||||
- `ImageWithDims.image_name` in regional guidance ref images
|
||||
6. Fetch each image from the backend API
|
||||
7. Build ZIP with JSZip: add `manifest.json` (including `name`), `canvas_state.json`, `params.json`, `ref_images.json`, and all images into `images/`
|
||||
8. Sanitize the name for filesystem use and generate blob, trigger download as `{name}.invk`
|
||||
|
||||
## Load Flow
|
||||
|
||||
1. User selects `.invk` file → confirmation dialog opens
|
||||
2. On confirm, parse ZIP with JSZip
|
||||
3. Validate manifest version via Zod schema
|
||||
4. Read `canvas_state.json`, `params.json` (optional), `ref_images.json` (optional)
|
||||
5. Collect all `image_name` references from the loaded state
|
||||
6. **Deduplicate images**: for each referenced image, check if it exists on the server via `getImageDTOSafe(image_name)`
|
||||
- Already exists → skip (no upload)
|
||||
- Missing → upload from ZIP via `uploadImage()`, record `oldName → newName` mapping
|
||||
7. Remap all `image_name` values in the loaded state using the mapping (only for re-uploaded images whose names changed)
|
||||
8. Dispatch Redux actions:
|
||||
- `canvasProjectRecalled()` — restores all canvas entities, bbox, selected/bookmarked entity
|
||||
- `refImagesRecalled()` — restores global reference images
|
||||
- `paramsRecalled()` — replaces the entire params state in one action
|
||||
- `loraAllDeleted()` + `loraRecalled()` — restores LoRAs
|
||||
9. Show success/error toast
|
||||
|
||||
## Image Name Collection & Remapping
|
||||
|
||||
The `canvasProjectFile.ts` utility provides two parallel sets of functions:
|
||||
|
||||
**Collection** (`collectImageNames`): Walks the entire state tree and returns a `Set<string>` of all referenced `image_name` values. This is used by both save (to know which images to fetch) and load (to know which images to check/upload).
|
||||
|
||||
**Remapping** (`remapCanvasState`, `remapRefImages`): Deep-clones state objects and replaces `image_name` values using a `Map<string, string>` mapping. Only images that were re-uploaded with a different name are remapped. Images that already existed on the server are left unchanged.
|
||||
|
||||
Both walk the same paths through the state tree:
|
||||
- Layer/mask objects → `CanvasImageState.image.image_name`
|
||||
- Regional guidance ref images → `ImageWithDims.image_name`
|
||||
- Global ref images → `CroppableImageWithDims.original.image.image_name` and `.crop.image.image_name`
|
||||
|
||||
## Extending the Format
|
||||
|
||||
### Adding new optional data (non-breaking)
|
||||
|
||||
Add a new JSON file to the ZIP. No version bump needed.
|
||||
|
||||
1. **Save**: Add `zip.file('new_data.json', JSON.stringify(data))` in `useCanvasProjectSave.ts`
|
||||
2. **Load**: Read with `zip.file('new_data.json')` in `useCanvasProjectLoad.ts` — check for `null` so older project files without it still load
|
||||
3. **Dispatch**: Add the appropriate Redux action to restore the data
|
||||
|
||||
### Adding new entity types with images
|
||||
|
||||
1. Extend `CanvasProjectState` type in `canvasProjectFile.ts`
|
||||
2. Add collection logic in `collectImageNames()` to walk the new entity's objects
|
||||
3. Add remapping logic in `remapCanvasState()` to update image names
|
||||
4. Include the new entity array in both save and load hooks
|
||||
5. Handle it in the `canvasProjectRecalled` reducer in `canvasSlice.ts`
|
||||
|
||||
### Breaking schema changes
|
||||
|
||||
1. Bump `CANVAS_PROJECT_VERSION` in `canvasProjectFile.ts`
|
||||
2. Update the Zod manifest schema: `version: z.union([z.literal(1), z.literal(2)])`
|
||||
3. Add migration logic in the load hook: check version, transform v1 → v2 before dispatching
|
||||
|
||||
## UI Architecture
|
||||
|
||||
### Save dialog
|
||||
|
||||
The save flow uses a **nanostore atom** (`$isOpen`) to control the `SaveCanvasProjectDialog`:
|
||||
|
||||
1. `useSaveCanvasProjectWithDialog()` — returns a callback that sets `$isOpen` to `true`
|
||||
2. `SaveCanvasProjectDialog` (singleton in `GlobalModalIsolator`) — renders an `AlertDialog` with a name input
|
||||
3. On save → calls `saveCanvasProject(name)` and closes the dialog
|
||||
4. On cancel → closes the dialog
|
||||
|
||||
### Load dialog
|
||||
|
||||
The load flow uses a **nanostore atom** (`$pendingFile`) to decouple the file dialog from the confirmation dialog:
|
||||
|
||||
1. `useLoadCanvasProjectWithDialog()` — opens a programmatic file input (`document.createElement('input')`)
|
||||
2. On file selection → sets `$pendingFile` atom
|
||||
3. `LoadCanvasProjectConfirmationAlertDialog` (singleton in `GlobalModalIsolator`) — subscribes to `$pendingFile` via `useStore()`
|
||||
4. On accept → calls `loadCanvasProject(file)` and clears the atom
|
||||
5. On cancel → clears the atom
|
||||
|
||||
The programmatic file input approach was chosen because the context menu component uses `isLazy: true`, which unmounts the DOM tree when the menu closes — a hidden `<input>` element inside the menu would be destroyed before the file dialog returns.
|
||||
129
docs/contributing/EXTERNAL_PROVIDERS.md
Normal file
129
docs/contributing/EXTERNAL_PROVIDERS.md
Normal file
@@ -0,0 +1,129 @@
|
||||
# External Provider Integration
|
||||
|
||||
This guide covers:
|
||||
|
||||
1. Adding a new **external model** (most common; existing provider).
|
||||
2. Adding a brand-new **external provider** (adapter + config + UI wiring).
|
||||
|
||||
## 1) Add a New External Model (Existing Provider)
|
||||
|
||||
For provider-backed models (for example, OpenAI or Gemini), the source of truth is
|
||||
`invokeai/backend/model_manager/starter_models.py`.
|
||||
|
||||
### Required model fields
|
||||
|
||||
Define a `StarterModel` with:
|
||||
|
||||
- `base=BaseModelType.External`
|
||||
- `type=ModelType.ExternalImageGenerator`
|
||||
- `format=ModelFormat.ExternalApi`
|
||||
- `source="external://<provider_id>/<provider_model_id>"`
|
||||
- `name`, `description`
|
||||
- `capabilities=ExternalModelCapabilities(...)`
|
||||
- optional `default_settings=ExternalApiModelDefaultSettings(...)`
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
new_external_model = StarterModel(
|
||||
name="Provider Model Name",
|
||||
base=BaseModelType.External,
|
||||
source="external://openai/my-model-id",
|
||||
description=(
|
||||
"Provider model (external API). "
|
||||
"Requires a configured OpenAI API key and may incur provider usage costs."
|
||||
),
|
||||
type=ModelType.ExternalImageGenerator,
|
||||
format=ModelFormat.ExternalApi,
|
||||
capabilities=ExternalModelCapabilities(
|
||||
modes=["txt2img", "img2img", "inpaint"],
|
||||
supports_negative_prompt=False,
|
||||
supports_seed=False,
|
||||
supports_guidance=False,
|
||||
supports_steps=False,
|
||||
supports_reference_images=True,
|
||||
max_images_per_request=4,
|
||||
),
|
||||
default_settings=ExternalApiModelDefaultSettings(
|
||||
width=1024,
|
||||
height=1024,
|
||||
num_images=1,
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
Then append it to `STARTER_MODELS`.
|
||||
|
||||
### Required description text
|
||||
|
||||
External starter model descriptions must clearly state:
|
||||
|
||||
- an API key is required
|
||||
- usage may incur provider-side costs
|
||||
|
||||
### Capabilities must be accurate
|
||||
|
||||
These flags directly control UI visibility and request payload fields:
|
||||
|
||||
- `supports_negative_prompt`
|
||||
- `supports_seed`
|
||||
- `supports_guidance`
|
||||
- `supports_steps`
|
||||
- `supports_reference_images`
|
||||
|
||||
`supports_steps` is especially important: if `False`, steps are hidden for that model and `steps` is sent as `null`.
|
||||
|
||||
### Source string stability
|
||||
|
||||
Starter overrides are matched by `source` (`external://provider/model-id`). Keep this stable:
|
||||
|
||||
- runtime capability/default overrides depend on it
|
||||
- installation detection in starter-model APIs depends on it
|
||||
|
||||
`STARTER_MODELS` enforces unique `source` values with an assertion.
|
||||
|
||||
### Install behavior notes
|
||||
|
||||
- External starter models are managed in **External Providers** setup (not the regular Starter Models tab).
|
||||
- External starter models auto-install when a provider is configured.
|
||||
- Removing a provider API key removes installed external models for that provider.
|
||||
|
||||
## 2) Credentials and Config
|
||||
|
||||
External provider API keys are stored separately from `invokeai.yaml`:
|
||||
|
||||
- default file: `~/invokeai/api_keys.yaml`
|
||||
- resolved path: `<INVOKEAI_ROOT>/api_keys.yaml`
|
||||
|
||||
Non-secret provider settings (for example base URL overrides) stay in `invokeai.yaml`.
|
||||
|
||||
Environment variables are still supported, e.g.:
|
||||
|
||||
- `INVOKEAI_EXTERNAL_GEMINI_API_KEY`
|
||||
- `INVOKEAI_EXTERNAL_OPENAI_API_KEY`
|
||||
|
||||
## 3) Add a New Provider (Only If Needed)
|
||||
|
||||
If your model uses a provider that is not already integrated:
|
||||
|
||||
1. Add config fields in `invokeai/app/services/config/config_default.py`
|
||||
`external_<provider>_api_key` and optional `external_<provider>_base_url`.
|
||||
2. Add provider field mapping in `invokeai/app/api/routers/app_info.py`
|
||||
(`EXTERNAL_PROVIDER_FIELDS`).
|
||||
3. Implement provider adapter in `invokeai/app/services/external_generation/providers/`
|
||||
by subclassing `ExternalProvider`.
|
||||
4. Register the provider in `invokeai/app/api/dependencies.py` when building
|
||||
`ExternalGenerationService`.
|
||||
5. Add starter model entries using `source="external://<provider>/<model-id>"`.
|
||||
6. Optional UI ordering tweak:
|
||||
`invokeai/frontend/web/src/features/modelManagerV2/subpanels/AddModelPanel/ExternalProviders/ExternalProvidersForm.tsx`
|
||||
(`PROVIDER_SORT_ORDER`).
|
||||
|
||||
## 4) Optional Manual Installation
|
||||
|
||||
You can also install external models directly via:
|
||||
|
||||
`POST /api/v2/models/install?source=external://<provider_id>/<provider_model_id>`
|
||||
|
||||
If omitted, `path`, `source`, and `hash` are auto-populated for external model configs.
|
||||
Set capabilities conservatively; the external generation service enforces capability checks at runtime.
|
||||
@@ -8,6 +8,10 @@ We welcome contributions, whether features, bug fixes, code cleanup, testing, co
|
||||
|
||||
If you’d like to help with development, please see our [development guide](contribution_guides/development.md).
|
||||
|
||||
## External Providers
|
||||
|
||||
If you are adding external image generation providers or configs, see our [external provider integration guide](EXTERNAL_PROVIDERS.md).
|
||||
|
||||
**New Contributors:** If you’re unfamiliar with contributing to open source projects, take a look at our [new contributor guide](contribution_guides/newContributorChecklist.md).
|
||||
|
||||
## Nodes
|
||||
|
||||
32
docs/features/Lasso_tool.md
Normal file
32
docs/features/Lasso_tool.md
Normal file
@@ -0,0 +1,32 @@
|
||||
Lasso Tool
|
||||
===========
|
||||
|
||||
- The Lasso tool creates selections and inpaint masks by drawing freehand or polygonal regions on the canvas.
|
||||
|
||||
How to open the Lasso tool
|
||||
--------------------------
|
||||
- Click the Lasso icon in the toolbar.
|
||||
- Hotkey: press `L` (default). The hotkey is shown in the tool's tooltip and can be customized in Hotkeys settings.
|
||||
|
||||
Modes
|
||||
-----
|
||||
- Freehand (default)
|
||||
- Hold the pointer and drag to draw a continuous contour.
|
||||
- Long segments are broken into intermediate points to keep the line continuous.
|
||||
- Very long strokes may be simplified after drawing to reduce point count for performance.
|
||||
|
||||
- Polygon
|
||||
- Click to place points; click the first point (or a point near it) to close the polygon.
|
||||
- The tool snaps the closing point to the start for precise closures.
|
||||
|
||||
Basic interactions
|
||||
------------------
|
||||
- Switch modes with the mode toggle in the toolbar.
|
||||
- To close a polygon: click the starting point again or click near it — the tool aligns the final point to the start to complete the shape.
|
||||
- The selection will be added to the current Inpaint Mask layer. If no Inpaint Mask layer exists, a new one will be created automatically.
|
||||
|
||||
Tips & behavior
|
||||
---------------
|
||||
- Hold `Space` to temporarily switch to the View tool for panning and zooming; release `Space` to return to the Lasso tool and continue drawing.
|
||||
- When using the Polygon mode, you can hold `Shift` to snap points to horizontal, vertical, or 45-degree angles for more precise shapes.
|
||||
- Hold `Ctrl` (Windows/Linux) or `Command` (macOS) while drawing to subtract from the current selection instead of adding to it.
|
||||
56
docs/features/canvas_projects.md
Normal file
56
docs/features/canvas_projects.md
Normal file
@@ -0,0 +1,56 @@
|
||||
---
|
||||
title: Canvas Projects
|
||||
---
|
||||
|
||||
# :material-folder-zip: Canvas Projects
|
||||
|
||||
## Save and Restore Your Canvas Work
|
||||
|
||||
Canvas Projects let you save your entire canvas setup to a file and load it back later. This is useful when you want to:
|
||||
|
||||
- **Switch between tasks** without losing your current canvas arrangement
|
||||
- **Back up complex setups** with multiple layers, masks, and reference images
|
||||
- **Share canvas layouts** with others or transfer them between machines
|
||||
- **Recover from deleted images** — all images are embedded in the project file
|
||||
|
||||
## What Gets Saved
|
||||
|
||||
A canvas project file (`.invk`) captures everything about your current canvas session:
|
||||
|
||||
- **All layers** — raster layers, control layers, inpaint masks, regional guidance
|
||||
- **All drawn content** — brush strokes, pasted images, eraser marks
|
||||
- **Reference images** — global IP-Adapter / FLUX Redux images with crop settings
|
||||
- **Regional guidance** — per-region prompts and reference images
|
||||
- **Bounding box** — position, size, aspect ratio, and scale settings
|
||||
- **All generation parameters** — prompts, seed, steps, CFG scale, guidance, scheduler, model, VAE, dimensions, img2img strength, infill settings, canvas coherence, refiner settings, FLUX/Z-Image specific parameters, and more
|
||||
- **LoRAs** — all added LoRA models with their weights and enabled/disabled state
|
||||
|
||||
## How to Save a Project
|
||||
|
||||
You can save from two places:
|
||||
|
||||
1. **Toolbar** — Click the **Archive icon** in the canvas toolbar, then select **Save Canvas Project**
|
||||
2. **Context menu** — Right-click the canvas, open the **Project** submenu, then select **Save Canvas Project**
|
||||
|
||||
A dialog will ask you to enter a **project name**. This name is used as the filename (e.g., entering "My Portrait" saves as `My Portrait.invk`) and is stored inside the project file.
|
||||
|
||||
## How to Load a Project
|
||||
|
||||
1. **Toolbar** — Click the **Archive icon**, then select **Load Canvas Project**
|
||||
2. **Context menu** — Right-click the canvas, open the **Project** submenu, then select **Load Canvas Project**
|
||||
|
||||
A file dialog will open. Select your `.invk` file. You will see a confirmation dialog warning that loading will replace your current canvas. Click **Load** to proceed.
|
||||
|
||||
### What Happens on Load
|
||||
|
||||
- Your current canvas is **completely replaced** — all existing layers, masks, reference images, and parameters are overwritten
|
||||
- Images that are already present on your InvokeAI server are reused automatically (no duplicate uploads)
|
||||
- Images that were deleted from the server are re-uploaded from the project file
|
||||
- If the saved model is not installed on your system, the model identifier is still restored — you will need to select an available model manually
|
||||
|
||||
## Good to Know
|
||||
|
||||
- **No undo** — Loading a project replaces your canvas entirely. There is no way to undo this action, so save your current project first if you want to keep it.
|
||||
- **Image deduplication** — When loading, images already on your server are not re-uploaded. Only missing images are uploaded from the project file.
|
||||
- **File size** — The `.invk` file size depends on the number and resolution of images in your canvas. A project with many high-resolution layers can be large.
|
||||
- **Model availability** — The project saves which model was selected, but does not include the model itself. If the model is not installed when you load the project, you will need to select a different one.
|
||||
@@ -140,13 +140,13 @@ As a regular user, you can:
|
||||
- ✅ View your own generation queue
|
||||
- ✅ Customize your UI preferences (theme, hotkeys, etc.)
|
||||
- ✅ View available models (read-only access to Model Manager)
|
||||
- ✅ Access shared boards (based on permissions granted to you) (FUTURE FEATURE)
|
||||
- ✅ Access workflows marked as public (FUTURE FEATURE)
|
||||
- ✅ View shared and public boards created by other users
|
||||
- ✅ View and use workflows marked as shared by other users
|
||||
|
||||
You cannot:
|
||||
|
||||
- ❌ Add, delete, or modify models
|
||||
- ❌ View or modify other users' boards, images, or workflows
|
||||
- ❌ View or modify other users' private boards, images, or workflows
|
||||
- ❌ Manage user accounts
|
||||
- ❌ Access system configuration
|
||||
- ❌ View or cancel other users' generation tasks
|
||||
@@ -173,7 +173,7 @@ Administrators have all regular user capabilities, plus:
|
||||
- ✅ Full model management (add, delete, configure models)
|
||||
- ✅ Create and manage user accounts
|
||||
- ✅ View and manage all users' generation queues
|
||||
- ✅ Create and manage shared boards (FUTURE FEATURE)
|
||||
- ✅ View and manage all users' boards, images, and workflows (including system-owned legacy content)
|
||||
- ✅ Access system configuration
|
||||
- ✅ Grant or revoke admin privileges
|
||||
|
||||
@@ -183,23 +183,30 @@ Administrators have all regular user capabilities, plus:
|
||||
|
||||
### Image Boards
|
||||
|
||||
In multi-user model, Image Boards work as before. Each user can create an unlimited number of boards and organize their images and assets as they see fit. Boards are private: you cannot see a board owned by a different user.
|
||||
In multi-user mode, each user can create an unlimited number of boards and organize their images and assets as they see fit. Boards have three visibility levels:
|
||||
|
||||
!!! tip "Shared Boards"
|
||||
InvokeAI 6.13 will add support for creating public boards that are accessible to all users.
|
||||
- **Private** (default): Only you (and administrators) can see and modify the board.
|
||||
- **Shared**: All users can view the board and its contents, but only you (and administrators) can modify it (rename, archive, delete, or add/remove images).
|
||||
- **Public**: All users can view the board. Only you (and administrators) can modify the board's structure (rename, archive, delete).
|
||||
|
||||
The Administrator can see all users Image Boards and their contents.
|
||||
To change a board's visibility, right-click on the board and select the desired visibility option.
|
||||
|
||||
### Going From Multi-User to Single-User mode
|
||||
Administrators can see and manage all users' image boards and their contents regardless of visibility settings.
|
||||
|
||||
### Going From Multi-User to Single-User Mode
|
||||
|
||||
If an InvokeAI instance was in multiuser mode and then restarted in single user mode (by setting `multiuser: false` in the configuration file), all users' boards will be consolidated in one place. Any images that were in "Uncategorized" will be merged together into a single Uncategorized board. If, at a later date, the server is restarted in multi-user mode, the boards and images will be separated and restored to their owners.
|
||||
|
||||
### Workflows
|
||||
|
||||
In the current released version (6.12) workflows are always shared among users. Any workflow that you create will be visible to other users and vice-versa, and there is no protection against one user modifying another user's workflow.
|
||||
Each user has their own private workflow library. Workflows you create are visible only to you by default.
|
||||
|
||||
!!! tip "Private and Shared Workflows"
|
||||
InvokeAI 6.13 will provide the ability to create private and shared workflows. A private workflow can only be viewed by the user who created it. At any time, however, the user can designate the workflow *shared*, in which case it can be opened on a read-only basis by all logged-in users.
|
||||
You can share a workflow with other users by marking it as **shared** (public). Shared workflows appear in all users' workflow libraries and can be opened by anyone, but only the owner (or an administrator) can modify or delete them.
|
||||
|
||||
To share a workflow, open it and use the sharing controls to toggle its public/shared status.
|
||||
|
||||
!!! warning "Preexisting workflows after enabling multi-user mode"
|
||||
When you enable multi-user mode for the first time on an existing InvokeAI installation, all workflows that were created before multi-user mode was activated will appear in the **shared workflows** section. These preexisting workflows are owned by the internal "system" account and are visible to all users. Administrators can edit or delete these shared legacy workflows. Regular users can view and use them but cannot modify them.
|
||||
|
||||
|
||||
### The Generation Queue
|
||||
@@ -330,11 +337,11 @@ These settings are stored per-user and won't affect other users.
|
||||
|
||||
### Can other users see my images?
|
||||
|
||||
No, unless you add them to a shared board (FUTURE FEATURE). All your personal boards and images are private.
|
||||
Not unless you change your board's visibility to "shared" or "public". All personal boards and images are private by default.
|
||||
|
||||
### Can I share my workflows with others?
|
||||
|
||||
Not directly. Ask your administrator to mark workflows as public if you want to share them.
|
||||
Yes. You can mark any workflow as shared (public), which makes it visible to all users. Other users can view and use shared workflows, but only you or an administrator can modify or delete them.
|
||||
|
||||
### How long do sessions last?
|
||||
|
||||
|
||||
@@ -16,6 +16,9 @@ from invokeai.app.services.client_state_persistence.client_state_persistence_sql
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.download.download_default import DownloadQueueService
|
||||
from invokeai.app.services.events.events_fastapievents import FastAPIEventService
|
||||
from invokeai.app.services.external_generation.external_generation_default import ExternalGenerationService
|
||||
from invokeai.app.services.external_generation.providers import AlibabaCloudProvider, GeminiProvider, OpenAIProvider
|
||||
from invokeai.app.services.external_generation.startup import sync_configured_external_starter_models
|
||||
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from invokeai.app.services.image_records.image_records_sqlite import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images.images_default import ImageService
|
||||
@@ -46,10 +49,12 @@ from invokeai.app.services.users.users_default import UserService
|
||||
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
AnimaConditioningInfo,
|
||||
BasicConditioningInfo,
|
||||
CogView4ConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
FLUXConditioningInfo,
|
||||
QwenImageConditioningInfo,
|
||||
SD3ConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
ZImageConditioningInfo,
|
||||
@@ -140,18 +145,30 @@ class ApiDependencies:
|
||||
SD3ConditioningInfo,
|
||||
CogView4ConditioningInfo,
|
||||
ZImageConditioningInfo,
|
||||
QwenImageConditioningInfo,
|
||||
AnimaConditioningInfo,
|
||||
],
|
||||
ephemeral=True,
|
||||
),
|
||||
)
|
||||
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
|
||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||
model_record_service = ModelRecordServiceSQL(db=db, logger=logger)
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration,
|
||||
model_record_service=ModelRecordServiceSQL(db=db, logger=logger),
|
||||
model_record_service=model_record_service,
|
||||
download_queue=download_queue_service,
|
||||
events=events,
|
||||
)
|
||||
external_generation = ExternalGenerationService(
|
||||
providers={
|
||||
AlibabaCloudProvider.provider_id: AlibabaCloudProvider(app_config=configuration, logger=logger),
|
||||
GeminiProvider.provider_id: GeminiProvider(app_config=configuration, logger=logger),
|
||||
OpenAIProvider.provider_id: OpenAIProvider(app_config=configuration, logger=logger),
|
||||
},
|
||||
logger=logger,
|
||||
record_store=model_record_service,
|
||||
)
|
||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||
model_relationships = ModelRelationshipsService()
|
||||
model_relationship_records = SqliteModelRelationshipRecordStorage(db=db)
|
||||
names = SimpleNameService()
|
||||
@@ -184,6 +201,7 @@ class ApiDependencies:
|
||||
model_relationships=model_relationships,
|
||||
model_relationship_records=model_relationship_records,
|
||||
download_queue=download_queue_service,
|
||||
external_generation=external_generation,
|
||||
names=names,
|
||||
performance_statistics=performance_statistics,
|
||||
session_processor=session_processor,
|
||||
@@ -200,6 +218,16 @@ class ApiDependencies:
|
||||
)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
configured_external_providers = {
|
||||
provider_id
|
||||
for provider_id, status in external_generation.get_provider_statuses().items()
|
||||
if status.configured
|
||||
}
|
||||
sync_configured_external_starter_models(
|
||||
configured_provider_ids=configured_external_providers,
|
||||
model_manager=model_manager,
|
||||
logger=logger,
|
||||
)
|
||||
db.clean()
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,15 +1,30 @@
|
||||
import locale
|
||||
from enum import Enum
|
||||
from importlib.metadata import distributions
|
||||
from pathlib import Path as FilePath
|
||||
from threading import Lock
|
||||
|
||||
import torch
|
||||
from fastapi import Body
|
||||
import yaml
|
||||
from fastapi import Body, HTTPException, Path
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.auth_dependencies import AdminUserOrDefault
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config
|
||||
from invokeai.app.services.config.config_default import (
|
||||
EXTERNAL_PROVIDER_CONFIG_FIELDS,
|
||||
DefaultInvokeAIAppConfig,
|
||||
InvokeAIAppConfig,
|
||||
get_config,
|
||||
load_and_migrate_config,
|
||||
load_external_api_keys,
|
||||
)
|
||||
from invokeai.app.services.external_generation.external_generation_common import ExternalProviderStatus
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
from invokeai.backend.util.logging import logging
|
||||
from invokeai.version import __version__
|
||||
|
||||
@@ -41,7 +56,7 @@ async def get_version() -> AppVersion:
|
||||
async def get_app_deps() -> dict[str, str]:
|
||||
deps: dict[str, str] = {dist.metadata["Name"]: dist.version for dist in distributions()}
|
||||
try:
|
||||
cuda = torch.version.cuda or "N/A"
|
||||
cuda = getattr(getattr(torch, "version", None), "cuda", None) or "N/A" # pyright: ignore[reportAttributeAccessIssue]
|
||||
except Exception:
|
||||
cuda = "N/A"
|
||||
|
||||
@@ -64,6 +79,41 @@ class InvokeAIAppConfigWithSetFields(BaseModel):
|
||||
config: InvokeAIAppConfig = Field(description="The InvokeAI App Config")
|
||||
|
||||
|
||||
class ExternalProviderStatusModel(BaseModel):
|
||||
provider_id: str = Field(description="The external provider identifier")
|
||||
configured: bool = Field(description="Whether credentials are configured for the provider")
|
||||
message: str | None = Field(default=None, description="Optional provider status detail")
|
||||
|
||||
|
||||
class ExternalProviderConfigUpdate(BaseModel):
|
||||
api_key: str | None = Field(default=None, description="API key for the external provider")
|
||||
base_url: str | None = Field(default=None, description="Optional base URL override for the provider")
|
||||
|
||||
|
||||
class ExternalProviderConfigModel(BaseModel):
|
||||
provider_id: str = Field(description="The external provider identifier")
|
||||
api_key_configured: bool = Field(description="Whether an API key is configured")
|
||||
base_url: str | None = Field(default=None, description="Optional base URL override")
|
||||
|
||||
|
||||
EXTERNAL_PROVIDER_FIELDS: dict[str, tuple[str, str]] = {
|
||||
"alibabacloud": ("external_alibabacloud_api_key", "external_alibabacloud_base_url"),
|
||||
"gemini": ("external_gemini_api_key", "external_gemini_base_url"),
|
||||
"openai": ("external_openai_api_key", "external_openai_base_url"),
|
||||
}
|
||||
_EXTERNAL_PROVIDER_CONFIG_LOCK = Lock()
|
||||
|
||||
|
||||
class UpdateAppGenerationSettingsRequest(BaseModel):
|
||||
"""Writable generation-related app settings."""
|
||||
|
||||
max_queue_history: int | None = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
description="Keep the last N completed, failed, and canceled queue items on startup. Set to 0 to prune all terminal items.",
|
||||
)
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/runtime_config", operation_id="get_runtime_config", status_code=200, response_model=InvokeAIAppConfigWithSetFields
|
||||
)
|
||||
@@ -72,6 +122,190 @@ async def get_runtime_config() -> InvokeAIAppConfigWithSetFields:
|
||||
return InvokeAIAppConfigWithSetFields(set_fields=config.model_fields_set, config=config)
|
||||
|
||||
|
||||
@app_router.patch(
|
||||
"/runtime_config",
|
||||
operation_id="update_runtime_config",
|
||||
status_code=200,
|
||||
response_model=InvokeAIAppConfigWithSetFields,
|
||||
)
|
||||
async def update_runtime_config(
|
||||
_: AdminUserOrDefault,
|
||||
changes: UpdateAppGenerationSettingsRequest = Body(description="Writable runtime configuration changes"),
|
||||
) -> InvokeAIAppConfigWithSetFields:
|
||||
config = get_config()
|
||||
update_dict = changes.model_dump(exclude_unset=True)
|
||||
config.update_config(update_dict)
|
||||
|
||||
if config.config_file_path.exists():
|
||||
persisted_config = load_and_migrate_config(config.config_file_path)
|
||||
else:
|
||||
persisted_config = DefaultInvokeAIAppConfig()
|
||||
|
||||
persisted_config.update_config(update_dict)
|
||||
persisted_config.write_file(config.config_file_path)
|
||||
return InvokeAIAppConfigWithSetFields(set_fields=config.model_fields_set, config=config)
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/external_providers/status",
|
||||
operation_id="get_external_provider_statuses",
|
||||
status_code=200,
|
||||
response_model=list[ExternalProviderStatusModel],
|
||||
)
|
||||
async def get_external_provider_statuses() -> list[ExternalProviderStatusModel]:
|
||||
statuses = ApiDependencies.invoker.services.external_generation.get_provider_statuses()
|
||||
return [status_to_model(status) for status in statuses.values()]
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/external_providers/config",
|
||||
operation_id="get_external_provider_configs",
|
||||
status_code=200,
|
||||
response_model=list[ExternalProviderConfigModel],
|
||||
)
|
||||
async def get_external_provider_configs() -> list[ExternalProviderConfigModel]:
|
||||
config = get_config()
|
||||
return [_build_external_provider_config(provider_id, config) for provider_id in EXTERNAL_PROVIDER_FIELDS]
|
||||
|
||||
|
||||
@app_router.post(
|
||||
"/external_providers/config/{provider_id}",
|
||||
operation_id="set_external_provider_config",
|
||||
status_code=200,
|
||||
response_model=ExternalProviderConfigModel,
|
||||
)
|
||||
async def set_external_provider_config(
|
||||
provider_id: str = Path(description="The external provider identifier"),
|
||||
update: ExternalProviderConfigUpdate = Body(description="External provider configuration settings"),
|
||||
) -> ExternalProviderConfigModel:
|
||||
api_key_field, base_url_field = _get_external_provider_fields(provider_id)
|
||||
updates: dict[str, str | None] = {}
|
||||
|
||||
if update.api_key is not None:
|
||||
api_key = update.api_key.strip()
|
||||
updates[api_key_field] = api_key or None
|
||||
if update.base_url is not None:
|
||||
base_url = update.base_url.strip()
|
||||
updates[base_url_field] = base_url or None
|
||||
|
||||
if not updates:
|
||||
raise HTTPException(status_code=400, detail="No external provider config fields provided")
|
||||
|
||||
api_key_removed = update.api_key is not None and updates.get(api_key_field) is None
|
||||
_apply_external_provider_update(updates)
|
||||
if api_key_removed:
|
||||
_remove_external_models_for_provider(provider_id)
|
||||
return _build_external_provider_config(provider_id, get_config())
|
||||
|
||||
|
||||
@app_router.delete(
|
||||
"/external_providers/config/{provider_id}",
|
||||
operation_id="reset_external_provider_config",
|
||||
status_code=200,
|
||||
response_model=ExternalProviderConfigModel,
|
||||
)
|
||||
async def reset_external_provider_config(
|
||||
provider_id: str = Path(description="The external provider identifier"),
|
||||
) -> ExternalProviderConfigModel:
|
||||
api_key_field, base_url_field = _get_external_provider_fields(provider_id)
|
||||
_apply_external_provider_update({api_key_field: None, base_url_field: None})
|
||||
_remove_external_models_for_provider(provider_id)
|
||||
return _build_external_provider_config(provider_id, get_config())
|
||||
|
||||
|
||||
def status_to_model(status: ExternalProviderStatus) -> ExternalProviderStatusModel:
|
||||
return ExternalProviderStatusModel(
|
||||
provider_id=status.provider_id,
|
||||
configured=status.configured,
|
||||
message=status.message,
|
||||
)
|
||||
|
||||
|
||||
def _get_external_provider_fields(provider_id: str) -> tuple[str, str]:
|
||||
if provider_id not in EXTERNAL_PROVIDER_FIELDS:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown external provider '{provider_id}'")
|
||||
return EXTERNAL_PROVIDER_FIELDS[provider_id]
|
||||
|
||||
|
||||
def _write_external_api_keys_file(api_keys_file_path: FilePath, api_keys: dict[str, str]) -> None:
|
||||
if not api_keys:
|
||||
if api_keys_file_path.exists():
|
||||
api_keys_file_path.unlink()
|
||||
return
|
||||
|
||||
api_keys_file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(api_keys_file_path, "w", encoding=locale.getpreferredencoding()) as api_keys_file:
|
||||
yaml.safe_dump(api_keys, api_keys_file, sort_keys=False)
|
||||
|
||||
|
||||
def _apply_external_provider_update(updates: dict[str, str | None]) -> None:
|
||||
with _EXTERNAL_PROVIDER_CONFIG_LOCK:
|
||||
runtime_config = get_config()
|
||||
config_path = runtime_config.config_file_path
|
||||
api_keys_file_path = runtime_config.api_keys_file_path
|
||||
if config_path.exists():
|
||||
file_config = load_and_migrate_config(config_path)
|
||||
else:
|
||||
file_config = DefaultInvokeAIAppConfig()
|
||||
|
||||
runtime_config.update_config(updates)
|
||||
provider_config_fields = set(EXTERNAL_PROVIDER_CONFIG_FIELDS)
|
||||
provider_updates = {field: value for field, value in updates.items() if field in provider_config_fields}
|
||||
non_provider_updates = {field: value for field, value in updates.items() if field not in provider_config_fields}
|
||||
|
||||
if non_provider_updates:
|
||||
file_config.update_config(non_provider_updates)
|
||||
|
||||
persisted_api_keys = load_external_api_keys(api_keys_file_path)
|
||||
for field_name in EXTERNAL_PROVIDER_CONFIG_FIELDS:
|
||||
file_value = getattr(file_config, field_name, None)
|
||||
if field_name not in persisted_api_keys and isinstance(file_value, str) and file_value.strip():
|
||||
persisted_api_keys[field_name] = file_value
|
||||
|
||||
for field_name, value in provider_updates.items():
|
||||
if value is None:
|
||||
persisted_api_keys.pop(field_name, None)
|
||||
else:
|
||||
persisted_api_keys[field_name] = value
|
||||
|
||||
_write_external_api_keys_file(api_keys_file_path, persisted_api_keys)
|
||||
|
||||
for field_name in EXTERNAL_PROVIDER_CONFIG_FIELDS:
|
||||
setattr(file_config, field_name, None)
|
||||
|
||||
file_config_to_write = type(file_config).model_validate(
|
||||
file_config.model_dump(exclude_unset=True, exclude_none=True)
|
||||
)
|
||||
file_config_to_write.write_file(config_path, as_example=False)
|
||||
|
||||
|
||||
def _build_external_provider_config(provider_id: str, config: InvokeAIAppConfig) -> ExternalProviderConfigModel:
|
||||
api_key_field, base_url_field = _get_external_provider_fields(provider_id)
|
||||
return ExternalProviderConfigModel(
|
||||
provider_id=provider_id,
|
||||
api_key_configured=bool(getattr(config, api_key_field)),
|
||||
base_url=getattr(config, base_url_field),
|
||||
)
|
||||
|
||||
|
||||
def _remove_external_models_for_provider(provider_id: str) -> None:
|
||||
model_manager = ApiDependencies.invoker.services.model_manager
|
||||
external_models = model_manager.store.search_by_attr(
|
||||
base_model=BaseModelType.External,
|
||||
model_type=ModelType.ExternalImageGenerator,
|
||||
)
|
||||
|
||||
for model in external_models:
|
||||
if getattr(model, "provider_id", None) != provider_id:
|
||||
continue
|
||||
try:
|
||||
model_manager.install.delete(model.key)
|
||||
except UnknownModelException:
|
||||
logging.warning(f"External model key '{model.key}' was already removed while resetting '{provider_id}'")
|
||||
except Exception as error:
|
||||
logging.warning(f"Failed removing external model key '{model.key}' for '{provider_id}': {error}")
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/logging",
|
||||
operation_id="get_log_level",
|
||||
|
||||
@@ -79,6 +79,8 @@ class SetupStatusResponse(BaseModel):
|
||||
|
||||
setup_required: bool = Field(description="Whether initial setup is required")
|
||||
multiuser_enabled: bool = Field(description="Whether multiuser mode is enabled")
|
||||
strict_password_checking: bool = Field(description="Whether strict password requirements are enforced")
|
||||
admin_email: str | None = Field(default=None, description="Email of the first active admin user, if any")
|
||||
|
||||
|
||||
@auth_router.get("/status", response_model=SetupStatusResponse)
|
||||
@@ -92,13 +94,27 @@ async def get_setup_status() -> SetupStatusResponse:
|
||||
|
||||
# If multiuser is disabled, setup is never required
|
||||
if not config.multiuser:
|
||||
return SetupStatusResponse(setup_required=False, multiuser_enabled=False)
|
||||
return SetupStatusResponse(
|
||||
setup_required=False,
|
||||
multiuser_enabled=False,
|
||||
strict_password_checking=config.strict_password_checking,
|
||||
admin_email=None,
|
||||
)
|
||||
|
||||
# In multiuser mode, check if an admin exists
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
setup_required = not user_service.has_admin()
|
||||
|
||||
return SetupStatusResponse(setup_required=setup_required, multiuser_enabled=True)
|
||||
# Only expose admin_email during initial setup to avoid leaking
|
||||
# administrator identity on public deployments.
|
||||
admin_email = user_service.get_admin_email() if setup_required else None
|
||||
|
||||
return SetupStatusResponse(
|
||||
setup_required=setup_required,
|
||||
multiuser_enabled=True,
|
||||
strict_password_checking=config.strict_password_checking,
|
||||
admin_email=admin_email,
|
||||
)
|
||||
|
||||
|
||||
@auth_router.post("/login", response_model=LoginResponse)
|
||||
@@ -145,6 +161,7 @@ async def login(
|
||||
user_id=user.user_id,
|
||||
email=user.email,
|
||||
is_admin=user.is_admin,
|
||||
remember_me=request.remember_me,
|
||||
)
|
||||
token = create_access_token(token_data, expires_delta)
|
||||
|
||||
@@ -248,7 +265,7 @@ async def setup_admin(
|
||||
password=request.password,
|
||||
is_admin=True,
|
||||
)
|
||||
user = user_service.create_admin(user_data)
|
||||
user = user_service.create_admin(user_data, strict_password_checking=config.strict_password_checking)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
|
||||
@@ -359,6 +376,7 @@ async def create_user(
|
||||
HTTPException: 400 if email already exists or password is weak
|
||||
"""
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
try:
|
||||
user_data = UserCreateRequest(
|
||||
email=request.email,
|
||||
@@ -366,7 +384,7 @@ async def create_user(
|
||||
password=request.password,
|
||||
is_admin=request.is_admin,
|
||||
)
|
||||
return user_service.create(user_data)
|
||||
return user_service.create(user_data, strict_password_checking=config.strict_password_checking)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
|
||||
@@ -414,6 +432,7 @@ async def update_user(
|
||||
HTTPException: 404 if user not found
|
||||
"""
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
try:
|
||||
changes = UserUpdateRequest(
|
||||
display_name=request.display_name,
|
||||
@@ -421,7 +440,7 @@ async def update_user(
|
||||
is_admin=request.is_admin,
|
||||
is_active=request.is_active,
|
||||
)
|
||||
return user_service.update(user_id, changes)
|
||||
return user_service.update(user_id, changes, strict_password_checking=config.strict_password_checking)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
|
||||
@@ -483,6 +502,7 @@ async def update_current_user(
|
||||
HTTPException: 404 if user not found
|
||||
"""
|
||||
user_service = ApiDependencies.invoker.services.users
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
|
||||
# Verify current password when attempting a password change
|
||||
if request.new_password is not None:
|
||||
@@ -509,6 +529,8 @@ async def update_current_user(
|
||||
display_name=request.display_name,
|
||||
password=request.new_password,
|
||||
)
|
||||
return user_service.update(current_user.user_id, changes)
|
||||
return user_service.update(
|
||||
current_user.user_id, changes, strict_password_checking=config.strict_password_checking
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
|
||||
|
||||
@@ -1,12 +1,53 @@
|
||||
from fastapi import Body, HTTPException
|
||||
from fastapi.routing import APIRouter
|
||||
|
||||
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.images.images_common import AddImagesToBoardResult, RemoveImagesFromBoardResult
|
||||
|
||||
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
|
||||
|
||||
|
||||
def _assert_board_write_access(board_id: str, current_user: CurrentUserOrDefault) -> None:
|
||||
"""Raise 403 if the current user may not mutate the given board.
|
||||
|
||||
Write access is granted when ANY of these hold:
|
||||
- The user is an admin.
|
||||
- The user owns the board.
|
||||
- The board visibility is Public (public boards accept contributions from any user).
|
||||
"""
|
||||
from invokeai.app.services.board_records.board_records_common import BoardVisibility
|
||||
|
||||
try:
|
||||
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
if current_user.is_admin:
|
||||
return
|
||||
if board.user_id == current_user.user_id:
|
||||
return
|
||||
if board.board_visibility == BoardVisibility.Public:
|
||||
return
|
||||
raise HTTPException(status_code=403, detail="Not authorized to modify this board")
|
||||
|
||||
|
||||
def _assert_image_direct_owner(image_name: str, current_user: CurrentUserOrDefault) -> None:
|
||||
"""Raise 403 if the current user is not the direct owner of the image.
|
||||
|
||||
This is intentionally stricter than _assert_image_owner in images.py:
|
||||
board ownership is NOT sufficient here. Allowing a user to add someone
|
||||
else's image to their own board would grant them mutation rights via the
|
||||
board-ownership fallback in _assert_image_owner, escalating read access
|
||||
into write access.
|
||||
"""
|
||||
if current_user.is_admin:
|
||||
return
|
||||
owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name)
|
||||
if owner is not None and owner == current_user.user_id:
|
||||
return
|
||||
raise HTTPException(status_code=403, detail="Not authorized to move this image")
|
||||
|
||||
|
||||
@board_images_router.post(
|
||||
"/",
|
||||
operation_id="add_image_to_board",
|
||||
@@ -17,14 +58,17 @@ board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
|
||||
response_model=AddImagesToBoardResult,
|
||||
)
|
||||
async def add_image_to_board(
|
||||
current_user: CurrentUserOrDefault,
|
||||
board_id: str = Body(description="The id of the board to add to"),
|
||||
image_name: str = Body(description="The name of the image to add"),
|
||||
) -> AddImagesToBoardResult:
|
||||
"""Creates a board_image"""
|
||||
_assert_board_write_access(board_id, current_user)
|
||||
_assert_image_direct_owner(image_name, current_user)
|
||||
try:
|
||||
added_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
old_board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) or "none"
|
||||
ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||
added_images.add(image_name)
|
||||
affected_boards.add(board_id)
|
||||
@@ -48,13 +92,16 @@ async def add_image_to_board(
|
||||
response_model=RemoveImagesFromBoardResult,
|
||||
)
|
||||
async def remove_image_from_board(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_name: str = Body(description="The name of the image to remove", embed=True),
|
||||
) -> RemoveImagesFromBoardResult:
|
||||
"""Removes an image from its board, if it had one"""
|
||||
try:
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
if old_board_id != "none":
|
||||
_assert_board_write_access(old_board_id, current_user)
|
||||
removed_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
removed_images.add(image_name)
|
||||
affected_boards.add("none")
|
||||
@@ -64,6 +111,8 @@ async def remove_image_from_board(
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to remove image from board")
|
||||
|
||||
@@ -78,16 +127,21 @@ async def remove_image_from_board(
|
||||
response_model=AddImagesToBoardResult,
|
||||
)
|
||||
async def add_images_to_board(
|
||||
current_user: CurrentUserOrDefault,
|
||||
board_id: str = Body(description="The id of the board to add to"),
|
||||
image_names: list[str] = Body(description="The names of the images to add", embed=True),
|
||||
) -> AddImagesToBoardResult:
|
||||
"""Adds a list of images to a board"""
|
||||
_assert_board_write_access(board_id, current_user)
|
||||
try:
|
||||
added_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
_assert_image_direct_owner(image_name, current_user)
|
||||
old_board_id = (
|
||||
ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) or "none"
|
||||
)
|
||||
ApiDependencies.invoker.services.board_images.add_image_to_board(
|
||||
board_id=board_id,
|
||||
image_name=image_name,
|
||||
@@ -96,12 +150,16 @@ async def add_images_to_board(
|
||||
affected_boards.add(board_id)
|
||||
affected_boards.add(old_board_id)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
return AddImagesToBoardResult(
|
||||
added_images=list(added_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to add images to board")
|
||||
|
||||
@@ -116,6 +174,7 @@ async def add_images_to_board(
|
||||
response_model=RemoveImagesFromBoardResult,
|
||||
)
|
||||
async def remove_images_from_board(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_names: list[str] = Body(description="The names of the images to remove", embed=True),
|
||||
) -> RemoveImagesFromBoardResult:
|
||||
"""Removes a list of images from their board, if they had one"""
|
||||
@@ -125,15 +184,21 @@ async def remove_images_from_board(
|
||||
for image_name in image_names:
|
||||
try:
|
||||
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
|
||||
if old_board_id != "none":
|
||||
_assert_board_write_access(old_board_id, current_user)
|
||||
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||
removed_images.add(image_name)
|
||||
affected_boards.add("none")
|
||||
affected_boards.add(old_board_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
return RemoveImagesFromBoardResult(
|
||||
removed_images=list(removed_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to remove images from board")
|
||||
|
||||
@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy
|
||||
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy, BoardVisibility
|
||||
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
@@ -56,7 +56,14 @@ async def get_board(
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
|
||||
if not current_user.is_admin and result.user_id != current_user.user_id:
|
||||
# Admins can access any board.
|
||||
# Owners can access their own boards.
|
||||
# Shared and public boards are visible to all authenticated users.
|
||||
if (
|
||||
not current_user.is_admin
|
||||
and result.user_id != current_user.user_id
|
||||
and result.board_visibility == BoardVisibility.Private
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this board")
|
||||
|
||||
return result
|
||||
@@ -188,7 +195,11 @@ async def list_all_board_image_names(
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
|
||||
if not current_user.is_admin and board.user_id != current_user.user_id:
|
||||
if (
|
||||
not current_user.is_admin
|
||||
and board.user_id != current_user.user_id
|
||||
and board.board_visibility == BoardVisibility.Private
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this board")
|
||||
|
||||
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
|
||||
@@ -196,4 +207,15 @@ async def list_all_board_image_names(
|
||||
categories,
|
||||
is_intermediate,
|
||||
)
|
||||
|
||||
# For uncategorized images (board_id="none"), filter to only the caller's
|
||||
# images so that one user cannot enumerate another's uncategorized images.
|
||||
# Admin users can see all uncategorized images.
|
||||
if board_id == "none" and not current_user.is_admin:
|
||||
image_names = [
|
||||
name
|
||||
for name in image_names
|
||||
if ApiDependencies.invoker.services.image_records.get_user_id(name) == current_user.user_id
|
||||
]
|
||||
|
||||
return image_names
|
||||
|
||||
@@ -38,6 +38,96 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
IMAGE_MAX_AGE = 31536000
|
||||
|
||||
|
||||
def _assert_image_owner(image_name: str, current_user: CurrentUserOrDefault) -> None:
|
||||
"""Raise 403 if the current user does not own the image and is not an admin.
|
||||
|
||||
Ownership is satisfied when ANY of these hold:
|
||||
- The user is an admin.
|
||||
- The user is the image's direct owner (image_records.user_id).
|
||||
- The user owns the board the image sits on.
|
||||
- The image sits on a Public board (public boards grant mutation rights).
|
||||
"""
|
||||
from invokeai.app.services.board_records.board_records_common import BoardVisibility
|
||||
|
||||
if current_user.is_admin:
|
||||
return
|
||||
owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name)
|
||||
if owner is not None and owner == current_user.user_id:
|
||||
return
|
||||
|
||||
# Check whether the user owns the board the image belongs to,
|
||||
# or the board is Public (public boards grant mutation rights).
|
||||
board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name)
|
||||
if board_id is not None:
|
||||
try:
|
||||
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
if board.user_id == current_user.user_id:
|
||||
return
|
||||
if board.board_visibility == BoardVisibility.Public:
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise HTTPException(status_code=403, detail="Not authorized to modify this image")
|
||||
|
||||
|
||||
def _assert_image_read_access(image_name: str, current_user: CurrentUserOrDefault) -> None:
|
||||
"""Raise 403 if the current user may not view the image.
|
||||
|
||||
Access is granted when ANY of these hold:
|
||||
- The user is an admin.
|
||||
- The user owns the image.
|
||||
- The image sits on a shared or public board.
|
||||
"""
|
||||
from invokeai.app.services.board_records.board_records_common import BoardVisibility
|
||||
|
||||
if current_user.is_admin:
|
||||
return
|
||||
|
||||
owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name)
|
||||
if owner is not None and owner == current_user.user_id:
|
||||
return
|
||||
|
||||
# Check whether the image's board makes it visible to other users.
|
||||
board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name)
|
||||
if board_id is not None:
|
||||
try:
|
||||
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public):
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this image")
|
||||
|
||||
|
||||
def _assert_board_read_access(board_id: str, current_user: CurrentUserOrDefault) -> None:
|
||||
"""Raise 403 if the current user may not read images from this board.
|
||||
|
||||
Access is granted when ANY of these hold:
|
||||
- The user is an admin.
|
||||
- The user owns the board.
|
||||
- The board visibility is Shared or Public.
|
||||
"""
|
||||
from invokeai.app.services.board_records.board_records_common import BoardVisibility
|
||||
|
||||
if current_user.is_admin:
|
||||
return
|
||||
|
||||
try:
|
||||
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
|
||||
if board.user_id == current_user.user_id:
|
||||
return
|
||||
|
||||
if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public):
|
||||
return
|
||||
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this board")
|
||||
|
||||
|
||||
class ResizeToDimensions(BaseModel):
|
||||
width: int = Field(..., gt=0)
|
||||
height: int = Field(..., gt=0)
|
||||
@@ -83,6 +173,22 @@ async def upload_image(
|
||||
),
|
||||
) -> ImageDTO:
|
||||
"""Uploads an image for the current user"""
|
||||
# If uploading into a board, verify the user has write access.
|
||||
# Public boards allow uploads from any authenticated user.
|
||||
if board_id is not None:
|
||||
from invokeai.app.services.board_records.board_records_common import BoardVisibility
|
||||
|
||||
try:
|
||||
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
if (
|
||||
not current_user.is_admin
|
||||
and board.user_id != current_user.user_id
|
||||
and board.board_visibility != BoardVisibility.Public
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Not authorized to upload to this board")
|
||||
|
||||
if not file.content_type or not file.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
@@ -165,9 +271,11 @@ async def create_image_upload_entry(
|
||||
|
||||
@images_router.delete("/i/{image_name}", operation_id="delete_image", response_model=DeleteImagesResult)
|
||||
async def delete_image(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
) -> DeleteImagesResult:
|
||||
"""Deletes an image"""
|
||||
_assert_image_owner(image_name, current_user)
|
||||
|
||||
deleted_images: set[str] = set()
|
||||
affected_boards: set[str] = set()
|
||||
@@ -189,26 +297,31 @@ async def delete_image(
|
||||
|
||||
|
||||
@images_router.delete("/intermediates", operation_id="clear_intermediates")
|
||||
async def clear_intermediates() -> int:
|
||||
"""Clears all intermediates"""
|
||||
async def clear_intermediates(
|
||||
current_user: CurrentUserOrDefault,
|
||||
) -> int:
|
||||
"""Clears all intermediates. Requires admin."""
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Only admins can clear all intermediates")
|
||||
|
||||
try:
|
||||
count_deleted = ApiDependencies.invoker.services.images.delete_intermediates()
|
||||
return count_deleted
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to clear intermediates")
|
||||
pass
|
||||
|
||||
|
||||
@images_router.get("/intermediates", operation_id="get_intermediates_count")
|
||||
async def get_intermediates_count() -> int:
|
||||
"""Gets the count of intermediate images"""
|
||||
async def get_intermediates_count(
|
||||
current_user: CurrentUserOrDefault,
|
||||
) -> int:
|
||||
"""Gets the count of intermediate images. Non-admin users only see their own intermediates."""
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_intermediates_count()
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.images.get_intermediates_count(user_id=user_id)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to get intermediates")
|
||||
pass
|
||||
|
||||
|
||||
@images_router.patch(
|
||||
@@ -217,10 +330,12 @@ async def get_intermediates_count() -> int:
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def update_image(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_name: str = Path(description="The name of the image to update"),
|
||||
image_changes: ImageRecordChanges = Body(description="The changes to apply to the image"),
|
||||
) -> ImageDTO:
|
||||
"""Updates an image"""
|
||||
_assert_image_owner(image_name, current_user)
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.update(image_name, image_changes)
|
||||
@@ -234,9 +349,11 @@ async def update_image(
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def get_image_dto(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_name: str = Path(description="The name of image to get"),
|
||||
) -> ImageDTO:
|
||||
"""Gets an image's DTO"""
|
||||
_assert_image_read_access(image_name, current_user)
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_dto(image_name)
|
||||
@@ -250,9 +367,11 @@ async def get_image_dto(
|
||||
response_model=Optional[MetadataField],
|
||||
)
|
||||
async def get_image_metadata(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_name: str = Path(description="The name of image to get"),
|
||||
) -> Optional[MetadataField]:
|
||||
"""Gets an image's metadata"""
|
||||
_assert_image_read_access(image_name, current_user)
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_metadata(image_name)
|
||||
@@ -269,8 +388,11 @@ class WorkflowAndGraphResponse(BaseModel):
|
||||
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=WorkflowAndGraphResponse
|
||||
)
|
||||
async def get_image_workflow(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_name: str = Path(description="The name of image whose workflow to get"),
|
||||
) -> WorkflowAndGraphResponse:
|
||||
_assert_image_read_access(image_name, current_user)
|
||||
|
||||
try:
|
||||
workflow = ApiDependencies.invoker.services.images.get_workflow(image_name)
|
||||
graph = ApiDependencies.invoker.services.images.get_graph(image_name)
|
||||
@@ -306,8 +428,12 @@ async def get_image_workflow(
|
||||
async def get_image_full(
|
||||
image_name: str = Path(description="The name of full-resolution image file to get"),
|
||||
) -> Response:
|
||||
"""Gets a full-resolution image file"""
|
||||
"""Gets a full-resolution image file.
|
||||
|
||||
This endpoint is intentionally unauthenticated because browsers load images
|
||||
via <img src> tags which cannot send Bearer tokens. Image names are UUIDs,
|
||||
providing security through unguessability.
|
||||
"""
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images.get_path(image_name)
|
||||
with open(path, "rb") as f:
|
||||
@@ -335,8 +461,12 @@ async def get_image_full(
|
||||
async def get_image_thumbnail(
|
||||
image_name: str = Path(description="The name of thumbnail image file to get"),
|
||||
) -> Response:
|
||||
"""Gets a thumbnail image file"""
|
||||
"""Gets a thumbnail image file.
|
||||
|
||||
This endpoint is intentionally unauthenticated because browsers load images
|
||||
via <img src> tags which cannot send Bearer tokens. Image names are UUIDs,
|
||||
providing security through unguessability.
|
||||
"""
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images.get_path(image_name, thumbnail=True)
|
||||
with open(path, "rb") as f:
|
||||
@@ -354,9 +484,11 @@ async def get_image_thumbnail(
|
||||
response_model=ImageUrlsDTO,
|
||||
)
|
||||
async def get_image_urls(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_name: str = Path(description="The name of the image whose URL to get"),
|
||||
) -> ImageUrlsDTO:
|
||||
"""Gets an image and thumbnail URL"""
|
||||
_assert_image_read_access(image_name, current_user)
|
||||
|
||||
try:
|
||||
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
|
||||
@@ -392,6 +524,11 @@ async def list_image_dtos(
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a list of image DTOs for the current user"""
|
||||
|
||||
# Validate that the caller can read from this board before listing its images.
|
||||
# "none" is a sentinel for uncategorized images and is handled by the SQL layer.
|
||||
if board_id is not None and board_id != "none":
|
||||
_assert_board_read_access(board_id, current_user)
|
||||
|
||||
image_dtos = ApiDependencies.invoker.services.images.get_many(
|
||||
offset,
|
||||
limit,
|
||||
@@ -410,6 +547,7 @@ async def list_image_dtos(
|
||||
|
||||
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesResult)
|
||||
async def delete_images_from_list(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_names: list[str] = Body(description="The list of names of images to delete", embed=True),
|
||||
) -> DeleteImagesResult:
|
||||
try:
|
||||
@@ -417,24 +555,31 @@ async def delete_images_from_list(
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
_assert_image_owner(image_name, current_user)
|
||||
image_dto = ApiDependencies.invoker.services.images.get_dto(image_name)
|
||||
board_id = image_dto.board_id or "none"
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
deleted_images.add(image_name)
|
||||
affected_boards.add(board_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
return DeleteImagesResult(
|
||||
deleted_images=list(deleted_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete images")
|
||||
|
||||
|
||||
@images_router.delete("/uncategorized", operation_id="delete_uncategorized_images", response_model=DeleteImagesResult)
|
||||
async def delete_uncategorized_images() -> DeleteImagesResult:
|
||||
"""Deletes all images that are uncategorized"""
|
||||
async def delete_uncategorized_images(
|
||||
current_user: CurrentUserOrDefault,
|
||||
) -> DeleteImagesResult:
|
||||
"""Deletes all uncategorized images owned by the current user (or all if admin)"""
|
||||
|
||||
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
|
||||
board_id="none", categories=None, is_intermediate=None
|
||||
@@ -445,9 +590,13 @@ async def delete_uncategorized_images() -> DeleteImagesResult:
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
_assert_image_owner(image_name, current_user)
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
deleted_images.add(image_name)
|
||||
affected_boards.add("none")
|
||||
except HTTPException:
|
||||
# Skip images not owned by the current user
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
return DeleteImagesResult(
|
||||
@@ -464,6 +613,7 @@ class ImagesUpdatedFromListResult(BaseModel):
|
||||
|
||||
@images_router.post("/star", operation_id="star_images_in_list", response_model=StarredImagesResult)
|
||||
async def star_images_in_list(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_names: list[str] = Body(description="The list of names of images to star", embed=True),
|
||||
) -> StarredImagesResult:
|
||||
try:
|
||||
@@ -471,23 +621,29 @@ async def star_images_in_list(
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
_assert_image_owner(image_name, current_user)
|
||||
updated_image_dto = ApiDependencies.invoker.services.images.update(
|
||||
image_name, changes=ImageRecordChanges(starred=True)
|
||||
)
|
||||
starred_images.add(image_name)
|
||||
affected_boards.add(updated_image_dto.board_id or "none")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
return StarredImagesResult(
|
||||
starred_images=list(starred_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to star images")
|
||||
|
||||
|
||||
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=UnstarredImagesResult)
|
||||
async def unstar_images_in_list(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_names: list[str] = Body(description="The list of names of images to unstar", embed=True),
|
||||
) -> UnstarredImagesResult:
|
||||
try:
|
||||
@@ -495,17 +651,22 @@ async def unstar_images_in_list(
|
||||
affected_boards: set[str] = set()
|
||||
for image_name in image_names:
|
||||
try:
|
||||
_assert_image_owner(image_name, current_user)
|
||||
updated_image_dto = ApiDependencies.invoker.services.images.update(
|
||||
image_name, changes=ImageRecordChanges(starred=False)
|
||||
)
|
||||
unstarred_images.add(image_name)
|
||||
affected_boards.add(updated_image_dto.board_id or "none")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
return UnstarredImagesResult(
|
||||
unstarred_images=list(unstarred_images),
|
||||
affected_boards=list(affected_boards),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to unstar images")
|
||||
|
||||
@@ -523,6 +684,7 @@ class ImagesDownloaded(BaseModel):
|
||||
"/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202
|
||||
)
|
||||
async def download_images_from_list(
|
||||
current_user: CurrentUserOrDefault,
|
||||
background_tasks: BackgroundTasks,
|
||||
image_names: Optional[list[str]] = Body(
|
||||
default=None, description="The list of names of images to download", embed=True
|
||||
@@ -533,6 +695,16 @@ async def download_images_from_list(
|
||||
) -> ImagesDownloaded:
|
||||
if (image_names is None or len(image_names) == 0) and board_id is None:
|
||||
raise HTTPException(status_code=400, detail="No images or board id specified.")
|
||||
|
||||
# Validate that the caller can read every image they are requesting.
|
||||
# For a board_id request, check board visibility; for explicit image names,
|
||||
# check each image individually.
|
||||
if board_id:
|
||||
_assert_board_read_access(board_id, current_user)
|
||||
if image_names:
|
||||
for name in image_names:
|
||||
_assert_image_read_access(name, current_user)
|
||||
|
||||
bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id)
|
||||
|
||||
background_tasks.add_task(
|
||||
@@ -540,6 +712,7 @@ async def download_images_from_list(
|
||||
image_names,
|
||||
board_id,
|
||||
bulk_download_item_id,
|
||||
current_user.user_id,
|
||||
)
|
||||
return ImagesDownloaded(bulk_download_item_name=bulk_download_item_id + ".zip")
|
||||
|
||||
@@ -558,11 +731,21 @@ async def download_images_from_list(
|
||||
},
|
||||
)
|
||||
async def get_bulk_download_item(
|
||||
current_user: CurrentUserOrDefault,
|
||||
background_tasks: BackgroundTasks,
|
||||
bulk_download_item_name: str = Path(description="The bulk_download_item_name of the bulk download item to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a bulk download zip file"""
|
||||
"""Gets a bulk download zip file.
|
||||
|
||||
Requires authentication. The caller must be the user who initiated the
|
||||
download (tracked by the bulk download service) or an admin.
|
||||
"""
|
||||
try:
|
||||
# Verify the caller owns this download (or is an admin)
|
||||
owner = ApiDependencies.invoker.services.bulk_download.get_owner(bulk_download_item_name)
|
||||
if owner is not None and owner != current_user.user_id and not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this download")
|
||||
|
||||
path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name)
|
||||
|
||||
response = FileResponse(
|
||||
@@ -574,6 +757,8 @@ async def get_bulk_download_item(
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.delete, bulk_download_item_name)
|
||||
return response
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
@@ -594,6 +779,10 @@ async def get_image_names(
|
||||
) -> ImageNamesResult:
|
||||
"""Gets ordered list of image names with metadata for optimistic updates"""
|
||||
|
||||
# Validate that the caller can read from this board before listing its images.
|
||||
if board_id is not None and board_id != "none":
|
||||
_assert_board_read_access(board_id, current_user)
|
||||
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.images.get_image_names(
|
||||
starred_first=starred_first,
|
||||
@@ -617,6 +806,7 @@ async def get_image_names(
|
||||
responses={200: {"model": list[ImageDTO]}},
|
||||
)
|
||||
async def get_images_by_names(
|
||||
current_user: CurrentUserOrDefault,
|
||||
image_names: list[str] = Body(embed=True, description="Object containing list of image names to fetch DTOs for"),
|
||||
) -> list[ImageDTO]:
|
||||
"""Gets image DTOs for the specified image names. Maintains order of input names."""
|
||||
@@ -628,8 +818,12 @@ async def get_images_by_names(
|
||||
image_dtos: list[ImageDTO] = []
|
||||
for name in image_names:
|
||||
try:
|
||||
_assert_image_read_access(name, current_user)
|
||||
dto = image_service.get_dto(name)
|
||||
image_dtos.append(dto)
|
||||
except HTTPException:
|
||||
# Skip images the user is not authorized to view
|
||||
continue
|
||||
except Exception:
|
||||
# Skip missing images - they may have been deleted between name fetch and DTO fetch
|
||||
continue
|
||||
|
||||
@@ -30,6 +30,7 @@ from invokeai.app.services.model_records import (
|
||||
)
|
||||
from invokeai.app.services.orphaned_models import OrphanedModelInfo
|
||||
from invokeai.app.util.suppress_output import SuppressOutput
|
||||
from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig
|
||||
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
|
||||
from invokeai.backend.model_manager.configs.main import (
|
||||
Main_Checkpoint_SD1_Config,
|
||||
@@ -75,8 +76,36 @@ class CacheType(str, Enum):
|
||||
def add_cover_image_to_model_config(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
|
||||
"""Add a cover image URL to a model configuration."""
|
||||
cover_image = dependencies.invoker.services.model_images.get_url(config.key)
|
||||
config.cover_image = cover_image
|
||||
return config
|
||||
return config.model_copy(update={"cover_image": cover_image})
|
||||
|
||||
|
||||
def apply_external_starter_model_overrides(config: AnyModelConfig) -> AnyModelConfig:
|
||||
"""Overlay starter-model metadata onto installed external model configs."""
|
||||
if not isinstance(config, ExternalApiModelConfig):
|
||||
return config
|
||||
|
||||
starter_match = next((starter for starter in STARTER_MODELS if starter.source == config.source), None)
|
||||
if starter_match is None:
|
||||
return config
|
||||
|
||||
model_updates: dict[str, object] = {}
|
||||
if starter_match.capabilities is not None:
|
||||
model_updates["capabilities"] = starter_match.capabilities
|
||||
if starter_match.default_settings is not None:
|
||||
model_updates["default_settings"] = starter_match.default_settings
|
||||
if starter_match.panel_schema is not None:
|
||||
model_updates["panel_schema"] = starter_match.panel_schema
|
||||
|
||||
if not model_updates:
|
||||
return config
|
||||
|
||||
return config.model_copy(update=model_updates)
|
||||
|
||||
|
||||
def prepare_model_config_for_response(config: AnyModelConfig, dependencies: Type[ApiDependencies]) -> AnyModelConfig:
|
||||
"""Apply API-only model config overlays before returning a response."""
|
||||
config = apply_external_starter_model_overrides(config)
|
||||
return add_cover_image_to_model_config(config, dependencies)
|
||||
|
||||
|
||||
##############################################################################
|
||||
@@ -145,8 +174,8 @@ async def list_model_records(
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
|
||||
)
|
||||
for model in found_models:
|
||||
model = add_cover_image_to_model_config(model, ApiDependencies)
|
||||
for index, model in enumerate(found_models):
|
||||
found_models[index] = prepare_model_config_for_response(model, ApiDependencies)
|
||||
return ModelsList(models=found_models)
|
||||
|
||||
|
||||
@@ -166,6 +195,8 @@ async def list_missing_models() -> ModelsList:
|
||||
|
||||
missing_models: list[AnyModelConfig] = []
|
||||
for model_config in record_store.all_models():
|
||||
if model_config.base == BaseModelType.External or model_config.format == ModelFormat.ExternalApi:
|
||||
continue
|
||||
if not (models_path / model_config.path).resolve().exists():
|
||||
missing_models.append(model_config)
|
||||
|
||||
@@ -190,7 +221,24 @@ async def get_model_records_by_attrs(
|
||||
if not configs:
|
||||
raise HTTPException(status_code=404, detail="No model found with these attributes")
|
||||
|
||||
return configs[0]
|
||||
return prepare_model_config_for_response(configs[0], ApiDependencies)
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/get_by_hash",
|
||||
operation_id="get_model_records_by_hash",
|
||||
response_model=AnyModelConfig,
|
||||
)
|
||||
async def get_model_records_by_hash(
|
||||
hash: str = Query(description="The hash of the model"),
|
||||
) -> AnyModelConfig:
|
||||
"""Gets a model by its hash. This is useful for recalling models that were deleted and reinstalled,
|
||||
as the hash remains stable across reinstallations while the key (UUID) changes."""
|
||||
configs = ApiDependencies.invoker.services.model_manager.store.search_by_hash(hash)
|
||||
if not configs:
|
||||
raise HTTPException(status_code=404, detail="No model found with this hash")
|
||||
|
||||
return prepare_model_config_for_response(configs[0], ApiDependencies)
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
@@ -211,7 +259,7 @@ async def get_model_record(
|
||||
"""Get a model record"""
|
||||
try:
|
||||
config = ApiDependencies.invoker.services.model_manager.store.get_model(key)
|
||||
return add_cover_image_to_model_config(config, ApiDependencies)
|
||||
return prepare_model_config_for_response(config, ApiDependencies)
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -251,7 +299,7 @@ async def reidentify_model(
|
||||
result.config.name = config.name
|
||||
result.config.description = config.description
|
||||
result.config.cover_image = config.cover_image
|
||||
if hasattr(config, "trigger_phrases") and hasattr(result.config, "trigger_phrases"):
|
||||
if hasattr(result.config, "trigger_phrases") and hasattr(config, "trigger_phrases"):
|
||||
result.config.trigger_phrases = config.trigger_phrases
|
||||
result.config.source = config.source
|
||||
result.config.source_type = config.source_type
|
||||
@@ -375,7 +423,7 @@ async def update_model_record(
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
config = record_store.update_model(key, changes=changes, allow_class_change=True)
|
||||
config = add_cover_image_to_model_config(config, ApiDependencies)
|
||||
config = prepare_model_config_for_response(config, ApiDependencies)
|
||||
logger.info(f"Updated model: {key}")
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@@ -499,6 +547,19 @@ class BulkDeleteModelsResponse(BaseModel):
|
||||
failed: List[dict] = Field(description="List of failed deletions with error messages")
|
||||
|
||||
|
||||
class BulkReidentifyModelsRequest(BaseModel):
|
||||
"""Request body for bulk model reidentification."""
|
||||
|
||||
keys: List[str] = Field(description="List of model keys to reidentify")
|
||||
|
||||
|
||||
class BulkReidentifyModelsResponse(BaseModel):
|
||||
"""Response body for bulk model reidentification."""
|
||||
|
||||
succeeded: List[str] = Field(description="List of successfully reidentified model keys")
|
||||
failed: List[dict] = Field(description="List of failed reidentifications with error messages")
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/i/bulk_delete",
|
||||
operation_id="bulk_delete_models",
|
||||
@@ -540,6 +601,67 @@ async def bulk_delete_models(
|
||||
return BulkDeleteModelsResponse(deleted=deleted, failed=failed)
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/i/bulk_reidentify",
|
||||
operation_id="bulk_reidentify_models",
|
||||
responses={
|
||||
200: {"description": "Models reidentified (possibly with some failures)"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def bulk_reidentify_models(
|
||||
current_admin: AdminUserOrDefault,
|
||||
request: BulkReidentifyModelsRequest = Body(description="List of model keys to reidentify"),
|
||||
) -> BulkReidentifyModelsResponse:
|
||||
"""
|
||||
Reidentify multiple models by re-probing their weights files.
|
||||
|
||||
Returns a list of successfully reidentified keys and failed reidentifications with error messages.
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
store = ApiDependencies.invoker.services.model_manager.store
|
||||
models_path = ApiDependencies.invoker.services.configuration.models_path
|
||||
|
||||
succeeded = []
|
||||
failed = []
|
||||
|
||||
for key in request.keys:
|
||||
try:
|
||||
config = store.get_model(key)
|
||||
if pathlib.Path(config.path).is_relative_to(models_path):
|
||||
model_path = pathlib.Path(config.path)
|
||||
else:
|
||||
model_path = models_path / config.path
|
||||
mod = ModelOnDisk(model_path)
|
||||
result = ModelConfigFactory.from_model_on_disk(mod)
|
||||
if result.config is None:
|
||||
raise InvalidModelException("Unable to identify model format")
|
||||
|
||||
# Retain user-editable fields from the original config
|
||||
result.config.path = config.path
|
||||
result.config.key = config.key
|
||||
result.config.name = config.name
|
||||
result.config.description = config.description
|
||||
result.config.cover_image = config.cover_image
|
||||
if hasattr(config, "trigger_phrases") and hasattr(result.config, "trigger_phrases"):
|
||||
result.config.trigger_phrases = config.trigger_phrases
|
||||
result.config.source = config.source
|
||||
result.config.source_type = config.source_type
|
||||
|
||||
store.replace_model(config.key, result.config)
|
||||
succeeded.append(key)
|
||||
logger.info(f"Reidentified model: {key}")
|
||||
except UnknownModelException as e:
|
||||
logger.error(f"Failed to reidentify model {key}: {str(e)}")
|
||||
failed.append({"key": key, "error": str(e)})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reidentify model {key}: {str(e)}")
|
||||
failed.append({"key": key, "error": str(e)})
|
||||
|
||||
logger.info(f"Bulk reidentify completed: {len(succeeded)} succeeded, {len(failed)} failed")
|
||||
return BulkReidentifyModelsResponse(succeeded=succeeded, failed=failed)
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/i/{key}/image",
|
||||
operation_id="delete_model_image",
|
||||
@@ -767,7 +889,7 @@ async def install_hugging_face_model(
|
||||
"/install",
|
||||
operation_id="list_model_installs",
|
||||
)
|
||||
async def list_model_installs() -> List[ModelInstallJob]:
|
||||
async def list_model_installs(current_admin: AdminUserOrDefault) -> List[ModelInstallJob]:
|
||||
"""Return the list of model install jobs.
|
||||
|
||||
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
|
||||
@@ -799,7 +921,9 @@ async def list_model_installs() -> List[ModelInstallJob]:
|
||||
404: {"description": "No such job"},
|
||||
},
|
||||
)
|
||||
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
|
||||
async def get_model_install_job(
|
||||
current_admin: AdminUserOrDefault, id: int = Path(description="Model install id")
|
||||
) -> ModelInstallJob:
|
||||
"""
|
||||
Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs'
|
||||
for information on the format of the return value.
|
||||
@@ -842,7 +966,9 @@ async def cancel_model_install_job(
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def pause_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob:
|
||||
async def pause_model_install_job(
|
||||
current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID")
|
||||
) -> ModelInstallJob:
|
||||
"""Pause the model install job corresponding to the given job ID."""
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
try:
|
||||
@@ -862,7 +988,9 @@ async def pause_model_install_job(id: int = Path(description="Model install job
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def resume_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob:
|
||||
async def resume_model_install_job(
|
||||
current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID")
|
||||
) -> ModelInstallJob:
|
||||
"""Resume a paused model install job corresponding to the given job ID."""
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
try:
|
||||
@@ -882,7 +1010,9 @@ async def resume_model_install_job(id: int = Path(description="Model install job
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def restart_failed_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob:
|
||||
async def restart_failed_model_install_job(
|
||||
current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID")
|
||||
) -> ModelInstallJob:
|
||||
"""Restart failed or non-resumable file downloads for the given job."""
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
try:
|
||||
@@ -903,6 +1033,7 @@ async def restart_failed_model_install_job(id: int = Path(description="Model ins
|
||||
status_code=201,
|
||||
)
|
||||
async def restart_model_install_file(
|
||||
current_admin: AdminUserOrDefault,
|
||||
id: int = Path(description="Model install job ID"),
|
||||
file_source: AnyHttpUrl = Body(description="File download URL to restart"),
|
||||
) -> ModelInstallJob:
|
||||
@@ -1024,7 +1155,7 @@ async def convert_model(
|
||||
|
||||
# return the config record for the new diffusers directory
|
||||
new_config = store.get_model(new_key)
|
||||
new_config = add_cover_image_to_model_config(new_config, ApiDependencies)
|
||||
new_config = prepare_model_config_for_response(new_config, ApiDependencies)
|
||||
return new_config
|
||||
|
||||
|
||||
@@ -1214,7 +1345,7 @@ class DeleteOrphanedModelsResponse(BaseModel):
|
||||
operation_id="get_orphaned_models",
|
||||
response_model=list[OrphanedModelInfo],
|
||||
)
|
||||
async def get_orphaned_models() -> list[OrphanedModelInfo]:
|
||||
async def get_orphaned_models(_: AdminUserOrDefault) -> list[OrphanedModelInfo]:
|
||||
"""Find orphaned model directories.
|
||||
|
||||
Orphaned models are directories in the models folder that contain model files
|
||||
@@ -1241,7 +1372,9 @@ async def get_orphaned_models() -> list[OrphanedModelInfo]:
|
||||
operation_id="delete_orphaned_models",
|
||||
response_model=DeleteOrphanedModelsResponse,
|
||||
)
|
||||
async def delete_orphaned_models(request: DeleteOrphanedModelsRequest) -> DeleteOrphanedModelsResponse:
|
||||
async def delete_orphaned_models(
|
||||
request: DeleteOrphanedModelsRequest, _: AdminUserOrDefault
|
||||
) -> DeleteOrphanedModelsResponse:
|
||||
"""Delete specified orphaned model directories.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -7,6 +7,7 @@ from fastapi import Body, HTTPException, Path
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.backend.image_util.controlnet_processor import process_controlnet_image
|
||||
from invokeai.backend.model_manager.taxonomy import ModelType
|
||||
@@ -291,12 +292,58 @@ def resolve_ip_adapter_models(ip_adapters: list[IPAdapterRecallParameter]) -> li
|
||||
return resolved_adapters
|
||||
|
||||
|
||||
def _assert_recall_image_access(parameters: "RecallParameter", current_user: CurrentUserOrDefault) -> None:
|
||||
"""Validate that the caller can read every image referenced in the recall parameters.
|
||||
|
||||
Control layers and IP adapters may reference image_name fields. Without this
|
||||
check an attacker who knows another user's image UUID could use the recall
|
||||
endpoint to extract image dimensions and — for ControlNet preprocessors — mint
|
||||
a derived processed image they can then fetch.
|
||||
"""
|
||||
from invokeai.app.services.board_records.board_records_common import BoardVisibility
|
||||
|
||||
image_names: list[str] = []
|
||||
if parameters.control_layers:
|
||||
for layer in parameters.control_layers:
|
||||
if layer.image_name is not None:
|
||||
image_names.append(layer.image_name)
|
||||
if parameters.ip_adapters:
|
||||
for adapter in parameters.ip_adapters:
|
||||
if adapter.image_name is not None:
|
||||
image_names.append(adapter.image_name)
|
||||
|
||||
if not image_names:
|
||||
return
|
||||
|
||||
# Admin can access all images
|
||||
if current_user.is_admin:
|
||||
return
|
||||
|
||||
for image_name in image_names:
|
||||
owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name)
|
||||
if owner is not None and owner == current_user.user_id:
|
||||
continue
|
||||
|
||||
# Check board visibility
|
||||
board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name)
|
||||
if board_id is not None:
|
||||
try:
|
||||
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public):
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise HTTPException(status_code=403, detail=f"Not authorized to access image {image_name}")
|
||||
|
||||
|
||||
@recall_parameters_router.post(
|
||||
"/{queue_id}",
|
||||
operation_id="update_recall_parameters",
|
||||
response_model=dict[str, Any],
|
||||
)
|
||||
async def update_recall_parameters(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(..., description="The queue id to perform this operation on"),
|
||||
parameters: RecallParameter = Body(..., description="Recall parameters to update"),
|
||||
) -> dict[str, Any]:
|
||||
@@ -328,6 +375,10 @@ async def update_recall_parameters(
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
# Validate image access before processing — prevents information leakage
|
||||
# (dimensions) and derived-image minting via ControlNet preprocessors.
|
||||
_assert_recall_image_access(parameters, current_user)
|
||||
|
||||
try:
|
||||
# Get only the parameters that were actually provided (non-None values)
|
||||
provided_params = {k: v for k, v in parameters.model_dump().items() if v is not None}
|
||||
@@ -335,14 +386,14 @@ async def update_recall_parameters(
|
||||
if not provided_params:
|
||||
return {"status": "no_parameters_provided", "updated_count": 0}
|
||||
|
||||
# Store each parameter in client state using a consistent key format
|
||||
# Store each parameter in client state scoped to the current user
|
||||
updated_count = 0
|
||||
for param_key, param_value in provided_params.items():
|
||||
# Convert parameter values to JSON strings for storage
|
||||
value_str = json.dumps(param_value)
|
||||
try:
|
||||
ApiDependencies.invoker.services.client_state_persistence.set_by_key(
|
||||
queue_id, f"recall_{param_key}", value_str
|
||||
current_user.user_id, f"recall_{param_key}", value_str
|
||||
)
|
||||
updated_count += 1
|
||||
except Exception as e:
|
||||
@@ -396,7 +447,9 @@ async def update_recall_parameters(
|
||||
logger.info(
|
||||
f"Emitting recall_parameters_updated event for queue {queue_id} with {len(provided_params)} parameters"
|
||||
)
|
||||
ApiDependencies.invoker.services.events.emit_recall_parameters_updated(queue_id, provided_params)
|
||||
ApiDependencies.invoker.services.events.emit_recall_parameters_updated(
|
||||
queue_id, current_user.user_id, provided_params
|
||||
)
|
||||
logger.info("Successfully emitted recall_parameters_updated event")
|
||||
except Exception as e:
|
||||
logger.error(f"Error emitting recall parameters event: {e}", exc_info=True)
|
||||
@@ -425,6 +478,7 @@ async def update_recall_parameters(
|
||||
response_model=dict[str, Any],
|
||||
)
|
||||
async def get_recall_parameters(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(..., description="The queue id to retrieve parameters for"),
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -44,7 +44,8 @@ def sanitize_queue_item_for_user(
|
||||
"""Sanitize queue item for non-admin users viewing other users' items.
|
||||
|
||||
For non-admin users viewing queue items belonging to other users,
|
||||
the field_values, session graph, and workflow should be hidden/cleared to protect privacy.
|
||||
only timestamps, status, and error information are exposed. All other
|
||||
fields (user identity, generation parameters, graphs, workflows) are stripped.
|
||||
|
||||
Args:
|
||||
queue_item: The queue item to sanitize
|
||||
@@ -58,15 +59,25 @@ def sanitize_queue_item_for_user(
|
||||
if is_admin or queue_item.user_id == current_user_id:
|
||||
return queue_item
|
||||
|
||||
# For non-admins viewing other users' items, clear sensitive fields
|
||||
# Create a shallow copy to avoid mutating the original
|
||||
# For non-admins viewing other users' items, strip everything except
|
||||
# item_id, queue_id, status, and timestamps
|
||||
sanitized_item = queue_item.model_copy(deep=False)
|
||||
sanitized_item.user_id = "redacted"
|
||||
sanitized_item.user_display_name = None
|
||||
sanitized_item.user_email = None
|
||||
sanitized_item.batch_id = "redacted"
|
||||
sanitized_item.session_id = "redacted"
|
||||
sanitized_item.origin = None
|
||||
sanitized_item.destination = None
|
||||
sanitized_item.priority = 0
|
||||
sanitized_item.field_values = None
|
||||
sanitized_item.retried_from_item_id = None
|
||||
sanitized_item.workflow = None
|
||||
# Clear the session graph by replacing it with an empty graph execution state
|
||||
# This prevents information leakage through the generation graph
|
||||
sanitized_item.error_type = None
|
||||
sanitized_item.error_message = None
|
||||
sanitized_item.error_traceback = None
|
||||
sanitized_item.session = GraphExecutionState(
|
||||
id=queue_item.session.id,
|
||||
id="redacted",
|
||||
graph=Graph(),
|
||||
)
|
||||
return sanitized_item
|
||||
@@ -126,12 +137,16 @@ async def list_all_queue_items(
|
||||
},
|
||||
)
|
||||
async def get_queue_item_ids(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
|
||||
) -> ItemIdsResult:
|
||||
"""Gets all queue item ids that match the given parameters"""
|
||||
"""Gets all queue item ids that match the given parameters. Non-admin users only see their own items."""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(queue_id=queue_id, order_dir=order_dir)
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(
|
||||
queue_id=queue_id, order_dir=order_dir, user_id=user_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue item ids: {e}")
|
||||
|
||||
@@ -376,11 +391,15 @@ async def prune(
|
||||
},
|
||||
)
|
||||
async def get_current_queue_item(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> Optional[SessionQueueItem]:
|
||||
"""Gets the currently execution queue item"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
if item is not None:
|
||||
item = sanitize_queue_item_for_user(item, current_user.user_id, current_user.is_admin)
|
||||
return item
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while getting current queue item: {e}")
|
||||
|
||||
@@ -393,11 +412,15 @@ async def get_current_queue_item(
|
||||
},
|
||||
)
|
||||
async def get_next_queue_item(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> Optional[SessionQueueItem]:
|
||||
"""Gets the next queue item, without executing it"""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
|
||||
item = ApiDependencies.invoker.services.session_queue.get_next(queue_id)
|
||||
if item is not None:
|
||||
item = sanitize_queue_item_for_user(item, current_user.user_id, current_user.is_admin)
|
||||
return item
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while getting next queue item: {e}")
|
||||
|
||||
@@ -413,9 +436,10 @@ async def get_queue_status(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionQueueAndProcessorStatus:
|
||||
"""Gets the status of the session queue"""
|
||||
"""Gets the status of the session queue. Non-admin users see only their own counts and cannot see current item details unless they own it."""
|
||||
try:
|
||||
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=current_user.user_id)
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=user_id)
|
||||
processor = ApiDependencies.invoker.services.session_processor.get_status()
|
||||
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
|
||||
except Exception as e:
|
||||
@@ -430,12 +454,16 @@ async def get_queue_status(
|
||||
},
|
||||
)
|
||||
async def get_batch_status(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
batch_id: str = Path(description="The batch to get the status of"),
|
||||
) -> BatchStatus:
|
||||
"""Gets the status of the session queue"""
|
||||
"""Gets the status of a batch. Non-admin users only see their own batches."""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.session_queue.get_batch_status(
|
||||
queue_id=queue_id, batch_id=batch_id, user_id=user_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while getting batch status: {e}")
|
||||
|
||||
@@ -529,13 +557,15 @@ async def cancel_queue_item(
|
||||
responses={200: {"model": SessionQueueCountsByDestination}},
|
||||
)
|
||||
async def counts_by_destination(
|
||||
current_user: CurrentUserOrDefault,
|
||||
queue_id: str = Path(description="The queue id to query"),
|
||||
destination: str = Query(description="The destination to query"),
|
||||
) -> SessionQueueCountsByDestination:
|
||||
"""Gets the counts of queue items by destination"""
|
||||
"""Gets the counts of queue items by destination. Non-admin users only see their own items."""
|
||||
try:
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
|
||||
queue_id=queue_id, destination=destination
|
||||
queue_id=queue_id, destination=destination, user_id=user_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching counts by destination: {e}")
|
||||
|
||||
@@ -6,6 +6,7 @@ from fastapi import APIRouter, Body, File, HTTPException, Path, Query, UploadFil
|
||||
from fastapi.responses import FileResponse
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
@@ -33,16 +34,25 @@ workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
|
||||
},
|
||||
)
|
||||
async def get_workflow(
|
||||
current_user: CurrentUserOrDefault,
|
||||
workflow_id: str = Path(description="The workflow to get"),
|
||||
) -> WorkflowRecordWithThumbnailDTO:
|
||||
"""Gets a workflow"""
|
||||
try:
|
||||
thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id)
|
||||
workflow = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump())
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
if config.multiuser:
|
||||
is_default = workflow.workflow.meta.category is WorkflowCategory.Default
|
||||
is_owner = workflow.user_id == current_user.user_id
|
||||
if not (is_default or is_owner or workflow.is_public or current_user.is_admin):
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this workflow")
|
||||
|
||||
thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id)
|
||||
return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump())
|
||||
|
||||
|
||||
@workflows_router.patch(
|
||||
"/i/{workflow_id}",
|
||||
@@ -52,10 +62,21 @@ async def get_workflow(
|
||||
},
|
||||
)
|
||||
async def update_workflow(
|
||||
current_user: CurrentUserOrDefault,
|
||||
workflow: Workflow = Body(description="The updated workflow", embed=True),
|
||||
) -> WorkflowRecordDTO:
|
||||
"""Updates a workflow"""
|
||||
return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow)
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
if config.multiuser:
|
||||
try:
|
||||
existing = ApiDependencies.invoker.services.workflow_records.get(workflow.id)
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
if not current_user.is_admin and existing.user_id != current_user.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
|
||||
# Pass user_id for defense-in-depth SQL scoping; admins pass None to allow any.
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow, user_id=user_id)
|
||||
|
||||
|
||||
@workflows_router.delete(
|
||||
@@ -63,15 +84,25 @@ async def update_workflow(
|
||||
operation_id="delete_workflow",
|
||||
)
|
||||
async def delete_workflow(
|
||||
current_user: CurrentUserOrDefault,
|
||||
workflow_id: str = Path(description="The workflow to delete"),
|
||||
) -> None:
|
||||
"""Deletes a workflow"""
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
if config.multiuser:
|
||||
try:
|
||||
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
if not current_user.is_admin and existing.user_id != current_user.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to delete this workflow")
|
||||
try:
|
||||
ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id)
|
||||
except WorkflowThumbnailFileNotFoundException:
|
||||
# It's OK if the workflow has no thumbnail file. We can still delete the workflow.
|
||||
pass
|
||||
ApiDependencies.invoker.services.workflow_records.delete(workflow_id)
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
ApiDependencies.invoker.services.workflow_records.delete(workflow_id, user_id=user_id)
|
||||
|
||||
|
||||
@workflows_router.post(
|
||||
@@ -82,10 +113,17 @@ async def delete_workflow(
|
||||
},
|
||||
)
|
||||
async def create_workflow(
|
||||
current_user: CurrentUserOrDefault,
|
||||
workflow: WorkflowWithoutID = Body(description="The workflow to create", embed=True),
|
||||
) -> WorkflowRecordDTO:
|
||||
"""Creates a workflow"""
|
||||
return ApiDependencies.invoker.services.workflow_records.create(workflow=workflow)
|
||||
# In single-user mode, workflows are owned by 'system' and shared by default so all legacy/single-user
|
||||
# workflows remain visible. In multiuser mode, workflows are private to the creator by default.
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
is_public = not config.multiuser
|
||||
return ApiDependencies.invoker.services.workflow_records.create(
|
||||
workflow=workflow, user_id=current_user.user_id, is_public=is_public
|
||||
)
|
||||
|
||||
|
||||
@workflows_router.get(
|
||||
@@ -96,6 +134,7 @@ async def create_workflow(
|
||||
},
|
||||
)
|
||||
async def list_workflows(
|
||||
current_user: CurrentUserOrDefault,
|
||||
page: int = Query(default=0, description="The page to get"),
|
||||
per_page: Optional[int] = Query(default=None, description="The number of workflows per page"),
|
||||
order_by: WorkflowRecordOrderBy = Query(
|
||||
@@ -106,8 +145,19 @@ async def list_workflows(
|
||||
tags: Optional[list[str]] = Query(default=None, description="The tags of workflow to get"),
|
||||
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
|
||||
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
|
||||
is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
|
||||
) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]:
|
||||
"""Gets a page of workflows"""
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
|
||||
# In multiuser mode, scope user-category workflows to the current user unless fetching shared workflows.
|
||||
# Admins skip the user_id filter so they can see and manage all workflows including system-owned ones.
|
||||
user_id_filter: Optional[str] = None
|
||||
if config.multiuser and not current_user.is_admin:
|
||||
has_user_category = not categories or WorkflowCategory.User in categories
|
||||
if has_user_category and is_public is not True:
|
||||
user_id_filter = current_user.user_id
|
||||
|
||||
workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = []
|
||||
workflows = ApiDependencies.invoker.services.workflow_records.get_many(
|
||||
order_by=order_by,
|
||||
@@ -118,6 +168,8 @@ async def list_workflows(
|
||||
categories=categories,
|
||||
tags=tags,
|
||||
has_been_opened=has_been_opened,
|
||||
user_id=user_id_filter,
|
||||
is_public=is_public,
|
||||
)
|
||||
for workflow in workflows.items:
|
||||
workflows_with_thumbnails.append(
|
||||
@@ -143,15 +195,20 @@ async def list_workflows(
|
||||
},
|
||||
)
|
||||
async def set_workflow_thumbnail(
|
||||
current_user: CurrentUserOrDefault,
|
||||
workflow_id: str = Path(description="The workflow to update"),
|
||||
image: UploadFile = File(description="The image file to upload"),
|
||||
):
|
||||
"""Sets a workflow's thumbnail image"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
|
||||
|
||||
if not image.content_type or not image.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
@@ -177,14 +234,19 @@ async def set_workflow_thumbnail(
|
||||
},
|
||||
)
|
||||
async def delete_workflow_thumbnail(
|
||||
current_user: CurrentUserOrDefault,
|
||||
workflow_id: str = Path(description="The workflow to update"),
|
||||
):
|
||||
"""Removes a workflow's thumbnail image"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id)
|
||||
except ValueError as e:
|
||||
@@ -206,8 +268,12 @@ async def delete_workflow_thumbnail(
|
||||
async def get_workflow_thumbnail(
|
||||
workflow_id: str = Path(description="The id of the workflow thumbnail to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a workflow's thumbnail image"""
|
||||
"""Gets a workflow's thumbnail image.
|
||||
|
||||
This endpoint is intentionally unauthenticated because browsers load images
|
||||
via <img src> tags which cannot send Bearer tokens. Workflow IDs are UUIDs,
|
||||
providing security through unguessability.
|
||||
"""
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.workflow_thumbnails.get_path(workflow_id)
|
||||
|
||||
@@ -223,37 +289,91 @@ async def get_workflow_thumbnail(
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@workflows_router.patch(
|
||||
"/i/{workflow_id}/is_public",
|
||||
operation_id="update_workflow_is_public",
|
||||
responses={
|
||||
200: {"model": WorkflowRecordDTO},
|
||||
},
|
||||
)
|
||||
async def update_workflow_is_public(
|
||||
current_user: CurrentUserOrDefault,
|
||||
workflow_id: str = Path(description="The workflow to update"),
|
||||
is_public: bool = Body(description="Whether the workflow should be shared publicly", embed=True),
|
||||
) -> WorkflowRecordDTO:
|
||||
"""Updates whether a workflow is shared publicly"""
|
||||
try:
|
||||
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
|
||||
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
return ApiDependencies.invoker.services.workflow_records.update_is_public(
|
||||
workflow_id=workflow_id, is_public=is_public, user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
@workflows_router.get("/tags", operation_id="get_all_tags")
|
||||
async def get_all_tags(
|
||||
current_user: CurrentUserOrDefault,
|
||||
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
|
||||
is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
|
||||
) -> list[str]:
|
||||
"""Gets all unique tags from workflows"""
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
user_id_filter: Optional[str] = None
|
||||
if config.multiuser and not current_user.is_admin:
|
||||
has_user_category = not categories or WorkflowCategory.User in categories
|
||||
if has_user_category and is_public is not True:
|
||||
user_id_filter = current_user.user_id
|
||||
|
||||
return ApiDependencies.invoker.services.workflow_records.get_all_tags(categories=categories)
|
||||
return ApiDependencies.invoker.services.workflow_records.get_all_tags(
|
||||
categories=categories, user_id=user_id_filter, is_public=is_public
|
||||
)
|
||||
|
||||
|
||||
@workflows_router.get("/counts_by_tag", operation_id="get_counts_by_tag")
|
||||
async def get_counts_by_tag(
|
||||
current_user: CurrentUserOrDefault,
|
||||
tags: list[str] = Query(description="The tags to get counts for"),
|
||||
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
|
||||
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
|
||||
is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
|
||||
) -> dict[str, int]:
|
||||
"""Counts workflows by tag"""
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
user_id_filter: Optional[str] = None
|
||||
if config.multiuser and not current_user.is_admin:
|
||||
has_user_category = not categories or WorkflowCategory.User in categories
|
||||
if has_user_category and is_public is not True:
|
||||
user_id_filter = current_user.user_id
|
||||
|
||||
return ApiDependencies.invoker.services.workflow_records.counts_by_tag(
|
||||
tags=tags, categories=categories, has_been_opened=has_been_opened
|
||||
tags=tags, categories=categories, has_been_opened=has_been_opened, user_id=user_id_filter, is_public=is_public
|
||||
)
|
||||
|
||||
|
||||
@workflows_router.get("/counts_by_category", operation_id="counts_by_category")
|
||||
async def counts_by_category(
|
||||
current_user: CurrentUserOrDefault,
|
||||
categories: list[WorkflowCategory] = Query(description="The categories to include"),
|
||||
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
|
||||
is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
|
||||
) -> dict[str, int]:
|
||||
"""Counts workflows by category"""
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
user_id_filter: Optional[str] = None
|
||||
if config.multiuser and not current_user.is_admin:
|
||||
has_user_category = WorkflowCategory.User in categories
|
||||
if has_user_category and is_public is not True:
|
||||
user_id_filter = current_user.user_id
|
||||
|
||||
return ApiDependencies.invoker.services.workflow_records.counts_by_category(
|
||||
categories=categories, has_been_opened=has_been_opened
|
||||
categories=categories, has_been_opened=has_been_opened, user_id=user_id_filter, is_public=is_public
|
||||
)
|
||||
|
||||
|
||||
@@ -262,7 +382,18 @@ async def counts_by_category(
|
||||
operation_id="update_opened_at",
|
||||
)
|
||||
async def update_opened_at(
|
||||
current_user: CurrentUserOrDefault,
|
||||
workflow_id: str = Path(description="The workflow to update"),
|
||||
) -> None:
|
||||
"""Updates the opened_at field of a workflow"""
|
||||
ApiDependencies.invoker.services.workflow_records.update_opened_at(workflow_id)
|
||||
try:
|
||||
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
config = ApiDependencies.invoker.services.configuration
|
||||
if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
|
||||
|
||||
user_id = None if current_user.is_admin else current_user.user_id
|
||||
ApiDependencies.invoker.services.workflow_records.update_opened_at(workflow_id, user_id=user_id)
|
||||
|
||||
@@ -121,6 +121,11 @@ class SocketIO:
|
||||
|
||||
Returns True to accept the connection, False to reject it.
|
||||
Stores user_id in the internal socket users dict for later use.
|
||||
|
||||
In multiuser mode, connections without a valid token are rejected outright
|
||||
so that anonymous clients cannot subscribe to queue rooms and observe
|
||||
queue activity belonging to other users. In single-user mode, unauthenticated
|
||||
connections are accepted as the system admin user.
|
||||
"""
|
||||
# Extract token from auth data or headers
|
||||
token = None
|
||||
@@ -137,6 +142,23 @@ class SocketIO:
|
||||
if token:
|
||||
token_data = verify_token(token)
|
||||
if token_data:
|
||||
# In multiuser mode, also verify the backing user record still
|
||||
# exists and is active — mirrors the REST auth check in
|
||||
# auth_dependencies.py. A deleted or deactivated user whose
|
||||
# JWT has not yet expired must not be allowed to open a socket.
|
||||
if self._is_multiuser_enabled():
|
||||
try:
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
|
||||
user = ApiDependencies.invoker.services.users.get(token_data.user_id)
|
||||
if user is None or not user.is_active:
|
||||
logger.warning(f"Rejecting socket {sid}: user {token_data.user_id} not found or inactive")
|
||||
return False
|
||||
except Exception:
|
||||
# If user service is unavailable, fail closed
|
||||
logger.warning(f"Rejecting socket {sid}: unable to verify user record")
|
||||
return False
|
||||
|
||||
# Store user_id and is_admin in socket users dict
|
||||
self._socket_users[sid] = {
|
||||
"user_id": token_data.user_id,
|
||||
@@ -147,14 +169,37 @@ class SocketIO:
|
||||
)
|
||||
return True
|
||||
|
||||
# If no valid token, store system user for backward compatibility
|
||||
# No valid token provided. In multiuser mode this is not allowed — reject
|
||||
# the connection so anonymous clients cannot subscribe to queue rooms.
|
||||
# In single-user mode, fall through and accept the socket as system admin.
|
||||
if self._is_multiuser_enabled():
|
||||
logger.warning(
|
||||
f"Rejecting socket {sid} connection: multiuser mode is enabled and no valid auth token was provided"
|
||||
)
|
||||
return False
|
||||
|
||||
self._socket_users[sid] = {
|
||||
"user_id": "system",
|
||||
"is_admin": False,
|
||||
"is_admin": True,
|
||||
}
|
||||
logger.debug(f"Socket {sid} connected as system user (no valid token)")
|
||||
logger.debug(f"Socket {sid} connected as system admin (single-user mode)")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _is_multiuser_enabled() -> bool:
|
||||
"""Check whether multiuser mode is enabled. Fails closed if configuration
|
||||
is not yet initialized, which should not happen in practice but prevents
|
||||
accidentally opening the socket during startup races."""
|
||||
try:
|
||||
# Imported here to avoid a circular import at module load time.
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
|
||||
return bool(ApiDependencies.invoker.services.configuration.multiuser)
|
||||
except Exception:
|
||||
# If dependencies are not initialized, fail closed (treat as multiuser)
|
||||
# so we never accidentally admit an anonymous socket.
|
||||
return True
|
||||
|
||||
async def _handle_disconnect(self, sid: str) -> None:
|
||||
"""Handle socket disconnection and cleanup user info."""
|
||||
if sid in self._socket_users:
|
||||
@@ -165,15 +210,20 @@ class SocketIO:
|
||||
"""Handle queue subscription and add socket to both queue and user-specific rooms."""
|
||||
queue_id = QueueSubscriptionEvent(**data).queue_id
|
||||
|
||||
# Check if we have user info for this socket
|
||||
# Check if we have user info for this socket. In multiuser mode _handle_connect
|
||||
# will have already rejected any socket without a valid token, so missing user
|
||||
# info here is a bug — refuse the subscription rather than silently falling back
|
||||
# to an anonymous system user who could then receive queue item events.
|
||||
if sid not in self._socket_users:
|
||||
logger.warning(
|
||||
f"Socket {sid} subscribing to queue {queue_id} but has no user info - need to authenticate via connect event"
|
||||
)
|
||||
# Store as system user temporarily - real auth should happen in connect
|
||||
if self._is_multiuser_enabled():
|
||||
logger.warning(
|
||||
f"Refusing queue subscription for socket {sid}: no user info (socket not authenticated via connect event)"
|
||||
)
|
||||
return
|
||||
# Single-user mode: safe to fall back to the system admin user.
|
||||
self._socket_users[sid] = {
|
||||
"user_id": "system",
|
||||
"is_admin": False,
|
||||
"is_admin": True,
|
||||
}
|
||||
|
||||
user_id = self._socket_users[sid]["user_id"]
|
||||
@@ -198,6 +248,13 @@ class SocketIO:
|
||||
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
|
||||
|
||||
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
|
||||
# In multiuser mode, only allow authenticated sockets to subscribe.
|
||||
# Bulk download events are routed to user-specific rooms, so the
|
||||
# bulk_download_id room subscription is only kept for single-user
|
||||
# backward compatibility.
|
||||
if self._is_multiuser_enabled() and sid not in self._socket_users:
|
||||
logger.warning(f"Refusing bulk download subscription for unknown socket {sid} in multiuser mode")
|
||||
return
|
||||
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
|
||||
|
||||
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
|
||||
@@ -206,9 +263,17 @@ class SocketIO:
|
||||
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
|
||||
"""Handle queue events with user isolation.
|
||||
|
||||
Invocation events (progress, started, complete) are private - only emit to owner and admins.
|
||||
Queue item status events are public - emit to all users (field values hidden via API).
|
||||
Other queue events emit to all subscribers.
|
||||
All queue item events (invocation events AND QueueItemStatusChangedEvent) are
|
||||
private to the owning user and admins. They carry unsanitized user_id, batch_id,
|
||||
session_id, origin, destination and error metadata, and must never be broadcast
|
||||
to the whole queue room — otherwise any other authenticated subscriber could
|
||||
observe cross-user queue activity.
|
||||
|
||||
RecallParametersUpdatedEvent is also private to the owner + admins.
|
||||
|
||||
BatchEnqueuedEvent carries the enqueuing user's batch_id/origin/counts and
|
||||
is also routed privately. QueueClearedEvent is the only queue event that
|
||||
is still broadcast to the whole queue room.
|
||||
|
||||
IMPORTANT: Check InvocationEventBase BEFORE QueueItemEventBase since InvocationEventBase
|
||||
inherits from QueueItemEventBase. The order of isinstance checks matters!
|
||||
@@ -237,24 +302,40 @@ class SocketIO:
|
||||
|
||||
logger.debug(f"Emitted private invocation event {event_name} to user room {user_room} and admin room")
|
||||
|
||||
# Queue item status events are visible to all users (field values masked via API)
|
||||
# This catches QueueItemStatusChangedEvent but NOT InvocationEvents (already handled above)
|
||||
# Other queue item events (QueueItemStatusChangedEvent) carry unsanitized
|
||||
# user_id, batch_id, session_id, origin, destination and error metadata.
|
||||
# They are private to the owning user + admins — never broadcast to the
|
||||
# full queue room.
|
||||
elif isinstance(event_data, QueueItemEventBase) and hasattr(event_data, "user_id"):
|
||||
# Emit to all subscribers in the queue
|
||||
await self._sio.emit(
|
||||
event=event_name, data=event_data.model_dump(mode="json"), room=event_data.queue_id
|
||||
)
|
||||
user_room = f"user:{event_data.user_id}"
|
||||
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
|
||||
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
|
||||
|
||||
logger.info(
|
||||
f"Emitted public queue item event {event_name} to all subscribers in queue {event_data.queue_id}"
|
||||
)
|
||||
logger.debug(f"Emitted private queue item event {event_name} to user room {user_room} and admin room")
|
||||
|
||||
# RecallParametersUpdatedEvent is private - only emit to owner + admins
|
||||
elif isinstance(event_data, RecallParametersUpdatedEvent):
|
||||
user_room = f"user:{event_data.user_id}"
|
||||
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
|
||||
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
|
||||
logger.debug(f"Emitted private recall_parameters_updated event to user room {user_room} and admin room")
|
||||
|
||||
# BatchEnqueuedEvent carries the enqueuing user's batch_id, origin, and
|
||||
# enqueued counts. Route it privately to the owner + admins so other
|
||||
# users do not observe cross-user batch activity.
|
||||
elif isinstance(event_data, BatchEnqueuedEvent):
|
||||
user_room = f"user:{event_data.user_id}"
|
||||
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
|
||||
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
|
||||
logger.debug(f"Emitted private batch_enqueued event to user room {user_room} and admin room")
|
||||
|
||||
else:
|
||||
# For other queue events (like QueueClearedEvent, BatchEnqueuedEvent), emit to all subscribers
|
||||
# For remaining queue events (e.g. QueueClearedEvent) that do not
|
||||
# carry user identity, emit to all subscribers in the queue room.
|
||||
await self._sio.emit(
|
||||
event=event_name, data=event_data.model_dump(mode="json"), room=event_data.queue_id
|
||||
)
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Emitted general queue event {event_name} to all subscribers in queue {event_data.queue_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -265,4 +346,17 @@ class SocketIO:
|
||||
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))
|
||||
|
||||
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
|
||||
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id)
|
||||
event_name, event_data = event
|
||||
# Route to user-specific + admin rooms so that other authenticated
|
||||
# users cannot learn the bulk_download_item_name (the capability token
|
||||
# needed to fetch the zip from the unauthenticated GET endpoint).
|
||||
# In single-user mode (user_id="system"), fall back to the shared
|
||||
# bulk_download_id room for backward compatibility.
|
||||
if hasattr(event_data, "user_id") and event_data.user_id != "system":
|
||||
user_room = f"user:{event_data.user_id}"
|
||||
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
|
||||
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
|
||||
else:
|
||||
await self._sio.emit(
|
||||
event=event_name, data=event_data.model_dump(mode="json"), room=event_data.bulk_download_id
|
||||
)
|
||||
|
||||
@@ -79,6 +79,50 @@ app = FastAPI(
|
||||
)
|
||||
|
||||
|
||||
class SlidingWindowTokenMiddleware(BaseHTTPMiddleware):
|
||||
"""Refresh the JWT token on each authenticated response.
|
||||
|
||||
When a request includes a valid Bearer token, the response includes a
|
||||
X-Refreshed-Token header with a new token that has a fresh expiry.
|
||||
This implements sliding-window session expiry: the session only expires
|
||||
after a period of *inactivity*, not a fixed time after login.
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
|
||||
response = await call_next(request)
|
||||
|
||||
# Only refresh on mutating requests (POST/PUT/PATCH/DELETE) — these indicate
|
||||
# genuine user activity. GET requests are often background fetches (RTK Query
|
||||
# cache revalidation, refetch-on-focus, etc.) and should not reset the
|
||||
# inactivity timer.
|
||||
if response.status_code < 400 and request.method in ("POST", "PUT", "PATCH", "DELETE"):
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[7:]
|
||||
try:
|
||||
from datetime import timedelta
|
||||
|
||||
from invokeai.app.api.routers.auth import TOKEN_EXPIRATION_NORMAL, TOKEN_EXPIRATION_REMEMBER_ME
|
||||
from invokeai.app.services.auth.token_service import create_access_token, verify_token
|
||||
|
||||
token_data = verify_token(token)
|
||||
if token_data is not None:
|
||||
# Use the remember_me claim from the token to determine the
|
||||
# correct refresh duration. This avoids the bug where a 7-day
|
||||
# token with <24h remaining would be silently downgraded to 1 day.
|
||||
if token_data.remember_me:
|
||||
expires_delta = timedelta(days=TOKEN_EXPIRATION_REMEMBER_ME)
|
||||
else:
|
||||
expires_delta = timedelta(days=TOKEN_EXPIRATION_NORMAL)
|
||||
|
||||
new_token = create_access_token(token_data, expires_delta)
|
||||
response.headers["X-Refreshed-Token"] = new_token
|
||||
except Exception:
|
||||
pass # Don't fail the request if token refresh fails
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class RedirectRootWithQueryStringMiddleware(BaseHTTPMiddleware):
|
||||
"""When a request is made to the root path with a query string, redirect to the root path without the query string.
|
||||
|
||||
@@ -99,6 +143,7 @@ class RedirectRootWithQueryStringMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
# Add the middleware
|
||||
app.add_middleware(RedirectRootWithQueryStringMiddleware)
|
||||
app.add_middleware(SlidingWindowTokenMiddleware)
|
||||
|
||||
|
||||
# Add event handler
|
||||
@@ -117,6 +162,7 @@ app.add_middleware(
|
||||
allow_credentials=app_config.allow_credentials,
|
||||
allow_methods=app_config.allow_methods,
|
||||
allow_headers=app_config.allow_headers,
|
||||
expose_headers=["X-Refreshed-Token"],
|
||||
)
|
||||
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
715
invokeai/app/invocations/anima_denoise.py
Normal file
715
invokeai/app/invocations/anima_denoise.py
Normal file
@@ -0,0 +1,715 @@
|
||||
"""Anima denoising invocation.
|
||||
|
||||
Implements the rectified flow denoising loop for Anima models:
|
||||
- Direct prediction: denoised = input - output * sigma
|
||||
- Fixed shift=3.0 via loglinear_timestep_shift (Flux paper by Black Forest Labs)
|
||||
- Timestep convention: timestep = sigma * 1.0 (raw sigma, NOT 1-sigma like Z-Image)
|
||||
- NO v-prediction negation (unlike Z-Image)
|
||||
- 3D latent space: [B, C, T, H, W] with T=1 for images
|
||||
- 16 latent channels, 8x spatial compression
|
||||
|
||||
Key differences from Z-Image denoise:
|
||||
- Anima uses fixed shift=3.0, Z-Image uses dynamic shift based on resolution
|
||||
- Anima: timestep = sigma (raw), Z-Image: model_t = 1.0 - sigma
|
||||
- Anima: noise_pred = model_output (direct), Z-Image: noise_pred = -model_output (v-pred)
|
||||
- Anima transformer takes (x, timesteps, context, t5xxl_ids, t5xxl_weights)
|
||||
- Anima uses 3D latents directly, Z-Image converts 4D -> list of 5D
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from contextlib import ExitStack
|
||||
from typing import Callable, Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
AnimaConditioningField,
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.anima.anima_transformer_patch import patch_anima_for_regional_prompting
|
||||
from invokeai.backend.anima.conditioning_data import AnimaRegionalTextConditioning, AnimaTextConditioning
|
||||
from invokeai.backend.anima.regional_prompting import AnimaRegionalPromptingExtension
|
||||
from invokeai.backend.flux.schedulers import ANIMA_SCHEDULER_LABELS, ANIMA_SCHEDULER_MAP, ANIMA_SCHEDULER_NAME_VALUES
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.anima_lora_constants import ANIMA_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import (
|
||||
RectifiedFlowInpaintExtension,
|
||||
assert_broadcastable,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import AnimaConditioningInfo, Range
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
# Anima uses 8x spatial compression (VAE downsamples by 2^3)
|
||||
ANIMA_LATENT_SCALE_FACTOR = 8
|
||||
# Anima uses 16 latent channels
|
||||
ANIMA_LATENT_CHANNELS = 16
|
||||
# Anima uses fixed shift=3.0 for the rectified flow schedule
|
||||
ANIMA_SHIFT = 3.0
|
||||
# Anima uses raw sigma values as timesteps (no rescaling)
|
||||
ANIMA_MULTIPLIER = 1.0
|
||||
|
||||
|
||||
def loglinear_timestep_shift(alpha: float, t: float) -> float:
|
||||
"""Apply log-linear timestep shift to a noise schedule value.
|
||||
|
||||
This shift biases the noise schedule toward higher noise levels, as described
|
||||
in the Flux model (Black Forest Labs, 2024). With alpha > 1, the model spends
|
||||
proportionally more denoising steps at higher noise levels.
|
||||
|
||||
Formula: sigma = alpha * t / (1 + (alpha - 1) * t)
|
||||
|
||||
Args:
|
||||
alpha: Shift factor (3.0 for Anima, resolution-dependent for Flux).
|
||||
t: Timestep value in [0, 1].
|
||||
|
||||
Returns:
|
||||
Shifted timestep value.
|
||||
"""
|
||||
if alpha == 1.0:
|
||||
return t
|
||||
return alpha * t / (1 + (alpha - 1) * t)
|
||||
|
||||
|
||||
def inverse_loglinear_timestep_shift(alpha: float, sigma: float) -> float:
|
||||
"""Recover linear t from a shifted sigma value.
|
||||
|
||||
Inverse of loglinear_timestep_shift: given sigma = alpha * t / (1 + (alpha-1) * t),
|
||||
solve for t = sigma / (alpha - (alpha-1) * sigma).
|
||||
|
||||
This is needed for the inpainting extension, which expects linear t values
|
||||
for gradient mask thresholding. With Anima's shift=3.0, the difference
|
||||
between shifted sigma and linear t is large (e.g. at t=0.5, sigma=0.75),
|
||||
causing overly aggressive mask thresholding if sigma is used directly.
|
||||
|
||||
Args:
|
||||
alpha: Shift factor (3.0 for Anima).
|
||||
sigma: Shifted sigma value in [0, 1].
|
||||
|
||||
Returns:
|
||||
Linear t value in [0, 1].
|
||||
"""
|
||||
if alpha == 1.0:
|
||||
return sigma
|
||||
denominator = alpha - (alpha - 1) * sigma
|
||||
if abs(denominator) < 1e-8:
|
||||
return 1.0
|
||||
return sigma / denominator
|
||||
|
||||
|
||||
class AnimaInpaintExtension(RectifiedFlowInpaintExtension):
|
||||
"""Inpaint extension for Anima that accounts for the time-SNR shift.
|
||||
|
||||
Anima uses a fixed shift=3.0 which makes sigma values significantly larger
|
||||
than the corresponding linear t values. The base RectifiedFlowInpaintExtension
|
||||
uses t_prev for both gradient mask thresholding and noise mixing, which assumes
|
||||
linear t values.
|
||||
|
||||
This subclass:
|
||||
- Uses the LINEAR t for gradient mask thresholding (correct progressive reveal)
|
||||
- Uses the SHIFTED sigma for noise mixing (matches the denoiser's noise level)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
init_latents: torch.Tensor,
|
||||
inpaint_mask: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
shift: float = ANIMA_SHIFT,
|
||||
):
|
||||
assert_broadcastable(init_latents.shape, inpaint_mask.shape, noise.shape)
|
||||
self._init_latents = init_latents
|
||||
self._inpaint_mask = inpaint_mask
|
||||
self._noise = noise
|
||||
self._shift = shift
|
||||
|
||||
def merge_intermediate_latents_with_init_latents(
|
||||
self, intermediate_latents: torch.Tensor, sigma_prev: float
|
||||
) -> torch.Tensor:
|
||||
"""Merge intermediate latents with init latents, correcting for Anima's shift.
|
||||
|
||||
Args:
|
||||
intermediate_latents: The denoised latents at the current step.
|
||||
sigma_prev: The SHIFTED sigma value for the next step.
|
||||
"""
|
||||
# Recover linear t from shifted sigma for gradient mask thresholding.
|
||||
# This ensures the gradient mask is revealed at the correct pace.
|
||||
t_prev = inverse_loglinear_timestep_shift(self._shift, sigma_prev)
|
||||
mask = self._apply_mask_gradient_adjustment(t_prev)
|
||||
|
||||
# Use shifted sigma for noise mixing to match the denoiser's noise level.
|
||||
# The Euler step produces latents at noise level sigma_prev, so the
|
||||
# preserved regions must also be at sigma_prev noise level.
|
||||
noised_init_latents = self._noise * sigma_prev + (1.0 - sigma_prev) * self._init_latents
|
||||
|
||||
return intermediate_latents * mask + noised_init_latents * (1.0 - mask)
|
||||
|
||||
|
||||
@invocation(
|
||||
"anima_denoise",
|
||||
title="Denoise - Anima",
|
||||
tags=["image", "anima"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class AnimaDenoiseInvocation(BaseInvocation):
|
||||
"""Run the denoising process with an Anima model.
|
||||
|
||||
Uses rectified flow sampling with shift=3.0 and the Cosmos Predict2 DiT
|
||||
backbone with integrated LLM Adapter for text conditioning.
|
||||
|
||||
Supports txt2img, img2img (via latents input), and inpainting (via denoise_mask).
|
||||
"""
|
||||
|
||||
# If latents is provided, this means we are doing image-to-image.
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None, description=FieldDescriptions.latents, input=Input.Connection
|
||||
)
|
||||
# denoise_mask is used for inpainting. Only the masked region is modified.
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
|
||||
)
|
||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
add_noise: bool = InputField(default=True, description="Add noise based on denoising start.")
|
||||
transformer: TransformerField = InputField(
|
||||
description="Anima transformer model.", input=Input.Connection, title="Transformer"
|
||||
)
|
||||
positive_conditioning: AnimaConditioningField | list[AnimaConditioningField] = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_conditioning: AnimaConditioningField | list[AnimaConditioningField] | None = InputField(
|
||||
default=None, description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
)
|
||||
guidance_scale: float = InputField(
|
||||
default=4.5,
|
||||
ge=1.0,
|
||||
description="Guidance scale for classifier-free guidance. Recommended: 4.0-5.0 for Anima.",
|
||||
title="Guidance Scale",
|
||||
)
|
||||
width: int = InputField(default=1024, multiple_of=8, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=8, description="Height of the generated image.")
|
||||
steps: int = InputField(default=30, gt=0, description="Number of denoising steps. 30 recommended for Anima.")
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
scheduler: ANIMA_SCHEDULER_NAME_VALUES = InputField(
|
||||
default="euler",
|
||||
description="Scheduler (sampler) for the denoising process.",
|
||||
ui_choice_labels=ANIMA_SCHEDULER_LABELS,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
latents = latents.detach().to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
||||
"""Prepare the inpaint mask for Anima.
|
||||
|
||||
Anima uses 3D latents [B, C, T, H, W] internally but the mask operates
|
||||
on the spatial dimensions [B, C, H, W] which match the squeezed output.
|
||||
"""
|
||||
if self.denoise_mask is None:
|
||||
return None
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
|
||||
# Invert mask: 0.0 = regions to denoise, 1.0 = regions to preserve
|
||||
mask = 1.0 - mask
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
mask = tv_resize(
|
||||
img=mask,
|
||||
size=[latent_height, latent_width],
|
||||
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
||||
antialias=False,
|
||||
)
|
||||
|
||||
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
||||
return mask
|
||||
|
||||
def _get_noise(
|
||||
self,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
seed: int,
|
||||
) -> torch.Tensor:
|
||||
"""Generate initial noise tensor in 3D latent space [B, C, T, H, W]."""
|
||||
rand_device = "cpu"
|
||||
return torch.randn(
|
||||
1,
|
||||
ANIMA_LATENT_CHANNELS,
|
||||
1, # T=1 for single image
|
||||
height // ANIMA_LATENT_SCALE_FACTOR,
|
||||
width // ANIMA_LATENT_SCALE_FACTOR,
|
||||
device=rand_device,
|
||||
dtype=torch.float32,
|
||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
def _get_sigmas(self, num_steps: int) -> list[float]:
|
||||
"""Generate sigma schedule with fixed shift=3.0.
|
||||
|
||||
Uses the log-linear timestep shift from the Flux model (Black Forest Labs)
|
||||
with a fixed shift factor of 3.0 (no dynamic resolution-based shift).
|
||||
|
||||
Returns:
|
||||
List of num_steps + 1 sigma values from ~1.0 (noise) to 0.0 (clean).
|
||||
"""
|
||||
sigmas = []
|
||||
for i in range(num_steps + 1):
|
||||
t = 1.0 - i / num_steps
|
||||
sigma = loglinear_timestep_shift(ANIMA_SHIFT, t)
|
||||
sigmas.append(sigma)
|
||||
return sigmas
|
||||
|
||||
def _load_conditioning(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
cond_field: AnimaConditioningField,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> AnimaConditioningInfo:
|
||||
"""Load Anima conditioning data from storage."""
|
||||
cond_data = context.conditioning.load(cond_field.conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
cond_info = cond_data.conditionings[0]
|
||||
assert isinstance(cond_info, AnimaConditioningInfo)
|
||||
return cond_info.to(dtype=dtype, device=device)
|
||||
|
||||
def _load_text_conditionings(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
cond_field: AnimaConditioningField | list[AnimaConditioningField],
|
||||
img_token_height: int,
|
||||
img_token_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> list[AnimaTextConditioning]:
|
||||
"""Load Anima text conditioning with optional regional masks.
|
||||
|
||||
Args:
|
||||
context: The invocation context.
|
||||
cond_field: Single conditioning field or list of fields.
|
||||
img_token_height: Height of the image token grid (H // patch_size).
|
||||
img_token_width: Width of the image token grid (W // patch_size).
|
||||
dtype: Target dtype.
|
||||
device: Target device.
|
||||
|
||||
Returns:
|
||||
List of AnimaTextConditioning objects with optional masks.
|
||||
"""
|
||||
cond_list = cond_field if isinstance(cond_field, list) else [cond_field]
|
||||
|
||||
text_conditionings: list[AnimaTextConditioning] = []
|
||||
for cond in cond_list:
|
||||
cond_info = self._load_conditioning(context, cond, dtype, device)
|
||||
|
||||
# Load the mask, if provided
|
||||
mask: torch.Tensor | None = None
|
||||
if cond.mask is not None:
|
||||
mask = context.tensors.load(cond.mask.tensor_name)
|
||||
mask = mask.to(device=device)
|
||||
mask = AnimaRegionalPromptingExtension.preprocess_regional_prompt_mask(
|
||||
mask, img_token_height, img_token_width, dtype, device
|
||||
)
|
||||
|
||||
text_conditionings.append(
|
||||
AnimaTextConditioning(
|
||||
qwen3_embeds=cond_info.qwen3_embeds,
|
||||
t5xxl_ids=cond_info.t5xxl_ids,
|
||||
t5xxl_weights=cond_info.t5xxl_weights,
|
||||
mask=mask,
|
||||
)
|
||||
)
|
||||
|
||||
return text_conditionings
|
||||
|
||||
def _run_llm_adapter_for_regions(
|
||||
self,
|
||||
transformer,
|
||||
text_conditionings: list[AnimaTextConditioning],
|
||||
dtype: torch.dtype,
|
||||
) -> AnimaRegionalTextConditioning:
|
||||
"""Run the LLM Adapter separately for each regional conditioning and concatenate.
|
||||
|
||||
Args:
|
||||
transformer: The AnimaTransformer instance (must be on device).
|
||||
text_conditionings: List of per-region conditioning data.
|
||||
dtype: Inference dtype.
|
||||
|
||||
Returns:
|
||||
AnimaRegionalTextConditioning with concatenated context and masks.
|
||||
"""
|
||||
context_embeds_list: list[torch.Tensor] = []
|
||||
context_ranges: list[Range] = []
|
||||
image_masks: list[torch.Tensor | None] = []
|
||||
cur_len = 0
|
||||
|
||||
for tc in text_conditionings:
|
||||
qwen3_embeds = tc.qwen3_embeds.unsqueeze(0) # (1, seq_len, 1024)
|
||||
t5xxl_ids = tc.t5xxl_ids.unsqueeze(0) # (1, seq_len)
|
||||
t5xxl_weights = None
|
||||
if tc.t5xxl_weights is not None:
|
||||
t5xxl_weights = tc.t5xxl_weights.unsqueeze(0).unsqueeze(-1) # (1, seq_len, 1)
|
||||
|
||||
# Run the LLM Adapter to produce context for this region
|
||||
context = transformer.preprocess_text_embeds(
|
||||
qwen3_embeds.to(dtype=dtype),
|
||||
t5xxl_ids,
|
||||
t5xxl_weights=t5xxl_weights.to(dtype=dtype) if t5xxl_weights is not None else None,
|
||||
)
|
||||
# context shape: (1, 512, 1024) — squeeze batch dim
|
||||
context_2d = context.squeeze(0) # (512, 1024)
|
||||
|
||||
context_embeds_list.append(context_2d)
|
||||
context_ranges.append(Range(start=cur_len, end=cur_len + context_2d.shape[0]))
|
||||
image_masks.append(tc.mask)
|
||||
cur_len += context_2d.shape[0]
|
||||
|
||||
concatenated_context = torch.cat(context_embeds_list, dim=0)
|
||||
|
||||
return AnimaRegionalTextConditioning(
|
||||
context_embeds=concatenated_context,
|
||||
image_masks=image_masks,
|
||||
context_ranges=context_ranges,
|
||||
)
|
||||
|
||||
def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
|
||||
device = TorchDevice.choose_torch_device()
|
||||
inference_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
|
||||
|
||||
if self.denoising_start >= self.denoising_end:
|
||||
raise ValueError(
|
||||
f"denoising_start ({self.denoising_start}) must be less than denoising_end ({self.denoising_end})."
|
||||
)
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
|
||||
# Compute image token grid dimensions for regional prompting
|
||||
# Anima: 8x VAE compression, 2x patch size → 16x total
|
||||
patch_size = 2
|
||||
latent_height = self.height // ANIMA_LATENT_SCALE_FACTOR
|
||||
latent_width = self.width // ANIMA_LATENT_SCALE_FACTOR
|
||||
img_token_height = latent_height // patch_size
|
||||
img_token_width = latent_width // patch_size
|
||||
img_seq_len = img_token_height * img_token_width
|
||||
|
||||
# Load positive conditioning with optional regional masks
|
||||
pos_text_conditionings = self._load_text_conditionings(
|
||||
context=context,
|
||||
cond_field=self.positive_conditioning,
|
||||
img_token_height=img_token_height,
|
||||
img_token_width=img_token_width,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
has_regional = len(pos_text_conditionings) > 1 or any(tc.mask is not None for tc in pos_text_conditionings)
|
||||
|
||||
# Load negative conditioning if CFG is enabled
|
||||
do_cfg = not math.isclose(self.guidance_scale, 1.0) and self.negative_conditioning is not None
|
||||
neg_text_conditionings: list[AnimaTextConditioning] | None = None
|
||||
if do_cfg:
|
||||
assert self.negative_conditioning is not None
|
||||
neg_text_conditionings = self._load_text_conditionings(
|
||||
context=context,
|
||||
cond_field=self.negative_conditioning,
|
||||
img_token_height=img_token_height,
|
||||
img_token_width=img_token_width,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Generate sigma schedule
|
||||
sigmas = self._get_sigmas(self.steps)
|
||||
|
||||
# Apply denoising_start and denoising_end clipping (for img2img/inpaint)
|
||||
if self.denoising_start > 0 or self.denoising_end < 1:
|
||||
total_sigmas = len(sigmas)
|
||||
start_idx = int(self.denoising_start * (total_sigmas - 1))
|
||||
end_idx = int(self.denoising_end * (total_sigmas - 1)) + 1
|
||||
sigmas = sigmas[start_idx:end_idx]
|
||||
|
||||
total_steps = len(sigmas) - 1
|
||||
|
||||
# Load input latents if provided (image-to-image)
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
init_latents = init_latents.to(device=device, dtype=inference_dtype)
|
||||
# Anima denoiser works in 3D: add temporal dim if needed
|
||||
if init_latents.ndim == 4:
|
||||
init_latents = init_latents.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]
|
||||
|
||||
# Generate initial noise (3D latent: [B, C, T, H, W])
|
||||
noise = self._get_noise(self.height, self.width, inference_dtype, device, self.seed)
|
||||
|
||||
# Prepare input latents
|
||||
if init_latents is not None:
|
||||
if self.add_noise:
|
||||
s_0 = sigmas[0]
|
||||
latents = s_0 * noise + (1.0 - s_0) * init_latents
|
||||
else:
|
||||
latents = init_latents
|
||||
else:
|
||||
if self.denoising_start > 1e-5:
|
||||
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
||||
latents = noise
|
||||
|
||||
if total_steps <= 0:
|
||||
return latents.squeeze(2)
|
||||
|
||||
# Prepare inpaint extension
|
||||
inpaint_mask = self._prep_inpaint_mask(context, latents.squeeze(2))
|
||||
inpaint_extension: AnimaInpaintExtension | None = None
|
||||
if inpaint_mask is not None:
|
||||
if init_latents is None:
|
||||
raise ValueError("Initial latents are required when using an inpaint mask (image-to-image inpainting)")
|
||||
inpaint_extension = AnimaInpaintExtension(
|
||||
init_latents=init_latents.squeeze(2),
|
||||
inpaint_mask=inpaint_mask,
|
||||
noise=noise.squeeze(2),
|
||||
shift=ANIMA_SHIFT,
|
||||
)
|
||||
|
||||
step_callback = self._build_step_callback(context)
|
||||
|
||||
# Initialize diffusers scheduler if not using built-in Euler
|
||||
scheduler: SchedulerMixin | None = None
|
||||
use_scheduler = self.scheduler != "euler"
|
||||
|
||||
if use_scheduler:
|
||||
scheduler_class = ANIMA_SCHEDULER_MAP[self.scheduler]
|
||||
scheduler = scheduler_class(num_train_timesteps=1000, shift=1.0)
|
||||
is_lcm = self.scheduler == "lcm"
|
||||
set_timesteps_sig = inspect.signature(scheduler.set_timesteps)
|
||||
if not is_lcm and "sigmas" in set_timesteps_sig.parameters:
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps=total_steps, device=device)
|
||||
num_scheduler_steps = len(scheduler.timesteps)
|
||||
else:
|
||||
num_scheduler_steps = total_steps
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
|
||||
|
||||
# Apply LoRA models to the transformer.
|
||||
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=ANIMA_LORA_TRANSFORMER_PREFIX,
|
||||
dtype=inference_dtype,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
|
||||
# Run LLM Adapter for each regional conditioning to produce context vectors.
|
||||
# This must happen with the transformer on device since it uses the adapter weights.
|
||||
if has_regional:
|
||||
pos_regional = self._run_llm_adapter_for_regions(transformer, pos_text_conditionings, inference_dtype)
|
||||
pos_context = pos_regional.context_embeds.unsqueeze(0) # (1, total_ctx_len, 1024)
|
||||
|
||||
# Build regional prompting extension with cross-attention mask
|
||||
regional_extension = AnimaRegionalPromptingExtension.from_regional_conditioning(
|
||||
pos_regional, img_seq_len
|
||||
)
|
||||
|
||||
# For negative, concatenate all regions without masking (matches Z-Image behavior)
|
||||
neg_context = None
|
||||
if do_cfg and neg_text_conditionings is not None:
|
||||
neg_regional = self._run_llm_adapter_for_regions(
|
||||
transformer, neg_text_conditionings, inference_dtype
|
||||
)
|
||||
neg_context = neg_regional.context_embeds.unsqueeze(0)
|
||||
else:
|
||||
# Single conditioning — run LLM Adapter via normal forward path
|
||||
tc = pos_text_conditionings[0]
|
||||
pos_qwen3_embeds = tc.qwen3_embeds.unsqueeze(0)
|
||||
pos_t5xxl_ids = tc.t5xxl_ids.unsqueeze(0)
|
||||
pos_t5xxl_weights = None
|
||||
if tc.t5xxl_weights is not None:
|
||||
pos_t5xxl_weights = tc.t5xxl_weights.unsqueeze(0).unsqueeze(-1)
|
||||
|
||||
# Pre-compute context via LLM Adapter
|
||||
pos_context = transformer.preprocess_text_embeds(
|
||||
pos_qwen3_embeds.to(dtype=inference_dtype),
|
||||
pos_t5xxl_ids,
|
||||
t5xxl_weights=pos_t5xxl_weights.to(dtype=inference_dtype)
|
||||
if pos_t5xxl_weights is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
neg_context = None
|
||||
if do_cfg and neg_text_conditionings is not None:
|
||||
ntc = neg_text_conditionings[0]
|
||||
neg_qwen3 = ntc.qwen3_embeds.unsqueeze(0)
|
||||
neg_ids = ntc.t5xxl_ids.unsqueeze(0)
|
||||
neg_weights = None
|
||||
if ntc.t5xxl_weights is not None:
|
||||
neg_weights = ntc.t5xxl_weights.unsqueeze(0).unsqueeze(-1)
|
||||
neg_context = transformer.preprocess_text_embeds(
|
||||
neg_qwen3.to(dtype=inference_dtype),
|
||||
neg_ids,
|
||||
t5xxl_weights=neg_weights.to(dtype=inference_dtype) if neg_weights is not None else None,
|
||||
)
|
||||
|
||||
regional_extension = None
|
||||
|
||||
# Apply regional prompting patch if we have regional masks
|
||||
exit_stack.enter_context(patch_anima_for_regional_prompting(transformer, regional_extension))
|
||||
|
||||
# Helper to run transformer with pre-computed context (bypasses LLM Adapter)
|
||||
def _run_transformer(ctx: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||
return transformer(
|
||||
x=x.to(transformer.dtype if hasattr(transformer, "dtype") else inference_dtype),
|
||||
timesteps=t,
|
||||
context=ctx,
|
||||
# t5xxl_ids=None skips the LLM Adapter — context is already pre-computed
|
||||
)
|
||||
|
||||
if use_scheduler and scheduler is not None:
|
||||
# Scheduler-based denoising
|
||||
user_step = 0
|
||||
pbar = tqdm(total=total_steps, desc="Denoising (Anima)")
|
||||
for step_index in range(num_scheduler_steps):
|
||||
sched_timestep = scheduler.timesteps[step_index]
|
||||
sigma_curr = sched_timestep.item() / scheduler.config.num_train_timesteps
|
||||
|
||||
is_heun = hasattr(scheduler, "state_in_first_order")
|
||||
in_first_order = scheduler.state_in_first_order if is_heun else True
|
||||
|
||||
timestep = torch.tensor(
|
||||
[sigma_curr * ANIMA_MULTIPLIER], device=device, dtype=inference_dtype
|
||||
).expand(latents.shape[0])
|
||||
|
||||
noise_pred_cond = _run_transformer(pos_context, latents, timestep).float()
|
||||
|
||||
if do_cfg and neg_context is not None:
|
||||
noise_pred_uncond = _run_transformer(neg_context, latents, timestep).float()
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
else:
|
||||
noise_pred = noise_pred_cond
|
||||
|
||||
step_output = scheduler.step(model_output=noise_pred, timestep=sched_timestep, sample=latents)
|
||||
latents = step_output.prev_sample
|
||||
|
||||
if step_index + 1 < len(scheduler.sigmas):
|
||||
sigma_prev = scheduler.sigmas[step_index + 1].item()
|
||||
else:
|
||||
sigma_prev = 0.0
|
||||
|
||||
if inpaint_extension is not None:
|
||||
latents_4d = latents.squeeze(2)
|
||||
latents_4d = inpaint_extension.merge_intermediate_latents_with_init_latents(
|
||||
latents_4d, sigma_prev
|
||||
)
|
||||
latents = latents_4d.unsqueeze(2)
|
||||
|
||||
if is_heun:
|
||||
if not in_first_order:
|
||||
user_step += 1
|
||||
if user_step <= total_steps:
|
||||
pbar.update(1)
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=user_step,
|
||||
order=2,
|
||||
total_steps=total_steps,
|
||||
timestep=int(sigma_curr * 1000),
|
||||
latents=latents.squeeze(2),
|
||||
)
|
||||
)
|
||||
else:
|
||||
user_step += 1
|
||||
if user_step <= total_steps:
|
||||
pbar.update(1)
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=user_step,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(sigma_curr * 1000),
|
||||
latents=latents.squeeze(2),
|
||||
)
|
||||
)
|
||||
pbar.close()
|
||||
else:
|
||||
# Built-in Euler implementation (default for Anima)
|
||||
for step_idx in tqdm(range(total_steps), desc="Denoising (Anima)"):
|
||||
sigma_curr = sigmas[step_idx]
|
||||
sigma_prev = sigmas[step_idx + 1]
|
||||
|
||||
timestep = torch.tensor(
|
||||
[sigma_curr * ANIMA_MULTIPLIER], device=device, dtype=inference_dtype
|
||||
).expand(latents.shape[0])
|
||||
|
||||
noise_pred_cond = _run_transformer(pos_context, latents, timestep).float()
|
||||
|
||||
if do_cfg and neg_context is not None:
|
||||
noise_pred_uncond = _run_transformer(neg_context, latents, timestep).float()
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
else:
|
||||
noise_pred = noise_pred_cond
|
||||
|
||||
latents_dtype = latents.dtype
|
||||
latents = latents.to(dtype=torch.float32)
|
||||
latents = latents + (sigma_prev - sigma_curr) * noise_pred
|
||||
latents = latents.to(dtype=latents_dtype)
|
||||
|
||||
if inpaint_extension is not None:
|
||||
latents_4d = latents.squeeze(2)
|
||||
latents_4d = inpaint_extension.merge_intermediate_latents_with_init_latents(
|
||||
latents_4d, sigma_prev
|
||||
)
|
||||
latents = latents_4d.unsqueeze(2)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=step_idx + 1,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(sigma_curr * 1000),
|
||||
latents=latents.squeeze(2),
|
||||
),
|
||||
)
|
||||
|
||||
# Remove temporal dimension for output: [B, C, 1, H, W] -> [B, C, H, W]
|
||||
return latents.squeeze(2)
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, BaseModelType.Anima)
|
||||
|
||||
return step_callback
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
"""Iterate over LoRA models to apply to the transformer."""
|
||||
for lora in self.transformer.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
if not isinstance(lora_info.model, ModelPatchRaw):
|
||||
raise TypeError(
|
||||
f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. "
|
||||
"The LoRA model may be corrupted or incompatible."
|
||||
)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
119
invokeai/app/invocations/anima_image_to_latents.py
Normal file
119
invokeai/app/invocations/anima_image_to_latents.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Anima image-to-latents invocation.
|
||||
|
||||
Encodes an image to latent space using the Anima VAE (AutoencoderKLWan or FLUX VAE).
|
||||
|
||||
For Wan VAE (AutoencoderKLWan):
|
||||
- Input image is converted to 5D tensor [B, C, T, H, W] with T=1
|
||||
- After encoding, latents are normalized: (latents - mean) / std
|
||||
(inverse of the denormalization in anima_latents_to_image.py)
|
||||
|
||||
For FLUX VAE (AutoEncoder):
|
||||
- Encoding is handled internally by the FLUX VAE
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from diffusers.models.autoencoders import AutoencoderKLWan
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder as FluxAutoEncoder
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
|
||||
|
||||
AnimaVAE = Union[AutoencoderKLWan, FluxAutoEncoder]
|
||||
|
||||
|
||||
@invocation(
|
||||
"anima_i2l",
|
||||
title="Image to Latents - Anima",
|
||||
tags=["image", "latents", "vae", "i2l", "anima"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class AnimaImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates latents from an image using the Anima VAE (supports Wan 2.1 and FLUX VAE)."""
|
||||
|
||||
image: ImageField = InputField(description="The image to encode.")
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
if not isinstance(vae_info.model, (AutoencoderKLWan, FluxAutoEncoder)):
|
||||
raise TypeError(
|
||||
f"Expected AutoencoderKLWan or FluxAutoEncoder for Anima VAE, got {type(vae_info.model).__name__}."
|
||||
)
|
||||
|
||||
estimated_working_memory = estimate_vae_working_memory_flux(
|
||||
operation="encode",
|
||||
image_tensor=image_tensor,
|
||||
vae=vae_info.model,
|
||||
)
|
||||
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
if not isinstance(vae, (AutoencoderKLWan, FluxAutoEncoder)):
|
||||
raise TypeError(f"Expected AutoencoderKLWan or FluxAutoEncoder, got {type(vae).__name__}.")
|
||||
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
||||
|
||||
with torch.inference_mode():
|
||||
if isinstance(vae, FluxAutoEncoder):
|
||||
# FLUX VAE handles scaling internally
|
||||
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
|
||||
latents = vae.encode(image_tensor, sample=True, generator=generator)
|
||||
else:
|
||||
# AutoencoderKLWan expects 5D input [B, C, T, H, W]
|
||||
if image_tensor.ndim == 4:
|
||||
image_tensor = image_tensor.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]
|
||||
|
||||
encoded = vae.encode(image_tensor, return_dict=False)[0]
|
||||
latents = encoded.sample().to(dtype=vae_dtype)
|
||||
|
||||
# Normalize to denoiser space: (latents - mean) / std
|
||||
# This is the inverse of the denormalization in anima_latents_to_image.py
|
||||
latents_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latents)
|
||||
latents_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(latents)
|
||||
latents = (latents - latents_mean) / latents_std
|
||||
|
||||
# Remove temporal dimension: [B, C, 1, H, W] -> [B, C, H, W]
|
||||
if latents.ndim == 5:
|
||||
latents = latents.squeeze(2)
|
||||
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
if not isinstance(vae_info.model, (AutoencoderKLWan, FluxAutoEncoder)):
|
||||
raise TypeError(
|
||||
f"Expected AutoencoderKLWan or FluxAutoEncoder for Anima VAE, got {type(vae_info.model).__name__}."
|
||||
)
|
||||
|
||||
context.util.signal_progress("Running Anima VAE encode")
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
108
invokeai/app/invocations/anima_latents_to_image.py
Normal file
108
invokeai/app/invocations/anima_latents_to_image.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Anima latents-to-image invocation.
|
||||
|
||||
Decodes Anima latents using the QwenImage VAE (AutoencoderKLWan) or
|
||||
compatible FLUX VAE as fallback.
|
||||
|
||||
Latents from the denoiser are in normalized space (zero-centered). Before
|
||||
VAE decode, they must be denormalized using the Wan 2.1 per-channel
|
||||
mean/std: latents = latents * std + mean (matching diffusers WanPipeline).
|
||||
|
||||
The VAE expects 5D latents [B, C, T, H, W] — for single images, T=1.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from diffusers.models.autoencoders import AutoencoderKLWan
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder as FluxAutoEncoder
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux
|
||||
|
||||
|
||||
@invocation(
|
||||
"anima_l2i",
|
||||
title="Latents to Image - Anima",
|
||||
tags=["latents", "image", "vae", "l2i", "anima"],
|
||||
category="latents",
|
||||
version="1.0.2",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class AnimaLatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents using the Anima VAE.
|
||||
|
||||
Supports the Wan 2.1 QwenImage VAE (AutoencoderKLWan) with explicit
|
||||
latent denormalization, and FLUX VAE as fallback.
|
||||
"""
|
||||
|
||||
latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
if not isinstance(vae_info.model, (AutoencoderKLWan, FluxAutoEncoder)):
|
||||
raise TypeError(
|
||||
f"Expected AutoencoderKLWan or FluxAutoEncoder for Anima VAE, got {type(vae_info.model).__name__}."
|
||||
)
|
||||
|
||||
estimated_working_memory = estimate_vae_working_memory_flux(
|
||||
operation="decode",
|
||||
image_tensor=latents,
|
||||
vae=vae_info.model,
|
||||
)
|
||||
|
||||
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
|
||||
context.util.signal_progress("Running Anima VAE decode")
|
||||
if not isinstance(vae, (AutoencoderKLWan, FluxAutoEncoder)):
|
||||
raise TypeError(f"Expected AutoencoderKLWan or FluxAutoEncoder, got {type(vae).__name__}.")
|
||||
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
if isinstance(vae, FluxAutoEncoder):
|
||||
# FLUX VAE handles scaling internally, expects 4D [B, C, H, W]
|
||||
img = vae.decode(latents)
|
||||
else:
|
||||
# Expects 5D latents [B, C, T, H, W]
|
||||
if latents.ndim == 4:
|
||||
latents = latents.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]
|
||||
|
||||
# Denormalize from denoiser space to raw VAE space
|
||||
# (same as diffusers WanPipeline and ComfyUI Wan21.process_out)
|
||||
latents_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latents)
|
||||
latents_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(latents)
|
||||
latents = latents * latents_std + latents_mean
|
||||
|
||||
decoded = vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
# Output is 5D [B, C, T, H, W] — squeeze temporal dim
|
||||
if decoded.ndim == 5:
|
||||
decoded = decoded.squeeze(2)
|
||||
img = decoded
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
image_dto = context.images.save(image=img_pil)
|
||||
return ImageOutput.build(image_dto)
|
||||
162
invokeai/app/invocations/anima_lora_loader.py
Normal file
162
invokeai/app/invocations/anima_lora_loader.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, Qwen3EncoderField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
@invocation_output("anima_lora_loader_output")
|
||||
class AnimaLoRALoaderOutput(BaseInvocationOutput):
|
||||
"""Anima LoRA Loader Output"""
|
||||
|
||||
transformer: Optional[TransformerField] = OutputField(
|
||||
default=None, description=FieldDescriptions.transformer, title="Anima Transformer"
|
||||
)
|
||||
qwen3_encoder: Optional[Qwen3EncoderField] = OutputField(
|
||||
default=None, description=FieldDescriptions.qwen3_encoder, title="Qwen3 Encoder"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"anima_lora_loader",
|
||||
title="Apply LoRA - Anima",
|
||||
tags=["lora", "model", "anima"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class AnimaLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply a LoRA model to an Anima transformer and/or Qwen3 text encoder."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model,
|
||||
title="LoRA",
|
||||
ui_model_base=BaseModelType.Anima,
|
||||
ui_model_type=ModelType.LoRA,
|
||||
)
|
||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||
transformer: TransformerField | None = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="Anima Transformer",
|
||||
)
|
||||
qwen3_encoder: Qwen3EncoderField | None = InputField(
|
||||
default=None,
|
||||
title="Qwen3 Encoder",
|
||||
description=FieldDescriptions.qwen3_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> AnimaLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras):
|
||||
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
|
||||
if self.qwen3_encoder and any(lora.lora.key == lora_key for lora in self.qwen3_encoder.loras):
|
||||
raise ValueError(f'LoRA "{lora_key}" already applied to Qwen3 encoder.')
|
||||
|
||||
output = AnimaLoRALoaderOutput()
|
||||
|
||||
if self.transformer is not None:
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
output.transformer.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
if self.qwen3_encoder is not None:
|
||||
output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True)
|
||||
output.qwen3_encoder.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@invocation(
|
||||
"anima_lora_collection_loader",
|
||||
title="Apply LoRA Collection - Anima",
|
||||
tags=["lora", "model", "anima"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class AnimaLoRACollectionLoader(BaseInvocation):
|
||||
"""Applies a collection of LoRAs to an Anima transformer."""
|
||||
|
||||
loras: Optional[LoRAField | list[LoRAField]] = InputField(
|
||||
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
|
||||
)
|
||||
|
||||
transformer: Optional[TransformerField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
qwen3_encoder: Qwen3EncoderField | None = InputField(
|
||||
default=None,
|
||||
title="Qwen3 Encoder",
|
||||
description=FieldDescriptions.qwen3_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> AnimaLoRALoaderOutput:
|
||||
output = AnimaLoRALoaderOutput()
|
||||
|
||||
if self.loras is None:
|
||||
if self.transformer is not None:
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
if self.qwen3_encoder is not None:
|
||||
output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True)
|
||||
return output
|
||||
|
||||
loras = self.loras if isinstance(self.loras, list) else [self.loras]
|
||||
added_loras: list[str] = []
|
||||
|
||||
if self.transformer is not None:
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
|
||||
if self.qwen3_encoder is not None:
|
||||
output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True)
|
||||
|
||||
for lora in loras:
|
||||
if lora is None:
|
||||
continue
|
||||
if lora.lora.key in added_loras:
|
||||
continue
|
||||
|
||||
if not context.models.exists(lora.lora.key):
|
||||
raise ValueError(f"Unknown lora: {lora.lora.key}!")
|
||||
|
||||
if lora.lora.base is not BaseModelType.Anima:
|
||||
raise ValueError(
|
||||
f"LoRA '{lora.lora.key}' is for {lora.lora.base.value if lora.lora.base else 'unknown'} models, "
|
||||
"not Anima models. Ensure you are using an Anima compatible LoRA."
|
||||
)
|
||||
|
||||
added_loras.append(lora.lora.key)
|
||||
|
||||
if self.transformer is not None and output.transformer is not None:
|
||||
output.transformer.loras.append(lora)
|
||||
|
||||
if self.qwen3_encoder is not None and output.qwen3_encoder is not None:
|
||||
output.qwen3_encoder.loras.append(lora)
|
||||
|
||||
return output
|
||||
102
invokeai/app/invocations/anima_model_loader.py
Normal file
102
invokeai/app/invocations/anima_model_loader.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import (
|
||||
ModelIdentifierField,
|
||||
Qwen3EncoderField,
|
||||
T5EncoderField,
|
||||
TransformerField,
|
||||
VAEField,
|
||||
)
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.t5_model_identifier import (
|
||||
preprocess_t5_encoder_model_identifier,
|
||||
preprocess_t5_tokenizer_model_identifier,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
|
||||
|
||||
|
||||
@invocation_output("anima_model_loader_output")
|
||||
class AnimaModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Anima model loader output."""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
qwen3_encoder: Qwen3EncoderField = OutputField(description=FieldDescriptions.qwen3_encoder, title="Qwen3 Encoder")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
|
||||
|
||||
|
||||
@invocation(
|
||||
"anima_model_loader",
|
||||
title="Main Model - Anima",
|
||||
tags=["model", "anima"],
|
||||
category="model",
|
||||
version="1.3.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class AnimaModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an Anima model, outputting its submodels.
|
||||
|
||||
Anima uses:
|
||||
- Transformer: Cosmos Predict2 DiT + LLM Adapter (from single-file checkpoint)
|
||||
- Qwen3 Encoder: Qwen3 0.6B (standalone single-file)
|
||||
- VAE: AutoencoderKLQwenImage / Wan 2.1 VAE (standalone single-file or FLUX VAE)
|
||||
- T5 Encoder: T5-XXL model (only the tokenizer submodel is used, for LLM Adapter token IDs)
|
||||
"""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description="Anima main model (transformer + LLM adapter).",
|
||||
input=Input.Direct,
|
||||
ui_model_base=BaseModelType.Anima,
|
||||
ui_model_type=ModelType.Main,
|
||||
title="Transformer",
|
||||
)
|
||||
|
||||
vae_model: ModelIdentifierField = InputField(
|
||||
description="Standalone VAE model. Anima uses a Wan 2.1 / QwenImage VAE (16-channel). "
|
||||
"A FLUX VAE can also be used as a compatible fallback.",
|
||||
input=Input.Direct,
|
||||
ui_model_type=ModelType.VAE,
|
||||
title="VAE",
|
||||
)
|
||||
|
||||
qwen3_encoder_model: ModelIdentifierField = InputField(
|
||||
description="Standalone Qwen3 0.6B Encoder model.",
|
||||
input=Input.Direct,
|
||||
ui_model_type=ModelType.Qwen3Encoder,
|
||||
title="Qwen3 Encoder",
|
||||
)
|
||||
|
||||
t5_encoder_model: ModelIdentifierField = InputField(
|
||||
description="T5-XXL encoder model. The tokenizer submodel is used for Anima text encoding.",
|
||||
input=Input.Direct,
|
||||
ui_model_type=ModelType.T5Encoder,
|
||||
title="T5 Encoder",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> AnimaModelLoaderOutput:
|
||||
# Transformer always comes from the main model
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
|
||||
# VAE
|
||||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
# Qwen3 Encoder
|
||||
qwen3_tokenizer = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
qwen3_encoder = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
|
||||
# T5 Encoder (only tokenizer submodel is used by Anima)
|
||||
t5_tokenizer = preprocess_t5_tokenizer_model_identifier(self.t5_encoder_model)
|
||||
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model)
|
||||
|
||||
return AnimaModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
qwen3_encoder=Qwen3EncoderField(tokenizer=qwen3_tokenizer, text_encoder=qwen3_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
t5_encoder=T5EncoderField(tokenizer=t5_tokenizer, text_encoder=t5_encoder, loras=[]),
|
||||
)
|
||||
221
invokeai/app/invocations/anima_text_encoder.py
Normal file
221
invokeai/app/invocations/anima_text_encoder.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""Anima text encoder invocation.
|
||||
|
||||
Encodes text using the dual-conditioning pipeline:
|
||||
1. Qwen3 0.6B: Produces hidden states (last layer)
|
||||
2. T5-XXL Tokenizer: Produces token IDs only (no T5 model needed)
|
||||
|
||||
Both outputs are stored together in AnimaConditioningInfo and used by
|
||||
the LLM Adapter inside the transformer during denoising.
|
||||
|
||||
Key differences from Z-Image text encoder:
|
||||
- Anima uses Qwen3 0.6B (base model, NOT instruct) — no chat template
|
||||
- Anima additionally tokenizes with T5-XXL tokenizer to get token IDs
|
||||
- Qwen3 output uses all positions (including padding) for full context
|
||||
"""
|
||||
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
AnimaConditioningField,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.invocations.model import Qwen3EncoderField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import AnimaConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.anima_lora_constants import ANIMA_LORA_QWEN3_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
AnimaConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
)
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.get_logger(__name__)
|
||||
|
||||
# T5-XXL max sequence length for token IDs
|
||||
T5_MAX_SEQ_LEN = 512
|
||||
|
||||
# Safety cap for Qwen3 sequence length to prevent GPU OOM on extremely long prompts.
|
||||
# Qwen3 0.6B supports 32K context but the LLM Adapter doesn't need that much.
|
||||
QWEN3_MAX_SEQ_LEN = 8192
|
||||
|
||||
|
||||
@invocation(
|
||||
"anima_text_encoder",
|
||||
title="Prompt - Anima",
|
||||
tags=["prompt", "conditioning", "anima"],
|
||||
category="conditioning",
|
||||
version="1.3.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class AnimaTextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes and preps a prompt for an Anima image.
|
||||
|
||||
Uses Qwen3 0.6B for hidden state extraction and T5-XXL tokenizer for
|
||||
token IDs (no T5 model weights needed). Both are combined by the
|
||||
LLM Adapter inside the Anima transformer during denoising.
|
||||
"""
|
||||
|
||||
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
|
||||
qwen3_encoder: Qwen3EncoderField = InputField(
|
||||
title="Qwen3 Encoder",
|
||||
description=FieldDescriptions.qwen3_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
t5_encoder: T5EncoderField = InputField(
|
||||
title="T5 Encoder",
|
||||
description=FieldDescriptions.t5_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
mask: TensorField | None = InputField(
|
||||
default=None,
|
||||
description="A mask defining the region that this conditioning prompt applies to.",
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> AnimaConditioningOutput:
|
||||
qwen3_embeds, t5xxl_ids, t5xxl_weights = self._encode_prompt(context)
|
||||
|
||||
# Move to CPU for storage
|
||||
qwen3_embeds = qwen3_embeds.detach().to("cpu")
|
||||
t5xxl_ids = t5xxl_ids.detach().to("cpu")
|
||||
t5xxl_weights = t5xxl_weights.detach().to("cpu") if t5xxl_weights is not None else None
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
AnimaConditioningInfo(
|
||||
qwen3_embeds=qwen3_embeds,
|
||||
t5xxl_ids=t5xxl_ids,
|
||||
t5xxl_weights=t5xxl_weights,
|
||||
)
|
||||
]
|
||||
)
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return AnimaConditioningOutput(
|
||||
conditioning=AnimaConditioningField(conditioning_name=conditioning_name, mask=self.mask)
|
||||
)
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
"""Encode prompt using Qwen3 0.6B and T5-XXL tokenizer.
|
||||
|
||||
Returns:
|
||||
Tuple of (qwen3_embeds, t5xxl_ids, t5xxl_weights).
|
||||
- qwen3_embeds: Shape (max_seq_len, 1024) — includes all positions (including padding)
|
||||
to preserve full sequence context for the LLM Adapter.
|
||||
- t5xxl_ids: Shape (seq_len,) — T5-XXL token IDs (unpadded).
|
||||
- t5xxl_weights: None (uniform weights for now).
|
||||
"""
|
||||
prompt = self.prompt
|
||||
|
||||
# --- Step 1: Encode with Qwen3 0.6B ---
|
||||
text_encoder_info = context.models.load(self.qwen3_encoder.text_encoder)
|
||||
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
(_, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
|
||||
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
|
||||
|
||||
device = text_encoder.device
|
||||
|
||||
# Apply LoRA models to the text encoder
|
||||
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=text_encoder,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=ANIMA_LORA_QWEN3_PREFIX,
|
||||
dtype=lora_dtype,
|
||||
)
|
||||
)
|
||||
|
||||
if not isinstance(text_encoder, PreTrainedModel):
|
||||
raise TypeError(f"Expected PreTrainedModel for text encoder, got {type(text_encoder).__name__}.")
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
||||
raise TypeError(f"Expected PreTrainedTokenizerBase for tokenizer, got {type(tokenizer).__name__}.")
|
||||
|
||||
context.util.signal_progress("Running Qwen3 0.6B text encoder")
|
||||
|
||||
# Anima uses base Qwen3 (not instruct) — tokenize directly, no chat template.
|
||||
# A safety cap is applied to prevent GPU OOM on extremely long prompts.
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding=False,
|
||||
truncation=True,
|
||||
max_length=QWEN3_MAX_SEQ_LEN,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask
|
||||
if not isinstance(text_input_ids, torch.Tensor) or not isinstance(attention_mask, torch.Tensor):
|
||||
raise TypeError("Tokenizer returned unexpected types.")
|
||||
|
||||
if text_input_ids.shape[-1] == QWEN3_MAX_SEQ_LEN:
|
||||
logger.warning(
|
||||
f"Prompt was truncated to {QWEN3_MAX_SEQ_LEN} tokens. "
|
||||
"Consider shortening the prompt for best results."
|
||||
)
|
||||
|
||||
# Ensure at least 1 token (empty prompts produce 0 tokens with padding=False)
|
||||
if text_input_ids.shape[-1] == 0:
|
||||
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
|
||||
text_input_ids = torch.tensor([[pad_id]])
|
||||
attention_mask = torch.tensor([[1]])
|
||||
|
||||
# Get last hidden state from Qwen3 (final layer output)
|
||||
prompt_mask = attention_mask.to(device).bool()
|
||||
outputs = text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=prompt_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
|
||||
raise RuntimeError("Text encoder did not return hidden_states.")
|
||||
if len(outputs.hidden_states) < 1:
|
||||
raise RuntimeError(f"Expected at least 1 hidden state, got {len(outputs.hidden_states)}.")
|
||||
|
||||
# Use last hidden state — only real tokens, no padding
|
||||
qwen3_embeds = outputs.hidden_states[-1][0] # Shape: (seq_len, 1024)
|
||||
|
||||
# --- Step 2: Tokenize with T5-XXL tokenizer (IDs only, no model) ---
|
||||
context.util.signal_progress("Tokenizing with T5-XXL")
|
||||
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
|
||||
with t5_tokenizer_info.model_on_device() as (_, t5_tokenizer):
|
||||
t5_tokens = t5_tokenizer(
|
||||
prompt,
|
||||
padding=False,
|
||||
truncation=True,
|
||||
max_length=T5_MAX_SEQ_LEN,
|
||||
return_tensors="pt",
|
||||
)
|
||||
t5xxl_ids = t5_tokens.input_ids[0] # Shape: (seq_len,)
|
||||
|
||||
return qwen3_embeds, t5xxl_ids, None
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
"""Iterate over LoRA models to apply to the Qwen3 text encoder."""
|
||||
for lora in self.qwen3_encoder.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
if not isinstance(lora_info.model, ModelPatchRaw):
|
||||
raise TypeError(
|
||||
f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. "
|
||||
"The LoRA model may be corrupted or incompatible."
|
||||
)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
@@ -56,7 +56,7 @@ class BaseBatchInvocation(BaseInvocation):
|
||||
"image_batch",
|
||||
title="Image Batch",
|
||||
tags=["primitives", "image", "batch", "special"],
|
||||
category="primitives",
|
||||
category="batch",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
@@ -87,7 +87,7 @@ class ImageGeneratorField(BaseModel):
|
||||
"image_generator",
|
||||
title="Image Generator",
|
||||
tags=["primitives", "board", "image", "batch", "special"],
|
||||
category="primitives",
|
||||
category="batch",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
@@ -111,7 +111,7 @@ class ImageGenerator(BaseInvocation):
|
||||
"string_batch",
|
||||
title="String Batch",
|
||||
tags=["primitives", "string", "batch", "special"],
|
||||
category="primitives",
|
||||
category="batch",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
@@ -142,7 +142,7 @@ class StringGeneratorField(BaseModel):
|
||||
"string_generator",
|
||||
title="String Generator",
|
||||
tags=["primitives", "string", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
category="batch",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
@@ -166,7 +166,7 @@ class StringGenerator(BaseInvocation):
|
||||
"integer_batch",
|
||||
title="Integer Batch",
|
||||
tags=["primitives", "integer", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
category="batch",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
@@ -195,7 +195,7 @@ class IntegerGeneratorField(BaseModel):
|
||||
"integer_generator",
|
||||
title="Integer Generator",
|
||||
tags=["primitives", "int", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
category="batch",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
@@ -219,7 +219,7 @@ class IntegerGenerator(BaseInvocation):
|
||||
"float_batch",
|
||||
title="Float Batch",
|
||||
tags=["primitives", "float", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
category="batch",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
@@ -250,7 +250,7 @@ class FloatGeneratorField(BaseModel):
|
||||
"float_generator",
|
||||
title="Float Generator",
|
||||
tags=["primitives", "float", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
category="batch",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ from invokeai.backend.image_util.util import cv2_to_pil, pil_to_cv2
|
||||
"canny_edge_detection",
|
||||
title="Canny Edge Detection",
|
||||
tags=["controlnet", "canny"],
|
||||
category="controlnet",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.0.0",
|
||||
)
|
||||
class CannyEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
27
invokeai/app/invocations/canvas.py
Normal file
27
invokeai/app/invocations/canvas.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation(
|
||||
"canvas_output",
|
||||
title="Canvas Output",
|
||||
tags=["canvas", "output", "image"],
|
||||
category="canvas",
|
||||
version="1.0.0",
|
||||
use_cache=False,
|
||||
)
|
||||
class CanvasOutputInvocation(BaseInvocation):
|
||||
"""Outputs an image to the canvas staging area.
|
||||
|
||||
Use this node in workflows intended for canvas workflow integration.
|
||||
Connect the final image of your workflow to this node to send it
|
||||
to the canvas staging area when run via 'Run Workflow on Canvas'."""
|
||||
|
||||
image: ImageField = InputField(description=FieldDescriptions.image)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -33,7 +33,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
"cogview4_denoise",
|
||||
title="Denoise - CogView4",
|
||||
tags=["image", "cogview4"],
|
||||
category="image",
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
|
||||
@@ -27,7 +27,7 @@ from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory
|
||||
"cogview4_i2l",
|
||||
title="Image to Latents - CogView4",
|
||||
tags=["image", "latents", "vae", "i2l", "cogview4"],
|
||||
category="image",
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.model import GlmEncoderField
|
||||
from invokeai.app.invocations.primitives import CogView4ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
CogView4ConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
@@ -19,7 +20,7 @@ COGVIEW4_GLM_MAX_SEQ_LEN = 1024
|
||||
"cogview4_text_encoder",
|
||||
title="Prompt - CogView4",
|
||||
tags=["prompt", "conditioning", "cogview4"],
|
||||
category="conditioning",
|
||||
category="prompt",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
@@ -46,10 +47,18 @@ class CogView4TextEncoderInvocation(BaseInvocation):
|
||||
prompt = [self.prompt]
|
||||
|
||||
# TODO(ryand): Add model inputs to the invocation rather than hard-coding.
|
||||
glm_text_encoder_info = context.models.load(self.glm_encoder.text_encoder)
|
||||
with (
|
||||
context.models.load(self.glm_encoder.text_encoder).model_on_device() as (_, glm_text_encoder),
|
||||
glm_text_encoder_info.model_on_device() as (_, glm_text_encoder),
|
||||
context.models.load(self.glm_encoder.tokenizer).model_on_device() as (_, glm_tokenizer),
|
||||
):
|
||||
repaired_tensors = glm_text_encoder_info.repair_required_tensors_on_device()
|
||||
device = get_effective_device(glm_text_encoder)
|
||||
if repaired_tensors > 0:
|
||||
context.logger.warning(
|
||||
f"Recovered {repaired_tensors} required GLM tensor(s) onto {device} after a partial device mismatch."
|
||||
)
|
||||
|
||||
context.util.signal_progress("Running GLM text encoder")
|
||||
assert isinstance(glm_text_encoder, GlmModel)
|
||||
assert isinstance(glm_tokenizer, PreTrainedTokenizerFast)
|
||||
@@ -85,9 +94,7 @@ class CogView4TextEncoderInvocation(BaseInvocation):
|
||||
device=text_input_ids.device,
|
||||
)
|
||||
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
|
||||
prompt_embeds = glm_text_encoder(
|
||||
text_input_ids.to(glm_text_encoder.device), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
prompt_embeds = glm_text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2]
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
|
||||
@@ -11,9 +11,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
|
||||
|
||||
@invocation(
|
||||
"range", title="Integer Range", tags=["collection", "integer", "range"], category="collections", version="1.0.0"
|
||||
)
|
||||
@invocation("range", title="Integer Range", tags=["collection", "integer", "range"], category="batch", version="1.0.0")
|
||||
class RangeInvocation(BaseInvocation):
|
||||
"""Creates a range of numbers from start to stop with step"""
|
||||
|
||||
@@ -35,7 +33,7 @@ class RangeInvocation(BaseInvocation):
|
||||
"range_of_size",
|
||||
title="Integer Range of Size",
|
||||
tags=["collection", "integer", "size", "range"],
|
||||
category="collections",
|
||||
category="batch",
|
||||
version="1.0.0",
|
||||
)
|
||||
class RangeOfSizeInvocation(BaseInvocation):
|
||||
@@ -55,7 +53,7 @@ class RangeOfSizeInvocation(BaseInvocation):
|
||||
"random_range",
|
||||
title="Random Range",
|
||||
tags=["range", "integer", "random", "collection"],
|
||||
category="collections",
|
||||
category="batch",
|
||||
version="1.0.1",
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
"color_map",
|
||||
title="Color Map",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ColorMapInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@@ -19,6 +19,7 @@ from invokeai.app.invocations.model import CLIPField
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
@@ -42,7 +43,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
"compel",
|
||||
title="Prompt - SD1.5",
|
||||
tags=["prompt", "compel"],
|
||||
category="conditioning",
|
||||
category="prompt",
|
||||
version="1.2.1",
|
||||
)
|
||||
class CompelInvocation(BaseInvocation):
|
||||
@@ -103,7 +104,7 @@ class CompelInvocation(BaseInvocation):
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
device=text_encoder.device, # Use the device the model is actually on
|
||||
device=get_effective_device(text_encoder),
|
||||
split_long_text_mode=SplitLongTextMode.SENTENCES,
|
||||
)
|
||||
|
||||
@@ -212,7 +213,7 @@ class SDXLPromptInvocationBase:
|
||||
truncate_long_prompts=False, # TODO:
|
||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||
requires_pooled=get_pooled,
|
||||
device=text_encoder.device, # Use the device the model is actually on
|
||||
device=get_effective_device(text_encoder),
|
||||
split_long_text_mode=SplitLongTextMode.SENTENCES,
|
||||
)
|
||||
|
||||
@@ -247,7 +248,7 @@ class SDXLPromptInvocationBase:
|
||||
"sdxl_compel_prompt",
|
||||
title="Prompt - SDXL",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
category="prompt",
|
||||
version="1.2.1",
|
||||
)
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
@@ -341,7 +342,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"sdxl_refiner_compel_prompt",
|
||||
title="Prompt - SDXL Refiner",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
category="prompt",
|
||||
version="1.1.2",
|
||||
)
|
||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
@@ -390,7 +391,7 @@ class CLIPSkipInvocationOutput(BaseInvocationOutput):
|
||||
"clip_skip",
|
||||
title="Apply CLIP Skip - SD1.5, SDXL",
|
||||
tags=["clipskip", "clip", "skip"],
|
||||
category="conditioning",
|
||||
category="prompt",
|
||||
version="1.1.1",
|
||||
)
|
||||
class CLIPSkipInvocation(BaseInvocation):
|
||||
|
||||
@@ -9,7 +9,7 @@ from invokeai.backend.image_util.content_shuffle import content_shuffle
|
||||
"content_shuffle",
|
||||
title="Content Shuffle",
|
||||
tags=["controlnet", "normal"],
|
||||
category="controlnet",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ContentShuffleInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@@ -64,7 +64,7 @@ class ControlOutput(BaseInvocationOutput):
|
||||
|
||||
|
||||
@invocation(
|
||||
"controlnet", title="ControlNet - SD1.5, SD2, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3"
|
||||
"controlnet", title="ControlNet - SD1.5, SD2, SDXL", tags=["controlnet"], category="conditioning", version="1.1.3"
|
||||
)
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
@@ -116,7 +116,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
"heuristic_resize",
|
||||
title="Heuristic Resize",
|
||||
tags=["image, controlnet"],
|
||||
category="image",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.1.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_t
|
||||
"create_denoise_mask",
|
||||
title="Create Denoise Mask",
|
||||
tags=["mask", "denoise"],
|
||||
category="latents",
|
||||
category="mask",
|
||||
version="1.0.2",
|
||||
)
|
||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
|
||||
@@ -41,7 +41,7 @@ class GradientMaskOutput(BaseInvocationOutput):
|
||||
"create_gradient_mask",
|
||||
title="Create Gradient Mask",
|
||||
tags=["mask", "denoise"],
|
||||
category="latents",
|
||||
category="mask",
|
||||
version="1.3.0",
|
||||
)
|
||||
class CreateGradientMaskInvocation(BaseInvocation):
|
||||
|
||||
@@ -20,7 +20,7 @@ DEPTH_ANYTHING_MODELS = {
|
||||
"depth_anything_depth_estimation",
|
||||
title="Depth Anything Depth Estimation",
|
||||
tags=["controlnet", "depth", "depth anything"],
|
||||
category="controlnet",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.0.0",
|
||||
)
|
||||
class DepthAnythingDepthEstimationInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@@ -11,7 +11,7 @@ from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
|
||||
"dw_openpose_detection",
|
||||
title="DW Openpose Detection",
|
||||
tags=["controlnet", "dwpose", "openpose"],
|
||||
category="controlnet",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.1.1",
|
||||
)
|
||||
class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
203
invokeai/app/invocations/external_image_generation.py
Normal file
203
invokeai/app/invocations/external_image_generation.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from typing import Any, ClassVar, Literal
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
MetadataField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageCollectionOutput
|
||||
from invokeai.app.services.external_generation.external_generation_common import (
|
||||
ExternalGenerationRequest,
|
||||
ExternalGenerationResult,
|
||||
ExternalReferenceImage,
|
||||
)
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.configs.external_api import ExternalApiModelConfig, ExternalGenerationMode
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
|
||||
|
||||
|
||||
class BaseExternalImageGenerationInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generate images using an external provider."""
|
||||
|
||||
provider_id: ClassVar[str | None] = None
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.main_model,
|
||||
ui_model_base=[BaseModelType.External],
|
||||
ui_model_type=[ModelType.ExternalImageGenerator],
|
||||
ui_model_format=[ModelFormat.ExternalApi],
|
||||
)
|
||||
mode: ExternalGenerationMode = InputField(default="txt2img", description="Generation mode")
|
||||
prompt: str = InputField(description="Prompt")
|
||||
seed: int | None = InputField(default=None, description=FieldDescriptions.seed)
|
||||
num_images: int = InputField(default=1, gt=0, description="Number of images to generate")
|
||||
width: int = InputField(default=1024, gt=0, description=FieldDescriptions.width)
|
||||
height: int = InputField(default=1024, gt=0, description=FieldDescriptions.height)
|
||||
image_size: str | None = InputField(default=None, description="Image size preset (e.g. 1K, 2K, 4K)")
|
||||
init_image: ImageField | None = InputField(default=None, description="Init image for img2img/inpaint")
|
||||
mask_image: ImageField | None = InputField(default=None, description="Mask image for inpaint")
|
||||
reference_images: list[ImageField] = InputField(default=[], description="Reference images")
|
||||
|
||||
def _build_provider_options(self) -> dict[str, Any] | None:
|
||||
"""Override in provider-specific subclasses to pass extra options."""
|
||||
return None
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
||||
model_config = context.models.get_config(self.model)
|
||||
if not isinstance(model_config, ExternalApiModelConfig):
|
||||
raise ValueError("Selected model is not an external API model")
|
||||
|
||||
if self.provider_id is not None and model_config.provider_id != self.provider_id:
|
||||
raise ValueError(
|
||||
f"Selected model provider '{model_config.provider_id}' does not match node provider '{self.provider_id}'"
|
||||
)
|
||||
|
||||
init_image = None
|
||||
if self.init_image is not None:
|
||||
init_image = context.images.get_pil(self.init_image.image_name, mode="RGB")
|
||||
|
||||
mask_image = None
|
||||
if self.mask_image is not None:
|
||||
mask_image = context.images.get_pil(self.mask_image.image_name, mode="L")
|
||||
|
||||
reference_images: list[ExternalReferenceImage] = []
|
||||
for image_field in self.reference_images:
|
||||
reference_image = context.images.get_pil(image_field.image_name, mode="RGB")
|
||||
reference_images.append(ExternalReferenceImage(image=reference_image))
|
||||
|
||||
request = ExternalGenerationRequest(
|
||||
model=model_config,
|
||||
mode=self.mode,
|
||||
prompt=self.prompt,
|
||||
seed=self.seed,
|
||||
num_images=self.num_images,
|
||||
width=self.width,
|
||||
height=self.height,
|
||||
image_size=self.image_size,
|
||||
init_image=init_image,
|
||||
mask_image=mask_image,
|
||||
reference_images=reference_images,
|
||||
metadata=self._build_request_metadata(),
|
||||
provider_options=self._build_provider_options(),
|
||||
)
|
||||
|
||||
result = context._services.external_generation.generate(request)
|
||||
|
||||
outputs: list[ImageField] = []
|
||||
for generated in result.images:
|
||||
metadata = self._build_output_metadata(model_config, result, generated.seed)
|
||||
image_dto = context.images.save(image=generated.image, metadata=metadata)
|
||||
outputs.append(ImageField(image_name=image_dto.image_name))
|
||||
|
||||
return ImageCollectionOutput(collection=outputs)
|
||||
|
||||
def _build_request_metadata(self) -> dict[str, Any] | None:
|
||||
if self.metadata is None:
|
||||
return None
|
||||
return self.metadata.root
|
||||
|
||||
def _build_output_metadata(
|
||||
self,
|
||||
model_config: ExternalApiModelConfig,
|
||||
result: ExternalGenerationResult,
|
||||
image_seed: int | None,
|
||||
) -> MetadataField | None:
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
if self.metadata is not None:
|
||||
metadata.update(self.metadata.root)
|
||||
|
||||
metadata.update(
|
||||
{
|
||||
"external_provider": model_config.provider_id,
|
||||
"external_model_id": model_config.provider_model_id,
|
||||
}
|
||||
)
|
||||
|
||||
provider_request_id = getattr(result, "provider_request_id", None)
|
||||
if provider_request_id:
|
||||
metadata["external_request_id"] = provider_request_id
|
||||
|
||||
provider_metadata = getattr(result, "provider_metadata", None)
|
||||
if provider_metadata:
|
||||
metadata["external_provider_metadata"] = provider_metadata
|
||||
|
||||
if image_seed is not None:
|
||||
metadata["external_seed"] = image_seed
|
||||
|
||||
if not metadata:
|
||||
return None
|
||||
return MetadataField(root=metadata)
|
||||
|
||||
|
||||
@invocation(
|
||||
"external_image_generation",
|
||||
title="External Image Generation (Legacy)",
|
||||
tags=["external", "generation"],
|
||||
category="image",
|
||||
version="1.1.0",
|
||||
classification=Classification.Internal,
|
||||
)
|
||||
class ExternalImageGenerationInvocation(BaseExternalImageGenerationInvocation):
|
||||
"""Legacy external image generation node kept for backward compatibility."""
|
||||
|
||||
|
||||
@invocation(
|
||||
"openai_image_generation",
|
||||
title="OpenAI Image Generation",
|
||||
tags=["external", "generation", "openai"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class OpenAIImageGenerationInvocation(BaseExternalImageGenerationInvocation):
|
||||
"""Generate images using an OpenAI-hosted external model."""
|
||||
|
||||
provider_id = "openai"
|
||||
|
||||
quality: Literal["auto", "high", "medium", "low"] = InputField(default="auto", description="Output image quality")
|
||||
background: Literal["auto", "transparent", "opaque"] = InputField(
|
||||
default="auto", description="Background transparency handling"
|
||||
)
|
||||
input_fidelity: Literal["low", "high"] | None = InputField(
|
||||
default=None, description="Fidelity to source images (edits only)"
|
||||
)
|
||||
|
||||
def _build_provider_options(self) -> dict[str, Any]:
|
||||
options: dict[str, Any] = {
|
||||
"quality": self.quality,
|
||||
"background": self.background,
|
||||
}
|
||||
if self.input_fidelity is not None:
|
||||
options["input_fidelity"] = self.input_fidelity
|
||||
return options
|
||||
|
||||
|
||||
@invocation(
|
||||
"gemini_image_generation",
|
||||
title="Gemini Image Generation",
|
||||
tags=["external", "generation", "gemini"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class GeminiImageGenerationInvocation(BaseExternalImageGenerationInvocation):
|
||||
"""Generate images using a Gemini-hosted external model."""
|
||||
|
||||
provider_id = "gemini"
|
||||
|
||||
temperature: float | None = InputField(default=None, ge=0.0, le=2.0, description="Sampling temperature")
|
||||
thinking_level: Literal["minimal", "high"] | None = InputField(
|
||||
default=None, description="Thinking level for image generation"
|
||||
)
|
||||
|
||||
def _build_provider_options(self) -> dict[str, Any] | None:
|
||||
options: dict[str, Any] = {}
|
||||
if self.temperature is not None:
|
||||
options["temperature"] = self.temperature
|
||||
if self.thinking_level is not None:
|
||||
options["thinking_level"] = self.thinking_level
|
||||
return options or None
|
||||
@@ -435,7 +435,9 @@ def get_faces_list(
|
||||
return all_faces
|
||||
|
||||
|
||||
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.2")
|
||||
@invocation(
|
||||
"face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="segmentation", version="1.2.2"
|
||||
)
|
||||
class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
|
||||
|
||||
@@ -514,7 +516,9 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||
return output
|
||||
|
||||
|
||||
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.2")
|
||||
@invocation(
|
||||
"face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="segmentation", version="1.2.2"
|
||||
)
|
||||
class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
||||
"""Face mask creation using mediapipe face detection"""
|
||||
|
||||
@@ -617,7 +621,11 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
||||
|
||||
|
||||
@invocation(
|
||||
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.2"
|
||||
"face_identifier",
|
||||
title="FaceIdentifier",
|
||||
tags=["image", "face", "identifier"],
|
||||
category="segmentation",
|
||||
version="1.2.2",
|
||||
)
|
||||
class FaceIdentifierInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
|
||||
|
||||
@@ -171,6 +171,8 @@ class FieldDescriptions:
|
||||
sd3_model = "SD3 model (MMDiTX) to load"
|
||||
cogview4_model = "CogView4 model (Transformer) to load"
|
||||
z_image_model = "Z-Image model (Transformer) to load"
|
||||
qwen_image_model = "Qwen Image Edit model (Transformer) to load"
|
||||
qwen_vl_encoder = "Qwen2.5-VL tokenizer, processor and text/vision encoder"
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
@@ -340,6 +342,27 @@ class ZImageConditioningField(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class QwenImageConditioningField(BaseModel):
|
||||
"""A Qwen Image Edit conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
class AnimaConditioningField(BaseModel):
|
||||
"""An Anima conditioning tensor primitive value.
|
||||
|
||||
Anima conditioning contains Qwen3 0.6B hidden states and T5-XXL token IDs,
|
||||
which are combined by the LLM Adapter inside the transformer.
|
||||
"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
mask: Optional[TensorField] = Field(
|
||||
default=None,
|
||||
description="The mask associated with this conditioning tensor for regional prompting. "
|
||||
"Excluded regions should be set to False, included regions should be set to True.",
|
||||
)
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
"flux2_denoise",
|
||||
title="FLUX2 Denoise",
|
||||
tags=["image", "flux", "flux2", "klein", "denoise"],
|
||||
category="image",
|
||||
category="latents",
|
||||
version="1.4.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
|
||||
@@ -25,6 +25,7 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import Qwen3EncoderField
|
||||
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_T5_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
@@ -44,7 +45,7 @@ KLEIN_MAX_SEQ_LEN = 512
|
||||
"flux2_klein_text_encoder",
|
||||
title="Prompt - Flux2 Klein",
|
||||
tags=["prompt", "conditioning", "flux", "klein", "qwen3"],
|
||||
category="conditioning",
|
||||
category="prompt",
|
||||
version="1.1.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
@@ -100,7 +101,12 @@ class Flux2KleinTextEncoderInvocation(BaseInvocation):
|
||||
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
|
||||
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
|
||||
|
||||
device = text_encoder.device
|
||||
repaired_tensors = text_encoder_info.repair_required_tensors_on_device()
|
||||
device = get_effective_device(text_encoder)
|
||||
if repaired_tensors > 0:
|
||||
context.logger.warning(
|
||||
f"Recovered {repaired_tensors} required Qwen3 tensor(s) onto {device} after a partial device mismatch."
|
||||
)
|
||||
|
||||
# Apply LoRA models
|
||||
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
|
||||
|
||||
@@ -50,7 +50,7 @@ class FluxControlNetOutput(BaseInvocationOutput):
|
||||
"flux_controlnet",
|
||||
title="FLUX ControlNet",
|
||||
tags=["controlnet", "flux"],
|
||||
category="controlnet",
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxControlNetInvocation(BaseInvocation):
|
||||
|
||||
@@ -70,7 +70,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
"flux_denoise",
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
category="latents",
|
||||
version="4.5.1",
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation):
|
||||
|
||||
@@ -29,7 +29,7 @@ class FluxFillOutput(BaseInvocationOutput):
|
||||
"flux_fill",
|
||||
title="FLUX Fill Conditioning",
|
||||
tags=["inpaint"],
|
||||
category="inpaint",
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
"flux_ip_adapter",
|
||||
title="FLUX IP-Adapter",
|
||||
tags=["ip_adapter", "control"],
|
||||
category="ip_adapter",
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxIPAdapterInvocation(BaseInvocation):
|
||||
|
||||
@@ -47,7 +47,7 @@ DOWNSAMPLING_FUNCTIONS = Literal["nearest", "bilinear", "bicubic", "area", "near
|
||||
"flux_redux",
|
||||
title="FLUX Redux",
|
||||
tags=["ip_adapter", "control"],
|
||||
category="ip_adapter",
|
||||
category="conditioning",
|
||||
version="2.1.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
|
||||
@@ -28,7 +28,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
|
||||
"flux_text_encoder",
|
||||
title="Prompt - FLUX",
|
||||
tags=["prompt", "conditioning", "flux"],
|
||||
category="conditioning",
|
||||
category="prompt",
|
||||
version="1.1.2",
|
||||
)
|
||||
class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
@@ -24,7 +24,7 @@ GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = {
|
||||
"grounding_dino",
|
||||
title="Grounding DINO (Text Prompt Object Detection)",
|
||||
tags=["prompt", "object detection"],
|
||||
category="image",
|
||||
category="segmentation",
|
||||
version="1.0.0",
|
||||
)
|
||||
class GroundingDinoInvocation(BaseInvocation):
|
||||
|
||||
@@ -11,7 +11,7 @@ from invokeai.backend.image_util.hed import ControlNetHED_Apache2, HEDEdgeDetect
|
||||
"hed_edge_detection",
|
||||
title="HED Edge Detection",
|
||||
tags=["controlnet", "hed", "softedge"],
|
||||
category="controlnet",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.0.0",
|
||||
)
|
||||
class HEDEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@@ -21,6 +21,7 @@ class IdealSizeOutput(BaseInvocationOutput):
|
||||
"ideal_size",
|
||||
title="Ideal Size - SD1.5, SDXL",
|
||||
tags=["latents", "math", "ideal_size"],
|
||||
category="latents",
|
||||
version="1.0.6",
|
||||
)
|
||||
class IdealSizeInvocation(BaseInvocation):
|
||||
|
||||
@@ -21,7 +21,7 @@ from invokeai.app.invocations.fields import (
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.primitives import ImageOutput, StringOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
@@ -197,7 +197,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"tomask",
|
||||
title="Mask from Alpha",
|
||||
tags=["image", "mask"],
|
||||
category="image",
|
||||
category="mask",
|
||||
version="1.2.2",
|
||||
)
|
||||
class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -581,11 +581,30 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
|
||||
@invocation(
|
||||
"decode_watermark",
|
||||
title="Decode Invisible Watermark",
|
||||
tags=["image", "watermark"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
)
|
||||
class DecodeInvisibleWatermarkInvocation(BaseInvocation):
|
||||
"""Decode an invisible watermark from an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to decode the watermark from")
|
||||
length: int = InputField(default=8, description="The expected watermark length in bytes")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
watermark = InvisibleWatermark.decode_watermark(image, self.length)
|
||||
return StringOutput(value=watermark)
|
||||
|
||||
|
||||
@invocation(
|
||||
"mask_edge",
|
||||
title="Mask Edge",
|
||||
tags=["image", "mask", "inpaint"],
|
||||
category="image",
|
||||
category="mask",
|
||||
version="1.2.2",
|
||||
)
|
||||
class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -624,7 +643,7 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"mask_combine",
|
||||
title="Combine Masks",
|
||||
tags=["image", "mask", "multiply"],
|
||||
category="image",
|
||||
category="mask",
|
||||
version="1.2.2",
|
||||
)
|
||||
class MaskCombineInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -955,7 +974,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"save_image",
|
||||
title="Save Image",
|
||||
tags=["primitives", "image"],
|
||||
category="primitives",
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
use_cache=False,
|
||||
)
|
||||
@@ -976,7 +995,7 @@ class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"canvas_paste_back",
|
||||
title="Canvas Paste Back",
|
||||
tags=["image", "combine"],
|
||||
category="image",
|
||||
category="canvas",
|
||||
version="1.0.1",
|
||||
)
|
||||
class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -1013,7 +1032,7 @@ class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"mask_from_id",
|
||||
title="Mask from Segmented Image",
|
||||
tags=["image", "mask", "id"],
|
||||
category="image",
|
||||
category="mask",
|
||||
version="1.0.1",
|
||||
)
|
||||
class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -1050,7 +1069,7 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"canvas_v2_mask_and_crop",
|
||||
title="Canvas V2 Mask and Crop",
|
||||
tags=["image", "mask", "id"],
|
||||
category="image",
|
||||
category="canvas",
|
||||
version="1.0.0",
|
||||
classification=Classification.Deprecated,
|
||||
)
|
||||
@@ -1091,7 +1110,7 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
|
||||
@invocation(
|
||||
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.1"
|
||||
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="mask", version="1.0.1"
|
||||
)
|
||||
class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
|
||||
@@ -1180,7 +1199,7 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"apply_mask_to_image",
|
||||
title="Apply Mask to Image",
|
||||
tags=["image", "mask", "blend"],
|
||||
category="image",
|
||||
category="mask",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ApplyMaskToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -1355,7 +1374,7 @@ class PasteImageIntoBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoar
|
||||
"flux_kontext_image_prep",
|
||||
title="FLUX Kontext Image Prep",
|
||||
tags=["image", "concatenate", "flux", "kontext"],
|
||||
category="image",
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class FluxKontextConcatenateImagesInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@@ -23,7 +23,7 @@ class ImagePanelCoordinateOutput(BaseInvocationOutput):
|
||||
"image_panel_layout",
|
||||
title="Image Panel Layout",
|
||||
tags=["image", "panel", "layout"],
|
||||
category="image",
|
||||
category="canvas",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
|
||||
@@ -73,7 +73,7 @@ CLIP_VISION_MODEL_MAP: dict[Literal["ViT-L", "ViT-H", "ViT-G"], StarterModel] =
|
||||
"ip_adapter",
|
||||
title="IP-Adapter - SD1.5, SDXL",
|
||||
tags=["ip_adapter", "control"],
|
||||
category="ip_adapter",
|
||||
category="conditioning",
|
||||
version="1.5.1",
|
||||
)
|
||||
class IPAdapterInvocation(BaseInvocation):
|
||||
|
||||
@@ -11,7 +11,7 @@ from invokeai.backend.image_util.lineart import Generator, LineartEdgeDetector
|
||||
"lineart_edge_detection",
|
||||
title="Lineart Edge Detection",
|
||||
tags=["controlnet", "lineart"],
|
||||
category="controlnet",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.0.0",
|
||||
)
|
||||
class LineartEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@@ -9,7 +9,7 @@ from invokeai.backend.image_util.lineart_anime import LineartAnimeEdgeDetector,
|
||||
"lineart_anime_edge_detection",
|
||||
title="Lineart Anime Edge Detection",
|
||||
tags=["controlnet", "lineart"],
|
||||
category="controlnet",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.0.0",
|
||||
)
|
||||
class LineartAnimeEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@@ -19,7 +19,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
"llava_onevision_vllm",
|
||||
title="LLaVA OneVision VLLM",
|
||||
tags=["vllm"],
|
||||
category="vllm",
|
||||
category="multimodal",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
|
||||
34
invokeai/app/invocations/logic.py
Normal file
34
invokeai/app/invocations/logic.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("if_output")
|
||||
class IfInvocationOutput(BaseInvocationOutput):
|
||||
value: Optional[Any] = OutputField(
|
||||
default=None, description="The selected value", title="Output", ui_type=UIType.Any
|
||||
)
|
||||
|
||||
|
||||
@invocation("if", title="If", tags=["logic", "conditional"], category="math", version="1.0.0")
|
||||
class IfInvocation(BaseInvocation):
|
||||
"""Selects between two optional inputs based on a boolean condition."""
|
||||
|
||||
condition: bool = InputField(default=False, description="The condition used to select an input", title="Condition")
|
||||
true_input: Optional[Any] = InputField(
|
||||
default=None,
|
||||
description="Selected when the condition is true",
|
||||
title="True Input",
|
||||
ui_type=UIType.Any,
|
||||
)
|
||||
false_input: Optional[Any] = InputField(
|
||||
default=None,
|
||||
description="Selected when the condition is false",
|
||||
title="False Input",
|
||||
ui_type=UIType.Any,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IfInvocationOutput:
|
||||
return IfInvocationOutput(value=self.true_input if self.condition else self.false_input)
|
||||
@@ -24,7 +24,7 @@ from invokeai.backend.image_util.util import pil_to_np
|
||||
"rectangle_mask",
|
||||
title="Create Rectangle Mask",
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
category="mask",
|
||||
version="1.0.1",
|
||||
)
|
||||
class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
||||
@@ -55,7 +55,7 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
||||
"alpha_mask_to_tensor",
|
||||
title="Alpha Mask to Tensor",
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
category="mask",
|
||||
version="1.0.0",
|
||||
)
|
||||
class AlphaMaskToTensorInvocation(BaseInvocation):
|
||||
@@ -83,7 +83,7 @@ class AlphaMaskToTensorInvocation(BaseInvocation):
|
||||
"invert_tensor_mask",
|
||||
title="Invert Tensor Mask",
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
category="mask",
|
||||
version="1.1.0",
|
||||
)
|
||||
class InvertTensorMaskInvocation(BaseInvocation):
|
||||
@@ -115,7 +115,7 @@ class InvertTensorMaskInvocation(BaseInvocation):
|
||||
"image_mask_to_tensor",
|
||||
title="Image Mask to Tensor",
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
category="mask",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
|
||||
|
||||
@@ -9,7 +9,7 @@ from invokeai.backend.image_util.mediapipe_face import detect_faces
|
||||
"mediapipe_face_detection",
|
||||
title="MediaPipe Face Detection",
|
||||
tags=["controlnet", "face"],
|
||||
category="controlnet",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MediaPipeFaceDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@@ -166,6 +166,14 @@ GENERATION_MODES = Literal[
|
||||
"z_image_img2img",
|
||||
"z_image_inpaint",
|
||||
"z_image_outpaint",
|
||||
"qwen_image_txt2img",
|
||||
"qwen_image_img2img",
|
||||
"qwen_image_inpaint",
|
||||
"qwen_image_outpaint",
|
||||
"anima_txt2img",
|
||||
"anima_img2img",
|
||||
"anima_inpaint",
|
||||
"anima_outpaint",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -621,7 +621,7 @@ class LatentsMetaOutput(LatentsOutput, MetadataOutput):
|
||||
"denoise_latents_meta",
|
||||
title=f"{DenoiseLatentsInvocation.UIConfig.title} + Metadata",
|
||||
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||
category="latents",
|
||||
category="metadata",
|
||||
version="1.1.1",
|
||||
)
|
||||
class DenoiseLatentsMetaInvocation(DenoiseLatentsInvocation, WithMetadata):
|
||||
@@ -686,7 +686,7 @@ class DenoiseLatentsMetaInvocation(DenoiseLatentsInvocation, WithMetadata):
|
||||
"flux_denoise_meta",
|
||||
title=f"{FluxDenoiseInvocation.UIConfig.title} + Metadata",
|
||||
tags=["flux", "latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||
category="latents",
|
||||
category="metadata",
|
||||
version="1.0.1",
|
||||
)
|
||||
class FluxDenoiseLatentsMetaInvocation(FluxDenoiseInvocation, WithMetadata):
|
||||
@@ -734,7 +734,7 @@ class FluxDenoiseLatentsMetaInvocation(FluxDenoiseInvocation, WithMetadata):
|
||||
"z_image_denoise_meta",
|
||||
title=f"{ZImageDenoiseInvocation.UIConfig.title} + Metadata",
|
||||
tags=["z-image", "latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||
category="latents",
|
||||
category="metadata",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ZImageDenoiseMetaInvocation(ZImageDenoiseInvocation, WithMetadata):
|
||||
|
||||
@@ -10,7 +10,7 @@ from invokeai.backend.image_util.mlsd.models.mbv2_mlsd_large import MobileV2_MLS
|
||||
"mlsd_detection",
|
||||
title="MLSD Detection",
|
||||
tags=["controlnet", "mlsd", "edge"],
|
||||
category="controlnet",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MLSDDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@@ -72,6 +72,13 @@ class GlmEncoderField(BaseModel):
|
||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||
|
||||
|
||||
class QwenVLEncoderField(BaseModel):
|
||||
"""Field for Qwen2.5-VL encoder used by Qwen Image Edit models."""
|
||||
|
||||
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||
|
||||
|
||||
class Qwen3EncoderField(BaseModel):
|
||||
"""Field for Qwen3 text encoder used by Z-Image models."""
|
||||
|
||||
@@ -577,7 +584,7 @@ class SeamlessModeInvocation(BaseInvocation):
|
||||
return SeamlessModeOutput(unet=unet, vae=vae)
|
||||
|
||||
|
||||
@invocation("freeu", title="Apply FreeU - SD1.5, SDXL", tags=["freeu"], category="unet", version="1.0.2")
|
||||
@invocation("freeu", title="Apply FreeU - SD1.5, SDXL", tags=["freeu"], category="model", version="1.0.2")
|
||||
class FreeUInvocation(BaseInvocation):
|
||||
"""
|
||||
Applies FreeU to the UNet. Suggested values (b1/b2/s1/s2):
|
||||
|
||||
@@ -10,7 +10,7 @@ from invokeai.backend.image_util.normal_bae.nets.NNET import NNET
|
||||
"normal_map",
|
||||
title="Normal Map",
|
||||
tags=["controlnet", "normal"],
|
||||
category="controlnet",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.0.0",
|
||||
)
|
||||
class NormalMapInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@@ -16,7 +16,9 @@ class PBRMapsOutput(BaseInvocationOutput):
|
||||
displacement_map: ImageField = OutputField(default=None, description="The generated displacement map")
|
||||
|
||||
|
||||
@invocation("pbr_maps", title="PBR Maps", tags=["image", "material"], category="image", version="1.0.0")
|
||||
@invocation(
|
||||
"pbr_maps", title="PBR Maps", tags=["image", "material"], category="controlnet_preprocessors", version="1.0.0"
|
||||
)
|
||||
class PBRMapsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generate Normal, Displacement and Roughness Map from a given image"""
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from invokeai.backend.image_util.pidi.model import PiDiNet
|
||||
"pidi_edge_detection",
|
||||
title="PiDiNet Edge Detection",
|
||||
tags=["controlnet", "edge"],
|
||||
category="controlnet",
|
||||
category="controlnet_preprocessors",
|
||||
version="1.0.0",
|
||||
)
|
||||
class PiDiNetEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@@ -12,6 +12,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
)
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
AnimaConditioningField,
|
||||
BoundingBoxField,
|
||||
CogView4ConditioningField,
|
||||
ColorField,
|
||||
@@ -24,6 +25,7 @@ from invokeai.app.invocations.fields import (
|
||||
InputField,
|
||||
LatentsField,
|
||||
OutputField,
|
||||
QwenImageConditioningField,
|
||||
SD3ConditioningField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
@@ -473,6 +475,28 @@ class ZImageConditioningOutput(BaseInvocationOutput):
|
||||
return cls(conditioning=ZImageConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("qwen_image_conditioning_output")
|
||||
class QwenImageConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a Qwen Image Edit conditioning tensor."""
|
||||
|
||||
conditioning: QwenImageConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
@classmethod
|
||||
def build(cls, conditioning_name: str) -> "QwenImageConditioningOutput":
|
||||
return cls(conditioning=QwenImageConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("anima_conditioning_output")
|
||||
class AnimaConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output an Anima text conditioning tensor."""
|
||||
|
||||
conditioning: AnimaConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
@classmethod
|
||||
def build(cls, conditioning_name: str) -> "AnimaConditioningOutput":
|
||||
return cls(conditioning=AnimaConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("conditioning_output")
|
||||
class ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
|
||||
490
invokeai/app/invocations/qwen_image_denoise.py
Normal file
490
invokeai/app/invocations/qwen_image_denoise.py
Normal file
@@ -0,0 +1,490 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Callable, Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformer2DModel
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
QwenImageConditioningField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.qwen_image_lora_constants import (
|
||||
QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX,
|
||||
)
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import QwenImageConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"qwen_image_denoise",
|
||||
title="Denoise - Qwen Image",
|
||||
tags=["image", "qwen_image"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class QwenImageDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run the denoising process with a Qwen Image model."""
|
||||
|
||||
# If latents is provided, this means we are doing image-to-image.
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None, description=FieldDescriptions.latents, input=Input.Connection
|
||||
)
|
||||
# Reference image latents (encoded through VAE) to concatenate with noisy latents.
|
||||
reference_latents: Optional[LatentsField] = InputField(
|
||||
default=None,
|
||||
description="Reference image latents to guide generation. Encoded through the VAE.",
|
||||
input=Input.Connection,
|
||||
)
|
||||
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
|
||||
)
|
||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.qwen_image_model, input=Input.Connection, title="Transformer"
|
||||
)
|
||||
positive_conditioning: QwenImageConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_conditioning: Optional[QwenImageConditioningField] = InputField(
|
||||
default=None, description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
)
|
||||
cfg_scale: float | list[float] = InputField(default=4.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||
steps: int = InputField(default=40, gt=0, description=FieldDescriptions.steps)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
shift: Optional[float] = InputField(
|
||||
default=None,
|
||||
description="Override the sigma schedule shift. "
|
||||
"When set, uses a fixed shift (e.g. 3.0 for Lightning LoRAs) instead of the default dynamic shifting. "
|
||||
"Leave unset for the base model's default schedule.",
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = self._run_diffusion(context)
|
||||
latents = latents.detach().to("cpu")
|
||||
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
||||
if self.denoise_mask is None:
|
||||
return None
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
mask = 1.0 - mask
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
mask = tv_resize(
|
||||
img=mask,
|
||||
size=[latent_height, latent_width],
|
||||
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
||||
antialias=False,
|
||||
)
|
||||
|
||||
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
||||
return mask
|
||||
|
||||
def _load_text_conditioning(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
conditioning_name: str,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
cond_data = context.conditioning.load(conditioning_name)
|
||||
assert len(cond_data.conditionings) == 1
|
||||
conditioning = cond_data.conditionings[0]
|
||||
assert isinstance(conditioning, QwenImageConditioningInfo)
|
||||
conditioning = conditioning.to(dtype=dtype, device=device)
|
||||
return conditioning.prompt_embeds, conditioning.prompt_embeds_mask
|
||||
|
||||
def _get_noise(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
seed: int,
|
||||
) -> torch.Tensor:
|
||||
rand_device = "cpu"
|
||||
rand_dtype = torch.float32
|
||||
|
||||
return torch.randn(
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
int(height) // LATENT_SCALE_FACTOR,
|
||||
int(width) // LATENT_SCALE_FACTOR,
|
||||
device=rand_device,
|
||||
dtype=rand_dtype,
|
||||
generator=torch.Generator(device=rand_device).manual_seed(seed),
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]:
|
||||
if isinstance(self.cfg_scale, float):
|
||||
cfg_scale = [self.cfg_scale] * num_timesteps
|
||||
elif isinstance(self.cfg_scale, list):
|
||||
assert len(self.cfg_scale) == num_timesteps
|
||||
cfg_scale = self.cfg_scale
|
||||
else:
|
||||
raise ValueError(f"Invalid CFG scale type: {type(self.cfg_scale)}")
|
||||
return cfg_scale
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(
|
||||
latents: torch.Tensor, batch_size: int, num_channels: int, height: int, width: int
|
||||
) -> torch.Tensor:
|
||||
"""Pack 4D latents (B, C, H, W) into 2x2-patched 3D (B, H/2*W/2, C*4)."""
|
||||
latents = latents.view(batch_size, num_channels, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels * 4)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _unpack_latents(latents: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""Unpack 3D patched latents (B, seq, C*4) back to 4D (B, C, H, W)."""
|
||||
batch_size, _num_patches, channels = latents.shape
|
||||
# height/width are in latent space; they must be divisible by 2 for packing
|
||||
h = 2 * (height // 2)
|
||||
w = 2 * (width // 2)
|
||||
latents = latents.view(batch_size, h // 2, w // 2, channels // 4, 2, 2)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
latents = latents.reshape(batch_size, channels // 4, h, w)
|
||||
return latents
|
||||
|
||||
def _run_diffusion(self, context: InvocationContext):
|
||||
inference_dtype = torch.bfloat16
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
assert isinstance(transformer_info.model, QwenImageTransformer2DModel)
|
||||
|
||||
# Load conditioning
|
||||
pos_prompt_embeds, pos_prompt_mask = self._load_text_conditioning(
|
||||
context=context,
|
||||
conditioning_name=self.positive_conditioning.conditioning_name,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
neg_prompt_embeds = None
|
||||
neg_prompt_mask = None
|
||||
# Match the diffusers pipeline: only enable CFG when cfg_scale > 1 AND negative conditioning is provided.
|
||||
# With cfg_scale <= 1, the negative prediction is unused, so skip it entirely.
|
||||
# For per-step arrays, enable CFG if any step has scale > 1.
|
||||
if isinstance(self.cfg_scale, list):
|
||||
any_cfg_above_one = any(v > 1.0 for v in self.cfg_scale)
|
||||
else:
|
||||
any_cfg_above_one = self.cfg_scale > 1.0
|
||||
do_classifier_free_guidance = self.negative_conditioning is not None and any_cfg_above_one
|
||||
if do_classifier_free_guidance:
|
||||
neg_prompt_embeds, neg_prompt_mask = self._load_text_conditioning(
|
||||
context=context,
|
||||
conditioning_name=self.negative_conditioning.conditioning_name,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Prepare the timestep / sigma schedule
|
||||
patch_size = transformer_info.model.config.patch_size
|
||||
assert isinstance(patch_size, int)
|
||||
# Output channels is 16 (the actual latent channels)
|
||||
out_channels = transformer_info.model.config.out_channels
|
||||
assert isinstance(out_channels, int)
|
||||
|
||||
latent_height = self.height // LATENT_SCALE_FACTOR
|
||||
latent_width = self.width // LATENT_SCALE_FACTOR
|
||||
image_seq_len = (latent_height * latent_width) // (patch_size**2)
|
||||
|
||||
# Use the actual FlowMatchEulerDiscreteScheduler to compute sigmas/timesteps,
|
||||
# exactly matching the diffusers pipeline.
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
# Try to load the scheduler config from the model's directory (Diffusers models
|
||||
# have a scheduler/ subdir). For GGUF models this path doesn't exist, so fall
|
||||
# back to instantiating the scheduler with the known Qwen Image defaults.
|
||||
model_path = context.models.get_absolute_path(context.models.get_config(self.transformer.transformer))
|
||||
scheduler_path = model_path / "scheduler"
|
||||
if scheduler_path.is_dir() and (scheduler_path / "scheduler_config.json").exists():
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(str(scheduler_path), local_files_only=True)
|
||||
else:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.5,
|
||||
max_shift=0.9,
|
||||
base_image_seq_len=256,
|
||||
max_image_seq_len=8192,
|
||||
shift_terminal=0.02,
|
||||
num_train_timesteps=1000,
|
||||
time_shift_type="exponential",
|
||||
)
|
||||
|
||||
if self.shift is not None:
|
||||
# Lightning LoRA: fixed shift
|
||||
mu = math.log(self.shift)
|
||||
else:
|
||||
# Default dynamic shifting
|
||||
# Linear interpolation matching diffusers' calculate_shift
|
||||
base_shift = scheduler.config.get("base_shift", 0.5)
|
||||
max_shift = scheduler.config.get("max_shift", 0.9)
|
||||
base_seq = scheduler.config.get("base_image_seq_len", 256)
|
||||
max_seq = scheduler.config.get("max_image_seq_len", 4096)
|
||||
m = (max_shift - base_shift) / (max_seq - base_seq)
|
||||
b = base_shift - m * base_seq
|
||||
mu = image_seq_len * m + b
|
||||
|
||||
init_sigmas = np.linspace(1.0, 1.0 / self.steps, self.steps).tolist()
|
||||
scheduler.set_timesteps(sigmas=init_sigmas, mu=mu, device=device)
|
||||
|
||||
# Clip the schedule based on denoising_start/denoising_end to support img2img strength.
|
||||
# The scheduler's sigmas go from high (noisy) to 0 (clean). We clip to the fractional range.
|
||||
sigmas_sched = scheduler.sigmas # (N+1,) including terminal 0
|
||||
if self.denoising_start > 0 or self.denoising_end < 1:
|
||||
total_sigmas = len(sigmas_sched) - 1 # exclude terminal
|
||||
start_idx = int(round(self.denoising_start * total_sigmas))
|
||||
end_idx = int(round(self.denoising_end * total_sigmas))
|
||||
sigmas_sched = sigmas_sched[start_idx : end_idx + 1] # +1 to include the next sigma for dt
|
||||
# Rebuild timesteps from clipped sigmas (exclude terminal 0)
|
||||
timesteps_sched = sigmas_sched[:-1] * scheduler.config.num_train_timesteps
|
||||
else:
|
||||
timesteps_sched = scheduler.timesteps
|
||||
|
||||
total_steps = len(timesteps_sched)
|
||||
|
||||
cfg_scale = self._prepare_cfg_scale(total_steps)
|
||||
|
||||
# Load initial latents if provided (for img2img)
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
init_latents = init_latents.to(device=device, dtype=inference_dtype)
|
||||
if init_latents.dim() == 5:
|
||||
init_latents = init_latents.squeeze(2)
|
||||
|
||||
# Load reference image latents if provided
|
||||
ref_latents = None
|
||||
if self.reference_latents is not None:
|
||||
ref_latents = context.tensors.load(self.reference_latents.latents_name)
|
||||
ref_latents = ref_latents.to(device=device, dtype=inference_dtype)
|
||||
# The VAE encoder produces 5D latents (B, C, 1, H, W); squeeze the frame dim
|
||||
# so we have 4D (B, C, H, W) for packing.
|
||||
if ref_latents.dim() == 5:
|
||||
ref_latents = ref_latents.squeeze(2)
|
||||
|
||||
# Generate noise (16 channels - the output latent channels)
|
||||
noise = self._get_noise(
|
||||
batch_size=1,
|
||||
num_channels_latents=out_channels,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
dtype=inference_dtype,
|
||||
device=device,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
# Prepare input latent image
|
||||
if init_latents is not None:
|
||||
s_0 = sigmas_sched[0].item()
|
||||
latents = s_0 * noise + (1.0 - s_0) * init_latents
|
||||
else:
|
||||
if self.denoising_start > 1e-5:
|
||||
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
||||
latents = noise
|
||||
|
||||
if total_steps <= 0:
|
||||
return latents
|
||||
|
||||
# Pack latents into 2x2 patches: (B, C, H, W) -> (B, H/2*W/2, C*4)
|
||||
latents = self._pack_latents(latents, 1, out_channels, latent_height, latent_width)
|
||||
|
||||
# Determine whether the model uses reference latent conditioning (zero_cond_t).
|
||||
# Edit models (zero_cond_t=True) expect [noisy_patches ; ref_patches] in the sequence.
|
||||
# Txt2img models (zero_cond_t=False) only take noisy patches.
|
||||
has_zero_cond_t = getattr(transformer_info.model, "zero_cond_t", False) or getattr(
|
||||
transformer_info.model.config, "zero_cond_t", False
|
||||
)
|
||||
use_ref_latents = has_zero_cond_t
|
||||
|
||||
ref_latents_packed = None
|
||||
if use_ref_latents:
|
||||
if ref_latents is not None:
|
||||
_, ref_ch, rh, rw = ref_latents.shape
|
||||
if rh != latent_height or rw != latent_width:
|
||||
ref_latents = torch.nn.functional.interpolate(
|
||||
ref_latents, size=(latent_height, latent_width), mode="bilinear"
|
||||
)
|
||||
else:
|
||||
# No reference image provided — use zeros so the model still gets the
|
||||
# expected sequence layout.
|
||||
ref_latents = torch.zeros(
|
||||
1, out_channels, latent_height, latent_width, device=device, dtype=inference_dtype
|
||||
)
|
||||
ref_latents_packed = self._pack_latents(ref_latents, 1, out_channels, latent_height, latent_width)
|
||||
|
||||
# img_shapes tells the transformer the spatial layout of patches.
|
||||
if use_ref_latents:
|
||||
img_shapes = [
|
||||
[
|
||||
(1, latent_height // 2, latent_width // 2),
|
||||
(1, latent_height // 2, latent_width // 2),
|
||||
]
|
||||
]
|
||||
else:
|
||||
img_shapes = [
|
||||
[
|
||||
(1, latent_height // 2, latent_width // 2),
|
||||
]
|
||||
]
|
||||
|
||||
# Prepare inpaint extension (operates in 4D space, so unpack/repack around it)
|
||||
inpaint_mask = self._prep_inpaint_mask(context, noise) # noise has the right 4D shape
|
||||
inpaint_extension: RectifiedFlowInpaintExtension | None = None
|
||||
if inpaint_mask is not None:
|
||||
assert init_latents is not None
|
||||
inpaint_extension = RectifiedFlowInpaintExtension(
|
||||
init_latents=init_latents,
|
||||
inpaint_mask=inpaint_mask,
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
step_callback = self._build_step_callback(context)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=0,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(timesteps_sched[0].item()) if len(timesteps_sched) > 0 else 0,
|
||||
latents=self._unpack_latents(latents, latent_height, latent_width),
|
||||
),
|
||||
)
|
||||
|
||||
noisy_seq_len = latents.shape[1]
|
||||
|
||||
# Determine if the model is quantized — GGUF models need sidecar patching for LoRAs
|
||||
transformer_config = context.models.get_config(self.transformer.transformer)
|
||||
model_is_quantized = transformer_config.format in (ModelFormat.GGUFQuantized,)
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
|
||||
assert isinstance(transformer, QwenImageTransformer2DModel)
|
||||
|
||||
# Apply LoRA patches to the transformer
|
||||
exit_stack.enter_context(
|
||||
LayerPatcher.apply_smart_model_patches(
|
||||
model=transformer,
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX,
|
||||
dtype=inference_dtype,
|
||||
cached_weights=cached_weights,
|
||||
force_sidecar_patching=model_is_quantized,
|
||||
)
|
||||
)
|
||||
|
||||
for step_idx, t in enumerate(tqdm(timesteps_sched)):
|
||||
# The pipeline passes timestep / 1000 to the transformer
|
||||
timestep = t.expand(latents.shape[0]).to(inference_dtype)
|
||||
|
||||
# For edit models: concatenate noisy and reference patches along the sequence dim
|
||||
# For txt2img models: just use noisy patches
|
||||
if ref_latents_packed is not None:
|
||||
model_input = torch.cat([latents, ref_latents_packed], dim=1)
|
||||
else:
|
||||
model_input = latents
|
||||
|
||||
noise_pred_cond = transformer(
|
||||
hidden_states=model_input,
|
||||
encoder_hidden_states=pos_prompt_embeds,
|
||||
encoder_hidden_states_mask=pos_prompt_mask,
|
||||
timestep=timestep / 1000,
|
||||
img_shapes=img_shapes,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
# Only keep the noisy-latent portion of the output
|
||||
noise_pred_cond = noise_pred_cond[:, :noisy_seq_len]
|
||||
|
||||
if do_classifier_free_guidance and neg_prompt_embeds is not None:
|
||||
noise_pred_uncond = transformer(
|
||||
hidden_states=model_input,
|
||||
encoder_hidden_states=neg_prompt_embeds,
|
||||
encoder_hidden_states_mask=neg_prompt_mask,
|
||||
timestep=timestep / 1000,
|
||||
img_shapes=img_shapes,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred_uncond = noise_pred_uncond[:, :noisy_seq_len]
|
||||
|
||||
noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond)
|
||||
else:
|
||||
noise_pred = noise_pred_cond
|
||||
|
||||
# Euler step using the (possibly clipped) sigma schedule
|
||||
sigma_curr = sigmas_sched[step_idx]
|
||||
sigma_next = sigmas_sched[step_idx + 1]
|
||||
dt = sigma_next - sigma_curr
|
||||
latents = latents.to(torch.float32) + dt * noise_pred.to(torch.float32)
|
||||
latents = latents.to(inference_dtype)
|
||||
|
||||
if inpaint_extension is not None:
|
||||
sigma_next = sigmas_sched[step_idx + 1].item()
|
||||
latents_4d = self._unpack_latents(latents, latent_height, latent_width)
|
||||
latents_4d = inpaint_extension.merge_intermediate_latents_with_init_latents(latents_4d, sigma_next)
|
||||
latents = self._pack_latents(latents_4d, 1, out_channels, latent_height, latent_width)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=step_idx + 1,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(t.item()),
|
||||
latents=self._unpack_latents(latents, latent_height, latent_width),
|
||||
),
|
||||
)
|
||||
|
||||
# Unpack back to 4D then add frame dim for the video-style VAE: (B, C, 1, H, W)
|
||||
latents = self._unpack_latents(latents, latent_height, latent_width)
|
||||
latents = latents.unsqueeze(2)
|
||||
return latents
|
||||
|
||||
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, BaseModelType.QwenImage)
|
||||
|
||||
return step_callback
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
||||
"""Iterate over LoRA models to apply to the transformer."""
|
||||
for lora in self.transformer.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
if not isinstance(lora_info.model, ModelPatchRaw):
|
||||
raise TypeError(
|
||||
f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}."
|
||||
)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
96
invokeai/app/invocations/qwen_image_image_to_latents.py
Normal file
96
invokeai/app/invocations/qwen_image_image_to_latents.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import einops
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"qwen_image_i2l",
|
||||
title="Image to Latents - Qwen Image",
|
||||
tags=["image", "latents", "vae", "i2l", "qwen_image"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class QwenImageImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates latents from an image using the Qwen Image VAE."""
|
||||
|
||||
image: ImageField = InputField(description="The image to encode.")
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
|
||||
width: int | None = InputField(
|
||||
default=None,
|
||||
description="Resize the image to this width before encoding. If not set, encodes at the image's original size.",
|
||||
)
|
||||
height: int | None = InputField(
|
||||
default=None,
|
||||
description="Resize the image to this height before encoding. If not set, encodes at the image's original size.",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
with vae_info.model_on_device() as (_, vae):
|
||||
assert isinstance(vae, AutoencoderKLQwenImage)
|
||||
|
||||
vae.disable_tiling()
|
||||
|
||||
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
|
||||
with torch.inference_mode():
|
||||
# The Qwen Image VAE expects 5D input: (B, C, num_frames, H, W)
|
||||
if image_tensor.dim() == 4:
|
||||
image_tensor = image_tensor.unsqueeze(2)
|
||||
|
||||
posterior = vae.encode(image_tensor).latent_dist
|
||||
# Use mode (argmax) for deterministic encoding, matching diffusers
|
||||
latents: torch.Tensor = posterior.mode().to(dtype=vae.dtype)
|
||||
|
||||
# Normalize with per-channel latents_mean / latents_std
|
||||
latents_mean = (
|
||||
torch.tensor(vae.config.latents_mean)
|
||||
.view(1, vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(vae.config.latents_std)
|
||||
.view(1, vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = (latents - latents_mean) / latents_std
|
||||
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
# If target dimensions are specified, resize the image BEFORE encoding
|
||||
# (matching the diffusers pipeline which resizes in pixel space, not latent space).
|
||||
if self.width is not None and self.height is not None:
|
||||
image = image.convert("RGB").resize((self.width, self.height), resample=PILImage.LANCZOS)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
85
invokeai/app/invocations/qwen_image_latents_to_image.py
Normal file
85
invokeai/app/invocations/qwen_image_latents_to_image.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation(
|
||||
"qwen_image_l2i",
|
||||
title="Latents to Image - Qwen Image",
|
||||
tags=["latents", "image", "vae", "l2i", "qwen_image"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class QwenImageLatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates an image from latents using the Qwen Image VAE."""
|
||||
|
||||
latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, AutoencoderKLQwenImage)
|
||||
with (
|
||||
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
|
||||
vae_info.model_on_device() as (_, vae),
|
||||
):
|
||||
context.util.signal_progress("Running VAE")
|
||||
assert isinstance(vae, AutoencoderKLQwenImage)
|
||||
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
|
||||
|
||||
vae.disable_tiling()
|
||||
|
||||
tiling_context = nullcontext()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
with torch.inference_mode(), tiling_context:
|
||||
# The Qwen Image VAE uses per-channel latents_mean / latents_std
|
||||
# instead of a single scaling_factor.
|
||||
# Latents are 5D: (B, C, num_frames, H, W) — the unpack from the
|
||||
# denoise step already produces this shape.
|
||||
latents_mean = (
|
||||
torch.tensor(vae.config.latents_mean)
|
||||
.view(1, vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std + latents_mean
|
||||
|
||||
img = vae.decode(latents, return_dict=False)[0]
|
||||
# Drop the temporal frame dimension: (B, C, 1, H, W) -> (B, C, H, W)
|
||||
img = img[:, :, 0]
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
image_dto = context.images.save(image=img_pil)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
115
invokeai/app/invocations/qwen_image_lora_loader.py
Normal file
115
invokeai/app/invocations/qwen_image_lora_loader.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
@invocation_output("qwen_image_lora_loader_output")
|
||||
class QwenImageLoRALoaderOutput(BaseInvocationOutput):
|
||||
"""Qwen Image LoRA Loader Output"""
|
||||
|
||||
transformer: Optional[TransformerField] = OutputField(
|
||||
default=None, description=FieldDescriptions.transformer, title="Transformer"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"qwen_image_lora_loader",
|
||||
title="Apply LoRA - Qwen Image",
|
||||
tags=["lora", "model", "qwen_image"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class QwenImageLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply a LoRA model to a Qwen Image transformer."""
|
||||
|
||||
lora: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.lora_model,
|
||||
title="LoRA",
|
||||
ui_model_base=BaseModelType.QwenImage,
|
||||
ui_model_type=ModelType.LoRA,
|
||||
)
|
||||
weight: float = InputField(default=1.0, description=FieldDescriptions.lora_weight)
|
||||
transformer: TransformerField | None = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> QwenImageLoRALoaderOutput:
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
raise ValueError(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras):
|
||||
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
|
||||
|
||||
output = QwenImageLoRALoaderOutput()
|
||||
|
||||
if self.transformer is not None:
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
output.transformer.loras.append(
|
||||
LoRAField(
|
||||
lora=self.lora,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@invocation(
|
||||
"qwen_image_lora_collection_loader",
|
||||
title="Apply LoRA Collection - Qwen Image",
|
||||
tags=["lora", "model", "qwen_image"],
|
||||
category="model",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class QwenImageLoRACollectionLoader(BaseInvocation):
|
||||
"""Applies a collection of LoRAs to a Qwen Image transformer."""
|
||||
|
||||
loras: Optional[LoRAField | list[LoRAField]] = InputField(
|
||||
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
|
||||
)
|
||||
transformer: Optional[TransformerField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> QwenImageLoRALoaderOutput:
|
||||
output = QwenImageLoRALoaderOutput()
|
||||
loras = self.loras if isinstance(self.loras, list) else [self.loras]
|
||||
added_loras: list[str] = []
|
||||
|
||||
if self.transformer is not None:
|
||||
output.transformer = self.transformer.model_copy(deep=True)
|
||||
|
||||
for lora in loras:
|
||||
if lora is None:
|
||||
continue
|
||||
if lora.lora.key in added_loras:
|
||||
continue
|
||||
if not context.models.exists(lora.lora.key):
|
||||
raise Exception(f"Unknown lora: {lora.lora.key}!")
|
||||
|
||||
added_loras.append(lora.lora.key)
|
||||
|
||||
if self.transformer is not None and output.transformer is not None:
|
||||
output.transformer.loras.append(lora)
|
||||
|
||||
return output
|
||||
107
invokeai/app/invocations/qwen_image_model_loader.py
Normal file
107
invokeai/app/invocations/qwen_image_model_loader.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.model import (
|
||||
ModelIdentifierField,
|
||||
QwenVLEncoderField,
|
||||
TransformerField,
|
||||
VAEField,
|
||||
)
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
|
||||
|
||||
@invocation_output("qwen_image_model_loader_output")
|
||||
class QwenImageModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Qwen Image model loader output."""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
qwen_vl_encoder: QwenVLEncoderField = OutputField(
|
||||
description=FieldDescriptions.qwen_vl_encoder, title="Qwen VL Encoder"
|
||||
)
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation(
|
||||
"qwen_image_model_loader",
|
||||
title="Main Model - Qwen Image",
|
||||
tags=["model", "qwen_image"],
|
||||
category="model",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class QwenImageModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a Qwen Image model, outputting its submodels.
|
||||
|
||||
The transformer is always loaded from the main model (Diffusers or GGUF).
|
||||
|
||||
For GGUF quantized models, the VAE and Qwen VL encoder must come from a
|
||||
separate Diffusers model specified in the "Component Source" field.
|
||||
|
||||
For Diffusers models, all components are extracted from the main model
|
||||
automatically. The "Component Source" field is ignored.
|
||||
"""
|
||||
|
||||
model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.qwen_image_model,
|
||||
input=Input.Direct,
|
||||
ui_model_base=BaseModelType.QwenImage,
|
||||
ui_model_type=ModelType.Main,
|
||||
title="Transformer",
|
||||
)
|
||||
|
||||
component_source: Optional[ModelIdentifierField] = InputField(
|
||||
default=None,
|
||||
description="Diffusers Qwen Image model to extract the VAE and Qwen VL encoder from. "
|
||||
"Required when using a GGUF quantized transformer. "
|
||||
"Ignored when the main model is already in Diffusers format.",
|
||||
input=Input.Direct,
|
||||
ui_model_base=BaseModelType.QwenImage,
|
||||
ui_model_type=ModelType.Main,
|
||||
ui_model_format=ModelFormat.Diffusers,
|
||||
title="Component Source (Diffusers)",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> QwenImageModelLoaderOutput:
|
||||
main_config = context.models.get_config(self.model)
|
||||
main_is_diffusers = main_config.format == ModelFormat.Diffusers
|
||||
|
||||
# Transformer always comes from the main model
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
|
||||
if main_is_diffusers:
|
||||
# Diffusers model: extract all components directly
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
elif self.component_source is not None:
|
||||
# GGUF/checkpoint transformer: get VAE + encoder from the component source
|
||||
source_config = context.models.get_config(self.component_source)
|
||||
if source_config.format != ModelFormat.Diffusers:
|
||||
raise ValueError(
|
||||
f"The Component Source model must be in Diffusers format. "
|
||||
f"The selected model '{source_config.name}' is in {source_config.format.value} format."
|
||||
)
|
||||
vae = self.component_source.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
tokenizer = self.component_source.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
text_encoder = self.component_source.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
else:
|
||||
raise ValueError(
|
||||
"No source for VAE and Qwen VL encoder. "
|
||||
"GGUF quantized models only contain the transformer — "
|
||||
"please set 'Component Source' to a Diffusers Qwen Image model "
|
||||
"to provide the VAE and text encoder."
|
||||
)
|
||||
|
||||
return QwenImageModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, loras=[]),
|
||||
qwen_vl_encoder=QwenVLEncoderField(tokenizer=tokenizer, text_encoder=text_encoder),
|
||||
vae=VAEField(vae=vae),
|
||||
)
|
||||
298
invokeai/app/invocations/qwen_image_text_encoder.py
Normal file
298
invokeai/app/invocations/qwen_image_text_encoder.py
Normal file
@@ -0,0 +1,298 @@
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.invocations.model import QwenVLEncoderField
|
||||
from invokeai.app.invocations.primitives import QwenImageConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
ConditioningFieldData,
|
||||
QwenImageConditioningInfo,
|
||||
)
|
||||
|
||||
# Prompt templates and drop indices for the two Qwen Image model modes.
|
||||
# These are taken directly from the diffusers pipelines.
|
||||
|
||||
# Image editing mode (QwenImagePipeline)
|
||||
_EDIT_SYSTEM_PROMPT = (
|
||||
"Describe the key features of the input image (color, shape, size, texture, objects, background), "
|
||||
"then explain how the user's text instruction should alter or modify the image. "
|
||||
"Generate a new image that meets the user's requirements while maintaining consistency "
|
||||
"with the original input where appropriate."
|
||||
)
|
||||
_EDIT_DROP_IDX = 64
|
||||
|
||||
# Text-to-image mode (QwenImagePipeline)
|
||||
_GENERATE_SYSTEM_PROMPT = (
|
||||
"Describe the image by detailing the color, shape, size, texture, quantity, "
|
||||
"text, spatial relationships of the objects and background:"
|
||||
)
|
||||
_GENERATE_DROP_IDX = 34
|
||||
|
||||
_IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
|
||||
|
||||
def _build_prompt(user_prompt: str, num_images: int) -> str:
|
||||
"""Build the full prompt with the appropriate template based on whether reference images are provided."""
|
||||
if num_images > 0:
|
||||
# Edit mode: include vision placeholders for reference images
|
||||
image_tokens = _IMAGE_PLACEHOLDER * num_images
|
||||
return (
|
||||
f"<|im_start|>system\n{_EDIT_SYSTEM_PROMPT}<|im_end|>\n"
|
||||
f"<|im_start|>user\n{image_tokens}{user_prompt}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
else:
|
||||
# Generate mode: text-only prompt
|
||||
return (
|
||||
f"<|im_start|>system\n{_GENERATE_SYSTEM_PROMPT}<|im_end|>\n"
|
||||
f"<|im_start|>user\n{user_prompt}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"qwen_image_text_encoder",
|
||||
title="Prompt - Qwen Image",
|
||||
tags=["prompt", "conditioning", "qwen_image"],
|
||||
category="conditioning",
|
||||
version="1.2.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class QwenImageTextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes text and reference images for Qwen Image using Qwen2.5-VL."""
|
||||
|
||||
prompt: str = InputField(description="Text prompt describing the desired edit.", ui_component=UIComponent.Textarea)
|
||||
reference_images: list[ImageField] = InputField(
|
||||
default=[],
|
||||
description="Reference images to guide the edit. The model can use multiple reference images.",
|
||||
)
|
||||
qwen_vl_encoder: QwenVLEncoderField = InputField(
|
||||
title="Qwen VL Encoder",
|
||||
description=FieldDescriptions.qwen_vl_encoder,
|
||||
input=Input.Connection,
|
||||
)
|
||||
quantization: Literal["none", "int8", "nf4"] = InputField(
|
||||
default="none",
|
||||
description="Quantize the Qwen VL encoder to reduce VRAM usage. "
|
||||
"'nf4' (4-bit) saves the most memory, 'int8' (8-bit) is a middle ground.",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resize_for_vl_encoder(image: PILImage.Image, target_pixels: int = 512 * 512) -> PILImage.Image:
|
||||
"""Resize image to fit within target_pixels while preserving aspect ratio.
|
||||
|
||||
Matches the diffusers pipeline's calculate_dimensions logic: the image is resized
|
||||
so its total pixel count is approximately target_pixels, with dimensions rounded to
|
||||
multiples of 32. This prevents large images from producing too many vision tokens
|
||||
which can overwhelm the text prompt.
|
||||
"""
|
||||
w, h = image.size
|
||||
aspect = w / h
|
||||
# Compute dimensions that preserve aspect ratio at ~target_pixels total
|
||||
new_w = int((target_pixels * aspect) ** 0.5)
|
||||
new_h = int(target_pixels / new_w)
|
||||
# Round to multiples of 32
|
||||
new_w = max(32, (new_w // 32) * 32)
|
||||
new_h = max(32, (new_h // 32) * 32)
|
||||
if new_w != w or new_h != h:
|
||||
image = image.resize((new_w, new_h), resample=PILImage.LANCZOS)
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> QwenImageConditioningOutput:
|
||||
# Load and resize reference images to ~1M pixels (matching diffusers pipeline)
|
||||
pil_images: list[PILImage.Image] = []
|
||||
for img_field in self.reference_images:
|
||||
pil_img = context.images.get_pil(img_field.image_name)
|
||||
pil_img = self._resize_for_vl_encoder(pil_img.convert("RGB"))
|
||||
pil_images.append(pil_img)
|
||||
|
||||
prompt_embeds, prompt_mask = self._encode(context, pil_images)
|
||||
prompt_embeds = prompt_embeds.detach().to("cpu")
|
||||
prompt_mask = prompt_mask.detach().to("cpu") if prompt_mask is not None else None
|
||||
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[QwenImageConditioningInfo(prompt_embeds=prompt_embeds, prompt_embeds_mask=prompt_mask)]
|
||||
)
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return QwenImageConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _encode(
|
||||
self, context: InvocationContext, images: list[PILImage.Image]
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""Encode text prompt and reference images using Qwen2.5-VL.
|
||||
|
||||
Matches the diffusers QwenImagePipeline._get_qwen_prompt_embeds logic:
|
||||
1. Format prompt with the edit-specific system template
|
||||
2. Run through Qwen2.5-VL to get hidden states
|
||||
3. Extract valid (non-padding) tokens and drop the system prefix
|
||||
4. Return padded embeddings + attention mask
|
||||
"""
|
||||
from transformers import AutoTokenizer, Qwen2_5_VLProcessor
|
||||
|
||||
try:
|
||||
from transformers import Qwen2_5_VLImageProcessor as _ImageProcessorCls
|
||||
except ImportError:
|
||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( # type: ignore[no-redef]
|
||||
Qwen2VLImageProcessor as _ImageProcessorCls,
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers import Qwen2_5_VLVideoProcessor as _VideoProcessorCls
|
||||
except ImportError:
|
||||
from transformers.models.qwen2_vl.video_processing_qwen2_vl import ( # type: ignore[no-redef]
|
||||
Qwen2VLVideoProcessor as _VideoProcessorCls,
|
||||
)
|
||||
|
||||
# Format the prompt with one vision placeholder per reference image
|
||||
text = _build_prompt(self.prompt, len(images))
|
||||
|
||||
# Build the processor
|
||||
tokenizer_config = context.models.get_config(self.qwen_vl_encoder.tokenizer)
|
||||
model_root = context.models.get_absolute_path(tokenizer_config)
|
||||
tokenizer_dir = model_root / "tokenizer"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_dir), local_files_only=True)
|
||||
|
||||
image_processor = None
|
||||
for search_dir in [model_root / "processor", tokenizer_dir, model_root, model_root / "image_processor"]:
|
||||
if (search_dir / "preprocessor_config.json").exists():
|
||||
image_processor = _ImageProcessorCls.from_pretrained(str(search_dir), local_files_only=True)
|
||||
break
|
||||
if image_processor is None:
|
||||
image_processor = _ImageProcessorCls()
|
||||
|
||||
processor = Qwen2_5_VLProcessor(
|
||||
tokenizer=tokenizer,
|
||||
image_processor=image_processor,
|
||||
video_processor=_VideoProcessorCls(),
|
||||
)
|
||||
|
||||
context.util.signal_progress("Running Qwen2.5-VL text/vision encoder")
|
||||
|
||||
if self.quantization != "none":
|
||||
text_encoder, device, cleanup = self._load_quantized_encoder(context)
|
||||
else:
|
||||
text_encoder, device, cleanup = self._load_cached_encoder(context)
|
||||
|
||||
try:
|
||||
model_inputs = processor(
|
||||
text=[text],
|
||||
images=images if images else None,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(device=device)
|
||||
|
||||
outputs = text_encoder(
|
||||
input_ids=model_inputs.input_ids,
|
||||
attention_mask=model_inputs.attention_mask,
|
||||
pixel_values=getattr(model_inputs, "pixel_values", None),
|
||||
image_grid_thw=getattr(model_inputs, "image_grid_thw", None),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
# Use last hidden state (matching diffusers pipeline)
|
||||
hidden_states = outputs.hidden_states[-1]
|
||||
|
||||
# Extract valid (non-padding) tokens using the attention mask,
|
||||
# then drop the system prompt prefix tokens.
|
||||
# The drop index differs between edit mode (64) and generate mode (34).
|
||||
drop_idx = _EDIT_DROP_IDX if images else _GENERATE_DROP_IDX
|
||||
|
||||
attn_mask = model_inputs.attention_mask
|
||||
bool_mask = attn_mask.bool()
|
||||
valid_lengths = bool_mask.sum(dim=1)
|
||||
selected = hidden_states[bool_mask]
|
||||
split_hidden = torch.split(selected, valid_lengths.tolist(), dim=0)
|
||||
|
||||
# Drop system prefix tokens and build padded output
|
||||
trimmed = [h[drop_idx:] for h in split_hidden]
|
||||
attn_mask_list = [torch.ones(h.size(0), dtype=torch.long, device=device) for h in trimmed]
|
||||
max_seq_len = max(h.size(0) for h in trimmed)
|
||||
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([h, h.new_zeros(max_seq_len - h.size(0), h.size(1))]) for h in trimmed]
|
||||
)
|
||||
encoder_attention_mask = torch.stack(
|
||||
[torch.cat([m, m.new_zeros(max_seq_len - m.size(0))]) for m in attn_mask_list]
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=torch.bfloat16)
|
||||
finally:
|
||||
if cleanup is not None:
|
||||
cleanup()
|
||||
|
||||
# If all tokens are valid (no padding), mask is not needed
|
||||
if encoder_attention_mask.all():
|
||||
encoder_attention_mask = None
|
||||
|
||||
return prompt_embeds, encoder_attention_mask
|
||||
|
||||
def _load_cached_encoder(self, context: InvocationContext):
|
||||
"""Load the text encoder through the model cache (no quantization)."""
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
text_encoder_info = context.models.load(self.qwen_vl_encoder.text_encoder)
|
||||
ctx = text_encoder_info.model_on_device()
|
||||
_, text_encoder = ctx.__enter__()
|
||||
device = get_effective_device(text_encoder)
|
||||
assert isinstance(text_encoder, Qwen2_5_VLForConditionalGeneration)
|
||||
return text_encoder, device, lambda: ctx.__exit__(None, None, None)
|
||||
|
||||
def _load_quantized_encoder(self, context: InvocationContext):
|
||||
"""Load the text encoder with BitsAndBytes quantization, bypassing the model cache.
|
||||
|
||||
BnB-quantized models are pinned to GPU and can't be moved between devices,
|
||||
so they can't go through the standard model cache. The model is loaded fresh
|
||||
each time and freed after use via the cleanup callback.
|
||||
"""
|
||||
import gc
|
||||
import warnings
|
||||
|
||||
from transformers import BitsAndBytesConfig, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
encoder_config = context.models.get_config(self.qwen_vl_encoder.text_encoder)
|
||||
model_root = context.models.get_absolute_path(encoder_config)
|
||||
encoder_path = model_root / "text_encoder"
|
||||
|
||||
if self.quantization == "nf4":
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
else: # int8
|
||||
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
context.util.signal_progress("Loading Qwen2.5-VL encoder (quantized)")
|
||||
with warnings.catch_warnings():
|
||||
# BnB int8 internally casts bfloat16→float16; the warning is harmless
|
||||
warnings.filterwarnings("ignore", message="MatMul8bitLt.*cast.*float16")
|
||||
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
str(encoder_path),
|
||||
quantization_config=bnb_config,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
local_files_only=True,
|
||||
)
|
||||
|
||||
device = next(text_encoder.parameters()).device
|
||||
|
||||
def cleanup():
|
||||
nonlocal text_encoder
|
||||
del text_encoder
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return text_encoder, device, cleanup
|
||||
@@ -34,7 +34,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
"sd3_denoise",
|
||||
title="Denoise - SD3",
|
||||
tags=["image", "sd3"],
|
||||
category="image",
|
||||
category="latents",
|
||||
version="1.1.1",
|
||||
)
|
||||
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@@ -24,7 +24,7 @@ from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory
|
||||
"sd3_i2l",
|
||||
title="Image to Latents - SD3",
|
||||
tags=["image", "latents", "vae", "i2l", "sd3"],
|
||||
category="image",
|
||||
category="latents",
|
||||
version="1.0.1",
|
||||
)
|
||||
class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@@ -16,6 +16,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import SD3ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
from invokeai.backend.model_manager.taxonomy import ModelFormat
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
@@ -30,7 +31,7 @@ SD3_T5_MAX_SEQ_LEN = 256
|
||||
"sd3_text_encoder",
|
||||
title="Prompt - SD3",
|
||||
tags=["prompt", "conditioning", "sd3"],
|
||||
category="conditioning",
|
||||
category="prompt",
|
||||
version="1.0.1",
|
||||
)
|
||||
class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
@@ -103,6 +104,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
context.util.signal_progress("Running T5 encoder")
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
|
||||
t5_device = get_effective_device(t5_text_encoder)
|
||||
|
||||
text_inputs = t5_tokenizer(
|
||||
prompt,
|
||||
@@ -125,7 +127,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
f" {max_seq_len} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
|
||||
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_device))[0]
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
return prompt_embeds
|
||||
@@ -144,6 +146,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
context.util.signal_progress("Running CLIP encoder")
|
||||
assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||
clip_device = get_effective_device(clip_text_encoder)
|
||||
|
||||
clip_text_encoder_config = clip_text_encoder_info.config
|
||||
assert clip_text_encoder_config is not None
|
||||
@@ -187,9 +190,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {tokenizer_max_length} tokens: {removed_text}"
|
||||
)
|
||||
prompt_embeds = clip_text_encoder(
|
||||
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
|
||||
)
|
||||
prompt_embeds = clip_text_encoder(input_ids=text_input_ids.to(clip_device), output_hidden_states=True)
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class StringPosNegOutput(BaseInvocationOutput):
|
||||
"string_split_neg",
|
||||
title="String Split Negative",
|
||||
tags=["string", "split", "negative"],
|
||||
category="string",
|
||||
category="strings",
|
||||
version="1.0.1",
|
||||
)
|
||||
class StringSplitNegInvocation(BaseInvocation):
|
||||
@@ -63,7 +63,7 @@ class String2Output(BaseInvocationOutput):
|
||||
string_2: str = OutputField(description="string 2")
|
||||
|
||||
|
||||
@invocation("string_split", title="String Split", tags=["string", "split"], category="string", version="1.0.1")
|
||||
@invocation("string_split", title="String Split", tags=["string", "split"], category="strings", version="1.0.1")
|
||||
class StringSplitInvocation(BaseInvocation):
|
||||
"""Splits string into two strings, based on the first occurance of the delimiter. The delimiter will be removed from the string"""
|
||||
|
||||
@@ -83,7 +83,7 @@ class StringSplitInvocation(BaseInvocation):
|
||||
return String2Output(string_1=part1, string_2=part2)
|
||||
|
||||
|
||||
@invocation("string_join", title="String Join", tags=["string", "join"], category="string", version="1.0.1")
|
||||
@invocation("string_join", title="String Join", tags=["string", "join"], category="strings", version="1.0.1")
|
||||
class StringJoinInvocation(BaseInvocation):
|
||||
"""Joins string left to string right"""
|
||||
|
||||
@@ -94,7 +94,9 @@ class StringJoinInvocation(BaseInvocation):
|
||||
return StringOutput(value=((self.string_left or "") + (self.string_right or "")))
|
||||
|
||||
|
||||
@invocation("string_join_three", title="String Join Three", tags=["string", "join"], category="string", version="1.0.1")
|
||||
@invocation(
|
||||
"string_join_three", title="String Join Three", tags=["string", "join"], category="strings", version="1.0.1"
|
||||
)
|
||||
class StringJoinThreeInvocation(BaseInvocation):
|
||||
"""Joins string left to string middle to string right"""
|
||||
|
||||
@@ -107,7 +109,7 @@ class StringJoinThreeInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"string_replace", title="String Replace", tags=["string", "replace", "regex"], category="string", version="1.0.1"
|
||||
"string_replace", title="String Replace", tags=["string", "replace", "regex"], category="strings", version="1.0.1"
|
||||
)
|
||||
class StringReplaceInvocation(BaseInvocation):
|
||||
"""Replaces the search string with the replace string"""
|
||||
|
||||
@@ -49,7 +49,7 @@ class T2IAdapterOutput(BaseInvocationOutput):
|
||||
"t2i_adapter",
|
||||
title="T2I-Adapter - SD1.5, SDXL",
|
||||
tags=["t2i_adapter", "control"],
|
||||
category="t2i_adapter",
|
||||
category="conditioning",
|
||||
version="1.0.4",
|
||||
)
|
||||
class T2IAdapterInvocation(BaseInvocation):
|
||||
|
||||
@@ -30,7 +30,7 @@ ESRGAN_MODEL_URLS: dict[str, str] = {
|
||||
}
|
||||
|
||||
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2")
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="upscale", version="1.3.2")
|
||||
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Upscales an image using RealESRGAN."""
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ class ZImageControlOutput(BaseInvocationOutput):
|
||||
"z_image_control",
|
||||
title="Z-Image ControlNet",
|
||||
tags=["image", "z-image", "control", "controlnet"],
|
||||
category="control",
|
||||
category="conditioning",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
|
||||
@@ -49,8 +49,8 @@ from invokeai.backend.z_image.z_image_transformer_patch import patch_transformer
|
||||
"z_image_denoise",
|
||||
title="Denoise - Z-Image",
|
||||
tags=["image", "z-image"],
|
||||
category="image",
|
||||
version="1.4.0",
|
||||
category="latents",
|
||||
version="1.5.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class ZImageDenoiseInvocation(BaseInvocation):
|
||||
@@ -104,6 +104,15 @@ class ZImageDenoiseInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.vae + " Required for control conditioning.",
|
||||
input=Input.Connection,
|
||||
)
|
||||
# Shift override for the sigma schedule. If None, shift is auto-calculated from image dimensions.
|
||||
shift: Optional[float] = InputField(
|
||||
default=None,
|
||||
ge=0.0,
|
||||
description="Override the timestep shift (mu) for the sigma schedule. "
|
||||
"Leave blank to auto-calculate based on image dimensions (recommended). "
|
||||
"Lower values (~0.5) produce less noise shifting, higher values (~1.15) produce more.",
|
||||
title="Shift",
|
||||
)
|
||||
# Scheduler selection for the denoising process
|
||||
scheduler: ZIMAGE_SCHEDULER_NAME_VALUES = InputField(
|
||||
default="euler",
|
||||
@@ -225,34 +234,36 @@ class ZImageDenoiseInvocation(BaseInvocation):
|
||||
"""Calculate timestep shift based on image sequence length.
|
||||
|
||||
Based on diffusers ZImagePipeline.calculate_shift method.
|
||||
"""
|
||||
m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len)
|
||||
b = base_shift - m * base_image_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
def _get_sigmas(self, mu: float, num_steps: int) -> list[float]:
|
||||
"""Generate sigma schedule with time shift.
|
||||
|
||||
Based on FlowMatchEulerDiscreteScheduler with shift.
|
||||
Generates num_steps + 1 sigma values (including terminal 0.0).
|
||||
Returns a linear shift value (exp(mu) from the original formula).
|
||||
"""
|
||||
import math
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: float) -> float:
|
||||
"""Apply time shift to a single timestep value."""
|
||||
m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len)
|
||||
b = base_shift - m * base_image_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
# Convert from exponential mu to linear shift value
|
||||
return math.exp(mu)
|
||||
|
||||
def _get_sigmas(self, shift: float, num_steps: int) -> list[float]:
|
||||
"""Generate sigma schedule with linear time shift.
|
||||
|
||||
Uses linear time shift: shift / (shift + (1/t - 1)).
|
||||
The shift value is used directly as a multiplier.
|
||||
Generates num_steps + 1 sigma values (including terminal 0.0).
|
||||
"""
|
||||
|
||||
def time_shift(shift: float, t: float) -> float:
|
||||
"""Apply linear time shift to a single timestep value."""
|
||||
if t <= 0:
|
||||
return 0.0
|
||||
if t >= 1:
|
||||
return 1.0
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
return shift / (shift + (1 / t - 1))
|
||||
|
||||
# Generate linearly spaced values from 1 to 0 (excluding endpoints for safety)
|
||||
# then apply time shift
|
||||
sigmas = []
|
||||
for i in range(num_steps + 1):
|
||||
t = 1.0 - i / num_steps # Goes from 1.0 to 0.0
|
||||
sigma = time_shift(mu, 1.0, t)
|
||||
sigma = time_shift(shift, t)
|
||||
sigmas.append(sigma)
|
||||
|
||||
return sigmas
|
||||
@@ -313,11 +324,14 @@ class ZImageDenoiseInvocation(BaseInvocation):
|
||||
# Concatenate all negative embeddings
|
||||
neg_prompt_embeds = torch.cat([tc.prompt_embeds for tc in neg_text_conditionings], dim=0)
|
||||
|
||||
# Calculate shift based on image sequence length
|
||||
mu = self._calculate_shift(img_seq_len)
|
||||
# Calculate shift based on image sequence length, or use override
|
||||
if self.shift is not None:
|
||||
shift = self.shift
|
||||
else:
|
||||
shift = self._calculate_shift(img_seq_len)
|
||||
|
||||
# Generate sigma schedule with time shift
|
||||
sigmas = self._get_sigmas(mu, self.steps)
|
||||
sigmas = self._get_sigmas(shift, self.steps)
|
||||
|
||||
# Apply denoising_start and denoising_end clipping
|
||||
if self.denoising_start > 0 or self.denoising_end < 1:
|
||||
|
||||
@@ -30,7 +30,7 @@ ZImageVAE = Union[AutoencoderKL, FluxAutoEncoder]
|
||||
"z_image_i2l",
|
||||
title="Image to Latents - Z-Image",
|
||||
tags=["image", "latents", "vae", "i2l", "z-image"],
|
||||
category="image",
|
||||
category="latents",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
"z_image_seed_variance_enhancer",
|
||||
title="Seed Variance Enhancer - Z-Image",
|
||||
tags=["conditioning", "z-image", "variance", "seed"],
|
||||
category="conditioning",
|
||||
category="prompt",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import Qwen3EncoderField
|
||||
from invokeai.app.invocations.primitives import ZImageConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.z_image_lora_constants import Z_IMAGE_LORA_QWEN3_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
@@ -33,7 +34,7 @@ Z_IMAGE_MAX_SEQ_LEN = 512
|
||||
"z_image_text_encoder",
|
||||
title="Prompt - Z-Image",
|
||||
tags=["prompt", "conditioning", "z-image"],
|
||||
category="conditioning",
|
||||
category="prompt",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
@@ -76,11 +77,17 @@ class ZImageTextEncoderInvocation(BaseInvocation):
|
||||
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
(_, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
|
||||
(cached_weights, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
|
||||
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
|
||||
|
||||
# Use the device that the text_encoder is actually on
|
||||
device = text_encoder.device
|
||||
# Use the device that the text encoder is effectively executing on, and repair any required tensors left on
|
||||
# the CPU by a previous interrupted run.
|
||||
repaired_tensors = text_encoder_info.repair_required_tensors_on_device()
|
||||
device = get_effective_device(text_encoder)
|
||||
if repaired_tensors > 0:
|
||||
context.logger.warning(
|
||||
f"Recovered {repaired_tensors} required Qwen3 tensor(s) onto {device} after a partial device mismatch."
|
||||
)
|
||||
|
||||
# Apply LoRA models to the text encoder
|
||||
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
|
||||
@@ -90,6 +97,7 @@ class ZImageTextEncoderInvocation(BaseInvocation):
|
||||
patches=self._lora_iterator(context),
|
||||
prefix=Z_IMAGE_LORA_QWEN3_PREFIX,
|
||||
dtype=lora_dtype,
|
||||
cached_weights=cached_weights,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Password hashing and validation utilities."""
|
||||
|
||||
from typing import cast
|
||||
from typing import Literal, cast
|
||||
|
||||
from passlib.context import CryptContext
|
||||
|
||||
@@ -84,3 +84,30 @@ def validate_password_strength(password: str) -> tuple[bool, str]:
|
||||
return False, "Password must contain uppercase, lowercase, and numbers"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def get_password_strength(password: str) -> Literal["weak", "moderate", "strong"]:
|
||||
"""Determine the strength of a password.
|
||||
|
||||
Strength levels:
|
||||
- weak: less than 8 characters
|
||||
- moderate: 8+ characters but missing at least one of uppercase, lowercase, or digit
|
||||
- strong: 8+ characters with uppercase, lowercase, and digit
|
||||
|
||||
Args:
|
||||
password: The password to evaluate
|
||||
|
||||
Returns:
|
||||
One of "weak", "moderate", or "strong"
|
||||
"""
|
||||
if len(password) < 8:
|
||||
return "weak"
|
||||
|
||||
has_upper = any(c.isupper() for c in password)
|
||||
has_lower = any(c.islower() for c in password)
|
||||
has_digit = any(c.isdigit() for c in password)
|
||||
|
||||
if not (has_upper and has_lower and has_digit):
|
||||
return "moderate"
|
||||
|
||||
return "strong"
|
||||
|
||||
@@ -21,6 +21,7 @@ class TokenData(BaseModel):
|
||||
user_id: str
|
||||
email: str
|
||||
is_admin: bool
|
||||
remember_me: bool = False
|
||||
|
||||
|
||||
def set_jwt_secret(secret: str) -> None:
|
||||
|
||||
@@ -9,6 +9,17 @@ from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
|
||||
class BoardVisibility(str, Enum, metaclass=MetaEnum):
|
||||
"""The visibility options for a board."""
|
||||
|
||||
Private = "private"
|
||||
"""Only the board owner (and admins) can see and modify this board."""
|
||||
Shared = "shared"
|
||||
"""All users can view this board, but only the owner (and admins) can modify it."""
|
||||
Public = "public"
|
||||
"""All users can view this board; only the owner (and admins) can modify its structure."""
|
||||
|
||||
|
||||
class BoardRecord(BaseModelExcludeNull):
|
||||
"""Deserialized board record."""
|
||||
|
||||
@@ -28,6 +39,10 @@ class BoardRecord(BaseModelExcludeNull):
|
||||
"""The name of the cover image of the board."""
|
||||
archived: bool = Field(description="Whether or not the board is archived.")
|
||||
"""Whether or not the board is archived."""
|
||||
board_visibility: BoardVisibility = Field(
|
||||
default=BoardVisibility.Private, description="The visibility of the board."
|
||||
)
|
||||
"""The visibility of the board (private, shared, or public)."""
|
||||
|
||||
|
||||
def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||
@@ -44,6 +59,11 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||
updated_at = board_dict.get("updated_at", get_iso_timestamp())
|
||||
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
|
||||
archived = board_dict.get("archived", False)
|
||||
board_visibility_raw = board_dict.get("board_visibility", BoardVisibility.Private.value)
|
||||
try:
|
||||
board_visibility = BoardVisibility(board_visibility_raw)
|
||||
except ValueError:
|
||||
board_visibility = BoardVisibility.Private
|
||||
|
||||
return BoardRecord(
|
||||
board_id=board_id,
|
||||
@@ -54,6 +74,7 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||
updated_at=updated_at,
|
||||
deleted_at=deleted_at,
|
||||
archived=archived,
|
||||
board_visibility=board_visibility,
|
||||
)
|
||||
|
||||
|
||||
@@ -61,6 +82,7 @@ class BoardChanges(BaseModel, extra="forbid"):
|
||||
board_name: Optional[str] = Field(default=None, description="The board's new name.", max_length=300)
|
||||
cover_image_name: Optional[str] = Field(default=None, description="The name of the board's new cover image.")
|
||||
archived: Optional[bool] = Field(default=None, description="Whether or not the board is archived")
|
||||
board_visibility: Optional[BoardVisibility] = Field(default=None, description="The visibility of the board.")
|
||||
|
||||
|
||||
class BoardRecordOrderBy(str, Enum, metaclass=MetaEnum):
|
||||
|
||||
@@ -116,6 +116,17 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
(changes.archived, board_id),
|
||||
)
|
||||
|
||||
# Change the visibility of a board
|
||||
if changes.board_visibility is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET board_visibility = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.board_visibility.value, board_id),
|
||||
)
|
||||
|
||||
except sqlite3.Error as e:
|
||||
raise BoardRecordSaveException from e
|
||||
return self.get(board_id)
|
||||
@@ -155,7 +166,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
SELECT DISTINCT boards.*
|
||||
FROM boards
|
||||
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
|
||||
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
|
||||
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public'))
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
LIMIT ? OFFSET ?;
|
||||
@@ -194,14 +205,14 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
SELECT COUNT(DISTINCT boards.board_id)
|
||||
FROM boards
|
||||
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
|
||||
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1);
|
||||
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public'));
|
||||
"""
|
||||
else:
|
||||
count_query = """
|
||||
SELECT COUNT(DISTINCT boards.board_id)
|
||||
FROM boards
|
||||
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
|
||||
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
|
||||
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public'))
|
||||
AND boards.archived = 0;
|
||||
"""
|
||||
|
||||
@@ -251,7 +262,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
SELECT DISTINCT boards.*
|
||||
FROM boards
|
||||
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
|
||||
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
|
||||
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public'))
|
||||
{archived_filter}
|
||||
ORDER BY LOWER(boards.board_name) {direction}
|
||||
"""
|
||||
@@ -260,7 +271,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
SELECT DISTINCT boards.*
|
||||
FROM boards
|
||||
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
|
||||
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
|
||||
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public'))
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
"""
|
||||
|
||||
@@ -7,7 +7,11 @@ class BulkDownloadBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def handler(
|
||||
self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
|
||||
self,
|
||||
image_names: Optional[list[str]],
|
||||
board_id: Optional[str],
|
||||
bulk_download_item_id: Optional[str],
|
||||
user_id: str = "system",
|
||||
) -> None:
|
||||
"""
|
||||
Create a zip file containing the images specified by the given image names or board id.
|
||||
@@ -15,6 +19,7 @@ class BulkDownloadBase(ABC):
|
||||
:param image_names: A list of image names to include in the zip file.
|
||||
:param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
|
||||
:param bulk_download_item_id: The bulk_download_item_id that will be used to retrieve the bulk download item when it is prepared, if none is provided a uuid will be generated.
|
||||
:param user_id: The ID of the user who initiated the download.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -42,3 +47,12 @@ class BulkDownloadBase(ABC):
|
||||
|
||||
:param bulk_download_item_name: The name of the bulk download item.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_owner(self, bulk_download_item_name: str) -> Optional[str]:
|
||||
"""
|
||||
Get the user_id of the user who initiated the download.
|
||||
|
||||
:param bulk_download_item_name: The name of the bulk download item.
|
||||
:return: The user_id of the owner, or None if not tracked.
|
||||
"""
|
||||
|
||||
@@ -25,15 +25,24 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
self._temp_directory = TemporaryDirectory()
|
||||
self._bulk_downloads_folder = Path(self._temp_directory.name) / "bulk_downloads"
|
||||
self._bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
|
||||
# Track which user owns each download so the fetch endpoint can enforce ownership
|
||||
self._download_owners: dict[str, str] = {}
|
||||
|
||||
def handler(
|
||||
self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
|
||||
self,
|
||||
image_names: Optional[list[str]],
|
||||
board_id: Optional[str],
|
||||
bulk_download_item_id: Optional[str],
|
||||
user_id: str = "system",
|
||||
) -> None:
|
||||
bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID
|
||||
bulk_download_item_id = bulk_download_item_id or uuid_string()
|
||||
bulk_download_item_name = bulk_download_item_id + ".zip"
|
||||
|
||||
self._signal_job_started(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
|
||||
# Record ownership so the fetch endpoint can verify the caller
|
||||
self._download_owners[bulk_download_item_name] = user_id
|
||||
|
||||
self._signal_job_started(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id)
|
||||
|
||||
try:
|
||||
image_dtos: list[ImageDTO] = []
|
||||
@@ -46,16 +55,16 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
raise BulkDownloadParametersException()
|
||||
|
||||
bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id)
|
||||
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
|
||||
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id)
|
||||
except (
|
||||
ImageRecordNotFoundException,
|
||||
BoardRecordNotFoundException,
|
||||
BulkDownloadException,
|
||||
BulkDownloadParametersException,
|
||||
) as e:
|
||||
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e)
|
||||
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e, user_id)
|
||||
except Exception as e:
|
||||
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e)
|
||||
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e, user_id)
|
||||
self._invoker.services.logger.error("Problem bulk downloading images.")
|
||||
raise e
|
||||
|
||||
@@ -103,43 +112,60 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " " or c == "_" or c == "-"]).rstrip()
|
||||
|
||||
def _signal_job_started(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
self,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
user_id: str = "system",
|
||||
) -> None:
|
||||
"""Signal that a bulk download job has started."""
|
||||
if self._invoker:
|
||||
assert bulk_download_id is not None
|
||||
self._invoker.services.events.emit_bulk_download_started(
|
||||
bulk_download_id, bulk_download_item_id, bulk_download_item_name
|
||||
bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id=user_id
|
||||
)
|
||||
|
||||
def _signal_job_completed(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
self,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
user_id: str = "system",
|
||||
) -> None:
|
||||
"""Signal that a bulk download job has completed."""
|
||||
if self._invoker:
|
||||
assert bulk_download_id is not None
|
||||
assert bulk_download_item_name is not None
|
||||
self._invoker.services.events.emit_bulk_download_complete(
|
||||
bulk_download_id, bulk_download_item_id, bulk_download_item_name
|
||||
bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id=user_id
|
||||
)
|
||||
|
||||
def _signal_job_failed(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, exception: Exception
|
||||
self,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
exception: Exception,
|
||||
user_id: str = "system",
|
||||
) -> None:
|
||||
"""Signal that a bulk download job has failed."""
|
||||
if self._invoker:
|
||||
assert bulk_download_id is not None
|
||||
assert exception is not None
|
||||
self._invoker.services.events.emit_bulk_download_error(
|
||||
bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception)
|
||||
bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception), user_id=user_id
|
||||
)
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self._temp_directory.cleanup()
|
||||
|
||||
def get_owner(self, bulk_download_item_name: str) -> Optional[str]:
|
||||
return self._download_owners.get(bulk_download_item_name)
|
||||
|
||||
def delete(self, bulk_download_item_name: str) -> None:
|
||||
path = self.get_path(bulk_download_item_name)
|
||||
Path(path).unlink()
|
||||
self._download_owners.pop(bulk_download_item_name, None)
|
||||
|
||||
def get_path(self, bulk_download_item_name: str) -> str:
|
||||
path = str(self._bulk_downloads_folder / bulk_download_item_name)
|
||||
|
||||
@@ -22,6 +22,7 @@ from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
API_KEYS_FILE = Path("api_keys.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
|
||||
@@ -30,6 +31,14 @@ ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8
|
||||
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
||||
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
||||
CONFIG_SCHEMA_VERSION = "4.0.2"
|
||||
EXTERNAL_PROVIDER_CONFIG_FIELDS = (
|
||||
"external_alibabacloud_api_key",
|
||||
"external_alibabacloud_base_url",
|
||||
"external_gemini_api_key",
|
||||
"external_gemini_base_url",
|
||||
"external_openai_api_key",
|
||||
"external_openai_base_url",
|
||||
)
|
||||
|
||||
|
||||
class URLRegexTokenPair(BaseModel):
|
||||
@@ -101,7 +110,8 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
|
||||
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
|
||||
max_queue_size: Maximum number of items in the session queue.
|
||||
clear_queue_on_startup: Empties session queue on startup.
|
||||
clear_queue_on_startup: Empties session queue on startup. If true, disables `max_queue_history`.
|
||||
max_queue_history: Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true.
|
||||
allow_nodes: List of nodes to allow. Omit to allow all.
|
||||
deny_nodes: List of nodes to deny. Omit to deny none.
|
||||
node_cache_size: How many cached nodes to keep in memory.
|
||||
@@ -111,6 +121,11 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
unsafe_disable_picklescan: UNSAFE. Disable the picklescan security check during model installation. Recommended only for development and testing purposes. This will allow arbitrary code execution during model installation, so should never be used in production.
|
||||
allow_unknown_models: Allow installation of models that we are unable to identify. If enabled, models will be marked as `unknown` in the database, and will not have any metadata associated with them. If disabled, unknown models will be rejected during installation.
|
||||
multiuser: Enable multiuser support. When disabled, the application runs in single-user mode using a default system account with administrator privileges. When enabled, requires user authentication and authorization.
|
||||
strict_password_checking: Enforce strict password requirements. When True, passwords must contain uppercase, lowercase, and numbers. When False (default), any password is accepted but its strength (weak/moderate/strong) is reported to the user.
|
||||
external_gemini_api_key: API key for Gemini image generation.
|
||||
external_openai_api_key: API key for OpenAI image generation.
|
||||
external_gemini_base_url: Base URL override for Gemini image generation.
|
||||
external_openai_base_url: Base URL override for OpenAI image generation.
|
||||
"""
|
||||
|
||||
_root: Optional[Path] = PrivateAttr(default=None)
|
||||
@@ -190,7 +205,8 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
|
||||
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
|
||||
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
|
||||
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")
|
||||
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup. If true, disables `max_queue_history`.")
|
||||
max_queue_history: Optional[int] = Field(default=None, ge=0, description="Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true.")
|
||||
|
||||
# NODES
|
||||
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")
|
||||
@@ -206,6 +222,21 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
|
||||
# MULTIUSER
|
||||
multiuser: bool = Field(default=False, description="Enable multiuser support. When disabled, the application runs in single-user mode using a default system account with administrator privileges. When enabled, requires user authentication and authorization.")
|
||||
strict_password_checking: bool = Field(default=False, description="Enforce strict password requirements. When True, passwords must contain uppercase, lowercase, and numbers. When False (default), any password is accepted but its strength (weak/moderate/strong) is reported to the user.")
|
||||
|
||||
# EXTERNAL PROVIDERS
|
||||
external_alibabacloud_api_key: Optional[str] = Field(default=None, description="API key for Alibaba Cloud DashScope image generation.")
|
||||
external_alibabacloud_base_url: Optional[str] = Field(
|
||||
default=None, description="Base URL override for Alibaba Cloud DashScope image generation."
|
||||
)
|
||||
external_gemini_api_key: Optional[str] = Field(default=None, description="API key for Gemini image generation.")
|
||||
external_openai_api_key: Optional[str] = Field(default=None, description="API key for OpenAI image generation.")
|
||||
external_gemini_base_url: Optional[str] = Field(
|
||||
default=None, description="Base URL override for Gemini image generation."
|
||||
)
|
||||
external_openai_base_url: Optional[str] = Field(
|
||||
default=None, description="Base URL override for OpenAI image generation."
|
||||
)
|
||||
|
||||
# fmt: on
|
||||
|
||||
@@ -288,6 +319,13 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
assert resolved_path is not None
|
||||
return resolved_path
|
||||
|
||||
@property
|
||||
def api_keys_file_path(self) -> Path:
|
||||
"""Path to api_keys.yaml, resolved to an absolute path.."""
|
||||
resolved_path = self._resolve(API_KEYS_FILE)
|
||||
assert resolved_path is not None
|
||||
return resolved_path
|
||||
|
||||
@property
|
||||
def outputs_path(self) -> Optional[Path]:
|
||||
"""Path to the outputs directory, resolved to an absolute path.."""
|
||||
@@ -500,6 +538,36 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||
|
||||
|
||||
def load_external_api_keys(api_keys_file_path: Path) -> dict[str, str]:
|
||||
"""Load external provider config (API keys and base URLs) from a dedicated YAML file."""
|
||||
if not api_keys_file_path.exists():
|
||||
return {}
|
||||
|
||||
with open(api_keys_file_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
||||
loaded_api_keys: Any = yaml.safe_load(file)
|
||||
|
||||
if loaded_api_keys is None:
|
||||
return {}
|
||||
|
||||
if not isinstance(loaded_api_keys, dict):
|
||||
raise RuntimeError(f"Failed to load api keys file {api_keys_file_path}: expected a mapping")
|
||||
|
||||
parsed_api_keys: dict[str, str] = {}
|
||||
for field_name in EXTERNAL_PROVIDER_CONFIG_FIELDS:
|
||||
value = loaded_api_keys.get(field_name)
|
||||
if value is None:
|
||||
continue
|
||||
if not isinstance(value, str):
|
||||
raise RuntimeError(
|
||||
f"Failed to load api keys file {api_keys_file_path}: value for '{field_name}' must be a string"
|
||||
)
|
||||
stripped_value = value.strip()
|
||||
if stripped_value:
|
||||
parsed_api_keys[field_name] = stripped_value
|
||||
|
||||
return parsed_api_keys
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_config() -> InvokeAIAppConfig:
|
||||
"""Get the global singleton app config.
|
||||
@@ -516,6 +584,7 @@ def get_config() -> InvokeAIAppConfig:
|
||||
"""
|
||||
# This object includes environment variables, as parsed by pydantic-settings
|
||||
config = InvokeAIAppConfig()
|
||||
env_fields_set = set(config.model_fields_set)
|
||||
|
||||
args = InvokeAIArgs.args
|
||||
|
||||
@@ -577,4 +646,11 @@ def get_config() -> InvokeAIAppConfig:
|
||||
default_config = DefaultInvokeAIAppConfig()
|
||||
default_config.write_file(config.config_file_path, as_example=False)
|
||||
|
||||
api_keys_from_file = load_external_api_keys(config.api_keys_file_path)
|
||||
if api_keys_from_file:
|
||||
# API keys file should take precedence over invokeai.yaml, but not over environment variables.
|
||||
api_keys_to_apply = {key: value for key, value in api_keys_from_file.items() if key not in env_fields_set}
|
||||
if api_keys_to_apply:
|
||||
config.update_config(api_keys_to_apply, clobber=True)
|
||||
|
||||
return config
|
||||
|
||||
@@ -100,9 +100,9 @@ class EventServiceBase:
|
||||
"""Emitted when a queue item's status changes"""
|
||||
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status))
|
||||
|
||||
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None:
|
||||
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult", user_id: str = "system") -> None:
|
||||
"""Emitted when a batch is enqueued"""
|
||||
self.dispatch(BatchEnqueuedEvent.build(enqueue_result))
|
||||
self.dispatch(BatchEnqueuedEvent.build(enqueue_result, user_id))
|
||||
|
||||
def emit_queue_items_retried(self, retry_result: "RetryItemsResult") -> None:
|
||||
"""Emitted when a list of queue items are retried"""
|
||||
@@ -112,9 +112,9 @@ class EventServiceBase:
|
||||
"""Emitted when a queue is cleared"""
|
||||
self.dispatch(QueueClearedEvent.build(queue_id))
|
||||
|
||||
def emit_recall_parameters_updated(self, queue_id: str, parameters: dict) -> None:
|
||||
def emit_recall_parameters_updated(self, queue_id: str, user_id: str, parameters: dict) -> None:
|
||||
"""Emitted when recall parameters are updated"""
|
||||
self.dispatch(RecallParametersUpdatedEvent.build(queue_id, parameters))
|
||||
self.dispatch(RecallParametersUpdatedEvent.build(queue_id, user_id, parameters))
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -194,23 +194,42 @@ class EventServiceBase:
|
||||
# region Bulk image download
|
||||
|
||||
def emit_bulk_download_started(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
self,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
user_id: str = "system",
|
||||
) -> None:
|
||||
"""Emitted when a bulk image download is started"""
|
||||
self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
|
||||
self.dispatch(
|
||||
BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id)
|
||||
)
|
||||
|
||||
def emit_bulk_download_complete(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
self,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
user_id: str = "system",
|
||||
) -> None:
|
||||
"""Emitted when a bulk image download is complete"""
|
||||
self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
|
||||
self.dispatch(
|
||||
BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id)
|
||||
)
|
||||
|
||||
def emit_bulk_download_error(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
|
||||
self,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
error: str,
|
||||
user_id: str = "system",
|
||||
) -> None:
|
||||
"""Emitted when a bulk image download has an error"""
|
||||
self.dispatch(
|
||||
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error)
|
||||
BulkDownloadErrorEvent.build(
|
||||
bulk_download_id, bulk_download_item_id, bulk_download_item_name, error, user_id
|
||||
)
|
||||
)
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -281,9 +281,10 @@ class BatchEnqueuedEvent(QueueEventBase):
|
||||
)
|
||||
priority: int = Field(description="The priority of the batch")
|
||||
origin: str | None = Field(default=None, description="The origin of the batch")
|
||||
user_id: str = Field(default="system", description="The ID of the user who enqueued the batch")
|
||||
|
||||
@classmethod
|
||||
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
|
||||
def build(cls, enqueue_result: EnqueueBatchResult, user_id: str = "system") -> "BatchEnqueuedEvent":
|
||||
return cls(
|
||||
queue_id=enqueue_result.queue_id,
|
||||
batch_id=enqueue_result.batch.batch_id,
|
||||
@@ -291,6 +292,7 @@ class BatchEnqueuedEvent(QueueEventBase):
|
||||
enqueued=enqueue_result.enqueued,
|
||||
requested=enqueue_result.requested,
|
||||
priority=enqueue_result.priority,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -609,6 +611,7 @@ class BulkDownloadEventBase(EventBase):
|
||||
bulk_download_id: str = Field(description="The ID of the bulk image download")
|
||||
bulk_download_item_id: str = Field(description="The ID of the bulk image download item")
|
||||
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
|
||||
user_id: str = Field(default="system", description="The ID of the user who initiated the download")
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
@@ -619,12 +622,17 @@ class BulkDownloadStartedEvent(BulkDownloadEventBase):
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
cls,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
user_id: str = "system",
|
||||
) -> "BulkDownloadStartedEvent":
|
||||
return cls(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -636,12 +644,17 @@ class BulkDownloadCompleteEvent(BulkDownloadEventBase):
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
cls,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
user_id: str = "system",
|
||||
) -> "BulkDownloadCompleteEvent":
|
||||
return cls(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -655,13 +668,19 @@ class BulkDownloadErrorEvent(BulkDownloadEventBase):
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
|
||||
cls,
|
||||
bulk_download_id: str,
|
||||
bulk_download_item_id: str,
|
||||
bulk_download_item_name: str,
|
||||
error: str,
|
||||
user_id: str = "system",
|
||||
) -> "BulkDownloadErrorEvent":
|
||||
return cls(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
error=error,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -671,8 +690,9 @@ class RecallParametersUpdatedEvent(QueueEventBase):
|
||||
|
||||
__event_name__ = "recall_parameters_updated"
|
||||
|
||||
user_id: str = Field(description="The ID of the user whose recall parameters were updated")
|
||||
parameters: dict[str, Any] = Field(description="The recall parameters that were updated")
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_id: str, parameters: dict[str, Any]) -> "RecallParametersUpdatedEvent":
|
||||
return cls(queue_id=queue_id, parameters=parameters)
|
||||
def build(cls, queue_id: str, user_id: str, parameters: dict[str, Any]) -> "RecallParametersUpdatedEvent":
|
||||
return cls(queue_id=queue_id, user_id=user_id, parameters=parameters)
|
||||
|
||||
@@ -46,3 +46,9 @@ class FastAPIEventService(EventServiceBase):
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
raise e # Raise a proper error
|
||||
except Exception:
|
||||
import logging
|
||||
|
||||
logging.getLogger("InvokeAI").error(
|
||||
f"Error dispatching event {getattr(event, '__event_name__', event)}", exc_info=True
|
||||
)
|
||||
|
||||
23
invokeai/app/services/external_generation/__init__.py
Normal file
23
invokeai/app/services/external_generation/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from invokeai.app.services.external_generation.external_generation_base import (
|
||||
ExternalGenerationServiceBase,
|
||||
ExternalProvider,
|
||||
)
|
||||
from invokeai.app.services.external_generation.external_generation_common import (
|
||||
ExternalGeneratedImage,
|
||||
ExternalGenerationRequest,
|
||||
ExternalGenerationResult,
|
||||
ExternalProviderStatus,
|
||||
ExternalReferenceImage,
|
||||
)
|
||||
from invokeai.app.services.external_generation.external_generation_default import ExternalGenerationService
|
||||
|
||||
__all__ = [
|
||||
"ExternalGenerationRequest",
|
||||
"ExternalGenerationResult",
|
||||
"ExternalGeneratedImage",
|
||||
"ExternalGenerationService",
|
||||
"ExternalGenerationServiceBase",
|
||||
"ExternalProvider",
|
||||
"ExternalProviderStatus",
|
||||
"ExternalReferenceImage",
|
||||
]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user