Merge remote-tracking branch 'upstream/main' into external-models

This commit is contained in:
Alexander Eichhorn
2026-04-14 03:43:39 +02:00
99 changed files with 3025 additions and 211 deletions

View File

@@ -0,0 +1,205 @@
# Canvas Projects — Technical Documentation
## Overview
Canvas Projects provide a save/load mechanism for the entire canvas state. The feature serializes all canvas entities, generation parameters, reference images, and their associated image files into a ZIP-based `.invk` file. On load, it restores the full state, handling image deduplication and re-uploading as needed.
## File Format
The `.invk` file is a standard ZIP archive with the following structure:
```
project.invk
├── manifest.json
├── canvas_state.json
├── params.json
├── ref_images.json
├── loras.json
└── images/
├── {image_name_1}.png
├── {image_name_2}.png
└── ...
```
### manifest.json
Schema version and metadata. Validated on load with Zod.
```json
{
"version": 1,
"appVersion": "5.12.0",
"createdAt": "2026-02-26T12:00:00.000Z",
"name": "My Canvas Project"
}
```
| Field | Type | Description |
|---|---|---|
| `version` | `number` | Schema version, currently `1`. Used for migration logic on load. |
| `appVersion` | `string` | InvokeAI version that created the file. Informational only. |
| `createdAt` | `string` | ISO 8601 timestamp. |
| `name` | `string` | User-provided project name. Also used as the download filename. |
### canvas_state.json
The serialized canvas entity tree. Type: `CanvasProjectState`.
```typescript
type CanvasProjectState = {
rasterLayers: CanvasRasterLayerState[];
controlLayers: CanvasControlLayerState[];
inpaintMasks: CanvasInpaintMaskState[];
regionalGuidance: CanvasRegionalGuidanceState[];
bbox: CanvasState['bbox'];
selectedEntityIdentifier: CanvasState['selectedEntityIdentifier'];
bookmarkedEntityIdentifier: CanvasState['bookmarkedEntityIdentifier'];
};
```
Each entity contains its full state including all canvas objects (brush lines, eraser lines, rect shapes, images). Image objects reference files by `image_name` which correspond to files in the `images/` folder.
### params.json
The complete generation parameters state (`ParamsState`). Optional on load (older files may not have it). This includes all fields from the params Redux slice:
- Prompts (positive, negative, prompt history)
- Core generation settings (seed, steps, CFG scale, guidance, scheduler, iterations)
- Model selections (main model, VAE, FLUX VAE, T5 encoder, CLIP embed models, refiner, Z-Image models, Klein models)
- Dimensions (width, height, aspect ratio)
- Img2img strength
- Infill settings (method, tile size, patchmatch downscale, color)
- Canvas coherence settings (mode, edge size, min denoise)
- Refiner parameters (steps, CFG scale, scheduler, aesthetic scores, start)
- FLUX-specific settings (scheduler, DyPE preset/scale/exponent)
- Z-Image-specific settings (scheduler, seed variance)
- Upscale settings (scheduler, CFG scale)
- Seamless tiling, mask blur, CLIP skip, VAE precision, CPU noise, color compensation
### ref_images.json
Global reference image entities (`RefImageState[]`). These are IP-Adapter / FLUX Redux configs with `CroppableImageWithDims` containing both original and cropped image references. Optional on load.
### loras.json
Array of LoRA configurations (`LoRA[]`). Each entry contains:
```typescript
type LoRA = {
id: string;
isEnabled: boolean;
model: ModelIdentifierField;
weight: number;
};
```
Optional on load. Like models, LoRA identifiers are stored as-is — if a LoRA is not installed when loading, the entry is restored but may not be usable.
### images/
All image files referenced anywhere in the state. Keyed by their original `image_name`. On save, each image is fetched from the backend via `GET /api/v1/images/i/{name}/full` and stored as-is.
## Key Source Files
| File | Purpose |
|---|---|
| `features/controlLayers/util/canvasProjectFile.ts` | Types, constants, image name collection, remapping, existence checking |
| `features/controlLayers/hooks/useCanvasProjectSave.ts` | Save hook — collects Redux state, fetches images, builds ZIP |
| `features/controlLayers/hooks/useCanvasProjectLoad.ts` | Load hook — parses ZIP, deduplicates images, dispatches state |
| `features/controlLayers/components/SaveCanvasProjectDialog.tsx` | Save name dialog + `useSaveCanvasProjectWithDialog` hook |
| `features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog.tsx` | Load confirmation dialog + `useLoadCanvasProjectWithDialog` hook |
| `features/controlLayers/components/Toolbar/CanvasToolbarProjectMenuButton.tsx` | Toolbar dropdown UI |
| `features/controlLayers/store/canvasSlice.ts` | `canvasProjectRecalled` Redux action |
## Save Flow
1. User clicks "Save Canvas Project" → `SaveCanvasProjectDialog` opens asking for a project name
2. On confirm, `saveCanvasProject(name)` is called
3. Read Redux state via selectors: `selectCanvasSlice()`, `selectParamsSlice()`, `selectRefImagesSlice()`, `selectLoRAsSlice()`
4. Build `CanvasProjectState` from the canvas slice; use `paramsState` directly for params
5. Walk all entities to collect every `image_name` reference via `collectImageNames()`:
- `CanvasImageState.image.image_name` in layer/mask objects
- `CroppableImageWithDims.original.image.image_name` in global ref images
- `CroppableImageWithDims.crop.image.image_name` in cropped ref images
- `ImageWithDims.image_name` in regional guidance ref images
6. Fetch each image from the backend API
7. Build ZIP with JSZip: add `manifest.json` (including `name`), `canvas_state.json`, `params.json`, `ref_images.json`, and all images into `images/`
8. Sanitize the name for filesystem use and generate blob, trigger download as `{name}.invk`
## Load Flow
1. User selects `.invk` file → confirmation dialog opens
2. On confirm, parse ZIP with JSZip
3. Validate manifest version via Zod schema
4. Read `canvas_state.json`, `params.json` (optional), `ref_images.json` (optional)
5. Collect all `image_name` references from the loaded state
6. **Deduplicate images**: for each referenced image, check if it exists on the server via `getImageDTOSafe(image_name)`
- Already exists → skip (no upload)
- Missing → upload from ZIP via `uploadImage()`, record `oldName → newName` mapping
7. Remap all `image_name` values in the loaded state using the mapping (only for re-uploaded images whose names changed)
8. Dispatch Redux actions:
- `canvasProjectRecalled()` — restores all canvas entities, bbox, selected/bookmarked entity
- `refImagesRecalled()` — restores global reference images
- `paramsRecalled()` — replaces the entire params state in one action
- `loraAllDeleted()` + `loraRecalled()` — restores LoRAs
9. Show success/error toast
## Image Name Collection & Remapping
The `canvasProjectFile.ts` utility provides two parallel sets of functions:
**Collection** (`collectImageNames`): Walks the entire state tree and returns a `Set<string>` of all referenced `image_name` values. This is used by both save (to know which images to fetch) and load (to know which images to check/upload).
**Remapping** (`remapCanvasState`, `remapRefImages`): Deep-clones state objects and replaces `image_name` values using a `Map<string, string>` mapping. Only images that were re-uploaded with a different name are remapped. Images that already existed on the server are left unchanged.
Both walk the same paths through the state tree:
- Layer/mask objects → `CanvasImageState.image.image_name`
- Regional guidance ref images → `ImageWithDims.image_name`
- Global ref images → `CroppableImageWithDims.original.image.image_name` and `.crop.image.image_name`
## Extending the Format
### Adding new optional data (non-breaking)
Add a new JSON file to the ZIP. No version bump needed.
1. **Save**: Add `zip.file('new_data.json', JSON.stringify(data))` in `useCanvasProjectSave.ts`
2. **Load**: Read with `zip.file('new_data.json')` in `useCanvasProjectLoad.ts` — check for `null` so older project files without it still load
3. **Dispatch**: Add the appropriate Redux action to restore the data
### Adding new entity types with images
1. Extend `CanvasProjectState` type in `canvasProjectFile.ts`
2. Add collection logic in `collectImageNames()` to walk the new entity's objects
3. Add remapping logic in `remapCanvasState()` to update image names
4. Include the new entity array in both save and load hooks
5. Handle it in the `canvasProjectRecalled` reducer in `canvasSlice.ts`
### Breaking schema changes
1. Bump `CANVAS_PROJECT_VERSION` in `canvasProjectFile.ts`
2. Update the Zod manifest schema: `version: z.union([z.literal(1), z.literal(2)])`
3. Add migration logic in the load hook: check version, transform v1 → v2 before dispatching
## UI Architecture
### Save dialog
The save flow uses a **nanostore atom** (`$isOpen`) to control the `SaveCanvasProjectDialog`:
1. `useSaveCanvasProjectWithDialog()` — returns a callback that sets `$isOpen` to `true`
2. `SaveCanvasProjectDialog` (singleton in `GlobalModalIsolator`) — renders an `AlertDialog` with a name input
3. On save → calls `saveCanvasProject(name)` and closes the dialog
4. On cancel → closes the dialog
### Load dialog
The load flow uses a **nanostore atom** (`$pendingFile`) to decouple the file dialog from the confirmation dialog:
1. `useLoadCanvasProjectWithDialog()` — opens a programmatic file input (`document.createElement('input')`)
2. On file selection → sets `$pendingFile` atom
3. `LoadCanvasProjectConfirmationAlertDialog` (singleton in `GlobalModalIsolator`) — subscribes to `$pendingFile` via `useStore()`
4. On accept → calls `loadCanvasProject(file)` and clears the atom
5. On cancel → clears the atom
The programmatic file input approach was chosen because the context menu component uses `isLazy: true`, which unmounts the DOM tree when the menu closes — a hidden `<input>` element inside the menu would be destroyed before the file dialog returns.

View File

@@ -0,0 +1,32 @@
Lasso Tool
===========
- The Lasso tool creates selections and inpaint masks by drawing freehand or polygonal regions on the canvas.
How to open the Lasso tool
--------------------------
- Click the Lasso icon in the toolbar.
- Hotkey: press `L` (default). The hotkey is shown in the tool's tooltip and can be customized in Hotkeys settings.
Modes
-----
- Freehand (default)
- Hold the pointer and drag to draw a continuous contour.
- Long segments are broken into intermediate points to keep the line continuous.
- Very long strokes may be simplified after drawing to reduce point count for performance.
- Polygon
- Click to place points; click the first point (or a point near it) to close the polygon.
- The tool snaps the closing point to the start for precise closures.
Basic interactions
------------------
- Switch modes with the mode toggle in the toolbar.
- To close a polygon: click the starting point again or click near it — the tool aligns the final point to the start to complete the shape.
- The selection will be added to the current Inpaint Mask layer. If no Inpaint Mask layer exists, a new one will be created automatically.
Tips & behavior
---------------
- Hold `Space` to temporarily switch to the View tool for panning and zooming; release `Space` to return to the Lasso tool and continue drawing.
- When using the Polygon mode, you can hold `Shift` to snap points to horizontal, vertical, or 45-degree angles for more precise shapes.
- Hold `Ctrl` (Windows/Linux) or `Command` (macOS) while drawing to subtract from the current selection instead of adding to it.

View File

@@ -0,0 +1,56 @@
---
title: Canvas Projects
---
# :material-folder-zip: Canvas Projects
## Save and Restore Your Canvas Work
Canvas Projects let you save your entire canvas setup to a file and load it back later. This is useful when you want to:
- **Switch between tasks** without losing your current canvas arrangement
- **Back up complex setups** with multiple layers, masks, and reference images
- **Share canvas layouts** with others or transfer them between machines
- **Recover from deleted images** — all images are embedded in the project file
## What Gets Saved
A canvas project file (`.invk`) captures everything about your current canvas session:
- **All layers** — raster layers, control layers, inpaint masks, regional guidance
- **All drawn content** — brush strokes, pasted images, eraser marks
- **Reference images** — global IP-Adapter / FLUX Redux images with crop settings
- **Regional guidance** — per-region prompts and reference images
- **Bounding box** — position, size, aspect ratio, and scale settings
- **All generation parameters** — prompts, seed, steps, CFG scale, guidance, scheduler, model, VAE, dimensions, img2img strength, infill settings, canvas coherence, refiner settings, FLUX/Z-Image specific parameters, and more
- **LoRAs** — all added LoRA models with their weights and enabled/disabled state
## How to Save a Project
You can save from two places:
1. **Toolbar** — Click the **Archive icon** in the canvas toolbar, then select **Save Canvas Project**
2. **Context menu** — Right-click the canvas, open the **Project** submenu, then select **Save Canvas Project**
A dialog will ask you to enter a **project name**. This name is used as the filename (e.g., entering "My Portrait" saves as `My Portrait.invk`) and is stored inside the project file.
## How to Load a Project
1. **Toolbar** — Click the **Archive icon**, then select **Load Canvas Project**
2. **Context menu** — Right-click the canvas, open the **Project** submenu, then select **Load Canvas Project**
A file dialog will open. Select your `.invk` file. You will see a confirmation dialog warning that loading will replace your current canvas. Click **Load** to proceed.
### What Happens on Load
- Your current canvas is **completely replaced** — all existing layers, masks, reference images, and parameters are overwritten
- Images that are already present on your InvokeAI server are reused automatically (no duplicate uploads)
- Images that were deleted from the server are re-uploaded from the project file
- If the saved model is not installed on your system, the model identifier is still restored — you will need to select an available model manually
## Good to Know
- **No undo** — Loading a project replaces your canvas entirely. There is no way to undo this action, so save your current project first if you want to keep it.
- **Image deduplication** — When loading, images already on your server are not re-uploaded. Only missing images are uploaded from the project file.
- **File size** — The `.invk` file size depends on the number and resolution of images in your canvas. A project with many high-resolution layers can be large.
- **Model availability** — The project saves which model was selected, but does not include the model itself. If the model is not installed when you load the project, you will need to select a different one.

View File

@@ -10,6 +10,7 @@ from fastapi import Body, HTTPException, Path
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from invokeai.app.api.auth_dependencies import AdminUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.config.config_default import (
EXTERNAL_PROVIDER_CONFIG_FIELDS,
@@ -102,6 +103,16 @@ EXTERNAL_PROVIDER_FIELDS: dict[str, tuple[str, str]] = {
_EXTERNAL_PROVIDER_CONFIG_LOCK = Lock()
class UpdateAppGenerationSettingsRequest(BaseModel):
"""Writable generation-related app settings."""
max_queue_history: int | None = Field(
default=None,
ge=0,
description="Keep the last N completed, failed, and canceled queue items on startup. Set to 0 to prune all terminal items.",
)
@app_router.get(
"/runtime_config", operation_id="get_runtime_config", status_code=200, response_model=InvokeAIAppConfigWithSetFields
)
@@ -110,6 +121,30 @@ async def get_runtime_config() -> InvokeAIAppConfigWithSetFields:
return InvokeAIAppConfigWithSetFields(set_fields=config.model_fields_set, config=config)
@app_router.patch(
"/runtime_config",
operation_id="update_runtime_config",
status_code=200,
response_model=InvokeAIAppConfigWithSetFields,
)
async def update_runtime_config(
_: AdminUserOrDefault,
changes: UpdateAppGenerationSettingsRequest = Body(description="Writable runtime configuration changes"),
) -> InvokeAIAppConfigWithSetFields:
config = get_config()
update_dict = changes.model_dump(exclude_unset=True)
config.update_config(update_dict)
if config.config_file_path.exists():
persisted_config = load_and_migrate_config(config.config_file_path)
else:
persisted_config = DefaultInvokeAIAppConfig()
persisted_config.update_config(update_dict)
persisted_config.write_file(config.config_file_path)
return InvokeAIAppConfigWithSetFields(set_fields=config.model_fields_set, config=config)
@app_router.get(
"/external_providers/status",
operation_id="get_external_provider_statuses",

View File

@@ -56,7 +56,7 @@ class BaseBatchInvocation(BaseInvocation):
"image_batch",
title="Image Batch",
tags=["primitives", "image", "batch", "special"],
category="primitives",
category="batch",
version="1.0.0",
classification=Classification.Special,
)
@@ -87,7 +87,7 @@ class ImageGeneratorField(BaseModel):
"image_generator",
title="Image Generator",
tags=["primitives", "board", "image", "batch", "special"],
category="primitives",
category="batch",
version="1.0.0",
classification=Classification.Special,
)
@@ -111,7 +111,7 @@ class ImageGenerator(BaseInvocation):
"string_batch",
title="String Batch",
tags=["primitives", "string", "batch", "special"],
category="primitives",
category="batch",
version="1.0.0",
classification=Classification.Special,
)
@@ -142,7 +142,7 @@ class StringGeneratorField(BaseModel):
"string_generator",
title="String Generator",
tags=["primitives", "string", "number", "batch", "special"],
category="primitives",
category="batch",
version="1.0.0",
classification=Classification.Special,
)
@@ -166,7 +166,7 @@ class StringGenerator(BaseInvocation):
"integer_batch",
title="Integer Batch",
tags=["primitives", "integer", "number", "batch", "special"],
category="primitives",
category="batch",
version="1.0.0",
classification=Classification.Special,
)
@@ -195,7 +195,7 @@ class IntegerGeneratorField(BaseModel):
"integer_generator",
title="Integer Generator",
tags=["primitives", "int", "number", "batch", "special"],
category="primitives",
category="batch",
version="1.0.0",
classification=Classification.Special,
)
@@ -219,7 +219,7 @@ class IntegerGenerator(BaseInvocation):
"float_batch",
title="Float Batch",
tags=["primitives", "float", "number", "batch", "special"],
category="primitives",
category="batch",
version="1.0.0",
classification=Classification.Special,
)
@@ -250,7 +250,7 @@ class FloatGeneratorField(BaseModel):
"float_generator",
title="Float Generator",
tags=["primitives", "float", "number", "batch", "special"],
category="primitives",
category="batch",
version="1.0.0",
classification=Classification.Special,
)

View File

@@ -11,7 +11,7 @@ from invokeai.backend.image_util.util import cv2_to_pil, pil_to_cv2
"canny_edge_detection",
title="Canny Edge Detection",
tags=["controlnet", "canny"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class CannyEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -33,7 +33,7 @@ from invokeai.backend.util.devices import TorchDevice
"cogview4_denoise",
title="Denoise - CogView4",
tags=["image", "cogview4"],
category="image",
category="latents",
version="1.0.0",
classification=Classification.Prototype,
)

View File

@@ -27,7 +27,7 @@ from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory
"cogview4_i2l",
title="Image to Latents - CogView4",
tags=["image", "latents", "vae", "i2l", "cogview4"],
category="image",
category="latents",
version="1.0.0",
classification=Classification.Prototype,
)

View File

@@ -20,7 +20,7 @@ COGVIEW4_GLM_MAX_SEQ_LEN = 1024
"cogview4_text_encoder",
title="Prompt - CogView4",
tags=["prompt", "conditioning", "cogview4"],
category="conditioning",
category="prompt",
version="1.0.0",
classification=Classification.Prototype,
)

View File

@@ -11,9 +11,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
@invocation(
"range", title="Integer Range", tags=["collection", "integer", "range"], category="collections", version="1.0.0"
)
@invocation("range", title="Integer Range", tags=["collection", "integer", "range"], category="batch", version="1.0.0")
class RangeInvocation(BaseInvocation):
"""Creates a range of numbers from start to stop with step"""
@@ -35,7 +33,7 @@ class RangeInvocation(BaseInvocation):
"range_of_size",
title="Integer Range of Size",
tags=["collection", "integer", "size", "range"],
category="collections",
category="batch",
version="1.0.0",
)
class RangeOfSizeInvocation(BaseInvocation):
@@ -55,7 +53,7 @@ class RangeOfSizeInvocation(BaseInvocation):
"random_range",
title="Random Range",
tags=["range", "integer", "random", "collection"],
category="collections",
category="batch",
version="1.0.1",
use_cache=False,
)

View File

@@ -11,7 +11,7 @@ from invokeai.backend.image_util.util import np_to_pil, pil_to_np
"color_map",
title="Color Map",
tags=["controlnet"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class ColorMapInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -43,7 +43,7 @@ from invokeai.backend.util.devices import TorchDevice
"compel",
title="Prompt - SD1.5",
tags=["prompt", "compel"],
category="conditioning",
category="prompt",
version="1.2.1",
)
class CompelInvocation(BaseInvocation):
@@ -248,7 +248,7 @@ class SDXLPromptInvocationBase:
"sdxl_compel_prompt",
title="Prompt - SDXL",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
category="prompt",
version="1.2.1",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@@ -342,7 +342,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"sdxl_refiner_compel_prompt",
title="Prompt - SDXL Refiner",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
category="prompt",
version="1.1.2",
)
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@@ -391,7 +391,7 @@ class CLIPSkipInvocationOutput(BaseInvocationOutput):
"clip_skip",
title="Apply CLIP Skip - SD1.5, SDXL",
tags=["clipskip", "clip", "skip"],
category="conditioning",
category="prompt",
version="1.1.1",
)
class CLIPSkipInvocation(BaseInvocation):

View File

@@ -9,7 +9,7 @@ from invokeai.backend.image_util.content_shuffle import content_shuffle
"content_shuffle",
title="Content Shuffle",
tags=["controlnet", "normal"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class ContentShuffleInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -64,7 +64,7 @@ class ControlOutput(BaseInvocationOutput):
@invocation(
"controlnet", title="ControlNet - SD1.5, SD2, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3"
"controlnet", title="ControlNet - SD1.5, SD2, SDXL", tags=["controlnet"], category="conditioning", version="1.1.3"
)
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
@@ -116,7 +116,7 @@ class ControlNetInvocation(BaseInvocation):
"heuristic_resize",
title="Heuristic Resize",
tags=["image, controlnet"],
category="image",
category="controlnet_preprocessors",
version="1.1.1",
classification=Classification.Prototype,
)

View File

@@ -18,7 +18,7 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_t
"create_denoise_mask",
title="Create Denoise Mask",
tags=["mask", "denoise"],
category="latents",
category="mask",
version="1.0.2",
)
class CreateDenoiseMaskInvocation(BaseInvocation):

View File

