diff --git a/docs/contributing/CANVAS_PROJECTS/CANVAS_PROJECTS.md b/docs/contributing/CANVAS_PROJECTS/CANVAS_PROJECTS.md new file mode 100644 index 0000000000..a05ef29492 --- /dev/null +++ b/docs/contributing/CANVAS_PROJECTS/CANVAS_PROJECTS.md @@ -0,0 +1,205 @@ +# Canvas Projects — Technical Documentation + +## Overview + +Canvas Projects provide a save/load mechanism for the entire canvas state. The feature serializes all canvas entities, generation parameters, reference images, and their associated image files into a ZIP-based `.invk` file. On load, it restores the full state, handling image deduplication and re-uploading as needed. + +## File Format + +The `.invk` file is a standard ZIP archive with the following structure: + +``` +project.invk +├── manifest.json +├── canvas_state.json +├── params.json +├── ref_images.json +├── loras.json +└── images/ + ├── {image_name_1}.png + ├── {image_name_2}.png + └── ... +``` + +### manifest.json + +Schema version and metadata. Validated on load with Zod. + +```json +{ + "version": 1, + "appVersion": "5.12.0", + "createdAt": "2026-02-26T12:00:00.000Z", + "name": "My Canvas Project" +} +``` + +| Field | Type | Description | +|---|---|---| +| `version` | `number` | Schema version, currently `1`. Used for migration logic on load. | +| `appVersion` | `string` | InvokeAI version that created the file. Informational only. | +| `createdAt` | `string` | ISO 8601 timestamp. | +| `name` | `string` | User-provided project name. Also used as the download filename. | + +### canvas_state.json + +The serialized canvas entity tree. Type: `CanvasProjectState`. + +```typescript +type CanvasProjectState = { + rasterLayers: CanvasRasterLayerState[]; + controlLayers: CanvasControlLayerState[]; + inpaintMasks: CanvasInpaintMaskState[]; + regionalGuidance: CanvasRegionalGuidanceState[]; + bbox: CanvasState['bbox']; + selectedEntityIdentifier: CanvasState['selectedEntityIdentifier']; + bookmarkedEntityIdentifier: CanvasState['bookmarkedEntityIdentifier']; +}; +``` + +Each entity contains its full state including all canvas objects (brush lines, eraser lines, rect shapes, images). Image objects reference files by `image_name` which correspond to files in the `images/` folder. + +### params.json + +The complete generation parameters state (`ParamsState`). Optional on load (older files may not have it). This includes all fields from the params Redux slice: + +- Prompts (positive, negative, prompt history) +- Core generation settings (seed, steps, CFG scale, guidance, scheduler, iterations) +- Model selections (main model, VAE, FLUX VAE, T5 encoder, CLIP embed models, refiner, Z-Image models, Klein models) +- Dimensions (width, height, aspect ratio) +- Img2img strength +- Infill settings (method, tile size, patchmatch downscale, color) +- Canvas coherence settings (mode, edge size, min denoise) +- Refiner parameters (steps, CFG scale, scheduler, aesthetic scores, start) +- FLUX-specific settings (scheduler, DyPE preset/scale/exponent) +- Z-Image-specific settings (scheduler, seed variance) +- Upscale settings (scheduler, CFG scale) +- Seamless tiling, mask blur, CLIP skip, VAE precision, CPU noise, color compensation + +### ref_images.json + +Global reference image entities (`RefImageState[]`). These are IP-Adapter / FLUX Redux configs with `CroppableImageWithDims` containing both original and cropped image references. Optional on load. + +### loras.json + +Array of LoRA configurations (`LoRA[]`). Each entry contains: + +```typescript +type LoRA = { + id: string; + isEnabled: boolean; + model: ModelIdentifierField; + weight: number; +}; +``` + +Optional on load. Like models, LoRA identifiers are stored as-is — if a LoRA is not installed when loading, the entry is restored but may not be usable. + +### images/ + +All image files referenced anywhere in the state. Keyed by their original `image_name`. On save, each image is fetched from the backend via `GET /api/v1/images/i/{name}/full` and stored as-is. + +## Key Source Files + +| File | Purpose | +|---|---| +| `features/controlLayers/util/canvasProjectFile.ts` | Types, constants, image name collection, remapping, existence checking | +| `features/controlLayers/hooks/useCanvasProjectSave.ts` | Save hook — collects Redux state, fetches images, builds ZIP | +| `features/controlLayers/hooks/useCanvasProjectLoad.ts` | Load hook — parses ZIP, deduplicates images, dispatches state | +| `features/controlLayers/components/SaveCanvasProjectDialog.tsx` | Save name dialog + `useSaveCanvasProjectWithDialog` hook | +| `features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog.tsx` | Load confirmation dialog + `useLoadCanvasProjectWithDialog` hook | +| `features/controlLayers/components/Toolbar/CanvasToolbarProjectMenuButton.tsx` | Toolbar dropdown UI | +| `features/controlLayers/store/canvasSlice.ts` | `canvasProjectRecalled` Redux action | + +## Save Flow + +1. User clicks "Save Canvas Project" → `SaveCanvasProjectDialog` opens asking for a project name +2. On confirm, `saveCanvasProject(name)` is called +3. Read Redux state via selectors: `selectCanvasSlice()`, `selectParamsSlice()`, `selectRefImagesSlice()`, `selectLoRAsSlice()` +4. Build `CanvasProjectState` from the canvas slice; use `paramsState` directly for params +5. Walk all entities to collect every `image_name` reference via `collectImageNames()`: + - `CanvasImageState.image.image_name` in layer/mask objects + - `CroppableImageWithDims.original.image.image_name` in global ref images + - `CroppableImageWithDims.crop.image.image_name` in cropped ref images + - `ImageWithDims.image_name` in regional guidance ref images +6. Fetch each image from the backend API +7. Build ZIP with JSZip: add `manifest.json` (including `name`), `canvas_state.json`, `params.json`, `ref_images.json`, and all images into `images/` +8. Sanitize the name for filesystem use and generate blob, trigger download as `{name}.invk` + +## Load Flow + +1. User selects `.invk` file → confirmation dialog opens +2. On confirm, parse ZIP with JSZip +3. Validate manifest version via Zod schema +4. Read `canvas_state.json`, `params.json` (optional), `ref_images.json` (optional) +5. Collect all `image_name` references from the loaded state +6. **Deduplicate images**: for each referenced image, check if it exists on the server via `getImageDTOSafe(image_name)` + - Already exists → skip (no upload) + - Missing → upload from ZIP via `uploadImage()`, record `oldName → newName` mapping +7. Remap all `image_name` values in the loaded state using the mapping (only for re-uploaded images whose names changed) +8. Dispatch Redux actions: + - `canvasProjectRecalled()` — restores all canvas entities, bbox, selected/bookmarked entity + - `refImagesRecalled()` — restores global reference images + - `paramsRecalled()` — replaces the entire params state in one action + - `loraAllDeleted()` + `loraRecalled()` — restores LoRAs +9. Show success/error toast + +## Image Name Collection & Remapping + +The `canvasProjectFile.ts` utility provides two parallel sets of functions: + +**Collection** (`collectImageNames`): Walks the entire state tree and returns a `Set` of all referenced `image_name` values. This is used by both save (to know which images to fetch) and load (to know which images to check/upload). + +**Remapping** (`remapCanvasState`, `remapRefImages`): Deep-clones state objects and replaces `image_name` values using a `Map` mapping. Only images that were re-uploaded with a different name are remapped. Images that already existed on the server are left unchanged. + +Both walk the same paths through the state tree: +- Layer/mask objects → `CanvasImageState.image.image_name` +- Regional guidance ref images → `ImageWithDims.image_name` +- Global ref images → `CroppableImageWithDims.original.image.image_name` and `.crop.image.image_name` + +## Extending the Format + +### Adding new optional data (non-breaking) + +Add a new JSON file to the ZIP. No version bump needed. + +1. **Save**: Add `zip.file('new_data.json', JSON.stringify(data))` in `useCanvasProjectSave.ts` +2. **Load**: Read with `zip.file('new_data.json')` in `useCanvasProjectLoad.ts` — check for `null` so older project files without it still load +3. **Dispatch**: Add the appropriate Redux action to restore the data + +### Adding new entity types with images + +1. Extend `CanvasProjectState` type in `canvasProjectFile.ts` +2. Add collection logic in `collectImageNames()` to walk the new entity's objects +3. Add remapping logic in `remapCanvasState()` to update image names +4. Include the new entity array in both save and load hooks +5. Handle it in the `canvasProjectRecalled` reducer in `canvasSlice.ts` + +### Breaking schema changes + +1. Bump `CANVAS_PROJECT_VERSION` in `canvasProjectFile.ts` +2. Update the Zod manifest schema: `version: z.union([z.literal(1), z.literal(2)])` +3. Add migration logic in the load hook: check version, transform v1 → v2 before dispatching + +## UI Architecture + +### Save dialog + +The save flow uses a **nanostore atom** (`$isOpen`) to control the `SaveCanvasProjectDialog`: + +1. `useSaveCanvasProjectWithDialog()` — returns a callback that sets `$isOpen` to `true` +2. `SaveCanvasProjectDialog` (singleton in `GlobalModalIsolator`) — renders an `AlertDialog` with a name input +3. On save → calls `saveCanvasProject(name)` and closes the dialog +4. On cancel → closes the dialog + +### Load dialog + +The load flow uses a **nanostore atom** (`$pendingFile`) to decouple the file dialog from the confirmation dialog: + +1. `useLoadCanvasProjectWithDialog()` — opens a programmatic file input (`document.createElement('input')`) +2. On file selection → sets `$pendingFile` atom +3. `LoadCanvasProjectConfirmationAlertDialog` (singleton in `GlobalModalIsolator`) — subscribes to `$pendingFile` via `useStore()` +4. On accept → calls `loadCanvasProject(file)` and clears the atom +5. On cancel → clears the atom + +The programmatic file input approach was chosen because the context menu component uses `isLazy: true`, which unmounts the DOM tree when the menu closes — a hidden `` element inside the menu would be destroyed before the file dialog returns. diff --git a/docs/features/Lasso_tool.md b/docs/features/Lasso_tool.md new file mode 100644 index 0000000000..8f7fc6d4ec --- /dev/null +++ b/docs/features/Lasso_tool.md @@ -0,0 +1,32 @@ +Lasso Tool +=========== + +- The Lasso tool creates selections and inpaint masks by drawing freehand or polygonal regions on the canvas. + +How to open the Lasso tool +-------------------------- +- Click the Lasso icon in the toolbar. +- Hotkey: press `L` (default). The hotkey is shown in the tool's tooltip and can be customized in Hotkeys settings. + +Modes +----- +- Freehand (default) + - Hold the pointer and drag to draw a continuous contour. + - Long segments are broken into intermediate points to keep the line continuous. + - Very long strokes may be simplified after drawing to reduce point count for performance. + +- Polygon + - Click to place points; click the first point (or a point near it) to close the polygon. + - The tool snaps the closing point to the start for precise closures. + +Basic interactions +------------------ +- Switch modes with the mode toggle in the toolbar. +- To close a polygon: click the starting point again or click near it — the tool aligns the final point to the start to complete the shape. +- The selection will be added to the current Inpaint Mask layer. If no Inpaint Mask layer exists, a new one will be created automatically. + +Tips & behavior +--------------- +- Hold `Space` to temporarily switch to the View tool for panning and zooming; release `Space` to return to the Lasso tool and continue drawing. +- When using the Polygon mode, you can hold `Shift` to snap points to horizontal, vertical, or 45-degree angles for more precise shapes. +- Hold `Ctrl` (Windows/Linux) or `Command` (macOS) while drawing to subtract from the current selection instead of adding to it. diff --git a/docs/features/canvas_projects.md b/docs/features/canvas_projects.md new file mode 100644 index 0000000000..8b161c6745 --- /dev/null +++ b/docs/features/canvas_projects.md @@ -0,0 +1,56 @@ +--- +title: Canvas Projects +--- + +# :material-folder-zip: Canvas Projects + +## Save and Restore Your Canvas Work + +Canvas Projects let you save your entire canvas setup to a file and load it back later. This is useful when you want to: + +- **Switch between tasks** without losing your current canvas arrangement +- **Back up complex setups** with multiple layers, masks, and reference images +- **Share canvas layouts** with others or transfer them between machines +- **Recover from deleted images** — all images are embedded in the project file + +## What Gets Saved + +A canvas project file (`.invk`) captures everything about your current canvas session: + +- **All layers** — raster layers, control layers, inpaint masks, regional guidance +- **All drawn content** — brush strokes, pasted images, eraser marks +- **Reference images** — global IP-Adapter / FLUX Redux images with crop settings +- **Regional guidance** — per-region prompts and reference images +- **Bounding box** — position, size, aspect ratio, and scale settings +- **All generation parameters** — prompts, seed, steps, CFG scale, guidance, scheduler, model, VAE, dimensions, img2img strength, infill settings, canvas coherence, refiner settings, FLUX/Z-Image specific parameters, and more +- **LoRAs** — all added LoRA models with their weights and enabled/disabled state + +## How to Save a Project + +You can save from two places: + +1. **Toolbar** — Click the **Archive icon** in the canvas toolbar, then select **Save Canvas Project** +2. **Context menu** — Right-click the canvas, open the **Project** submenu, then select **Save Canvas Project** + +A dialog will ask you to enter a **project name**. This name is used as the filename (e.g., entering "My Portrait" saves as `My Portrait.invk`) and is stored inside the project file. + +## How to Load a Project + +1. **Toolbar** — Click the **Archive icon**, then select **Load Canvas Project** +2. **Context menu** — Right-click the canvas, open the **Project** submenu, then select **Load Canvas Project** + +A file dialog will open. Select your `.invk` file. You will see a confirmation dialog warning that loading will replace your current canvas. Click **Load** to proceed. + +### What Happens on Load + +- Your current canvas is **completely replaced** — all existing layers, masks, reference images, and parameters are overwritten +- Images that are already present on your InvokeAI server are reused automatically (no duplicate uploads) +- Images that were deleted from the server are re-uploaded from the project file +- If the saved model is not installed on your system, the model identifier is still restored — you will need to select an available model manually + +## Good to Know + +- **No undo** — Loading a project replaces your canvas entirely. There is no way to undo this action, so save your current project first if you want to keep it. +- **Image deduplication** — When loading, images already on your server are not re-uploaded. Only missing images are uploaded from the project file. +- **File size** — The `.invk` file size depends on the number and resolution of images in your canvas. A project with many high-resolution layers can be large. +- **Model availability** — The project saves which model was selected, but does not include the model itself. If the model is not installed when you load the project, you will need to select a different one. diff --git a/invokeai/app/api/routers/app_info.py b/invokeai/app/api/routers/app_info.py index d8f3bb2f80..da777ebc73 100644 --- a/invokeai/app/api/routers/app_info.py +++ b/invokeai/app/api/routers/app_info.py @@ -6,8 +6,14 @@ from fastapi import Body from fastapi.routing import APIRouter from pydantic import BaseModel, Field +from invokeai.app.api.auth_dependencies import AdminUserOrDefault from invokeai.app.api.dependencies import ApiDependencies -from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config +from invokeai.app.services.config.config_default import ( + DefaultInvokeAIAppConfig, + InvokeAIAppConfig, + get_config, + load_and_migrate_config, +) from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch from invokeai.backend.util.logging import logging @@ -64,6 +70,16 @@ class InvokeAIAppConfigWithSetFields(BaseModel): config: InvokeAIAppConfig = Field(description="The InvokeAI App Config") +class UpdateAppGenerationSettingsRequest(BaseModel): + """Writable generation-related app settings.""" + + max_queue_history: int | None = Field( + default=None, + ge=0, + description="Keep the last N completed, failed, and canceled queue items on startup. Set to 0 to prune all terminal items.", + ) + + @app_router.get( "/runtime_config", operation_id="get_runtime_config", status_code=200, response_model=InvokeAIAppConfigWithSetFields ) @@ -72,6 +88,30 @@ async def get_runtime_config() -> InvokeAIAppConfigWithSetFields: return InvokeAIAppConfigWithSetFields(set_fields=config.model_fields_set, config=config) +@app_router.patch( + "/runtime_config", + operation_id="update_runtime_config", + status_code=200, + response_model=InvokeAIAppConfigWithSetFields, +) +async def update_runtime_config( + _: AdminUserOrDefault, + changes: UpdateAppGenerationSettingsRequest = Body(description="Writable runtime configuration changes"), +) -> InvokeAIAppConfigWithSetFields: + config = get_config() + update_dict = changes.model_dump(exclude_unset=True) + config.update_config(update_dict) + + if config.config_file_path.exists(): + persisted_config = load_and_migrate_config(config.config_file_path) + else: + persisted_config = DefaultInvokeAIAppConfig() + + persisted_config.update_config(update_dict) + persisted_config.write_file(config.config_file_path) + return InvokeAIAppConfigWithSetFields(set_fields=config.model_fields_set, config=config) + + @app_router.get( "/logging", operation_id="get_log_level", diff --git a/invokeai/app/invocations/batch.py b/invokeai/app/invocations/batch.py index 34ecd38f26..f79b8816ad 100644 --- a/invokeai/app/invocations/batch.py +++ b/invokeai/app/invocations/batch.py @@ -56,7 +56,7 @@ class BaseBatchInvocation(BaseInvocation): "image_batch", title="Image Batch", tags=["primitives", "image", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) @@ -87,7 +87,7 @@ class ImageGeneratorField(BaseModel): "image_generator", title="Image Generator", tags=["primitives", "board", "image", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) @@ -111,7 +111,7 @@ class ImageGenerator(BaseInvocation): "string_batch", title="String Batch", tags=["primitives", "string", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) @@ -142,7 +142,7 @@ class StringGeneratorField(BaseModel): "string_generator", title="String Generator", tags=["primitives", "string", "number", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) @@ -166,7 +166,7 @@ class StringGenerator(BaseInvocation): "integer_batch", title="Integer Batch", tags=["primitives", "integer", "number", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) @@ -195,7 +195,7 @@ class IntegerGeneratorField(BaseModel): "integer_generator", title="Integer Generator", tags=["primitives", "int", "number", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) @@ -219,7 +219,7 @@ class IntegerGenerator(BaseInvocation): "float_batch", title="Float Batch", tags=["primitives", "float", "number", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) @@ -250,7 +250,7 @@ class FloatGeneratorField(BaseModel): "float_generator", title="Float Generator", tags=["primitives", "float", "number", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) diff --git a/invokeai/app/invocations/canny.py b/invokeai/app/invocations/canny.py index 0cdc386e62..dbfde6d353 100644 --- a/invokeai/app/invocations/canny.py +++ b/invokeai/app/invocations/canny.py @@ -11,7 +11,7 @@ from invokeai.backend.image_util.util import cv2_to_pil, pil_to_cv2 "canny_edge_detection", title="Canny Edge Detection", tags=["controlnet", "canny"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class CannyEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/cogview4_denoise.py b/invokeai/app/invocations/cogview4_denoise.py index 070d8a3478..e8b910f731 100644 --- a/invokeai/app/invocations/cogview4_denoise.py +++ b/invokeai/app/invocations/cogview4_denoise.py @@ -33,7 +33,7 @@ from invokeai.backend.util.devices import TorchDevice "cogview4_denoise", title="Denoise - CogView4", tags=["image", "cogview4"], - category="image", + category="latents", version="1.0.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/cogview4_image_to_latents.py b/invokeai/app/invocations/cogview4_image_to_latents.py index 630b9ab1e3..facbc38dd4 100644 --- a/invokeai/app/invocations/cogview4_image_to_latents.py +++ b/invokeai/app/invocations/cogview4_image_to_latents.py @@ -27,7 +27,7 @@ from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory "cogview4_i2l", title="Image to Latents - CogView4", tags=["image", "latents", "vae", "i2l", "cogview4"], - category="image", + category="latents", version="1.0.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/cogview4_text_encoder.py b/invokeai/app/invocations/cogview4_text_encoder.py index 3b5b1dc73f..13234889fb 100644 --- a/invokeai/app/invocations/cogview4_text_encoder.py +++ b/invokeai/app/invocations/cogview4_text_encoder.py @@ -20,7 +20,7 @@ COGVIEW4_GLM_MAX_SEQ_LEN = 1024 "cogview4_text_encoder", title="Prompt - CogView4", tags=["prompt", "conditioning", "cogview4"], - category="conditioning", + category="prompt", version="1.0.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index bd3dedb3f8..39e77f5b63 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -11,9 +11,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX -@invocation( - "range", title="Integer Range", tags=["collection", "integer", "range"], category="collections", version="1.0.0" -) +@invocation("range", title="Integer Range", tags=["collection", "integer", "range"], category="batch", version="1.0.0") class RangeInvocation(BaseInvocation): """Creates a range of numbers from start to stop with step""" @@ -35,7 +33,7 @@ class RangeInvocation(BaseInvocation): "range_of_size", title="Integer Range of Size", tags=["collection", "integer", "size", "range"], - category="collections", + category="batch", version="1.0.0", ) class RangeOfSizeInvocation(BaseInvocation): @@ -55,7 +53,7 @@ class RangeOfSizeInvocation(BaseInvocation): "random_range", title="Random Range", tags=["range", "integer", "random", "collection"], - category="collections", + category="batch", version="1.0.1", use_cache=False, ) diff --git a/invokeai/app/invocations/color_map.py b/invokeai/app/invocations/color_map.py index e55584caf5..ec95acfffd 100644 --- a/invokeai/app/invocations/color_map.py +++ b/invokeai/app/invocations/color_map.py @@ -11,7 +11,7 @@ from invokeai.backend.image_util.util import np_to_pil, pil_to_np "color_map", title="Color Map", tags=["controlnet"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class ColorMapInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 0ff6be969f..99373531d8 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -43,7 +43,7 @@ from invokeai.backend.util.devices import TorchDevice "compel", title="Prompt - SD1.5", tags=["prompt", "compel"], - category="conditioning", + category="prompt", version="1.2.1", ) class CompelInvocation(BaseInvocation): @@ -248,7 +248,7 @@ class SDXLPromptInvocationBase: "sdxl_compel_prompt", title="Prompt - SDXL", tags=["sdxl", "compel", "prompt"], - category="conditioning", + category="prompt", version="1.2.1", ) class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): @@ -342,7 +342,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): "sdxl_refiner_compel_prompt", title="Prompt - SDXL Refiner", tags=["sdxl", "compel", "prompt"], - category="conditioning", + category="prompt", version="1.1.2", ) class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): @@ -391,7 +391,7 @@ class CLIPSkipInvocationOutput(BaseInvocationOutput): "clip_skip", title="Apply CLIP Skip - SD1.5, SDXL", tags=["clipskip", "clip", "skip"], - category="conditioning", + category="prompt", version="1.1.1", ) class CLIPSkipInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/content_shuffle.py b/invokeai/app/invocations/content_shuffle.py index e01096ecea..6fd35b53eb 100644 --- a/invokeai/app/invocations/content_shuffle.py +++ b/invokeai/app/invocations/content_shuffle.py @@ -9,7 +9,7 @@ from invokeai.backend.image_util.content_shuffle import content_shuffle "content_shuffle", title="Content Shuffle", tags=["controlnet", "normal"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class ContentShuffleInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/controlnet.py b/invokeai/app/invocations/controlnet.py index d1878d967e..9b0fc8219b 100644 --- a/invokeai/app/invocations/controlnet.py +++ b/invokeai/app/invocations/controlnet.py @@ -64,7 +64,7 @@ class ControlOutput(BaseInvocationOutput): @invocation( - "controlnet", title="ControlNet - SD1.5, SD2, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3" + "controlnet", title="ControlNet - SD1.5, SD2, SDXL", tags=["controlnet"], category="conditioning", version="1.1.3" ) class ControlNetInvocation(BaseInvocation): """Collects ControlNet info to pass to other nodes""" @@ -116,7 +116,7 @@ class ControlNetInvocation(BaseInvocation): "heuristic_resize", title="Heuristic Resize", tags=["image, controlnet"], - category="image", + category="controlnet_preprocessors", version="1.1.1", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/create_denoise_mask.py b/invokeai/app/invocations/create_denoise_mask.py index d013e8f4f6..419a516bcd 100644 --- a/invokeai/app/invocations/create_denoise_mask.py +++ b/invokeai/app/invocations/create_denoise_mask.py @@ -18,7 +18,7 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_t "create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], - category="latents", + category="mask", version="1.0.2", ) class CreateDenoiseMaskInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/create_gradient_mask.py b/invokeai/app/invocations/create_gradient_mask.py index 8a7e7c5231..08826cc5ef 100644 --- a/invokeai/app/invocations/create_gradient_mask.py +++ b/invokeai/app/invocations/create_gradient_mask.py @@ -41,7 +41,7 @@ class GradientMaskOutput(BaseInvocationOutput): "create_gradient_mask", title="Create Gradient Mask", tags=["mask", "denoise"], - category="latents", + category="mask", version="1.3.0", ) class CreateGradientMaskInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/depth_anything.py b/invokeai/app/invocations/depth_anything.py index af79413ce0..1fd808efde 100644 --- a/invokeai/app/invocations/depth_anything.py +++ b/invokeai/app/invocations/depth_anything.py @@ -20,7 +20,7 @@ DEPTH_ANYTHING_MODELS = { "depth_anything_depth_estimation", title="Depth Anything Depth Estimation", tags=["controlnet", "depth", "depth anything"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class DepthAnythingDepthEstimationInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/dw_openpose.py b/invokeai/app/invocations/dw_openpose.py index 225c7e2283..918a4bc4d0 100644 --- a/invokeai/app/invocations/dw_openpose.py +++ b/invokeai/app/invocations/dw_openpose.py @@ -11,7 +11,7 @@ from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector "dw_openpose_detection", title="DW Openpose Detection", tags=["controlnet", "dwpose", "openpose"], - category="controlnet", + category="controlnet_preprocessors", version="1.1.1", ) class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index 987f1b1e40..1092a67ce9 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -435,7 +435,9 @@ def get_faces_list( return all_faces -@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.2") +@invocation( + "face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="segmentation", version="1.2.2" +) class FaceOffInvocation(BaseInvocation, WithMetadata): """Bound, extract, and mask a face from an image using MediaPipe detection""" @@ -514,7 +516,9 @@ class FaceOffInvocation(BaseInvocation, WithMetadata): return output -@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.2") +@invocation( + "face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="segmentation", version="1.2.2" +) class FaceMaskInvocation(BaseInvocation, WithMetadata): """Face mask creation using mediapipe face detection""" @@ -617,7 +621,11 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata): @invocation( - "face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.2" + "face_identifier", + title="FaceIdentifier", + tags=["image", "face", "identifier"], + category="segmentation", + version="1.2.2", ) class FaceIdentifierInvocation(BaseInvocation, WithMetadata, WithBoard): """Outputs an image with detected face IDs printed on each face. For use with other FaceTools.""" diff --git a/invokeai/app/invocations/flux2_denoise.py b/invokeai/app/invocations/flux2_denoise.py index c387a72790..1b5ea372d6 100644 --- a/invokeai/app/invocations/flux2_denoise.py +++ b/invokeai/app/invocations/flux2_denoise.py @@ -53,7 +53,7 @@ from invokeai.backend.util.devices import TorchDevice "flux2_denoise", title="FLUX2 Denoise", tags=["image", "flux", "flux2", "klein", "denoise"], - category="image", + category="latents", version="1.4.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/flux2_klein_text_encoder.py b/invokeai/app/invocations/flux2_klein_text_encoder.py index b44e782c8a..b2728d1d7c 100644 --- a/invokeai/app/invocations/flux2_klein_text_encoder.py +++ b/invokeai/app/invocations/flux2_klein_text_encoder.py @@ -45,7 +45,7 @@ KLEIN_MAX_SEQ_LEN = 512 "flux2_klein_text_encoder", title="Prompt - Flux2 Klein", tags=["prompt", "conditioning", "flux", "klein", "qwen3"], - category="conditioning", + category="prompt", version="1.1.1", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/flux_controlnet.py b/invokeai/app/invocations/flux_controlnet.py index 8228484375..b11d497f31 100644 --- a/invokeai/app/invocations/flux_controlnet.py +++ b/invokeai/app/invocations/flux_controlnet.py @@ -50,7 +50,7 @@ class FluxControlNetOutput(BaseInvocationOutput): "flux_controlnet", title="FLUX ControlNet", tags=["controlnet", "flux"], - category="controlnet", + category="conditioning", version="1.0.0", ) class FluxControlNetInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index d6102b105b..84f0a030c5 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -70,7 +70,7 @@ from invokeai.backend.util.devices import TorchDevice "flux_denoise", title="FLUX Denoise", tags=["image", "flux"], - category="image", + category="latents", version="4.5.1", ) class FluxDenoiseInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/flux_fill.py b/invokeai/app/invocations/flux_fill.py index cff8f2b1e5..440f3e5c97 100644 --- a/invokeai/app/invocations/flux_fill.py +++ b/invokeai/app/invocations/flux_fill.py @@ -29,7 +29,7 @@ class FluxFillOutput(BaseInvocationOutput): "flux_fill", title="FLUX Fill Conditioning", tags=["inpaint"], - category="inpaint", + category="conditioning", version="1.0.0", classification=Classification.Beta, ) diff --git a/invokeai/app/invocations/flux_ip_adapter.py b/invokeai/app/invocations/flux_ip_adapter.py index 4a1997c512..c0d797d0bd 100644 --- a/invokeai/app/invocations/flux_ip_adapter.py +++ b/invokeai/app/invocations/flux_ip_adapter.py @@ -24,7 +24,7 @@ from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType "flux_ip_adapter", title="FLUX IP-Adapter", tags=["ip_adapter", "control"], - category="ip_adapter", + category="conditioning", version="1.0.0", ) class FluxIPAdapterInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/flux_redux.py b/invokeai/app/invocations/flux_redux.py index 403d78b078..b68e9911c5 100644 --- a/invokeai/app/invocations/flux_redux.py +++ b/invokeai/app/invocations/flux_redux.py @@ -47,7 +47,7 @@ DOWNSAMPLING_FUNCTIONS = Literal["nearest", "bilinear", "bicubic", "area", "near "flux_redux", title="FLUX Redux", tags=["ip_adapter", "control"], - category="ip_adapter", + category="conditioning", version="2.1.0", classification=Classification.Beta, ) diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index 56ebbe7fd9..8b3b33fad1 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -28,7 +28,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit "flux_text_encoder", title="Prompt - FLUX", tags=["prompt", "conditioning", "flux"], - category="conditioning", + category="prompt", version="1.1.2", ) class FluxTextEncoderInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/grounding_dino.py b/invokeai/app/invocations/grounding_dino.py index 1e3d5cea0c..4d900c5034 100644 --- a/invokeai/app/invocations/grounding_dino.py +++ b/invokeai/app/invocations/grounding_dino.py @@ -24,7 +24,7 @@ GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = { "grounding_dino", title="Grounding DINO (Text Prompt Object Detection)", tags=["prompt", "object detection"], - category="image", + category="segmentation", version="1.0.0", ) class GroundingDinoInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/hed.py b/invokeai/app/invocations/hed.py index 5ea6e8df1f..e2b68143e5 100644 --- a/invokeai/app/invocations/hed.py +++ b/invokeai/app/invocations/hed.py @@ -11,7 +11,7 @@ from invokeai.backend.image_util.hed import ControlNetHED_Apache2, HEDEdgeDetect "hed_edge_detection", title="HED Edge Detection", tags=["controlnet", "hed", "softedge"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class HEDEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/ideal_size.py b/invokeai/app/invocations/ideal_size.py index aae3a37c8e..5cfa9c04d0 100644 --- a/invokeai/app/invocations/ideal_size.py +++ b/invokeai/app/invocations/ideal_size.py @@ -21,6 +21,7 @@ class IdealSizeOutput(BaseInvocationOutput): "ideal_size", title="Ideal Size - SD1.5, SDXL", tags=["latents", "math", "ideal_size"], + category="latents", version="1.0.6", ) class IdealSizeInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index d4a1977319..17576a0296 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -197,7 +197,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard): "tomask", title="Mask from Alpha", tags=["image", "mask"], - category="image", + category="mask", version="1.2.2", ) class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard): @@ -604,7 +604,7 @@ class DecodeInvisibleWatermarkInvocation(BaseInvocation): "mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], - category="image", + category="mask", version="1.2.2", ) class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard): @@ -643,7 +643,7 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard): "mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], - category="image", + category="mask", version="1.2.2", ) class MaskCombineInvocation(BaseInvocation, WithMetadata, WithBoard): @@ -974,7 +974,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard): "save_image", title="Save Image", tags=["primitives", "image"], - category="primitives", + category="image", version="1.2.2", use_cache=False, ) @@ -995,7 +995,7 @@ class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard): "canvas_paste_back", title="Canvas Paste Back", tags=["image", "combine"], - category="image", + category="canvas", version="1.0.1", ) class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard): @@ -1032,7 +1032,7 @@ class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard): "mask_from_id", title="Mask from Segmented Image", tags=["image", "mask", "id"], - category="image", + category="mask", version="1.0.1", ) class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard): @@ -1069,7 +1069,7 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard): "canvas_v2_mask_and_crop", title="Canvas V2 Mask and Crop", tags=["image", "mask", "id"], - category="image", + category="canvas", version="1.0.0", classification=Classification.Deprecated, ) @@ -1110,7 +1110,7 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard): @invocation( - "expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.1" + "expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="mask", version="1.0.1" ) class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard): """Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard. @@ -1199,7 +1199,7 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard): "apply_mask_to_image", title="Apply Mask to Image", tags=["image", "mask", "blend"], - category="image", + category="mask", version="1.0.0", ) class ApplyMaskToImageInvocation(BaseInvocation, WithMetadata, WithBoard): @@ -1374,7 +1374,7 @@ class PasteImageIntoBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoar "flux_kontext_image_prep", title="FLUX Kontext Image Prep", tags=["image", "concatenate", "flux", "kontext"], - category="image", + category="conditioning", version="1.0.0", ) class FluxKontextConcatenateImagesInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/image_panels.py b/invokeai/app/invocations/image_panels.py index bb9aa4995a..71fefbd1c6 100644 --- a/invokeai/app/invocations/image_panels.py +++ b/invokeai/app/invocations/image_panels.py @@ -23,7 +23,7 @@ class ImagePanelCoordinateOutput(BaseInvocationOutput): "image_panel_layout", title="Image Panel Layout", tags=["image", "panel", "layout"], - category="image", + category="canvas", version="1.0.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 2b2931e78f..711f910d58 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -73,7 +73,7 @@ CLIP_VISION_MODEL_MAP: dict[Literal["ViT-L", "ViT-H", "ViT-G"], StarterModel] = "ip_adapter", title="IP-Adapter - SD1.5, SDXL", tags=["ip_adapter", "control"], - category="ip_adapter", + category="conditioning", version="1.5.1", ) class IPAdapterInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/lineart.py b/invokeai/app/invocations/lineart.py index c486c329ec..3ffd51b5b6 100644 --- a/invokeai/app/invocations/lineart.py +++ b/invokeai/app/invocations/lineart.py @@ -11,7 +11,7 @@ from invokeai.backend.image_util.lineart import Generator, LineartEdgeDetector "lineart_edge_detection", title="Lineart Edge Detection", tags=["controlnet", "lineart"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class LineartEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/lineart_anime.py b/invokeai/app/invocations/lineart_anime.py index 848756b113..f07476491c 100644 --- a/invokeai/app/invocations/lineart_anime.py +++ b/invokeai/app/invocations/lineart_anime.py @@ -9,7 +9,7 @@ from invokeai.backend.image_util.lineart_anime import LineartAnimeEdgeDetector, "lineart_anime_edge_detection", title="Lineart Anime Edge Detection", tags=["controlnet", "lineart"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class LineartAnimeEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/llava_onevision_vllm.py b/invokeai/app/invocations/llava_onevision_vllm.py index fbd2420590..ff3b801d37 100644 --- a/invokeai/app/invocations/llava_onevision_vllm.py +++ b/invokeai/app/invocations/llava_onevision_vllm.py @@ -19,7 +19,7 @@ from invokeai.backend.util.devices import TorchDevice "llava_onevision_vllm", title="LLaVA OneVision VLLM", tags=["vllm"], - category="vllm", + category="multimodal", version="1.0.0", classification=Classification.Beta, ) diff --git a/invokeai/app/invocations/logic.py b/invokeai/app/invocations/logic.py index 3197427d4e..7cc98afbbc 100644 --- a/invokeai/app/invocations/logic.py +++ b/invokeai/app/invocations/logic.py @@ -12,7 +12,7 @@ class IfInvocationOutput(BaseInvocationOutput): ) -@invocation("if", title="If", tags=["logic", "conditional"], category="logic", version="1.0.0") +@invocation("if", title="If", tags=["logic", "conditional"], category="math", version="1.0.0") class IfInvocation(BaseInvocation): """Selects between two optional inputs based on a boolean condition.""" diff --git a/invokeai/app/invocations/mask.py b/invokeai/app/invocations/mask.py index 556ab8801d..49749f43b6 100644 --- a/invokeai/app/invocations/mask.py +++ b/invokeai/app/invocations/mask.py @@ -24,7 +24,7 @@ from invokeai.backend.image_util.util import pil_to_np "rectangle_mask", title="Create Rectangle Mask", tags=["conditioning"], - category="conditioning", + category="mask", version="1.0.1", ) class RectangleMaskInvocation(BaseInvocation, WithMetadata): @@ -55,7 +55,7 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata): "alpha_mask_to_tensor", title="Alpha Mask to Tensor", tags=["conditioning"], - category="conditioning", + category="mask", version="1.0.0", ) class AlphaMaskToTensorInvocation(BaseInvocation): @@ -83,7 +83,7 @@ class AlphaMaskToTensorInvocation(BaseInvocation): "invert_tensor_mask", title="Invert Tensor Mask", tags=["conditioning"], - category="conditioning", + category="mask", version="1.1.0", ) class InvertTensorMaskInvocation(BaseInvocation): @@ -115,7 +115,7 @@ class InvertTensorMaskInvocation(BaseInvocation): "image_mask_to_tensor", title="Image Mask to Tensor", tags=["conditioning"], - category="conditioning", + category="mask", version="1.0.0", ) class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata): diff --git a/invokeai/app/invocations/mediapipe_face.py b/invokeai/app/invocations/mediapipe_face.py index 89fccfc1ac..e81326463c 100644 --- a/invokeai/app/invocations/mediapipe_face.py +++ b/invokeai/app/invocations/mediapipe_face.py @@ -9,7 +9,7 @@ from invokeai.backend.image_util.mediapipe_face import detect_faces "mediapipe_face_detection", title="MediaPipe Face Detection", tags=["controlnet", "face"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class MediaPipeFaceDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/metadata_linked.py b/invokeai/app/invocations/metadata_linked.py index 6a9db3e589..53f2ea7471 100644 --- a/invokeai/app/invocations/metadata_linked.py +++ b/invokeai/app/invocations/metadata_linked.py @@ -621,7 +621,7 @@ class LatentsMetaOutput(LatentsOutput, MetadataOutput): "denoise_latents_meta", title=f"{DenoiseLatentsInvocation.UIConfig.title} + Metadata", tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], - category="latents", + category="metadata", version="1.1.1", ) class DenoiseLatentsMetaInvocation(DenoiseLatentsInvocation, WithMetadata): @@ -686,7 +686,7 @@ class DenoiseLatentsMetaInvocation(DenoiseLatentsInvocation, WithMetadata): "flux_denoise_meta", title=f"{FluxDenoiseInvocation.UIConfig.title} + Metadata", tags=["flux", "latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], - category="latents", + category="metadata", version="1.0.1", ) class FluxDenoiseLatentsMetaInvocation(FluxDenoiseInvocation, WithMetadata): @@ -734,7 +734,7 @@ class FluxDenoiseLatentsMetaInvocation(FluxDenoiseInvocation, WithMetadata): "z_image_denoise_meta", title=f"{ZImageDenoiseInvocation.UIConfig.title} + Metadata", tags=["z-image", "latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], - category="latents", + category="metadata", version="1.0.0", ) class ZImageDenoiseMetaInvocation(ZImageDenoiseInvocation, WithMetadata): diff --git a/invokeai/app/invocations/mlsd.py b/invokeai/app/invocations/mlsd.py index 1526350db8..a2446876c8 100644 --- a/invokeai/app/invocations/mlsd.py +++ b/invokeai/app/invocations/mlsd.py @@ -10,7 +10,7 @@ from invokeai.backend.image_util.mlsd.models.mbv2_mlsd_large import MobileV2_MLS "mlsd_detection", title="MLSD Detection", tags=["controlnet", "mlsd", "edge"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class MLSDDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 6b5afb5529..0c96cdb1d9 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -584,7 +584,7 @@ class SeamlessModeInvocation(BaseInvocation): return SeamlessModeOutput(unet=unet, vae=vae) -@invocation("freeu", title="Apply FreeU - SD1.5, SDXL", tags=["freeu"], category="unet", version="1.0.2") +@invocation("freeu", title="Apply FreeU - SD1.5, SDXL", tags=["freeu"], category="model", version="1.0.2") class FreeUInvocation(BaseInvocation): """ Applies FreeU to the UNet. Suggested values (b1/b2/s1/s2): diff --git a/invokeai/app/invocations/normal_bae.py b/invokeai/app/invocations/normal_bae.py index ebbea869a1..1159927150 100644 --- a/invokeai/app/invocations/normal_bae.py +++ b/invokeai/app/invocations/normal_bae.py @@ -10,7 +10,7 @@ from invokeai.backend.image_util.normal_bae.nets.NNET import NNET "normal_map", title="Normal Map", tags=["controlnet", "normal"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class NormalMapInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/pbr_maps.py b/invokeai/app/invocations/pbr_maps.py index 5e519d38bc..945c3cad59 100644 --- a/invokeai/app/invocations/pbr_maps.py +++ b/invokeai/app/invocations/pbr_maps.py @@ -16,7 +16,9 @@ class PBRMapsOutput(BaseInvocationOutput): displacement_map: ImageField = OutputField(default=None, description="The generated displacement map") -@invocation("pbr_maps", title="PBR Maps", tags=["image", "material"], category="image", version="1.0.0") +@invocation( + "pbr_maps", title="PBR Maps", tags=["image", "material"], category="controlnet_preprocessors", version="1.0.0" +) class PBRMapsInvocation(BaseInvocation, WithMetadata, WithBoard): """Generate Normal, Displacement and Roughness Map from a given image""" diff --git a/invokeai/app/invocations/pidi.py b/invokeai/app/invocations/pidi.py index 47b241ee1f..5d8cab0458 100644 --- a/invokeai/app/invocations/pidi.py +++ b/invokeai/app/invocations/pidi.py @@ -10,7 +10,7 @@ from invokeai.backend.image_util.pidi.model import PiDiNet "pidi_edge_detection", title="PiDiNet Edge Detection", tags=["controlnet", "edge"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class PiDiNetEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/sd3_denoise.py b/invokeai/app/invocations/sd3_denoise.py index b9d69369b7..4b990ee42b 100644 --- a/invokeai/app/invocations/sd3_denoise.py +++ b/invokeai/app/invocations/sd3_denoise.py @@ -34,7 +34,7 @@ from invokeai.backend.util.devices import TorchDevice "sd3_denoise", title="Denoise - SD3", tags=["image", "sd3"], - category="image", + category="latents", version="1.1.1", ) class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/sd3_image_to_latents.py b/invokeai/app/invocations/sd3_image_to_latents.py index 71a48ee9ad..9af641d8bc 100644 --- a/invokeai/app/invocations/sd3_image_to_latents.py +++ b/invokeai/app/invocations/sd3_image_to_latents.py @@ -24,7 +24,7 @@ from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory "sd3_i2l", title="Image to Latents - SD3", tags=["image", "latents", "vae", "i2l", "sd3"], - category="image", + category="latents", version="1.0.1", ) class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/sd3_text_encoder.py b/invokeai/app/invocations/sd3_text_encoder.py index 58880f9a28..7af138fe45 100644 --- a/invokeai/app/invocations/sd3_text_encoder.py +++ b/invokeai/app/invocations/sd3_text_encoder.py @@ -31,7 +31,7 @@ SD3_T5_MAX_SEQ_LEN = 256 "sd3_text_encoder", title="Prompt - SD3", tags=["prompt", "conditioning", "sd3"], - category="conditioning", + category="prompt", version="1.0.1", ) class Sd3TextEncoderInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/strings.py b/invokeai/app/invocations/strings.py index 2b6bf300b9..6d64e8771a 100644 --- a/invokeai/app/invocations/strings.py +++ b/invokeai/app/invocations/strings.py @@ -20,7 +20,7 @@ class StringPosNegOutput(BaseInvocationOutput): "string_split_neg", title="String Split Negative", tags=["string", "split", "negative"], - category="string", + category="strings", version="1.0.1", ) class StringSplitNegInvocation(BaseInvocation): @@ -63,7 +63,7 @@ class String2Output(BaseInvocationOutput): string_2: str = OutputField(description="string 2") -@invocation("string_split", title="String Split", tags=["string", "split"], category="string", version="1.0.1") +@invocation("string_split", title="String Split", tags=["string", "split"], category="strings", version="1.0.1") class StringSplitInvocation(BaseInvocation): """Splits string into two strings, based on the first occurance of the delimiter. The delimiter will be removed from the string""" @@ -83,7 +83,7 @@ class StringSplitInvocation(BaseInvocation): return String2Output(string_1=part1, string_2=part2) -@invocation("string_join", title="String Join", tags=["string", "join"], category="string", version="1.0.1") +@invocation("string_join", title="String Join", tags=["string", "join"], category="strings", version="1.0.1") class StringJoinInvocation(BaseInvocation): """Joins string left to string right""" @@ -94,7 +94,9 @@ class StringJoinInvocation(BaseInvocation): return StringOutput(value=((self.string_left or "") + (self.string_right or ""))) -@invocation("string_join_three", title="String Join Three", tags=["string", "join"], category="string", version="1.0.1") +@invocation( + "string_join_three", title="String Join Three", tags=["string", "join"], category="strings", version="1.0.1" +) class StringJoinThreeInvocation(BaseInvocation): """Joins string left to string middle to string right""" @@ -107,7 +109,7 @@ class StringJoinThreeInvocation(BaseInvocation): @invocation( - "string_replace", title="String Replace", tags=["string", "replace", "regex"], category="string", version="1.0.1" + "string_replace", title="String Replace", tags=["string", "replace", "regex"], category="strings", version="1.0.1" ) class StringReplaceInvocation(BaseInvocation): """Replaces the search string with the replace string""" diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index 15f1881eef..cf4b7cda47 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -49,7 +49,7 @@ class T2IAdapterOutput(BaseInvocationOutput): "t2i_adapter", title="T2I-Adapter - SD1.5, SDXL", tags=["t2i_adapter", "control"], - category="t2i_adapter", + category="conditioning", version="1.0.4", ) class T2IAdapterInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index e7b3968aec..64e372a0f6 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -30,7 +30,7 @@ ESRGAN_MODEL_URLS: dict[str, str] = { } -@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2") +@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="upscale", version="1.3.2") class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): """Upscales an image using RealESRGAN.""" diff --git a/invokeai/app/invocations/z_image_control.py b/invokeai/app/invocations/z_image_control.py index 3b01f12373..f51c2fcd16 100644 --- a/invokeai/app/invocations/z_image_control.py +++ b/invokeai/app/invocations/z_image_control.py @@ -57,7 +57,7 @@ class ZImageControlOutput(BaseInvocationOutput): "z_image_control", title="Z-Image ControlNet", tags=["image", "z-image", "control", "controlnet"], - category="control", + category="conditioning", version="1.1.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/z_image_denoise.py b/invokeai/app/invocations/z_image_denoise.py index 24b135e447..397e917112 100644 --- a/invokeai/app/invocations/z_image_denoise.py +++ b/invokeai/app/invocations/z_image_denoise.py @@ -49,7 +49,7 @@ from invokeai.backend.z_image.z_image_transformer_patch import patch_transformer "z_image_denoise", title="Denoise - Z-Image", tags=["image", "z-image"], - category="image", + category="latents", version="1.5.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/z_image_image_to_latents.py b/invokeai/app/invocations/z_image_image_to_latents.py index 5a70fdba13..263346e296 100644 --- a/invokeai/app/invocations/z_image_image_to_latents.py +++ b/invokeai/app/invocations/z_image_image_to_latents.py @@ -30,7 +30,7 @@ ZImageVAE = Union[AutoencoderKL, FluxAutoEncoder] "z_image_i2l", title="Image to Latents - Z-Image", tags=["image", "latents", "vae", "i2l", "z-image"], - category="image", + category="latents", version="1.1.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/z_image_seed_variance_enhancer.py b/invokeai/app/invocations/z_image_seed_variance_enhancer.py index b24002e971..72819a966a 100644 --- a/invokeai/app/invocations/z_image_seed_variance_enhancer.py +++ b/invokeai/app/invocations/z_image_seed_variance_enhancer.py @@ -19,7 +19,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( "z_image_seed_variance_enhancer", title="Seed Variance Enhancer - Z-Image", tags=["conditioning", "z-image", "variance", "seed"], - category="conditioning", + category="prompt", version="1.0.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/z_image_text_encoder.py b/invokeai/app/invocations/z_image_text_encoder.py index c3405d6dc8..71af6085d0 100644 --- a/invokeai/app/invocations/z_image_text_encoder.py +++ b/invokeai/app/invocations/z_image_text_encoder.py @@ -34,7 +34,7 @@ Z_IMAGE_MAX_SEQ_LEN = 512 "z_image_text_encoder", title="Prompt - Z-Image", tags=["prompt", "conditioning", "z-image"], - category="conditioning", + category="prompt", version="1.1.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 5d1b1d0d8d..7e56db9f61 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -101,7 +101,8 @@ class InvokeAIAppConfig(BaseSettings): force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty). pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. max_queue_size: Maximum number of items in the session queue. - clear_queue_on_startup: Empties session queue on startup. + clear_queue_on_startup: Empties session queue on startup. If true, disables `max_queue_history`. + max_queue_history: Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true. allow_nodes: List of nodes to allow. Omit to allow all. deny_nodes: List of nodes to deny. Omit to deny none. node_cache_size: How many cached nodes to keep in memory. @@ -191,7 +192,8 @@ class InvokeAIAppConfig(BaseSettings): force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).") pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.") max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.") - clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.") + clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup. If true, disables `max_queue_history`.") + max_queue_history: Optional[int] = Field(default=None, ge=0, description="Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true.") # NODES allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.") diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 070a7cef29..172dc08d55 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -45,10 +45,19 @@ class SqliteSessionQueue(SessionQueueBase): def start(self, invoker: Invoker) -> None: self.__invoker = invoker self._set_in_progress_to_canceled() - if self.__invoker.services.configuration.clear_queue_on_startup: + config = self.__invoker.services.configuration + if config.clear_queue_on_startup: clear_result = self.clear(DEFAULT_QUEUE_ID) if clear_result.deleted > 0: self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items") + return + + if config.max_queue_history is not None: + deleted = self._prune_terminal_to_limit(DEFAULT_QUEUE_ID, config.max_queue_history) + if deleted > 0: + self.__invoker.services.logger.info( + f"Pruned {deleted} completed/failed/canceled queue items (kept up to {config.max_queue_history})" + ) def __init__(self, db: SqliteDatabase) -> None: super().__init__() @@ -68,6 +77,51 @@ class SqliteSessionQueue(SessionQueueBase): """ ) + def _prune_terminal_to_limit(self, queue_id: str, keep: int) -> int: + """Prune terminal items (completed/failed/canceled) to keep at most N most-recent items.""" + with self._db.transaction() as cursor: + where = """--sql + WHERE + queue_id = ? + AND ( + status = 'completed' + OR status = 'failed' + OR status = 'canceled' + ) + """ + cursor.execute( + f"""--sql + SELECT COUNT(*) + FROM session_queue + {where} + AND item_id NOT IN ( + SELECT item_id + FROM session_queue + {where} + ORDER BY COALESCE(completed_at, updated_at, created_at) DESC, item_id DESC + LIMIT ? + ); + """, + (queue_id, queue_id, keep), + ) + count = cursor.fetchone()[0] + cursor.execute( + f"""--sql + DELETE + FROM session_queue + {where} + AND item_id NOT IN ( + SELECT item_id + FROM session_queue + {where} + ORDER BY COALESCE(completed_at, updated_at, created_at) DESC, item_id DESC + LIMIT ? + ); + """, + (queue_id, queue_id, keep), + ) + return count + def _get_current_queue_size(self, queue_id: str) -> int: """Gets the current number of pending queue items""" with self._db.transaction() as cursor: diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index e9a896f1b4..e537362801 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -66,6 +66,7 @@ "i18next-http-backend": "^3.0.2", "idb-keyval": "6.2.1", "jsondiffpatch": "^0.7.3", + "jszip": "^3.10.1", "konva": "^9.3.22", "linkify-react": "^4.3.1", "linkifyjs": "^4.3.1", diff --git a/invokeai/frontend/web/pnpm-lock.yaml b/invokeai/frontend/web/pnpm-lock.yaml index 3f94ba7d69..6a2ed95ab0 100644 --- a/invokeai/frontend/web/pnpm-lock.yaml +++ b/invokeai/frontend/web/pnpm-lock.yaml @@ -86,6 +86,9 @@ importers: jsondiffpatch: specifier: ^0.7.3 version: 0.7.3 + jszip: + specifier: ^3.10.1 + version: 3.10.1 konva: specifier: ^9.3.22 version: 9.3.22 @@ -2003,6 +2006,9 @@ packages: copy-to-clipboard@3.3.3: resolution: {integrity: sha512-2KV8NhB5JqC3ky0r9PMCAZKbUHSwtEo4CwCs0KXgruG43gX5PMqDEBbVU4OUzw2MuAWUfsuFmWvEKG5QRfSnJA==} + core-util-is@1.0.3: + resolution: {integrity: sha512-ZQBvi1DcpJ4GDqanjucZ2Hj3wEO5pZDS89BWbkcrvdxksJorwUDDZamX9ldFkp9aw2lmBDLgkObEA4DWNJ9FYQ==} + cosmiconfig@7.1.0: resolution: {integrity: sha512-AdmX6xUzdNASswsFtmwSt7Vj8po9IuqXm0UXz7QKPuEUmPB4XyjGfaAr2PSuELMwkRMVH1EpIkX5bTZGRB3eCA==} engines: {node: '>=10'} @@ -2672,6 +2678,9 @@ packages: resolution: {integrity: sha512-Hs59xBNfUIunMFgWAbGX5cq6893IbWg4KnrjbYwX3tx0ztorVgTDA6B2sxf8ejHJ4wz8BqGUMYlnzNBer5NvGg==} engines: {node: '>= 4'} + immediate@3.0.6: + resolution: {integrity: sha512-XXOFtyqDjNDAQxVfYxuF7g9Il/IbWmmlQg2MYKOH8ExIT1qg6xc4zyS3HaEEATgs1btfzxq15ciUiY7gjSXRGQ==} + immer@10.1.1: resolution: {integrity: sha512-s2MPrmjovJcoMaHtx6K11Ra7oD05NT97w1IC5zpMkT6Atjr7H8LjaDd81iIxUYpMKSRRNMJE703M1Fhr/TctHw==} @@ -2825,6 +2834,9 @@ packages: resolution: {integrity: sha512-fKzAra0rGJUUBwGBgNkHZuToZcn+TtXHpeCgmkMJMMYx1sQDYaCSyjJBSCa2nH1DGm7s3n1oBnohoVTBaN7Lww==} engines: {node: '>=8'} + isarray@1.0.0: + resolution: {integrity: sha512-VLghIWNM6ELQzo7zwmcg0NmTVyWKYjvIeM83yjp0wRDTmUnrM678fQbcKBo6n2CJEF0szoG//ytg+TKla89ALQ==} + isarray@2.0.5: resolution: {integrity: sha512-xHjhDr3cNBK0BzdUJSPXZntQUx/mwMS5Rw4A7lPJ90XGAO6ISP/ePDNuo0vhqOZU+UD5JoodwCAAoZQd3FeAKw==} @@ -2916,6 +2928,9 @@ packages: resolution: {integrity: sha512-ZZow9HBI5O6EPgSJLUb8n2NKgmVWTwCvHGwFuJlMjvLFqlGG6pjirPhtdsseaLZjSibD8eegzmYpUZwoIlj2cQ==} engines: {node: '>=4.0'} + jszip@3.10.1: + resolution: {integrity: sha512-xXDvecyTpGLrqFrvkrUSoxxfJI5AH7U8zxxtVclpsUtMCq4JQ290LY8AW5c7Ggnr/Y/oK+bQMbqK2qmtk3pN4g==} + keyv@4.5.4: resolution: {integrity: sha512-oxVHkHR/EJf2CNXnWxRLW6mg7JyCCUcG0DtEGmL2ctUo1PNTin1PUil+r/+4r5MpVgC/fn1kjsx7mjSujKqIpw==} @@ -2934,6 +2949,9 @@ packages: resolution: {integrity: sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ==} engines: {node: '>= 0.8.0'} + lie@3.3.0: + resolution: {integrity: sha512-UaiMJzeWRlEujzAuw5LokY1L5ecNQYZKfmyZ9L7wDHb/p5etKaxXhohBcrw0EYby+G/NA52vRSN4N39dxHAIwQ==} + lines-and-columns@1.2.4: resolution: {integrity: sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==} @@ -3210,6 +3228,9 @@ packages: package-json-from-dist@1.0.1: resolution: {integrity: sha512-UEZIS3/by4OC8vL3P2dTXRETpebLI2NiI5vIrjaD/5UtrkFX/tNbwjTSRAGC/+7CAo2pIcBaRgWmcBBHcsaCIw==} + pako@1.0.11: + resolution: {integrity: sha512-4hLB8Py4zZce5s4yd9XzopqwVv/yGNhV1Bl8NTmCq1763HeK2+EwVTv+leGeL13Dnh2wfbqowVPXCIO0z4taYw==} + pako@2.1.0: resolution: {integrity: sha512-w+eufiZ1WuJYgPXbV/PO3NCMEc3xqylkKHzp8bxp1uW4qaSNQUkwmLLEc3kKsfz8lpV1F8Ht3U1Cm+9Srog2ug==} @@ -3298,6 +3319,9 @@ packages: resolution: {integrity: sha512-Qb1gy5OrP5+zDf2Bvnzdl3jsTf1qXVMazbvCoKhtKqVs4/YK4ozX4gKQJJVyNe+cajNPn0KoC0MC3FUmaHWEmQ==} engines: {node: ^10.13.0 || ^12.13.0 || ^14.15.0 || >=15.0.0} + process-nextick-args@2.0.1: + resolution: {integrity: sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==} + prop-types@15.8.1: resolution: {integrity: sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg==} @@ -3539,6 +3563,9 @@ packages: resolution: {integrity: sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==} engines: {node: '>=0.10.0'} + readable-stream@2.3.8: + resolution: {integrity: sha512-8p0AUk4XODgIewSi0l8Epjs+EVnWiK7NoDIEGU0HhE7+ZyY8D1IMY7odu5lRrFXGg71L15KG8QrPmum45RTtdA==} + readable-stream@3.6.2: resolution: {integrity: sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==} engines: {node: '>= 6'} @@ -3661,6 +3688,9 @@ packages: resolution: {integrity: sha512-AURm5f0jYEOydBj7VQlVvDrjeFgthDdEF5H1dP+6mNpoXOMo1quQqJ4wvJDyRZ9+pO3kGWoOdmV08cSv2aJV6Q==} engines: {node: '>=0.4'} + safe-buffer@5.1.2: + resolution: {integrity: sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==} + safe-buffer@5.2.1: resolution: {integrity: sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==} @@ -3718,6 +3748,9 @@ packages: resolution: {integrity: sha512-RJRdvCo6IAnPdsvP/7m6bsQqNnn1FCBX5ZNtFL98MmFF/4xAIJTIg1YbHW5DC2W5SKZanrC6i4HsJqlajw/dZw==} engines: {node: '>= 0.4'} + setimmediate@1.0.5: + resolution: {integrity: sha512-MATJdZp8sLqDl/68LfQmbP8zKPLQNV6BIZoIgrscFDQ+RsvK/BxeDQOgyxKKoh0y/8h3BqVFnCqQ/gd+reiIXA==} + shebang-command@2.0.0: resolution: {integrity: sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==} engines: {node: '>=8'} @@ -3857,6 +3890,9 @@ packages: resolution: {integrity: sha512-UXSH262CSZY1tfu3G3Secr6uGLCFVPMhIqHjlgCUtCCcgihYc/xKs9djMTMUOb2j1mVSeU8EU6NWc/iQKU6Gfg==} engines: {node: '>= 0.4'} + string_decoder@1.1.1: + resolution: {integrity: sha512-n/ShnvDi6FHbbVfviro+WojiFzv+s8MPMHBczVePfUpDJLwoLT0ht1l4YwBCbi8pJAveEEdnkHyPyTP/mzRfwg==} + string_decoder@1.3.0: resolution: {integrity: sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==} @@ -6153,6 +6189,8 @@ snapshots: dependencies: toggle-selection: 1.0.6 + core-util-is@1.0.3: {} + cosmiconfig@7.1.0: dependencies: '@types/parse-json': 4.0.2 @@ -6957,6 +6995,8 @@ snapshots: ignore@7.0.5: {} + immediate@3.0.6: {} + immer@10.1.1: {} import-fresh@3.3.1: @@ -7103,6 +7143,8 @@ snapshots: dependencies: is-docker: 2.2.1 + isarray@1.0.0: {} + isarray@2.0.5: {} isexe@2.0.0: {} @@ -7192,6 +7234,13 @@ snapshots: object.assign: 4.1.7 object.values: 1.2.1 + jszip@3.10.1: + dependencies: + lie: 3.3.0 + pako: 1.0.11 + readable-stream: 2.3.8 + setimmediate: 1.0.5 + keyv@4.5.4: dependencies: json-buffer: 3.0.1 @@ -7221,6 +7270,10 @@ snapshots: prelude-ls: 1.2.1 type-check: 0.4.0 + lie@3.3.0: + dependencies: + immediate: 3.0.6 + lines-and-columns@1.2.4: {} linkify-react@4.3.1(linkifyjs@4.3.1)(react@18.3.1): @@ -7510,6 +7563,8 @@ snapshots: package-json-from-dist@1.0.1: {} + pako@1.0.11: {} + pako@2.1.0: {} parent-module@1.0.1: @@ -7578,6 +7633,8 @@ snapshots: ansi-styles: 5.2.0 react-is: 17.0.2 + process-nextick-args@2.0.1: {} + prop-types@15.8.1: dependencies: loose-envify: 1.4.0 @@ -7843,6 +7900,16 @@ snapshots: dependencies: loose-envify: 1.4.0 + readable-stream@2.3.8: + dependencies: + core-util-is: 1.0.3 + inherits: 2.0.4 + isarray: 1.0.0 + process-nextick-args: 2.0.1 + safe-buffer: 5.1.2 + string_decoder: 1.1.1 + util-deprecate: 1.0.2 + readable-stream@3.6.2: dependencies: inherits: 2.0.4 @@ -7994,6 +8061,8 @@ snapshots: has-symbols: 1.1.0 isarray: 2.0.5 + safe-buffer@5.1.2: {} + safe-buffer@5.2.1: {} safe-push-apply@1.0.0: @@ -8051,6 +8120,8 @@ snapshots: es-errors: 1.3.0 es-object-atoms: 1.1.1 + setimmediate@1.0.5: {} + shebang-command@2.0.0: dependencies: shebang-regex: 3.0.0 @@ -8236,6 +8307,10 @@ snapshots: define-properties: 1.2.1 es-object-atoms: 1.1.1 + string_decoder@1.1.1: + dependencies: + safe-buffer: 5.1.2 + string_decoder@1.3.0: dependencies: safe-buffer: 5.2.1 diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 201ea8badb..fd6154760e 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -212,6 +212,7 @@ "copy": "Copy", "copyError": "$t(gallery.copy) Error", "clipboard": "Clipboard", + "collapseAll": "Collapse All", "crop": "Crop", "on": "On", "off": "Off", @@ -239,6 +240,7 @@ "error": "Error", "error_withCount_one": "{{count}} error", "error_withCount_other": "{{count}} errors", + "expandAll": "Expand All", "model_withCount_one": "{{count}} model", "model_withCount_other": "{{count}} models", "file": "File", @@ -715,6 +717,10 @@ "title": "Rect Tool", "desc": "Select the rect tool." }, + "selectLassoTool": { + "title": "Lasso Tool", + "desc": "Select the lasso tool." + }, "selectViewTool": { "title": "View Tool", "desc": "Select the view tool." @@ -1379,6 +1385,8 @@ "fullyContainNodesHelp": "Nodes must be fully inside the selection box to be selected", "showEdgeLabels": "Show Edge Labels", "showEdgeLabelsHelp": "Show labels on edges, indicating the connected nodes", + "groupNodesByCategory": "Group Nodes by Category", + "groupNodesByCategoryHelp": "Group nodes by category in the add node dialog", "hideLegendNodes": "Hide Field Type Legend", "hideMinimapnodes": "Hide MiniMap", "inputMayOnlyHaveOneConnection": "Input may only have one connection", @@ -1417,6 +1425,8 @@ "notes": "Notes", "description": "Description", "notesDescription": "Add notes about your workflow", + "addConnector": "Add Connector", + "deleteConnector": "Delete Connector", "problemSettingTitle": "Problem Setting Title", "resetToDefaultValue": "Reset to default value", "reloadNodeTemplates": "Reload Node Templates", @@ -1705,6 +1715,8 @@ "enableNSFWChecker": "Enable NSFW Checker", "general": "General", "generation": "Generation", + "maxQueueHistory": "Max Queue History", + "maxQueueHistorySaveFailed": "Failed to save Max Queue History", "models": "Models", "preferAttentionStyleNumeric": "Prefer Numeric Attention Style", "prompt": "Prompt", @@ -2758,10 +2770,16 @@ "radial": "Radial", "clip": "Clip Gradient" }, + "lasso": { + "freehand": "Freehand", + "polygon": "Polygon", + "polygonHint": "Click to add points, click the first point to close." + }, "tool": { "brush": "Brush", "eraser": "Eraser", "rectangle": "Rectangle", + "lasso": "Lasso", "gradient": "Gradient", "bbox": "Bbox", "move": "Move", @@ -3001,6 +3019,19 @@ "copyCanvasToClipboard": "Copy Canvas to Clipboard", "copyBboxToClipboard": "Copy Bbox to Clipboard" }, + "canvasProject": { + "project": "Project", + "saveProject": "Save Canvas Project", + "loadProject": "Load Canvas Project", + "saveSuccess": "Project Saved", + "saveSuccessDesc": "Saved project with {{count}} images", + "saveError": "Failed to Save Project", + "loadSuccess": "Project Loaded", + "loadSuccessDesc": "Canvas state restored from project file", + "loadError": "Failed to Load Project", + "loadWarning": "Loading a project will replace your current canvas, including all layers, masks, reference images, and generation parameters. This action cannot be undone.", + "projectName": "Project Name" + }, "stagingArea": { "accept": "Accept", "discardAll": "Discard All", diff --git a/invokeai/frontend/web/src/app/components/GlobalModalIsolator.tsx b/invokeai/frontend/web/src/app/components/GlobalModalIsolator.tsx index ef0747707f..e5ec5ccc56 100644 --- a/invokeai/frontend/web/src/app/components/GlobalModalIsolator.tsx +++ b/invokeai/frontend/web/src/app/components/GlobalModalIsolator.tsx @@ -2,6 +2,8 @@ import { GlobalImageHotkeys } from 'app/components/GlobalImageHotkeys'; import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal'; import { CanvasPasteModal } from 'features/controlLayers/components/CanvasPasteModal'; import { CanvasWorkflowIntegrationModal } from 'features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationModal'; +import { LoadCanvasProjectConfirmationAlertDialog } from 'features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog'; +import { SaveCanvasProjectDialog } from 'features/controlLayers/components/SaveCanvasProjectDialog'; import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate'; import { CropImageModal } from 'features/cropper/components/CropImageModal'; import { DeleteImageModal } from 'features/deleteImageModal/components/DeleteImageModal'; @@ -54,6 +56,8 @@ export const GlobalModalIsolator = memo(() => { + + diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems.tsx index ca264fa389..064378b227 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems.tsx @@ -2,6 +2,8 @@ import { Menu, MenuButton, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-l import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu'; import { CanvasContextMenuItemsCropCanvasToBbox } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuItemsCropCanvasToBbox'; import { NewLayerIcon } from 'features/controlLayers/components/common/icons'; +import { useLoadCanvasProjectWithDialog } from 'features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog'; +import { useSaveCanvasProjectWithDialog } from 'features/controlLayers/components/SaveCanvasProjectDialog'; import { useCopyCanvasToClipboard } from 'features/controlLayers/hooks/copyHooks'; import { useNewControlLayerFromBbox, @@ -14,16 +16,19 @@ import { import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; -import { PiCopyBold, PiFloppyDiskBold } from 'react-icons/pi'; +import { PiArchiveBold, PiCopyBold, PiFileArrowDownBold, PiFileArrowUpBold, PiFloppyDiskBold } from 'react-icons/pi'; export const CanvasContextMenuGlobalMenuItems = memo(() => { const { t } = useTranslation(); const saveSubMenu = useSubMenu(); + const projectSubMenu = useSubMenu(); const newSubMenu = useSubMenu(); const copySubMenu = useSubMenu(); const isBusy = useCanvasIsBusy(); const saveCanvasToGallery = useSaveCanvasToGallery(); const saveBboxToGallery = useSaveBboxToGallery(); + const saveCanvasProject = useSaveCanvasProjectWithDialog(); + const loadCanvasProject = useLoadCanvasProjectWithDialog(); const newRegionalReferenceImageFromBbox = useNewRegionalReferenceImageFromBbox(); const newGlobalReferenceImageFromBbox = useNewGlobalReferenceImageFromBbox(); const newRasterLayerFromBbox = useNewRasterLayerFromBbox(); @@ -50,6 +55,21 @@ export const CanvasContextMenuGlobalMenuItems = memo(() => { + }> + + + + + + } isDisabled={isBusy} onClick={saveCanvasProject}> + {t('controlLayers.canvasProject.saveProject')} + + } isDisabled={isBusy} onClick={loadCanvasProject}> + {t('controlLayers.canvasProject.loadProject')} + + + + }> diff --git a/invokeai/frontend/web/src/features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog.tsx b/invokeai/frontend/web/src/features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog.tsx new file mode 100644 index 0000000000..149d5b4f17 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog.tsx @@ -0,0 +1,69 @@ +import { ConfirmationAlertDialog, Flex, Text } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; +import { useAssertSingleton } from 'common/hooks/useAssertSingleton'; +import { useCanvasProjectLoad } from 'features/controlLayers/hooks/useCanvasProjectLoad'; +import { CANVAS_PROJECT_EXTENSION } from 'features/controlLayers/util/canvasProjectFile'; +import { atom } from 'nanostores'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; + +const $pendingFile = atom(null); + +const openFileDialog = (onFileSelected: (file: File) => void) => { + const input = document.createElement('input'); + input.type = 'file'; + input.accept = CANVAS_PROJECT_EXTENSION; + input.onchange = () => { + const file = input.files?.[0]; + if (file) { + onFileSelected(file); + } + }; + input.click(); +}; + +export const useLoadCanvasProjectWithDialog = () => { + const openDialog = useCallback(() => { + openFileDialog((file) => { + $pendingFile.set(file); + }); + }, []); + + return openDialog; +}; + +export const LoadCanvasProjectConfirmationAlertDialog = memo(() => { + useAssertSingleton('LoadCanvasProjectConfirmationAlertDialog'); + const { t } = useTranslation(); + const { loadCanvasProject } = useCanvasProjectLoad(); + const pendingFile = useStore($pendingFile); + + const onClose = useCallback(() => { + $pendingFile.set(null); + }, []); + + const onAccept = useCallback(() => { + const file = $pendingFile.get(); + if (file) { + void loadCanvasProject(file); + } + $pendingFile.set(null); + }, [loadCanvasProject]); + + return ( + + + {t('controlLayers.canvasProject.loadWarning')} + + + ); +}); + +LoadCanvasProjectConfirmationAlertDialog.displayName = 'LoadCanvasProjectConfirmationAlertDialog'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/SaveCanvasProjectDialog.tsx b/invokeai/frontend/web/src/features/controlLayers/components/SaveCanvasProjectDialog.tsx new file mode 100644 index 0000000000..bf947ba7c4 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/SaveCanvasProjectDialog.tsx @@ -0,0 +1,92 @@ +import { + AlertDialog, + AlertDialogBody, + AlertDialogContent, + AlertDialogFooter, + AlertDialogHeader, + Button, + Flex, + FormControl, + FormLabel, + Input, +} from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; +import { useAssertSingleton } from 'common/hooks/useAssertSingleton'; +import { useCanvasProjectSave } from 'features/controlLayers/hooks/useCanvasProjectSave'; +import { atom } from 'nanostores'; +import type { ChangeEvent, RefObject } from 'react'; +import { memo, useCallback, useRef, useState } from 'react'; +import { useTranslation } from 'react-i18next'; + +const $isOpen = atom(false); + +export const useSaveCanvasProjectWithDialog = () => { + return useCallback(() => { + $isOpen.set(true); + }, []); +}; + +export const SaveCanvasProjectDialog = memo(() => { + useAssertSingleton('SaveCanvasProjectDialog'); + const isOpen = useStore($isOpen); + const cancelRef = useRef(null); + + const onClose = useCallback(() => { + $isOpen.set(false); + }, []); + + return ( + + {isOpen && } + + ); +}); + +SaveCanvasProjectDialog.displayName = 'SaveCanvasProjectDialog'; + +const Content = memo(({ cancelRef }: { cancelRef: RefObject }) => { + const { t } = useTranslation(); + const { saveCanvasProject } = useCanvasProjectSave(); + const [name, setName] = useState('Canvas Project'); + + const onChange = useCallback((e: ChangeEvent) => { + setName(e.target.value); + }, []); + + const onClose = useCallback(() => { + $isOpen.set(false); + }, []); + + const onSave = useCallback(() => { + void saveCanvasProject(name); + $isOpen.set(false); + }, [name, saveCanvasProject]); + + return ( + + + {t('controlLayers.canvasProject.saveProject')} + + + + + {t('controlLayers.canvasProject.projectName')} + + + + + + + + + + + + ); +}); + +Content.displayName = 'SaveCanvasProjectDialogContent'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolChooser.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolChooser.tsx index 44efe12eb9..30d8272207 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolChooser.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolChooser.tsx @@ -3,6 +3,7 @@ import { ToolBboxButton } from 'features/controlLayers/components/Tool/ToolBboxB import { ToolBrushButton } from 'features/controlLayers/components/Tool/ToolBrushButton'; import { ToolColorPickerButton } from 'features/controlLayers/components/Tool/ToolColorPickerButton'; import { ToolGradientButton } from 'features/controlLayers/components/Tool/ToolGradientButton'; +import { ToolLassoButton } from 'features/controlLayers/components/Tool/ToolLassoButton'; import { ToolMoveButton } from 'features/controlLayers/components/Tool/ToolMoveButton'; import { ToolRectButton } from 'features/controlLayers/components/Tool/ToolRectButton'; import { ToolTextButton } from 'features/controlLayers/components/Tool/ToolTextButton'; @@ -20,6 +21,7 @@ export const ToolChooser: React.FC = () => { + diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolLassoButton.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolLassoButton.tsx new file mode 100644 index 0000000000..587e8a7223 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolLassoButton.tsx @@ -0,0 +1,34 @@ +import { IconButton, Tooltip } from '@invoke-ai/ui-library'; +import { useSelectTool, useToolIsSelected } from 'features/controlLayers/components/Tool/hooks'; +import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiLassoBold } from 'react-icons/pi'; + +export const ToolLassoButton = memo(() => { + const { t } = useTranslation(); + const isSelected = useToolIsSelected('lasso'); + const selectLasso = useSelectTool('lasso'); + + useRegisteredHotkeys({ + id: 'selectLassoTool', + category: 'canvas', + callback: selectLasso, + options: { enabled: !isSelected }, + dependencies: [isSelected, selectLasso], + }); + + return ( + + } + colorScheme={isSelected ? 'invokeBlue' : 'base'} + variant="solid" + onClick={selectLasso} + /> + + ); +}); + +ToolLassoButton.displayName = 'ToolLassoButton'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolLassoModeToggle.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolLassoModeToggle.tsx new file mode 100644 index 0000000000..63aa27c609 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolLassoModeToggle.tsx @@ -0,0 +1,47 @@ +import { ButtonGroup, IconButton, Tooltip } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { selectLassoMode, settingsLassoModeChanged } from 'features/controlLayers/store/canvasSettingsSlice'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiPolygonBold, PiScribbleLoopBold } from 'react-icons/pi'; + +export const ToolLassoModeToggle = memo(() => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const lassoMode = useAppSelector(selectLassoMode); + + const setFreehand = useCallback(() => { + dispatch(settingsLassoModeChanged('freehand')); + }, [dispatch]); + + const setPolygon = useCallback(() => { + dispatch(settingsLassoModeChanged('polygon')); + }, [dispatch]); + + return ( + + + } + colorScheme={lassoMode === 'freehand' ? 'invokeBlue' : 'base'} + variant="solid" + onClick={setFreehand} + /> + + + } + colorScheme={lassoMode === 'polygon' ? 'invokeBlue' : 'base'} + variant="solid" + onClick={setPolygon} + /> + + + ); +}); + +ToolLassoModeToggle.displayName = 'ToolLassoModeToggle'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx index bf186ed630..fc34f4331c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx @@ -5,11 +5,13 @@ import { useToolIsSelected } from 'features/controlLayers/components/Tool/hooks' import { ToolFillColorPicker } from 'features/controlLayers/components/Tool/ToolFillColorPicker'; import { ToolGradientClipToggle } from 'features/controlLayers/components/Tool/ToolGradientClipToggle'; import { ToolGradientModeToggle } from 'features/controlLayers/components/Tool/ToolGradientModeToggle'; +import { ToolLassoModeToggle } from 'features/controlLayers/components/Tool/ToolLassoModeToggle'; import { ToolOptionsRowContainer } from 'features/controlLayers/components/Tool/ToolOptionsRowContainer'; import { ToolWidthPicker } from 'features/controlLayers/components/Tool/ToolWidthPicker'; import { CanvasToolbarFitBboxToLayersButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarFitBboxToLayersButton'; import { CanvasToolbarFitBboxToMasksButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarFitBboxToMasksButton'; import { CanvasToolbarNewSessionMenuButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarNewSessionMenuButton'; +import { CanvasToolbarProjectMenuButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarProjectMenuButton'; import { CanvasToolbarRedoButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarRedoButton'; import { CanvasToolbarResetViewButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarResetViewButton'; import { CanvasToolbarSaveToGalleryButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarSaveToGalleryButton'; @@ -31,6 +33,7 @@ export const CanvasToolbar = memo(() => { const isBrushSelected = useToolIsSelected('brush'); const isEraserSelected = useToolIsSelected('eraser'); const isTextSelected = useToolIsSelected('text'); + const isLassoSelected = useToolIsSelected('lasso'); const isGradientSelected = useToolIsSelected('gradient'); const showToolWithPicker = useMemo(() => { return !isTextSelected && (isBrushSelected || isEraserSelected); @@ -57,6 +60,11 @@ export const CanvasToolbar = memo(() => { )} + {isLassoSelected && ( + + + + )} {isTextSelected ? : showToolWithPicker && } @@ -67,6 +75,7 @@ export const CanvasToolbar = memo(() => { + diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarProjectMenuButton.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarProjectMenuButton.tsx new file mode 100644 index 0000000000..92cdc629ac --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarProjectMenuButton.tsx @@ -0,0 +1,37 @@ +import { IconButton, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library'; +import { useLoadCanvasProjectWithDialog } from 'features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog'; +import { useSaveCanvasProjectWithDialog } from 'features/controlLayers/components/SaveCanvasProjectDialog'; +import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiArchiveBold, PiFileArrowDownBold, PiFileArrowUpBold } from 'react-icons/pi'; + +export const CanvasToolbarProjectMenuButton = memo(() => { + const { t } = useTranslation(); + const isBusy = useCanvasIsBusy(); + const saveCanvasProject = useSaveCanvasProjectWithDialog(); + const loadCanvasProject = useLoadCanvasProjectWithDialog(); + + return ( + + } + variant="link" + alignSelf="stretch" + /> + + } isDisabled={isBusy} onClick={saveCanvasProject}> + {t('controlLayers.canvasProject.saveProject')} + + } isDisabled={isBusy} onClick={loadCanvasProject}> + {t('controlLayers.canvasProject.loadProject')} + + + + ); +}); + +CanvasToolbarProjectMenuButton.displayName = 'CanvasToolbarProjectMenuButton'; diff --git a/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasProjectLoad.ts b/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasProjectLoad.ts new file mode 100644 index 0000000000..21de5d3b22 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasProjectLoad.ts @@ -0,0 +1,157 @@ +import { logger } from 'app/logging/logger'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { parseify } from 'common/util/serialize'; +import { canvasProjectRecalled } from 'features/controlLayers/store/canvasSlice'; +import { loraAllDeleted, loraRecalled } from 'features/controlLayers/store/lorasSlice'; +import { paramsRecalled } from 'features/controlLayers/store/paramsSlice'; +import { refImagesRecalled } from 'features/controlLayers/store/refImagesSlice'; +import type { LoRA, ParamsState, RefImageState } from 'features/controlLayers/store/types'; +import type { CanvasProjectState } from 'features/controlLayers/util/canvasProjectFile'; +import { + checkExistingImages, + collectImageNames, + parseManifest, + processWithConcurrencyLimit, + remapCanvasState, + remapRefImages, +} from 'features/controlLayers/util/canvasProjectFile'; +import { toast } from 'features/toast/toast'; +import JSZip from 'jszip'; +import { useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { uploadImage } from 'services/api/endpoints/images'; + +const log = logger('canvas'); + +export const useCanvasProjectLoad = () => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + + const loadCanvasProject = useCallback( + async (file: File) => { + try { + const zip = await JSZip.loadAsync(file); + + // Validate manifest + const manifestFile = zip.file('manifest.json'); + if (!manifestFile) { + throw new Error('Invalid project file: missing manifest.json'); + } + const manifestData = JSON.parse(await manifestFile.async('string')); + parseManifest(manifestData); + + // Read state files + const canvasStateFile = zip.file('canvas_state.json'); + if (!canvasStateFile) { + throw new Error('Invalid project file: missing canvas_state.json'); + } + const canvasState: CanvasProjectState = JSON.parse(await canvasStateFile.async('string')); + + const paramsFile = zip.file('params.json'); + let projectParams: ParamsState | null = null; + if (paramsFile) { + projectParams = JSON.parse(await paramsFile.async('string')); + } + + const refImagesFile = zip.file('ref_images.json'); + let refImages: RefImageState[] = []; + if (refImagesFile) { + refImages = JSON.parse(await refImagesFile.async('string')); + } + + const lorasFile = zip.file('loras.json'); + let loras: LoRA[] = []; + if (lorasFile) { + loras = JSON.parse(await lorasFile.async('string')); + } + + // Collect all image names referenced in the state + const imageNames = collectImageNames(canvasState, refImages); + + // Check which images already exist on the server + const { missing } = await checkExistingImages(imageNames); + + // Upload missing images from the ZIP + const imageNameMapping = new Map(); + const imagesFolder = zip.folder('images'); + + if (imagesFolder && missing.size > 0) { + await processWithConcurrencyLimit(Array.from(missing), async (imageName) => { + const imageFile = imagesFolder.file(imageName); + if (!imageFile) { + log.warn(`Image ${imageName} referenced but not found in ZIP`); + return; + } + + try { + const blob = await imageFile.async('blob'); + const uploadFile = new File([blob], imageName, { type: 'image/png' }); + const imageDTO = await uploadImage({ + file: uploadFile, + image_category: 'general', + is_intermediate: false, + silent: true, + }); + + // Map old name to new name (only if different) + if (imageDTO.image_name !== imageName) { + imageNameMapping.set(imageName, imageDTO.image_name); + } + } catch (error) { + log.warn({ error: parseify(error) }, `Failed to upload image ${imageName}`); + } + }); + } + + // Remap image names in state objects + const remappedCanvasState = remapCanvasState(canvasState, imageNameMapping); + const remappedRefImages = remapRefImages(refImages, imageNameMapping); + + // Dispatch state restoration + dispatch( + canvasProjectRecalled({ + rasterLayers: remappedCanvasState.rasterLayers, + controlLayers: remappedCanvasState.controlLayers, + inpaintMasks: remappedCanvasState.inpaintMasks, + regionalGuidance: remappedCanvasState.regionalGuidance, + bbox: remappedCanvasState.bbox, + selectedEntityIdentifier: remappedCanvasState.selectedEntityIdentifier, + bookmarkedEntityIdentifier: remappedCanvasState.bookmarkedEntityIdentifier, + }) + ); + + // Restore reference images + dispatch(refImagesRecalled({ entities: remappedRefImages, replace: true })); + + // Restore generation parameters + if (projectParams) { + dispatch(paramsRecalled(projectParams)); + } + + // Restore LoRAs (always clear, even if project has none) + dispatch(loraAllDeleted()); + for (const lora of loras) { + dispatch(loraRecalled({ lora })); + } + + toast({ + id: 'CANVAS_PROJECT_LOAD_SUCCESS', + title: t('controlLayers.canvasProject.loadSuccess'), + description: t('controlLayers.canvasProject.loadSuccessDesc'), + status: 'success', + }); + } catch (error) { + log.error({ error: parseify(error) }, 'Failed to load canvas project'); + toast({ + id: 'CANVAS_PROJECT_LOAD_ERROR', + title: t('controlLayers.canvasProject.loadError'), + description: String(error), + status: 'error', + }); + } + }, + [dispatch, t] + ); + + return { loadCanvasProject }; +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasProjectSave.ts b/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasProjectSave.ts new file mode 100644 index 0000000000..76a91a2efa --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasProjectSave.ts @@ -0,0 +1,116 @@ +import { logger } from 'app/logging/logger'; +import { useAppStore } from 'app/store/storeHooks'; +import { parseify } from 'common/util/serialize'; +import { downloadBlob } from 'features/controlLayers/konva/util'; +import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice'; +import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; +import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice'; +import { selectCanvasSlice } from 'features/controlLayers/store/selectors'; +import type { CanvasProjectManifest, CanvasProjectState } from 'features/controlLayers/util/canvasProjectFile'; +import { + CANVAS_PROJECT_EXTENSION, + CANVAS_PROJECT_VERSION, + collectImageNames, + processWithConcurrencyLimit, +} from 'features/controlLayers/util/canvasProjectFile'; +import { toast } from 'features/toast/toast'; +import JSZip from 'jszip'; +import { useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useGetAppVersionQuery } from 'services/api/endpoints/appInfo'; + +const log = logger('canvas'); + +const sanitizeFileName = (name: string): string => { + // Replace characters that are invalid in filenames + return name.replace(/[<>:"/\\|?*]/g, '_').trim() || 'canvas-project'; +}; + +export const useCanvasProjectSave = () => { + const { t } = useTranslation(); + const store = useAppStore(); + const { data: appVersion } = useGetAppVersionQuery(); + + const saveCanvasProject = useCallback( + async (name: string) => { + try { + const state = store.getState(); + const canvasState = selectCanvasSlice(state); + const paramsState = selectParamsSlice(state); + const refImagesState = selectRefImagesSlice(state); + const lorasState = selectLoRAsSlice(state); + + // Build the canvas project state + const projectState: CanvasProjectState = { + rasterLayers: canvasState.rasterLayers.entities, + controlLayers: canvasState.controlLayers.entities, + inpaintMasks: canvasState.inpaintMasks.entities, + regionalGuidance: canvasState.regionalGuidance.entities, + bbox: canvasState.bbox, + selectedEntityIdentifier: canvasState.selectedEntityIdentifier, + bookmarkedEntityIdentifier: canvasState.bookmarkedEntityIdentifier, + }; + + // Collect all image names referenced in the state + const imageNames = collectImageNames(projectState, refImagesState.entities); + + // Build ZIP + const zip = new JSZip(); + + // Add manifest + const manifest: CanvasProjectManifest = { + version: CANVAS_PROJECT_VERSION, + appVersion: appVersion?.version ?? 'unknown', + createdAt: new Date().toISOString(), + name, + }; + zip.file('manifest.json', JSON.stringify(manifest, null, 2)); + + // Add state files + zip.file('canvas_state.json', JSON.stringify(projectState, null, 2)); + zip.file('params.json', JSON.stringify(paramsState, null, 2)); + zip.file('ref_images.json', JSON.stringify(refImagesState.entities, null, 2)); + zip.file('loras.json', JSON.stringify(lorasState.loras, null, 2)); + + // Fetch and add images + const imagesFolder = zip.folder('images')!; + await processWithConcurrencyLimit(Array.from(imageNames), async (imageName) => { + try { + const response = await fetch(`/api/v1/images/i/${imageName}/full`); + if (!response.ok) { + log.warn(`Failed to fetch image ${imageName}: ${response.status}`); + return; + } + const blob = await response.blob(); + imagesFolder.file(imageName, blob); + } catch (error) { + log.warn({ error: parseify(error) }, `Failed to fetch image ${imageName}`); + } + }); + + // Generate ZIP blob and trigger download + const blob = await zip.generateAsync({ type: 'blob' }); + const fileName = `${sanitizeFileName(name)}${CANVAS_PROJECT_EXTENSION}`; + downloadBlob(blob, fileName); + + toast({ + id: 'CANVAS_PROJECT_SAVE_SUCCESS', + title: t('controlLayers.canvasProject.saveSuccess'), + description: t('controlLayers.canvasProject.saveSuccessDesc', { count: imageNames.size }), + status: 'success', + }); + } catch (error) { + log.error({ error: parseify(error) }, 'Failed to save canvas project'); + toast({ + id: 'CANVAS_PROJECT_SAVE_ERROR', + title: t('controlLayers.canvasProject.saveError'), + description: String(error), + status: 'error', + }); + } + }, + [appVersion?.version, store, t] + ); + + return { saveCanvasProject }; +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer.ts index ef5cee8d89..9941761a2e 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer.ts @@ -8,6 +8,7 @@ import { CanvasObjectEraserLine } from 'features/controlLayers/konva/CanvasObjec import { CanvasObjectEraserLineWithPressure } from 'features/controlLayers/konva/CanvasObject/CanvasObjectEraserLineWithPressure'; import { CanvasObjectGradient } from 'features/controlLayers/konva/CanvasObject/CanvasObjectGradient'; import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage'; +import { CanvasObjectLasso } from 'features/controlLayers/konva/CanvasObject/CanvasObjectLasso'; import { CanvasObjectRect } from 'features/controlLayers/konva/CanvasObject/CanvasObjectRect'; import type { AnyObjectRenderer, AnyObjectState } from 'features/controlLayers/konva/CanvasObject/types'; import { getPrefixedId } from 'features/controlLayers/konva/util'; @@ -152,6 +153,15 @@ export class CanvasEntityBufferObjectRenderer extends CanvasModuleBase { this.konva.group.add(this.renderer.konva.group); } + didRender = this.renderer.update(this.state, true); + } else if (this.state.type === 'lasso') { + assert(this.renderer instanceof CanvasObjectLasso || !this.renderer); + + if (!this.renderer) { + this.renderer = new CanvasObjectLasso(this.state, this); + this.konva.group.add(this.renderer.konva.group); + } + didRender = this.renderer.update(this.state, true); } else if (this.state.type === 'gradient') { assert(this.renderer instanceof CanvasObjectGradient || !this.renderer); @@ -247,6 +257,9 @@ export class CanvasEntityBufferObjectRenderer extends CanvasModuleBase { case 'rect': this.manager.stateApi.addRect({ entityIdentifier, rect: this.state }); break; + case 'lasso': + this.manager.stateApi.addLasso({ entityIdentifier, lasso: this.state }); + break; case 'gradient': this.manager.stateApi.addGradient({ entityIdentifier, gradient: this.state }); break; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts index 1498a6cbb5..903ccaa772 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts @@ -10,6 +10,7 @@ import { CanvasObjectEraserLine } from 'features/controlLayers/konva/CanvasObjec import { CanvasObjectEraserLineWithPressure } from 'features/controlLayers/konva/CanvasObject/CanvasObjectEraserLineWithPressure'; import { CanvasObjectGradient } from 'features/controlLayers/konva/CanvasObject/CanvasObjectGradient'; import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage'; +import { CanvasObjectLasso } from 'features/controlLayers/konva/CanvasObject/CanvasObjectLasso'; import { CanvasObjectRect } from 'features/controlLayers/konva/CanvasObject/CanvasObjectRect'; import type { AnyObjectRenderer, AnyObjectState } from 'features/controlLayers/konva/CanvasObject/types'; import { LightnessToAlphaFilter } from 'features/controlLayers/konva/filters'; @@ -397,6 +398,16 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase { this.konva.objectGroup.add(renderer.konva.group); } + didRender = renderer.update(objectState, force || isFirstRender); + } else if (objectState.type === 'lasso') { + assert(renderer instanceof CanvasObjectLasso || !renderer); + + if (!renderer) { + renderer = new CanvasObjectLasso(objectState, this); + this.renderers.set(renderer.id, renderer); + this.konva.objectGroup.add(renderer.konva.group); + } + didRender = renderer.update(objectState, force || isFirstRender); } else if (objectState.type === 'gradient') { assert(renderer instanceof CanvasObjectGradient || !renderer); @@ -433,17 +444,21 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase { * these visually transparent shapes in its calculation: * * - Eraser lines, which are normal lines with a globalCompositeOperation of 'destination-out'. + * - Subtracting lasso shapes, which use a globalCompositeOperation of 'destination-out'. * - Clipped portions of any shape. * - Images, which may have transparent areas. */ needsPixelBbox = (): boolean => { let needsPixelBbox = false; for (const renderer of this.renderers.values()) { - const isEraserLine = renderer instanceof CanvasObjectEraserLine; + const isEraserLine = + renderer instanceof CanvasObjectEraserLine || renderer instanceof CanvasObjectEraserLineWithPressure; + const isSubtractingLasso = + renderer instanceof CanvasObjectLasso && renderer.state.compositeOperation === 'destination-out'; const isImage = renderer instanceof CanvasObjectImage; const imageIgnoresTransparency = isImage && renderer.state.usePixelBbox === false; const hasClip = renderer instanceof CanvasObjectBrushLine && renderer.state.clip; - if (isEraserLine || hasClip || (isImage && !imageIgnoresTransparency)) { + if (isEraserLine || isSubtractingLasso || hasClip || (isImage && !imageIgnoresTransparency)) { needsPixelBbox = true; break; } diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectLasso.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectLasso.ts new file mode 100644 index 0000000000..ad433a9f67 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectLasso.ts @@ -0,0 +1,85 @@ +import { deepClone } from 'common/util/deepClone'; +import type { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer'; +import type { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer'; +import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; +import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'; +import type { CanvasLassoState } from 'features/controlLayers/store/types'; +import Konva from 'konva'; +import type { Logger } from 'roarr'; + +export class CanvasObjectLasso extends CanvasModuleBase { + readonly type = 'object_lasso'; + readonly id: string; + readonly path: string[]; + readonly parent: CanvasEntityObjectRenderer | CanvasEntityBufferObjectRenderer; + readonly manager: CanvasManager; + readonly log: Logger; + + state: CanvasLassoState; + konva: { + group: Konva.Group; + line: Konva.Line; + }; + + constructor(state: CanvasLassoState, parent: CanvasEntityObjectRenderer | CanvasEntityBufferObjectRenderer) { + super(); + this.id = state.id; + this.parent = parent; + this.manager = parent.manager; + this.path = this.manager.buildPath(this); + this.log = this.manager.buildLogger(this); + + this.log.debug({ state }, 'Creating module'); + + this.konva = { + group: new Konva.Group({ + name: `${this.type}:group`, + listening: false, + }), + line: new Konva.Line({ + name: `${this.type}:line`, + listening: false, + closed: true, + fill: 'white', + strokeEnabled: false, + perfectDrawEnabled: false, + }), + }; + this.konva.group.add(this.konva.line); + this.state = state; + } + + update(state: CanvasLassoState, force = false): boolean { + if (force || this.state !== state) { + this.log.trace({ state }, 'Updating lasso'); + this.konva.line.setAttrs({ + points: state.points, + globalCompositeOperation: state.compositeOperation, + }); + this.state = state; + return true; + } + + return false; + } + + setVisibility(isVisible: boolean): void { + this.log.trace({ isVisible }, 'Setting lasso visibility'); + this.konva.group.visible(isVisible); + } + + destroy = () => { + this.log.debug('Destroying module'); + this.konva.group.destroy(); + }; + + repr = () => { + return { + id: this.id, + type: this.type, + path: this.path, + parent: this.parent.id, + state: deepClone(this.state), + }; + }; +} diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/types.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/types.ts index 31a2bfee07..f193c0b391 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/types.ts @@ -4,6 +4,7 @@ import type { CanvasObjectEraserLine } from 'features/controlLayers/konva/Canvas import type { CanvasObjectEraserLineWithPressure } from 'features/controlLayers/konva/CanvasObject/CanvasObjectEraserLineWithPressure'; import type { CanvasObjectGradient } from 'features/controlLayers/konva/CanvasObject/CanvasObjectGradient'; import type { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage'; +import type { CanvasObjectLasso } from 'features/controlLayers/konva/CanvasObject/CanvasObjectLasso'; import type { CanvasObjectRect } from 'features/controlLayers/konva/CanvasObject/CanvasObjectRect'; import type { CanvasBrushLineState, @@ -12,6 +13,7 @@ import type { CanvasEraserLineWithPressureState, CanvasGradientState, CanvasImageState, + CanvasLassoState, CanvasRectState, } from 'features/controlLayers/store/types'; @@ -25,6 +27,7 @@ export type AnyObjectRenderer = | CanvasObjectEraserLine | CanvasObjectEraserLineWithPressure | CanvasObjectRect + | CanvasObjectLasso | CanvasObjectImage | CanvasObjectGradient; /** @@ -37,4 +40,5 @@ export type AnyObjectState = | CanvasEraserLineWithPressureState | CanvasImageState | CanvasRectState + | CanvasLassoState | CanvasGradientState; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts index e00ade1f8b..7d4c76b0c0 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts @@ -21,6 +21,7 @@ import { entityBrushLineAdded, entityEraserLineAdded, entityGradientAdded, + entityLassoAdded, entityMovedBy, entityMovedTo, entityRasterized, @@ -43,6 +44,7 @@ import type { EntityEraserLineAddedPayload, EntityGradientAddedPayload, EntityIdentifierPayload, + EntityLassoAddedPayload, EntityMovedByPayload, EntityMovedToPayload, EntityRasterizedPayload, @@ -175,6 +177,13 @@ export class CanvasStateApiModule extends CanvasModuleBase { this.store.dispatch(entityRectAdded(arg)); }; + /** + * Adds a lasso object to an entity, pushing state to redux. + */ + addLasso = (arg: EntityLassoAddedPayload) => { + this.store.dispatch(entityLassoAdded(arg)); + }; + /** * Adds a gradient to an entity, pushing state to redux. */ diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasLassoToolModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasLassoToolModule.ts new file mode 100644 index 0000000000..12f2638abc --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasLassoToolModule.ts @@ -0,0 +1,566 @@ +import { rgbaColorToString } from 'common/util/colorCodeTransformers'; +import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; +import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'; +import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule'; +import { getPrefixedId, isDistanceMoreThanMin, offsetCoord } from 'features/controlLayers/konva/util'; +import type { Coordinate } from 'features/controlLayers/store/types'; +import { simplifyFlatNumbersArray } from 'features/controlLayers/util/simplify'; +import Konva from 'konva'; +import type { KonvaEventObject } from 'konva/lib/Node'; +import type { Logger } from 'roarr'; + +type CanvasLassoToolModuleConfig = { + PREVIEW_STROKE_COLOR: string; + PREVIEW_FILL_COLOR: string; + PREVIEW_STROKE_WIDTH_PX: number; + START_POINT_RADIUS_PX: number; + START_POINT_STROKE_WIDTH_PX: number; + START_POINT_HOVER_RADIUS_DELTA_PX: number; + POLYGON_CLOSE_RADIUS_PX: number; + MIN_FREEHAND_POINT_DISTANCE_PX: number; + MAX_FREEHAND_SEGMENT_LENGTH_PX: number; + FREEHAND_SIMPLIFY_MIN_POINTS: number; + FREEHAND_SIMPLIFY_TOLERANCE: number; +}; + +const DEFAULT_CONFIG: CanvasLassoToolModuleConfig = { + PREVIEW_STROKE_COLOR: rgbaColorToString({ r: 90, g: 175, b: 255, a: 1 }), + PREVIEW_FILL_COLOR: rgbaColorToString({ r: 90, g: 175, b: 255, a: 0.2 }), + PREVIEW_STROKE_WIDTH_PX: 1.5, + START_POINT_RADIUS_PX: 4, + START_POINT_STROKE_WIDTH_PX: 2, + START_POINT_HOVER_RADIUS_DELTA_PX: 2, + POLYGON_CLOSE_RADIUS_PX: 10, + MIN_FREEHAND_POINT_DISTANCE_PX: 1, + MAX_FREEHAND_SEGMENT_LENGTH_PX: 2, + FREEHAND_SIMPLIFY_MIN_POINTS: 200, + FREEHAND_SIMPLIFY_TOLERANCE: 0.6, +}; + +export class CanvasLassoToolModule extends CanvasModuleBase { + readonly type = 'lasso_tool'; + readonly id: string; + readonly path: string[]; + readonly parent: CanvasToolModule; + readonly manager: CanvasManager; + readonly log: Logger; + + config: CanvasLassoToolModuleConfig = DEFAULT_CONFIG; + + private freehandPoints: Coordinate[] = []; + private polygonPoints: Coordinate[] = []; + private polygonPointer: Coordinate | null = null; + private isDrawingFreehand = false; + + konva: { + group: Konva.Group; + fillShape: Konva.Line; + strokeShape: Konva.Line; + startPointIndicator: Konva.Circle; + }; + + constructor(parent: CanvasToolModule) { + super(); + this.id = getPrefixedId(this.type); + this.parent = parent; + this.manager = this.parent.manager; + this.path = this.manager.buildPath(this); + this.log = this.manager.buildLogger(this); + this.log.debug('Creating module'); + + this.konva = { + group: new Konva.Group({ name: `${this.type}:group`, listening: false }), + fillShape: new Konva.Line({ + name: `${this.type}:fill_shape`, + listening: false, + closed: true, + fill: this.config.PREVIEW_FILL_COLOR, + strokeEnabled: false, + visible: false, + perfectDrawEnabled: false, + }), + strokeShape: new Konva.Line({ + name: `${this.type}:stroke_shape`, + listening: false, + closed: false, + stroke: this.config.PREVIEW_STROKE_COLOR, + strokeWidth: this.config.PREVIEW_STROKE_WIDTH_PX, + lineCap: 'round', + lineJoin: 'round', + fillEnabled: false, + visible: false, + perfectDrawEnabled: false, + }), + startPointIndicator: new Konva.Circle({ + name: `${this.type}:start_point_indicator`, + listening: false, + fillEnabled: false, + stroke: this.config.PREVIEW_STROKE_COLOR, + visible: false, + perfectDrawEnabled: false, + }), + }; + + this.konva.group.add(this.konva.fillShape); + this.konva.group.add(this.konva.strokeShape); + this.konva.group.add(this.konva.startPointIndicator); + } + + syncCursorStyle = () => { + if (!this.parent.getCanDraw()) { + this.manager.stage.setCursor('not-allowed'); + return; + } + this.manager.stage.setCursor('crosshair'); + }; + + render = () => { + const tool = this.parent.$tool.get(); + const isTemporaryViewSwitch = tool === 'view' && this.parent.$toolBuffer.get() === 'lasso'; + if (tool !== 'lasso' && !isTemporaryViewSwitch) { + this.hidePreview(); + return; + } + + if (tool === 'lasso') { + this.syncCursorStyle(); + } + this.syncPreview(); + }; + + onToolChanged = () => { + const tool = this.parent.$tool.get(); + const isTemporaryViewSwitch = tool === 'view' && this.parent.$toolBuffer.get() === 'lasso'; + if (tool !== 'lasso' && !isTemporaryViewSwitch) { + this.reset(); + } + }; + + hasActiveSession = (): boolean => { + return this.isDrawingFreehand || this.freehandPoints.length > 0 || this.polygonPoints.length > 0; + }; + + onStagePointerDown = (e: KonvaEventObject) => { + const cursorPos = this.parent.$cursorPos.get(); + if (!cursorPos) { + return; + } + + const lassoMode = this.manager.stateApi.getSettings().lassoMode; + const point = cursorPos.relative; + + // Keep middle click for pan and right click for context menu. + if (e.evt.button !== 0) { + return; + } + + if (lassoMode === 'freehand') { + if (!this.parent.$isPrimaryPointerDown.get()) { + return; + } + + this.polygonPoints = []; + this.polygonPointer = null; + this.freehandPoints = [point]; + this.isDrawingFreehand = true; + this.syncPreview(); + return; + } + + this.freehandPoints = []; + this.isDrawingFreehand = false; + + if (this.polygonPoints.length === 0) { + this.polygonPoints = [point]; + this.polygonPointer = point; + this.syncPreview(); + return; + } + + const startPoint = this.polygonPoints[0]; + if (!startPoint) { + return; + } + + if ( + this.polygonPoints.length >= 3 && + Math.hypot(point.x - startPoint.x, point.y - startPoint.y) <= this.getPolygonCloseRadius() + ) { + this.commitContour(this.polygonPoints); + this.reset(); + return; + } + + const snappedPoint = this.getPolygonPoint(point, e.evt.shiftKey); + this.polygonPoints = [...this.polygonPoints, snappedPoint]; + this.polygonPointer = snappedPoint; + this.syncPreview(); + }; + + onStagePointerMove = (_e: KonvaEventObject) => { + this.handlePointerMove(_e.evt.shiftKey); + }; + + onWindowPointerMove = (e: PointerEvent) => { + this.handlePointerMove(e.shiftKey); + }; + + onStagePointerUp = (_e: KonvaEventObject) => { + const lassoMode = this.manager.stateApi.getSettings().lassoMode; + if (lassoMode !== 'freehand' || !this.isDrawingFreehand) { + return; + } + + this.commitContour(this.freehandPoints, true); + this.reset(); + }; + + onWindowPointerUp = () => { + const lassoMode = this.manager.stateApi.getSettings().lassoMode; + if (lassoMode !== 'freehand' || !this.isDrawingFreehand) { + return; + } + + this.commitContour(this.freehandPoints, true); + this.reset(); + }; + + reset = () => { + this.freehandPoints = []; + this.polygonPoints = []; + this.polygonPointer = null; + this.isDrawingFreehand = false; + this.hidePreview(); + }; + + private handlePointerMove = (shouldSnap: boolean) => { + const cursorPos = this.parent.$cursorPos.get(); + if (!cursorPos) { + return; + } + + const lassoMode = this.manager.stateApi.getSettings().lassoMode; + const point = cursorPos.relative; + + if (lassoMode === 'freehand') { + if (!this.isDrawingFreehand || !this.parent.$isPrimaryPointerDown.get()) { + return; + } + + const minDistance = this.manager.stage.unscale(this.config.MIN_FREEHAND_POINT_DISTANCE_PX); + const lastPoint = this.freehandPoints.at(-1) ?? null; + if (!isDistanceMoreThanMin(point, lastPoint, minDistance)) { + return; + } + this.appendFreehandPoint(point); + this.syncPreview(); + return; + } + + if (this.polygonPoints.length > 0) { + this.polygonPointer = this.getPolygonPoint(point, shouldSnap); + this.syncPreview(); + } + }; + + private appendFreehandPoint = (point: Coordinate) => { + const lastPoint = this.freehandPoints.at(-1) ?? null; + if (!lastPoint) { + this.freehandPoints.push(point); + return; + } + + const maxSegmentLength = this.manager.stage.unscale(this.config.MAX_FREEHAND_SEGMENT_LENGTH_PX); + const dx = point.x - lastPoint.x; + const dy = point.y - lastPoint.y; + const distance = Math.hypot(dx, dy); + + if (distance <= maxSegmentLength) { + this.freehandPoints.push(point); + return; + } + + const steps = Math.ceil(distance / maxSegmentLength); + for (let i = 1; i <= steps; i++) { + const t = i / steps; + this.freehandPoints.push({ + x: lastPoint.x + dx * t, + y: lastPoint.y + dy * t, + }); + } + }; + + private hidePreview = () => { + this.konva.strokeShape.visible(false); + this.konva.fillShape.visible(false); + this.konva.startPointIndicator.visible(false); + }; + + private syncPreview = () => { + const lassoMode = this.manager.stateApi.getSettings().lassoMode; + const stageScale = this.manager.stage.getScale(); + const strokeWidth = this.config.PREVIEW_STROKE_WIDTH_PX / stageScale; + + let points: Coordinate[] = []; + if (lassoMode === 'freehand') { + points = this.freehandPoints; + } else { + points = [...this.polygonPoints]; + if (this.polygonPointer) { + points.push(this.polygonPointer); + } + } + + if (points.length < 1) { + this.hidePreview(); + return; + } + + const flat = points.flatMap((point) => [point.x, point.y]); + this.konva.strokeShape.setAttrs({ + points: flat, + strokeWidth, + visible: true, + }); + + if (points.length >= 3) { + this.konva.fillShape.setAttrs({ + points: flat, + visible: true, + }); + } else { + this.konva.fillShape.visible(false); + } + + if (lassoMode === 'polygon' && this.polygonPoints.length > 0) { + const startPoint = this.polygonPoints[0]; + if (startPoint) { + const isHoveringStartPoint = this.getIsHoveringStartPoint(startPoint); + const baseRadius = this.manager.stage.unscale(this.config.START_POINT_RADIUS_PX); + this.konva.startPointIndicator.setAttrs({ + x: startPoint.x, + y: startPoint.y, + radius: + baseRadius + + (isHoveringStartPoint ? this.manager.stage.unscale(this.config.START_POINT_HOVER_RADIUS_DELTA_PX) : 0), + strokeWidth: this.manager.stage.unscale(this.config.START_POINT_STROKE_WIDTH_PX), + visible: true, + }); + } + } else { + this.konva.startPointIndicator.visible(false); + } + }; + + private getPolygonCloseRadius = (): number => { + return this.manager.stage.unscale(this.config.POLYGON_CLOSE_RADIUS_PX); + }; + + private getIsHoveringStartPoint = (startPoint: Coordinate): boolean => { + if (this.polygonPoints.length < 3) { + return false; + } + + const pointerPoint = this.parent.$cursorPos.get()?.relative; + if (!pointerPoint) { + return false; + } + + return Math.hypot(pointerPoint.x - startPoint.x, pointerPoint.y - startPoint.y) <= this.getPolygonCloseRadius(); + }; + + private getPolygonPoint = (point: Coordinate, shouldSnap: boolean): Coordinate => { + if (!shouldSnap) { + return point; + } + + const lastPoint = this.polygonPoints.at(-1); + if (!lastPoint) { + return point; + } + + const dx = point.x - lastPoint.x; + const dy = point.y - lastPoint.y; + const distance = Math.hypot(dx, dy); + + if (distance === 0) { + return point; + } + + const SNAP_ANGLE = Math.PI / 4; + const angle = Math.atan2(dy, dx); + const snappedAngle = Math.round(angle / SNAP_ANGLE) * SNAP_ANGLE; + + const snappedPoint = { + x: lastPoint.x + Math.cos(snappedAngle) * distance, + y: lastPoint.y + Math.sin(snappedAngle) * distance, + }; + + return this.alignPointToStart(snappedPoint); + }; + + private alignPointToStart = (point: Coordinate): Coordinate => { + if (this.polygonPoints.length < 2) { + return point; + } + + const startPoint = this.polygonPoints[0]; + if (!startPoint) { + return point; + } + + const alignThreshold = this.getPolygonCloseRadius(); + const deltaX = Math.abs(point.x - startPoint.x); + const deltaY = Math.abs(point.y - startPoint.y); + const canAlignX = deltaX <= alignThreshold; + const canAlignY = deltaY <= alignThreshold; + + if (!canAlignX && !canAlignY) { + return point; + } + + if (canAlignX && canAlignY) { + if (deltaX <= deltaY) { + return { x: startPoint.x, y: point.y }; + } + return { x: point.x, y: startPoint.y }; + } + + if (canAlignX) { + return { x: startPoint.x, y: point.y }; + } + + return { x: point.x, y: startPoint.y }; + }; + + private closeContour = (points: Coordinate[]): Coordinate[] => { + if (points.length === 0) { + return []; + } + + const start = points[0]; + const end = points.at(-1); + if (!start || !end) { + return points; + } + + if (start.x === end.x && start.y === end.y) { + return points; + } + + return [...points, start]; + }; + + private commitContour = (points: Coordinate[], simplifyFreehand: boolean = false) => { + const contourPoints = simplifyFreehand ? this.simplifyFreehandContour(points) : points; + if (contourPoints.length < 3) { + return; + } + + const closedPoints = this.closeContour(contourPoints); + if (closedPoints.length < 4) { + return; + } + + let targetMaskId = this.getActiveInpaintMaskId(); + if (!targetMaskId) { + this.manager.stateApi.addInpaintMask({ isSelected: true }); + targetMaskId = this.getActiveInpaintMaskId(); + } + + if (!targetMaskId) { + return; + } + + const targetMaskState = this.manager.stateApi + .getInpaintMasksState() + .entities.find((entity) => entity.id === targetMaskId); + if (!targetMaskState) { + return; + } + + const normalizedPoints = closedPoints.flatMap((point) => { + const normalizedPoint = offsetCoord(point, targetMaskState.position); + return [normalizedPoint.x, normalizedPoint.y]; + }); + + this.manager.stateApi.addLasso({ + entityIdentifier: { type: 'inpaint_mask', id: targetMaskId }, + lasso: { + id: getPrefixedId('lasso'), + type: 'lasso', + points: normalizedPoints, + compositeOperation: + this.manager.stateApi.$ctrlKey.get() || this.manager.stateApi.$metaKey.get() + ? 'destination-out' + : 'source-over', + }, + }); + }; + + private simplifyFreehandContour = (points: Coordinate[]): Coordinate[] => { + if (points.length < this.config.FREEHAND_SIMPLIFY_MIN_POINTS) { + return points; + } + + const flatPoints = points.flatMap((point) => [point.x, point.y]); + const simplifiedFlatPoints = simplifyFlatNumbersArray(flatPoints, { + tolerance: this.config.FREEHAND_SIMPLIFY_TOLERANCE, + highestQuality: true, + }); + if (simplifiedFlatPoints.length < 6) { + return points; + } + + const simplifiedPoints = this.flatNumbersToCoords(simplifiedFlatPoints); + if (simplifiedPoints.length < 3) { + return points; + } + + return simplifiedPoints; + }; + + private flatNumbersToCoords = (points: number[]): Coordinate[] => { + const coords: Coordinate[] = []; + for (let i = 0; i < points.length; i += 2) { + const x = points[i]; + const y = points[i + 1]; + if (x === undefined || y === undefined) { + continue; + } + coords.push({ x, y }); + } + return coords; + }; + + private getActiveInpaintMaskId = (): string | null => { + const canvasState = this.manager.stateApi.getCanvasState(); + const selectedEntityIdentifier = canvasState.selectedEntityIdentifier; + if (selectedEntityIdentifier?.type === 'inpaint_mask') { + const selectedMask = canvasState.inpaintMasks.entities.find( + (entity) => entity.id === selectedEntityIdentifier.id + ); + if (selectedMask?.isEnabled) { + return selectedMask.id; + } + // If the selected mask is disabled, commit to a new mask instead. + return null; + } + + const inpaintMasks = canvasState.inpaintMasks.entities; + const activeMask = [...inpaintMasks].reverse().find((entity) => entity.isEnabled); + return activeMask?.id ?? null; + }; + + repr = () => { + return { + id: this.id, + type: this.type, + path: this.path, + freehandPoints: this.freehandPoints, + polygonPoints: this.polygonPoints, + polygonPointer: this.polygonPointer, + isDrawingFreehand: this.isDrawingFreehand, + }; + }; +} diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts index c25a714bad..668ac7be3b 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts @@ -5,6 +5,7 @@ import { CanvasBrushToolModule } from 'features/controlLayers/konva/CanvasTool/C import { CanvasColorPickerToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasColorPickerToolModule'; import { CanvasEraserToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasEraserToolModule'; import { CanvasGradientToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasGradientToolModule'; +import { CanvasLassoToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasLassoToolModule'; import { CanvasMoveToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasMoveToolModule'; import { CanvasRectToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasRectToolModule'; import { CanvasTextToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasTextToolModule'; @@ -38,6 +39,7 @@ Konva.dragButtons = [0]; const KEY_ESCAPE = 'Escape'; const KEY_SPACE = ' '; const KEY_ALT = 'Alt'; +const CODE_SPACE = 'Space'; type CanvasToolModuleConfig = { BRUSH_SPACING_TARGET_SCALE: number; @@ -62,6 +64,7 @@ export class CanvasToolModule extends CanvasModuleBase { brush: CanvasBrushToolModule; eraser: CanvasEraserToolModule; rect: CanvasRectToolModule; + lasso: CanvasLassoToolModule; gradient: CanvasGradientToolModule; colorPicker: CanvasColorPickerToolModule; bbox: CanvasBboxToolModule; @@ -121,6 +124,7 @@ export class CanvasToolModule extends CanvasModuleBase { brush: new CanvasBrushToolModule(this), eraser: new CanvasEraserToolModule(this), rect: new CanvasRectToolModule(this), + lasso: new CanvasLassoToolModule(this), gradient: new CanvasGradientToolModule(this), colorPicker: new CanvasColorPickerToolModule(this), bbox: new CanvasBboxToolModule(this), @@ -139,15 +143,26 @@ export class CanvasToolModule extends CanvasModuleBase { this.konva.group.add(this.tools.colorPicker.konva.group); this.konva.group.add(this.tools.text.konva.group); this.konva.group.add(this.tools.bbox.konva.group); + this.konva.group.add(this.tools.lasso.konva.group); this.subscriptions.add(this.manager.stage.$stageAttrs.listen(this.render)); this.subscriptions.add(this.manager.$isBusy.listen(this.render)); this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectCanvasSettingsSlice, this.render)); this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectCanvasSlice, this.render)); this.subscriptions.add( - this.$tool.listen(() => { - // On tool switch, reset mouse state - this.manager.tool.$isPrimaryPointerDown.set(false); + this.$tool.listen((tool, previousTool) => { + // Preserve pointer state during temporary view switching so lasso sessions can freeze/resume on space. + const shouldPreservePointerState = + this.$toolBuffer.get() === 'lasso' && + this.tools.lasso.hasActiveSession() && + ((previousTool === 'lasso' && tool === 'view') || (previousTool === 'view' && tool === 'lasso')); + + if (!shouldPreservePointerState) { + // On tool switch, reset mouse state + this.manager.tool.$isPrimaryPointerDown.set(false); + } + + this.tools.lasso.onToolChanged(); void this.tools.text.onToolChanged(); this.render(); }) @@ -189,6 +204,8 @@ export class CanvasToolModule extends CanvasModuleBase { this.tools.colorPicker.syncCursorStyle(); } else if (tool === 'text') { this.tools.text.syncCursorStyle(); + } else if (tool === 'lasso') { + this.tools.lasso.syncCursorStyle(); } else if (selectedEntityAdapter) { if (selectedEntityAdapter.$isDisabled.get()) { stage.setCursor('not-allowed'); @@ -222,6 +239,7 @@ export class CanvasToolModule extends CanvasModuleBase { this.tools.colorPicker.render(); this.tools.text.render(); this.tools.bbox.render(); + this.tools.lasso.render(); }; syncCursorPositions = () => { @@ -235,6 +253,19 @@ export class CanvasToolModule extends CanvasModuleBase { this.$cursorPos.set({ relative, absolute }); }; + syncCursorPositionsFromWindowEvent = (e: PointerEvent): boolean => { + this.konva.stage.setPointersPositions(e); + const relative = this.konva.stage.getRelativePointerPosition(); + const absolute = this.konva.stage.getPointerPosition(); + + if (!relative || !absolute) { + return false; + } + + this.$cursorPos.set({ relative, absolute }); + return true; + }; + getClip = ( entity: CanvasRegionalGuidanceState | CanvasControlLayerState | CanvasRasterLayerState | CanvasInpaintMaskState ) => { @@ -274,6 +305,7 @@ export class CanvasToolModule extends CanvasModuleBase { window.addEventListener('keydown', this.onKeyDown); window.addEventListener('keyup', this.onKeyUp); + window.addEventListener('pointermove', this.onWindowPointerMove); window.addEventListener('pointerup', this.onWindowPointerUp); window.addEventListener('blur', this.onWindowBlur); @@ -289,6 +321,7 @@ export class CanvasToolModule extends CanvasModuleBase { window.removeEventListener('keydown', this.onKeyDown); window.removeEventListener('keyup', this.onKeyUp); + window.removeEventListener('pointermove', this.onWindowPointerMove); window.removeEventListener('pointerup', this.onWindowPointerUp); window.removeEventListener('blur', this.onWindowBlur); }; @@ -316,6 +349,18 @@ export class CanvasToolModule extends CanvasModuleBase { return true; } + if (tool === 'lasso') { + if (this.manager.$isBusy.get()) { + return false; + } + + if (this.manager.stage.getIsDragging()) { + return false; + } + + return true; + } + if (this.manager.stateApi.getRenderedEntityCount() === 0) { return false; } @@ -407,6 +452,8 @@ export class CanvasToolModule extends CanvasModuleBase { await this.tools.eraser.onStagePointerDown(e); } else if (tool === 'rect') { await this.tools.rect.onStagePointerDown(e); + } else if (tool === 'lasso') { + await this.tools.lasso.onStagePointerDown(e); } else if (tool === 'gradient') { await this.tools.gradient.onStagePointerDown(e); } else if (tool === 'text') { @@ -441,6 +488,8 @@ export class CanvasToolModule extends CanvasModuleBase { this.tools.eraser.onStagePointerUp(e); } else if (tool === 'rect') { this.tools.rect.onStagePointerUp(e); + } else if (tool === 'lasso') { + void this.tools.lasso.onStagePointerUp(e); } else if (tool === 'gradient') { this.tools.gradient.onStagePointerUp(e); } @@ -476,6 +525,8 @@ export class CanvasToolModule extends CanvasModuleBase { await this.tools.eraser.onStagePointerMove(e); } else if (tool === 'rect') { await this.tools.rect.onStagePointerMove(e); + } else if (tool === 'lasso') { + await this.tools.lasso.onStagePointerMove(e); } else if (tool === 'gradient') { await this.tools.gradient.onStagePointerMove(e); } else if (tool === 'text') { @@ -560,6 +611,7 @@ export class CanvasToolModule extends CanvasModuleBase { onWindowPointerUp = (_: PointerEvent) => { try { this.$isPrimaryPointerDown.set(false); + void this.tools.lasso.onWindowPointerUp(); const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter(); if (selectedEntity && selectedEntity.bufferRenderer.hasBuffer() && !this.manager.$isBusy.get()) { @@ -570,6 +622,41 @@ export class CanvasToolModule extends CanvasModuleBase { } }; + onWindowPointerMove = (e: PointerEvent) => { + const target = e.target; + if (target instanceof Node && this.manager.stage.container.contains(target)) { + return; + } + + if (this.$tool.get() !== 'lasso') { + return; + } + + if (!this.getCanDraw()) { + return; + } + + if (!this.$isPrimaryPointerDown.get()) { + return; + } + + if (!this.tools.lasso.hasActiveSession()) { + return; + } + + try { + this.$lastPointerType.set(e.pointerType); + + if (!this.syncCursorPositionsFromWindowEvent(e)) { + return; + } + + this.tools.lasso.onWindowPointerMove(e); + } finally { + this.render(); + } + }; + /** * We want to reset any "quick-switch" tool selection on window blur. Fixes an issue where you alt-tab out of the app * and the color picker tool is still active when you come back. @@ -579,6 +666,7 @@ export class CanvasToolModule extends CanvasModuleBase { }; onKeyDown = (e: KeyboardEvent) => { + const isSpaceKey = e.key === KEY_SPACE || e.code === CODE_SPACE; if (e.target instanceof HTMLInputElement || e.target instanceof HTMLTextAreaElement) { return; } @@ -600,6 +688,9 @@ export class CanvasToolModule extends CanvasModuleBase { if (e.key === KEY_ESCAPE) { // Cancel shape drawing on escape e.preventDefault(); + if (this.$tool.get() === 'lasso') { + this.tools.lasso.reset(); + } const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter(); if ( selectedEntity && @@ -612,19 +703,27 @@ export class CanvasToolModule extends CanvasModuleBase { return; } - if (e.key === KEY_SPACE) { + if (isSpaceKey) { // Select the view tool on space key down e.preventDefault(); - this.$toolBuffer.set(this.$tool.get()); - this.$tool.set('view'); + e.stopPropagation(); + const currentTool = this.$tool.get(); + this.$toolBuffer.set(currentTool); this.manager.stateApi.$spaceKey.set(true); - this.$cursorPos.set(null); + this.$tool.set('view'); + if (currentTool === 'lasso' && this.tools.lasso.hasActiveSession() && this.$isPrimaryPointerDown.get()) { + // Start panning immediately if user is already drawing with freehand lasso. + this.manager.stage.startDragging(); + } else { + this.$cursorPos.set(null); + } return; } if (e.key === KEY_ALT) { // Select the color picker on alt key down e.preventDefault(); + e.stopPropagation(); this.$toolBuffer.set(this.$tool.get()); this.$tool.set('colorPicker'); } @@ -644,9 +743,10 @@ export class CanvasToolModule extends CanvasModuleBase { return; } - if (e.key === KEY_SPACE) { + if (e.key === KEY_SPACE || e.code === CODE_SPACE) { // Revert the tool to the previous tool on space key up e.preventDefault(); + e.stopPropagation(); this.revertToolBuffer(); this.manager.stateApi.$spaceKey.set(false); return; @@ -655,6 +755,7 @@ export class CanvasToolModule extends CanvasModuleBase { if (e.key === KEY_ALT) { // Revert the tool to the previous tool on alt key up e.preventDefault(); + e.stopPropagation(); this.revertToolBuffer(); return; } @@ -684,6 +785,7 @@ export class CanvasToolModule extends CanvasModuleBase { eraser: this.tools.eraser.repr(), colorPicker: this.tools.colorPicker.repr(), rect: this.tools.rect.repr(), + lasso: this.tools.lasso.repr(), gradient: this.tools.gradient.repr(), bbox: this.tools.bbox.repr(), view: this.tools.view.repr(), diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts index 91428b4521..202b70e142 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts @@ -13,6 +13,7 @@ const zTransformSmoothingMode = z.enum(['bilinear', 'bicubic', 'hamming', 'lancz export type TransformSmoothingMode = z.infer; const zGradientType = z.enum(['linear', 'radial']); +const zLassoMode = z.enum(['freehand', 'polygon']); const zCanvasSettingsState = z.object({ /** @@ -118,6 +119,10 @@ const zCanvasSettingsState = z.object({ * Whether the gradient tool clips to the drag gesture. */ gradientClipEnabled: z.boolean().default(true), + /** + * The lasso tool mode. + */ + lassoMode: zLassoMode.default('freehand'), }); type CanvasSettingsState = z.infer; @@ -148,6 +153,7 @@ const getInitialState = (): CanvasSettingsState => ({ transformSmoothingMode: 'bicubic', gradientType: 'linear', gradientClipEnabled: true, + lassoMode: 'freehand', }); const slice = createSlice({ @@ -245,6 +251,9 @@ const slice = createSlice({ settingsGradientClipToggled: (state) => { state.gradientClipEnabled = !state.gradientClipEnabled; }, + settingsLassoModeChanged: (state, action: PayloadAction) => { + state.lassoMode = action.payload; + }, }, }); @@ -276,6 +285,7 @@ export const { settingsFillColorPickerPinnedSet, settingsGradientTypeChanged, settingsGradientClipToggled, + settingsLassoModeChanged, } = slice.actions; export const canvasSettingsSliceConfig: SliceConfig = { @@ -317,3 +327,4 @@ export const selectTransformSmoothingEnabled = createCanvasSettingsSelector( export const selectTransformSmoothingMode = createCanvasSettingsSelector((settings) => settings.transformSmoothingMode); export const selectGradientType = createCanvasSettingsSelector((settings) => settings.gradientType); export const selectGradientClipEnabled = createCanvasSettingsSelector((settings) => settings.gradientClipEnabled); +export const selectLassoMode = createCanvasSettingsSelector((settings) => settings.lassoMode); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index 79d3963d12..fd170e19e8 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -66,6 +66,7 @@ import type { EntityEraserLineAddedPayload, EntityGradientAddedPayload, EntityIdentifierPayload, + EntityLassoAddedPayload, EntityMovedToPayload, EntityRasterizedPayload, EntityRectAddedPayload, @@ -99,6 +100,12 @@ import { makeDefaultRasterLayerAdjustments, } from './util'; +const resetInpaintMasksHiddenIfEmpty = (state: CanvasState) => { + if (state.inpaintMasks.entities.length === 0) { + state.inpaintMasks.isHidden = false; + } +}; + const slice = createSlice({ name: 'canvas', initialState: getInitialCanvasState(), @@ -1061,6 +1068,7 @@ const slice = createSlice({ (entity) => !mergedEntitiesToDelete.includes(entity.id) ); } + resetInpaintMasksHiddenIfEmpty(state); const entityIdentifier = getEntityIdentifier(entityState); if (isSelected || mergedEntitiesToDelete.length > 0) { @@ -1132,6 +1140,7 @@ const slice = createSlice({ if (replace) { // Remove the inpaint mask state.inpaintMasks.entities = state.inpaintMasks.entities.filter((layer) => layer.id !== entityIdentifier.id); + resetInpaintMasksHiddenIfEmpty(state); } // Add the new regional guidance @@ -1548,6 +1557,17 @@ const slice = createSlice({ // re-render it (reference equality check). I don't like this behaviour. entity.objects.push({ ...rect }); }, + entityLassoAdded: (state, action: PayloadAction) => { + const { entityIdentifier, lasso } = action.payload; + const entity = selectEntity(state, entityIdentifier); + if (!entity) { + return; + } + + // TODO(psyche): If we add the object without splatting, the renderer will see it as the same object and not + // re-render it (reference equality check). I don't like this behaviour. + entity.objects.push({ ...lasso }); + }, entityGradientAdded: (state, action: PayloadAction) => { const { entityIdentifier, gradient } = action.payload; const entity = selectEntity(state, entityIdentifier); @@ -1590,6 +1610,7 @@ const slice = createSlice({ break; } + resetInpaintMasksHiddenIfEmpty(state); state.selectedEntityIdentifier = selectedEntityIdentifier; }, entityArrangedForwardOne: (state, action: PayloadAction) => { @@ -1678,6 +1699,7 @@ const slice = createSlice({ break; case 'inpaint_mask': state.inpaintMasks.isHidden = !state.inpaintMasks.isHidden; + resetInpaintMasksHiddenIfEmpty(state); break; case 'regional_guidance': state.regionalGuidance.isHidden = !state.regionalGuidance.isHidden; @@ -1686,13 +1708,16 @@ const slice = createSlice({ }, allNonRasterLayersIsHiddenToggled: (state) => { const hasVisibleNonRasterLayers = - !state.controlLayers.isHidden || !state.inpaintMasks.isHidden || !state.regionalGuidance.isHidden; + (state.controlLayers.entities.length > 0 && !state.controlLayers.isHidden) || + (state.inpaintMasks.entities.length > 0 && !state.inpaintMasks.isHidden) || + (state.regionalGuidance.entities.length > 0 && !state.regionalGuidance.isHidden); const shouldHide = hasVisibleNonRasterLayers; state.controlLayers.isHidden = shouldHide; state.inpaintMasks.isHidden = shouldHide; state.regionalGuidance.isHidden = shouldHide; + resetInpaintMasksHiddenIfEmpty(state); }, allEntitiesDeleted: (state) => { // Deleting all entities is equivalent to resetting the state for each entity type @@ -1708,6 +1733,37 @@ const slice = createSlice({ state.inpaintMasks.entities = inpaintMasks; state.rasterLayers.entities = rasterLayers; state.regionalGuidance.entities = regionalGuidance; + resetInpaintMasksHiddenIfEmpty(state); + return state; + }, + canvasProjectRecalled: ( + state, + action: PayloadAction<{ + rasterLayers: CanvasRasterLayerState[]; + controlLayers: CanvasControlLayerState[]; + inpaintMasks: CanvasInpaintMaskState[]; + regionalGuidance: CanvasRegionalGuidanceState[]; + bbox: CanvasState['bbox']; + selectedEntityIdentifier: CanvasState['selectedEntityIdentifier']; + bookmarkedEntityIdentifier: CanvasState['bookmarkedEntityIdentifier']; + }> + ) => { + const { + rasterLayers, + controlLayers, + inpaintMasks, + regionalGuidance, + bbox, + selectedEntityIdentifier, + bookmarkedEntityIdentifier, + } = action.payload; + state.rasterLayers.entities = rasterLayers; + state.controlLayers.entities = controlLayers; + state.inpaintMasks.entities = inpaintMasks; + state.regionalGuidance.entities = regionalGuidance; + state.bbox = bbox; + state.selectedEntityIdentifier = selectedEntityIdentifier; + state.bookmarkedEntityIdentifier = bookmarkedEntityIdentifier; return state; }, canvasUndo: () => {}, @@ -1768,6 +1824,7 @@ const resetState = (state: CanvasState) => { export const { canvasMetadataRecalled, + canvasProjectRecalled, canvasUndo, canvasRedo, canvasClearHistory, @@ -1787,6 +1844,7 @@ export const { entityBrushLineAdded, entityEraserLineAdded, entityRectAdded, + entityLassoAdded, entityGradientAdded, // Raster layer adjustments rasterLayerAdjustmentsSet, @@ -1913,7 +1971,13 @@ export const canvasSliceConfig: SliceConfig = { }, }; -const doNotGroupMatcher = isAnyOf(entityBrushLineAdded, entityEraserLineAdded, entityRectAdded, entityGradientAdded); +const doNotGroupMatcher = isAnyOf( + entityBrushLineAdded, + entityEraserLineAdded, + entityRectAdded, + entityLassoAdded, + entityGradientAdded +); // Store rapid actions of the same type at most once every x time. // See: https://github.com/omnidan/redux-undo/blob/master/examples/throttled-drag/util/undoFilter.js diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index 09a16f4bca..ba78c36c3f 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -471,6 +471,9 @@ const slice = createSlice({ } }, paramsReset: (state) => resetState(state), + paramsRecalled: (_state, action: PayloadAction) => { + return action.payload; + }, }, extraReducers(builder) { // Reset params state on logout to prevent user data leakage when switching users @@ -609,6 +612,7 @@ export const { syncedToOptimalDimension, paramsReset, + paramsRecalled, animaVaeModelSelected, animaQwen3EncoderModelSelected, animaT5EncoderModelSelected, diff --git a/invokeai/frontend/web/src/features/controlLayers/store/selectors.ts b/invokeai/frontend/web/src/features/controlLayers/store/selectors.ts index 5c0abfdb89..2e2ae09212 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/selectors.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/selectors.ts @@ -361,5 +361,9 @@ export const selectCanvasMetadata = createSelector( * This is used to determine the state of the toggle button that shows/hides all non-raster layers. */ export const selectNonRasterLayersIsHidden = createSelector(selectCanvasSlice, (canvas) => { - return canvas.controlLayers.isHidden && canvas.inpaintMasks.isHidden && canvas.regionalGuidance.isHidden; + const areControlLayersEffectivelyHidden = canvas.controlLayers.entities.length === 0 || canvas.controlLayers.isHidden; + const areInpaintMasksEffectivelyHidden = canvas.inpaintMasks.entities.length === 0 || canvas.inpaintMasks.isHidden; + const areRegionalGuidanceEffectivelyHidden = + canvas.regionalGuidance.entities.length === 0 || canvas.regionalGuidance.isHidden; + return areControlLayersEffectivelyHidden && areInpaintMasksEffectivelyHidden && areRegionalGuidanceEffectivelyHidden; }); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 7247f4cf86..c16ffdbeab 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -105,7 +105,7 @@ const zIPMethodV2 = z.enum(['full', 'style', 'composition', 'style_strong', 'sty export type IPMethodV2 = z.infer; export const isIPMethodV2 = (v: unknown): v is IPMethodV2 => zIPMethodV2.safeParse(v).success; -const _zTool = z.enum(['brush', 'eraser', 'move', 'rect', 'gradient', 'view', 'bbox', 'colorPicker', 'text']); +const _zTool = z.enum(['brush', 'eraser', 'move', 'rect', 'lasso', 'gradient', 'view', 'bbox', 'colorPicker', 'text']); export type Tool = z.infer; const zPoints = z.array(z.number()).refine((points) => points.length % 2 === 0, { @@ -260,6 +260,20 @@ const zCanvasRectState = z.object({ }); export type CanvasRectState = z.infer; +const zCanvasLassoCompositeOperation = z.enum(['source-over', 'destination-out']); + +const zCanvasLassoState = z.object({ + id: zId, + type: z.literal('lasso'), + /** + * Points in the format [x1, y1, x2, y2, ...]. + * The lasso tool always commits a closed contour. + */ + points: zPoints, + compositeOperation: zCanvasLassoCompositeOperation.default('source-over'), +}); +export type CanvasLassoState = z.infer; + // Gradient state includes clip metadata so the tool can optionally clip to drag gesture. const zCanvasLinearGradientState = z.object({ id: zId, @@ -309,6 +323,7 @@ const zCanvasObjectState = z.union([ zCanvasBrushLineState, zCanvasEraserLineState, zCanvasRectState, + zCanvasLassoState, zCanvasBrushLineWithPressureState, zCanvasEraserLineWithPressureState, zCanvasGradientState, @@ -955,6 +970,7 @@ export type EntityEraserLineAddedPayload = EntityIdentifierPayload<{ eraserLine: CanvasEraserLineState | CanvasEraserLineWithPressureState; }>; export type EntityRectAddedPayload = EntityIdentifierPayload<{ rect: CanvasRectState }>; +export type EntityLassoAddedPayload = EntityIdentifierPayload<{ lasso: CanvasLassoState }>; export type EntityGradientAddedPayload = EntityIdentifierPayload<{ gradient: CanvasGradientState }>; export type EntityRasterizedPayload = EntityIdentifierPayload<{ imageObject: CanvasImageState; diff --git a/invokeai/frontend/web/src/features/controlLayers/util/canvasProjectFile.ts b/invokeai/frontend/web/src/features/controlLayers/util/canvasProjectFile.ts new file mode 100644 index 0000000000..97ea31e8bb --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/util/canvasProjectFile.ts @@ -0,0 +1,287 @@ +import { deepClone } from 'common/util/deepClone'; +import type { + CanvasControlLayerState, + CanvasInpaintMaskState, + CanvasObjectState, + CanvasRasterLayerState, + CanvasRegionalGuidanceState, + CanvasState, + CroppableImageWithDims, + ImageWithDims, + RefImageState, +} from 'features/controlLayers/store/types'; +import { getImageDTOSafe } from 'services/api/endpoints/images'; +import { z } from 'zod'; + +export const CANVAS_PROJECT_VERSION = 1; +export const CANVAS_PROJECT_EXTENSION = '.invk'; + +// #region Manifest + +const zCanvasProjectManifest = z.object({ + version: z.literal(CANVAS_PROJECT_VERSION), + appVersion: z.string(), + createdAt: z.string(), + name: z.string(), +}); +export type CanvasProjectManifest = z.infer; + +export const parseManifest = (data: unknown): CanvasProjectManifest => { + return zCanvasProjectManifest.parse(data); +}; + +// #endregion + +// #region Canvas Project State + +export type CanvasProjectState = { + rasterLayers: CanvasRasterLayerState[]; + controlLayers: CanvasControlLayerState[]; + inpaintMasks: CanvasInpaintMaskState[]; + regionalGuidance: CanvasRegionalGuidanceState[]; + bbox: CanvasState['bbox']; + selectedEntityIdentifier: CanvasState['selectedEntityIdentifier']; + bookmarkedEntityIdentifier: CanvasState['bookmarkedEntityIdentifier']; +}; + +// #endregion + +// #region Image Name Collection + +/** + * Collects image_name values from a CroppableImageWithDims (used by ref images). + */ +const collectFromCroppableImage = (image: CroppableImageWithDims | null, names: Set): void => { + if (!image) { + return; + } + names.add(image.original.image.image_name); + if (image.crop?.image) { + names.add(image.crop.image.image_name); + } +}; + +/** + * Collects image_name values from an ImageWithDims (used by regional guidance ref images). + */ +const collectFromImageWithDims = (image: ImageWithDims | null, names: Set): void => { + if (!image) { + return; + } + names.add(image.image_name); +}; + +/** + * Collects image_name values from canvas objects (brush lines, images, etc.). + */ +const collectFromObjects = (objects: CanvasObjectState[], names: Set): void => { + for (const obj of objects) { + if (obj.type === 'image' && 'image_name' in obj.image) { + names.add(obj.image.image_name); + } + } +}; + +/** + * Walks the entire canvas state + ref images and returns a deduplicated set of all image_name references. + */ +export const collectImageNames = (canvasState: CanvasProjectState, refImages: RefImageState[]): Set => { + const names = new Set(); + + // Raster layers + for (const layer of canvasState.rasterLayers) { + collectFromObjects(layer.objects, names); + } + + // Control layers + for (const layer of canvasState.controlLayers) { + collectFromObjects(layer.objects, names); + } + + // Inpaint masks + for (const mask of canvasState.inpaintMasks) { + collectFromObjects(mask.objects, names); + } + + // Regional guidance + for (const rg of canvasState.regionalGuidance) { + collectFromObjects(rg.objects, names); + for (const refImage of rg.referenceImages) { + if (refImage.config.type === 'ip_adapter' || refImage.config.type === 'flux_redux') { + collectFromImageWithDims(refImage.config.image, names); + } + } + } + + // Global reference images + for (const refImage of refImages) { + collectFromCroppableImage(refImage.config.image, names); + } + + return names; +}; + +// #endregion + +// #region Image Name Remapping + +/** + * Remaps image_name values in a CroppableImageWithDims. + */ +/** + * Remaps image_name values in a CroppableImageWithDims in-place. + * Caller is responsible for cloning beforehand. + */ +const remapCroppableImage = (image: CroppableImageWithDims | null, mapping: Map): void => { + if (!image) { + return; + } + + const newOriginalName = mapping.get(image.original.image.image_name); + if (newOriginalName) { + image.original.image.image_name = newOriginalName; + } + + if (image.crop?.image) { + const newCropName = mapping.get(image.crop.image.image_name); + if (newCropName) { + image.crop.image.image_name = newCropName; + } + } +}; + +/** + * Remaps image_name in an ImageWithDims. + */ +const remapImageWithDims = (image: ImageWithDims | null, mapping: Map): ImageWithDims | null => { + if (!image) { + return null; + } + + const result = deepClone(image); + const newName = mapping.get(result.image_name); + if (newName) { + result.image_name = newName; + } + return result; +}; + +/** + * Remaps image_name values in canvas objects. + */ +const remapObjects = (objects: CanvasObjectState[], mapping: Map): CanvasObjectState[] => { + return objects.map((obj) => { + if (obj.type === 'image' && 'image_name' in obj.image) { + const newName = mapping.get(obj.image.image_name); + if (newName) { + return { ...obj, image: { ...obj.image, image_name: newName } }; + } + } + return obj; + }); +}; + +/** + * Deep-clones canvas state and remaps all image_name values using the provided mapping. + * Only images present in the mapping are changed (images that already existed on the server are skipped). + */ +export const remapCanvasState = (canvasState: CanvasProjectState, mapping: Map): CanvasProjectState => { + if (mapping.size === 0) { + return canvasState; + } + + const result = deepClone(canvasState); + + for (const layer of result.rasterLayers) { + layer.objects = remapObjects(layer.objects, mapping); + } + + for (const layer of result.controlLayers) { + layer.objects = remapObjects(layer.objects, mapping); + } + + for (const mask of result.inpaintMasks) { + mask.objects = remapObjects(mask.objects, mapping); + } + + for (const rg of result.regionalGuidance) { + rg.objects = remapObjects(rg.objects, mapping); + for (const refImage of rg.referenceImages) { + if (refImage.config.type === 'ip_adapter' || refImage.config.type === 'flux_redux') { + refImage.config.image = remapImageWithDims(refImage.config.image, mapping); + } + } + } + + return result; +}; + +/** + * Deep-clones ref images and remaps all image_name values using the provided mapping. + */ +export const remapRefImages = (refImages: RefImageState[], mapping: Map): RefImageState[] => { + if (mapping.size === 0) { + return refImages; + } + + return refImages.map((refImage) => { + const result = deepClone(refImage); + remapCroppableImage(result.config.image, mapping); + return result; + }); +}; + +// #endregion + +// #region Concurrency + +const MAX_CONCURRENT_REQUESTS = 5; + +/** + * Processes an array of async tasks with a concurrency limit. + */ +export const processWithConcurrencyLimit = async ( + items: T[], + fn: (item: T) => Promise, + limit: number = MAX_CONCURRENT_REQUESTS +): Promise => { + let index = 0; + + const next = async (): Promise => { + while (index < items.length) { + const currentIndex = index++; + await fn(items[currentIndex]!); + } + }; + + const workers = Array.from({ length: Math.min(limit, items.length) }, () => next()); + await Promise.all(workers); +}; + +// #endregion + +// #region Image Existence Check + +/** + * Checks which images already exist on the backend server. + * Returns sets of existing and missing image names. + */ +export const checkExistingImages = async ( + imageNames: Set +): Promise<{ existing: Set; missing: Set }> => { + const existing = new Set(); + const missing = new Set(); + + await processWithConcurrencyLimit(Array.from(imageNames), async (imageName) => { + const dto = await getImageDTOSafe(imageName); + if (dto) { + existing.add(imageName); + } else { + missing.add(imageName); + } + }); + + return { existing, missing }; +}; + +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodeCmdk/AddNodeCmdk.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodeCmdk/AddNodeCmdk.tsx index 9feae5215a..4a72cad8bf 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodeCmdk/AddNodeCmdk.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodeCmdk/AddNodeCmdk.tsx @@ -1,6 +1,7 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library'; import { Box, + Button, Flex, Icon, Input, @@ -17,6 +18,7 @@ import { useAppSelector, useAppStore } from 'app/store/storeHooks'; import { CommandEmpty, CommandItem, CommandList, CommandRoot } from 'cmdk'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; +import { capitalize } from 'es-toolkit'; import { memoize } from 'es-toolkit/compat'; import { useBuildNode } from 'features/nodes/hooks/useBuildNode'; import { @@ -33,16 +35,24 @@ import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupied import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; +import { selectShouldGroupNodesByCategory } from 'features/nodes/store/workflowSettingsSlice'; import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData'; import { toast } from 'features/toast/toast'; import { selectActiveTab } from 'features/ui/store/uiSelectors'; import { computed } from 'nanostores'; -import type { ChangeEvent } from 'react'; +import type { ChangeEvent, Dispatch, SetStateAction } from 'react'; import { memo, useCallback, useMemo, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; -import { PiCircuitryBold, PiFlaskBold, PiHammerBold, PiLightningFill } from 'react-icons/pi'; +import { + PiCaretDownBold, + PiCaretRightBold, + PiCircuitryBold, + PiFlaskBold, + PiHammerBold, + PiLightningFill, +} from 'react-icons/pi'; import type { S } from 'services/api/types'; import { objectEntries } from 'tsafe'; import { useDebounce } from 'use-debounce'; @@ -171,15 +181,36 @@ export const AddNodeCmdk = memo(() => { const onClose = useCallback(() => { close(); setSearchTerm(''); + setExpandedCategories(new Set()); $pendingConnection.set(null); }, [close]); + const [expandedCategories, setExpandedCategories] = useState>(new Set()); + + const toggleCategory = useCallback((category: string) => { + setExpandedCategories((prev) => { + const next = new Set(prev); + if (next.has(category)) { + next.delete(category); + } else { + next.add(category); + } + return next; + }); + }, []); + const onSelect = useCallback( (value: string) => { + // Category headers have a special prefix + if (value.startsWith('__category__:')) { + const category = value.slice('__category__:'.length); + toggleCategory(category); + return; + } addNode(value); onClose(); }, - [addNode, onClose] + [addNode, onClose, toggleCategory] ); return ( @@ -204,7 +235,12 @@ export const AddNodeCmdk = memo(() => { /> - + @@ -230,6 +266,7 @@ type NodeCommandItemData = { description: string; classification: S['Classification']; nodePack: string; + category: string; }; /** @@ -260,6 +297,7 @@ type FilterableItem = { tags: string[]; classification: S['Classification']; nodePack: string; + category: string; }; const filter = memoize( @@ -290,6 +328,10 @@ const filter = memoize( return true; } + if (item.category.includes(searchTerm) || regex.test(item.category)) { + return true; + } + for (const tag of item.tags) { if (tag.includes(searchTerm) || regex.test(tag)) { return true; @@ -301,112 +343,253 @@ const filter = memoize( (item: FilterableItem, searchTerm: string) => `${item.type}-${searchTerm}` ); -const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; onSelect: (value: string) => void }) => { - const { t } = useTranslation(); - const templatesArray = useStore($templatesArray); - const pendingConnection = useStore($pendingConnection); - const currentImageFilterItem = useMemo( - () => ({ - type: 'current_image', - title: t('nodes.currentImage'), - description: t('nodes.currentImageDescription'), - tags: ['progress', 'image', 'current'], - classification: 'stable', - nodePack: 'invokeai', - }), - [t] - ); - const notesFilterItem = useMemo( - () => ({ - type: 'notes', - title: t('nodes.notes'), - description: t('nodes.notesDescription'), - tags: ['notes'], - classification: 'stable', - nodePack: 'invokeai', - }), - [t] - ); +const categoryItemSx: SystemStyleObject = { + cursor: 'pointer', + userSelect: 'none', + '&[data-selected="true"]': { + bg: 'base.750', + }, +}; - const items = useMemo(() => { - // If we have a connection in progress, we need to filter the node choices - const _items: NodeCommandItemData[] = []; +const NodeCommandItem = memo( + ({ + item, + onSelect, + isGrouped, + }: { + item: NodeCommandItemData; + onSelect: (value: string) => void; + isGrouped?: boolean; + }) => ( + + + + {item.classification === 'beta' && } + {item.classification === 'prototype' && } + {item.classification === 'internal' && } + {item.classification === 'special' && } + {item.label} + + + {item.nodePack} + + + {item.description && {item.description}} + + + ) +); - if (!pendingConnection) { - for (const template of templatesArray) { - if (filter(template, searchTerm)) { - _items.push({ - label: template.title, - value: template.type, - description: template.description, - classification: template.classification, - nodePack: template.nodePack, - }); +NodeCommandItem.displayName = 'NodeCommandItem'; + +const NodeCommandList = memo( + ({ + searchTerm, + onSelect, + expandedCategories, + setExpandedCategories, + }: { + searchTerm: string; + onSelect: (value: string) => void; + expandedCategories: Set; + setExpandedCategories: Dispatch>>; + }) => { + const { t } = useTranslation(); + const templatesArray = useStore($templatesArray); + const pendingConnection = useStore($pendingConnection); + const shouldGroupNodesByCategory = useAppSelector(selectShouldGroupNodesByCategory); + const currentImageFilterItem = useMemo( + () => ({ + type: 'current_image', + title: t('nodes.currentImage'), + description: t('nodes.currentImageDescription'), + tags: ['progress', 'image', 'current'], + classification: 'stable', + nodePack: 'invokeai', + category: 'image', + }), + [t] + ); + const notesFilterItem = useMemo( + () => ({ + type: 'notes', + title: t('nodes.notes'), + description: t('nodes.notesDescription'), + tags: ['notes'], + classification: 'stable', + nodePack: 'invokeai', + category: 'other', + }), + [t] + ); + + const items = useMemo(() => { + // If we have a connection in progress, we need to filter the node choices + const _items: NodeCommandItemData[] = []; + + if (!pendingConnection) { + for (const template of templatesArray) { + if (filter(template, searchTerm)) { + _items.push({ + label: template.title, + value: template.type, + description: template.description, + classification: template.classification, + nodePack: template.nodePack, + category: template.category, + }); + } } - } - for (const item of [currentImageFilterItem, notesFilterItem]) { - if (filter(item, searchTerm)) { - _items.push({ - label: item.title, - value: item.type, - description: item.description, - classification: item.classification, - nodePack: item.nodePack, - }); + for (const item of [currentImageFilterItem, notesFilterItem]) { + if (filter(item, searchTerm)) { + _items.push({ + label: item.title, + value: item.type, + description: item.description, + classification: item.classification, + nodePack: item.nodePack, + category: item.category, + }); + } } - } - } else { - for (const template of templatesArray) { - if (filter(template, searchTerm)) { - const candidateFields = pendingConnection.handleType === 'source' ? template.inputs : template.outputs; + } else { + for (const template of templatesArray) { + if (filter(template, searchTerm)) { + const candidateFields = pendingConnection.handleType === 'source' ? template.inputs : template.outputs; - for (const [_fieldName, fieldTemplate] of objectEntries(candidateFields)) { - const sourceType = - pendingConnection.handleType === 'source' ? pendingConnection.fieldTemplate.type : fieldTemplate.type; - const targetType = - pendingConnection.handleType === 'target' ? pendingConnection.fieldTemplate.type : fieldTemplate.type; + for (const [_fieldName, fieldTemplate] of objectEntries(candidateFields)) { + const sourceType = + pendingConnection.handleType === 'source' ? pendingConnection.fieldTemplate.type : fieldTemplate.type; + const targetType = + pendingConnection.handleType === 'target' ? pendingConnection.fieldTemplate.type : fieldTemplate.type; - if (validateConnectionTypes(sourceType, targetType)) { - _items.push({ - label: template.title, - value: template.type, - description: template.description, - classification: template.classification, - nodePack: template.nodePack, - }); - break; + if (validateConnectionTypes(sourceType, targetType)) { + _items.push({ + label: template.title, + value: template.type, + description: template.description, + classification: template.classification, + nodePack: template.nodePack, + category: template.category, + }); + break; + } } } } } + + // Sort exact title matches to the top when searching + if (searchTerm) { + const lowerSearch = searchTerm.toLowerCase(); + _items.sort((a, b) => { + const aExact = a.label.toLowerCase() === lowerSearch; + const bExact = b.label.toLowerCase() === lowerSearch; + if (aExact && !bExact) { + return -1; + } + if (!aExact && bExact) { + return 1; + } + return 0; + }); + } + + return _items; + }, [pendingConnection, templatesArray, searchTerm, currentImageFilterItem, notesFilterItem]); + + const groupedItems = useMemo(() => { + const groups: Record = {}; + for (const item of items) { + const cat = item.category; + if (!groups[cat]) { + groups[cat] = []; + } + groups[cat].push(item); + } + // Sort categories alphabetically, but put "other" last. + // When searching, prioritize categories that contain an exact title match. + const lowerSearch = searchTerm.toLowerCase(); + return Object.entries(groups).sort(([a, aItems], [b, bItems]) => { + if (searchTerm) { + const aHasExact = aItems.some((item) => item.label.toLowerCase() === lowerSearch); + const bHasExact = bItems.some((item) => item.label.toLowerCase() === lowerSearch); + if (aHasExact && !bHasExact) { + return -1; + } + if (!aHasExact && bHasExact) { + return 1; + } + } + if (a === 'other') { + return 1; + } + if (b === 'other') { + return -1; + } + return a.localeCompare(b); + }); + }, [items, searchTerm]); + + // When searching, auto-expand all categories; when not searching, use manual state + const isSearching = searchTerm.length > 0; + + const expandAll = useCallback(() => { + setExpandedCategories(new Set(groupedItems.map(([cat]) => cat))); + }, [groupedItems, setExpandedCategories]); + + const collapseAll = useCallback(() => { + setExpandedCategories(new Set()); + }, [setExpandedCategories]); + + if (!shouldGroupNodesByCategory) { + return ( + <> + {items.map((item) => ( + + ))} + + ); } - return _items; - }, [pendingConnection, templatesArray, searchTerm, currentImageFilterItem, notesFilterItem]); - - return ( - <> - {items.map((item) => ( - - - - {item.classification === 'beta' && } - {item.classification === 'prototype' && } - {item.classification === 'internal' && } - {item.classification === 'special' && } - {item.label} - - - {item.nodePack} - - - {item.description && {item.description}} + return ( + <> + {!isSearching && ( + + + - - ))} - - ); -}); + )} + {groupedItems.map(([category, categoryItems]) => { + const isExpanded = isSearching || expandedCategories.has(category); + return ( + + + + + + {capitalize(category)} + + + ({categoryItems.length}) + + + + {isExpanded && + categoryItems.map((item) => ( + + ))} + + ); + })} + + ); + } +); NodeCommandList.displayName = 'CommandListItems'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index 0c48eddfc6..2fc5e12384 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -1,4 +1,4 @@ -import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library'; +import { Menu, MenuButton, MenuItem, MenuList, Portal, useGlobalMenuClose, useToken } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; import type { EdgeChange, @@ -32,7 +32,9 @@ import { $edgePendingUpdate, $lastEdgeUpdateMouseEvent, $pendingConnection, + $templates, $viewport, + connectorInserted, edgesChanged, nodesChanged, redo, @@ -46,18 +48,24 @@ import { selectNodes, selectNodesSlice, } from 'features/nodes/store/selectors'; +import { getConnectorDeletionSpliceConnections } from 'features/nodes/store/util/connectorTopology'; import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; +import { validateConnection } from 'features/nodes/store/util/validateConnection'; import { selectSelectionMode, selectShouldSnapToGrid } from 'features/nodes/store/workflowSettingsSlice'; import { NO_DRAG_CLASS, NO_PAN_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants'; import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation'; +import { buildConnectorNode } from 'features/nodes/util/node/buildConnectorNode'; import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData'; import type { CSSProperties, MouseEvent, RefObject } from 'react'; -import { memo, useCallback, useMemo, useRef } from 'react'; +import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; +import { useTranslation } from 'react-i18next'; +import { PiPlugsConnectedBold, PiTrashBold } from 'react-icons/pi'; import CustomConnectionLine from './connectionLines/CustomConnectionLine'; import InvocationCollapsedEdge from './edges/InvocationCollapsedEdge'; import InvocationDefaultEdge from './edges/InvocationDefaultEdge'; +import ConnectorNode from './nodes/Connector/ConnectorNode'; import CurrentImageNode from './nodes/CurrentImage/CurrentImageNode'; import InvocationNodeWrapper from './nodes/Invocation/InvocationNodeWrapper'; import NotesNode from './nodes/Notes/NotesNode'; @@ -70,6 +78,7 @@ const edgeTypes = { const nodeTypes = { invocation: InvocationNodeWrapper, + connector: ConnectorNode, current_image: CurrentImageNode, notes: NotesNode, } as const; @@ -81,17 +90,69 @@ const snapGrid: [number, number] = [25, 25]; const selectCancelConnection = (state: ReactFlowState) => state.cancelConnection; +type WorkflowContextMenuState = + | { + kind: 'pane'; + clientX: number; + clientY: number; + pageX: number; + pageY: number; + } + | { + kind: 'connector'; + connectorId: string; + pageX: number; + pageY: number; + } + | null; + +const getWorkflowContextMenuState = ( + event: globalThis.MouseEvent, + flowWrapper: HTMLDivElement | null +): WorkflowContextMenuState => { + if (event.shiftKey || !(event.target instanceof Element) || !flowWrapper?.contains(event.target)) { + return null; + } + + const connectorId = event.target.closest('[data-connector-node-id]')?.dataset.connectorNodeId; + if (connectorId) { + return { + kind: 'connector', + connectorId, + pageX: event.pageX, + pageY: event.pageY, + }; + } + + const paneTarget = event.target.closest('.react-flow__pane'); + if (paneTarget && flowWrapper.contains(paneTarget)) { + return { + kind: 'pane', + clientX: event.clientX, + clientY: event.clientY, + pageX: event.pageX, + pageY: event.pageY, + }; + } + + return null; +}; + export const Flow = memo(() => { + const { t } = useTranslation(); const dispatch = useAppDispatch(); const nodes = useAppSelector(selectNodes); const edges = useAppSelector(selectEdges); + const templates = useStore($templates); const viewport = useStore($viewport); const shouldSnapToGrid = useAppSelector(selectShouldSnapToGrid); const selectionMode = useAppSelector(selectSelectionMode); const { onConnectStart, onConnect, onConnectEnd } = useConnection(); const flowWrapper = useRef(null); + const pendingNodeInternalsUpdateRef = useRef(null); const isValidConnection = useIsValidConnection(); const updateNodeInternals = useUpdateNodeInternals(); + const [contextMenuState, setContextMenuState] = useState(null); useFocusRegion('workflows', flowWrapper); @@ -127,9 +188,16 @@ export const Flow = memo(() => { }, []); const { onCloseGlobal } = useGlobalMenuClose(); - const handlePaneClick = useCallback(() => { - onCloseGlobal(); - }, [onCloseGlobal]); + const handlePaneClick: NonNullable['onPaneClick']> = useCallback( + (event) => { + if ('button' in event && event.button !== 0) { + return; + } + onCloseGlobal(); + setContextMenuState(null); + }, + [onCloseGlobal] + ); const onInit: OnInit = useCallback((flow) => { $flow.set(flow); @@ -147,6 +215,149 @@ export const Flow = memo(() => { } }, []); + useEffect(() => { + const pendingNodeIds = pendingNodeInternalsUpdateRef.current; + if (!pendingNodeIds) { + return; + } + + pendingNodeInternalsUpdateRef.current = null; + + const frameId = requestAnimationFrame(() => { + updateNodeInternals([...new Set(pendingNodeIds)]); + }); + + return () => { + cancelAnimationFrame(frameId); + }; + }, [edges, nodes, updateNodeInternals]); + + const addConnectorAtPaneMenuPosition = useCallback(() => { + if (contextMenuState?.kind !== 'pane') { + return; + } + const flow = $flow.get(); + if (!flow) { + return; + } + const connector = buildConnectorNode( + flow.screenToFlowPosition({ + x: contextMenuState.clientX, + y: contextMenuState.clientY, + }) + ); + dispatch(nodesChanged([{ type: 'add', item: connector }])); + setContextMenuState(null); + }, [contextMenuState, dispatch]); + + const connectorSpliceConnections = useMemo( + () => + contextMenuState?.kind === 'connector' + ? getConnectorDeletionSpliceConnections( + contextMenuState.connectorId, + nodes, + edges, + templates, + validateConnection + ) + : null, + [contextMenuState, edges, nodes, templates] + ); + + const deleteConnectorFromContextMenu = useCallback(() => { + if (contextMenuState?.kind !== 'connector' || !connectorSpliceConnections) { + return; + } + const connectorEdgeRemovals: EdgeChange[] = edges + .filter((edge) => edge.source === contextMenuState.connectorId || edge.target === contextMenuState.connectorId) + .map((edge) => ({ type: 'remove', id: edge.id })); + const spliceEdgeAdditions: EdgeChange[] = connectorSpliceConnections.map((connection) => ({ + type: 'add', + item: connectionToEdge(connection), + })); + + pendingNodeInternalsUpdateRef.current = [ + contextMenuState.connectorId, + ...connectorSpliceConnections.flatMap((connection) => [connection.source, connection.target]), + ]; + dispatch(edgesChanged([...connectorEdgeRemovals, ...spliceEdgeAdditions])); + dispatch(nodesChanged([{ type: 'remove', id: contextMenuState.connectorId }])); + setContextMenuState(null); + }, [connectorSpliceConnections, contextMenuState, dispatch, edges]); + + useEffect(() => { + const onWindowContextMenu = (event: globalThis.MouseEvent) => { + const nextContextMenuState = getWorkflowContextMenuState(event, flowWrapper.current); + if (!nextContextMenuState) { + return; + } + + event.preventDefault(); + event.stopPropagation(); + setContextMenuState(nextContextMenuState); + }; + + window.addEventListener('contextmenu', onWindowContextMenu, { capture: true }); + + return () => { + window.removeEventListener('contextmenu', onWindowContextMenu, { capture: true }); + }; + }, []); + + const renderContextMenu = useCallback(() => { + if (contextMenuState?.kind === 'pane') { + return ( + + } onClick={addConnectorAtPaneMenuPosition}> + {t('nodes.addConnector')} + + + ); + } + + if (contextMenuState?.kind === 'connector') { + return ( + + } + onClick={deleteConnectorFromContextMenu} + isDisabled={!connectorSpliceConnections} + isDestructive + > + {t('nodes.deleteConnector')} + + + ); + } + + return ; + }, [addConnectorAtPaneMenuPosition, connectorSpliceConnections, contextMenuState, deleteConnectorFromContextMenu, t]); + + const closeContextMenu = useCallback(() => { + setContextMenuState(null); + }, []); + + const onEdgeDoubleClick = useCallback>( + (event, edge) => { + if (edge.type !== 'default' || edge.hidden) { + return; + } + const flow = $flow.get(); + if (!flow) { + return; + } + const connector = buildConnectorNode( + flow.screenToFlowPosition({ + x: event.clientX, + y: event.clientY, + }) + ); + dispatch(connectorInserted({ edgeId: edge.id, connector })); + updateNodeInternals([edge.source, edge.target, connector.id]); + }, + [dispatch, updateNodeInternals] + ); + // #region Updatable Edges /** @@ -209,16 +420,130 @@ export const Flow = memo(() => { // #endregion + const renderedNodes = useMemo(() => nodes, [nodes]); + + const renderedEdges = useMemo(() => edges, [edges]); + const contextMenuPosition = contextMenuState ? { x: contextMenuState.pageX, y: contextMenuState.pageY } : null; + const contextMenuKey = contextMenuPosition ? `${contextMenuPosition.x}-${contextMenuPosition.y}` : 'closed'; + return ( <> + + + + + {renderContextMenu()} + + + + + ); +}); + +Flow.displayName = 'Flow'; + +type FlowSurfaceProps = { + flowWrapper: { current: HTMLDivElement | null }; + viewport: ReactFlowProps['defaultViewport']; + renderedNodes: AnyNode[]; + renderedEdges: AnyEdge[]; + onInit: OnInit; + onMouseMove: (event: MouseEvent) => void; + onNodesChange: OnNodesChange; + onEdgesChange: OnEdgesChange; + onReconnect: OnReconnect; + onReconnectStart: NonNullable['onReconnectStart']>; + onReconnectEnd: NonNullable['onReconnectEnd']>; + onConnectStart: NonNullable['onConnectStart']>; + onConnect: NonNullable['onConnect']>; + onConnectEnd: NonNullable['onConnectEnd']>; + handleMoveEnd: OnMoveEnd; + onEdgeDoubleClick: NonNullable['onEdgeDoubleClick']>; + isValidConnection: NonNullable['isValidConnection']>; + shouldSnapToGrid: boolean; + flowStyles: CSSProperties; + handlePaneClick: NonNullable['onPaneClick']>; + selectionMode: ReturnType; +}; + +const FlowSurface = memo((props: FlowSurfaceProps) => { + const { + flowWrapper, + viewport, + renderedNodes, + renderedEdges, + onInit, + onMouseMove, + onNodesChange, + onEdgesChange, + onReconnect, + onReconnectStart, + onReconnectEnd, + onConnectStart, + onConnect, + onConnectEnd, + handleMoveEnd, + onEdgeDoubleClick, + isValidConnection, + shouldSnapToGrid, + flowStyles, + handlePaneClick, + selectionMode, + } = props; + + const setFlowWrapperElement = useCallback( + (el: HTMLDivElement | null) => { + flowWrapper.current = el; + }, + [flowWrapper] + ); + + return ( +
id="workflow-editor" - ref={flowWrapper} defaultViewport={viewport} nodeTypes={nodeTypes} edgeTypes={edgeTypes} - nodes={nodes} - edges={edges} + nodes={renderedNodes} + edges={renderedEdges} onInit={onInit} onMouseMove={onMouseMove} onNodesChange={onNodesChange} @@ -230,6 +555,7 @@ export const Flow = memo(() => { onConnect={onConnect} onConnectEnd={onConnectEnd} onMoveEnd={handleMoveEnd} + onEdgeDoubleClick={onEdgeDoubleClick} connectionLineComponent={CustomConnectionLine} isValidConnection={isValidConnection} minZoom={0.1} @@ -249,12 +575,11 @@ export const Flow = memo(() => { > - - +
); }); -Flow.displayName = 'Flow'; +FlowSurface.displayName = 'FlowSurface'; const HotkeyIsolator = memo(({ flowWrapper }: { flowWrapper: RefObject }) => { const mayUndo = useAppSelector(selectMayUndo); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/buildEdgeSelectors.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/buildEdgeSelectors.ts index b64e5a6e6a..40a44a16c2 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/buildEdgeSelectors.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/buildEdgeSelectors.ts @@ -1,9 +1,10 @@ import { createSelector } from '@reduxjs/toolkit'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; -import { selectNodes } from 'features/nodes/store/selectors'; +import { selectEdges, selectNodes } from 'features/nodes/store/selectors'; import type { Templates } from 'features/nodes/store/types'; +import { resolveConnectorSource, resolveConnectorSourceFieldType } from 'features/nodes/store/util/connectorTopology'; import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { isConnectorNode } from 'features/nodes/types/invocation'; import { getFieldColor } from './getEdgeColor'; @@ -22,23 +23,20 @@ export const buildSelectEdgeColor = ( target: string, targetHandleId: string | null | undefined ) => - createSelector(selectNodes, selectWorkflowSettingsSlice, (nodes, workflowSettings): string => { + createSelector(selectNodes, selectEdges, selectWorkflowSettingsSlice, (nodes, edges, workflowSettings): string => { const { shouldColorEdges } = workflowSettings; if (!shouldColorEdges) { return colorTokenToCssVar('base.500'); } const sourceNode = nodes.find((node) => node.id === source); - const targetNode = nodes.find((node) => node.id === target); - if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) { + if (!sourceNode || !sourceHandleId || !targetHandleId) { return colorTokenToCssVar('base.500'); } - const sourceNodeTemplate = templates[sourceNode.data.type]; - - const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode); - const outputFieldTemplate = sourceNodeTemplate?.outputs[sourceHandleId]; - const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined; + const sourceType = isConnectorNode(sourceNode) + ? resolveConnectorSourceFieldType(sourceNode.id, nodes, edges, templates) + : templates[sourceNode.data.type]?.outputs[sourceHandleId]?.type; return sourceType ? getFieldColor(sourceType) : colorTokenToCssVar('base.500'); }); @@ -50,7 +48,7 @@ export const buildSelectEdgeLabel = ( target: string, targetHandleId: string | null | undefined ) => - createSelector(selectNodes, (nodes): string | null => { + createSelector(selectNodes, selectEdges, (nodes, edges): string | null => { const sourceNode = nodes.find((node) => node.id === source); const targetNode = nodes.find((node) => node.id === target); @@ -58,8 +56,12 @@ export const buildSelectEdgeLabel = ( return null; } - const sourceNodeTemplate = templates[sourceNode.data.type]; + const resolvedSource = isConnectorNode(sourceNode) ? resolveConnectorSource(sourceNode.id, nodes, edges) : null; + const sourceTemplate = + resolvedSource !== null + ? templates[nodes.find((node) => node.id === resolvedSource.nodeId)?.data.type ?? ''] + : templates[sourceNode.data.type]; const targetNodeTemplate = templates[targetNode.data.type]; - return `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`; + return `${sourceTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`; }); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Connector/ConnectorNode.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Connector/ConnectorNode.tsx new file mode 100644 index 0000000000..a971efb397 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Connector/ConnectorNode.tsx @@ -0,0 +1,97 @@ +import type { SystemStyleObject } from '@invoke-ai/ui-library'; +import { Box, Icon } from '@invoke-ai/ui-library'; +import type { Node, NodeProps } from '@xyflow/react'; +import { Handle, Position } from '@xyflow/react'; +import NonInvocationNodeWrapper from 'features/nodes/components/flow/nodes/common/NonInvocationNodeWrapper'; +import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology'; +import { NO_DRAG_CLASS } from 'features/nodes/types/constants'; +import type { ConnectorNodeData } from 'features/nodes/types/invocation'; +import type { CSSProperties } from 'react'; +import { memo } from 'react'; +import { PiDotOutlineFill } from 'react-icons/pi'; + +const handleVisualSx = { + w: 3, + h: 3, + borderRadius: 'full', + borderWidth: 2, + borderColor: 'base.900', + bg: 'base.100', + pointerEvents: 'none', +} satisfies SystemStyleObject; + +const handleStyles = { + position: 'absolute', + width: '1rem', + height: '1rem', + display: 'flex', + alignItems: 'center', + justifyContent: 'center', + top: '50%', + transform: 'translateY(-50%)', + zIndex: 1, + background: 'none', + border: 'none', +} satisfies CSSProperties; + +const inputHandleStyles = { + ...handleStyles, + insetInlineStart: 0, + justifyContent: 'flex-start', +} satisfies CSSProperties; + +const outputHandleStyles = { + ...handleStyles, + insetInlineEnd: 0, + justifyContent: 'flex-end', +} satisfies CSSProperties; + +const ConnectorNode = ({ id, selected }: NodeProps>) => { + return ( + + + + + + + + + + + + + + ); +}; + +export default memo(ConnectorNode); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NonInvocationNodeWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NonInvocationNodeWrapper.tsx index 7e2cde7093..84246ba43b 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NonInvocationNodeWrapper.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NonInvocationNodeWrapper.tsx @@ -16,10 +16,12 @@ type NonInvocationNodeWrapperProps = PropsWithChildren & { nodeId: string; selected: boolean; width?: ChakraProps['w']; + borderRadius?: ChakraProps['borderRadius']; + withChrome?: boolean; }; const NonInvocationNodeWrapper = (props: NonInvocationNodeWrapperProps) => { - const { nodeId, width, children, selected } = props; + const { nodeId, width, children, selected, borderRadius = 'base', withChrome = true } = props; const mouseOverNode = useMouseOverNode(nodeId); const zoomToNode = useZoomToNode(nodeId); @@ -62,14 +64,15 @@ const NonInvocationNodeWrapper = (props: NonInvocationNodeWrapperProps) => { onMouseOut={mouseOverNode.handleMouseOut} className={DRAG_HANDLE_CLASSNAME} sx={containerSx} + borderRadius={borderRadius} width={width || NODE_WIDTH} opacity={opacity} data-is-selected={selected} > - - + {withChrome && } + {withChrome && } {children} - + {withChrome && } ); }; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/shared.ts b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/shared.ts index 70e56cb4db..15624780c2 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/shared.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/shared.ts @@ -6,7 +6,7 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library'; export const containerSx: SystemStyleObject = { h: 'full', position: 'relative', - borderRadius: 'base', + borderRadius: 'inherit', transitionProperty: 'none', cursor: 'grab', '--border-color': 'var(--invoke-colors-base-500)', @@ -30,7 +30,7 @@ export const containerSx: SystemStyleObject = { insetInlineEnd: 0, bottom: 0, insetInlineStart: 0, - borderRadius: 'base', + borderRadius: 'inherit', transitionProperty: 'none', pointerEvents: 'none', shadow: '0 0 0 1px var(--border-color)', @@ -64,7 +64,7 @@ export const shadowsSx: SystemStyleObject = { insetInlineEnd: 0, bottom: 0, insetInlineStart: 0, - borderRadius: 'base', + borderRadius: 'inherit', pointerEvents: 'none', zIndex: -1, shadow: 'var(--invoke-shadows-xl), var(--invoke-shadows-base), var(--invoke-shadows-base)', @@ -76,7 +76,7 @@ export const inProgressSx: SystemStyleObject = { insetInlineEnd: 0, bottom: 0, insetInlineStart: 0, - borderRadius: 'md', + borderRadius: 'inherit', pointerEvents: 'none', transitionProperty: 'none', opacity: 0.7, diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings.tsx index 0d3ca06c8a..2009f92144 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings.tsx @@ -24,11 +24,13 @@ import { selectSelectionMode, selectShouldAnimateEdges, selectShouldColorEdges, + selectShouldGroupNodesByCategory, selectShouldShouldValidateGraph, selectShouldShowEdgeLabels, selectShouldSnapToGrid, shouldAnimateEdgesChanged, shouldColorEdgesChanged, + shouldGroupNodesByCategoryChanged, shouldShowEdgeLabelsChanged, shouldSnapToGridChanged, shouldValidateGraphChanged, @@ -50,6 +52,7 @@ const WorkflowEditorSettings = () => { const shouldAnimateEdges = useAppSelector(selectShouldAnimateEdges); const shouldShowEdgeLabels = useAppSelector(selectShouldShowEdgeLabels); const shouldValidateGraph = useAppSelector(selectShouldShouldValidateGraph); + const shouldGroupNodesByCategory = useAppSelector(selectShouldGroupNodesByCategory); const handleChangeShouldValidate = useCallback( (e: ChangeEvent) => { @@ -93,6 +96,13 @@ const WorkflowEditorSettings = () => { [dispatch] ); + const handleChangeShouldGroupNodesByCategory = useCallback( + (e: ChangeEvent) => { + dispatch(shouldGroupNodesByCategoryChanged(e.target.checked)); + }, + [dispatch] + ); + const { t } = useTranslation(); return ( @@ -145,6 +155,14 @@ const WorkflowEditorSettings = () => { {t('nodes.showEdgeLabelsHelp')} + + + {t('nodes.groupNodesByCategory')} + + + {t('nodes.groupNodesByCategoryHelp')} + + {t('common.advanced')} diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts index 591aac6c18..428fcf1e82 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts @@ -3,6 +3,7 @@ import { $templates } from 'features/nodes/store/nodesSlice'; import { $flow } from 'features/nodes/store/reactFlowInstance'; import { NODE_WIDTH } from 'features/nodes/types/constants'; import type { AnyNode, InvocationTemplate } from 'features/nodes/types/invocation'; +import { buildConnectorNode } from 'features/nodes/util/node/buildConnectorNode'; import { buildCurrentImageNode } from 'features/nodes/util/node/buildCurrentImageNode'; import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode'; import { buildNotesNode } from 'features/nodes/util/node/buildNotesNode'; @@ -14,7 +15,7 @@ export const useBuildNode = () => { return useCallback( // string here is "any invocation type" - (type: string | 'current_image' | 'notes'): AnyNode => { + (type: string | 'connector' | 'current_image' | 'notes'): AnyNode => { const flow = $flow.get(); assert(flow !== null); @@ -42,6 +43,10 @@ export const useBuildNode = () => { return buildNotesNode(position); } + if (type === 'connector') { + return buildConnectorNode(position); + } + // TODO: Keep track of invocation types so we do not need to cast this // We know it is safe because the caller of this function gets the `type` arg from the list of invocation templates. const template = templates[type] as InvocationTemplate; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts index bfd8be95f5..d763254bc4 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts @@ -12,9 +12,15 @@ import { edgesChanged, } from 'features/nodes/store/nodesSlice'; import { selectNodes, selectNodesSlice } from 'features/nodes/store/selectors'; +import { + CONNECTOR_INPUT_HANDLE, + CONNECTOR_OUTPUT_HANDLE, + resolveConnectorSourceFieldType, +} from 'features/nodes/store/util/connectorTopology'; import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; import type { AnyEdge } from 'features/nodes/types/invocation'; +import { isConnectorNode } from 'features/nodes/types/invocation'; import { useCallback, useMemo } from 'react'; import { assert } from 'tsafe'; @@ -33,6 +39,47 @@ export const useConnection = () => { return; } + if (isConnectorNode(node)) { + if (handleType === 'source' && handleId !== CONNECTOR_OUTPUT_HANDLE) { + return; + } + if (handleType === 'target' && handleId !== CONNECTOR_INPUT_HANDLE) { + return; + } + + const resolvedSourceType = + handleType === 'source' + ? resolveConnectorSourceFieldType(nodeId, nodes, selectNodesSlice(store.getState()).edges, templates) + : null; + $pendingConnection.set({ + nodeId, + handleId, + handleType, + fieldTemplate: + handleType === 'source' + ? { + name: CONNECTOR_OUTPUT_HANDLE, + title: 'Connector Output', + description: '', + fieldKind: 'output', + ui_hidden: false, + type: resolvedSourceType ?? { name: 'AnyField', cardinality: 'SINGLE', batch: false }, + } + : { + name: CONNECTOR_INPUT_HANDLE, + title: 'Connector Input', + description: '', + fieldKind: 'input', + input: 'connection', + required: false, + default: undefined, + ui_hidden: false, + type: { name: 'AnyField', cardinality: 'SINGLE', batch: false }, + }, + }); + return; + } + const template = templates[node.data.type]; if (!template) { return; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.test.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.test.ts new file mode 100644 index 0000000000..5306347921 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.test.ts @@ -0,0 +1,160 @@ +import { deepClone } from 'common/util/deepClone'; +import { buildConnectorNode } from 'features/nodes/util/node/buildConnectorNode'; +import { describe, expect, it } from 'vitest'; + +import { connectorInserted, nodesChanged, nodesSliceConfig } from './nodesSlice'; +import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from './util/connectorTopology'; +import { add, buildEdge, buildNode, sub } from './util/testUtils'; + +const buildFixedConnectorNode = (id: string) => { + const connectorNode = buildConnectorNode({ x: 0, y: 0 }); + return { + ...connectorNode, + id, + data: { + ...connectorNode.data, + id, + }, + }; +}; + +describe('nodesSlice connector actions', () => { + it('splits a direct edge into source -> connector -> target edges when inserting a connector', () => { + const source = buildNode(add); + const target = buildNode(sub); + const connector = buildFixedConnectorNode('connector-1'); + const directEdge = buildEdge(source.id, 'value', target.id, 'a'); + + const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' })); + initialState.nodes = [source, target]; + initialState.edges = [directEdge]; + + const nextState = nodesSliceConfig.slice.reducer( + initialState, + connectorInserted({ + edgeId: directEdge.id, + connector, + }) + ); + + expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, target.id, connector.id]); + expect(nextState.edges).toEqual([ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ]); + }); + + it('splices connector outputs back to the resolved upstream source when removed', () => { + const source = buildNode(add); + const target = buildNode(sub); + const connector = buildFixedConnectorNode('connector-1'); + + const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' })); + initialState.nodes = [source, connector, target]; + initialState.edges = [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ]; + + const nextState = nodesSliceConfig.slice.reducer( + initialState, + nodesChanged([{ type: 'remove', id: connector.id }]) + ); + + expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, target.id]); + expect(nextState.edges).toEqual([buildEdge(source.id, 'value', target.id, 'a')]); + }); + + it('splices one connector source back to multiple downstream targets when removed', () => { + const source = buildNode(add); + const targetA = buildNode(sub); + const targetB = buildNode(sub); + const connector = buildFixedConnectorNode('connector-1'); + + const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' })); + initialState.nodes = [source, connector, targetA, targetB]; + initialState.edges = [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetA.id, 'a'), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetB.id, 'b'), + ]; + + const nextState = nodesSliceConfig.slice.reducer( + initialState, + nodesChanged([{ type: 'remove', id: connector.id }]) + ); + + expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, targetA.id, targetB.id]); + expect(nextState.edges).toEqual([ + buildEdge(source.id, 'value', targetA.id, 'a'), + buildEdge(source.id, 'value', targetB.id, 'b'), + ]); + }); + + it('does not create any edges when removing a connector with no downstream targets', () => { + const source = buildNode(add); + const connector = buildFixedConnectorNode('connector-1'); + + const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' })); + initialState.nodes = [source, connector]; + initialState.edges = [buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)]; + + const nextState = nodesSliceConfig.slice.reducer( + initialState, + nodesChanged([{ type: 'remove', id: connector.id }]) + ); + + expect(nextState.nodes.map((node) => node.id)).toEqual([source.id]); + expect(nextState.edges).toEqual([]); + }); + + it('removes a connector while preserving downstream connector edges in a chained splice case', () => { + const source = buildNode(add); + const connectorA = buildFixedConnectorNode('connector-a'); + const connectorB = buildFixedConnectorNode('connector-b'); + const target = buildNode(sub); + + const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' })); + initialState.nodes = [source, connectorA, connectorB, target]; + initialState.edges = [ + buildEdge(source.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorB.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ]; + + const nextState = nodesSliceConfig.slice.reducer( + initialState, + nodesChanged([{ type: 'remove', id: connectorA.id }]) + ); + + expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, connectorB.id, target.id]); + expect(nextState.edges).toHaveLength(2); + expect(nextState.edges).toEqual( + expect.arrayContaining([ + buildEdge(source.id, 'value', connectorB.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorB.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ]) + ); + }); + + it('splices connector edges when the connector is removed through generic node removal', () => { + const source = buildNode(add); + const target = buildNode(sub); + const connector = buildFixedConnectorNode('connector-1'); + + const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' })); + initialState.nodes = [source, connector, target]; + initialState.edges = [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ]; + + const nextState = nodesSliceConfig.slice.reducer( + initialState, + nodesChanged([{ type: 'remove', id: connector.id }]) + ); + + expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, target.id]); + expect(nextState.edges).toEqual([buildEdge(source.id, 'value', target.id, 'a')]); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index bdab6c1ae3..6713ee8fb4 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -20,6 +20,13 @@ import { reparentElement, } from 'features/nodes/components/sidePanel/builder/form-manipulation'; import { type NodesState, zNodesState } from 'features/nodes/store/types'; +import { + CONNECTOR_INPUT_HANDLE, + CONNECTOR_OUTPUT_HANDLE, + getConnectorOutputEdges, + resolveConnectorSource, +} from 'features/nodes/store/util/connectorTopology'; +import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; import type { BoardFieldValue, @@ -65,8 +72,8 @@ import { zStringGeneratorFieldValue, zStylePresetFieldValue, } from 'features/nodes/types/field'; -import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation'; -import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; +import type { AnyEdge, AnyNode, ConnectorNode } from 'features/nodes/types/invocation'; +import { isConnectorNode, isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; import type { BuilderForm, ContainerElement, @@ -103,7 +110,7 @@ export const getInitialWorkflow = (): Omit[]>) => { + const removedConnectorSpliceEdges: AnyEdge[] = action.payload.flatMap((change) => { + if (change.type !== 'remove') { + return []; + } + + const node = state.nodes.find((candidate) => candidate.id === change.id); + if (!isConnectorNode(node)) { + return []; + } + + const resolvedSource = resolveConnectorSource(node.id, state.nodes, state.edges); + if (!resolvedSource) { + return []; + } + + return getConnectorOutputEdges(node.id, state.edges) + .filter((edge): edge is AnyEdge & { type: 'default'; targetHandle: string } => edge.type === 'default') + .map((edge) => + connectionToEdge({ + source: resolvedSource.nodeId, + sourceHandle: resolvedSource.fieldName, + target: edge.target, + targetHandle: edge.targetHandle, + }) + ); + }); + // TODO(psyche): The below TS issue was recently fixed upstream. Need to upgrade @xyflow/react and then we // should be able to remove this cast. // @@ -206,6 +240,12 @@ const slice = createSlice({ if (edgeChanges.length > 0) { state.edges = applyEdgeChanges(edgeChanges, state.edges); } + if (removedConnectorSpliceEdges.length > 0) { + state.edges = applyEdgeChanges( + removedConnectorSpliceEdges.map((edge) => ({ type: 'add', item: edge })), + state.edges + ); + } } const wereNodesRemoved = action.payload.some((change) => change.type === 'remove' || change.type === 'replace'); @@ -396,11 +436,40 @@ const slice = createSlice({ } } }, + connectorInserted: ( + state, + action: PayloadAction<{ + edgeId: string; + connector: ConnectorNode; + }> + ) => { + const { edgeId, connector } = action.payload; + const edge = state.edges.find((candidate) => candidate.id === edgeId); + if (!edge || edge.type !== 'default') { + return; + } + state.nodes.push({ ...SHARED_NODE_PROPERTIES, ...connector } as (typeof state.nodes)[number]); + state.edges = state.edges.filter((candidate) => candidate.id !== edgeId); + state.edges.push( + connectionToEdge({ + source: edge.source, + sourceHandle: edge.sourceHandle ?? null, + target: connector.id, + targetHandle: CONNECTOR_INPUT_HANDLE, + }), + connectionToEdge({ + source: connector.id, + sourceHandle: CONNECTOR_OUTPUT_HANDLE, + target: edge.target, + targetHandle: edge.targetHandle ?? null, + }) + ); + }, nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => { const { nodeId, label } = action.payload; const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId); const node = state.nodes?.[nodeIndex]; - if (isInvocationNode(node) || isNotesNode(node)) { + if (isInvocationNode(node) || isNotesNode(node) || isConnectorNode(node)) { node.data.label = label; } }, @@ -614,6 +683,7 @@ export const { nodeEditorReset, nodeIsIntermediateChanged, nodeIsOpenChanged, + connectorInserted, nodeLabelChanged, nodeNotesChanged, nodesChanged, diff --git a/invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.test.ts new file mode 100644 index 0000000000..e87ebcde79 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.test.ts @@ -0,0 +1,126 @@ +import type { AnyNode, ConnectorNode } from 'features/nodes/types/invocation'; +import { describe, expect, it } from 'vitest'; + +import { + CONNECTOR_INPUT_HANDLE, + CONNECTOR_OUTPUT_HANDLE, + getConnectorDeletionSpliceConnections, + getConnectorInputEdge, + getConnectorOutputEdges, + resolveConnectorSource, + resolveConnectorSourceFieldType, +} from './connectorTopology'; +import { add, buildEdge, buildNode, img_resize, sub, templates } from './testUtils'; + +const buildConnectorNode = (id: string): ConnectorNode => ({ + id, + type: 'connector', + position: { x: 0, y: 0 }, + data: { + id, + type: 'connector', + label: 'Connector', + isOpen: true, + }, +}); + +describe('connectorTopology', () => { + it('resolves the effective upstream source through one connector', () => { + const source = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const target = buildNode(sub); + const nodes: AnyNode[] = [source, connector, target]; + const edges = [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ]; + + expect(resolveConnectorSource(connector.id, nodes, edges)).toEqual({ + nodeId: source.id, + fieldName: 'value', + }); + expect(resolveConnectorSourceFieldType(connector.id, nodes, edges, templates)).toEqual(add.outputs.value?.type); + }); + + it('resolves the effective upstream source through chained connectors', () => { + const source = buildNode(add); + const connectorA = buildConnectorNode('connector-a'); + const connectorB = buildConnectorNode('connector-b'); + const nodes: AnyNode[] = [source, connectorA, connectorB]; + const edges = [ + buildEdge(source.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE), + ]; + + expect(resolveConnectorSource(connectorB.id, nodes, edges)).toEqual({ + nodeId: source.id, + fieldName: 'value', + }); + }); + + it('returns no source or type for an unresolved connector chain', () => { + const connectorA = buildConnectorNode('connector-a'); + const connectorB = buildConnectorNode('connector-b'); + const nodes: AnyNode[] = [connectorA, connectorB]; + const edges = [buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE)]; + + expect(resolveConnectorSource(connectorB.id, nodes, edges)).toBe(null); + expect(resolveConnectorSourceFieldType(connectorB.id, nodes, edges, templates)).toBe(null); + }); + + it('enumerates multiple outgoing edges for a connector', () => { + const source = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const targetA = buildNode(sub); + const targetB = buildNode(img_resize); + const incoming = buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE); + const outgoingA = buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetA.id, 'a'); + const outgoingB = buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetB.id, 'width'); + const edges = [incoming, outgoingA, outgoingB]; + + expect(getConnectorInputEdge(connector.id, edges)).toEqual(incoming); + expect(getConnectorOutputEdges(connector.id, edges)).toEqual([outgoingA, outgoingB]); + }); + + it('rejects connector deletion splice-through when any downstream target would be invalid', () => { + const source = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const target = buildNode(img_resize); + const nodes: AnyNode[] = [source, connector, target]; + const edges = [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'image'), + ]; + + expect(getConnectorDeletionSpliceConnections(connector.id, nodes, edges, templates)).toBe(null); + }); + + it('builds connector deletion splice-through edges when every downstream target remains valid', () => { + const source = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const target = buildNode(sub); + const nodes: AnyNode[] = [source, connector, target]; + const edges = [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ]; + + expect(getConnectorDeletionSpliceConnections(connector.id, nodes, edges, templates)).toEqual([ + { + source: source.id, + sourceHandle: 'value', + target: target.id, + targetHandle: 'a', + }, + ]); + }); + + it('returns no splice-through edges when a connector has downstream targets but no upstream source', () => { + const connector = buildConnectorNode('connector-1'); + const target = buildNode(sub); + const nodes: AnyNode[] = [connector, target]; + const edges = [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a')]; + + expect(getConnectorDeletionSpliceConnections(connector.id, nodes, edges, templates)).toBe(null); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.ts b/invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.ts new file mode 100644 index 0000000000..e1267763d7 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.ts @@ -0,0 +1,228 @@ +import type { Templates } from 'features/nodes/store/types'; +import type { FieldType } from 'features/nodes/types/field'; +import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation'; +import { isConnectorNode, isInvocationNode } from 'features/nodes/types/invocation'; + +export const CONNECTOR_INPUT_HANDLE = 'in'; +export const CONNECTOR_OUTPUT_HANDLE = 'out'; + +type ResolvedConnectorSource = { + nodeId: string; + fieldName: string; +}; + +type SpliceConnection = { + source: string; + sourceHandle: string; + target: string; + targetHandle: string; +}; + +type SpliceConnectionValidator = ( + connection: SpliceConnection, + nodes: AnyNode[], + edges: AnyEdge[], + templates: Templates, + ignoreEdge: AnyEdge | null, + strict?: boolean +) => string | null; + +export const getConnectorInputEdge = (connectorId: string, edges: AnyEdge[]): AnyEdge | undefined => + edges.find( + (edge) => + edge.type === 'default' && + edge.target === connectorId && + edge.targetHandle === CONNECTOR_INPUT_HANDLE && + typeof edge.sourceHandle === 'string' + ); + +export const getConnectorOutputEdges = (connectorId: string, edges: AnyEdge[]): AnyEdge[] => + edges.filter( + (edge) => + edge.type === 'default' && + edge.source === connectorId && + edge.sourceHandle === CONNECTOR_OUTPUT_HANDLE && + typeof edge.targetHandle === 'string' + ); + +export const resolveConnectorSource = ( + connectorId: string, + nodes: AnyNode[], + edges: AnyEdge[] +): ResolvedConnectorSource | null => { + const visited = new Set(); + + const resolve = (nodeId: string): ResolvedConnectorSource | null => { + if (visited.has(nodeId)) { + return null; + } + visited.add(nodeId); + + const incomingEdge = getConnectorInputEdge(nodeId, edges); + if (!incomingEdge || incomingEdge.type !== 'default') { + return null; + } + if (typeof incomingEdge.sourceHandle !== 'string') { + return null; + } + + const sourceNode = nodes.find((node) => node.id === incomingEdge.source); + if (!sourceNode) { + return null; + } + + if (isInvocationNode(sourceNode)) { + return { nodeId: sourceNode.id, fieldName: incomingEdge.sourceHandle }; + } + + if (isConnectorNode(sourceNode)) { + return resolve(sourceNode.id); + } + + return null; + }; + + return resolve(connectorId); +}; + +export const resolveConnectorSourceFieldType = ( + connectorId: string, + nodes: AnyNode[], + edges: AnyEdge[], + templates: Templates +): FieldType | null => { + const resolvedSource = resolveConnectorSource(connectorId, nodes, edges); + if (!resolvedSource) { + return null; + } + + const sourceNode = nodes.find((node) => node.id === resolvedSource.nodeId); + if (!sourceNode || !isInvocationNode(sourceNode)) { + return null; + } + + const sourceTemplate = templates[sourceNode.data.type]; + return sourceTemplate?.outputs[resolvedSource.fieldName]?.type ?? null; +}; + +export const getConnectorDeletionSpliceConnections = ( + connectorId: string, + nodes: AnyNode[], + edges: AnyEdge[], + templates: Templates, + validateConnection?: SpliceConnectionValidator +): SpliceConnection[] | null => { + const resolvedSource = resolveConnectorSource(connectorId, nodes, edges); + if (!resolvedSource) { + return null; + } + + const outputEdges = getConnectorOutputEdges(connectorId, edges); + const spliceConnections = outputEdges + .filter((edge): edge is AnyEdge & { type: 'default'; targetHandle: string } => edge.type === 'default') + .map((edge) => ({ + source: resolvedSource.nodeId, + sourceHandle: resolvedSource.fieldName, + target: edge.target, + targetHandle: edge.targetHandle, + })); + + const deduped = new Set(); + for (const connection of spliceConnections) { + const key = `${connection.source}:${connection.sourceHandle}->${connection.target}:${connection.targetHandle}`; + if (deduped.has(key)) { + return null; + } + deduped.add(key); + } + + if (!validateConnection) { + const sourceType = resolveConnectorSourceFieldType(connectorId, nodes, edges, templates); + if (!sourceType) { + return null; + } + const inputEdgeId = getConnectorInputEdge(connectorId, edges)?.id; + const outputEdgeIds = new Set(outputEdges.map((edge) => edge.id)); + + for (const connection of spliceConnections) { + const targetNode = nodes.find((node) => node.id === connection.target); + if (!targetNode || !isInvocationNode(targetNode)) { + return null; + } + const targetTemplate = templates[targetNode.data.type]; + const targetFieldTemplate = targetTemplate?.inputs[connection.targetHandle]; + if (!targetFieldTemplate) { + return null; + } + + const matchesExistingDirectEdge = edges.some( + (edge) => + edge.type === 'default' && + edge.source === connection.source && + edge.sourceHandle === connection.sourceHandle && + edge.target === connection.target && + edge.targetHandle === connection.targetHandle + ); + if (matchesExistingDirectEdge) { + return null; + } + + const targetConflictCount = spliceConnections.filter( + (candidate) => candidate.target === connection.target && candidate.targetHandle === connection.targetHandle + ).length; + const existingTargetConflict = edges.some( + (edge) => + edge.type === 'default' && + edge.id !== inputEdgeId && + !outputEdgeIds.has(edge.id) && + edge.target === connection.target && + edge.targetHandle === connection.targetHandle + ); + if ( + targetFieldTemplate.type.name !== 'CollectionItemField' && + (targetConflictCount > 1 || existingTargetConflict) + ) { + return null; + } + + if ( + sourceType.name !== targetFieldTemplate.type.name && + targetFieldTemplate.type.name !== 'CollectionItemField' + ) { + return null; + } + } + + return spliceConnections; + } + + const ignoredEdgeIds = new Set([ + getConnectorInputEdge(connectorId, edges)?.id, + ...outputEdges.map((edge) => edge.id), + ]); + const existingEdges = edges.filter((edge) => !ignoredEdgeIds.has(edge.id)); + const stagedConnections: SpliceConnection[] = []; + + for (const connection of spliceConnections) { + const stagedEdges = [ + ...existingEdges, + ...stagedConnections.map( + ({ source, sourceHandle, target, targetHandle }) => + ({ + id: `splice-${source}-${sourceHandle}-${target}-${targetHandle}`, + type: 'default', + source, + sourceHandle, + target, + targetHandle, + }) satisfies AnyEdge + ), + ]; + if (validateConnection(connection, nodes, stagedEdges, templates, null, true) !== null) { + return null; + } + stagedConnections.push(connection); + } + + return spliceConnections; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts index 288a4f9066..b4374a920c 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts @@ -1,13 +1,26 @@ import { deepClone } from 'common/util/deepClone'; import { unset } from 'es-toolkit/compat'; +import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology'; import { getFirstValidConnection, getSourceCandidateFields, getTargetCandidateFields, } from 'features/nodes/store/util/getFirstValidConnection'; -import { add, buildEdge, buildNode, img_resize, templates } from 'features/nodes/store/util/testUtils'; +import { add, buildEdge, buildNode, img_resize, sub, templates } from 'features/nodes/store/util/testUtils'; import { describe, expect, it } from 'vitest'; +const buildConnectorNode = (id: string) => ({ + id, + type: 'connector' as const, + position: { x: 0, y: 0 }, + data: { + id, + type: 'connector' as const, + label: 'Connector', + isOpen: true, + }, +}); + describe('getFirstValidConnection', () => { it('should return null if the pending and candidate nodes are the same node', () => { const n = buildNode(add); @@ -120,6 +133,33 @@ describe('getFirstValidConnection', () => { expect(r).toEqual(null); }); }); + + it('should resolve connector target candidates when connecting an invocation output to a connector', () => { + const n1 = buildNode(add); + const connector = buildConnectorNode('connector-1'); + expect(getFirstValidConnection(n1.id, 'value', connector.id, null, [n1, connector], [], templates, null)).toEqual({ + source: n1.id, + sourceHandle: 'value', + target: connector.id, + targetHandle: CONNECTOR_INPUT_HANDLE, + }); + }); + + it('should resolve connector source candidates when connecting a connector to a typed invocation input', () => { + const n1 = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const n2 = buildNode(img_resize); + const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)]; + + expect( + getFirstValidConnection(connector.id, null, n2.id, 'width', [n1, connector, n2], edges, templates, null) + ).toEqual({ + source: connector.id, + sourceHandle: CONNECTOR_OUTPUT_HANDLE, + target: n2.id, + targetHandle: 'width', + }); + }); }); describe('getTargetCandidateFields', () => { @@ -160,6 +200,62 @@ describe('getTargetCandidateFields', () => { const r = getTargetCandidateFields(n1.id, 'width', n2.id, nodes, [], templates, edgePendingUpdate); expect(r).toEqual([img_resize.inputs['width'], img_resize.inputs['height']]); }); + it('should return the connector input handle when the target is a connector', () => { + const n1 = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const r = getTargetCandidateFields(n1.id, 'value', connector.id, [n1, connector], [], templates, null); + expect(r.map((field) => field.name)).toEqual([CONNECTOR_INPUT_HANDLE]); + }); + it('should advertise typed target candidates for an unresolved connector output when no downstream constraint exists', () => { + const connector = buildConnectorNode('connector-1'); + const n2 = buildNode(sub); + const r = getTargetCandidateFields( + connector.id, + CONNECTOR_OUTPUT_HANDLE, + n2.id, + [connector, n2], + [], + templates, + null + ); + expect(r.map((field) => field.name)).toEqual(['a', 'b']); + }); + it('should only advertise compatible typed target candidates for an unresolved connector output with downstream constraints', () => { + const connector = buildConnectorNode('connector-1'); + const n1 = buildNode(sub); + const n2 = buildNode(img_resize); + const edges = [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n1.id, 'a')]; + const r = getTargetCandidateFields( + connector.id, + CONNECTOR_OUTPUT_HANDLE, + n2.id, + [connector, n1, n2], + edges, + templates, + null + ); + expect(r.map((field) => field.name)).toEqual(['width', 'height']); + }); + it('should resolve chained connector sources like the direct upstream source', () => { + const n1 = buildNode(add); + const connectorA = buildConnectorNode('connector-a'); + const connectorB = buildConnectorNode('connector-b'); + const n2 = buildNode(img_resize); + const edges = [ + buildEdge(n1.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE), + ]; + const r = getTargetCandidateFields( + connectorB.id, + CONNECTOR_OUTPUT_HANDLE, + n2.id, + [n1, connectorA, connectorB, n2], + edges, + templates, + null + ); + expect(r).toEqual([img_resize.inputs['width'], img_resize.inputs['height']]); + }); }); describe('getSourceCandidateFields', () => { @@ -200,4 +296,18 @@ describe('getSourceCandidateFields', () => { const r = getSourceCandidateFields(n2.id, 'width', n1.id, nodes, [], templates, edgePendingUpdate); expect(r).toEqual([img_resize.outputs['width'], img_resize.outputs['height']]); }); + it('should return the connector output handle when the source is a connector with a typed upstream source', () => { + const n1 = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const n2 = buildNode(img_resize); + const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)]; + const r = getSourceCandidateFields(n2.id, 'width', connector.id, [n1, connector, n2], edges, templates, null); + expect(r.map((field) => field.name)).toEqual([CONNECTOR_OUTPUT_HANDLE]); + }); + it('should return a target-constrained connector source candidate when the connector chain is unresolved', () => { + const connector = buildConnectorNode('connector-1'); + const n2 = buildNode(img_resize); + const r = getSourceCandidateFields(n2.id, 'width', connector.id, [connector, n2], [], templates, null); + expect(r.map((field) => field.name)).toEqual([CONNECTOR_OUTPUT_HANDLE]); + }); }); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts index 0b5aa17e17..2a466aae44 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts @@ -1,9 +1,15 @@ import type { Connection } from '@xyflow/react'; import { map } from 'es-toolkit/compat'; import type { Templates } from 'features/nodes/store/types'; +import { + CONNECTOR_INPUT_HANDLE, + CONNECTOR_OUTPUT_HANDLE, + resolveConnectorSourceFieldType, +} from 'features/nodes/store/util/connectorTopology'; import { validateConnection } from 'features/nodes/store/util/validateConnection'; import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation'; +import { isConnectorNode } from 'features/nodes/types/invocation'; /** * @@ -91,16 +97,43 @@ export const getTargetCandidateFields = ( return []; } - const sourceTemplate = templates[sourceNode.data.type]; + if (isConnectorNode(targetNode)) { + const candidate = { + name: CONNECTOR_INPUT_HANDLE, + title: 'Connector Input', + description: '', + fieldKind: 'input', + input: 'connection', + required: false, + default: undefined, + ui_hidden: false, + type: { + name: 'AnyField', + cardinality: 'SINGLE', + batch: false, + }, + } satisfies FieldInputTemplate; + + const c = { source, sourceHandle, target, targetHandle: candidate.name }; + return validateConnection(c, nodes, edges, templates, edgePendingUpdate, true) === null ? [candidate] : []; + } + const targetTemplate = templates[targetNode.data.type]; - if (!sourceTemplate || !targetTemplate) { + if (!targetTemplate) { return []; } - const sourceField = sourceTemplate.outputs[sourceHandle]; + if (!isConnectorNode(sourceNode)) { + const sourceTemplate = templates[sourceNode.data.type]; + if (!sourceTemplate) { + return []; + } - if (!sourceField) { - return []; + const sourceField = sourceTemplate.outputs[sourceHandle]; + + if (!sourceField) { + return []; + } } const targetCandidateFields = map(targetTemplate.inputs).filter((field) => { @@ -127,15 +160,45 @@ export const getSourceCandidateFields = ( return []; } + if (isConnectorNode(sourceNode)) { + const sourceFieldType = resolveConnectorSourceFieldType(sourceNode.id, nodes, edges, templates); + const targetTemplate = !isConnectorNode(targetNode) ? templates[targetNode.data.type] : null; + const targetFieldType = targetTemplate?.inputs[targetHandle]?.type; + const candidateType = sourceFieldType ?? targetFieldType; + if (!candidateType) { + return []; + } + + const candidate = { + name: CONNECTOR_OUTPUT_HANDLE, + title: 'Connector Output', + description: '', + fieldKind: 'output', + ui_hidden: false, + type: candidateType, + } satisfies FieldOutputTemplate; + + const c = { source, sourceHandle: candidate.name, target, targetHandle }; + return validateConnection(c, nodes, edges, templates, edgePendingUpdate, true) === null ? [candidate] : []; + } + const sourceTemplate = templates[sourceNode.data.type]; - const targetTemplate = templates[targetNode.data.type]; - if (!sourceTemplate || !targetTemplate) { + if (!sourceTemplate) { return []; } - const targetField = targetTemplate.inputs[targetHandle]; + if (!isConnectorNode(targetNode)) { + const targetTemplate = templates[targetNode.data.type]; + if (!targetTemplate) { + return []; + } - if (!targetField) { + const targetField = targetTemplate.inputs[targetHandle]; + + if (!targetField) { + return []; + } + } else if (targetHandle !== CONNECTOR_INPUT_HANDLE) { return []; } diff --git a/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.test.ts new file mode 100644 index 0000000000..b70eda4bda --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.test.ts @@ -0,0 +1,23 @@ +import { describe, expect, it } from 'vitest'; + +import { connectionToEdge } from './reactFlowUtil'; + +describe('connectionToEdge', () => { + it('creates a default edge with the expected id and endpoints', () => { + expect( + connectionToEdge({ + source: 'source-node', + sourceHandle: 'value', + target: 'target-node', + targetHandle: 'a', + }) + ).toEqual({ + type: 'default', + source: 'source-node', + sourceHandle: 'value', + target: 'target-node', + targetHandle: 'a', + id: 'reactflow__edge-source-nodevalue-target-nodea', + }); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts b/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts index 67264bf41b..3eaece154f 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts @@ -24,6 +24,7 @@ export const connectionToEdge = (connection: Connection): AnyEdge => { const { source, sourceHandle, target, targetHandle } = connection; assert(source && sourceHandle && target && targetHandle, 'Invalid connection'); return { + type: 'default', source, sourceHandle, target, diff --git a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts index 1eb445beaf..8706e199bb 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts @@ -70,6 +70,7 @@ export const add: InvocationTemplate = { useCache: true, nodePack: 'invokeai', classification: 'stable', + category: 'math', }; export const sub: InvocationTemplate = { @@ -128,6 +129,7 @@ export const sub: InvocationTemplate = { useCache: true, nodePack: 'invokeai', classification: 'stable', + category: 'math', }; export const collect: InvocationTemplate = { @@ -189,6 +191,7 @@ export const collect: InvocationTemplate = { useCache: true, nodePack: 'invokeai', classification: 'stable', + category: 'collections', }; const scheduler: InvocationTemplate = { @@ -245,6 +248,7 @@ const scheduler: InvocationTemplate = { useCache: true, nodePack: 'invokeai', classification: 'stable', + category: 'other', }; export const main_model_loader: InvocationTemplate = { @@ -313,6 +317,7 @@ export const main_model_loader: InvocationTemplate = { useCache: true, nodePack: 'invokeai', classification: 'stable', + category: 'model', }; export const img_resize: InvocationTemplate = { @@ -457,6 +462,7 @@ export const img_resize: InvocationTemplate = { useCache: true, nodePack: 'invokeai', classification: 'stable', + category: 'image', }; const iterate: InvocationTemplate = { @@ -526,6 +532,7 @@ const iterate: InvocationTemplate = { useCache: true, nodePack: 'invokeai', classification: 'stable', + category: 'collections', }; export const templates: Templates = { @@ -713,7 +720,6 @@ export const schema = { required: ['type', 'id'], title: 'Scheduler', description: 'Selects a scheduler.', - category: 'latents', classification: 'stable', node_pack: 'invokeai', tags: ['scheduler'], @@ -1199,6 +1205,7 @@ export const schema = { title: 'CollectInvocation', node_pack: 'invokeai', description: 'Collects values into a collection', + category: 'collections', classification: 'stable', version: '1.1.0', output: { @@ -1558,6 +1565,7 @@ export const schema = { required: ['type', 'id'], title: 'IterateInvocation', description: 'Iterates over a list of items', + category: 'collections', classification: 'stable', node_pack: 'invokeai', version: '1.1.0', diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts index 88eae8484f..730dced1d3 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts @@ -3,6 +3,11 @@ import { set } from 'es-toolkit/compat'; import type { InvocationTemplate } from 'features/nodes/types/invocation'; import { describe, expect, it } from 'vitest'; +import { + CONNECTOR_INPUT_HANDLE, + CONNECTOR_OUTPUT_HANDLE, + getConnectorDeletionSpliceConnections, +} from './connectorTopology'; import { add, buildEdge, buildNode, collect, img_resize, main_model_loader, sub, templates } from './testUtils'; import { validateConnection } from './validateConnection'; @@ -11,6 +16,7 @@ const ifTemplate: InvocationTemplate = { type: 'if', version: '1.0.0', tags: [], + category: 'math', description: 'Selects between two inputs based on a boolean condition', outputType: 'if_output', inputs: { @@ -88,6 +94,7 @@ const floatOutputTemplate: InvocationTemplate = { type: 'float_output', version: '1.0.0', tags: [], + category: 'primitives', description: 'Outputs a float', outputType: 'float_output', inputs: {}, @@ -116,6 +123,7 @@ const integerCollectionOutputTemplate: InvocationTemplate = { type: 'integer_collection_output', version: '1.0.0', tags: [], + category: 'primitives', description: 'Outputs an integer collection', outputType: 'integer_collection_output', inputs: {}, @@ -139,6 +147,18 @@ const integerCollectionOutputTemplate: InvocationTemplate = { classification: 'stable', }; +const buildConnectorNode = (id: string) => ({ + id, + type: 'connector' as const, + position: { x: 0, y: 0 }, + data: { + id, + type: 'connector' as const, + label: 'Connector', + isOpen: true, + }, +}); + describe(validateConnection.name, () => { it('should reject invalid connection to self', () => { const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' }; @@ -458,6 +478,214 @@ describe(validateConnection.name, () => { expect(r).toEqual('nodes.connectionWouldCreateCycle'); }); + describe('connectors', () => { + it('should accept invocation output to connector input', () => { + const n1 = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const r = validateConnection( + { source: n1.id, sourceHandle: 'value', target: connector.id, targetHandle: CONNECTOR_INPUT_HANDLE }, + [n1, connector], + [], + templates, + null + ); + expect(r).toEqual(null); + }); + + it('should reject a second input into a connector', () => { + const n1 = buildNode(add); + const n2 = buildNode(sub); + const connector = buildConnectorNode('connector-1'); + const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)]; + const r = validateConnection( + { source: n2.id, sourceHandle: 'value', target: connector.id, targetHandle: CONNECTOR_INPUT_HANDLE }, + [n1, n2, connector], + edges, + templates, + null + ); + expect(r).toEqual('nodes.inputMayOnlyHaveOneConnection'); + }); + + it('should accept connector output to invocation input when the upstream type matches', () => { + const n1 = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const n2 = buildNode(sub); + const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)]; + const r = validateConnection( + { source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'a' }, + [n1, connector, n2], + edges, + templates, + null + ); + expect(r).toEqual(null); + }); + + it('should reject connector output to invocation input when the upstream type mismatches', () => { + const n1 = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const n2 = buildNode(img_resize); + const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)]; + const r = validateConnection( + { source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'image' }, + [n1, connector, n2], + edges, + templates, + null + ); + expect(r).toEqual('nodes.fieldTypesMustMatch'); + }); + + it('should accept unresolved connector output to a typed invocation input as the first downstream constraint', () => { + const connector = buildConnectorNode('connector-1'); + const n2 = buildNode(sub); + const r = validateConnection( + { source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'a' }, + [connector, n2], + [], + templates, + null + ); + expect(r).toEqual(null); + }); + + it('should reject unresolved connector output when it conflicts with an existing downstream typed constraint', () => { + const connector = buildConnectorNode('connector-1'); + const n1 = buildNode(sub); + const n2 = buildNode(img_resize); + const edges = [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n1.id, 'a')]; + const r = validateConnection( + { source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'image' }, + [connector, n1, n2], + edges, + templates, + null + ); + expect(r).toEqual('nodes.fieldTypesMustMatch'); + }); + + it('should reject connecting an incompatible upstream source into a connector with downstream typed constraints', () => { + const source = buildNode(main_model_loader); + const connector = buildConnectorNode('connector-1'); + const target = buildNode(sub); + const edges = [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a')]; + const r = validateConnection( + { source: source.id, sourceHandle: 'vae', target: connector.id, targetHandle: CONNECTOR_INPUT_HANDLE }, + [source, connector, target], + edges, + templates, + null + ); + expect(r).toEqual('nodes.fieldTypesMustMatch'); + }); + + it('should preserve type information through chained connectors', () => { + const n1 = buildNode(add); + const connectorA = buildConnectorNode('connector-a'); + const connectorB = buildConnectorNode('connector-b'); + const n2 = buildNode(sub); + const edges = [ + buildEdge(n1.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE), + ]; + const r = validateConnection( + { source: connectorB.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'a' }, + [n1, connectorA, connectorB, n2], + edges, + templates, + null + ); + expect(r).toEqual(null); + }); + + it('should reject cycles routed through connectors', () => { + const n1 = buildNode(add); + const n2 = buildNode(sub); + const connector = buildConnectorNode('connector-1'); + const edges = [ + buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n2.id, 'a'), + ]; + const r = validateConnection( + { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' }, + [n1, n2, connector], + edges, + templates, + null + ); + expect(r).toEqual('nodes.connectionWouldCreateCycle'); + }); + + it('should preserve collect item validation through connectors', () => { + const n1 = buildNode(add); + const n2 = buildNode(collect); + const n3 = buildNode(main_model_loader); + const connector = buildConnectorNode('connector-1'); + const edges = [ + buildEdge(n1.id, 'value', n2.id, 'item'), + buildEdge(n3.id, 'vae', connector.id, CONNECTOR_INPUT_HANDLE), + ]; + const r = validateConnection( + { source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'item' }, + [n1, n2, n3, connector], + edges, + templates, + null + ); + expect(r).toEqual('nodes.cannotMixAndMatchCollectionItemTypes'); + }); + + it('should preserve if branch validation through connectors', () => { + const n1 = buildNode(add); + const n2 = buildNode(img_resize); + const n3 = buildNode(ifTemplate); + const connector = buildConnectorNode('connector-1'); + const edges = [ + buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n3.id, 'true_input'), + ]; + const r = validateConnection( + { source: n2.id, sourceHandle: 'image', target: n3.id, targetHandle: 'false_input' }, + [n1, n2, n3, connector], + edges, + { ...templates, if: ifTemplate }, + null + ); + expect(r).toEqual('nodes.fieldTypesMustMatch'); + }); + + it('should reject connector deletion splice-through when it would duplicate an existing direct edge', () => { + const n1 = buildNode(add); + const n2 = buildNode(sub); + const connector = buildConnectorNode('connector-1'); + const edges = [ + buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n2.id, 'a'), + buildEdge(n1.id, 'value', n2.id, 'a'), + ]; + + expect(getConnectorDeletionSpliceConnections(connector.id, [n1, n2, connector], edges, templates)).toBe(null); + }); + + it('should reject connector deletion splice-through when fan-out would violate a single-input target', () => { + const n1 = buildNode(add); + const connectorA = buildConnectorNode('connector-a'); + const connectorB = buildConnectorNode('connector-b'); + const n2 = buildNode(sub); + const edges = [ + buildEdge(n1.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, n2.id, 'a'), + buildEdge(connectorB.id, CONNECTOR_OUTPUT_HANDLE, n2.id, 'a'), + ]; + + expect( + getConnectorDeletionSpliceConnections(connectorA.id, [n1, connectorA, connectorB, n2], edges, templates) + ).toBe(null); + }); + }); + describe('non-strict mode', () => { it('should reject connections from self to self in non-strict mode', () => { const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' }; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts index b342df064b..bb98d472d3 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts @@ -1,10 +1,17 @@ import type { Connection as NullableConnection } from '@xyflow/react'; import type { Templates } from 'features/nodes/store/types'; import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual'; +import { + CONNECTOR_INPUT_HANDLE, + CONNECTOR_OUTPUT_HANDLE, + resolveConnectorSource, +} from 'features/nodes/store/util/connectorTopology'; import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType'; import { getHasCycles } from 'features/nodes/store/util/getHasCycles'; import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; -import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation'; +import type { FieldType } from 'features/nodes/types/field'; +import type { AnyEdge, AnyNode, InvocationNode } from 'features/nodes/types/invocation'; +import { isConnectorNode, isInvocationNode } from 'features/nodes/types/invocation'; import type { SetNonNullable } from 'type-fest'; type Connection = SetNonNullable; @@ -18,6 +25,12 @@ type ValidateConnectionFunc = ( strict?: boolean ) => string | null; +type EffectiveSource = { + node: InvocationNode; + handle: string; + fieldTemplate: NonNullable['outputs'][string]; +}; + const getEqualityPredicate = (c: Connection) => (e: AnyEdge): boolean => { @@ -41,10 +54,7 @@ const isIfInputHandle = (handle: string): handle is (typeof IF_INPUT_HANDLES)[nu return IF_INPUT_HANDLES.includes(handle as (typeof IF_INPUT_HANDLES)[number]); }; -const isSingleCollectionPairOfSameBaseType = ( - firstType: { name: string; cardinality: string; batch: boolean }, - secondType: { name: string; cardinality: string; batch: boolean } -) => { +const isSingleCollectionPairOfSameBaseType = (firstType: FieldType, secondType: FieldType) => { const isSingleToCollection = firstType.cardinality === 'SINGLE' && secondType.cardinality === 'COLLECTION' && firstType.name === secondType.name; const isCollectionToSingle = @@ -52,6 +62,139 @@ const isSingleCollectionPairOfSameBaseType = ( return firstType.batch === secondType.batch && (isSingleToCollection || isCollectionToSingle); }; +const areFieldTypesCompatible = (firstType: FieldType, secondType: FieldType) => + validateConnectionTypes(firstType, secondType) || + validateConnectionTypes(secondType, firstType) || + isSingleCollectionPairOfSameBaseType(firstType, secondType); + +type ConnectorTerminalTargetEdge = AnyEdge & { + type: 'default'; + sourceHandle: string; + targetHandle: string; +}; + +const getConnectorTerminalTargetEdges = (connectorId: string, nodes: AnyNode[], edges: AnyEdge[]) => { + const visited = new Set(); + const resolve = (currentConnectorId: string): ConnectorTerminalTargetEdge[] => { + if (visited.has(currentConnectorId)) { + return []; + } + visited.add(currentConnectorId); + + return edges.flatMap((edge) => { + if ( + edge.type !== 'default' || + edge.source !== currentConnectorId || + edge.sourceHandle !== CONNECTOR_OUTPUT_HANDLE || + typeof edge.targetHandle !== 'string' + ) { + return []; + } + + const targetNode = nodes.find((node) => node.id === edge.target); + if (targetNode && isConnectorNode(targetNode)) { + return resolve(targetNode.id); + } + + return [edge as ConnectorTerminalTargetEdge]; + }); + }; + + return resolve(connectorId); +}; + +const getConnectorSubgraphEdgeIds = (connectorId: string, nodes: AnyNode[], edges: AnyEdge[]) => { + const visited = new Set(); + const edgeIds = new Set(); + + const visit = (currentConnectorId: string) => { + if (visited.has(currentConnectorId)) { + return; + } + visited.add(currentConnectorId); + + edges.forEach((edge) => { + if ( + edge.type !== 'default' || + edge.source !== currentConnectorId || + edge.sourceHandle !== CONNECTOR_OUTPUT_HANDLE || + typeof edge.targetHandle !== 'string' + ) { + return; + } + + edgeIds.add(edge.id); + + const targetNode = nodes.find((node) => node.id === edge.target); + if (targetNode && isConnectorNode(targetNode)) { + visit(targetNode.id); + } + }); + }; + + visit(connectorId); + return edgeIds; +}; + +const getEffectiveSource = ( + sourceId: string, + sourceHandle: string, + nodes: AnyNode[], + edges: AnyEdge[], + templates: Templates +): EffectiveSource | 'nodes.missingNode' | 'nodes.missingInvocationTemplate' | 'nodes.missingFieldTemplate' | null => { + const sourceNode = nodes.find((n) => n.id === sourceId); + if (!sourceNode) { + return 'nodes.missingNode'; + } + + if (isConnectorNode(sourceNode)) { + if (sourceHandle !== CONNECTOR_OUTPUT_HANDLE) { + return 'nodes.missingFieldTemplate'; + } + + const resolvedSource = resolveConnectorSource(sourceNode.id, nodes, edges); + if (!resolvedSource) { + return null; + } + + return getEffectiveSource(resolvedSource.nodeId, resolvedSource.fieldName, nodes, edges, templates); + } + + if (!isInvocationNode(sourceNode)) { + return 'nodes.missingInvocationTemplate'; + } + + const sourceTemplate = templates[sourceNode.data.type]; + if (!sourceTemplate) { + return 'nodes.missingInvocationTemplate'; + } + + const sourceFieldTemplate = sourceTemplate.outputs[sourceHandle]; + if (!sourceFieldTemplate) { + return 'nodes.missingFieldTemplate'; + } + + return { + node: sourceNode, + handle: sourceHandle, + fieldTemplate: sourceFieldTemplate, + }; +}; + +const getEffectiveSourceForEdge = ( + edge: AnyEdge, + nodes: AnyNode[], + edges: AnyEdge[], + templates: Templates +): EffectiveSource | 'nodes.missingNode' | 'nodes.missingInvocationTemplate' | 'nodes.missingFieldTemplate' | null => { + if (edge.type !== 'default' || typeof edge.sourceHandle !== 'string') { + return null; + } + + return getEffectiveSource(edge.source, edge.sourceHandle, nodes, edges, templates); +}; + /** * Validates a connection between two fields * @returns A translation key for an error if the connection is invalid, otherwise null @@ -83,18 +226,63 @@ export const validateConnection: ValidateConnectionFunc = ( return 'nodes.cannotDuplicateConnection'; } - const sourceNode = nodes.find((n) => n.id === c.source); - if (!sourceNode) { - return 'nodes.missingNode'; - } - const targetNode = nodes.find((n) => n.id === c.target); + const sourceNode = nodes.find((n) => n.id === c.source); if (!targetNode) { return 'nodes.missingNode'; } - const sourceTemplate = templates[sourceNode.data.type]; - if (!sourceTemplate) { + const effectiveSource = getEffectiveSource(c.source, c.sourceHandle, nodes, filteredEdges, templates); + if (effectiveSource === 'nodes.missingNode') { + return 'nodes.missingNode'; + } + if (effectiveSource === 'nodes.missingInvocationTemplate') { + return 'nodes.missingInvocationTemplate'; + } + if (effectiveSource === 'nodes.missingFieldTemplate') { + return 'nodes.missingFieldTemplate'; + } + + if (isConnectorNode(targetNode)) { + if (c.targetHandle !== CONNECTOR_INPUT_HANDLE) { + return 'nodes.missingFieldTemplate'; + } + + if (filteredEdges.find(getTargetEqualityPredicate(c))) { + return 'nodes.inputMayOnlyHaveOneConnection'; + } + + if (effectiveSource) { + const connectorSubgraphEdgeIds = getConnectorSubgraphEdgeIds(targetNode.id, nodes, filteredEdges); + const stagedEdges = filteredEdges.filter((edge) => !connectorSubgraphEdgeIds.has(edge.id)); + const terminalTargetEdges = getConnectorTerminalTargetEdges(targetNode.id, nodes, filteredEdges); + + for (const terminalTargetEdge of terminalTargetEdges) { + const downstreamValidation = validateConnection( + { + source: c.source, + sourceHandle: c.sourceHandle, + target: terminalTargetEdge.target, + targetHandle: terminalTargetEdge.targetHandle, + }, + nodes, + stagedEdges, + templates, + null, + true + ); + + if (downstreamValidation !== null) { + return downstreamValidation; + } + } + } + + // Unresolved connector chains are allowed to terminate on another connector. + return null; + } + + if (!isInvocationNode(targetNode)) { return 'nodes.missingInvocationTemplate'; } @@ -103,11 +291,6 @@ export const validateConnection: ValidateConnectionFunc = ( return 'nodes.missingInvocationTemplate'; } - const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle]; - if (!sourceFieldTemplate) { - return 'nodes.missingFieldTemplate'; - } - const targetFieldTemplate = targetTemplate.inputs[c.targetHandle]; if (!targetFieldTemplate) { return 'nodes.missingFieldTemplate'; @@ -117,23 +300,58 @@ export const validateConnection: ValidateConnectionFunc = ( return 'nodes.cannotConnectToDirectInput'; } + if (!effectiveSource) { + if (sourceNode && isConnectorNode(sourceNode) && c.sourceHandle === CONNECTOR_OUTPUT_HANDLE) { + const existingTerminalTargetEdges = getConnectorTerminalTargetEdges(sourceNode.id, nodes, filteredEdges).filter( + (edge) => !(edge.target === c.target && edge.targetHandle === c.targetHandle) + ); + + for (const terminalTargetEdge of existingTerminalTargetEdges) { + const constrainedTargetNode = nodes.find((node) => node.id === terminalTargetEdge.target); + if (!constrainedTargetNode || !isInvocationNode(constrainedTargetNode)) { + return 'nodes.missingInvocationTemplate'; + } + + const constrainedTargetTemplate = templates[constrainedTargetNode.data.type]; + if (!constrainedTargetTemplate) { + return 'nodes.missingInvocationTemplate'; + } + + const constrainedTargetFieldTemplate = constrainedTargetTemplate.inputs[terminalTargetEdge.targetHandle]; + if (!constrainedTargetFieldTemplate) { + return 'nodes.missingFieldTemplate'; + } + + if (!areFieldTypesCompatible(constrainedTargetFieldTemplate.type, targetFieldTemplate.type)) { + return 'nodes.fieldTypesMustMatch'; + } + } + + return null; + } + + return 'nodes.fieldTypesMustMatch'; + } + + const { node: resolvedSourceNode, handle: sourceHandle, fieldTemplate: sourceFieldTemplate } = effectiveSource; + if (targetNode.data.type === 'collect' && c.targetHandle === 'item') { // Collect nodes shouldn't mix and match field types. - const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); + const collectItemType = getCollectItemType(templates, nodes, filteredEdges, targetNode.id); if (collectItemType && !areTypesEqual(sourceFieldTemplate.type, collectItemType)) { return 'nodes.cannotMixAndMatchCollectionItemTypes'; } } if ( - sourceNode.data.type === 'collect' && - c.sourceHandle === 'collection' && + resolvedSourceNode.data.type === 'collect' && + sourceHandle === 'collection' && targetNode.data.type === 'collect' && c.targetHandle === 'collection' ) { // Chained collect nodes should preserve a single item type when both ends are already typed. - const sourceCollectItemType = getCollectItemType(templates, nodes, edges, sourceNode.id); - const targetCollectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); + const sourceCollectItemType = getCollectItemType(templates, nodes, filteredEdges, resolvedSourceNode.id); + const targetCollectItemType = getCollectItemType(templates, nodes, filteredEdges, targetNode.id); if ( sourceCollectItemType && targetCollectItemType && @@ -148,33 +366,24 @@ export const validateConnection: ValidateConnectionFunc = ( const siblingInputEdge = filteredEdges.find((e) => e.target === c.target && e.targetHandle === siblingHandle); if (siblingInputEdge) { - if (siblingInputEdge.source === null || siblingInputEdge.source === undefined) { + const siblingEffectiveSource = getEffectiveSourceForEdge(siblingInputEdge, nodes, filteredEdges, templates); + if (siblingEffectiveSource === 'nodes.missingNode') { return 'nodes.missingNode'; } - - if (siblingInputEdge.sourceHandle === null || siblingInputEdge.sourceHandle === undefined) { - return 'nodes.missingFieldTemplate'; - } - - const siblingSourceNode = nodes.find((n) => n.id === siblingInputEdge.source); - if (!siblingSourceNode) { - return 'nodes.missingNode'; - } - - const siblingSourceTemplate = templates[siblingSourceNode.data.type]; - if (!siblingSourceTemplate) { + if (siblingEffectiveSource === 'nodes.missingInvocationTemplate') { return 'nodes.missingInvocationTemplate'; } - - const siblingSourceFieldTemplate = siblingSourceTemplate.outputs[siblingInputEdge.sourceHandle]; - if (!siblingSourceFieldTemplate) { + if (siblingEffectiveSource === 'nodes.missingFieldTemplate') { return 'nodes.missingFieldTemplate'; } + if (!siblingEffectiveSource) { + return 'nodes.fieldTypesMustMatch'; + } const areIfInputTypesCompatible = - validateConnectionTypes(sourceFieldTemplate.type, siblingSourceFieldTemplate.type) || - validateConnectionTypes(siblingSourceFieldTemplate.type, sourceFieldTemplate.type) || - isSingleCollectionPairOfSameBaseType(sourceFieldTemplate.type, siblingSourceFieldTemplate.type); + validateConnectionTypes(sourceFieldTemplate.type, siblingEffectiveSource.fieldTemplate.type) || + validateConnectionTypes(siblingEffectiveSource.fieldTemplate.type, sourceFieldTemplate.type) || + isSingleCollectionPairOfSameBaseType(sourceFieldTemplate.type, siblingEffectiveSource.fieldTemplate.type); if (!areIfInputTypesCompatible) { return 'nodes.fieldTypesMustMatch'; @@ -189,30 +398,17 @@ export const validateConnection: ValidateConnectionFunc = ( } } - if (sourceNode.data.type === 'if' && c.sourceHandle === 'value') { + if (resolvedSourceNode.data.type === 'if' && sourceHandle === 'value') { const ifInputEdges = filteredEdges.filter( - (e) => e.target === sourceNode.id && typeof e.targetHandle === 'string' && isIfInputHandle(e.targetHandle) + (e) => + e.target === resolvedSourceNode.id && typeof e.targetHandle === 'string' && isIfInputHandle(e.targetHandle) ); const ifInputTypes = ifInputEdges.flatMap((edge) => { - if (edge.source === null || edge.source === undefined) { + const ifInputSource = getEffectiveSourceForEdge(edge, nodes, filteredEdges, templates); + if (!ifInputSource || typeof ifInputSource === 'string') { return []; } - if (edge.sourceHandle === null || edge.sourceHandle === undefined) { - return []; - } - const ifInputSourceNode = nodes.find((n) => n.id === edge.source); - if (!ifInputSourceNode) { - return []; - } - const ifInputSourceTemplate = templates[ifInputSourceNode.data.type]; - if (!ifInputSourceTemplate) { - return []; - } - const ifInputSourceFieldTemplate = ifInputSourceTemplate.outputs[edge.sourceHandle]; - if (!ifInputSourceFieldTemplate) { - return []; - } - return [ifInputSourceFieldTemplate.type]; + return [ifInputSource.fieldTemplate.type]; }); if (ifInputTypes.length > 0) { diff --git a/invokeai/frontend/web/src/features/nodes/store/workflowSettingsSlice.ts b/invokeai/frontend/web/src/features/nodes/store/workflowSettingsSlice.ts index 85b803acd4..7c84bd6f1e 100644 --- a/invokeai/frontend/web/src/features/nodes/store/workflowSettingsSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/workflowSettingsSlice.ts @@ -30,6 +30,7 @@ const zWorkflowSettingsState = z.object({ shouldColorEdges: z.boolean(), shouldShowEdgeLabels: z.boolean(), selectionMode: zSelectionMode, + shouldGroupNodesByCategory: z.boolean(), }); export type WorkflowSettingsState = z.infer; @@ -49,6 +50,7 @@ const getInitialState = (): WorkflowSettingsState => ({ shouldShowEdgeLabels: false, nodeOpacity: 1, selectionMode: 'partial', + shouldGroupNodesByCategory: true, }); const slice = createSlice({ @@ -94,6 +96,9 @@ const slice = createSlice({ selectionModeChanged: (state, action: PayloadAction) => { state.selectionMode = action.payload ? 'full' : 'partial'; }, + shouldGroupNodesByCategoryChanged: (state, action: PayloadAction) => { + state.shouldGroupNodesByCategory = action.payload; + }, }, }); @@ -111,6 +116,7 @@ export const { shouldValidateGraphChanged, nodeOpacityChanged, selectionModeChanged, + shouldGroupNodesByCategoryChanged, } = slice.actions; export const workflowSettingsSliceConfig: SliceConfig = { @@ -123,6 +129,9 @@ export const workflowSettingsSliceConfig: SliceConfig = { if (!('_version' in state)) { state._version = 1; } + if (!('shouldGroupNodesByCategory' in state)) { + state.shouldGroupNodesByCategory = false; + } return zWorkflowSettingsState.parse(state); }, }, @@ -145,3 +154,4 @@ export const selectNodeSpacing = createWorkflowSettingsSelector((s) => s.nodeSpa export const selectLayerSpacing = createWorkflowSettingsSelector((s) => s.layerSpacing); export const selectLayoutDirection = createWorkflowSettingsSelector((s) => s.layoutDirection); export const selectNodeAlignment = createWorkflowSettingsSelector((s) => s.nodeAlignment); +export const selectShouldGroupNodesByCategory = createWorkflowSettingsSelector((s) => s.shouldGroupNodesByCategory); diff --git a/invokeai/frontend/web/src/features/nodes/types/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/invocation.ts index 8cd529deb7..5d8d85dd87 100644 --- a/invokeai/frontend/web/src/features/nodes/types/invocation.ts +++ b/invokeai/frontend/web/src/features/nodes/types/invocation.ts @@ -18,6 +18,7 @@ const _zInvocationTemplate = z.object({ useCache: z.boolean(), nodePack: z.string().min(1).default('invokeai'), classification: zClassification, + category: z.string().default('other'), }); export type InvocationTemplate = z.infer; // #endregion @@ -43,6 +44,12 @@ export const zNotesNodeData = z.object({ isOpen: z.boolean(), notes: z.string(), }); +export const zConnectorNodeData = z.object({ + id: z.string().trim().min(1), + type: z.literal('connector'), + label: z.string(), + isOpen: z.boolean(), +}); const zCurrentImageNodeData = z.object({ id: z.string().trim().min(1), type: z.literal('current_image'), @@ -52,6 +59,7 @@ const zCurrentImageNodeData = z.object({ export type NotesNodeData = z.infer; export type InvocationNodeData = z.infer; +export type ConnectorNodeData = z.infer; type CurrentImageNodeData = z.infer; const zInvocationNodeValidationSchema = z.looseObject({ @@ -70,6 +78,15 @@ const zNotesNodeValidationSchema = z.looseObject({ const zNotesNode = z.custom>((val) => zNotesNodeValidationSchema.safeParse(val).success); export type NotesNode = z.infer; +const zConnectorNodeValidationSchema = z.looseObject({ + type: z.literal('connector'), + data: zConnectorNodeData, +}); +const zConnectorNode = z.custom>( + (val) => zConnectorNodeValidationSchema.safeParse(val).success +); +export type ConnectorNode = z.infer; + const zCurrentImageNodeValidationSchema = z.looseObject({ type: z.literal('current_image'), data: zCurrentImageNodeData, @@ -79,12 +96,14 @@ const zCurrentImageNode = z.custom>( ); export type CurrentImageNode = z.infer; -export const zAnyNode = z.union([zInvocationNode, zNotesNode, zCurrentImageNode]); +export const zAnyNode = z.union([zInvocationNode, zNotesNode, zConnectorNode, zCurrentImageNode]); export type AnyNode = z.infer; export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode => Boolean(node && node.type === 'invocation'); export const isNotesNode = (node?: AnyNode | null): node is NotesNode => Boolean(node && node.type === 'notes'); +export const isConnectorNode = (node?: AnyNode | null): node is ConnectorNode => + Boolean(node && node.type === 'connector'); // #endregion // #region NodeExecutionState diff --git a/invokeai/frontend/web/src/features/nodes/types/workflow.ts b/invokeai/frontend/web/src/features/nodes/types/workflow.ts index 66e69ec585..34f98eb289 100644 --- a/invokeai/frontend/web/src/features/nodes/types/workflow.ts +++ b/invokeai/frontend/web/src/features/nodes/types/workflow.ts @@ -3,7 +3,7 @@ import { z } from 'zod'; import type { FieldType } from './field'; import { zFieldIdentifier } from './field'; -import { zInvocationNodeData, zNotesNodeData } from './invocation'; +import { zConnectorNodeData, zInvocationNodeData, zNotesNodeData } from './invocation'; // #region Workflow misc const zXYPosition = z @@ -31,7 +31,13 @@ const zWorkflowNotesNode = z.object({ data: zNotesNodeData, position: zXYPosition, }); -const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode]); +const zWorkflowConnectorNode = z.object({ + id: z.string().trim().min(1), + type: z.literal('connector'), + data: zConnectorNodeData, + position: zXYPosition, +}); +const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode, zWorkflowConnectorNode]); type WorkflowInvocationNode = z.infer; @@ -377,7 +383,7 @@ export const zWorkflowV3 = z.object({ exposedFields: z.array(zFieldIdentifier), meta: z.object({ category: zWorkflowCategory.default('user'), - version: z.literal('3.0.0'), + version: z.literal('4.0.0'), }), // Use the validated form schema! form: zValidatedBuilderForm, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.test.ts new file mode 100644 index 0000000000..b44c6b38cd --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.test.ts @@ -0,0 +1,194 @@ +import { deepClone } from 'common/util/deepClone'; +import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology'; +import { add, buildEdge, buildNode, img_resize, sub, templates } from 'features/nodes/store/util/testUtils'; +import { describe, expect, it } from 'vitest'; + +import { buildNodesGraph } from './buildNodesGraph'; + +const buildConnectorNode = (id: string) => ({ + id, + type: 'connector' as const, + position: { x: 0, y: 0 }, + data: { + id, + type: 'connector' as const, + label: 'Connector', + isOpen: true, + }, +}); + +const buildState = (nodes: unknown[], edges: unknown[]) => + ({ + nodes: { + present: { + _version: 1, + nodes, + edges, + formFieldInitialValues: {}, + id: undefined, + name: '', + author: '', + description: '', + version: '', + contact: '', + tags: '', + notes: '', + exposedFields: [], + meta: { version: '4.0.0', category: 'user' }, + form: { + rootElementId: 'root', + elements: { + root: { + id: 'root', + type: 'container', + data: { layout: 'column', children: [] }, + }, + }, + }, + }, + }, + gallery: { + autoAddBoardId: 'none', + selection: [], + }, + }) as unknown as Parameters[0]; + +describe('buildNodesGraph', () => { + it('flattens a single connector to one direct execution edge', () => { + const source = buildNode(add); + const target = buildNode(sub); + const connector = buildConnectorNode('connector-1'); + const state = buildState( + [source, target, connector], + [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ] + ); + + const graph = buildNodesGraph(state, templates); + + expect(graph.nodes).not.toHaveProperty(connector.id); + expect(graph.edges).toEqual([ + { + source: { node_id: source.id, field: 'value' }, + destination: { node_id: target.id, field: 'a' }, + }, + ]); + }); + + it('flattens chained connectors transitively', () => { + const source = buildNode(add); + const target = buildNode(sub); + const connectorA = buildConnectorNode('connector-a'); + const connectorB = buildConnectorNode('connector-b'); + const state = buildState( + [source, target, connectorA, connectorB], + [ + buildEdge(source.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorB.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ] + ); + + const graph = buildNodesGraph(state, templates); + + expect(graph.edges).toEqual([ + { + source: { node_id: source.id, field: 'value' }, + destination: { node_id: target.id, field: 'a' }, + }, + ]); + }); + + it('fans out through a connector into multiple execution edges', () => { + const source = buildNode(add); + const targetA = buildNode(sub); + const targetB = buildNode(img_resize); + const connector = buildConnectorNode('connector-1'); + const state = buildState( + [source, targetA, targetB, connector], + [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetA.id, 'a'), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetB.id, 'width'), + ] + ); + + const graph = buildNodesGraph(state, templates); + + expect(graph.edges).toEqual([ + { + source: { node_id: source.id, field: 'value' }, + destination: { node_id: targetA.id, field: 'a' }, + }, + { + source: { node_id: source.id, field: 'value' }, + destination: { node_id: targetB.id, field: 'width' }, + }, + ]); + }); + + it('drops unresolved connector paths from the execution graph', () => { + const target = buildNode(sub); + const connector = buildConnectorNode('connector-1'); + const state = buildState([target, connector], [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a')]); + + const graph = buildNodesGraph(state, templates); + + expect(graph.nodes).not.toHaveProperty(connector.id); + expect(graph.edges).toEqual([]); + }); + + it('deduplicates effective execution edges created by flattening', () => { + const source = buildNode(add); + const target = buildNode(sub); + const connector = buildConnectorNode('connector-1'); + const state = buildState( + [source, target, connector], + [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + buildEdge(source.id, 'value', target.id, 'a'), + ] + ); + + const graph = buildNodesGraph(state, templates); + + expect(graph.edges).toEqual([ + { + source: { node_id: source.id, field: 'value' }, + destination: { node_id: target.id, field: 'a' }, + }, + ]); + }); + + it('still omits explicit destination input values when the flattened edge exists', () => { + const source = buildNode(add); + const target = deepClone(buildNode(sub)); + const connector = buildConnectorNode('connector-1'); + const inputA = target.data.inputs.a; + expect(inputA).toBeDefined(); + if (!inputA) { + throw new Error('Missing input a'); + } + inputA.value = 'not-an-integer' as never; + const state = buildState( + [source, target, connector], + [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ] + ); + + const graph = buildNodesGraph(state, templates); + + expect(graph.edges).toEqual([ + { + source: { node_id: source.id, field: 'value' }, + destination: { node_id: target.id, field: 'a' }, + }, + ]); + expect(graph.nodes[target.id]).not.toHaveProperty('a'); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts index d83555e558..50052c806c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts @@ -4,10 +4,11 @@ import { omit, reduce } from 'es-toolkit/compat'; import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors'; import { selectNodesSlice } from 'features/nodes/store/selectors'; import type { Templates } from 'features/nodes/store/types'; +import { resolveConnectorSource } from 'features/nodes/store/util/connectorTopology'; import type { BoardField } from 'features/nodes/types/common'; import type { BoardFieldInputInstance } from 'features/nodes/types/field'; import { isBoardFieldInputInstance, isBoardFieldInputTemplate } from 'features/nodes/types/field'; -import { isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation'; +import { isConnectorNode, isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation'; import type { AnyInvocation, Graph } from 'services/api/types'; import { v4 as uuidv4 } from 'uuid'; @@ -96,12 +97,57 @@ export const buildNodesGraph = (state: RootState, templates: Templates): Require const filteredNodeIds = filteredNodes.map(({ id }) => id); // skip out the "dummy" edges between collapsed nodes - const filteredEdges = edges - .filter((edge) => edge.type !== 'collapsed') - .filter((edge) => filteredNodeIds.includes(edge.source) && filteredNodeIds.includes(edge.target)); + const flattenedEdges = edges + .filter((edge) => edge.type === 'default') + .flatMap((edge) => { + const targetNode = nodes.find((node) => node.id === edge.target); + if (!targetNode || !isInvocationNode(targetNode) || !isExecutableNode(targetNode)) { + return []; + } + + const sourceNode = nodes.find((node) => node.id === edge.source); + if (!sourceNode) { + return []; + } + + if (isInvocationNode(sourceNode)) { + if (!isExecutableNode(sourceNode) || !filteredNodeIds.includes(sourceNode.id)) { + return []; + } + return [edge]; + } + + if (isConnectorNode(sourceNode)) { + const resolvedSource = resolveConnectorSource(sourceNode.id, nodes, edges); + if (!resolvedSource || !filteredNodeIds.includes(resolvedSource.nodeId)) { + return []; + } + return [ + { + ...edge, + id: `flattened-${resolvedSource.nodeId}-${resolvedSource.fieldName}-${edge.target}-${edge.targetHandle}`, + source: resolvedSource.nodeId, + sourceHandle: resolvedSource.fieldName, + }, + ]; + } + + return []; + }) + .filter((edge, index, allEdges) => { + return ( + allEdges.findIndex( + (candidate) => + candidate.source === edge.source && + candidate.sourceHandle === edge.sourceHandle && + candidate.target === edge.target && + candidate.targetHandle === edge.targetHandle + ) === index + ); + }); // Reduce the node editor edges into invocation graph edges - const parsedEdges = filteredEdges.reduce>((edgesAccumulator, edge) => { + const parsedEdges = flattenedEdges.reduce>((edgesAccumulator, edge) => { const { source, target, sourceHandle, targetHandle } = edge; if (!sourceHandle || !targetHandle) { diff --git a/invokeai/frontend/web/src/features/nodes/util/node/buildConnectorNode.ts b/invokeai/frontend/web/src/features/nodes/util/node/buildConnectorNode.ts new file mode 100644 index 0000000000..18cb45fb1c --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/node/buildConnectorNode.ts @@ -0,0 +1,21 @@ +import type { XYPosition } from '@xyflow/react'; +import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; +import type { ConnectorNode } from 'features/nodes/types/invocation'; +import { v4 as uuidv4 } from 'uuid'; + +export const buildConnectorNode = (position: XYPosition): ConnectorNode => { + const nodeId = uuidv4(); + const node: ConnectorNode = { + ...SHARED_NODE_PROPERTIES, + id: nodeId, + type: 'connector', + position, + data: { + id: nodeId, + type: 'connector', + isOpen: true, + label: 'Connector', + }, + }; + return node; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts index 57cd9943c5..47be2c62ec 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts @@ -113,6 +113,7 @@ export const parseSchema = ( const version = schema.version; const nodePack = schema.node_pack; const classification = schema.classification; + const category = schema.category ?? 'other'; const inputs = reduce( schema.properties, @@ -260,6 +261,7 @@ export const parseSchema = ( useCache, nodePack, classification, + category, }; Object.assign(invocationsAccumulator, { [type]: invocation }); diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts index 17cd0a33a7..21f0c38009 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts @@ -5,7 +5,7 @@ import { parseify } from 'common/util/serialize'; import { pick } from 'es-toolkit/compat'; import { selectNodesSlice } from 'features/nodes/store/selectors'; import type { NodesState } from 'features/nodes/store/types'; -import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; +import { isConnectorNode, isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { zWorkflowV3 } from 'features/nodes/types/workflow'; import i18n from 'i18n'; @@ -42,6 +42,9 @@ export const buildWorkflowFast = (nodesState: NodesState): WorkflowV3 => { if (isInvocationNode(node) && node.type) { const { id, type, data, position } = node; newWorkflow.nodes.push({ id, type, data, position }); + } else if (isConnectorNode(node) && node.type) { + const { id, type, data, position } = node; + newWorkflow.nodes.push({ id, type, data, position }); } else if (isNotesNode(node) && node.type) { const { id, type, data, position } = node; newWorkflow.nodes.push({ id, type, data, position }); diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/graphToWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/graphToWorkflow.ts index 05efc26fea..c09f4e1729 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/graphToWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/graphToWorkflow.ts @@ -30,7 +30,7 @@ export const graphToWorkflow = (graph: NonNullableGraph, autoLayout = true): Wor description: '', meta: { category: 'user', - version: '3.0.0', + version: '4.0.0', }, notes: '', tags: '', diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts index 32971a02d0..638e66806d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts @@ -70,9 +70,11 @@ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { }; const migrateV2toV3 = (workflowToMigrate: WorkflowV2): WorkflowV3 => { - // Bump version - (workflowToMigrate as unknown as WorkflowV3).meta.version = '3.0.0'; - // Parsing strips out any extra properties not in the latest version + return migrateV3toV4(workflowToMigrate as unknown as WorkflowV3); +}; + +const migrateV3toV4 = (workflowToMigrate: WorkflowV3): WorkflowV3 => { + workflowToMigrate.meta.version = '4.0.0'; return zWorkflowV3.parse(workflowToMigrate); }; @@ -100,6 +102,10 @@ export const parseAndMigrateWorkflow = (data: unknown): WorkflowV3 => { workflow = migrateV2toV3(v2); } + if (get(workflow, 'meta.version') === '3.0.0') { + workflow = migrateV3toV4(workflow as WorkflowV3); + } + // We should now have a V3 workflow const migratedWorkflow = zWorkflowV3.parse(workflow); diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.test.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.test.ts index 1442a3475e..c1d0858831 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.test.ts @@ -1,4 +1,5 @@ import { get } from 'es-toolkit/compat'; +import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology'; import { img_resize, main_model_loader } from 'features/nodes/store/util/testUtils'; import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { getDefaultForm } from 'features/nodes/types/workflow'; @@ -7,6 +8,17 @@ import { describe, expect, it } from 'vitest'; //TODO(psyche): Test workflow validation for form builder fields describe('validateWorkflow', () => { + const buildConnectorNode = (id: string) => ({ + id, + type: 'connector' as const, + position: { x: 0, y: 0 }, + data: { + id, + type: 'connector' as const, + label: 'Connector', + isOpen: true, + }, + }); const getWorkflow = (): WorkflowV3 => ({ name: '', author: '', @@ -17,7 +29,7 @@ describe('validateWorkflow', () => { notes: '', exposedFields: [], form: getDefaultForm(), - meta: { version: '3.0.0', category: 'user' }, + meta: { version: '4.0.0', category: 'user' }, nodes: [ { id: '94b1d596-f2f2-4c1c-bd5b-a79c62d947ad', @@ -104,6 +116,7 @@ describe('validateWorkflow', () => { }); expect(validationResult.warnings.length).toBe(1); expect(get(validationResult, 'workflow.nodes[1].data.inputs.image.value')).toBeUndefined(); + expect(validationResult.workflow.meta.version).toBe('4.0.0'); }); it('should reset boards that are inaccessible', async () => { const validationResult = await validateWorkflow({ @@ -127,4 +140,138 @@ describe('validateWorkflow', () => { expect(validationResult.warnings.length).toBe(1); expect(get(validationResult, 'workflow.nodes[0].data.inputs.model.value')).toBeUndefined(); }); + + it('should delete malformed connector edges with invalid handles', async () => { + const workflow = getWorkflow(); + workflow.nodes.push(buildConnectorNode('connector-1')); + workflow.edges.push({ + id: 'e1', + type: 'default', + source: workflow.nodes[0]!.id, + sourceHandle: 'vae', + target: 'connector-1', + targetHandle: 'wrong', + }); + + const validationResult = await validateWorkflow({ + workflow, + templates: { img_resize, main_model_loader }, + checkImageAccess: resolveTrue, + checkBoardAccess: resolveTrue, + checkModelAccess: resolveTrue, + }); + + expect(validationResult.workflow.edges).toEqual([]); + expect(validationResult.warnings.length).toBe(1); + }); + + it('should delete connector edges with missing endpoints', async () => { + const workflow = getWorkflow(); + workflow.nodes.push(buildConnectorNode('connector-1')); + workflow.edges.push({ + id: 'e1', + type: 'default', + source: 'missing-node', + sourceHandle: 'value', + target: 'connector-1', + targetHandle: CONNECTOR_INPUT_HANDLE, + }); + + const validationResult = await validateWorkflow({ + workflow, + templates: { img_resize, main_model_loader }, + checkImageAccess: resolveTrue, + checkBoardAccess: resolveTrue, + checkModelAccess: resolveTrue, + }); + + expect(validationResult.workflow.edges).toEqual([]); + expect(validationResult.warnings.length).toBe(1); + }); + + it('should repair invalid multi-input connector state predictably by keeping the first valid input edge', async () => { + const workflow = getWorkflow(); + const loader2 = structuredClone(workflow.nodes[0]!); + loader2.id = 'second-loader'; + loader2.data.id = 'second-loader'; + const connector = buildConnectorNode('connector-1'); + workflow.nodes.push(loader2, connector); + workflow.edges.push({ + id: 'e1', + type: 'default', + source: workflow.nodes[0]!.id, + sourceHandle: 'vae', + target: connector.id, + targetHandle: CONNECTOR_INPUT_HANDLE, + }); + workflow.edges.push({ + id: 'e2', + type: 'default', + source: loader2.id, + sourceHandle: 'vae', + target: connector.id, + targetHandle: CONNECTOR_INPUT_HANDLE, + }); + + const validationResult = await validateWorkflow({ + workflow, + templates: { img_resize, main_model_loader }, + checkImageAccess: resolveTrue, + checkBoardAccess: resolveTrue, + checkModelAccess: resolveTrue, + }); + + expect(validationResult.workflow.edges).toEqual([ + { + id: 'e1', + type: 'default', + source: workflow.nodes[0]!.id, + sourceHandle: 'vae', + target: connector.id, + targetHandle: CONNECTOR_INPUT_HANDLE, + }, + ]); + expect(validationResult.warnings.length).toBe(1); + }); + + it('should retain isolated connectors during workflow validation', async () => { + const workflow = getWorkflow(); + workflow.nodes.push(buildConnectorNode('connector-1')); + + const validationResult = await validateWorkflow({ + workflow, + templates: { img_resize, main_model_loader }, + checkImageAccess: resolveTrue, + checkBoardAccess: resolveTrue, + checkModelAccess: resolveTrue, + }); + + expect(validationResult.workflow.nodes.find((node) => node.id === 'connector-1')).toBeDefined(); + expect(validationResult.warnings).toEqual([]); + }); + + it('should retain unresolved connector output edges that establish downstream constraints in the editor', async () => { + const workflow = getWorkflow(); + workflow.nodes.push(buildConnectorNode('connector-1')); + const unresolvedEdge = { + id: 'e1', + type: 'default' as const, + source: 'connector-1', + sourceHandle: CONNECTOR_OUTPUT_HANDLE, + target: workflow.nodes[1]!.id, + targetHandle: 'image', + }; + workflow.edges.push(unresolvedEdge); + + const validationResult = await validateWorkflow({ + workflow, + templates: { img_resize, main_model_loader }, + checkImageAccess: resolveTrue, + checkBoardAccess: resolveTrue, + checkModelAccess: resolveTrue, + }); + + expect(validationResult.workflow.edges).toEqual([unresolvedEdge]); + expect(validationResult.warnings).toEqual([]); + }); }); diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts index b86870d450..448214defe 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts @@ -1,6 +1,7 @@ import { parseify } from 'common/util/serialize'; import { addElement, getIsFormEmpty } from 'features/nodes/components/sidePanel/builder/form-manipulation'; import type { Templates } from 'features/nodes/store/types'; +import { validateConnection } from 'features/nodes/store/util/validateConnection'; import { isBoardFieldInputInstance, isImageFieldCollectionInputInstance, @@ -149,8 +150,7 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise(); + const validEdges = []; for (const edge of edges) { // Validate each edge. If the edge is invalid, we must remove it to prevent runtime errors with reactflow. @@ -169,7 +169,7 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise !edgesToDelete.has(id)); + _workflow.edges = validEdges; // Migrated exposed fields to form elements if they exist and the form does not // Note: If the form is invalid per its zod schema, it will be reset to a default, empty form! diff --git a/invokeai/frontend/web/src/features/parameters/components/Core/ParamNegativePrompt.tsx b/invokeai/frontend/web/src/features/parameters/components/Core/ParamNegativePrompt.tsx index 685a0b7a2a..b0405c0ff3 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Core/ParamNegativePrompt.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Core/ParamNegativePrompt.tsx @@ -4,6 +4,7 @@ import { usePersistedTextAreaSize } from 'common/hooks/usePersistedTextareaSize' import { negativePromptChanged, selectNegativePromptWithFallback } from 'features/controlLayers/store/paramsSlice'; import { PromptLabel } from 'features/parameters/components/Prompts/PromptLabel'; import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper'; +import { PromptResizeHandle } from 'features/parameters/components/Prompts/PromptResizeHandle'; import { ViewModePrompt } from 'features/parameters/components/Prompts/ViewModePrompt'; import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton'; import { PromptPopover } from 'features/prompt/PromptPopover'; @@ -22,6 +23,8 @@ const persistOptions: Parameters[2] = { trackHeight: true, }; +const NEGATIVE_PROMPT_MIN_HEIGHT = 28; + export const ParamNegativePrompt = memo(() => { const dispatch = useAppDispatch(); const prompt = useAppSelector(selectNegativePromptWithFallback); @@ -70,14 +73,16 @@ export const ParamNegativePrompt = memo(() => { onChange={onChange} onKeyDown={onKeyDown} variant="darkFilled" - minH={28} borderTopWidth={24} // This prevents the prompt from being hidden behind the header paddingInlineEnd={10} paddingInlineStart={3} paddingTop={0} paddingBottom={3} + resize="none" + minH={NEGATIVE_PROMPT_MIN_HEIGHT} fontFamily="mono" fontSize="0.82rem" + sx={{ '&::-webkit-resizer': { display: 'none' } }} /> @@ -90,6 +95,7 @@ export const ParamNegativePrompt = memo(() => { label={`${t('parameters.negativePromptPlaceholder')} (${t('stylePresets.preview')})`} /> )} + ); diff --git a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx index f95e950c25..89169b5ea5 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx @@ -11,6 +11,7 @@ import { ShowDynamicPromptsPreviewButton } from 'features/dynamicPrompts/compone import { NegativePromptToggleButton } from 'features/parameters/components/Core/NegativePromptToggleButton'; import { PromptLabel } from 'features/parameters/components/Prompts/PromptLabel'; import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper'; +import { PromptResizeHandle } from 'features/parameters/components/Prompts/PromptResizeHandle'; import { ViewModePrompt } from 'features/parameters/components/Prompts/ViewModePrompt'; import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton'; import { PromptPopover } from 'features/prompt/PromptPopover'; @@ -35,6 +36,8 @@ const persistOptions: Parameters[2] = { initialHeight: 120, }; +const POSITIVE_PROMPT_MIN_HEIGHT = 32; + const usePromptHistory = () => { const store = useAppStore(); const history = useAppSelector(selectPositivePromptHistory); @@ -215,10 +218,11 @@ export const ParamPositivePrompt = memo(() => { paddingInlineStart={3} paddingTop={0} paddingBottom={3} - resize="vertical" - minH={32} + resize="none" + minH={POSITIVE_PROMPT_MIN_HEIGHT} fontFamily="mono" fontSize="0.82rem" + sx={{ '&::-webkit-resizer': { display: 'none' } }} /> @@ -236,6 +240,7 @@ export const ParamPositivePrompt = memo(() => { label={`${t('parameters.positivePromptPlaceholder')} (${t('stylePresets.preview')})`} /> )} + diff --git a/invokeai/frontend/web/src/features/parameters/components/Prompts/PromptResizeHandle.tsx b/invokeai/frontend/web/src/features/parameters/components/Prompts/PromptResizeHandle.tsx new file mode 100644 index 0000000000..0a5f211924 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/components/Prompts/PromptResizeHandle.tsx @@ -0,0 +1,135 @@ +import { Box } from '@invoke-ai/ui-library'; +import { + memo, + type PointerEvent as ReactPointerEvent, + type RefObject, + useCallback, + useEffect, + useRef, + useState, +} from 'react'; + +type PromptResizeHandleProps = { + textareaRef: RefObject; + minHeight: number; +}; + +const PROMPT_RESIZE_HANDLE_HEIGHT_PX = 8; + +export const PromptResizeHandle = memo(({ textareaRef, minHeight }: PromptResizeHandleProps) => { + const activePointerIdRef = useRef(null); + const startHeightRef = useRef(0); + const startYRef = useRef(0); + const previousCursorRef = useRef(''); + const previousUserSelectRef = useRef(''); + const [isResizing, setIsResizing] = useState(false); + + const stopResize = useCallback(() => { + if (activePointerIdRef.current === null) { + return; + } + + activePointerIdRef.current = null; + setIsResizing(false); + document.body.style.cursor = previousCursorRef.current; + document.body.style.userSelect = previousUserSelectRef.current; + }, []); + + useEffect(() => stopResize, [stopResize]); + + const onPointerDown = useCallback( + (e: ReactPointerEvent) => { + if (e.button !== 0) { + return; + } + + const textarea = textareaRef.current; + if (!textarea) { + return; + } + + activePointerIdRef.current = e.pointerId; + startYRef.current = e.clientY; + startHeightRef.current = textarea.offsetHeight; + previousCursorRef.current = document.body.style.cursor; + previousUserSelectRef.current = document.body.style.userSelect; + + document.body.style.cursor = 'ns-resize'; + document.body.style.userSelect = 'none'; + e.currentTarget.setPointerCapture(e.pointerId); + setIsResizing(true); + e.preventDefault(); + }, + [textareaRef] + ); + + const onPointerMove = useCallback( + (e: ReactPointerEvent) => { + if (activePointerIdRef.current !== e.pointerId) { + return; + } + + const textarea = textareaRef.current; + if (!textarea) { + return; + } + + const nextHeight = Math.max(minHeight, startHeightRef.current + e.clientY - startYRef.current); + textarea.style.height = `${nextHeight}px`; + e.preventDefault(); + }, + [minHeight, textareaRef] + ); + + const onPointerUp = useCallback( + (e: ReactPointerEvent) => { + if (activePointerIdRef.current !== e.pointerId) { + return; + } + + if (e.currentTarget.hasPointerCapture(e.pointerId)) { + e.currentTarget.releasePointerCapture(e.pointerId); + } + + stopResize(); + }, + [stopResize] + ); + + const onPointerCancel = useCallback( + (e: ReactPointerEvent) => { + if (activePointerIdRef.current !== e.pointerId) { + return; + } + + stopResize(); + }, + [stopResize] + ); + + return ( + + ); +}); + +PromptResizeHandle.displayName = 'PromptResizeHandle'; diff --git a/invokeai/frontend/web/src/features/system/components/HotkeysModal/useHotkeyData.ts b/invokeai/frontend/web/src/features/system/components/HotkeysModal/useHotkeyData.ts index c6a0bd6705..dfc9b5d280 100644 --- a/invokeai/frontend/web/src/features/system/components/HotkeysModal/useHotkeyData.ts +++ b/invokeai/frontend/web/src/features/system/components/HotkeysModal/useHotkeyData.ts @@ -118,6 +118,7 @@ export const useHotkeyData = (): HotkeysData => { addHotkey('canvas', 'selectEraserTool', ['e']); addHotkey('canvas', 'selectMoveTool', ['v']); addHotkey('canvas', 'selectRectTool', ['u']); + addHotkey('canvas', 'selectLassoTool', ['l']); addHotkey('canvas', 'selectViewTool', ['h']); addHotkey('canvas', 'selectColorPickerTool', ['i']); addHotkey('canvas', 'setFillColorsToDefault', ['d']); diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx index b95d2adb47..7bcc103402 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx @@ -11,6 +11,8 @@ import { ModalFooter, ModalHeader, ModalOverlay, + NumberInput, + NumberInputField, Switch, Text, } from '@invoke-ai/ui-library'; @@ -19,6 +21,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import { buildUseBoolean } from 'common/hooks/useBoolean'; +import { selectCurrentUser } from 'features/auth/store/authSlice'; import { selectShouldUseCPUNoise, shouldUseCpuNoiseChanged } from 'features/controlLayers/store/paramsSlice'; import { useRefreshAfterResetModal } from 'features/system/components/SettingsModal/RefreshAfterResetModal'; import { SettingsDeveloperLogIsEnabled } from 'features/system/components/SettingsModal/SettingsDeveloperLogIsEnabled'; @@ -48,16 +51,25 @@ import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged, } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { selectShouldShowProgressInViewer } from 'features/ui/store/uiSelectors'; import { setShouldShowProgressInViewer } from 'features/ui/store/uiSlice'; -import type { ChangeEvent, ReactElement } from 'react'; -import { cloneElement, memo, useCallback, useEffect } from 'react'; +import type { ChangeEvent, KeyboardEvent, ReactElement } from 'react'; +import { cloneElement, memo, useCallback, useEffect, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; +import { useGetRuntimeConfigQuery, useUpdateRuntimeConfigMutation } from 'services/api/endpoints/appInfo'; import { SettingsLanguageSelect } from './SettingsLanguageSelect'; const [useSettingsModal] = buildUseBoolean(false); +const formatOptionalInteger = (value: number | null | undefined) => { + if (value === null || value === undefined) { + return ''; + } + return String(value); +}; + const SettingsModal = (props: { children: ReactElement }) => { const dispatch = useAppDispatch(); const { t } = useTranslation(); @@ -72,6 +84,10 @@ const SettingsModal = (props: { children: ReactElement }) => { const settingsModal = useSettingsModal(); const refreshModal = useRefreshAfterResetModal(); + const currentUser = useAppSelector(selectCurrentUser); + const { data: runtimeConfig } = useGetRuntimeConfigQuery(); + const [updateRuntimeConfig, { isLoading: isUpdatingRuntimeConfig }] = useUpdateRuntimeConfigMutation(); + const pendingMaxQueueHistoryRef = useRef(undefined); const prefersNumericAttentionWeights = useAppSelector(selectSystemPrefersNumericAttentionWeights); const shouldUseCpuNoise = useAppSelector(selectShouldUseCPUNoise); @@ -85,6 +101,10 @@ const SettingsModal = (props: { children: ReactElement }) => { const shouldHighlightFocusedRegions = useAppSelector(selectSystemShouldEnableHighlightFocusedRegions); const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession); const shouldShowInvocationProgressDetail = useAppSelector(selectSystemShouldShowInvocationProgressDetail); + const maxQueueHistory = runtimeConfig?.config.max_queue_history ?? null; + const canEditRuntimeConfig = runtimeConfig ? !runtimeConfig.config.multiuser || currentUser?.is_admin : false; + const [maxQueueHistoryInput, setMaxQueueHistoryInput] = useState(formatOptionalInteger(maxQueueHistory)); + const onToggleConfirmOnNewSession = useCallback(() => { dispatch(shouldConfirmOnNewSessionToggled()); }, [dispatch]); @@ -96,11 +116,60 @@ const SettingsModal = (props: { children: ReactElement }) => { } }, [refetchIntermediatesCount, settingsModal.isTrue]); + useEffect(() => { + setMaxQueueHistoryInput(formatOptionalInteger(maxQueueHistory)); + }, [maxQueueHistory]); + + const commitMaxQueueHistory = useCallback(async () => { + if (!runtimeConfig || !canEditRuntimeConfig) { + return; + } + + const trimmedValue = maxQueueHistoryInput.trim(); + const parsedValue = trimmedValue === '' ? null : Number.parseInt(trimmedValue, 10); + + if (parsedValue !== null && Number.isNaN(parsedValue)) { + setMaxQueueHistoryInput(formatOptionalInteger(maxQueueHistory)); + return; + } + + const normalizedValue = parsedValue === null ? null : Math.max(0, parsedValue); + const currentValue = + pendingMaxQueueHistoryRef.current === undefined ? maxQueueHistory : pendingMaxQueueHistoryRef.current; + + if (normalizedValue === currentValue) { + setMaxQueueHistoryInput(formatOptionalInteger(currentValue)); + return; + } + + pendingMaxQueueHistoryRef.current = normalizedValue; + setMaxQueueHistoryInput(formatOptionalInteger(normalizedValue)); + + try { + await updateRuntimeConfig({ max_queue_history: normalizedValue }).unwrap(); + } catch { + setMaxQueueHistoryInput(formatOptionalInteger(maxQueueHistory)); + toast({ + id: 'SETTINGS_MAX_QUEUE_HISTORY_SAVE_FAILED', + title: t('settings.maxQueueHistorySaveFailed'), + status: 'error', + }); + } finally { + pendingMaxQueueHistoryRef.current = undefined; + } + }, [canEditRuntimeConfig, maxQueueHistory, maxQueueHistoryInput, runtimeConfig, t, updateRuntimeConfig]); + + const handleCloseSettingsModal = useCallback(() => { + void commitMaxQueueHistory(); + settingsModal.setFalse(); + }, [commitMaxQueueHistory, settingsModal]); + const handleClickResetWebUI = useCallback(() => { + void commitMaxQueueHistory(); clearStorage(); settingsModal.setFalse(); refreshModal.setTrue(); - }, [settingsModal, refreshModal]); + }, [commitMaxQueueHistory, refreshModal, settingsModal]); const handleChangeShouldConfirmOnDelete = useCallback( (e: ChangeEvent) => { @@ -172,12 +241,30 @@ const SettingsModal = (props: { children: ReactElement }) => { [dispatch] ); + const handleChangeMaxQueueHistory = useCallback((valueAsString: string) => { + setMaxQueueHistoryInput(valueAsString); + }, []); + + const handleBlurMaxQueueHistory = useCallback(() => { + void commitMaxQueueHistory(); + }, [commitMaxQueueHistory]); + + const handleKeyDownMaxQueueHistory = useCallback( + (e: KeyboardEvent) => { + if (e.key === 'Enter') { + void commitMaxQueueHistory(); + e.currentTarget.blur(); + } + }, + [commitMaxQueueHistory] + ); + return ( <> {cloneElement(props.children, { onClick: settingsModal.setTrue, })} - + {t('common.settingsLabel')} @@ -206,6 +293,21 @@ const SettingsModal = (props: { children: ReactElement }) => { {t('settings.enableInvisibleWatermark')} + + {t('settings.maxQueueHistory')} + + + + diff --git a/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts b/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts index 8fe85125e6..1c656d289b 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts @@ -51,6 +51,26 @@ export const appInfoApi = api.injectEndpoints({ url: buildAppInfoUrl('runtime_config'), method: 'GET', }), + providesTags: ['AppConfig'], + }), + updateRuntimeConfig: build.mutation< + paths['/api/v1/app/runtime_config']['patch']['responses']['200']['content']['application/json'], + paths['/api/v1/app/runtime_config']['patch']['requestBody']['content']['application/json'] + >({ + query: (body) => ({ + url: buildAppInfoUrl('runtime_config'), + method: 'PATCH', + body, + }), + async onQueryStarted(_, { dispatch, queryFulfilled }) { + try { + const { data } = await queryFulfilled; + dispatch(appInfoApi.util.upsertQueryData('getRuntimeConfig', undefined, data)); + } catch { + // no-op + } + }, + invalidatesTags: ['AppConfig'], }), getInvocationCacheStatus: build.query< paths['/api/v1/app/invocation_cache/status']['get']['responses']['200']['content']['application/json'], @@ -95,6 +115,7 @@ export const { useGetAppDepsQuery, useGetPatchmatchStatusQuery, useGetRuntimeConfigQuery, + useUpdateRuntimeConfigMutation, useClearInvocationCacheMutation, useDisableInvocationCacheMutation, useEnableInvocationCacheMutation, diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 14f8e407a3..10afb4dfce 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -1586,7 +1586,8 @@ export type paths = { delete?: never; options?: never; head?: never; - patch?: never; + /** Update Runtime Config */ + patch: operations["update_runtime_config"]; trace?: never; }; "/api/v1/app/logging": { @@ -15201,7 +15202,8 @@ export type components = { * force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty). * pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. * max_queue_size: Maximum number of items in the session queue. - * clear_queue_on_startup: Empties session queue on startup. + * clear_queue_on_startup: Empties session queue on startup. If true, disables `max_queue_history`. + * max_queue_history: Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true. * allow_nodes: List of nodes to allow. Omit to allow all. * deny_nodes: List of nodes to deny. Omit to deny none. * node_cache_size: How many cached nodes to keep in memory. @@ -15529,10 +15531,15 @@ export type components = { max_queue_size?: number; /** * Clear Queue On Startup - * @description Empties session queue on startup. + * @description Empties session queue on startup. If true, disables `max_queue_history`. * @default false */ clear_queue_on_startup?: boolean; + /** + * Max Queue History + * @description Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true. + */ + max_queue_history?: number | null; /** * Allow Nodes * @description List of nodes to allow. Omit to allow all. @@ -28610,6 +28617,17 @@ export type components = { */ unstarred_images: string[]; }; + /** + * UpdateAppGenerationSettingsRequest + * @description Writable generation-related app settings. + */ + UpdateAppGenerationSettingsRequest: { + /** + * Max Queue History + * @description Keep the last N completed, failed, and canceled queue items on startup. Set to 0 to prune all terminal items. + */ + max_queue_history?: number | null; + }; /** * UserDTO * @description User data transfer object. @@ -33776,6 +33794,39 @@ export interface operations { }; }; }; + update_runtime_config: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody: { + content: { + "application/json": components["schemas"]["UpdateAppGenerationSettingsRequest"]; + }; + }; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["InvokeAIAppConfigWithSetFields"]; + }; + }; + /** @description Validation Error */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; get_log_level: { parameters: { query?: never;