@@ -41,7 +41,7 @@ class GradientMaskOutput(BaseInvocationOutput):
"create_gradient_mask",
title="Create Gradient Mask",
tags=["mask", "denoise"],
category="latents",
category="mask",
version="1.3.0",
)
class CreateGradientMaskInvocation(BaseInvocation):

View File

@@ -20,7 +20,7 @@ DEPTH_ANYTHING_MODELS = {
"depth_anything_depth_estimation",
title="Depth Anything Depth Estimation",
tags=["controlnet", "depth", "depth anything"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class DepthAnythingDepthEstimationInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -11,7 +11,7 @@ from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
"dw_openpose_detection",
title="DW Openpose Detection",
tags=["controlnet", "dwpose", "openpose"],
category="controlnet",
category="controlnet_preprocessors",
version="1.1.1",
)
class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -435,7 +435,9 @@ def get_faces_list(
return all_faces
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.2")
@invocation(
"face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="segmentation", version="1.2.2"
)
class FaceOffInvocation(BaseInvocation, WithMetadata):
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
@@ -514,7 +516,9 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
return output
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.2")
@invocation(
"face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="segmentation", version="1.2.2"
)
class FaceMaskInvocation(BaseInvocation, WithMetadata):
"""Face mask creation using mediapipe face detection"""
@@ -617,7 +621,11 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
@invocation(
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.2"
"face_identifier",
title="FaceIdentifier",
tags=["image", "face", "identifier"],
category="segmentation",
version="1.2.2",
)
class FaceIdentifierInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""

View File

@@ -53,7 +53,7 @@ from invokeai.backend.util.devices import TorchDevice
"flux2_denoise",
title="FLUX2 Denoise",
tags=["image", "flux", "flux2", "klein", "denoise"],
category="image",
category="latents",
version="1.4.0",
classification=Classification.Prototype,
)

View File

@@ -45,7 +45,7 @@ KLEIN_MAX_SEQ_LEN = 512
"flux2_klein_text_encoder",
title="Prompt - Flux2 Klein",
tags=["prompt", "conditioning", "flux", "klein", "qwen3"],
category="conditioning",
category="prompt",
version="1.1.1",
classification=Classification.Prototype,
)

View File

@@ -50,7 +50,7 @@ class FluxControlNetOutput(BaseInvocationOutput):
"flux_controlnet",
title="FLUX ControlNet",
tags=["controlnet", "flux"],
category="controlnet",
category="conditioning",
version="1.0.0",
)
class FluxControlNetInvocation(BaseInvocation):

View File

@@ -70,7 +70,7 @@ from invokeai.backend.util.devices import TorchDevice
"flux_denoise",
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
category="latents",
version="4.5.1",
)
class FluxDenoiseInvocation(BaseInvocation):

View File

@@ -29,7 +29,7 @@ class FluxFillOutput(BaseInvocationOutput):
"flux_fill",
title="FLUX Fill Conditioning",
tags=["inpaint"],
category="inpaint",
category="conditioning",
version="1.0.0",
classification=Classification.Beta,
)

View File

@@ -24,7 +24,7 @@ from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
"flux_ip_adapter",
title="FLUX IP-Adapter",
tags=["ip_adapter", "control"],
category="ip_adapter",
category="conditioning",
version="1.0.0",
)
class FluxIPAdapterInvocation(BaseInvocation):

View File

@@ -47,7 +47,7 @@ DOWNSAMPLING_FUNCTIONS = Literal["nearest", "bilinear", "bicubic", "area", "near
"flux_redux",
title="FLUX Redux",
tags=["ip_adapter", "control"],
category="ip_adapter",
category="conditioning",
version="2.1.0",
classification=Classification.Beta,
)

View File

@@ -28,7 +28,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
"flux_text_encoder",
title="Prompt - FLUX",
tags=["prompt", "conditioning", "flux"],
category="conditioning",
category="prompt",
version="1.1.2",
)
class FluxTextEncoderInvocation(BaseInvocation):

View File

@@ -24,7 +24,7 @@ GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = {
"grounding_dino",
title="Grounding DINO (Text Prompt Object Detection)",
tags=["prompt", "object detection"],
category="image",
category="segmentation",
version="1.0.0",
)
class GroundingDinoInvocation(BaseInvocation):

View File

@@ -11,7 +11,7 @@ from invokeai.backend.image_util.hed import ControlNetHED_Apache2, HEDEdgeDetect
"hed_edge_detection",
title="HED Edge Detection",
tags=["controlnet", "hed", "softedge"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class HEDEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -21,6 +21,7 @@ class IdealSizeOutput(BaseInvocationOutput):
"ideal_size",
title="Ideal Size - SD1.5, SDXL",
tags=["latents", "math", "ideal_size"],
category="latents",
version="1.0.6",
)
class IdealSizeInvocation(BaseInvocation):

View File

@@ -197,7 +197,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard):
"tomask",
title="Mask from Alpha",
tags=["image", "mask"],
category="image",
category="mask",
version="1.2.2",
)
class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -604,7 +604,7 @@ class DecodeInvisibleWatermarkInvocation(BaseInvocation):
"mask_edge",
title="Mask Edge",
tags=["image", "mask", "inpaint"],
category="image",
category="mask",
version="1.2.2",
)
class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -643,7 +643,7 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard):
"mask_combine",
title="Combine Masks",
tags=["image", "mask", "multiply"],
category="image",
category="mask",
version="1.2.2",
)
class MaskCombineInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -974,7 +974,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard):
"save_image",
title="Save Image",
tags=["primitives", "image"],
category="primitives",
category="image",
version="1.2.2",
use_cache=False,
)
@@ -995,7 +995,7 @@ class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"canvas_paste_back",
title="Canvas Paste Back",
tags=["image", "combine"],
category="image",
category="canvas",
version="1.0.1",
)
class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -1032,7 +1032,7 @@ class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
"mask_from_id",
title="Mask from Segmented Image",
tags=["image", "mask", "id"],
category="image",
category="mask",
version="1.0.1",
)
class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -1069,7 +1069,7 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
"canvas_v2_mask_and_crop",
title="Canvas V2 Mask and Crop",
tags=["image", "mask", "id"],
category="image",
category="canvas",
version="1.0.0",
classification=Classification.Deprecated,
)
@@ -1110,7 +1110,7 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
@invocation(
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.1"
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="mask", version="1.0.1"
)
class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
@@ -1199,7 +1199,7 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
"apply_mask_to_image",
title="Apply Mask to Image",
tags=["image", "mask", "blend"],
category="image",
category="mask",
version="1.0.0",
)
class ApplyMaskToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -1374,7 +1374,7 @@ class PasteImageIntoBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoar
"flux_kontext_image_prep",
title="FLUX Kontext Image Prep",
tags=["image", "concatenate", "flux", "kontext"],
category="image",
category="conditioning",
version="1.0.0",
)
class FluxKontextConcatenateImagesInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -23,7 +23,7 @@ class ImagePanelCoordinateOutput(BaseInvocationOutput):
"image_panel_layout",
title="Image Panel Layout",
tags=["image", "panel", "layout"],
category="image",
category="canvas",
version="1.0.0",
classification=Classification.Prototype,
)

View File

@@ -73,7 +73,7 @@ CLIP_VISION_MODEL_MAP: dict[Literal["ViT-L", "ViT-H", "ViT-G"], StarterModel] =
"ip_adapter",
title="IP-Adapter - SD1.5, SDXL",
tags=["ip_adapter", "control"],
category="ip_adapter",
category="conditioning",
version="1.5.1",
)
class IPAdapterInvocation(BaseInvocation):

View File

@@ -11,7 +11,7 @@ from invokeai.backend.image_util.lineart import Generator, LineartEdgeDetector
"lineart_edge_detection",
title="Lineart Edge Detection",
tags=["controlnet", "lineart"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class LineartEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -9,7 +9,7 @@ from invokeai.backend.image_util.lineart_anime import LineartAnimeEdgeDetector,
"lineart_anime_edge_detection",
title="Lineart Anime Edge Detection",
tags=["controlnet", "lineart"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class LineartAnimeEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -19,7 +19,7 @@ from invokeai.backend.util.devices import TorchDevice
"llava_onevision_vllm",
title="LLaVA OneVision VLLM",
tags=["vllm"],
category="vllm",
category="multimodal",
version="1.0.0",
classification=Classification.Beta,
)

View File

@@ -12,7 +12,7 @@ class IfInvocationOutput(BaseInvocationOutput):
)
@invocation("if", title="If", tags=["logic", "conditional"], category="logic", version="1.0.0")
@invocation("if", title="If", tags=["logic", "conditional"], category="math", version="1.0.0")
class IfInvocation(BaseInvocation):
"""Selects between two optional inputs based on a boolean condition."""

View File

@@ -24,7 +24,7 @@ from invokeai.backend.image_util.util import pil_to_np
"rectangle_mask",
title="Create Rectangle Mask",
tags=["conditioning"],
category="conditioning",
category="mask",
version="1.0.1",
)
class RectangleMaskInvocation(BaseInvocation, WithMetadata):
@@ -55,7 +55,7 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata):
"alpha_mask_to_tensor",
title="Alpha Mask to Tensor",
tags=["conditioning"],
category="conditioning",
category="mask",
version="1.0.0",
)
class AlphaMaskToTensorInvocation(BaseInvocation):
@@ -83,7 +83,7 @@ class AlphaMaskToTensorInvocation(BaseInvocation):
"invert_tensor_mask",
title="Invert Tensor Mask",
tags=["conditioning"],
category="conditioning",
category="mask",
version="1.1.0",
)
class InvertTensorMaskInvocation(BaseInvocation):
@@ -115,7 +115,7 @@ class InvertTensorMaskInvocation(BaseInvocation):
"image_mask_to_tensor",
title="Image Mask to Tensor",
tags=["conditioning"],
category="conditioning",
category="mask",
version="1.0.0",
)
class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):

View File

@@ -9,7 +9,7 @@ from invokeai.backend.image_util.mediapipe_face import detect_faces
"mediapipe_face_detection",
title="MediaPipe Face Detection",
tags=["controlnet", "face"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class MediaPipeFaceDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -621,7 +621,7 @@ class LatentsMetaOutput(LatentsOutput, MetadataOutput):
"denoise_latents_meta",
title=f"{DenoiseLatentsInvocation.UIConfig.title} + Metadata",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
category="metadata",
version="1.1.1",
)
class DenoiseLatentsMetaInvocation(DenoiseLatentsInvocation, WithMetadata):
@@ -686,7 +686,7 @@ class DenoiseLatentsMetaInvocation(DenoiseLatentsInvocation, WithMetadata):
"flux_denoise_meta",
title=f"{FluxDenoiseInvocation.UIConfig.title} + Metadata",
tags=["flux", "latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
category="metadata",
version="1.0.1",
)
class FluxDenoiseLatentsMetaInvocation(FluxDenoiseInvocation, WithMetadata):
@@ -734,7 +734,7 @@ class FluxDenoiseLatentsMetaInvocation(FluxDenoiseInvocation, WithMetadata):
"z_image_denoise_meta",
title=f"{ZImageDenoiseInvocation.UIConfig.title} + Metadata",
tags=["z-image", "latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
category="metadata",
version="1.0.0",
)
class ZImageDenoiseMetaInvocation(ZImageDenoiseInvocation, WithMetadata):

View File

@@ -10,7 +10,7 @@ from invokeai.backend.image_util.mlsd.models.mbv2_mlsd_large import MobileV2_MLS
"mlsd_detection",
title="MLSD Detection",
tags=["controlnet", "mlsd", "edge"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class MLSDDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -584,7 +584,7 @@ class SeamlessModeInvocation(BaseInvocation):
return SeamlessModeOutput(unet=unet, vae=vae)
@invocation("freeu", title="Apply FreeU - SD1.5, SDXL", tags=["freeu"], category="unet", version="1.0.2")
@invocation("freeu", title="Apply FreeU - SD1.5, SDXL", tags=["freeu"], category="model", version="1.0.2")
class FreeUInvocation(BaseInvocation):
"""
Applies FreeU to the UNet. Suggested values (b1/b2/s1/s2):

View File

@@ -10,7 +10,7 @@ from invokeai.backend.image_util.normal_bae.nets.NNET import NNET
"normal_map",
title="Normal Map",
tags=["controlnet", "normal"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class NormalMapInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -16,7 +16,9 @@ class PBRMapsOutput(BaseInvocationOutput):
displacement_map: ImageField = OutputField(default=None, description="The generated displacement map")
@invocation("pbr_maps", title="PBR Maps", tags=["image", "material"], category="image", version="1.0.0")
@invocation(
"pbr_maps", title="PBR Maps", tags=["image", "material"], category="controlnet_preprocessors", version="1.0.0"
)
class PBRMapsInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generate Normal, Displacement and Roughness Map from a given image"""

View File

@@ -10,7 +10,7 @@ from invokeai.backend.image_util.pidi.model import PiDiNet
"pidi_edge_detection",
title="PiDiNet Edge Detection",
tags=["controlnet", "edge"],
category="controlnet",
category="controlnet_preprocessors",
version="1.0.0",
)
class PiDiNetEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -34,7 +34,7 @@ from invokeai.backend.util.devices import TorchDevice
"sd3_denoise",
title="Denoise - SD3",
tags=["image", "sd3"],
category="image",
category="latents",
version="1.1.1",
)
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -24,7 +24,7 @@ from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory
"sd3_i2l",
title="Image to Latents - SD3",
tags=["image", "latents", "vae", "i2l", "sd3"],
category="image",
category="latents",
version="1.0.1",
)
class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):

View File

@@ -31,7 +31,7 @@ SD3_T5_MAX_SEQ_LEN = 256
"sd3_text_encoder",
title="Prompt - SD3",
tags=["prompt", "conditioning", "sd3"],
category="conditioning",
category="prompt",
version="1.0.1",
)
class Sd3TextEncoderInvocation(BaseInvocation):

View File

@@ -20,7 +20,7 @@ class StringPosNegOutput(BaseInvocationOutput):
"string_split_neg",
title="String Split Negative",
tags=["string", "split", "negative"],
category="string",
category="strings",
version="1.0.1",
)
class StringSplitNegInvocation(BaseInvocation):
@@ -63,7 +63,7 @@ class String2Output(BaseInvocationOutput):
string_2: str = OutputField(description="string 2")
@invocation("string_split", title="String Split", tags=["string", "split"], category="string", version="1.0.1")
@invocation("string_split", title="String Split", tags=["string", "split"], category="strings", version="1.0.1")
class StringSplitInvocation(BaseInvocation):
"""Splits string into two strings, based on the first occurance of the delimiter. The delimiter will be removed from the string"""
@@ -83,7 +83,7 @@ class StringSplitInvocation(BaseInvocation):
return String2Output(string_1=part1, string_2=part2)
@invocation("string_join", title="String Join", tags=["string", "join"], category="string", version="1.0.1")
@invocation("string_join", title="String Join", tags=["string", "join"], category="strings", version="1.0.1")
class StringJoinInvocation(BaseInvocation):
"""Joins string left to string right"""
@@ -94,7 +94,9 @@ class StringJoinInvocation(BaseInvocation):
return StringOutput(value=((self.string_left or "") + (self.string_right or "")))
@invocation("string_join_three", title="String Join Three", tags=["string", "join"], category="string", version="1.0.1")
@invocation(
"string_join_three", title="String Join Three", tags=["string", "join"], category="strings", version="1.0.1"
)
class StringJoinThreeInvocation(BaseInvocation):
"""Joins string left to string middle to string right"""
@@ -107,7 +109,7 @@ class StringJoinThreeInvocation(BaseInvocation):
@invocation(
"string_replace", title="String Replace", tags=["string", "replace", "regex"], category="string", version="1.0.1"
"string_replace", title="String Replace", tags=["string", "replace", "regex"], category="strings", version="1.0.1"
)
class StringReplaceInvocation(BaseInvocation):
"""Replaces the search string with the replace string"""

View File

@@ -49,7 +49,7 @@ class T2IAdapterOutput(BaseInvocationOutput):
"t2i_adapter",
title="T2I-Adapter - SD1.5, SDXL",
tags=["t2i_adapter", "control"],
category="t2i_adapter",
category="conditioning",
version="1.0.4",
)
class T2IAdapterInvocation(BaseInvocation):

View File

@@ -30,7 +30,7 @@ ESRGAN_MODEL_URLS: dict[str, str] = {
}
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2")
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="upscale", version="1.3.2")
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Upscales an image using RealESRGAN."""

View File

@@ -57,7 +57,7 @@ class ZImageControlOutput(BaseInvocationOutput):
"z_image_control",
title="Z-Image ControlNet",
tags=["image", "z-image", "control", "controlnet"],
category="control",
category="conditioning",
version="1.1.0",
classification=Classification.Prototype,
)

View File

@@ -49,7 +49,7 @@ from invokeai.backend.z_image.z_image_transformer_patch import patch_transformer
"z_image_denoise",
title="Denoise - Z-Image",
tags=["image", "z-image"],
category="image",
category="latents",
version="1.5.0",
classification=Classification.Prototype,
)

View File

@@ -30,7 +30,7 @@ ZImageVAE = Union[AutoencoderKL, FluxAutoEncoder]
"z_image_i2l",
title="Image to Latents - Z-Image",
tags=["image", "latents", "vae", "i2l", "z-image"],
category="image",
category="latents",
version="1.1.0",
classification=Classification.Prototype,
)

View File

@@ -19,7 +19,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
"z_image_seed_variance_enhancer",
title="Seed Variance Enhancer - Z-Image",
tags=["conditioning", "z-image", "variance", "seed"],
category="conditioning",
category="prompt",
version="1.0.0",
classification=Classification.Prototype,
)

View File

@@ -34,7 +34,7 @@ Z_IMAGE_MAX_SEQ_LEN = 512
"z_image_text_encoder",
title="Prompt - Z-Image",
tags=["prompt", "conditioning", "z-image"],
category="conditioning",
category="prompt",
version="1.1.0",
classification=Classification.Prototype,
)

View File

@@ -108,7 +108,8 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue.
clear_queue_on_startup: Empties session queue on startup.
clear_queue_on_startup: Empties session queue on startup. If true, disables `max_queue_history`.
max_queue_history: Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true.
allow_nodes: List of nodes to allow. Omit to allow all.
deny_nodes: List of nodes to deny. Omit to deny none.
node_cache_size: How many cached nodes to keep in memory.
@@ -202,7 +203,8 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup. If true, disables `max_queue_history`.")
max_queue_history: Optional[int] = Field(default=None, ge=0, description="Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true.")
# NODES
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")

View File

@@ -45,10 +45,19 @@ class SqliteSessionQueue(SessionQueueBase):
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
self._set_in_progress_to_canceled()
if self.__invoker.services.configuration.clear_queue_on_startup:
config = self.__invoker.services.configuration
if config.clear_queue_on_startup:
clear_result = self.clear(DEFAULT_QUEUE_ID)
if clear_result.deleted > 0:
self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items")
return
if config.max_queue_history is not None:
deleted = self._prune_terminal_to_limit(DEFAULT_QUEUE_ID, config.max_queue_history)
if deleted > 0:
self.__invoker.services.logger.info(
f"Pruned {deleted} completed/failed/canceled queue items (kept up to {config.max_queue_history})"
)
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
@@ -68,6 +77,51 @@ class SqliteSessionQueue(SessionQueueBase):
"""
)
def _prune_terminal_to_limit(self, queue_id: str, keep: int) -> int:
"""Prune terminal items (completed/failed/canceled) to keep at most N most-recent items."""
with self._db.transaction() as cursor:
where = """--sql
WHERE
queue_id = ?
AND (
status = 'completed'
OR status = 'failed'
OR status = 'canceled'
)
"""
cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where}
AND item_id NOT IN (
SELECT item_id
FROM session_queue
{where}
ORDER BY COALESCE(completed_at, updated_at, created_at) DESC, item_id DESC
LIMIT ?
);
""",
(queue_id, queue_id, keep),
)
count = cursor.fetchone()[0]
cursor.execute(
f"""--sql
DELETE
FROM session_queue
{where}
AND item_id NOT IN (
SELECT item_id
FROM session_queue
{where}
ORDER BY COALESCE(completed_at, updated_at, created_at) DESC, item_id DESC
LIMIT ?
);
""",
(queue_id, queue_id, keep),
)
return count
def _get_current_queue_size(self, queue_id: str) -> int:
"""Gets the current number of pending queue items"""
with self._db.transaction() as cursor:

View File

@@ -66,6 +66,7 @@
"i18next-http-backend": "^3.0.2",
"idb-keyval": "6.2.1",
"jsondiffpatch": "^0.7.3",
"jszip": "^3.10.1",
"konva": "^9.3.22",
"linkify-react": "^4.3.1",
"linkifyjs": "^4.3.1",

View File

@@ -86,6 +86,9 @@ importers:
jsondiffpatch:
specifier: ^0.7.3
version: 0.7.3
jszip:
specifier: ^3.10.1
version: 3.10.1
konva:
specifier: ^9.3.22
version: 9.3.22
@@ -2003,6 +2006,9 @@ packages:
copy-to-clipboard@3.3.3:
resolution: {integrity: sha512-2KV8NhB5JqC3ky0r9PMCAZKbUHSwtEo4CwCs0KXgruG43gX5PMqDEBbVU4OUzw2MuAWUfsuFmWvEKG5QRfSnJA==}
core-util-is@1.0.3:
resolution: {integrity: sha512-ZQBvi1DcpJ4GDqanjucZ2Hj3wEO5pZDS89BWbkcrvdxksJorwUDDZamX9ldFkp9aw2lmBDLgkObEA4DWNJ9FYQ==}
cosmiconfig@7.1.0:
resolution: {integrity: sha512-AdmX6xUzdNASswsFtmwSt7Vj8po9IuqXm0UXz7QKPuEUmPB4XyjGfaAr2PSuELMwkRMVH1EpIkX5bTZGRB3eCA==}
engines: {node: '>=10'}
@@ -2672,6 +2678,9 @@ packages:
resolution: {integrity: sha512-Hs59xBNfUIunMFgWAbGX5cq6893IbWg4KnrjbYwX3tx0ztorVgTDA6B2sxf8ejHJ4wz8BqGUMYlnzNBer5NvGg==}
engines: {node: '>= 4'}
immediate@3.0.6:
resolution: {integrity: sha512-XXOFtyqDjNDAQxVfYxuF7g9Il/IbWmmlQg2MYKOH8ExIT1qg6xc4zyS3HaEEATgs1btfzxq15ciUiY7gjSXRGQ==}
immer@10.1.1:
resolution: {integrity: sha512-s2MPrmjovJcoMaHtx6K11Ra7oD05NT97w1IC5zpMkT6Atjr7H8LjaDd81iIxUYpMKSRRNMJE703M1Fhr/TctHw==}
@@ -2825,6 +2834,9 @@ packages:
resolution: {integrity: sha512-fKzAra0rGJUUBwGBgNkHZuToZcn+TtXHpeCgmkMJMMYx1sQDYaCSyjJBSCa2nH1DGm7s3n1oBnohoVTBaN7Lww==}
engines: {node: '>=8'}
isarray@1.0.0:
resolution: {integrity: sha512-VLghIWNM6ELQzo7zwmcg0NmTVyWKYjvIeM83yjp0wRDTmUnrM678fQbcKBo6n2CJEF0szoG//ytg+TKla89ALQ==}
isarray@2.0.5:
resolution: {integrity: sha512-xHjhDr3cNBK0BzdUJSPXZntQUx/mwMS5Rw4A7lPJ90XGAO6ISP/ePDNuo0vhqOZU+UD5JoodwCAAoZQd3FeAKw==}
@@ -2916,6 +2928,9 @@ packages:
resolution: {integrity: sha512-ZZow9HBI5O6EPgSJLUb8n2NKgmVWTwCvHGwFuJlMjvLFqlGG6pjirPhtdsseaLZjSibD8eegzmYpUZwoIlj2cQ==}
engines: {node: '>=4.0'}
jszip@3.10.1:
resolution: {integrity: sha512-xXDvecyTpGLrqFrvkrUSoxxfJI5AH7U8zxxtVclpsUtMCq4JQ290LY8AW5c7Ggnr/Y/oK+bQMbqK2qmtk3pN4g==}
keyv@4.5.4:
resolution: {integrity: sha512-oxVHkHR/EJf2CNXnWxRLW6mg7JyCCUcG0DtEGmL2ctUo1PNTin1PUil+r/+4r5MpVgC/fn1kjsx7mjSujKqIpw==}
@@ -2934,6 +2949,9 @@ packages:
resolution: {integrity: sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ==}
engines: {node: '>= 0.8.0'}
lie@3.3.0:
resolution: {integrity: sha512-UaiMJzeWRlEujzAuw5LokY1L5ecNQYZKfmyZ9L7wDHb/p5etKaxXhohBcrw0EYby+G/NA52vRSN4N39dxHAIwQ==}
lines-and-columns@1.2.4:
resolution: {integrity: sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==}
@@ -3210,6 +3228,9 @@ packages:
package-json-from-dist@1.0.1:
resolution: {integrity: sha512-UEZIS3/by4OC8vL3P2dTXRETpebLI2NiI5vIrjaD/5UtrkFX/tNbwjTSRAGC/+7CAo2pIcBaRgWmcBBHcsaCIw==}
pako@1.0.11:
resolution: {integrity: sha512-4hLB8Py4zZce5s4yd9XzopqwVv/yGNhV1Bl8NTmCq1763HeK2+EwVTv+leGeL13Dnh2wfbqowVPXCIO0z4taYw==}
pako@2.1.0:
resolution: {integrity: sha512-w+eufiZ1WuJYgPXbV/PO3NCMEc3xqylkKHzp8bxp1uW4qaSNQUkwmLLEc3kKsfz8lpV1F8Ht3U1Cm+9Srog2ug==}
@@ -3298,6 +3319,9 @@ packages:
resolution: {integrity: sha512-Qb1gy5OrP5+zDf2Bvnzdl3jsTf1qXVMazbvCoKhtKqVs4/YK4ozX4gKQJJVyNe+cajNPn0KoC0MC3FUmaHWEmQ==}
engines: {node: ^10.13.0 || ^12.13.0 || ^14.15.0 || >=15.0.0}
process-nextick-args@2.0.1:
resolution: {integrity: sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==}
prop-types@15.8.1:
resolution: {integrity: sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg==}
@@ -3539,6 +3563,9 @@ packages:
resolution: {integrity: sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==}
engines: {node: '>=0.10.0'}
readable-stream@2.3.8:
resolution: {integrity: sha512-8p0AUk4XODgIewSi0l8Epjs+EVnWiK7NoDIEGU0HhE7+ZyY8D1IMY7odu5lRrFXGg71L15KG8QrPmum45RTtdA==}
readable-stream@3.6.2:
resolution: {integrity: sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==}
engines: {node: '>= 6'}
@@ -3661,6 +3688,9 @@ packages:
resolution: {integrity: sha512-AURm5f0jYEOydBj7VQlVvDrjeFgthDdEF5H1dP+6mNpoXOMo1quQqJ4wvJDyRZ9+pO3kGWoOdmV08cSv2aJV6Q==}
engines: {node: '>=0.4'}
safe-buffer@5.1.2:
resolution: {integrity: sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==}
safe-buffer@5.2.1:
resolution: {integrity: sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==}
@@ -3718,6 +3748,9 @@ packages:
resolution: {integrity: sha512-RJRdvCo6IAnPdsvP/7m6bsQqNnn1FCBX5ZNtFL98MmFF/4xAIJTIg1YbHW5DC2W5SKZanrC6i4HsJqlajw/dZw==}
engines: {node: '>= 0.4'}
setimmediate@1.0.5:
resolution: {integrity: sha512-MATJdZp8sLqDl/68LfQmbP8zKPLQNV6BIZoIgrscFDQ+RsvK/BxeDQOgyxKKoh0y/8h3BqVFnCqQ/gd+reiIXA==}
shebang-command@2.0.0:
resolution: {integrity: sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==}
engines: {node: '>=8'}
@@ -3857,6 +3890,9 @@ packages:
resolution: {integrity: sha512-UXSH262CSZY1tfu3G3Secr6uGLCFVPMhIqHjlgCUtCCcgihYc/xKs9djMTMUOb2j1mVSeU8EU6NWc/iQKU6Gfg==}
engines: {node: '>= 0.4'}
string_decoder@1.1.1:
resolution: {integrity: sha512-n/ShnvDi6FHbbVfviro+WojiFzv+s8MPMHBczVePfUpDJLwoLT0ht1l4YwBCbi8pJAveEEdnkHyPyTP/mzRfwg==}
string_decoder@1.3.0:
resolution: {integrity: sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==}
@@ -6153,6 +6189,8 @@ snapshots:
dependencies:
toggle-selection: 1.0.6
core-util-is@1.0.3: {}
cosmiconfig@7.1.0:
dependencies:
'@types/parse-json': 4.0.2
@@ -6957,6 +6995,8 @@ snapshots:
ignore@7.0.5: {}
immediate@3.0.6: {}
immer@10.1.1: {}
import-fresh@3.3.1:
@@ -7103,6 +7143,8 @@ snapshots:
dependencies:
is-docker: 2.2.1
isarray@1.0.0: {}
isarray@2.0.5: {}
isexe@2.0.0: {}
@@ -7192,6 +7234,13 @@ snapshots:
object.assign: 4.1.7
object.values: 1.2.1
jszip@3.10.1:
dependencies:
lie: 3.3.0
pako: 1.0.11
readable-stream: 2.3.8
setimmediate: 1.0.5
keyv@4.5.4:
dependencies:
json-buffer: 3.0.1
@@ -7221,6 +7270,10 @@ snapshots:
prelude-ls: 1.2.1
type-check: 0.4.0
lie@3.3.0:
dependencies:
immediate: 3.0.6
lines-and-columns@1.2.4: {}
linkify-react@4.3.1(linkifyjs@4.3.1)(react@18.3.1):
@@ -7510,6 +7563,8 @@ snapshots:
package-json-from-dist@1.0.1: {}
pako@1.0.11: {}
pako@2.1.0: {}
parent-module@1.0.1:
@@ -7578,6 +7633,8 @@ snapshots:
ansi-styles: 5.2.0
react-is: 17.0.2
process-nextick-args@2.0.1: {}
prop-types@15.8.1:
dependencies:
loose-envify: 1.4.0
@@ -7843,6 +7900,16 @@ snapshots:
dependencies:
loose-envify: 1.4.0
readable-stream@2.3.8:
dependencies:
core-util-is: 1.0.3
inherits: 2.0.4
isarray: 1.0.0
process-nextick-args: 2.0.1
safe-buffer: 5.1.2
string_decoder: 1.1.1
util-deprecate: 1.0.2
readable-stream@3.6.2:
dependencies:
inherits: 2.0.4
@@ -7994,6 +8061,8 @@ snapshots:
has-symbols: 1.1.0
isarray: 2.0.5
safe-buffer@5.1.2: {}
safe-buffer@5.2.1: {}
safe-push-apply@1.0.0:
@@ -8051,6 +8120,8 @@ snapshots:
es-errors: 1.3.0
es-object-atoms: 1.1.1
setimmediate@1.0.5: {}
shebang-command@2.0.0:
dependencies:
shebang-regex: 3.0.0
@@ -8236,6 +8307,10 @@ snapshots:
define-properties: 1.2.1
es-object-atoms: 1.1.1
string_decoder@1.1.1:
dependencies:
safe-buffer: 5.1.2
string_decoder@1.3.0:
dependencies:
safe-buffer: 5.2.1

View File

@@ -212,6 +212,7 @@
"copy": "Copy",
"copyError": "$t(gallery.copy) Error",
"clipboard": "Clipboard",
"collapseAll": "Collapse All",
"crop": "Crop",
"on": "On",
"off": "Off",
@@ -239,6 +240,7 @@
"error": "Error",
"error_withCount_one": "{{count}} error",
"error_withCount_other": "{{count}} errors",
"expandAll": "Expand All",
"model_withCount_one": "{{count}} model",
"model_withCount_other": "{{count}} models",
"file": "File",
@@ -715,6 +717,10 @@
"title": "Rect Tool",
"desc": "Select the rect tool."
},
"selectLassoTool": {
"title": "Lasso Tool",
"desc": "Select the lasso tool."
},
"selectViewTool": {
"title": "View Tool",
"desc": "Select the view tool."
@@ -1411,6 +1417,8 @@
"fullyContainNodesHelp": "Nodes must be fully inside the selection box to be selected",
"showEdgeLabels": "Show Edge Labels",
"showEdgeLabelsHelp": "Show labels on edges, indicating the connected nodes",
"groupNodesByCategory": "Group Nodes by Category",
"groupNodesByCategoryHelp": "Group nodes by category in the add node dialog",
"hideLegendNodes": "Hide Field Type Legend",
"hideMinimapnodes": "Hide MiniMap",
"inputMayOnlyHaveOneConnection": "Input may only have one connection",
@@ -1741,6 +1749,8 @@
"enableNSFWChecker": "Enable NSFW Checker",
"general": "General",
"generation": "Generation",
"maxQueueHistory": "Max Queue History",
"maxQueueHistorySaveFailed": "Failed to save Max Queue History",
"models": "Models",
"preferAttentionStyleNumeric": "Prefer Numeric Attention Style",
"prompt": "Prompt",
@@ -2798,10 +2808,16 @@
"radial": "Radial",
"clip": "Clip Gradient"
},
"lasso": {
"freehand": "Freehand",
"polygon": "Polygon",
"polygonHint": "Click to add points, click the first point to close."
},
"tool": {
"brush": "Brush",
"eraser": "Eraser",
"rectangle": "Rectangle",
"lasso": "Lasso",
"gradient": "Gradient",
"bbox": "Bbox",
"move": "Move",
@@ -3041,6 +3057,19 @@
"copyCanvasToClipboard": "Copy Canvas to Clipboard",
"copyBboxToClipboard": "Copy Bbox to Clipboard"
},
"canvasProject": {
"project": "Project",
"saveProject": "Save Canvas Project",
"loadProject": "Load Canvas Project",
"saveSuccess": "Project Saved",
"saveSuccessDesc": "Saved project with {{count}} images",
"saveError": "Failed to Save Project",
"loadSuccess": "Project Loaded",
"loadSuccessDesc": "Canvas state restored from project file",
"loadError": "Failed to Load Project",
"loadWarning": "Loading a project will replace your current canvas, including all layers, masks, reference images, and generation parameters. This action cannot be undone.",
"projectName": "Project Name"
},
"stagingArea": {
"accept": "Accept",
"discardAll": "Discard All",

View File

@@ -2,6 +2,8 @@ import { GlobalImageHotkeys } from 'app/components/GlobalImageHotkeys';
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
import { CanvasPasteModal } from 'features/controlLayers/components/CanvasPasteModal';
import { CanvasWorkflowIntegrationModal } from 'features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationModal';
import { LoadCanvasProjectConfirmationAlertDialog } from 'features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog';
import { SaveCanvasProjectDialog } from 'features/controlLayers/components/SaveCanvasProjectDialog';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { CropImageModal } from 'features/cropper/components/CropImageModal';
import { DeleteImageModal } from 'features/deleteImageModal/components/DeleteImageModal';
@@ -54,6 +56,8 @@ export const GlobalModalIsolator = memo(() => {
<CanvasPasteModal />
<CanvasWorkflowIntegrationModal />
</CanvasManagerProviderGate>
<SaveCanvasProjectDialog />
<LoadCanvasProjectConfirmationAlertDialog />
<LoadWorkflowFromGraphModal />
<CropImageModal />
</>

View File

@@ -2,6 +2,8 @@ import { Menu, MenuButton, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-l
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import { CanvasContextMenuItemsCropCanvasToBbox } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuItemsCropCanvasToBbox';
import { NewLayerIcon } from 'features/controlLayers/components/common/icons';
import { useLoadCanvasProjectWithDialog } from 'features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog';
import { useSaveCanvasProjectWithDialog } from 'features/controlLayers/components/SaveCanvasProjectDialog';
import { useCopyCanvasToClipboard } from 'features/controlLayers/hooks/copyHooks';
import {
useNewControlLayerFromBbox,
@@ -14,16 +16,19 @@ import {
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCopyBold, PiFloppyDiskBold } from 'react-icons/pi';
import { PiArchiveBold, PiCopyBold, PiFileArrowDownBold, PiFileArrowUpBold, PiFloppyDiskBold } from 'react-icons/pi';
export const CanvasContextMenuGlobalMenuItems = memo(() => {
const { t } = useTranslation();
const saveSubMenu = useSubMenu();
const projectSubMenu = useSubMenu();
const newSubMenu = useSubMenu();
const copySubMenu = useSubMenu();
const isBusy = useCanvasIsBusy();
const saveCanvasToGallery = useSaveCanvasToGallery();
const saveBboxToGallery = useSaveBboxToGallery();
const saveCanvasProject = useSaveCanvasProjectWithDialog();
const loadCanvasProject = useLoadCanvasProjectWithDialog();
const newRegionalReferenceImageFromBbox = useNewRegionalReferenceImageFromBbox();
const newGlobalReferenceImageFromBbox = useNewGlobalReferenceImageFromBbox();
const newRasterLayerFromBbox = useNewRasterLayerFromBbox();
@@ -50,6 +55,21 @@ export const CanvasContextMenuGlobalMenuItems = memo(() => {
</MenuList>
</Menu>
</MenuItem>
<MenuItem {...projectSubMenu.parentMenuItemProps} icon={<PiArchiveBold />}>
<Menu {...projectSubMenu.menuProps}>
<MenuButton {...projectSubMenu.menuButtonProps}>
<SubMenuButtonContent label={t('controlLayers.canvasProject.project')} />
</MenuButton>
<MenuList {...projectSubMenu.menuListProps}>
<MenuItem icon={<PiFileArrowDownBold />} isDisabled={isBusy} onClick={saveCanvasProject}>
{t('controlLayers.canvasProject.saveProject')}
</MenuItem>
<MenuItem icon={<PiFileArrowUpBold />} isDisabled={isBusy} onClick={loadCanvasProject}>
{t('controlLayers.canvasProject.loadProject')}
</MenuItem>
</MenuList>
</Menu>
</MenuItem>
<MenuItem {...newSubMenu.parentMenuItemProps} icon={<NewLayerIcon />}>
<Menu {...newSubMenu.menuProps}>
<MenuButton {...newSubMenu.menuButtonProps}>

View File

@@ -0,0 +1,69 @@
import { ConfirmationAlertDialog, Flex, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { useCanvasProjectLoad } from 'features/controlLayers/hooks/useCanvasProjectLoad';
import { CANVAS_PROJECT_EXTENSION } from 'features/controlLayers/util/canvasProjectFile';
import { atom } from 'nanostores';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const $pendingFile = atom<File | null>(null);
const openFileDialog = (onFileSelected: (file: File) => void) => {
const input = document.createElement('input');
input.type = 'file';
input.accept = CANVAS_PROJECT_EXTENSION;
input.onchange = () => {
const file = input.files?.[0];
if (file) {
onFileSelected(file);
}
};
input.click();
};
export const useLoadCanvasProjectWithDialog = () => {
const openDialog = useCallback(() => {
openFileDialog((file) => {
$pendingFile.set(file);
});
}, []);
return openDialog;
};
export const LoadCanvasProjectConfirmationAlertDialog = memo(() => {
useAssertSingleton('LoadCanvasProjectConfirmationAlertDialog');
const { t } = useTranslation();
const { loadCanvasProject } = useCanvasProjectLoad();
const pendingFile = useStore($pendingFile);
const onClose = useCallback(() => {
$pendingFile.set(null);
}, []);
const onAccept = useCallback(() => {
const file = $pendingFile.get();
if (file) {
void loadCanvasProject(file);
}
$pendingFile.set(null);
}, [loadCanvasProject]);
return (
<ConfirmationAlertDialog
isOpen={pendingFile !== null}
onClose={onClose}
title={t('controlLayers.canvasProject.loadProject')}
acceptCallback={onAccept}
acceptButtonText={t('common.load')}
useInert={false}
>
<Flex flexDir="column" gap={2}>
<Text>{t('controlLayers.canvasProject.loadWarning')}</Text>
</Flex>
</ConfirmationAlertDialog>
);
});
LoadCanvasProjectConfirmationAlertDialog.displayName = 'LoadCanvasProjectConfirmationAlertDialog';

View File

@@ -0,0 +1,92 @@
import {
AlertDialog,
AlertDialogBody,
AlertDialogContent,
AlertDialogFooter,
AlertDialogHeader,
Button,
Flex,
FormControl,
FormLabel,
Input,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { useCanvasProjectSave } from 'features/controlLayers/hooks/useCanvasProjectSave';
import { atom } from 'nanostores';
import type { ChangeEvent, RefObject } from 'react';
import { memo, useCallback, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
const $isOpen = atom(false);
export const useSaveCanvasProjectWithDialog = () => {
return useCallback(() => {
$isOpen.set(true);
}, []);
};
export const SaveCanvasProjectDialog = memo(() => {
useAssertSingleton('SaveCanvasProjectDialog');
const isOpen = useStore($isOpen);
const cancelRef = useRef<HTMLButtonElement>(null);
const onClose = useCallback(() => {
$isOpen.set(false);
}, []);
return (
<AlertDialog isOpen={isOpen} onClose={onClose} leastDestructiveRef={cancelRef} isCentered>
{isOpen && <Content cancelRef={cancelRef} />}
</AlertDialog>
);
});
SaveCanvasProjectDialog.displayName = 'SaveCanvasProjectDialog';
const Content = memo(({ cancelRef }: { cancelRef: RefObject<HTMLButtonElement> }) => {
const { t } = useTranslation();
const { saveCanvasProject } = useCanvasProjectSave();
const [name, setName] = useState('Canvas Project');
const onChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setName(e.target.value);
}, []);
const onClose = useCallback(() => {
$isOpen.set(false);
}, []);
const onSave = useCallback(() => {
void saveCanvasProject(name);
$isOpen.set(false);
}, [name, saveCanvasProject]);
return (
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
{t('controlLayers.canvasProject.saveProject')}
</AlertDialogHeader>
<AlertDialogBody>
<FormControl alignItems="flex-start">
<FormLabel mt="2">{t('controlLayers.canvasProject.projectName')}</FormLabel>
<Flex flexDir="column" width="full" gap="2">
<Input value={name} onChange={onChange} placeholder={t('controlLayers.canvasProject.projectName')} />
</Flex>
</FormControl>
</AlertDialogBody>
<AlertDialogFooter>
<Button ref={cancelRef} onClick={onClose}>
{t('common.cancel')}
</Button>
<Button colorScheme="invokeBlue" onClick={onSave} ml={3} isDisabled={!name.trim()}>
{t('common.save')}
</Button>
</AlertDialogFooter>
</AlertDialogContent>
);
});
Content.displayName = 'SaveCanvasProjectDialogContent';

View File

@@ -3,6 +3,7 @@ import { ToolBboxButton } from 'features/controlLayers/components/Tool/ToolBboxB
import { ToolBrushButton } from 'features/controlLayers/components/Tool/ToolBrushButton';
import { ToolColorPickerButton } from 'features/controlLayers/components/Tool/ToolColorPickerButton';
import { ToolGradientButton } from 'features/controlLayers/components/Tool/ToolGradientButton';
import { ToolLassoButton } from 'features/controlLayers/components/Tool/ToolLassoButton';
import { ToolMoveButton } from 'features/controlLayers/components/Tool/ToolMoveButton';
import { ToolRectButton } from 'features/controlLayers/components/Tool/ToolRectButton';
import { ToolTextButton } from 'features/controlLayers/components/Tool/ToolTextButton';
@@ -20,6 +21,7 @@ export const ToolChooser: React.FC = () => {
<ToolRectButton />
<ToolGradientButton />
<ToolTextButton />
<ToolLassoButton />
<ToolMoveButton />
<ToolViewButton />
<ToolBboxButton />

View File

@@ -0,0 +1,34 @@
import { IconButton, Tooltip } from '@invoke-ai/ui-library';
import { useSelectTool, useToolIsSelected } from 'features/controlLayers/components/Tool/hooks';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiLassoBold } from 'react-icons/pi';
export const ToolLassoButton = memo(() => {
const { t } = useTranslation();
const isSelected = useToolIsSelected('lasso');
const selectLasso = useSelectTool('lasso');
useRegisteredHotkeys({
id: 'selectLassoTool',
category: 'canvas',
callback: selectLasso,
options: { enabled: !isSelected },
dependencies: [isSelected, selectLasso],
});
return (
<Tooltip label={`${t('controlLayers.tool.lasso', { defaultValue: 'Lasso' })} (L)`} placement="end">
<IconButton
aria-label={`${t('controlLayers.tool.lasso', { defaultValue: 'Lasso' })} (L)`}
icon={<PiLassoBold />}
colorScheme={isSelected ? 'invokeBlue' : 'base'}
variant="solid"
onClick={selectLasso}
/>
</Tooltip>
);
});
ToolLassoButton.displayName = 'ToolLassoButton';

View File

@@ -0,0 +1,47 @@
import { ButtonGroup, IconButton, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectLassoMode, settingsLassoModeChanged } from 'features/controlLayers/store/canvasSettingsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPolygonBold, PiScribbleLoopBold } from 'react-icons/pi';
export const ToolLassoModeToggle = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const lassoMode = useAppSelector(selectLassoMode);
const setFreehand = useCallback(() => {
dispatch(settingsLassoModeChanged('freehand'));
}, [dispatch]);
const setPolygon = useCallback(() => {
dispatch(settingsLassoModeChanged('polygon'));
}, [dispatch]);
return (
<ButtonGroup isAttached size="sm">
<Tooltip label={t('controlLayers.lasso.freehand', { defaultValue: 'Freehand' })}>
<IconButton
aria-label={t('controlLayers.lasso.freehand', { defaultValue: 'Freehand' })}
icon={<PiScribbleLoopBold size={16} />}
colorScheme={lassoMode === 'freehand' ? 'invokeBlue' : 'base'}
variant="solid"
onClick={setFreehand}
/>
</Tooltip>
<Tooltip label={t('controlLayers.lasso.polygon', { defaultValue: 'Polygon' })}>
<IconButton
aria-label={t('controlLayers.lasso.polygonHint', {
defaultValue: 'Click to add points, click the first point to close.',
})}
icon={<PiPolygonBold size={16} />}
colorScheme={lassoMode === 'polygon' ? 'invokeBlue' : 'base'}
variant="solid"
onClick={setPolygon}
/>
</Tooltip>
</ButtonGroup>
);
});
ToolLassoModeToggle.displayName = 'ToolLassoModeToggle';

View File

@@ -5,11 +5,13 @@ import { useToolIsSelected } from 'features/controlLayers/components/Tool/hooks'
import { ToolFillColorPicker } from 'features/controlLayers/components/Tool/ToolFillColorPicker';
import { ToolGradientClipToggle } from 'features/controlLayers/components/Tool/ToolGradientClipToggle';
import { ToolGradientModeToggle } from 'features/controlLayers/components/Tool/ToolGradientModeToggle';
import { ToolLassoModeToggle } from 'features/controlLayers/components/Tool/ToolLassoModeToggle';
import { ToolOptionsRowContainer } from 'features/controlLayers/components/Tool/ToolOptionsRowContainer';
import { ToolWidthPicker } from 'features/controlLayers/components/Tool/ToolWidthPicker';
import { CanvasToolbarFitBboxToLayersButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarFitBboxToLayersButton';
import { CanvasToolbarFitBboxToMasksButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarFitBboxToMasksButton';
import { CanvasToolbarNewSessionMenuButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarNewSessionMenuButton';
import { CanvasToolbarProjectMenuButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarProjectMenuButton';
import { CanvasToolbarRedoButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarRedoButton';
import { CanvasToolbarResetViewButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarResetViewButton';
import { CanvasToolbarSaveToGalleryButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarSaveToGalleryButton';
@@ -31,6 +33,7 @@ export const CanvasToolbar = memo(() => {
const isBrushSelected = useToolIsSelected('brush');
const isEraserSelected = useToolIsSelected('eraser');
const isTextSelected = useToolIsSelected('text');
const isLassoSelected = useToolIsSelected('lasso');
const isGradientSelected = useToolIsSelected('gradient');
const showToolWithPicker = useMemo(() => {
return !isTextSelected && (isBrushSelected || isEraserSelected);
@@ -57,6 +60,11 @@ export const CanvasToolbar = memo(() => {
<ToolGradientModeToggle />
</Box>
)}
{isLassoSelected && (
<Box ms={2} mt="-2px" display="flex" alignItems="center" gap={2}>
<ToolLassoModeToggle />
</Box>
)}
{isTextSelected ? <TextToolOptions /> : showToolWithPicker && <ToolWidthPicker />}
</ToolOptionsRowContainer>
<Flex alignItems="center" h="full">
@@ -67,6 +75,7 @@ export const CanvasToolbar = memo(() => {
</Flex>
<Divider orientation="vertical" />
<Flex alignItems="center" h="full">
<CanvasToolbarProjectMenuButton />
<CanvasToolbarSaveToGalleryButton />
<CanvasToolbarUndoButton />
<CanvasToolbarRedoButton />

View File

@@ -0,0 +1,37 @@
import { IconButton, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useLoadCanvasProjectWithDialog } from 'features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog';
import { useSaveCanvasProjectWithDialog } from 'features/controlLayers/components/SaveCanvasProjectDialog';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArchiveBold, PiFileArrowDownBold, PiFileArrowUpBold } from 'react-icons/pi';
export const CanvasToolbarProjectMenuButton = memo(() => {
const { t } = useTranslation();
const isBusy = useCanvasIsBusy();
const saveCanvasProject = useSaveCanvasProjectWithDialog();
const loadCanvasProject = useLoadCanvasProjectWithDialog();
return (
<Menu placement="bottom-end">
<MenuButton
as={IconButton}
aria-label={t('controlLayers.canvasProject.project')}
tooltip={t('controlLayers.canvasProject.project')}
icon={<PiArchiveBold />}
variant="link"
alignSelf="stretch"
/>
<MenuList>
<MenuItem icon={<PiFileArrowDownBold />} isDisabled={isBusy} onClick={saveCanvasProject}>
{t('controlLayers.canvasProject.saveProject')}
</MenuItem>
<MenuItem icon={<PiFileArrowUpBold />} isDisabled={isBusy} onClick={loadCanvasProject}>
{t('controlLayers.canvasProject.loadProject')}
</MenuItem>
</MenuList>
</Menu>
);
});
CanvasToolbarProjectMenuButton.displayName = 'CanvasToolbarProjectMenuButton';

View File

@@ -0,0 +1,157 @@
import { logger } from 'app/logging/logger';
import { useAppDispatch } from 'app/store/storeHooks';
import { parseify } from 'common/util/serialize';
import { canvasProjectRecalled } from 'features/controlLayers/store/canvasSlice';
import { loraAllDeleted, loraRecalled } from 'features/controlLayers/store/lorasSlice';
import { paramsRecalled } from 'features/controlLayers/store/paramsSlice';
import { refImagesRecalled } from 'features/controlLayers/store/refImagesSlice';
import type { LoRA, ParamsState, RefImageState } from 'features/controlLayers/store/types';
import type { CanvasProjectState } from 'features/controlLayers/util/canvasProjectFile';
import {
checkExistingImages,
collectImageNames,
parseManifest,
processWithConcurrencyLimit,
remapCanvasState,
remapRefImages,
} from 'features/controlLayers/util/canvasProjectFile';
import { toast } from 'features/toast/toast';
import JSZip from 'jszip';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { uploadImage } from 'services/api/endpoints/images';
const log = logger('canvas');
export const useCanvasProjectLoad = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const loadCanvasProject = useCallback(
async (file: File) => {
try {
const zip = await JSZip.loadAsync(file);
// Validate manifest
const manifestFile = zip.file('manifest.json');
if (!manifestFile) {
throw new Error('Invalid project file: missing manifest.json');
}
const manifestData = JSON.parse(await manifestFile.async('string'));
parseManifest(manifestData);
// Read state files
const canvasStateFile = zip.file('canvas_state.json');
if (!canvasStateFile) {
throw new Error('Invalid project file: missing canvas_state.json');
}
const canvasState: CanvasProjectState = JSON.parse(await canvasStateFile.async('string'));
const paramsFile = zip.file('params.json');
let projectParams: ParamsState | null = null;
if (paramsFile) {
projectParams = JSON.parse(await paramsFile.async('string'));
}
const refImagesFile = zip.file('ref_images.json');
let refImages: RefImageState[] = [];
if (refImagesFile) {
refImages = JSON.parse(await refImagesFile.async('string'));
}
const lorasFile = zip.file('loras.json');
let loras: LoRA[] = [];
if (lorasFile) {
loras = JSON.parse(await lorasFile.async('string'));
}
// Collect all image names referenced in the state
const imageNames = collectImageNames(canvasState, refImages);
// Check which images already exist on the server
const { missing } = await checkExistingImages(imageNames);
// Upload missing images from the ZIP
const imageNameMapping = new Map<string, string>();
const imagesFolder = zip.folder('images');
if (imagesFolder && missing.size > 0) {
await processWithConcurrencyLimit(Array.from(missing), async (imageName) => {
const imageFile = imagesFolder.file(imageName);
if (!imageFile) {
log.warn(`Image ${imageName} referenced but not found in ZIP`);
return;
}
try {
const blob = await imageFile.async('blob');
const uploadFile = new File([blob], imageName, { type: 'image/png' });
const imageDTO = await uploadImage({
file: uploadFile,
image_category: 'general',
is_intermediate: false,
silent: true,
});
// Map old name to new name (only if different)
if (imageDTO.image_name !== imageName) {
imageNameMapping.set(imageName, imageDTO.image_name);
}
} catch (error) {
log.warn({ error: parseify(error) }, `Failed to upload image ${imageName}`);
}
});
}
// Remap image names in state objects
const remappedCanvasState = remapCanvasState(canvasState, imageNameMapping);
const remappedRefImages = remapRefImages(refImages, imageNameMapping);
// Dispatch state restoration
dispatch(
canvasProjectRecalled({
rasterLayers: remappedCanvasState.rasterLayers,
controlLayers: remappedCanvasState.controlLayers,
inpaintMasks: remappedCanvasState.inpaintMasks,
regionalGuidance: remappedCanvasState.regionalGuidance,
bbox: remappedCanvasState.bbox,
selectedEntityIdentifier: remappedCanvasState.selectedEntityIdentifier,
bookmarkedEntityIdentifier: remappedCanvasState.bookmarkedEntityIdentifier,
})
);
// Restore reference images
dispatch(refImagesRecalled({ entities: remappedRefImages, replace: true }));
// Restore generation parameters
if (projectParams) {
dispatch(paramsRecalled(projectParams));
}
// Restore LoRAs (always clear, even if project has none)
dispatch(loraAllDeleted());
for (const lora of loras) {
dispatch(loraRecalled({ lora }));
}
toast({
id: 'CANVAS_PROJECT_LOAD_SUCCESS',
title: t('controlLayers.canvasProject.loadSuccess'),
description: t('controlLayers.canvasProject.loadSuccessDesc'),
status: 'success',
});
} catch (error) {
log.error({ error: parseify(error) }, 'Failed to load canvas project');
toast({
id: 'CANVAS_PROJECT_LOAD_ERROR',
title: t('controlLayers.canvasProject.loadError'),
description: String(error),
status: 'error',
});
}
},
[dispatch, t]
);
return { loadCanvasProject };
};

View File

@@ -0,0 +1,116 @@
import { logger } from 'app/logging/logger';
import { useAppStore } from 'app/store/storeHooks';
import { parseify } from 'common/util/serialize';
import { downloadBlob } from 'features/controlLayers/konva/util';
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { CanvasProjectManifest, CanvasProjectState } from 'features/controlLayers/util/canvasProjectFile';
import {
CANVAS_PROJECT_EXTENSION,
CANVAS_PROJECT_VERSION,
collectImageNames,
processWithConcurrencyLimit,
} from 'features/controlLayers/util/canvasProjectFile';
import { toast } from 'features/toast/toast';
import JSZip from 'jszip';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetAppVersionQuery } from 'services/api/endpoints/appInfo';
const log = logger('canvas');
const sanitizeFileName = (name: string): string => {
// Replace characters that are invalid in filenames
return name.replace(/[<>:"/\\|?*]/g, '_').trim() || 'canvas-project';
};
export const useCanvasProjectSave = () => {
const { t } = useTranslation();
const store = useAppStore();
const { data: appVersion } = useGetAppVersionQuery();
const saveCanvasProject = useCallback(
async (name: string) => {
try {
const state = store.getState();
const canvasState = selectCanvasSlice(state);
const paramsState = selectParamsSlice(state);
const refImagesState = selectRefImagesSlice(state);
const lorasState = selectLoRAsSlice(state);
// Build the canvas project state
const projectState: CanvasProjectState = {
rasterLayers: canvasState.rasterLayers.entities,
controlLayers: canvasState.controlLayers.entities,
inpaintMasks: canvasState.inpaintMasks.entities,
regionalGuidance: canvasState.regionalGuidance.entities,
bbox: canvasState.bbox,
selectedEntityIdentifier: canvasState.selectedEntityIdentifier,
bookmarkedEntityIdentifier: canvasState.bookmarkedEntityIdentifier,
};
// Collect all image names referenced in the state
const imageNames = collectImageNames(projectState, refImagesState.entities);
// Build ZIP
const zip = new JSZip();
// Add manifest
const manifest: CanvasProjectManifest = {
version: CANVAS_PROJECT_VERSION,
appVersion: appVersion?.version ?? 'unknown',
createdAt: new Date().toISOString(),
name,
};
zip.file('manifest.json', JSON.stringify(manifest, null, 2));
// Add state files
zip.file('canvas_state.json', JSON.stringify(projectState, null, 2));
zip.file('params.json', JSON.stringify(paramsState, null, 2));
zip.file('ref_images.json', JSON.stringify(refImagesState.entities, null, 2));
zip.file('loras.json', JSON.stringify(lorasState.loras, null, 2));
// Fetch and add images
const imagesFolder = zip.folder('images')!;
await processWithConcurrencyLimit(Array.from(imageNames), async (imageName) => {
try {
const response = await fetch(`/api/v1/images/i/${imageName}/full`);
if (!response.ok) {
log.warn(`Failed to fetch image ${imageName}: ${response.status}`);
return;
}
const blob = await response.blob();
imagesFolder.file(imageName, blob);
} catch (error) {
log.warn({ error: parseify(error) }, `Failed to fetch image ${imageName}`);
}
});
// Generate ZIP blob and trigger download
const blob = await zip.generateAsync({ type: 'blob' });
const fileName = `${sanitizeFileName(name)}${CANVAS_PROJECT_EXTENSION}`;
downloadBlob(blob, fileName);
toast({
id: 'CANVAS_PROJECT_SAVE_SUCCESS',
title: t('controlLayers.canvasProject.saveSuccess'),
description: t('controlLayers.canvasProject.saveSuccessDesc', { count: imageNames.size }),
status: 'success',
});
} catch (error) {
log.error({ error: parseify(error) }, 'Failed to save canvas project');
toast({
id: 'CANVAS_PROJECT_SAVE_ERROR',
title: t('controlLayers.canvasProject.saveError'),
description: String(error),
status: 'error',
});
}
},
[appVersion?.version, store, t]
);
return { saveCanvasProject };
};

View File

@@ -8,6 +8,7 @@ import { CanvasObjectEraserLine } from 'features/controlLayers/konva/CanvasObjec
import { CanvasObjectEraserLineWithPressure } from 'features/controlLayers/konva/CanvasObject/CanvasObjectEraserLineWithPressure';
import { CanvasObjectGradient } from 'features/controlLayers/konva/CanvasObject/CanvasObjectGradient';
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
import { CanvasObjectLasso } from 'features/controlLayers/konva/CanvasObject/CanvasObjectLasso';
import { CanvasObjectRect } from 'features/controlLayers/konva/CanvasObject/CanvasObjectRect';
import type { AnyObjectRenderer, AnyObjectState } from 'features/controlLayers/konva/CanvasObject/types';
import { getPrefixedId } from 'features/controlLayers/konva/util';
@@ -152,6 +153,15 @@ export class CanvasEntityBufferObjectRenderer extends CanvasModuleBase {
this.konva.group.add(this.renderer.konva.group);
}
didRender = this.renderer.update(this.state, true);
} else if (this.state.type === 'lasso') {
assert(this.renderer instanceof CanvasObjectLasso || !this.renderer);
if (!this.renderer) {
this.renderer = new CanvasObjectLasso(this.state, this);
this.konva.group.add(this.renderer.konva.group);
}
didRender = this.renderer.update(this.state, true);
} else if (this.state.type === 'gradient') {
assert(this.renderer instanceof CanvasObjectGradient || !this.renderer);
@@ -247,6 +257,9 @@ export class CanvasEntityBufferObjectRenderer extends CanvasModuleBase {
case 'rect':
this.manager.stateApi.addRect({ entityIdentifier, rect: this.state });
break;
case 'lasso':
this.manager.stateApi.addLasso({ entityIdentifier, lasso: this.state });
break;
case 'gradient':
this.manager.stateApi.addGradient({ entityIdentifier, gradient: this.state });
break;

View File

@@ -10,6 +10,7 @@ import { CanvasObjectEraserLine } from 'features/controlLayers/konva/CanvasObjec
import { CanvasObjectEraserLineWithPressure } from 'features/controlLayers/konva/CanvasObject/CanvasObjectEraserLineWithPressure';
import { CanvasObjectGradient } from 'features/controlLayers/konva/CanvasObject/CanvasObjectGradient';
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
import { CanvasObjectLasso } from 'features/controlLayers/konva/CanvasObject/CanvasObjectLasso';
import { CanvasObjectRect } from 'features/controlLayers/konva/CanvasObject/CanvasObjectRect';
import type { AnyObjectRenderer, AnyObjectState } from 'features/controlLayers/konva/CanvasObject/types';
import { LightnessToAlphaFilter } from 'features/controlLayers/konva/filters';
@@ -397,6 +398,16 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
this.konva.objectGroup.add(renderer.konva.group);
}
didRender = renderer.update(objectState, force || isFirstRender);
} else if (objectState.type === 'lasso') {
assert(renderer instanceof CanvasObjectLasso || !renderer);
if (!renderer) {
renderer = new CanvasObjectLasso(objectState, this);
this.renderers.set(renderer.id, renderer);
this.konva.objectGroup.add(renderer.konva.group);
}
didRender = renderer.update(objectState, force || isFirstRender);
} else if (objectState.type === 'gradient') {
assert(renderer instanceof CanvasObjectGradient || !renderer);
@@ -433,17 +444,21 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
* these visually transparent shapes in its calculation:
*
* - Eraser lines, which are normal lines with a globalCompositeOperation of 'destination-out'.
* - Subtracting lasso shapes, which use a globalCompositeOperation of 'destination-out'.
* - Clipped portions of any shape.
* - Images, which may have transparent areas.
*/
needsPixelBbox = (): boolean => {
let needsPixelBbox = false;
for (const renderer of this.renderers.values()) {
const isEraserLine = renderer instanceof CanvasObjectEraserLine;
const isEraserLine =
renderer instanceof CanvasObjectEraserLine || renderer instanceof CanvasObjectEraserLineWithPressure;
const isSubtractingLasso =
renderer instanceof CanvasObjectLasso && renderer.state.compositeOperation === 'destination-out';
const isImage = renderer instanceof CanvasObjectImage;
const imageIgnoresTransparency = isImage && renderer.state.usePixelBbox === false;
const hasClip = renderer instanceof CanvasObjectBrushLine && renderer.state.clip;
if (isEraserLine || hasClip || (isImage && !imageIgnoresTransparency)) {
if (isEraserLine || isSubtractingLasso || hasClip || (isImage && !imageIgnoresTransparency)) {
needsPixelBbox = true;
break;
}

View File

@@ -0,0 +1,85 @@
import { deepClone } from 'common/util/deepClone';
import type { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer';
import type { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { CanvasLassoState } from 'features/controlLayers/store/types';
import Konva from 'konva';
import type { Logger } from 'roarr';
export class CanvasObjectLasso extends CanvasModuleBase {
readonly type = 'object_lasso';
readonly id: string;
readonly path: string[];
readonly parent: CanvasEntityObjectRenderer | CanvasEntityBufferObjectRenderer;
readonly manager: CanvasManager;
readonly log: Logger;
state: CanvasLassoState;
konva: {
group: Konva.Group;
line: Konva.Line;
};
constructor(state: CanvasLassoState, parent: CanvasEntityObjectRenderer | CanvasEntityBufferObjectRenderer) {
super();
this.id = state.id;
this.parent = parent;
this.manager = parent.manager;
this.path = this.manager.buildPath(this);
this.log = this.manager.buildLogger(this);
this.log.debug({ state }, 'Creating module');
this.konva = {
group: new Konva.Group({
name: `${this.type}:group`,
listening: false,
}),
line: new Konva.Line({
name: `${this.type}:line`,
listening: false,
closed: true,
fill: 'white',
strokeEnabled: false,
perfectDrawEnabled: false,
}),
};
this.konva.group.add(this.konva.line);
this.state = state;
}
update(state: CanvasLassoState, force = false): boolean {
if (force || this.state !== state) {
this.log.trace({ state }, 'Updating lasso');
this.konva.line.setAttrs({
points: state.points,
globalCompositeOperation: state.compositeOperation,
});
this.state = state;
return true;
}
return false;
}
setVisibility(isVisible: boolean): void {
this.log.trace({ isVisible }, 'Setting lasso visibility');
this.konva.group.visible(isVisible);
}
destroy = () => {
this.log.debug('Destroying module');
this.konva.group.destroy();
};
repr = () => {
return {
id: this.id,
type: this.type,
path: this.path,
parent: this.parent.id,
state: deepClone(this.state),
};
};
}

View File

@@ -4,6 +4,7 @@ import type { CanvasObjectEraserLine } from 'features/controlLayers/konva/Canvas
import type { CanvasObjectEraserLineWithPressure } from 'features/controlLayers/konva/CanvasObject/CanvasObjectEraserLineWithPressure';
import type { CanvasObjectGradient } from 'features/controlLayers/konva/CanvasObject/CanvasObjectGradient';
import type { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
import type { CanvasObjectLasso } from 'features/controlLayers/konva/CanvasObject/CanvasObjectLasso';
import type { CanvasObjectRect } from 'features/controlLayers/konva/CanvasObject/CanvasObjectRect';
import type {
CanvasBrushLineState,
@@ -12,6 +13,7 @@ import type {
CanvasEraserLineWithPressureState,
CanvasGradientState,
CanvasImageState,
CanvasLassoState,
CanvasRectState,
} from 'features/controlLayers/store/types';
@@ -25,6 +27,7 @@ export type AnyObjectRenderer =
| CanvasObjectEraserLine
| CanvasObjectEraserLineWithPressure
| CanvasObjectRect
| CanvasObjectLasso
| CanvasObjectImage
| CanvasObjectGradient;
/**
@@ -37,4 +40,5 @@ export type AnyObjectState =
| CanvasEraserLineWithPressureState
| CanvasImageState
| CanvasRectState
| CanvasLassoState
| CanvasGradientState;

View File

@@ -21,6 +21,7 @@ import {
entityBrushLineAdded,
entityEraserLineAdded,
entityGradientAdded,
entityLassoAdded,
entityMovedBy,
entityMovedTo,
entityRasterized,
@@ -43,6 +44,7 @@ import type {
EntityEraserLineAddedPayload,
EntityGradientAddedPayload,
EntityIdentifierPayload,
EntityLassoAddedPayload,
EntityMovedByPayload,
EntityMovedToPayload,
EntityRasterizedPayload,
@@ -175,6 +177,13 @@ export class CanvasStateApiModule extends CanvasModuleBase {
this.store.dispatch(entityRectAdded(arg));
};
/**
* Adds a lasso object to an entity, pushing state to redux.
*/
addLasso = (arg: EntityLassoAddedPayload) => {
this.store.dispatch(entityLassoAdded(arg));
};
/**
* Adds a gradient to an entity, pushing state to redux.
*/

View File

@@ -0,0 +1,566 @@
import { rgbaColorToString } from 'common/util/colorCodeTransformers';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
import { getPrefixedId, isDistanceMoreThanMin, offsetCoord } from 'features/controlLayers/konva/util';
import type { Coordinate } from 'features/controlLayers/store/types';
import { simplifyFlatNumbersArray } from 'features/controlLayers/util/simplify';
import Konva from 'konva';
import type { KonvaEventObject } from 'konva/lib/Node';
import type { Logger } from 'roarr';
type CanvasLassoToolModuleConfig = {
PREVIEW_STROKE_COLOR: string;
PREVIEW_FILL_COLOR: string;
PREVIEW_STROKE_WIDTH_PX: number;
START_POINT_RADIUS_PX: number;
START_POINT_STROKE_WIDTH_PX: number;
START_POINT_HOVER_RADIUS_DELTA_PX: number;
POLYGON_CLOSE_RADIUS_PX: number;
MIN_FREEHAND_POINT_DISTANCE_PX: number;
MAX_FREEHAND_SEGMENT_LENGTH_PX: number;
FREEHAND_SIMPLIFY_MIN_POINTS: number;
FREEHAND_SIMPLIFY_TOLERANCE: number;
};
const DEFAULT_CONFIG: CanvasLassoToolModuleConfig = {
PREVIEW_STROKE_COLOR: rgbaColorToString({ r: 90, g: 175, b: 255, a: 1 }),
PREVIEW_FILL_COLOR: rgbaColorToString({ r: 90, g: 175, b: 255, a: 0.2 }),
PREVIEW_STROKE_WIDTH_PX: 1.5,
START_POINT_RADIUS_PX: 4,
START_POINT_STROKE_WIDTH_PX: 2,
START_POINT_HOVER_RADIUS_DELTA_PX: 2,
POLYGON_CLOSE_RADIUS_PX: 10,
MIN_FREEHAND_POINT_DISTANCE_PX: 1,
MAX_FREEHAND_SEGMENT_LENGTH_PX: 2,
FREEHAND_SIMPLIFY_MIN_POINTS: 200,
FREEHAND_SIMPLIFY_TOLERANCE: 0.6,
};
export class CanvasLassoToolModule extends CanvasModuleBase {
readonly type = 'lasso_tool';
readonly id: string;
readonly path: string[];
readonly parent: CanvasToolModule;
readonly manager: CanvasManager;
readonly log: Logger;
config: CanvasLassoToolModuleConfig = DEFAULT_CONFIG;
private freehandPoints: Coordinate[] = [];
private polygonPoints: Coordinate[] = [];
private polygonPointer: Coordinate | null = null;
private isDrawingFreehand = false;
konva: {
group: Konva.Group;
fillShape: Konva.Line;
strokeShape: Konva.Line;
startPointIndicator: Konva.Circle;
};
constructor(parent: CanvasToolModule) {
super();
this.id = getPrefixedId(this.type);
this.parent = parent;
this.manager = this.parent.manager;
this.path = this.manager.buildPath(this);
this.log = this.manager.buildLogger(this);
this.log.debug('Creating module');
this.konva = {
group: new Konva.Group({ name: `${this.type}:group`, listening: false }),
fillShape: new Konva.Line({
name: `${this.type}:fill_shape`,
listening: false,
closed: true,
fill: this.config.PREVIEW_FILL_COLOR,
strokeEnabled: false,
visible: false,
perfectDrawEnabled: false,
}),
strokeShape: new Konva.Line({
name: `${this.type}:stroke_shape`,
listening: false,
closed: false,
stroke: this.config.PREVIEW_STROKE_COLOR,
strokeWidth: this.config.PREVIEW_STROKE_WIDTH_PX,
lineCap: 'round',
lineJoin: 'round',
fillEnabled: false,
visible: false,
perfectDrawEnabled: false,
}),
startPointIndicator: new Konva.Circle({
name: `${this.type}:start_point_indicator`,
listening: false,
fillEnabled: false,
stroke: this.config.PREVIEW_STROKE_COLOR,
visible: false,
perfectDrawEnabled: false,
}),
};
this.konva.group.add(this.konva.fillShape);
this.konva.group.add(this.konva.strokeShape);
this.konva.group.add(this.konva.startPointIndicator);
}
syncCursorStyle = () => {
if (!this.parent.getCanDraw()) {
this.manager.stage.setCursor('not-allowed');
return;
}
this.manager.stage.setCursor('crosshair');
};
render = () => {
const tool = this.parent.$tool.get();
const isTemporaryViewSwitch = tool === 'view' && this.parent.$toolBuffer.get() === 'lasso';
if (tool !== 'lasso' && !isTemporaryViewSwitch) {
this.hidePreview();
return;
}
if (tool === 'lasso') {
this.syncCursorStyle();
}
this.syncPreview();
};
onToolChanged = () => {
const tool = this.parent.$tool.get();
const isTemporaryViewSwitch = tool === 'view' && this.parent.$toolBuffer.get() === 'lasso';
if (tool !== 'lasso' && !isTemporaryViewSwitch) {
this.reset();
}
};
hasActiveSession = (): boolean => {
return this.isDrawingFreehand || this.freehandPoints.length > 0 || this.polygonPoints.length > 0;
};
onStagePointerDown = (e: KonvaEventObject<PointerEvent>) => {
const cursorPos = this.parent.$cursorPos.get();
if (!cursorPos) {
return;
}
const lassoMode = this.manager.stateApi.getSettings().lassoMode;
const point = cursorPos.relative;
// Keep middle click for pan and right click for context menu.
if (e.evt.button !== 0) {
return;
}
if (lassoMode === 'freehand') {
if (!this.parent.$isPrimaryPointerDown.get()) {
return;
}
this.polygonPoints = [];
this.polygonPointer = null;
this.freehandPoints = [point];
this.isDrawingFreehand = true;
this.syncPreview();
return;
}
this.freehandPoints = [];
this.isDrawingFreehand = false;
if (this.polygonPoints.length === 0) {
this.polygonPoints = [point];
this.polygonPointer = point;
this.syncPreview();
return;
}
const startPoint = this.polygonPoints[0];
if (!startPoint) {
return;
}
if (
this.polygonPoints.length >= 3 &&
Math.hypot(point.x - startPoint.x, point.y - startPoint.y) <= this.getPolygonCloseRadius()
) {
this.commitContour(this.polygonPoints);
this.reset();
return;
}
const snappedPoint = this.getPolygonPoint(point, e.evt.shiftKey);
this.polygonPoints = [...this.polygonPoints, snappedPoint];
this.polygonPointer = snappedPoint;
this.syncPreview();
};
onStagePointerMove = (_e: KonvaEventObject<PointerEvent>) => {
this.handlePointerMove(_e.evt.shiftKey);
};
onWindowPointerMove = (e: PointerEvent) => {
this.handlePointerMove(e.shiftKey);
};
onStagePointerUp = (_e: KonvaEventObject<PointerEvent>) => {
const lassoMode = this.manager.stateApi.getSettings().lassoMode;
if (lassoMode !== 'freehand' || !this.isDrawingFreehand) {
return;
}
this.commitContour(this.freehandPoints, true);
this.reset();
};
onWindowPointerUp = () => {
const lassoMode = this.manager.stateApi.getSettings().lassoMode;
if (lassoMode !== 'freehand' || !this.isDrawingFreehand) {
return;
}
this.commitContour(this.freehandPoints, true);
this.reset();
};
reset = () => {
this.freehandPoints = [];
this.polygonPoints = [];
this.polygonPointer = null;
this.isDrawingFreehand = false;
this.hidePreview();
};
private handlePointerMove = (shouldSnap: boolean) => {
const cursorPos = this.parent.$cursorPos.get();
if (!cursorPos) {
return;
}
const lassoMode = this.manager.stateApi.getSettings().lassoMode;
const point = cursorPos.relative;
if (lassoMode === 'freehand') {
if (!this.isDrawingFreehand || !this.parent.$isPrimaryPointerDown.get()) {
return;
}
const minDistance = this.manager.stage.unscale(this.config.MIN_FREEHAND_POINT_DISTANCE_PX);
const lastPoint = this.freehandPoints.at(-1) ?? null;
if (!isDistanceMoreThanMin(point, lastPoint, minDistance)) {
return;
}
this.appendFreehandPoint(point);
this.syncPreview();
return;
}
if (this.polygonPoints.length > 0) {
this.polygonPointer = this.getPolygonPoint(point, shouldSnap);
this.syncPreview();
}
};
private appendFreehandPoint = (point: Coordinate) => {
const lastPoint = this.freehandPoints.at(-1) ?? null;
if (!lastPoint) {
this.freehandPoints.push(point);
return;
}
const maxSegmentLength = this.manager.stage.unscale(this.config.MAX_FREEHAND_SEGMENT_LENGTH_PX);
const dx = point.x - lastPoint.x;
const dy = point.y - lastPoint.y;
const distance = Math.hypot(dx, dy);
if (distance <= maxSegmentLength) {
this.freehandPoints.push(point);
return;
}
const steps = Math.ceil(distance / maxSegmentLength);
for (let i = 1; i <= steps; i++) {
const t = i / steps;
this.freehandPoints.push({
x: lastPoint.x + dx * t,
y: lastPoint.y + dy * t,
});
}
};
private hidePreview = () => {
this.konva.strokeShape.visible(false);
this.konva.fillShape.visible(false);
this.konva.startPointIndicator.visible(false);
};
private syncPreview = () => {
const lassoMode = this.manager.stateApi.getSettings().lassoMode;
const stageScale = this.manager.stage.getScale();
const strokeWidth = this.config.PREVIEW_STROKE_WIDTH_PX / stageScale;
let points: Coordinate[] = [];
if (lassoMode === 'freehand') {
points = this.freehandPoints;
} else {
points = [...this.polygonPoints];
if (this.polygonPointer) {
points.push(this.polygonPointer);
}
}
if (points.length < 1) {
this.hidePreview();
return;
}
const flat = points.flatMap((point) => [point.x, point.y]);
this.konva.strokeShape.setAttrs({
points: flat,
strokeWidth,
visible: true,
});
if (points.length >= 3) {
this.konva.fillShape.setAttrs({
points: flat,
visible: true,
});
} else {
this.konva.fillShape.visible(false);
}
if (lassoMode === 'polygon' && this.polygonPoints.length > 0) {
const startPoint = this.polygonPoints[0];
if (startPoint) {
const isHoveringStartPoint = this.getIsHoveringStartPoint(startPoint);
const baseRadius = this.manager.stage.unscale(this.config.START_POINT_RADIUS_PX);
this.konva.startPointIndicator.setAttrs({
x: startPoint.x,
y: startPoint.y,
radius:
baseRadius +
(isHoveringStartPoint ? this.manager.stage.unscale(this.config.START_POINT_HOVER_RADIUS_DELTA_PX) : 0),
strokeWidth: this.manager.stage.unscale(this.config.START_POINT_STROKE_WIDTH_PX),
visible: true,
});
}
} else {
this.konva.startPointIndicator.visible(false);
}
};
private getPolygonCloseRadius = (): number => {
return this.manager.stage.unscale(this.config.POLYGON_CLOSE_RADIUS_PX);
};
private getIsHoveringStartPoint = (startPoint: Coordinate): boolean => {
if (this.polygonPoints.length < 3) {
return false;
}
const pointerPoint = this.parent.$cursorPos.get()?.relative;
if (!pointerPoint) {
return false;
}
return Math.hypot(pointerPoint.x - startPoint.x, pointerPoint.y - startPoint.y) <= this.getPolygonCloseRadius();
};
private getPolygonPoint = (point: Coordinate, shouldSnap: boolean): Coordinate => {
if (!shouldSnap) {
return point;
}
const lastPoint = this.polygonPoints.at(-1);
if (!lastPoint) {
return point;
}
const dx = point.x - lastPoint.x;
const dy = point.y - lastPoint.y;
const distance = Math.hypot(dx, dy);
if (distance === 0) {
return point;
}
const SNAP_ANGLE = Math.PI / 4;
const angle = Math.atan2(dy, dx);
const snappedAngle = Math.round(angle / SNAP_ANGLE) * SNAP_ANGLE;
const snappedPoint = {
x: lastPoint.x + Math.cos(snappedAngle) * distance,
y: lastPoint.y + Math.sin(snappedAngle) * distance,
};
return this.alignPointToStart(snappedPoint);
};
private alignPointToStart = (point: Coordinate): Coordinate => {
if (this.polygonPoints.length < 2) {
return point;
}
const startPoint = this.polygonPoints[0];
if (!startPoint) {
return point;
}
const alignThreshold = this.getPolygonCloseRadius();
const deltaX = Math.abs(point.x - startPoint.x);
const deltaY = Math.abs(point.y - startPoint.y);
const canAlignX = deltaX <= alignThreshold;
const canAlignY = deltaY <= alignThreshold;
if (!canAlignX && !canAlignY) {
return point;
}
if (canAlignX && canAlignY) {
if (deltaX <= deltaY) {
return { x: startPoint.x, y: point.y };
}
return { x: point.x, y: startPoint.y };
}
if (canAlignX) {
return { x: startPoint.x, y: point.y };
}
return { x: point.x, y: startPoint.y };
};
private closeContour = (points: Coordinate[]): Coordinate[] => {
if (points.length === 0) {
return [];
}
const start = points[0];
const end = points.at(-1);
if (!start || !end) {
return points;
}
if (start.x === end.x && start.y === end.y) {
return points;
}
return [...points, start];
};
private commitContour = (points: Coordinate[], simplifyFreehand: boolean = false) => {
const contourPoints = simplifyFreehand ? this.simplifyFreehandContour(points) : points;
if (contourPoints.length < 3) {
return;
}
const closedPoints = this.closeContour(contourPoints);
if (closedPoints.length < 4) {
return;
}
let targetMaskId = this.getActiveInpaintMaskId();
if (!targetMaskId) {
this.manager.stateApi.addInpaintMask({ isSelected: true });
targetMaskId = this.getActiveInpaintMaskId();
}
if (!targetMaskId) {
return;
}
const targetMaskState = this.manager.stateApi
.getInpaintMasksState()
.entities.find((entity) => entity.id === targetMaskId);
if (!targetMaskState) {
return;
}
const normalizedPoints = closedPoints.flatMap((point) => {
const normalizedPoint = offsetCoord(point, targetMaskState.position);
return [normalizedPoint.x, normalizedPoint.y];
});
this.manager.stateApi.addLasso({
entityIdentifier: { type: 'inpaint_mask', id: targetMaskId },
lasso: {
id: getPrefixedId('lasso'),
type: 'lasso',
points: normalizedPoints,
compositeOperation:
this.manager.stateApi.$ctrlKey.get() || this.manager.stateApi.$metaKey.get()
? 'destination-out'
: 'source-over',
},
});
};
private simplifyFreehandContour = (points: Coordinate[]): Coordinate[] => {
if (points.length < this.config.FREEHAND_SIMPLIFY_MIN_POINTS) {
return points;
}
const flatPoints = points.flatMap((point) => [point.x, point.y]);
const simplifiedFlatPoints = simplifyFlatNumbersArray(flatPoints, {
tolerance: this.config.FREEHAND_SIMPLIFY_TOLERANCE,
highestQuality: true,
});
if (simplifiedFlatPoints.length < 6) {
return points;
}
const simplifiedPoints = this.flatNumbersToCoords(simplifiedFlatPoints);
if (simplifiedPoints.length < 3) {
return points;
}
return simplifiedPoints;
};
private flatNumbersToCoords = (points: number[]): Coordinate[] => {
const coords: Coordinate[] = [];
for (let i = 0; i < points.length; i += 2) {
const x = points[i];
const y = points[i + 1];
if (x === undefined || y === undefined) {
continue;
}
coords.push({ x, y });
}
return coords;
};
private getActiveInpaintMaskId = (): string | null => {
const canvasState = this.manager.stateApi.getCanvasState();
const selectedEntityIdentifier = canvasState.selectedEntityIdentifier;
if (selectedEntityIdentifier?.type === 'inpaint_mask') {
const selectedMask = canvasState.inpaintMasks.entities.find(
(entity) => entity.id === selectedEntityIdentifier.id
);
if (selectedMask?.isEnabled) {
return selectedMask.id;
}
// If the selected mask is disabled, commit to a new mask instead.
return null;
}
const inpaintMasks = canvasState.inpaintMasks.entities;
const activeMask = [...inpaintMasks].reverse().find((entity) => entity.isEnabled);
return activeMask?.id ?? null;
};
repr = () => {
return {
id: this.id,
type: this.type,
path: this.path,
freehandPoints: this.freehandPoints,
polygonPoints: this.polygonPoints,
polygonPointer: this.polygonPointer,
isDrawingFreehand: this.isDrawingFreehand,
};
};
}

View File

@@ -5,6 +5,7 @@ import { CanvasBrushToolModule } from 'features/controlLayers/konva/CanvasTool/C
import { CanvasColorPickerToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasColorPickerToolModule';
import { CanvasEraserToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasEraserToolModule';
import { CanvasGradientToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasGradientToolModule';
import { CanvasLassoToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasLassoToolModule';
import { CanvasMoveToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasMoveToolModule';
import { CanvasRectToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasRectToolModule';
import { CanvasTextToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasTextToolModule';
@@ -38,6 +39,7 @@ Konva.dragButtons = [0];
const KEY_ESCAPE = 'Escape';
const KEY_SPACE = ' ';
const KEY_ALT = 'Alt';
const CODE_SPACE = 'Space';
type CanvasToolModuleConfig = {
BRUSH_SPACING_TARGET_SCALE: number;
@@ -62,6 +64,7 @@ export class CanvasToolModule extends CanvasModuleBase {
brush: CanvasBrushToolModule;
eraser: CanvasEraserToolModule;
rect: CanvasRectToolModule;
lasso: CanvasLassoToolModule;
gradient: CanvasGradientToolModule;
colorPicker: CanvasColorPickerToolModule;
bbox: CanvasBboxToolModule;
@@ -121,6 +124,7 @@ export class CanvasToolModule extends CanvasModuleBase {
brush: new CanvasBrushToolModule(this),
eraser: new CanvasEraserToolModule(this),
rect: new CanvasRectToolModule(this),
lasso: new CanvasLassoToolModule(this),
gradient: new CanvasGradientToolModule(this),
colorPicker: new CanvasColorPickerToolModule(this),
bbox: new CanvasBboxToolModule(this),
@@ -139,15 +143,26 @@ export class CanvasToolModule extends CanvasModuleBase {
this.konva.group.add(this.tools.colorPicker.konva.group);
this.konva.group.add(this.tools.text.konva.group);
this.konva.group.add(this.tools.bbox.konva.group);
this.konva.group.add(this.tools.lasso.konva.group);
this.subscriptions.add(this.manager.stage.$stageAttrs.listen(this.render));
this.subscriptions.add(this.manager.$isBusy.listen(this.render));
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectCanvasSettingsSlice, this.render));
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectCanvasSlice, this.render));
this.subscriptions.add(
this.$tool.listen(() => {
// On tool switch, reset mouse state
this.manager.tool.$isPrimaryPointerDown.set(false);
this.$tool.listen((tool, previousTool) => {
// Preserve pointer state during temporary view switching so lasso sessions can freeze/resume on space.
const shouldPreservePointerState =
this.$toolBuffer.get() === 'lasso' &&
this.tools.lasso.hasActiveSession() &&
((previousTool === 'lasso' && tool === 'view') || (previousTool === 'view' && tool === 'lasso'));
if (!shouldPreservePointerState) {
// On tool switch, reset mouse state
this.manager.tool.$isPrimaryPointerDown.set(false);
}
this.tools.lasso.onToolChanged();
void this.tools.text.onToolChanged();
this.render();
})
@@ -189,6 +204,8 @@ export class CanvasToolModule extends CanvasModuleBase {
this.tools.colorPicker.syncCursorStyle();
} else if (tool === 'text') {
this.tools.text.syncCursorStyle();
} else if (tool === 'lasso') {
this.tools.lasso.syncCursorStyle();
} else if (selectedEntityAdapter) {
if (selectedEntityAdapter.$isDisabled.get()) {
stage.setCursor('not-allowed');
@@ -222,6 +239,7 @@ export class CanvasToolModule extends CanvasModuleBase {
this.tools.colorPicker.render();
this.tools.text.render();
this.tools.bbox.render();
this.tools.lasso.render();
};
syncCursorPositions = () => {
@@ -235,6 +253,19 @@ export class CanvasToolModule extends CanvasModuleBase {
this.$cursorPos.set({ relative, absolute });
};
syncCursorPositionsFromWindowEvent = (e: PointerEvent): boolean => {
this.konva.stage.setPointersPositions(e);
const relative = this.konva.stage.getRelativePointerPosition();
const absolute = this.konva.stage.getPointerPosition();
if (!relative || !absolute) {
return false;
}
this.$cursorPos.set({ relative, absolute });
return true;
};
getClip = (
entity: CanvasRegionalGuidanceState | CanvasControlLayerState | CanvasRasterLayerState | CanvasInpaintMaskState
) => {
@@ -274,6 +305,7 @@ export class CanvasToolModule extends CanvasModuleBase {
window.addEventListener('keydown', this.onKeyDown);
window.addEventListener('keyup', this.onKeyUp);
window.addEventListener('pointermove', this.onWindowPointerMove);
window.addEventListener('pointerup', this.onWindowPointerUp);
window.addEventListener('blur', this.onWindowBlur);
@@ -289,6 +321,7 @@ export class CanvasToolModule extends CanvasModuleBase {
window.removeEventListener('keydown', this.onKeyDown);
window.removeEventListener('keyup', this.onKeyUp);
window.removeEventListener('pointermove', this.onWindowPointerMove);
window.removeEventListener('pointerup', this.onWindowPointerUp);
window.removeEventListener('blur', this.onWindowBlur);
};
@@ -316,6 +349,18 @@ export class CanvasToolModule extends CanvasModuleBase {
return true;
}
if (tool === 'lasso') {
if (this.manager.$isBusy.get()) {
return false;
}
if (this.manager.stage.getIsDragging()) {
return false;
}
return true;
}
if (this.manager.stateApi.getRenderedEntityCount() === 0) {
return false;
}
@@ -407,6 +452,8 @@ export class CanvasToolModule extends CanvasModuleBase {
await this.tools.eraser.onStagePointerDown(e);
} else if (tool === 'rect') {
await this.tools.rect.onStagePointerDown(e);
} else if (tool === 'lasso') {
await this.tools.lasso.onStagePointerDown(e);
} else if (tool === 'gradient') {
await this.tools.gradient.onStagePointerDown(e);
} else if (tool === 'text') {
@@ -441,6 +488,8 @@ export class CanvasToolModule extends CanvasModuleBase {
this.tools.eraser.onStagePointerUp(e);
} else if (tool === 'rect') {
this.tools.rect.onStagePointerUp(e);
} else if (tool === 'lasso') {
void this.tools.lasso.onStagePointerUp(e);
} else if (tool === 'gradient') {
this.tools.gradient.onStagePointerUp(e);
}
@@ -476,6 +525,8 @@ export class CanvasToolModule extends CanvasModuleBase {
await this.tools.eraser.onStagePointerMove(e);
} else if (tool === 'rect') {
await this.tools.rect.onStagePointerMove(e);
} else if (tool === 'lasso') {
await this.tools.lasso.onStagePointerMove(e);
} else if (tool === 'gradient') {
await this.tools.gradient.onStagePointerMove(e);
} else if (tool === 'text') {
@@ -560,6 +611,7 @@ export class CanvasToolModule extends CanvasModuleBase {
onWindowPointerUp = (_: PointerEvent) => {
try {
this.$isPrimaryPointerDown.set(false);
void this.tools.lasso.onWindowPointerUp();
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (selectedEntity && selectedEntity.bufferRenderer.hasBuffer() && !this.manager.$isBusy.get()) {
@@ -570,6 +622,41 @@ export class CanvasToolModule extends CanvasModuleBase {
}
};
onWindowPointerMove = (e: PointerEvent) => {
const target = e.target;
if (target instanceof Node && this.manager.stage.container.contains(target)) {
return;
}
if (this.$tool.get() !== 'lasso') {
return;
}
if (!this.getCanDraw()) {
return;
}
if (!this.$isPrimaryPointerDown.get()) {
return;
}
if (!this.tools.lasso.hasActiveSession()) {
return;
}
try {
this.$lastPointerType.set(e.pointerType);
if (!this.syncCursorPositionsFromWindowEvent(e)) {
return;
}
this.tools.lasso.onWindowPointerMove(e);
} finally {
this.render();
}
};
/**
* We want to reset any "quick-switch" tool selection on window blur. Fixes an issue where you alt-tab out of the app
* and the color picker tool is still active when you come back.
@@ -579,6 +666,7 @@ export class CanvasToolModule extends CanvasModuleBase {
};
onKeyDown = (e: KeyboardEvent) => {
const isSpaceKey = e.key === KEY_SPACE || e.code === CODE_SPACE;
if (e.target instanceof HTMLInputElement || e.target instanceof HTMLTextAreaElement) {
return;
}
@@ -600,6 +688,9 @@ export class CanvasToolModule extends CanvasModuleBase {
if (e.key === KEY_ESCAPE) {
// Cancel shape drawing on escape
e.preventDefault();
if (this.$tool.get() === 'lasso') {
this.tools.lasso.reset();
}
const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter();
if (
selectedEntity &&
@@ -612,19 +703,27 @@ export class CanvasToolModule extends CanvasModuleBase {
return;
}
if (e.key === KEY_SPACE) {
if (isSpaceKey) {
// Select the view tool on space key down
e.preventDefault();
this.$toolBuffer.set(this.$tool.get());
this.$tool.set('view');
e.stopPropagation();
const currentTool = this.$tool.get();
this.$toolBuffer.set(currentTool);
this.manager.stateApi.$spaceKey.set(true);
this.$cursorPos.set(null);
this.$tool.set('view');
if (currentTool === 'lasso' && this.tools.lasso.hasActiveSession() && this.$isPrimaryPointerDown.get()) {
// Start panning immediately if user is already drawing with freehand lasso.
this.manager.stage.startDragging();
} else {
this.$cursorPos.set(null);
}
return;
}
if (e.key === KEY_ALT) {
// Select the color picker on alt key down
e.preventDefault();
e.stopPropagation();
this.$toolBuffer.set(this.$tool.get());
this.$tool.set('colorPicker');
}
@@ -644,9 +743,10 @@ export class CanvasToolModule extends CanvasModuleBase {
return;
}
if (e.key === KEY_SPACE) {
if (e.key === KEY_SPACE || e.code === CODE_SPACE) {
// Revert the tool to the previous tool on space key up
e.preventDefault();
e.stopPropagation();
this.revertToolBuffer();
this.manager.stateApi.$spaceKey.set(false);
return;
@@ -655,6 +755,7 @@ export class CanvasToolModule extends CanvasModuleBase {
if (e.key === KEY_ALT) {
// Revert the tool to the previous tool on alt key up
e.preventDefault();
e.stopPropagation();
this.revertToolBuffer();
return;
}
@@ -684,6 +785,7 @@ export class CanvasToolModule extends CanvasModuleBase {
eraser: this.tools.eraser.repr(),
colorPicker: this.tools.colorPicker.repr(),
rect: this.tools.rect.repr(),
lasso: this.tools.lasso.repr(),
gradient: this.tools.gradient.repr(),
bbox: this.tools.bbox.repr(),
view: this.tools.view.repr(),

View File

@@ -13,6 +13,7 @@ const zTransformSmoothingMode = z.enum(['bilinear', 'bicubic', 'hamming', 'lancz
export type TransformSmoothingMode = z.infer<typeof zTransformSmoothingMode>;
const zGradientType = z.enum(['linear', 'radial']);
const zLassoMode = z.enum(['freehand', 'polygon']);
const zCanvasSettingsState = z.object({
/**
@@ -118,6 +119,10 @@ const zCanvasSettingsState = z.object({
* Whether the gradient tool clips to the drag gesture.
*/
gradientClipEnabled: z.boolean().default(true),
/**
* The lasso tool mode.
*/
lassoMode: zLassoMode.default('freehand'),
});
type CanvasSettingsState = z.infer<typeof zCanvasSettingsState>;
@@ -148,6 +153,7 @@ const getInitialState = (): CanvasSettingsState => ({
transformSmoothingMode: 'bicubic',
gradientType: 'linear',
gradientClipEnabled: true,
lassoMode: 'freehand',
});
const slice = createSlice({
@@ -245,6 +251,9 @@ const slice = createSlice({
settingsGradientClipToggled: (state) => {
state.gradientClipEnabled = !state.gradientClipEnabled;
},
settingsLassoModeChanged: (state, action: PayloadAction<CanvasSettingsState['lassoMode']>) => {
state.lassoMode = action.payload;
},
},
});
@@ -276,6 +285,7 @@ export const {
settingsFillColorPickerPinnedSet,
settingsGradientTypeChanged,
settingsGradientClipToggled,
settingsLassoModeChanged,
} = slice.actions;
export const canvasSettingsSliceConfig: SliceConfig<typeof slice> = {
@@ -317,3 +327,4 @@ export const selectTransformSmoothingEnabled = createCanvasSettingsSelector(
export const selectTransformSmoothingMode = createCanvasSettingsSelector((settings) => settings.transformSmoothingMode);
export const selectGradientType = createCanvasSettingsSelector((settings) => settings.gradientType);
export const selectGradientClipEnabled = createCanvasSettingsSelector((settings) => settings.gradientClipEnabled);
export const selectLassoMode = createCanvasSettingsSelector((settings) => settings.lassoMode);

View File

@@ -67,6 +67,7 @@ import type {
EntityEraserLineAddedPayload,
EntityGradientAddedPayload,
EntityIdentifierPayload,
EntityLassoAddedPayload,
EntityMovedToPayload,
EntityRasterizedPayload,
EntityRectAddedPayload,
@@ -100,6 +101,12 @@ import {
makeDefaultRasterLayerAdjustments,
} from './util';
const resetInpaintMasksHiddenIfEmpty = (state: CanvasState) => {
if (state.inpaintMasks.entities.length === 0) {
state.inpaintMasks.isHidden = false;
}
};
const slice = createSlice({
name: 'canvas',
initialState: getInitialCanvasState(),
@@ -1062,6 +1069,7 @@ const slice = createSlice({
(entity) => !mergedEntitiesToDelete.includes(entity.id)
);
}
resetInpaintMasksHiddenIfEmpty(state);
const entityIdentifier = getEntityIdentifier(entityState);
if (isSelected || mergedEntitiesToDelete.length > 0) {
@@ -1133,6 +1141,7 @@ const slice = createSlice({
if (replace) {
// Remove the inpaint mask
state.inpaintMasks.entities = state.inpaintMasks.entities.filter((layer) => layer.id !== entityIdentifier.id);
resetInpaintMasksHiddenIfEmpty(state);
}
// Add the new regional guidance
@@ -1559,6 +1568,17 @@ const slice = createSlice({
// re-render it (reference equality check). I don't like this behaviour.
entity.objects.push({ ...rect });
},
entityLassoAdded: (state, action: PayloadAction<EntityLassoAddedPayload>) => {
const { entityIdentifier, lasso } = action.payload;
const entity = selectEntity(state, entityIdentifier);
if (!entity) {
return;
}
// TODO(psyche): If we add the object without splatting, the renderer will see it as the same object and not
// re-render it (reference equality check). I don't like this behaviour.
entity.objects.push({ ...lasso });
},
entityGradientAdded: (state, action: PayloadAction<EntityGradientAddedPayload>) => {
const { entityIdentifier, gradient } = action.payload;
const entity = selectEntity(state, entityIdentifier);
@@ -1601,6 +1621,7 @@ const slice = createSlice({
break;
}
resetInpaintMasksHiddenIfEmpty(state);
state.selectedEntityIdentifier = selectedEntityIdentifier;
},
entityArrangedForwardOne: (state, action: PayloadAction<EntityIdentifierPayload>) => {
@@ -1689,6 +1710,7 @@ const slice = createSlice({
break;
case 'inpaint_mask':
state.inpaintMasks.isHidden = !state.inpaintMasks.isHidden;
resetInpaintMasksHiddenIfEmpty(state);
break;
case 'regional_guidance':
state.regionalGuidance.isHidden = !state.regionalGuidance.isHidden;
@@ -1697,13 +1719,16 @@ const slice = createSlice({
},
allNonRasterLayersIsHiddenToggled: (state) => {
const hasVisibleNonRasterLayers =
!state.controlLayers.isHidden || !state.inpaintMasks.isHidden || !state.regionalGuidance.isHidden;
(state.controlLayers.entities.length > 0 && !state.controlLayers.isHidden) ||
(state.inpaintMasks.entities.length > 0 && !state.inpaintMasks.isHidden) ||
(state.regionalGuidance.entities.length > 0 && !state.regionalGuidance.isHidden);
const shouldHide = hasVisibleNonRasterLayers;
state.controlLayers.isHidden = shouldHide;
state.inpaintMasks.isHidden = shouldHide;
state.regionalGuidance.isHidden = shouldHide;
resetInpaintMasksHiddenIfEmpty(state);
},
allEntitiesDeleted: (state) => {
// Deleting all entities is equivalent to resetting the state for each entity type
@@ -1719,6 +1744,37 @@ const slice = createSlice({
state.inpaintMasks.entities = inpaintMasks;
state.rasterLayers.entities = rasterLayers;
state.regionalGuidance.entities = regionalGuidance;
resetInpaintMasksHiddenIfEmpty(state);
return state;
},
canvasProjectRecalled: (
state,
action: PayloadAction<{
rasterLayers: CanvasRasterLayerState[];
controlLayers: CanvasControlLayerState[];
inpaintMasks: CanvasInpaintMaskState[];
regionalGuidance: CanvasRegionalGuidanceState[];
bbox: CanvasState['bbox'];
selectedEntityIdentifier: CanvasState['selectedEntityIdentifier'];
bookmarkedEntityIdentifier: CanvasState['bookmarkedEntityIdentifier'];
}>
) => {
const {
rasterLayers,
controlLayers,
inpaintMasks,
regionalGuidance,
bbox,
selectedEntityIdentifier,
bookmarkedEntityIdentifier,
} = action.payload;
state.rasterLayers.entities = rasterLayers;
state.controlLayers.entities = controlLayers;
state.inpaintMasks.entities = inpaintMasks;
state.regionalGuidance.entities = regionalGuidance;
state.bbox = bbox;
state.selectedEntityIdentifier = selectedEntityIdentifier;
state.bookmarkedEntityIdentifier = bookmarkedEntityIdentifier;
return state;
},
canvasUndo: () => {},
@@ -1802,6 +1858,7 @@ const resetState = (state: CanvasState) => {
export const {
canvasMetadataRecalled,
canvasProjectRecalled,
canvasUndo,
canvasRedo,
canvasClearHistory,
@@ -1821,6 +1878,7 @@ export const {
entityBrushLineAdded,
entityEraserLineAdded,
entityRectAdded,
entityLassoAdded,
entityGradientAdded,
// Raster layer adjustments
rasterLayerAdjustmentsSet,
@@ -1947,7 +2005,13 @@ export const canvasSliceConfig: SliceConfig<typeof slice> = {
},
};
const doNotGroupMatcher = isAnyOf(entityBrushLineAdded, entityEraserLineAdded, entityRectAdded, entityGradientAdded);
const doNotGroupMatcher = isAnyOf(
entityBrushLineAdded,
entityEraserLineAdded,
entityRectAdded,
entityLassoAdded,
entityGradientAdded
);
// Store rapid actions of the same type at most once every x time.
// See: https://github.com/omnidan/redux-undo/blob/master/examples/throttled-drag/util/undoFilter.js

View File

@@ -519,6 +519,9 @@ const slice = createSlice({
state.dimensions.aspectRatio.isLocked = true;
},
paramsReset: (state) => resetState(state),
paramsRecalled: (_state, action: PayloadAction<ParamsState>) => {
return action.payload;
},
},
extraReducers(builder) {
// Reset params state on logout to prevent user data leakage when switching users
@@ -666,6 +669,7 @@ export const {
openaiInputFidelityChanged,
geminiTemperatureChanged,
geminiThinkingLevelChanged,
paramsRecalled,
animaVaeModelSelected,
animaQwen3EncoderModelSelected,
animaT5EncoderModelSelected,

View File

@@ -362,5 +362,9 @@ export const selectCanvasMetadata = createSelector(
* This is used to determine the state of the toggle button that shows/hides all non-raster layers.
*/
export const selectNonRasterLayersIsHidden = createSelector(selectCanvasSlice, (canvas) => {
return canvas.controlLayers.isHidden && canvas.inpaintMasks.isHidden && canvas.regionalGuidance.isHidden;
const areControlLayersEffectivelyHidden = canvas.controlLayers.entities.length === 0 || canvas.controlLayers.isHidden;
const areInpaintMasksEffectivelyHidden = canvas.inpaintMasks.entities.length === 0 || canvas.inpaintMasks.isHidden;
const areRegionalGuidanceEffectivelyHidden =
canvas.regionalGuidance.entities.length === 0 || canvas.regionalGuidance.isHidden;
return areControlLayersEffectivelyHidden && areInpaintMasksEffectivelyHidden && areRegionalGuidanceEffectivelyHidden;
});

View File

@@ -105,7 +105,7 @@ const zIPMethodV2 = z.enum(['full', 'style', 'composition', 'style_strong', 'sty
export type IPMethodV2 = z.infer<typeof zIPMethodV2>;
export const isIPMethodV2 = (v: unknown): v is IPMethodV2 => zIPMethodV2.safeParse(v).success;
const _zTool = z.enum(['brush', 'eraser', 'move', 'rect', 'gradient', 'view', 'bbox', 'colorPicker', 'text']);
const _zTool = z.enum(['brush', 'eraser', 'move', 'rect', 'lasso', 'gradient', 'view', 'bbox', 'colorPicker', 'text']);
export type Tool = z.infer<typeof _zTool>;
const zPoints = z.array(z.number()).refine((points) => points.length % 2 === 0, {
@@ -260,6 +260,20 @@ const zCanvasRectState = z.object({
});
export type CanvasRectState = z.infer<typeof zCanvasRectState>;
const zCanvasLassoCompositeOperation = z.enum(['source-over', 'destination-out']);
const zCanvasLassoState = z.object({
id: zId,
type: z.literal('lasso'),
/**
* Points in the format [x1, y1, x2, y2, ...].
* The lasso tool always commits a closed contour.
*/
points: zPoints,
compositeOperation: zCanvasLassoCompositeOperation.default('source-over'),
});
export type CanvasLassoState = z.infer<typeof zCanvasLassoState>;
// Gradient state includes clip metadata so the tool can optionally clip to drag gesture.
const zCanvasLinearGradientState = z.object({
id: zId,
@@ -309,6 +323,7 @@ const zCanvasObjectState = z.union([
zCanvasBrushLineState,
zCanvasEraserLineState,
zCanvasRectState,
zCanvasLassoState,
zCanvasBrushLineWithPressureState,
zCanvasEraserLineWithPressureState,
zCanvasGradientState,
@@ -992,6 +1007,7 @@ export type EntityEraserLineAddedPayload = EntityIdentifierPayload<{
eraserLine: CanvasEraserLineState | CanvasEraserLineWithPressureState;
}>;
export type EntityRectAddedPayload = EntityIdentifierPayload<{ rect: CanvasRectState }>;
export type EntityLassoAddedPayload = EntityIdentifierPayload<{ lasso: CanvasLassoState }>;
export type EntityGradientAddedPayload = EntityIdentifierPayload<{ gradient: CanvasGradientState }>;
export type EntityRasterizedPayload = EntityIdentifierPayload<{
imageObject: CanvasImageState;

View File

@@ -0,0 +1,287 @@
import { deepClone } from 'common/util/deepClone';
import type {
CanvasControlLayerState,
CanvasInpaintMaskState,
CanvasObjectState,
CanvasRasterLayerState,
CanvasRegionalGuidanceState,
CanvasState,
CroppableImageWithDims,
ImageWithDims,
RefImageState,
} from 'features/controlLayers/store/types';
import { getImageDTOSafe } from 'services/api/endpoints/images';
import { z } from 'zod';
export const CANVAS_PROJECT_VERSION = 1;
export const CANVAS_PROJECT_EXTENSION = '.invk';
// #region Manifest
const zCanvasProjectManifest = z.object({
version: z.literal(CANVAS_PROJECT_VERSION),
appVersion: z.string(),
createdAt: z.string(),
name: z.string(),
});
export type CanvasProjectManifest = z.infer<typeof zCanvasProjectManifest>;
export const parseManifest = (data: unknown): CanvasProjectManifest => {
return zCanvasProjectManifest.parse(data);
};
// #endregion
// #region Canvas Project State
export type CanvasProjectState = {
rasterLayers: CanvasRasterLayerState[];
controlLayers: CanvasControlLayerState[];
inpaintMasks: CanvasInpaintMaskState[];
regionalGuidance: CanvasRegionalGuidanceState[];
bbox: CanvasState['bbox'];
selectedEntityIdentifier: CanvasState['selectedEntityIdentifier'];
bookmarkedEntityIdentifier: CanvasState['bookmarkedEntityIdentifier'];
};
// #endregion
// #region Image Name Collection
/**
* Collects image_name values from a CroppableImageWithDims (used by ref images).
*/
const collectFromCroppableImage = (image: CroppableImageWithDims | null, names: Set<string>): void => {
if (!image) {
return;
}
names.add(image.original.image.image_name);
if (image.crop?.image) {
names.add(image.crop.image.image_name);
}
};
/**
* Collects image_name values from an ImageWithDims (used by regional guidance ref images).
*/
const collectFromImageWithDims = (image: ImageWithDims | null, names: Set<string>): void => {
if (!image) {
return;
}
names.add(image.image_name);
};
/**
* Collects image_name values from canvas objects (brush lines, images, etc.).
*/
const collectFromObjects = (objects: CanvasObjectState[], names: Set<string>): void => {
for (const obj of objects) {
if (obj.type === 'image' && 'image_name' in obj.image) {
names.add(obj.image.image_name);
}
}
};
/**
* Walks the entire canvas state + ref images and returns a deduplicated set of all image_name references.
*/
export const collectImageNames = (canvasState: CanvasProjectState, refImages: RefImageState[]): Set<string> => {
const names = new Set<string>();
// Raster layers
for (const layer of canvasState.rasterLayers) {
collectFromObjects(layer.objects, names);
}
// Control layers
for (const layer of canvasState.controlLayers) {
collectFromObjects(layer.objects, names);
}
// Inpaint masks
for (const mask of canvasState.inpaintMasks) {
collectFromObjects(mask.objects, names);
}
// Regional guidance
for (const rg of canvasState.regionalGuidance) {
collectFromObjects(rg.objects, names);
for (const refImage of rg.referenceImages) {
if (refImage.config.type === 'ip_adapter' || refImage.config.type === 'flux_redux') {
collectFromImageWithDims(refImage.config.image, names);
}
}
}
// Global reference images
for (const refImage of refImages) {
collectFromCroppableImage(refImage.config.image, names);
}
return names;
};
// #endregion
// #region Image Name Remapping
/**
* Remaps image_name values in a CroppableImageWithDims.
*/
/**
* Remaps image_name values in a CroppableImageWithDims in-place.
* Caller is responsible for cloning beforehand.
*/
const remapCroppableImage = (image: CroppableImageWithDims | null, mapping: Map<string, string>): void => {
if (!image) {
return;
}
const newOriginalName = mapping.get(image.original.image.image_name);
if (newOriginalName) {
image.original.image.image_name = newOriginalName;
}
if (image.crop?.image) {
const newCropName = mapping.get(image.crop.image.image_name);
if (newCropName) {
image.crop.image.image_name = newCropName;
}
}
};
/**
* Remaps image_name in an ImageWithDims.
*/
const remapImageWithDims = (image: ImageWithDims | null, mapping: Map<string, string>): ImageWithDims | null => {
if (!image) {
return null;
}
const result = deepClone(image);
const newName = mapping.get(result.image_name);
if (newName) {
result.image_name = newName;
}
return result;
};
/**
* Remaps image_name values in canvas objects.
*/
const remapObjects = (objects: CanvasObjectState[], mapping: Map<string, string>): CanvasObjectState[] => {
return objects.map((obj) => {
if (obj.type === 'image' && 'image_name' in obj.image) {
const newName = mapping.get(obj.image.image_name);
if (newName) {
return { ...obj, image: { ...obj.image, image_name: newName } };
}
}
return obj;
});
};
/**
* Deep-clones canvas state and remaps all image_name values using the provided mapping.
* Only images present in the mapping are changed (images that already existed on the server are skipped).
*/
export const remapCanvasState = (canvasState: CanvasProjectState, mapping: Map<string, string>): CanvasProjectState => {
if (mapping.size === 0) {
return canvasState;
}
const result = deepClone(canvasState);
for (const layer of result.rasterLayers) {
layer.objects = remapObjects(layer.objects, mapping);
}
for (const layer of result.controlLayers) {
layer.objects = remapObjects(layer.objects, mapping);
}
for (const mask of result.inpaintMasks) {
mask.objects = remapObjects(mask.objects, mapping);
}
for (const rg of result.regionalGuidance) {
rg.objects = remapObjects(rg.objects, mapping);
for (const refImage of rg.referenceImages) {
if (refImage.config.type === 'ip_adapter' || refImage.config.type === 'flux_redux') {
refImage.config.image = remapImageWithDims(refImage.config.image, mapping);
}
}
}
return result;
};
/**
* Deep-clones ref images and remaps all image_name values using the provided mapping.
*/
export const remapRefImages = (refImages: RefImageState[], mapping: Map<string, string>): RefImageState[] => {
if (mapping.size === 0) {
return refImages;
}
return refImages.map((refImage) => {
const result = deepClone(refImage);
remapCroppableImage(result.config.image, mapping);
return result;
});
};
// #endregion
// #region Concurrency
const MAX_CONCURRENT_REQUESTS = 5;
/**
* Processes an array of async tasks with a concurrency limit.
*/
export const processWithConcurrencyLimit = async <T>(
items: T[],
fn: (item: T) => Promise<void>,
limit: number = MAX_CONCURRENT_REQUESTS
): Promise<void> => {
let index = 0;
const next = async (): Promise<void> => {
while (index < items.length) {
const currentIndex = index++;
await fn(items[currentIndex]!);
}
};
const workers = Array.from({ length: Math.min(limit, items.length) }, () => next());
await Promise.all(workers);
};
// #endregion
// #region Image Existence Check
/**
* Checks which images already exist on the backend server.
* Returns sets of existing and missing image names.
*/
export const checkExistingImages = async (
imageNames: Set<string>
): Promise<{ existing: Set<string>; missing: Set<string> }> => {
const existing = new Set<string>();
const missing = new Set<string>();
await processWithConcurrencyLimit(Array.from(imageNames), async (imageName) => {
const dto = await getImageDTOSafe(imageName);
if (dto) {
existing.add(imageName);
} else {
missing.add(imageName);
}
});
return { existing, missing };
};
// #endregion

View File

@@ -1,6 +1,7 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import {
Box,
Button,
Flex,
Icon,
Input,
@@ -17,6 +18,7 @@ import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { CommandEmpty, CommandItem, CommandList, CommandRoot } from 'cmdk';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { capitalize } from 'es-toolkit';
import { memoize } from 'es-toolkit/compat';
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
import {
@@ -33,16 +35,24 @@ import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupied
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
import { selectShouldGroupNodesByCategory } from 'features/nodes/store/workflowSettingsSlice';
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { toast } from 'features/toast/toast';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { computed } from 'nanostores';
import type { ChangeEvent } from 'react';
import type { ChangeEvent, Dispatch, SetStateAction } from 'react';
import { memo, useCallback, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCircuitryBold, PiFlaskBold, PiHammerBold, PiLightningFill } from 'react-icons/pi';
import {
PiCaretDownBold,
PiCaretRightBold,
PiCircuitryBold,
PiFlaskBold,
PiHammerBold,
PiLightningFill,
} from 'react-icons/pi';
import type { S } from 'services/api/types';
import { objectEntries } from 'tsafe';
import { useDebounce } from 'use-debounce';
@@ -171,15 +181,36 @@ export const AddNodeCmdk = memo(() => {
const onClose = useCallback(() => {
close();
setSearchTerm('');
setExpandedCategories(new Set());
$pendingConnection.set(null);
}, [close]);
const [expandedCategories, setExpandedCategories] = useState<Set<string>>(new Set());
const toggleCategory = useCallback((category: string) => {
setExpandedCategories((prev) => {
const next = new Set(prev);
if (next.has(category)) {
next.delete(category);
} else {
next.add(category);
}
return next;
});
}, []);
const onSelect = useCallback(
(value: string) => {
// Category headers have a special prefix
if (value.startsWith('__category__:')) {
const category = value.slice('__category__:'.length);
toggleCategory(category);
return;
}
addNode(value);
onClose();
},
[addNode, onClose]
[addNode, onClose, toggleCategory]
);
return (
@@ -204,7 +235,12 @@ export const AddNodeCmdk = memo(() => {
/>
</CommandEmpty>
<CommandList>
<NodeCommandList searchTerm={debouncedSearchTerm} onSelect={onSelect} />
<NodeCommandList
searchTerm={debouncedSearchTerm}
onSelect={onSelect}
expandedCategories={expandedCategories}
setExpandedCategories={setExpandedCategories}
/>
</CommandList>
</ScrollableContent>
</Box>
@@ -230,6 +266,7 @@ type NodeCommandItemData = {
description: string;
classification: S['Classification'];
nodePack: string;
category: string;
};
/**
@@ -260,6 +297,7 @@ type FilterableItem = {
tags: string[];
classification: S['Classification'];
nodePack: string;
category: string;
};
const filter = memoize(
@@ -290,6 +328,10 @@ const filter = memoize(
return true;
}
if (item.category.includes(searchTerm) || regex.test(item.category)) {
return true;
}
for (const tag of item.tags) {
if (tag.includes(searchTerm) || regex.test(tag)) {
return true;
@@ -301,112 +343,253 @@ const filter = memoize(
(item: FilterableItem, searchTerm: string) => `${item.type}-${searchTerm}`
);
const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; onSelect: (value: string) => void }) => {
const { t } = useTranslation();
const templatesArray = useStore($templatesArray);
const pendingConnection = useStore($pendingConnection);
const currentImageFilterItem = useMemo<FilterableItem>(
() => ({
type: 'current_image',
title: t('nodes.currentImage'),
description: t('nodes.currentImageDescription'),
tags: ['progress', 'image', 'current'],
classification: 'stable',
nodePack: 'invokeai',
}),
[t]
);
const notesFilterItem = useMemo<FilterableItem>(
() => ({
type: 'notes',
title: t('nodes.notes'),
description: t('nodes.notesDescription'),
tags: ['notes'],
classification: 'stable',
nodePack: 'invokeai',
}),
[t]
);
const categoryItemSx: SystemStyleObject = {
cursor: 'pointer',
userSelect: 'none',
'&[data-selected="true"]': {
bg: 'base.750',
},
};
const items = useMemo<NodeCommandItemData[]>(() => {
// If we have a connection in progress, we need to filter the node choices
const _items: NodeCommandItemData[] = [];
const NodeCommandItem = memo(
({
item,
onSelect,
isGrouped,
}: {
item: NodeCommandItemData;
onSelect: (value: string) => void;
isGrouped?: boolean;
}) => (
<CommandItem value={item.value} onSelect={onSelect} asChild>
<Flex role="button" flexDir="column" sx={cmdkItemSx} py={1} px={2} ps={isGrouped ? 6 : 2} borderRadius="base">
<Flex alignItems="center" gap={2}>
{item.classification === 'beta' && <Icon boxSize={4} color="invokeYellow.300" as={PiHammerBold} />}
{item.classification === 'prototype' && <Icon boxSize={4} color="invokeRed.300" as={PiFlaskBold} />}
{item.classification === 'internal' && <Icon boxSize={4} color="invokePurple.300" as={PiCircuitryBold} />}
{item.classification === 'special' && <Icon boxSize={4} color="invokeGreen.300" as={PiLightningFill} />}
<Text fontWeight="semibold">{item.label}</Text>
<Spacer />
<Text variant="subtext" fontWeight="semibold">
{item.nodePack}
</Text>
</Flex>
{item.description && <Text color="base.200">{item.description}</Text>}
</Flex>
</CommandItem>
)
);
if (!pendingConnection) {
for (const template of templatesArray) {
if (filter(template, searchTerm)) {
_items.push({
label: template.title,
value: template.type,
description: template.description,
classification: template.classification,
nodePack: template.nodePack,
});
NodeCommandItem.displayName = 'NodeCommandItem';
const NodeCommandList = memo(
({
searchTerm,
onSelect,
expandedCategories,
setExpandedCategories,
}: {
searchTerm: string;
onSelect: (value: string) => void;
expandedCategories: Set<string>;
setExpandedCategories: Dispatch<SetStateAction<Set<string>>>;
}) => {
const { t } = useTranslation();
const templatesArray = useStore($templatesArray);
const pendingConnection = useStore($pendingConnection);
const shouldGroupNodesByCategory = useAppSelector(selectShouldGroupNodesByCategory);
const currentImageFilterItem = useMemo<FilterableItem>(
() => ({
type: 'current_image',
title: t('nodes.currentImage'),
description: t('nodes.currentImageDescription'),
tags: ['progress', 'image', 'current'],
classification: 'stable',
nodePack: 'invokeai',
category: 'image',
}),
[t]
);
const notesFilterItem = useMemo<FilterableItem>(
() => ({
type: 'notes',
title: t('nodes.notes'),
description: t('nodes.notesDescription'),
tags: ['notes'],
classification: 'stable',
nodePack: 'invokeai',
category: 'other',
}),
[t]
);
const items = useMemo<NodeCommandItemData[]>(() => {
// If we have a connection in progress, we need to filter the node choices
const _items: NodeCommandItemData[] = [];
if (!pendingConnection) {
for (const template of templatesArray) {
if (filter(template, searchTerm)) {
_items.push({
label: template.title,
value: template.type,
description: template.description,
classification: template.classification,
nodePack: template.nodePack,
category: template.category,
});
}
}
}
for (const item of [currentImageFilterItem, notesFilterItem]) {
if (filter(item, searchTerm)) {
_items.push({
label: item.title,
value: item.type,
description: item.description,
classification: item.classification,
nodePack: item.nodePack,
});
for (const item of [currentImageFilterItem, notesFilterItem]) {
if (filter(item, searchTerm)) {
_items.push({
label: item.title,
value: item.type,
description: item.description,
classification: item.classification,
nodePack: item.nodePack,
category: item.category,
});
}
}
}
} else {
for (const template of templatesArray) {
if (filter(template, searchTerm)) {
const candidateFields = pendingConnection.handleType === 'source' ? template.inputs : template.outputs;
} else {
for (const template of templatesArray) {
if (filter(template, searchTerm)) {
const candidateFields = pendingConnection.handleType === 'source' ? template.inputs : template.outputs;
for (const [_fieldName, fieldTemplate] of objectEntries(candidateFields)) {
const sourceType =
pendingConnection.handleType === 'source' ? pendingConnection.fieldTemplate.type : fieldTemplate.type;
const targetType =
pendingConnection.handleType === 'target' ? pendingConnection.fieldTemplate.type : fieldTemplate.type;
for (const [_fieldName, fieldTemplate] of objectEntries(candidateFields)) {
const sourceType =
pendingConnection.handleType === 'source' ? pendingConnection.fieldTemplate.type : fieldTemplate.type;
const targetType =
pendingConnection.handleType === 'target' ? pendingConnection.fieldTemplate.type : fieldTemplate.type;
if (validateConnectionTypes(sourceType, targetType)) {
_items.push({
label: template.title,
value: template.type,
description: template.description,
classification: template.classification,
nodePack: template.nodePack,
});
break;
if (validateConnectionTypes(sourceType, targetType)) {
_items.push({
label: template.title,
value: template.type,
description: template.description,
classification: template.classification,
nodePack: template.nodePack,
category: template.category,
});
break;
}
}
}
}
}
// Sort exact title matches to the top when searching
if (searchTerm) {
const lowerSearch = searchTerm.toLowerCase();
_items.sort((a, b) => {
const aExact = a.label.toLowerCase() === lowerSearch;
const bExact = b.label.toLowerCase() === lowerSearch;
if (aExact && !bExact) {
return -1;
}
if (!aExact && bExact) {
return 1;
}
return 0;
});
}
return _items;
}, [pendingConnection, templatesArray, searchTerm, currentImageFilterItem, notesFilterItem]);
const groupedItems = useMemo(() => {
const groups: Record<string, NodeCommandItemData[]> = {};
for (const item of items) {
const cat = item.category;
if (!groups[cat]) {
groups[cat] = [];
}
groups[cat].push(item);
}
// Sort categories alphabetically, but put "other" last.
// When searching, prioritize categories that contain an exact title match.
const lowerSearch = searchTerm.toLowerCase();
return Object.entries(groups).sort(([a, aItems], [b, bItems]) => {
if (searchTerm) {
const aHasExact = aItems.some((item) => item.label.toLowerCase() === lowerSearch);
const bHasExact = bItems.some((item) => item.label.toLowerCase() === lowerSearch);
if (aHasExact && !bHasExact) {
return -1;
}
if (!aHasExact && bHasExact) {
return 1;
}
}
if (a === 'other') {
return 1;
}
if (b === 'other') {
return -1;
}
return a.localeCompare(b);
});
}, [items, searchTerm]);
// When searching, auto-expand all categories; when not searching, use manual state
const isSearching = searchTerm.length > 0;
const expandAll = useCallback(() => {
setExpandedCategories(new Set(groupedItems.map(([cat]) => cat)));
}, [groupedItems, setExpandedCategories]);
const collapseAll = useCallback(() => {
setExpandedCategories(new Set());
}, [setExpandedCategories]);
if (!shouldGroupNodesByCategory) {
return (
<>
{items.map((item) => (
<NodeCommandItem key={item.value} item={item} onSelect={onSelect} />
))}
</>
);
}
return _items;
}, [pendingConnection, templatesArray, searchTerm, currentImageFilterItem, notesFilterItem]);
return (
<>
{items.map((item) => (
<CommandItem key={item.value} value={item.value} onSelect={onSelect} asChild>
<Flex role="button" flexDir="column" sx={cmdkItemSx} py={1} px={2} borderRadius="base">
<Flex alignItems="center" gap={2}>
{item.classification === 'beta' && <Icon boxSize={4} color="invokeYellow.300" as={PiHammerBold} />}
{item.classification === 'prototype' && <Icon boxSize={4} color="invokeRed.300" as={PiFlaskBold} />}
{item.classification === 'internal' && <Icon boxSize={4} color="invokePurple.300" as={PiCircuitryBold} />}
{item.classification === 'special' && <Icon boxSize={4} color="invokeGreen.300" as={PiLightningFill} />}
<Text fontWeight="semibold">{item.label}</Text>
<Spacer />
<Text variant="subtext" fontWeight="semibold">
{item.nodePack}
</Text>
</Flex>
{item.description && <Text color="base.200">{item.description}</Text>}
return (
<>
{!isSearching && (
<Flex gap={1} px={2} pb={1}>
<Button size="sm" variant="ghost" onClick={expandAll}>
{t('common.expandAll')}
</Button>
<Button size="sm" variant="ghost" onClick={collapseAll}>
{t('common.collapseAll')}
</Button>
</Flex>
</CommandItem>
))}
</>
);
});
)}
{groupedItems.map(([category, categoryItems]) => {
const isExpanded = isSearching || expandedCategories.has(category);
return (
<Box key={category}>
<CommandItem value={`__category__:${category}`} onSelect={onSelect} asChild>
<Flex role="button" alignItems="center" gap={2} px={2} py={1.5} borderRadius="base" sx={categoryItemSx}>
<Icon boxSize={3} as={isExpanded ? PiCaretDownBold : PiCaretRightBold} color="base.400" />
<Text fontSize="sm" fontWeight="bold" color="base.400">
{capitalize(category)}
</Text>
<Text fontSize="xs" color="base.500">
({categoryItems.length})
</Text>
</Flex>
</CommandItem>
{isExpanded &&
categoryItems.map((item) => (
<NodeCommandItem key={item.value} item={item} onSelect={onSelect} isGrouped />
))}
</Box>
);
})}
</>
);
}
);
NodeCommandList.displayName = 'CommandListItems';

View File

@@ -24,11 +24,13 @@ import {
selectSelectionMode,
selectShouldAnimateEdges,
selectShouldColorEdges,
selectShouldGroupNodesByCategory,
selectShouldShouldValidateGraph,
selectShouldShowEdgeLabels,
selectShouldSnapToGrid,
shouldAnimateEdgesChanged,
shouldColorEdgesChanged,
shouldGroupNodesByCategoryChanged,
shouldShowEdgeLabelsChanged,
shouldSnapToGridChanged,
shouldValidateGraphChanged,
@@ -50,6 +52,7 @@ const WorkflowEditorSettings = () => {
const shouldAnimateEdges = useAppSelector(selectShouldAnimateEdges);
const shouldShowEdgeLabels = useAppSelector(selectShouldShowEdgeLabels);
const shouldValidateGraph = useAppSelector(selectShouldShouldValidateGraph);
const shouldGroupNodesByCategory = useAppSelector(selectShouldGroupNodesByCategory);
const handleChangeShouldValidate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
@@ -93,6 +96,13 @@ const WorkflowEditorSettings = () => {
[dispatch]
);
const handleChangeShouldGroupNodesByCategory = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldGroupNodesByCategoryChanged(e.target.checked));
},
[dispatch]
);
const { t } = useTranslation();
return (
@@ -145,6 +155,14 @@ const WorkflowEditorSettings = () => {
<FormHelperText>{t('nodes.showEdgeLabelsHelp')}</FormHelperText>
</FormControl>
<Divider />
<FormControl>
<Flex w="full">
<FormLabel>{t('nodes.groupNodesByCategory')}</FormLabel>
<Switch isChecked={shouldGroupNodesByCategory} onChange={handleChangeShouldGroupNodesByCategory} />
</Flex>
<FormHelperText>{t('nodes.groupNodesByCategoryHelp')}</FormHelperText>
</FormControl>
<Divider />
<Heading size="sm" pt={4}>
{t('common.advanced')}
</Heading>

View File

@@ -70,6 +70,7 @@ export const add: InvocationTemplate = {
useCache: true,
nodePack: 'invokeai',
classification: 'stable',
category: 'math',
};
export const sub: InvocationTemplate = {
@@ -128,6 +129,7 @@ export const sub: InvocationTemplate = {
useCache: true,
nodePack: 'invokeai',
classification: 'stable',
category: 'math',
};
export const collect: InvocationTemplate = {
@@ -189,6 +191,7 @@ export const collect: InvocationTemplate = {
useCache: true,
nodePack: 'invokeai',
classification: 'stable',
category: 'collections',
};
const scheduler: InvocationTemplate = {
@@ -245,6 +248,7 @@ const scheduler: InvocationTemplate = {
useCache: true,
nodePack: 'invokeai',
classification: 'stable',
category: 'other',
};
export const main_model_loader: InvocationTemplate = {
@@ -313,6 +317,7 @@ export const main_model_loader: InvocationTemplate = {
useCache: true,
nodePack: 'invokeai',
classification: 'stable',
category: 'model',
};
export const img_resize: InvocationTemplate = {
@@ -457,6 +462,7 @@ export const img_resize: InvocationTemplate = {
useCache: true,
nodePack: 'invokeai',
classification: 'stable',
category: 'image',
};
const iterate: InvocationTemplate = {
@@ -526,6 +532,7 @@ const iterate: InvocationTemplate = {
useCache: true,
nodePack: 'invokeai',
classification: 'stable',
category: 'collections',
};
export const templates: Templates = {
@@ -713,7 +720,6 @@ export const schema = {
required: ['type', 'id'],
title: 'Scheduler',
description: 'Selects a scheduler.',
category: 'latents',
classification: 'stable',
node_pack: 'invokeai',
tags: ['scheduler'],
@@ -1199,6 +1205,7 @@ export const schema = {
title: 'CollectInvocation',
node_pack: 'invokeai',
description: 'Collects values into a collection',
category: 'collections',
classification: 'stable',
version: '1.1.0',
output: {
@@ -1558,6 +1565,7 @@ export const schema = {
required: ['type', 'id'],
title: 'IterateInvocation',
description: 'Iterates over a list of items',
category: 'collections',
classification: 'stable',
node_pack: 'invokeai',
version: '1.1.0',

View File

@@ -16,6 +16,7 @@ const ifTemplate: InvocationTemplate = {
type: 'if',
version: '1.0.0',
tags: [],
category: 'math',
description: 'Selects between two inputs based on a boolean condition',
outputType: 'if_output',
inputs: {
@@ -93,6 +94,7 @@ const floatOutputTemplate: InvocationTemplate = {
type: 'float_output',
version: '1.0.0',
tags: [],
category: 'primitives',
description: 'Outputs a float',
outputType: 'float_output',
inputs: {},
@@ -121,6 +123,7 @@ const integerCollectionOutputTemplate: InvocationTemplate = {
type: 'integer_collection_output',
version: '1.0.0',
tags: [],
category: 'primitives',
description: 'Outputs an integer collection',
outputType: 'integer_collection_output',
inputs: {},

View File

@@ -30,6 +30,7 @@ const zWorkflowSettingsState = z.object({
shouldColorEdges: z.boolean(),
shouldShowEdgeLabels: z.boolean(),
selectionMode: zSelectionMode,
shouldGroupNodesByCategory: z.boolean(),
});
export type WorkflowSettingsState = z.infer<typeof zWorkflowSettingsState>;
@@ -49,6 +50,7 @@ const getInitialState = (): WorkflowSettingsState => ({
shouldShowEdgeLabels: false,
nodeOpacity: 1,
selectionMode: 'partial',
shouldGroupNodesByCategory: true,
});
const slice = createSlice({
@@ -94,6 +96,9 @@ const slice = createSlice({
selectionModeChanged: (state, action: PayloadAction<boolean>) => {
state.selectionMode = action.payload ? 'full' : 'partial';
},
shouldGroupNodesByCategoryChanged: (state, action: PayloadAction<boolean>) => {
state.shouldGroupNodesByCategory = action.payload;
},
},
});
@@ -111,6 +116,7 @@ export const {
shouldValidateGraphChanged,
nodeOpacityChanged,
selectionModeChanged,
shouldGroupNodesByCategoryChanged,
} = slice.actions;
export const workflowSettingsSliceConfig: SliceConfig<typeof slice> = {
@@ -123,6 +129,9 @@ export const workflowSettingsSliceConfig: SliceConfig<typeof slice> = {
if (!('_version' in state)) {
state._version = 1;
}
if (!('shouldGroupNodesByCategory' in state)) {
state.shouldGroupNodesByCategory = false;
}
return zWorkflowSettingsState.parse(state);
},
},
@@ -145,3 +154,4 @@ export const selectNodeSpacing = createWorkflowSettingsSelector((s) => s.nodeSpa
export const selectLayerSpacing = createWorkflowSettingsSelector((s) => s.layerSpacing);
export const selectLayoutDirection = createWorkflowSettingsSelector((s) => s.layoutDirection);
export const selectNodeAlignment = createWorkflowSettingsSelector((s) => s.nodeAlignment);
export const selectShouldGroupNodesByCategory = createWorkflowSettingsSelector((s) => s.shouldGroupNodesByCategory);

View File

@@ -18,6 +18,7 @@ const _zInvocationTemplate = z.object({
useCache: z.boolean(),
nodePack: z.string().min(1).default('invokeai'),
classification: zClassification,
category: z.string().default('other'),
});
export type InvocationTemplate = z.infer<typeof _zInvocationTemplate>;
// #endregion

View File

@@ -113,6 +113,7 @@ export const parseSchema = (
const version = schema.version;
const nodePack = schema.node_pack;
const classification = schema.classification;
const category = schema.category ?? 'other';
const inputs = reduce(
schema.properties,
@@ -260,6 +261,7 @@ export const parseSchema = (
useCache,
nodePack,
classification,
category,
};
Object.assign(invocationsAccumulator, { [type]: invocation });

View File

@@ -4,6 +4,7 @@ import { usePersistedTextAreaSize } from 'common/hooks/usePersistedTextareaSize'
import { negativePromptChanged, selectNegativePromptWithFallback } from 'features/controlLayers/store/paramsSlice';
import { PromptLabel } from 'features/parameters/components/Prompts/PromptLabel';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { PromptResizeHandle } from 'features/parameters/components/Prompts/PromptResizeHandle';
import { ViewModePrompt } from 'features/parameters/components/Prompts/ViewModePrompt';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
@@ -22,6 +23,8 @@ const persistOptions: Parameters<typeof usePersistedTextAreaSize>[2] = {
trackHeight: true,
};
const NEGATIVE_PROMPT_MIN_HEIGHT = 28;
export const ParamNegativePrompt = memo(() => {
const dispatch = useAppDispatch();
const prompt = useAppSelector(selectNegativePromptWithFallback);
@@ -70,14 +73,16 @@ export const ParamNegativePrompt = memo(() => {
onChange={onChange}
onKeyDown={onKeyDown}
variant="darkFilled"
minH={28}
borderTopWidth={24} // This prevents the prompt from being hidden behind the header
paddingInlineEnd={10}
paddingInlineStart={3}
paddingTop={0}
paddingBottom={3}
resize="none"
minH={NEGATIVE_PROMPT_MIN_HEIGHT}
fontFamily="mono"
fontSize="0.82rem"
sx={{ '&::-webkit-resizer': { display: 'none' } }}
/>
<PromptOverlayButtonWrapper>
<AddPromptTriggerButton isOpen={isOpen} onOpen={onOpen} />
@@ -90,6 +95,7 @@ export const ParamNegativePrompt = memo(() => {
label={`${t('parameters.negativePromptPlaceholder')} (${t('stylePresets.preview')})`}
/>
)}
<PromptResizeHandle textareaRef={textareaRef} minHeight={NEGATIVE_PROMPT_MIN_HEIGHT} />
</Box>
</PromptPopover>
);

View File

@@ -11,6 +11,7 @@ import { ShowDynamicPromptsPreviewButton } from 'features/dynamicPrompts/compone
import { NegativePromptToggleButton } from 'features/parameters/components/Core/NegativePromptToggleButton';
import { PromptLabel } from 'features/parameters/components/Prompts/PromptLabel';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { PromptResizeHandle } from 'features/parameters/components/Prompts/PromptResizeHandle';
import { ViewModePrompt } from 'features/parameters/components/Prompts/ViewModePrompt';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
@@ -35,6 +36,8 @@ const persistOptions: Parameters<typeof usePersistedTextAreaSize>[2] = {
initialHeight: 120,
};
const POSITIVE_PROMPT_MIN_HEIGHT = 32;
const usePromptHistory = () => {
const store = useAppStore();
const history = useAppSelector(selectPositivePromptHistory);
@@ -215,10 +218,11 @@ export const ParamPositivePrompt = memo(() => {
paddingInlineStart={3}
paddingTop={0}
paddingBottom={3}
resize="vertical"
minH={32}
resize="none"
minH={POSITIVE_PROMPT_MIN_HEIGHT}
fontFamily="mono"
fontSize="0.82rem"
sx={{ '&::-webkit-resizer': { display: 'none' } }}
/>
<PromptOverlayButtonWrapper>
<Flex flexDir="column" gap={2} justifyContent="flex-start" alignItems="center">
@@ -236,6 +240,7 @@ export const ParamPositivePrompt = memo(() => {
label={`${t('parameters.positivePromptPlaceholder')} (${t('stylePresets.preview')})`}
/>
)}
<PromptResizeHandle textareaRef={textareaRef} minHeight={POSITIVE_PROMPT_MIN_HEIGHT} />
</Box>
</PromptPopover>
</Box>

View File

@@ -0,0 +1,135 @@
import { Box } from '@invoke-ai/ui-library';
import {
memo,
type PointerEvent as ReactPointerEvent,
type RefObject,
useCallback,
useEffect,
useRef,
useState,
} from 'react';
type PromptResizeHandleProps = {
textareaRef: RefObject<HTMLTextAreaElement>;
minHeight: number;
};
const PROMPT_RESIZE_HANDLE_HEIGHT_PX = 8;
export const PromptResizeHandle = memo(({ textareaRef, minHeight }: PromptResizeHandleProps) => {
const activePointerIdRef = useRef<number | null>(null);
const startHeightRef = useRef(0);
const startYRef = useRef(0);
const previousCursorRef = useRef('');
const previousUserSelectRef = useRef('');
const [isResizing, setIsResizing] = useState(false);
const stopResize = useCallback(() => {
if (activePointerIdRef.current === null) {
return;
}
activePointerIdRef.current = null;
setIsResizing(false);
document.body.style.cursor = previousCursorRef.current;
document.body.style.userSelect = previousUserSelectRef.current;
}, []);
useEffect(() => stopResize, [stopResize]);
const onPointerDown = useCallback(
(e: ReactPointerEvent<HTMLDivElement>) => {
if (e.button !== 0) {
return;
}
const textarea = textareaRef.current;
if (!textarea) {
return;
}
activePointerIdRef.current = e.pointerId;
startYRef.current = e.clientY;
startHeightRef.current = textarea.offsetHeight;
previousCursorRef.current = document.body.style.cursor;
previousUserSelectRef.current = document.body.style.userSelect;
document.body.style.cursor = 'ns-resize';
document.body.style.userSelect = 'none';
e.currentTarget.setPointerCapture(e.pointerId);
setIsResizing(true);
e.preventDefault();
},
[textareaRef]
);
const onPointerMove = useCallback(
(e: ReactPointerEvent<HTMLDivElement>) => {
if (activePointerIdRef.current !== e.pointerId) {
return;
}
const textarea = textareaRef.current;
if (!textarea) {
return;
}
const nextHeight = Math.max(minHeight, startHeightRef.current + e.clientY - startYRef.current);
textarea.style.height = `${nextHeight}px`;
e.preventDefault();
},
[minHeight, textareaRef]
);
const onPointerUp = useCallback(
(e: ReactPointerEvent<HTMLDivElement>) => {
if (activePointerIdRef.current !== e.pointerId) {
return;
}
if (e.currentTarget.hasPointerCapture(e.pointerId)) {
e.currentTarget.releasePointerCapture(e.pointerId);
}
stopResize();
},
[stopResize]
);
const onPointerCancel = useCallback(
(e: ReactPointerEvent<HTMLDivElement>) => {
if (activePointerIdRef.current !== e.pointerId) {
return;
}
stopResize();
},
[stopResize]
);
return (
<Box
aria-hidden
pos="absolute"
insetInlineStart={0}
insetInlineEnd={0}
insetBlockEnd={0}
h={`${PROMPT_RESIZE_HANDLE_HEIGHT_PX}px`}
borderBottomRadius="base"
bg={isResizing ? 'base.500' : 'base.700'}
cursor="ns-resize"
zIndex={1}
style={{ touchAction: 'none' }}
transitionProperty="background-color"
transitionDuration="normal"
_hover={{ bg: 'base.600' }}
onPointerDown={onPointerDown}
onPointerMove={onPointerMove}
onPointerUp={onPointerUp}
onPointerCancel={onPointerCancel}
onLostPointerCapture={stopResize}
/>
);
});
PromptResizeHandle.displayName = 'PromptResizeHandle';

View File

@@ -118,6 +118,7 @@ export const useHotkeyData = (): HotkeysData => {
addHotkey('canvas', 'selectEraserTool', ['e']);
addHotkey('canvas', 'selectMoveTool', ['v']);
addHotkey('canvas', 'selectRectTool', ['u']);
addHotkey('canvas', 'selectLassoTool', ['l']);
addHotkey('canvas', 'selectViewTool', ['h']);
addHotkey('canvas', 'selectColorPickerTool', ['i']);
addHotkey('canvas', 'setFillColorsToDefault', ['d']);

View File

@@ -11,6 +11,8 @@ import {
ModalFooter,
ModalHeader,
ModalOverlay,
NumberInput,
NumberInputField,
Switch,
Text,
} from '@invoke-ai/ui-library';
@@ -19,6 +21,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { buildUseBoolean } from 'common/hooks/useBoolean';
import { selectCurrentUser } from 'features/auth/store/authSlice';
import { selectShouldUseCPUNoise, shouldUseCpuNoiseChanged } from 'features/controlLayers/store/paramsSlice';
import { ExternalProviderStatusList } from 'features/system/components/SettingsModal/ExternalProviderStatusList';
import { useRefreshAfterResetModal } from 'features/system/components/SettingsModal/RefreshAfterResetModal';
@@ -49,15 +52,25 @@ import {
shouldUseNSFWCheckerChanged,
shouldUseWatermarkerChanged,
} from 'features/system/store/systemSlice';
import { toast } from 'features/toast/toast';
import { selectShouldShowProgressInViewer } from 'features/ui/store/uiSelectors';
import { setShouldShowProgressInViewer } from 'features/ui/store/uiSlice';
import { type ChangeEvent, cloneElement, memo, type ReactElement, useCallback, useEffect } from 'react';
import type { ChangeEvent, KeyboardEvent, ReactElement } from 'react';
import { cloneElement, memo, useCallback, useEffect, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetRuntimeConfigQuery, useUpdateRuntimeConfigMutation } from 'services/api/endpoints/appInfo';
import { SettingsLanguageSelect } from './SettingsLanguageSelect';
const [useSettingsModal] = buildUseBoolean(false);
const formatOptionalInteger = (value: number | null | undefined) => {
if (value === null || value === undefined) {
return '';
}
return String(value);
};
const SettingsModal = (props: { children: ReactElement }) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
@@ -72,6 +85,10 @@ const SettingsModal = (props: { children: ReactElement }) => {
const settingsModal = useSettingsModal();
const refreshModal = useRefreshAfterResetModal();
const currentUser = useAppSelector(selectCurrentUser);
const { data: runtimeConfig } = useGetRuntimeConfigQuery();
const [updateRuntimeConfig, { isLoading: isUpdatingRuntimeConfig }] = useUpdateRuntimeConfigMutation();
const pendingMaxQueueHistoryRef = useRef<number | null | undefined>(undefined);
const prefersNumericAttentionWeights = useAppSelector(selectSystemPrefersNumericAttentionWeights);
const shouldUseCpuNoise = useAppSelector(selectShouldUseCPUNoise);
@@ -85,6 +102,10 @@ const SettingsModal = (props: { children: ReactElement }) => {
const shouldHighlightFocusedRegions = useAppSelector(selectSystemShouldEnableHighlightFocusedRegions);
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
const shouldShowInvocationProgressDetail = useAppSelector(selectSystemShouldShowInvocationProgressDetail);
const maxQueueHistory = runtimeConfig?.config.max_queue_history ?? null;
const canEditRuntimeConfig = runtimeConfig ? !runtimeConfig.config.multiuser || currentUser?.is_admin : false;
const [maxQueueHistoryInput, setMaxQueueHistoryInput] = useState(formatOptionalInteger(maxQueueHistory));
const onToggleConfirmOnNewSession = useCallback(() => {
dispatch(shouldConfirmOnNewSessionToggled());
}, [dispatch]);
@@ -96,11 +117,60 @@ const SettingsModal = (props: { children: ReactElement }) => {
}
}, [refetchIntermediatesCount, settingsModal.isTrue]);
useEffect(() => {
setMaxQueueHistoryInput(formatOptionalInteger(maxQueueHistory));
}, [maxQueueHistory]);
const commitMaxQueueHistory = useCallback(async () => {
if (!runtimeConfig || !canEditRuntimeConfig) {
return;
}
const trimmedValue = maxQueueHistoryInput.trim();
const parsedValue = trimmedValue === '' ? null : Number.parseInt(trimmedValue, 10);
if (parsedValue !== null && Number.isNaN(parsedValue)) {
setMaxQueueHistoryInput(formatOptionalInteger(maxQueueHistory));
return;
}
const normalizedValue = parsedValue === null ? null : Math.max(0, parsedValue);
const currentValue =
pendingMaxQueueHistoryRef.current === undefined ? maxQueueHistory : pendingMaxQueueHistoryRef.current;
if (normalizedValue === currentValue) {
setMaxQueueHistoryInput(formatOptionalInteger(currentValue));
return;
}
pendingMaxQueueHistoryRef.current = normalizedValue;
setMaxQueueHistoryInput(formatOptionalInteger(normalizedValue));
try {
await updateRuntimeConfig({ max_queue_history: normalizedValue }).unwrap();
} catch {
setMaxQueueHistoryInput(formatOptionalInteger(maxQueueHistory));
toast({
id: 'SETTINGS_MAX_QUEUE_HISTORY_SAVE_FAILED',
title: t('settings.maxQueueHistorySaveFailed'),
status: 'error',
});
} finally {
pendingMaxQueueHistoryRef.current = undefined;
}
}, [canEditRuntimeConfig, maxQueueHistory, maxQueueHistoryInput, runtimeConfig, t, updateRuntimeConfig]);
const handleCloseSettingsModal = useCallback(() => {
void commitMaxQueueHistory();
settingsModal.setFalse();
}, [commitMaxQueueHistory, settingsModal]);
const handleClickResetWebUI = useCallback(() => {
void commitMaxQueueHistory();
clearStorage();
settingsModal.setFalse();
refreshModal.setTrue();
}, [settingsModal, refreshModal]);
}, [commitMaxQueueHistory, refreshModal, settingsModal]);
const handleChangeShouldConfirmOnDelete = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
@@ -172,12 +242,30 @@ const SettingsModal = (props: { children: ReactElement }) => {
[dispatch]
);
const handleChangeMaxQueueHistory = useCallback((valueAsString: string) => {
setMaxQueueHistoryInput(valueAsString);
}, []);
const handleBlurMaxQueueHistory = useCallback(() => {
void commitMaxQueueHistory();
}, [commitMaxQueueHistory]);
const handleKeyDownMaxQueueHistory = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (e.key === 'Enter') {
void commitMaxQueueHistory();
e.currentTarget.blur();
}
},
[commitMaxQueueHistory]
);
return (
<>
{cloneElement(props.children, {
onClick: settingsModal.setTrue,
})}
<Modal isOpen={settingsModal.isTrue} onClose={settingsModal.setFalse} size="2xl" isCentered useInert={false}>
<Modal isOpen={settingsModal.isTrue} onClose={handleCloseSettingsModal} size="2xl" isCentered useInert={false}>
<ModalOverlay />
<ModalContent maxH="80vh" h="68rem">
<ModalHeader bg="none">{t('common.settingsLabel')}</ModalHeader>
@@ -206,6 +294,21 @@ const SettingsModal = (props: { children: ReactElement }) => {
<FormLabel>{t('settings.enableInvisibleWatermark')}</FormLabel>
<Switch isChecked={shouldUseWatermarker} onChange={handleChangeShouldUseWatermarker} />
</FormControl>
<FormControl>
<FormLabel>{t('settings.maxQueueHistory')}</FormLabel>
<NumberInput
min={0}
step={1}
value={maxQueueHistoryInput}
onChange={handleChangeMaxQueueHistory}
onBlur={handleBlurMaxQueueHistory}
clampValueOnBlur={false}
isDisabled={!runtimeConfig || !canEditRuntimeConfig || isUpdatingRuntimeConfig}
w="8rem"
>
<NumberInputField onKeyDown={handleKeyDownMaxQueueHistory} />
</NumberInput>
</FormControl>
</StickyScrollable>
<StickyScrollable title={t('settings.models')}>

View File

@@ -56,6 +56,26 @@ export const appInfoApi = api.injectEndpoints({
url: buildAppInfoUrl('runtime_config'),
method: 'GET',
}),
providesTags: ['AppConfig'],
}),
updateRuntimeConfig: build.mutation<
paths['/api/v1/app/runtime_config']['patch']['responses']['200']['content']['application/json'],
paths['/api/v1/app/runtime_config']['patch']['requestBody']['content']['application/json']
>({
query: (body) => ({
url: buildAppInfoUrl('runtime_config'),
method: 'PATCH',
body,
}),
async onQueryStarted(_, { dispatch, queryFulfilled }) {
try {
const { data } = await queryFulfilled;
dispatch(appInfoApi.util.upsertQueryData('getRuntimeConfig', undefined, data));
} catch {
// no-op
}
},
invalidatesTags: ['AppConfig'],
}),
getExternalProviderStatuses: build.query<ExternalProviderStatus[], void>({
query: () => ({
@@ -133,6 +153,7 @@ export const {
useGetExternalProviderConfigsQuery,
useSetExternalProviderConfigMutation,
useResetExternalProviderConfigMutation,
useUpdateRuntimeConfigMutation,
useClearInvocationCacheMutation,
useDisableInvocationCacheMutation,
useEnableInvocationCacheMutation,

View File

@@ -1586,7 +1586,8 @@ export type paths = {
delete?: never;
options?: never;
head?: never;
patch?: never;
/** Update Runtime Config */
patch: operations["update_runtime_config"];
trace?: never;
};
"/api/v1/app/external_providers/status": {
@@ -15745,7 +15746,8 @@ export type components = {
* force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
* pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
* max_queue_size: Maximum number of items in the session queue.
* clear_queue_on_startup: Empties session queue on startup.
* clear_queue_on_startup: Empties session queue on startup. If true, disables `max_queue_history`.
* max_queue_history: Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true.
* allow_nodes: List of nodes to allow. Omit to allow all.
* deny_nodes: List of nodes to deny. Omit to deny none.
* node_cache_size: How many cached nodes to keep in memory.
@@ -16077,10 +16079,15 @@ export type components = {
max_queue_size?: number;
/**
* Clear Queue On Startup
* @description Empties session queue on startup.
* @description Empties session queue on startup. If true, disables `max_queue_history`.
* @default false
*/
clear_queue_on_startup?: boolean;
/**
* Max Queue History
* @description Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true.
*/
max_queue_history?: number | null;
/**
* Allow Nodes
* @description List of nodes to allow. Omit to allow all.
@@ -29296,6 +29303,17 @@ export type components = {
*/
unstarred_images: string[];
};
/**
* UpdateAppGenerationSettingsRequest
* @description Writable generation-related app settings.
*/
UpdateAppGenerationSettingsRequest: {
/**
* Max Queue History
* @description Keep the last N completed, failed, and canceled queue items on startup. Set to 0 to prune all terminal items.
*/
max_queue_history?: number | null;
};
/**
* UserDTO
* @description User data transfer object.
@@ -34462,6 +34480,39 @@ export interface operations {
};
};
};
update_runtime_config: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody: {
content: {
"application/json": components["schemas"]["UpdateAppGenerationSettingsRequest"];
};
};
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["InvokeAIAppConfigWithSetFields"];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
get_external_provider_statuses: {
parameters: {
query?: never;