diff --git a/Makefile b/Makefile index f1e81429e7..ecf101f1d5 100644 --- a/Makefile +++ b/Makefile @@ -12,12 +12,13 @@ help: @echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports" @echo "test Run the unit tests." @echo "update-config-docstring Update the app's config docstring so mkdocs can autogenerate it correctly." - @echo "frontend-install Install the pnpm modules needed for the front end" - @echo "frontend-build Build the frontend in order to run on localhost:9090" + @echo "frontend-install Install the pnpm modules needed for the frontend" + @echo "frontend-build Build the frontend for localhost:9090" + @echo "frontend-test Run the frontend test suite once" @echo "frontend-dev Run the frontend in developer mode on localhost:5173" @echo "frontend-typegen Generate types for the frontend from the OpenAPI schema" - @echo "frontend-prettier Format the frontend using lint:prettier" - @echo "wheel Build the wheel for the current version" + @echo "frontend-lint Run frontend checks and fixable lint/format steps" + @echo "wheel Build the wheel for the current version" @echo "tag-release Tag the GitHub repository with the current version (use at release time only!)" @echo "openapi Generate the OpenAPI schema for the app, outputting to stdout" @echo "docs Serve the mkdocs site with live reload" @@ -57,6 +58,10 @@ frontend-install: frontend-build: cd invokeai/frontend/web && pnpm build +# Run the frontend test suite once +frontend-test: + cd invokeai/frontend/web && pnpm run test:run + # Run the frontend in dev mode frontend-dev: cd invokeai/frontend/web && pnpm dev diff --git a/README.md b/README.md index 06fc98e46b..4c6cc52493 100644 --- a/README.md +++ b/README.md @@ -58,15 +58,39 @@ Invoke offers a fully featured workflow management solution, enabling users to c Invoke features an organized gallery system for easily storing, accessing, and remixing your content in the Invoke workspace. Images can be dragged/dropped onto any Image-base UI element in the application, and rich metadata within the Image allows for easy recall of key prompts or settings used in your workflow. +### Model Support +- SD 1.5 +- SD 2.0 +- SDXL +- SD 3.5 Medium +- SD 3.5 Large +- CogView 4 +- Flux.1 Dev +- Flux.1 Schnell +- Flux.1 Kontext +- Flux.1 Krea +- Flux Redux +- Flux Fill +- Flux.2 Klein 4B +- Flux.2 Klein 9B +- Z-Image Turbo +- Z-Image Base +- Anima +- Qwen Image +- Qwen Image Edit +- Nano Banana (API Only) +- GPT Image (API Only) +- Wan (API Only) + ### Other features -- Support for both ckpt and diffusers models -- SD1.5, SD2.0, SDXL, and FLUX support +- Support for ckpt, diffusers, and some gguf models - Upscaling Tools - Embedding Manager & Support - Model Manager & Support - Workflow creation & management - Node-Based Architecture +- Object Segmentation & Selection Models (SAM / SAM2) ## Contributing diff --git a/docs/contributing/CANVAS_PROJECTS/CANVAS_PROJECTS.md b/docs/contributing/CANVAS_PROJECTS/CANVAS_PROJECTS.md new file mode 100644 index 0000000000..a05ef29492 --- /dev/null +++ b/docs/contributing/CANVAS_PROJECTS/CANVAS_PROJECTS.md @@ -0,0 +1,205 @@ +# Canvas Projects — Technical Documentation + +## Overview + +Canvas Projects provide a save/load mechanism for the entire canvas state. The feature serializes all canvas entities, generation parameters, reference images, and their associated image files into a ZIP-based `.invk` file. On load, it restores the full state, handling image deduplication and re-uploading as needed. + +## File Format + +The `.invk` file is a standard ZIP archive with the following structure: + +``` +project.invk +├── manifest.json +├── canvas_state.json +├── params.json +├── ref_images.json +├── loras.json +└── images/ + ├── {image_name_1}.png + ├── {image_name_2}.png + └── ... +``` + +### manifest.json + +Schema version and metadata. Validated on load with Zod. + +```json +{ + "version": 1, + "appVersion": "5.12.0", + "createdAt": "2026-02-26T12:00:00.000Z", + "name": "My Canvas Project" +} +``` + +| Field | Type | Description | +|---|---|---| +| `version` | `number` | Schema version, currently `1`. Used for migration logic on load. | +| `appVersion` | `string` | InvokeAI version that created the file. Informational only. | +| `createdAt` | `string` | ISO 8601 timestamp. | +| `name` | `string` | User-provided project name. Also used as the download filename. | + +### canvas_state.json + +The serialized canvas entity tree. Type: `CanvasProjectState`. + +```typescript +type CanvasProjectState = { + rasterLayers: CanvasRasterLayerState[]; + controlLayers: CanvasControlLayerState[]; + inpaintMasks: CanvasInpaintMaskState[]; + regionalGuidance: CanvasRegionalGuidanceState[]; + bbox: CanvasState['bbox']; + selectedEntityIdentifier: CanvasState['selectedEntityIdentifier']; + bookmarkedEntityIdentifier: CanvasState['bookmarkedEntityIdentifier']; +}; +``` + +Each entity contains its full state including all canvas objects (brush lines, eraser lines, rect shapes, images). Image objects reference files by `image_name` which correspond to files in the `images/` folder. + +### params.json + +The complete generation parameters state (`ParamsState`). Optional on load (older files may not have it). This includes all fields from the params Redux slice: + +- Prompts (positive, negative, prompt history) +- Core generation settings (seed, steps, CFG scale, guidance, scheduler, iterations) +- Model selections (main model, VAE, FLUX VAE, T5 encoder, CLIP embed models, refiner, Z-Image models, Klein models) +- Dimensions (width, height, aspect ratio) +- Img2img strength +- Infill settings (method, tile size, patchmatch downscale, color) +- Canvas coherence settings (mode, edge size, min denoise) +- Refiner parameters (steps, CFG scale, scheduler, aesthetic scores, start) +- FLUX-specific settings (scheduler, DyPE preset/scale/exponent) +- Z-Image-specific settings (scheduler, seed variance) +- Upscale settings (scheduler, CFG scale) +- Seamless tiling, mask blur, CLIP skip, VAE precision, CPU noise, color compensation + +### ref_images.json + +Global reference image entities (`RefImageState[]`). These are IP-Adapter / FLUX Redux configs with `CroppableImageWithDims` containing both original and cropped image references. Optional on load. + +### loras.json + +Array of LoRA configurations (`LoRA[]`). Each entry contains: + +```typescript +type LoRA = { + id: string; + isEnabled: boolean; + model: ModelIdentifierField; + weight: number; +}; +``` + +Optional on load. Like models, LoRA identifiers are stored as-is — if a LoRA is not installed when loading, the entry is restored but may not be usable. + +### images/ + +All image files referenced anywhere in the state. Keyed by their original `image_name`. On save, each image is fetched from the backend via `GET /api/v1/images/i/{name}/full` and stored as-is. + +## Key Source Files + +| File | Purpose | +|---|---| +| `features/controlLayers/util/canvasProjectFile.ts` | Types, constants, image name collection, remapping, existence checking | +| `features/controlLayers/hooks/useCanvasProjectSave.ts` | Save hook — collects Redux state, fetches images, builds ZIP | +| `features/controlLayers/hooks/useCanvasProjectLoad.ts` | Load hook — parses ZIP, deduplicates images, dispatches state | +| `features/controlLayers/components/SaveCanvasProjectDialog.tsx` | Save name dialog + `useSaveCanvasProjectWithDialog` hook | +| `features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog.tsx` | Load confirmation dialog + `useLoadCanvasProjectWithDialog` hook | +| `features/controlLayers/components/Toolbar/CanvasToolbarProjectMenuButton.tsx` | Toolbar dropdown UI | +| `features/controlLayers/store/canvasSlice.ts` | `canvasProjectRecalled` Redux action | + +## Save Flow + +1. User clicks "Save Canvas Project" → `SaveCanvasProjectDialog` opens asking for a project name +2. On confirm, `saveCanvasProject(name)` is called +3. Read Redux state via selectors: `selectCanvasSlice()`, `selectParamsSlice()`, `selectRefImagesSlice()`, `selectLoRAsSlice()` +4. Build `CanvasProjectState` from the canvas slice; use `paramsState` directly for params +5. Walk all entities to collect every `image_name` reference via `collectImageNames()`: + - `CanvasImageState.image.image_name` in layer/mask objects + - `CroppableImageWithDims.original.image.image_name` in global ref images + - `CroppableImageWithDims.crop.image.image_name` in cropped ref images + - `ImageWithDims.image_name` in regional guidance ref images +6. Fetch each image from the backend API +7. Build ZIP with JSZip: add `manifest.json` (including `name`), `canvas_state.json`, `params.json`, `ref_images.json`, and all images into `images/` +8. Sanitize the name for filesystem use and generate blob, trigger download as `{name}.invk` + +## Load Flow + +1. User selects `.invk` file → confirmation dialog opens +2. On confirm, parse ZIP with JSZip +3. Validate manifest version via Zod schema +4. Read `canvas_state.json`, `params.json` (optional), `ref_images.json` (optional) +5. Collect all `image_name` references from the loaded state +6. **Deduplicate images**: for each referenced image, check if it exists on the server via `getImageDTOSafe(image_name)` + - Already exists → skip (no upload) + - Missing → upload from ZIP via `uploadImage()`, record `oldName → newName` mapping +7. Remap all `image_name` values in the loaded state using the mapping (only for re-uploaded images whose names changed) +8. Dispatch Redux actions: + - `canvasProjectRecalled()` — restores all canvas entities, bbox, selected/bookmarked entity + - `refImagesRecalled()` — restores global reference images + - `paramsRecalled()` — replaces the entire params state in one action + - `loraAllDeleted()` + `loraRecalled()` — restores LoRAs +9. Show success/error toast + +## Image Name Collection & Remapping + +The `canvasProjectFile.ts` utility provides two parallel sets of functions: + +**Collection** (`collectImageNames`): Walks the entire state tree and returns a `Set` of all referenced `image_name` values. This is used by both save (to know which images to fetch) and load (to know which images to check/upload). + +**Remapping** (`remapCanvasState`, `remapRefImages`): Deep-clones state objects and replaces `image_name` values using a `Map` mapping. Only images that were re-uploaded with a different name are remapped. Images that already existed on the server are left unchanged. + +Both walk the same paths through the state tree: +- Layer/mask objects → `CanvasImageState.image.image_name` +- Regional guidance ref images → `ImageWithDims.image_name` +- Global ref images → `CroppableImageWithDims.original.image.image_name` and `.crop.image.image_name` + +## Extending the Format + +### Adding new optional data (non-breaking) + +Add a new JSON file to the ZIP. No version bump needed. + +1. **Save**: Add `zip.file('new_data.json', JSON.stringify(data))` in `useCanvasProjectSave.ts` +2. **Load**: Read with `zip.file('new_data.json')` in `useCanvasProjectLoad.ts` — check for `null` so older project files without it still load +3. **Dispatch**: Add the appropriate Redux action to restore the data + +### Adding new entity types with images + +1. Extend `CanvasProjectState` type in `canvasProjectFile.ts` +2. Add collection logic in `collectImageNames()` to walk the new entity's objects +3. Add remapping logic in `remapCanvasState()` to update image names +4. Include the new entity array in both save and load hooks +5. Handle it in the `canvasProjectRecalled` reducer in `canvasSlice.ts` + +### Breaking schema changes + +1. Bump `CANVAS_PROJECT_VERSION` in `canvasProjectFile.ts` +2. Update the Zod manifest schema: `version: z.union([z.literal(1), z.literal(2)])` +3. Add migration logic in the load hook: check version, transform v1 → v2 before dispatching + +## UI Architecture + +### Save dialog + +The save flow uses a **nanostore atom** (`$isOpen`) to control the `SaveCanvasProjectDialog`: + +1. `useSaveCanvasProjectWithDialog()` — returns a callback that sets `$isOpen` to `true` +2. `SaveCanvasProjectDialog` (singleton in `GlobalModalIsolator`) — renders an `AlertDialog` with a name input +3. On save → calls `saveCanvasProject(name)` and closes the dialog +4. On cancel → closes the dialog + +### Load dialog + +The load flow uses a **nanostore atom** (`$pendingFile`) to decouple the file dialog from the confirmation dialog: + +1. `useLoadCanvasProjectWithDialog()` — opens a programmatic file input (`document.createElement('input')`) +2. On file selection → sets `$pendingFile` atom +3. `LoadCanvasProjectConfirmationAlertDialog` (singleton in `GlobalModalIsolator`) — subscribes to `$pendingFile` via `useStore()` +4. On accept → calls `loadCanvasProject(file)` and clears the atom +5. On cancel → clears the atom + +The programmatic file input approach was chosen because the context menu component uses `isLazy: true`, which unmounts the DOM tree when the menu closes — a hidden `` element inside the menu would be destroyed before the file dialog returns. diff --git a/docs/features/Lasso_tool.md b/docs/features/Lasso_tool.md new file mode 100644 index 0000000000..8f7fc6d4ec --- /dev/null +++ b/docs/features/Lasso_tool.md @@ -0,0 +1,32 @@ +Lasso Tool +=========== + +- The Lasso tool creates selections and inpaint masks by drawing freehand or polygonal regions on the canvas. + +How to open the Lasso tool +-------------------------- +- Click the Lasso icon in the toolbar. +- Hotkey: press `L` (default). The hotkey is shown in the tool's tooltip and can be customized in Hotkeys settings. + +Modes +----- +- Freehand (default) + - Hold the pointer and drag to draw a continuous contour. + - Long segments are broken into intermediate points to keep the line continuous. + - Very long strokes may be simplified after drawing to reduce point count for performance. + +- Polygon + - Click to place points; click the first point (or a point near it) to close the polygon. + - The tool snaps the closing point to the start for precise closures. + +Basic interactions +------------------ +- Switch modes with the mode toggle in the toolbar. +- To close a polygon: click the starting point again or click near it — the tool aligns the final point to the start to complete the shape. +- The selection will be added to the current Inpaint Mask layer. If no Inpaint Mask layer exists, a new one will be created automatically. + +Tips & behavior +--------------- +- Hold `Space` to temporarily switch to the View tool for panning and zooming; release `Space` to return to the Lasso tool and continue drawing. +- When using the Polygon mode, you can hold `Shift` to snap points to horizontal, vertical, or 45-degree angles for more precise shapes. +- Hold `Ctrl` (Windows/Linux) or `Command` (macOS) while drawing to subtract from the current selection instead of adding to it. diff --git a/docs/features/canvas_projects.md b/docs/features/canvas_projects.md new file mode 100644 index 0000000000..8b161c6745 --- /dev/null +++ b/docs/features/canvas_projects.md @@ -0,0 +1,56 @@ +--- +title: Canvas Projects +--- + +# :material-folder-zip: Canvas Projects + +## Save and Restore Your Canvas Work + +Canvas Projects let you save your entire canvas setup to a file and load it back later. This is useful when you want to: + +- **Switch between tasks** without losing your current canvas arrangement +- **Back up complex setups** with multiple layers, masks, and reference images +- **Share canvas layouts** with others or transfer them between machines +- **Recover from deleted images** — all images are embedded in the project file + +## What Gets Saved + +A canvas project file (`.invk`) captures everything about your current canvas session: + +- **All layers** — raster layers, control layers, inpaint masks, regional guidance +- **All drawn content** — brush strokes, pasted images, eraser marks +- **Reference images** — global IP-Adapter / FLUX Redux images with crop settings +- **Regional guidance** — per-region prompts and reference images +- **Bounding box** — position, size, aspect ratio, and scale settings +- **All generation parameters** — prompts, seed, steps, CFG scale, guidance, scheduler, model, VAE, dimensions, img2img strength, infill settings, canvas coherence, refiner settings, FLUX/Z-Image specific parameters, and more +- **LoRAs** — all added LoRA models with their weights and enabled/disabled state + +## How to Save a Project + +You can save from two places: + +1. **Toolbar** — Click the **Archive icon** in the canvas toolbar, then select **Save Canvas Project** +2. **Context menu** — Right-click the canvas, open the **Project** submenu, then select **Save Canvas Project** + +A dialog will ask you to enter a **project name**. This name is used as the filename (e.g., entering "My Portrait" saves as `My Portrait.invk`) and is stored inside the project file. + +## How to Load a Project + +1. **Toolbar** — Click the **Archive icon**, then select **Load Canvas Project** +2. **Context menu** — Right-click the canvas, open the **Project** submenu, then select **Load Canvas Project** + +A file dialog will open. Select your `.invk` file. You will see a confirmation dialog warning that loading will replace your current canvas. Click **Load** to proceed. + +### What Happens on Load + +- Your current canvas is **completely replaced** — all existing layers, masks, reference images, and parameters are overwritten +- Images that are already present on your InvokeAI server are reused automatically (no duplicate uploads) +- Images that were deleted from the server are re-uploaded from the project file +- If the saved model is not installed on your system, the model identifier is still restored — you will need to select an available model manually + +## Good to Know + +- **No undo** — Loading a project replaces your canvas entirely. There is no way to undo this action, so save your current project first if you want to keep it. +- **Image deduplication** — When loading, images already on your server are not re-uploaded. Only missing images are uploaded from the project file. +- **File size** — The `.invk` file size depends on the number and resolution of images in your canvas. A project with many high-resolution layers can be large. +- **Model availability** — The project saves which model was selected, but does not include the model itself. If the model is not installed when you load the project, you will need to select a different one. diff --git a/docs/multiuser/user_guide.md b/docs/multiuser/user_guide.md index 9c950913de..87587c599f 100644 --- a/docs/multiuser/user_guide.md +++ b/docs/multiuser/user_guide.md @@ -140,13 +140,13 @@ As a regular user, you can: - ✅ View your own generation queue - ✅ Customize your UI preferences (theme, hotkeys, etc.) - ✅ View available models (read-only access to Model Manager) -- ✅ Access shared boards (based on permissions granted to you) (FUTURE FEATURE) -- ✅ Access workflows marked as public (FUTURE FEATURE) +- ✅ View shared and public boards created by other users +- ✅ View and use workflows marked as shared by other users You cannot: - ❌ Add, delete, or modify models -- ❌ View or modify other users' boards, images, or workflows +- ❌ View or modify other users' private boards, images, or workflows - ❌ Manage user accounts - ❌ Access system configuration - ❌ View or cancel other users' generation tasks @@ -173,7 +173,7 @@ Administrators have all regular user capabilities, plus: - ✅ Full model management (add, delete, configure models) - ✅ Create and manage user accounts - ✅ View and manage all users' generation queues -- ✅ Create and manage shared boards (FUTURE FEATURE) +- ✅ View and manage all users' boards, images, and workflows (including system-owned legacy content) - ✅ Access system configuration - ✅ Grant or revoke admin privileges @@ -183,23 +183,30 @@ Administrators have all regular user capabilities, plus: ### Image Boards -In multi-user model, Image Boards work as before. Each user can create an unlimited number of boards and organize their images and assets as they see fit. Boards are private: you cannot see a board owned by a different user. +In multi-user mode, each user can create an unlimited number of boards and organize their images and assets as they see fit. Boards have three visibility levels: -!!! tip "Shared Boards" - InvokeAI 6.13 will add support for creating public boards that are accessible to all users. +- **Private** (default): Only you (and administrators) can see and modify the board. +- **Shared**: All users can view the board and its contents, but only you (and administrators) can modify it (rename, archive, delete, or add/remove images). +- **Public**: All users can view the board. Only you (and administrators) can modify the board's structure (rename, archive, delete). -The Administrator can see all users Image Boards and their contents. +To change a board's visibility, right-click on the board and select the desired visibility option. -### Going From Multi-User to Single-User mode +Administrators can see and manage all users' image boards and their contents regardless of visibility settings. + +### Going From Multi-User to Single-User Mode If an InvokeAI instance was in multiuser mode and then restarted in single user mode (by setting `multiuser: false` in the configuration file), all users' boards will be consolidated in one place. Any images that were in "Uncategorized" will be merged together into a single Uncategorized board. If, at a later date, the server is restarted in multi-user mode, the boards and images will be separated and restored to their owners. ### Workflows -In the current released version (6.12) workflows are always shared among users. Any workflow that you create will be visible to other users and vice-versa, and there is no protection against one user modifying another user's workflow. +Each user has their own private workflow library. Workflows you create are visible only to you by default. -!!! tip "Private and Shared Workflows" - InvokeAI 6.13 will provide the ability to create private and shared workflows. A private workflow can only be viewed by the user who created it. At any time, however, the user can designate the workflow *shared*, in which case it can be opened on a read-only basis by all logged-in users. +You can share a workflow with other users by marking it as **shared** (public). Shared workflows appear in all users' workflow libraries and can be opened by anyone, but only the owner (or an administrator) can modify or delete them. + +To share a workflow, open it and use the sharing controls to toggle its public/shared status. + +!!! warning "Preexisting workflows after enabling multi-user mode" + When you enable multi-user mode for the first time on an existing InvokeAI installation, all workflows that were created before multi-user mode was activated will appear in the **shared workflows** section. These preexisting workflows are owned by the internal "system" account and are visible to all users. Administrators can edit or delete these shared legacy workflows. Regular users can view and use them but cannot modify them. ### The Generation Queue @@ -330,11 +337,11 @@ These settings are stored per-user and won't affect other users. ### Can other users see my images? -No, unless you add them to a shared board (FUTURE FEATURE). All your personal boards and images are private. +Not unless you change your board's visibility to "shared" or "public". All personal boards and images are private by default. ### Can I share my workflows with others? -Not directly. Ask your administrator to mark workflows as public if you want to share them. +Yes. You can mark any workflow as shared (public), which makes it visible to all users. Other users can view and use shared workflows, but only you or an administrator can modify or delete them. ### How long do sessions last? diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 856dcb7d1a..ed7910cee7 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -49,10 +49,12 @@ from invokeai.app.services.users.users_default import UserService from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + AnimaConditioningInfo, BasicConditioningInfo, CogView4ConditioningInfo, ConditioningFieldData, FLUXConditioningInfo, + QwenImageConditioningInfo, SD3ConditioningInfo, SDXLConditioningInfo, ZImageConditioningInfo, @@ -143,6 +145,8 @@ class ApiDependencies: SD3ConditioningInfo, CogView4ConditioningInfo, ZImageConditioningInfo, + QwenImageConditioningInfo, + AnimaConditioningInfo, ], ephemeral=True, ), diff --git a/invokeai/app/api/routers/app_info.py b/invokeai/app/api/routers/app_info.py index ea3ca480b9..d2b3c9ef02 100644 --- a/invokeai/app/api/routers/app_info.py +++ b/invokeai/app/api/routers/app_info.py @@ -10,14 +10,15 @@ from fastapi import Body, HTTPException, Path from fastapi.routing import APIRouter from pydantic import BaseModel, Field +from invokeai.app.api.auth_dependencies import AdminUserOrDefault from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.services.config.config_default import ( - DefaultInvokeAIAppConfig, EXTERNAL_PROVIDER_CONFIG_FIELDS, + DefaultInvokeAIAppConfig, InvokeAIAppConfig, get_config, - load_external_api_keys, load_and_migrate_config, + load_external_api_keys, ) from invokeai.app.services.external_generation.external_generation_common import ExternalProviderStatus from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus @@ -103,6 +104,16 @@ EXTERNAL_PROVIDER_FIELDS: dict[str, tuple[str, str]] = { _EXTERNAL_PROVIDER_CONFIG_LOCK = Lock() +class UpdateAppGenerationSettingsRequest(BaseModel): + """Writable generation-related app settings.""" + + max_queue_history: int | None = Field( + default=None, + ge=0, + description="Keep the last N completed, failed, and canceled queue items on startup. Set to 0 to prune all terminal items.", + ) + + @app_router.get( "/runtime_config", operation_id="get_runtime_config", status_code=200, response_model=InvokeAIAppConfigWithSetFields ) @@ -111,6 +122,30 @@ async def get_runtime_config() -> InvokeAIAppConfigWithSetFields: return InvokeAIAppConfigWithSetFields(set_fields=config.model_fields_set, config=config) +@app_router.patch( + "/runtime_config", + operation_id="update_runtime_config", + status_code=200, + response_model=InvokeAIAppConfigWithSetFields, +) +async def update_runtime_config( + _: AdminUserOrDefault, + changes: UpdateAppGenerationSettingsRequest = Body(description="Writable runtime configuration changes"), +) -> InvokeAIAppConfigWithSetFields: + config = get_config() + update_dict = changes.model_dump(exclude_unset=True) + config.update_config(update_dict) + + if config.config_file_path.exists(): + persisted_config = load_and_migrate_config(config.config_file_path) + else: + persisted_config = DefaultInvokeAIAppConfig() + + persisted_config.update_config(update_dict) + persisted_config.write_file(config.config_file_path) + return InvokeAIAppConfigWithSetFields(set_fields=config.model_fields_set, config=config) + + @app_router.get( "/external_providers/status", operation_id="get_external_provider_statuses", diff --git a/invokeai/app/api/routers/auth.py b/invokeai/app/api/routers/auth.py index b4c1e86cf3..e0b0c885cd 100644 --- a/invokeai/app/api/routers/auth.py +++ b/invokeai/app/api/routers/auth.py @@ -80,6 +80,7 @@ class SetupStatusResponse(BaseModel): setup_required: bool = Field(description="Whether initial setup is required") multiuser_enabled: bool = Field(description="Whether multiuser mode is enabled") strict_password_checking: bool = Field(description="Whether strict password requirements are enforced") + admin_email: str | None = Field(default=None, description="Email of the first active admin user, if any") @auth_router.get("/status", response_model=SetupStatusResponse) @@ -94,15 +95,25 @@ async def get_setup_status() -> SetupStatusResponse: # If multiuser is disabled, setup is never required if not config.multiuser: return SetupStatusResponse( - setup_required=False, multiuser_enabled=False, strict_password_checking=config.strict_password_checking + setup_required=False, + multiuser_enabled=False, + strict_password_checking=config.strict_password_checking, + admin_email=None, ) # In multiuser mode, check if an admin exists user_service = ApiDependencies.invoker.services.users setup_required = not user_service.has_admin() + # Only expose admin_email during initial setup to avoid leaking + # administrator identity on public deployments. + admin_email = user_service.get_admin_email() if setup_required else None + return SetupStatusResponse( - setup_required=setup_required, multiuser_enabled=True, strict_password_checking=config.strict_password_checking + setup_required=setup_required, + multiuser_enabled=True, + strict_password_checking=config.strict_password_checking, + admin_email=admin_email, ) @@ -150,6 +161,7 @@ async def login( user_id=user.user_id, email=user.email, is_admin=user.is_admin, + remember_me=request.remember_me, ) token = create_access_token(token_data, expires_delta) diff --git a/invokeai/app/api/routers/board_images.py b/invokeai/app/api/routers/board_images.py index cb5e0ab51a..f94e4f2437 100644 --- a/invokeai/app/api/routers/board_images.py +++ b/invokeai/app/api/routers/board_images.py @@ -1,12 +1,53 @@ from fastapi import Body, HTTPException from fastapi.routing import APIRouter +from invokeai.app.api.auth_dependencies import CurrentUserOrDefault from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.services.images.images_common import AddImagesToBoardResult, RemoveImagesFromBoardResult board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"]) +def _assert_board_write_access(board_id: str, current_user: CurrentUserOrDefault) -> None: + """Raise 403 if the current user may not mutate the given board. + + Write access is granted when ANY of these hold: + - The user is an admin. + - The user owns the board. + - The board visibility is Public (public boards accept contributions from any user). + """ + from invokeai.app.services.board_records.board_records_common import BoardVisibility + + try: + board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + except Exception: + raise HTTPException(status_code=404, detail="Board not found") + if current_user.is_admin: + return + if board.user_id == current_user.user_id: + return + if board.board_visibility == BoardVisibility.Public: + return + raise HTTPException(status_code=403, detail="Not authorized to modify this board") + + +def _assert_image_direct_owner(image_name: str, current_user: CurrentUserOrDefault) -> None: + """Raise 403 if the current user is not the direct owner of the image. + + This is intentionally stricter than _assert_image_owner in images.py: + board ownership is NOT sufficient here. Allowing a user to add someone + else's image to their own board would grant them mutation rights via the + board-ownership fallback in _assert_image_owner, escalating read access + into write access. + """ + if current_user.is_admin: + return + owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name) + if owner is not None and owner == current_user.user_id: + return + raise HTTPException(status_code=403, detail="Not authorized to move this image") + + @board_images_router.post( "/", operation_id="add_image_to_board", @@ -17,14 +58,17 @@ board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"]) response_model=AddImagesToBoardResult, ) async def add_image_to_board( + current_user: CurrentUserOrDefault, board_id: str = Body(description="The id of the board to add to"), image_name: str = Body(description="The name of the image to add"), ) -> AddImagesToBoardResult: """Creates a board_image""" + _assert_board_write_access(board_id, current_user) + _assert_image_direct_owner(image_name, current_user) try: added_images: set[str] = set() affected_boards: set[str] = set() - old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none" + old_board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) or "none" ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name) added_images.add(image_name) affected_boards.add(board_id) @@ -48,13 +92,16 @@ async def add_image_to_board( response_model=RemoveImagesFromBoardResult, ) async def remove_image_from_board( + current_user: CurrentUserOrDefault, image_name: str = Body(description="The name of the image to remove", embed=True), ) -> RemoveImagesFromBoardResult: """Removes an image from its board, if it had one""" try: + old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none" + if old_board_id != "none": + _assert_board_write_access(old_board_id, current_user) removed_images: set[str] = set() affected_boards: set[str] = set() - old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none" ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name) removed_images.add(image_name) affected_boards.add("none") @@ -64,6 +111,8 @@ async def remove_image_from_board( affected_boards=list(affected_boards), ) + except HTTPException: + raise except Exception: raise HTTPException(status_code=500, detail="Failed to remove image from board") @@ -78,16 +127,21 @@ async def remove_image_from_board( response_model=AddImagesToBoardResult, ) async def add_images_to_board( + current_user: CurrentUserOrDefault, board_id: str = Body(description="The id of the board to add to"), image_names: list[str] = Body(description="The names of the images to add", embed=True), ) -> AddImagesToBoardResult: """Adds a list of images to a board""" + _assert_board_write_access(board_id, current_user) try: added_images: set[str] = set() affected_boards: set[str] = set() for image_name in image_names: try: - old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none" + _assert_image_direct_owner(image_name, current_user) + old_board_id = ( + ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) or "none" + ) ApiDependencies.invoker.services.board_images.add_image_to_board( board_id=board_id, image_name=image_name, @@ -96,12 +150,16 @@ async def add_images_to_board( affected_boards.add(board_id) affected_boards.add(old_board_id) + except HTTPException: + raise except Exception: pass return AddImagesToBoardResult( added_images=list(added_images), affected_boards=list(affected_boards), ) + except HTTPException: + raise except Exception: raise HTTPException(status_code=500, detail="Failed to add images to board") @@ -116,6 +174,7 @@ async def add_images_to_board( response_model=RemoveImagesFromBoardResult, ) async def remove_images_from_board( + current_user: CurrentUserOrDefault, image_names: list[str] = Body(description="The names of the images to remove", embed=True), ) -> RemoveImagesFromBoardResult: """Removes a list of images from their board, if they had one""" @@ -125,15 +184,21 @@ async def remove_images_from_board( for image_name in image_names: try: old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none" + if old_board_id != "none": + _assert_board_write_access(old_board_id, current_user) ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name) removed_images.add(image_name) affected_boards.add("none") affected_boards.add(old_board_id) + except HTTPException: + raise except Exception: pass return RemoveImagesFromBoardResult( removed_images=list(removed_images), affected_boards=list(affected_boards), ) + except HTTPException: + raise except Exception: raise HTTPException(status_code=500, detail="Failed to remove images from board") diff --git a/invokeai/app/api/routers/boards.py b/invokeai/app/api/routers/boards.py index e93bb8b2a9..6897e90aff 100644 --- a/invokeai/app/api/routers/boards.py +++ b/invokeai/app/api/routers/boards.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field from invokeai.app.api.auth_dependencies import CurrentUserOrDefault from invokeai.app.api.dependencies import ApiDependencies -from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy +from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy, BoardVisibility from invokeai.app.services.boards.boards_common import BoardDTO from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.app.services.shared.pagination import OffsetPaginatedResults @@ -56,7 +56,14 @@ async def get_board( except Exception: raise HTTPException(status_code=404, detail="Board not found") - if not current_user.is_admin and result.user_id != current_user.user_id: + # Admins can access any board. + # Owners can access their own boards. + # Shared and public boards are visible to all authenticated users. + if ( + not current_user.is_admin + and result.user_id != current_user.user_id + and result.board_visibility == BoardVisibility.Private + ): raise HTTPException(status_code=403, detail="Not authorized to access this board") return result @@ -188,7 +195,11 @@ async def list_all_board_image_names( except Exception: raise HTTPException(status_code=404, detail="Board not found") - if not current_user.is_admin and board.user_id != current_user.user_id: + if ( + not current_user.is_admin + and board.user_id != current_user.user_id + and board.board_visibility == BoardVisibility.Private + ): raise HTTPException(status_code=403, detail="Not authorized to access this board") image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board( @@ -196,4 +207,15 @@ async def list_all_board_image_names( categories, is_intermediate, ) + + # For uncategorized images (board_id="none"), filter to only the caller's + # images so that one user cannot enumerate another's uncategorized images. + # Admin users can see all uncategorized images. + if board_id == "none" and not current_user.is_admin: + image_names = [ + name + for name in image_names + if ApiDependencies.invoker.services.image_records.get_user_id(name) == current_user.user_id + ] + return image_names diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 6b11762c9e..a3ae6fce82 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -38,6 +38,96 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"]) IMAGE_MAX_AGE = 31536000 +def _assert_image_owner(image_name: str, current_user: CurrentUserOrDefault) -> None: + """Raise 403 if the current user does not own the image and is not an admin. + + Ownership is satisfied when ANY of these hold: + - The user is an admin. + - The user is the image's direct owner (image_records.user_id). + - The user owns the board the image sits on. + - The image sits on a Public board (public boards grant mutation rights). + """ + from invokeai.app.services.board_records.board_records_common import BoardVisibility + + if current_user.is_admin: + return + owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name) + if owner is not None and owner == current_user.user_id: + return + + # Check whether the user owns the board the image belongs to, + # or the board is Public (public boards grant mutation rights). + board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) + if board_id is not None: + try: + board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + if board.user_id == current_user.user_id: + return + if board.board_visibility == BoardVisibility.Public: + return + except Exception: + pass + + raise HTTPException(status_code=403, detail="Not authorized to modify this image") + + +def _assert_image_read_access(image_name: str, current_user: CurrentUserOrDefault) -> None: + """Raise 403 if the current user may not view the image. + + Access is granted when ANY of these hold: + - The user is an admin. + - The user owns the image. + - The image sits on a shared or public board. + """ + from invokeai.app.services.board_records.board_records_common import BoardVisibility + + if current_user.is_admin: + return + + owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name) + if owner is not None and owner == current_user.user_id: + return + + # Check whether the image's board makes it visible to other users. + board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) + if board_id is not None: + try: + board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public): + return + except Exception: + pass + + raise HTTPException(status_code=403, detail="Not authorized to access this image") + + +def _assert_board_read_access(board_id: str, current_user: CurrentUserOrDefault) -> None: + """Raise 403 if the current user may not read images from this board. + + Access is granted when ANY of these hold: + - The user is an admin. + - The user owns the board. + - The board visibility is Shared or Public. + """ + from invokeai.app.services.board_records.board_records_common import BoardVisibility + + if current_user.is_admin: + return + + try: + board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + except Exception: + raise HTTPException(status_code=404, detail="Board not found") + + if board.user_id == current_user.user_id: + return + + if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public): + return + + raise HTTPException(status_code=403, detail="Not authorized to access this board") + + class ResizeToDimensions(BaseModel): width: int = Field(..., gt=0) height: int = Field(..., gt=0) @@ -83,6 +173,22 @@ async def upload_image( ), ) -> ImageDTO: """Uploads an image for the current user""" + # If uploading into a board, verify the user has write access. + # Public boards allow uploads from any authenticated user. + if board_id is not None: + from invokeai.app.services.board_records.board_records_common import BoardVisibility + + try: + board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + except Exception: + raise HTTPException(status_code=404, detail="Board not found") + if ( + not current_user.is_admin + and board.user_id != current_user.user_id + and board.board_visibility != BoardVisibility.Public + ): + raise HTTPException(status_code=403, detail="Not authorized to upload to this board") + if not file.content_type or not file.content_type.startswith("image"): raise HTTPException(status_code=415, detail="Not an image") @@ -165,9 +271,11 @@ async def create_image_upload_entry( @images_router.delete("/i/{image_name}", operation_id="delete_image", response_model=DeleteImagesResult) async def delete_image( + current_user: CurrentUserOrDefault, image_name: str = Path(description="The name of the image to delete"), ) -> DeleteImagesResult: """Deletes an image""" + _assert_image_owner(image_name, current_user) deleted_images: set[str] = set() affected_boards: set[str] = set() @@ -189,26 +297,31 @@ async def delete_image( @images_router.delete("/intermediates", operation_id="clear_intermediates") -async def clear_intermediates() -> int: - """Clears all intermediates""" +async def clear_intermediates( + current_user: CurrentUserOrDefault, +) -> int: + """Clears all intermediates. Requires admin.""" + if not current_user.is_admin: + raise HTTPException(status_code=403, detail="Only admins can clear all intermediates") try: count_deleted = ApiDependencies.invoker.services.images.delete_intermediates() return count_deleted except Exception: raise HTTPException(status_code=500, detail="Failed to clear intermediates") - pass @images_router.get("/intermediates", operation_id="get_intermediates_count") -async def get_intermediates_count() -> int: - """Gets the count of intermediate images""" +async def get_intermediates_count( + current_user: CurrentUserOrDefault, +) -> int: + """Gets the count of intermediate images. Non-admin users only see their own intermediates.""" try: - return ApiDependencies.invoker.services.images.get_intermediates_count() + user_id = None if current_user.is_admin else current_user.user_id + return ApiDependencies.invoker.services.images.get_intermediates_count(user_id=user_id) except Exception: raise HTTPException(status_code=500, detail="Failed to get intermediates") - pass @images_router.patch( @@ -217,10 +330,12 @@ async def get_intermediates_count() -> int: response_model=ImageDTO, ) async def update_image( + current_user: CurrentUserOrDefault, image_name: str = Path(description="The name of the image to update"), image_changes: ImageRecordChanges = Body(description="The changes to apply to the image"), ) -> ImageDTO: """Updates an image""" + _assert_image_owner(image_name, current_user) try: return ApiDependencies.invoker.services.images.update(image_name, image_changes) @@ -234,9 +349,11 @@ async def update_image( response_model=ImageDTO, ) async def get_image_dto( + current_user: CurrentUserOrDefault, image_name: str = Path(description="The name of image to get"), ) -> ImageDTO: """Gets an image's DTO""" + _assert_image_read_access(image_name, current_user) try: return ApiDependencies.invoker.services.images.get_dto(image_name) @@ -250,9 +367,11 @@ async def get_image_dto( response_model=Optional[MetadataField], ) async def get_image_metadata( + current_user: CurrentUserOrDefault, image_name: str = Path(description="The name of image to get"), ) -> Optional[MetadataField]: """Gets an image's metadata""" + _assert_image_read_access(image_name, current_user) try: return ApiDependencies.invoker.services.images.get_metadata(image_name) @@ -269,8 +388,11 @@ class WorkflowAndGraphResponse(BaseModel): "/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=WorkflowAndGraphResponse ) async def get_image_workflow( + current_user: CurrentUserOrDefault, image_name: str = Path(description="The name of image whose workflow to get"), ) -> WorkflowAndGraphResponse: + _assert_image_read_access(image_name, current_user) + try: workflow = ApiDependencies.invoker.services.images.get_workflow(image_name) graph = ApiDependencies.invoker.services.images.get_graph(image_name) @@ -306,8 +428,12 @@ async def get_image_workflow( async def get_image_full( image_name: str = Path(description="The name of full-resolution image file to get"), ) -> Response: - """Gets a full-resolution image file""" + """Gets a full-resolution image file. + This endpoint is intentionally unauthenticated because browsers load images + via tags which cannot send Bearer tokens. Image names are UUIDs, + providing security through unguessability. + """ try: path = ApiDependencies.invoker.services.images.get_path(image_name) with open(path, "rb") as f: @@ -335,8 +461,12 @@ async def get_image_full( async def get_image_thumbnail( image_name: str = Path(description="The name of thumbnail image file to get"), ) -> Response: - """Gets a thumbnail image file""" + """Gets a thumbnail image file. + This endpoint is intentionally unauthenticated because browsers load images + via tags which cannot send Bearer tokens. Image names are UUIDs, + providing security through unguessability. + """ try: path = ApiDependencies.invoker.services.images.get_path(image_name, thumbnail=True) with open(path, "rb") as f: @@ -354,9 +484,11 @@ async def get_image_thumbnail( response_model=ImageUrlsDTO, ) async def get_image_urls( + current_user: CurrentUserOrDefault, image_name: str = Path(description="The name of the image whose URL to get"), ) -> ImageUrlsDTO: """Gets an image and thumbnail URL""" + _assert_image_read_access(image_name, current_user) try: image_url = ApiDependencies.invoker.services.images.get_url(image_name) @@ -392,6 +524,11 @@ async def list_image_dtos( ) -> OffsetPaginatedResults[ImageDTO]: """Gets a list of image DTOs for the current user""" + # Validate that the caller can read from this board before listing its images. + # "none" is a sentinel for uncategorized images and is handled by the SQL layer. + if board_id is not None and board_id != "none": + _assert_board_read_access(board_id, current_user) + image_dtos = ApiDependencies.invoker.services.images.get_many( offset, limit, @@ -410,6 +547,7 @@ async def list_image_dtos( @images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesResult) async def delete_images_from_list( + current_user: CurrentUserOrDefault, image_names: list[str] = Body(description="The list of names of images to delete", embed=True), ) -> DeleteImagesResult: try: @@ -417,24 +555,31 @@ async def delete_images_from_list( affected_boards: set[str] = set() for image_name in image_names: try: + _assert_image_owner(image_name, current_user) image_dto = ApiDependencies.invoker.services.images.get_dto(image_name) board_id = image_dto.board_id or "none" ApiDependencies.invoker.services.images.delete(image_name) deleted_images.add(image_name) affected_boards.add(board_id) + except HTTPException: + raise except Exception: pass return DeleteImagesResult( deleted_images=list(deleted_images), affected_boards=list(affected_boards), ) + except HTTPException: + raise except Exception: raise HTTPException(status_code=500, detail="Failed to delete images") @images_router.delete("/uncategorized", operation_id="delete_uncategorized_images", response_model=DeleteImagesResult) -async def delete_uncategorized_images() -> DeleteImagesResult: - """Deletes all images that are uncategorized""" +async def delete_uncategorized_images( + current_user: CurrentUserOrDefault, +) -> DeleteImagesResult: + """Deletes all uncategorized images owned by the current user (or all if admin)""" image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board( board_id="none", categories=None, is_intermediate=None @@ -445,9 +590,13 @@ async def delete_uncategorized_images() -> DeleteImagesResult: affected_boards: set[str] = set() for image_name in image_names: try: + _assert_image_owner(image_name, current_user) ApiDependencies.invoker.services.images.delete(image_name) deleted_images.add(image_name) affected_boards.add("none") + except HTTPException: + # Skip images not owned by the current user + pass except Exception: pass return DeleteImagesResult( @@ -464,6 +613,7 @@ class ImagesUpdatedFromListResult(BaseModel): @images_router.post("/star", operation_id="star_images_in_list", response_model=StarredImagesResult) async def star_images_in_list( + current_user: CurrentUserOrDefault, image_names: list[str] = Body(description="The list of names of images to star", embed=True), ) -> StarredImagesResult: try: @@ -471,23 +621,29 @@ async def star_images_in_list( affected_boards: set[str] = set() for image_name in image_names: try: + _assert_image_owner(image_name, current_user) updated_image_dto = ApiDependencies.invoker.services.images.update( image_name, changes=ImageRecordChanges(starred=True) ) starred_images.add(image_name) affected_boards.add(updated_image_dto.board_id or "none") + except HTTPException: + raise except Exception: pass return StarredImagesResult( starred_images=list(starred_images), affected_boards=list(affected_boards), ) + except HTTPException: + raise except Exception: raise HTTPException(status_code=500, detail="Failed to star images") @images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=UnstarredImagesResult) async def unstar_images_in_list( + current_user: CurrentUserOrDefault, image_names: list[str] = Body(description="The list of names of images to unstar", embed=True), ) -> UnstarredImagesResult: try: @@ -495,17 +651,22 @@ async def unstar_images_in_list( affected_boards: set[str] = set() for image_name in image_names: try: + _assert_image_owner(image_name, current_user) updated_image_dto = ApiDependencies.invoker.services.images.update( image_name, changes=ImageRecordChanges(starred=False) ) unstarred_images.add(image_name) affected_boards.add(updated_image_dto.board_id or "none") + except HTTPException: + raise except Exception: pass return UnstarredImagesResult( unstarred_images=list(unstarred_images), affected_boards=list(affected_boards), ) + except HTTPException: + raise except Exception: raise HTTPException(status_code=500, detail="Failed to unstar images") @@ -523,6 +684,7 @@ class ImagesDownloaded(BaseModel): "/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202 ) async def download_images_from_list( + current_user: CurrentUserOrDefault, background_tasks: BackgroundTasks, image_names: Optional[list[str]] = Body( default=None, description="The list of names of images to download", embed=True @@ -533,6 +695,16 @@ async def download_images_from_list( ) -> ImagesDownloaded: if (image_names is None or len(image_names) == 0) and board_id is None: raise HTTPException(status_code=400, detail="No images or board id specified.") + + # Validate that the caller can read every image they are requesting. + # For a board_id request, check board visibility; for explicit image names, + # check each image individually. + if board_id: + _assert_board_read_access(board_id, current_user) + if image_names: + for name in image_names: + _assert_image_read_access(name, current_user) + bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id) background_tasks.add_task( @@ -540,6 +712,7 @@ async def download_images_from_list( image_names, board_id, bulk_download_item_id, + current_user.user_id, ) return ImagesDownloaded(bulk_download_item_name=bulk_download_item_id + ".zip") @@ -558,11 +731,21 @@ async def download_images_from_list( }, ) async def get_bulk_download_item( + current_user: CurrentUserOrDefault, background_tasks: BackgroundTasks, bulk_download_item_name: str = Path(description="The bulk_download_item_name of the bulk download item to get"), ) -> FileResponse: - """Gets a bulk download zip file""" + """Gets a bulk download zip file. + + Requires authentication. The caller must be the user who initiated the + download (tracked by the bulk download service) or an admin. + """ try: + # Verify the caller owns this download (or is an admin) + owner = ApiDependencies.invoker.services.bulk_download.get_owner(bulk_download_item_name) + if owner is not None and owner != current_user.user_id and not current_user.is_admin: + raise HTTPException(status_code=403, detail="Not authorized to access this download") + path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name) response = FileResponse( @@ -574,6 +757,8 @@ async def get_bulk_download_item( response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}" background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.delete, bulk_download_item_name) return response + except HTTPException: + raise except Exception: raise HTTPException(status_code=404) @@ -594,6 +779,10 @@ async def get_image_names( ) -> ImageNamesResult: """Gets ordered list of image names with metadata for optimistic updates""" + # Validate that the caller can read from this board before listing its images. + if board_id is not None and board_id != "none": + _assert_board_read_access(board_id, current_user) + try: result = ApiDependencies.invoker.services.images.get_image_names( starred_first=starred_first, @@ -617,6 +806,7 @@ async def get_image_names( responses={200: {"model": list[ImageDTO]}}, ) async def get_images_by_names( + current_user: CurrentUserOrDefault, image_names: list[str] = Body(embed=True, description="Object containing list of image names to fetch DTOs for"), ) -> list[ImageDTO]: """Gets image DTOs for the specified image names. Maintains order of input names.""" @@ -628,8 +818,12 @@ async def get_images_by_names( image_dtos: list[ImageDTO] = [] for name in image_names: try: + _assert_image_read_access(name, current_user) dto = image_service.get_dto(name) image_dtos.append(dto) + except HTTPException: + # Skip images the user is not authorized to view + continue except Exception: # Skip missing images - they may have been deleted between name fetch and DTO fetch continue diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 127a2c1f50..f351be11ad 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -547,6 +547,19 @@ class BulkDeleteModelsResponse(BaseModel): failed: List[dict] = Field(description="List of failed deletions with error messages") +class BulkReidentifyModelsRequest(BaseModel): + """Request body for bulk model reidentification.""" + + keys: List[str] = Field(description="List of model keys to reidentify") + + +class BulkReidentifyModelsResponse(BaseModel): + """Response body for bulk model reidentification.""" + + succeeded: List[str] = Field(description="List of successfully reidentified model keys") + failed: List[dict] = Field(description="List of failed reidentifications with error messages") + + @model_manager_router.post( "/i/bulk_delete", operation_id="bulk_delete_models", @@ -588,6 +601,67 @@ async def bulk_delete_models( return BulkDeleteModelsResponse(deleted=deleted, failed=failed) +@model_manager_router.post( + "/i/bulk_reidentify", + operation_id="bulk_reidentify_models", + responses={ + 200: {"description": "Models reidentified (possibly with some failures)"}, + }, + status_code=200, +) +async def bulk_reidentify_models( + current_admin: AdminUserOrDefault, + request: BulkReidentifyModelsRequest = Body(description="List of model keys to reidentify"), +) -> BulkReidentifyModelsResponse: + """ + Reidentify multiple models by re-probing their weights files. + + Returns a list of successfully reidentified keys and failed reidentifications with error messages. + """ + logger = ApiDependencies.invoker.services.logger + store = ApiDependencies.invoker.services.model_manager.store + models_path = ApiDependencies.invoker.services.configuration.models_path + + succeeded = [] + failed = [] + + for key in request.keys: + try: + config = store.get_model(key) + if pathlib.Path(config.path).is_relative_to(models_path): + model_path = pathlib.Path(config.path) + else: + model_path = models_path / config.path + mod = ModelOnDisk(model_path) + result = ModelConfigFactory.from_model_on_disk(mod) + if result.config is None: + raise InvalidModelException("Unable to identify model format") + + # Retain user-editable fields from the original config + result.config.path = config.path + result.config.key = config.key + result.config.name = config.name + result.config.description = config.description + result.config.cover_image = config.cover_image + if hasattr(config, "trigger_phrases") and hasattr(result.config, "trigger_phrases"): + result.config.trigger_phrases = config.trigger_phrases + result.config.source = config.source + result.config.source_type = config.source_type + + store.replace_model(config.key, result.config) + succeeded.append(key) + logger.info(f"Reidentified model: {key}") + except UnknownModelException as e: + logger.error(f"Failed to reidentify model {key}: {str(e)}") + failed.append({"key": key, "error": str(e)}) + except Exception as e: + logger.error(f"Failed to reidentify model {key}: {str(e)}") + failed.append({"key": key, "error": str(e)}) + + logger.info(f"Bulk reidentify completed: {len(succeeded)} succeeded, {len(failed)} failed") + return BulkReidentifyModelsResponse(succeeded=succeeded, failed=failed) + + @model_manager_router.delete( "/i/{key}/image", operation_id="delete_model_image", @@ -815,7 +889,7 @@ async def install_hugging_face_model( "/install", operation_id="list_model_installs", ) -async def list_model_installs() -> List[ModelInstallJob]: +async def list_model_installs(current_admin: AdminUserOrDefault) -> List[ModelInstallJob]: """Return the list of model install jobs. Install jobs have a numeric `id`, a `status`, and other fields that provide information on @@ -847,7 +921,9 @@ async def list_model_installs() -> List[ModelInstallJob]: 404: {"description": "No such job"}, }, ) -async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob: +async def get_model_install_job( + current_admin: AdminUserOrDefault, id: int = Path(description="Model install id") +) -> ModelInstallJob: """ Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs' for information on the format of the return value. @@ -890,7 +966,9 @@ async def cancel_model_install_job( }, status_code=201, ) -async def pause_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob: +async def pause_model_install_job( + current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID") +) -> ModelInstallJob: """Pause the model install job corresponding to the given job ID.""" installer = ApiDependencies.invoker.services.model_manager.install try: @@ -910,7 +988,9 @@ async def pause_model_install_job(id: int = Path(description="Model install job }, status_code=201, ) -async def resume_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob: +async def resume_model_install_job( + current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID") +) -> ModelInstallJob: """Resume a paused model install job corresponding to the given job ID.""" installer = ApiDependencies.invoker.services.model_manager.install try: @@ -930,7 +1010,9 @@ async def resume_model_install_job(id: int = Path(description="Model install job }, status_code=201, ) -async def restart_failed_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob: +async def restart_failed_model_install_job( + current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID") +) -> ModelInstallJob: """Restart failed or non-resumable file downloads for the given job.""" installer = ApiDependencies.invoker.services.model_manager.install try: @@ -951,6 +1033,7 @@ async def restart_failed_model_install_job(id: int = Path(description="Model ins status_code=201, ) async def restart_model_install_file( + current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID"), file_source: AnyHttpUrl = Body(description="File download URL to restart"), ) -> ModelInstallJob: @@ -1262,7 +1345,7 @@ class DeleteOrphanedModelsResponse(BaseModel): operation_id="get_orphaned_models", response_model=list[OrphanedModelInfo], ) -async def get_orphaned_models() -> list[OrphanedModelInfo]: +async def get_orphaned_models(_: AdminUserOrDefault) -> list[OrphanedModelInfo]: """Find orphaned model directories. Orphaned models are directories in the models folder that contain model files @@ -1289,7 +1372,9 @@ async def get_orphaned_models() -> list[OrphanedModelInfo]: operation_id="delete_orphaned_models", response_model=DeleteOrphanedModelsResponse, ) -async def delete_orphaned_models(request: DeleteOrphanedModelsRequest) -> DeleteOrphanedModelsResponse: +async def delete_orphaned_models( + request: DeleteOrphanedModelsRequest, _: AdminUserOrDefault +) -> DeleteOrphanedModelsResponse: """Delete specified orphaned model directories. Args: diff --git a/invokeai/app/api/routers/recall_parameters.py b/invokeai/app/api/routers/recall_parameters.py index 0af3fd29b0..ec08adba2e 100644 --- a/invokeai/app/api/routers/recall_parameters.py +++ b/invokeai/app/api/routers/recall_parameters.py @@ -7,6 +7,7 @@ from fastapi import Body, HTTPException, Path from fastapi.routing import APIRouter from pydantic import BaseModel, ConfigDict, Field +from invokeai.app.api.auth_dependencies import CurrentUserOrDefault from invokeai.app.api.dependencies import ApiDependencies from invokeai.backend.image_util.controlnet_processor import process_controlnet_image from invokeai.backend.model_manager.taxonomy import ModelType @@ -291,12 +292,58 @@ def resolve_ip_adapter_models(ip_adapters: list[IPAdapterRecallParameter]) -> li return resolved_adapters +def _assert_recall_image_access(parameters: "RecallParameter", current_user: CurrentUserOrDefault) -> None: + """Validate that the caller can read every image referenced in the recall parameters. + + Control layers and IP adapters may reference image_name fields. Without this + check an attacker who knows another user's image UUID could use the recall + endpoint to extract image dimensions and — for ControlNet preprocessors — mint + a derived processed image they can then fetch. + """ + from invokeai.app.services.board_records.board_records_common import BoardVisibility + + image_names: list[str] = [] + if parameters.control_layers: + for layer in parameters.control_layers: + if layer.image_name is not None: + image_names.append(layer.image_name) + if parameters.ip_adapters: + for adapter in parameters.ip_adapters: + if adapter.image_name is not None: + image_names.append(adapter.image_name) + + if not image_names: + return + + # Admin can access all images + if current_user.is_admin: + return + + for image_name in image_names: + owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name) + if owner is not None and owner == current_user.user_id: + continue + + # Check board visibility + board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) + if board_id is not None: + try: + board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public): + continue + except Exception: + pass + + raise HTTPException(status_code=403, detail=f"Not authorized to access image {image_name}") + + @recall_parameters_router.post( "/{queue_id}", operation_id="update_recall_parameters", response_model=dict[str, Any], ) async def update_recall_parameters( + current_user: CurrentUserOrDefault, queue_id: str = Path(..., description="The queue id to perform this operation on"), parameters: RecallParameter = Body(..., description="Recall parameters to update"), ) -> dict[str, Any]: @@ -328,6 +375,10 @@ async def update_recall_parameters( """ logger = ApiDependencies.invoker.services.logger + # Validate image access before processing — prevents information leakage + # (dimensions) and derived-image minting via ControlNet preprocessors. + _assert_recall_image_access(parameters, current_user) + try: # Get only the parameters that were actually provided (non-None values) provided_params = {k: v for k, v in parameters.model_dump().items() if v is not None} @@ -335,14 +386,14 @@ async def update_recall_parameters( if not provided_params: return {"status": "no_parameters_provided", "updated_count": 0} - # Store each parameter in client state using a consistent key format + # Store each parameter in client state scoped to the current user updated_count = 0 for param_key, param_value in provided_params.items(): # Convert parameter values to JSON strings for storage value_str = json.dumps(param_value) try: ApiDependencies.invoker.services.client_state_persistence.set_by_key( - queue_id, f"recall_{param_key}", value_str + current_user.user_id, f"recall_{param_key}", value_str ) updated_count += 1 except Exception as e: @@ -396,7 +447,9 @@ async def update_recall_parameters( logger.info( f"Emitting recall_parameters_updated event for queue {queue_id} with {len(provided_params)} parameters" ) - ApiDependencies.invoker.services.events.emit_recall_parameters_updated(queue_id, provided_params) + ApiDependencies.invoker.services.events.emit_recall_parameters_updated( + queue_id, current_user.user_id, provided_params + ) logger.info("Successfully emitted recall_parameters_updated event") except Exception as e: logger.error(f"Error emitting recall parameters event: {e}", exc_info=True) @@ -425,6 +478,7 @@ async def update_recall_parameters( response_model=dict[str, Any], ) async def get_recall_parameters( + current_user: CurrentUserOrDefault, queue_id: str = Path(..., description="The queue id to retrieve parameters for"), ) -> dict[str, Any]: """ diff --git a/invokeai/app/api/routers/session_queue.py b/invokeai/app/api/routers/session_queue.py index 403e7727cb..41a5a411c7 100644 --- a/invokeai/app/api/routers/session_queue.py +++ b/invokeai/app/api/routers/session_queue.py @@ -44,7 +44,8 @@ def sanitize_queue_item_for_user( """Sanitize queue item for non-admin users viewing other users' items. For non-admin users viewing queue items belonging to other users, - the field_values, session graph, and workflow should be hidden/cleared to protect privacy. + only timestamps, status, and error information are exposed. All other + fields (user identity, generation parameters, graphs, workflows) are stripped. Args: queue_item: The queue item to sanitize @@ -58,15 +59,25 @@ def sanitize_queue_item_for_user( if is_admin or queue_item.user_id == current_user_id: return queue_item - # For non-admins viewing other users' items, clear sensitive fields - # Create a shallow copy to avoid mutating the original + # For non-admins viewing other users' items, strip everything except + # item_id, queue_id, status, and timestamps sanitized_item = queue_item.model_copy(deep=False) + sanitized_item.user_id = "redacted" + sanitized_item.user_display_name = None + sanitized_item.user_email = None + sanitized_item.batch_id = "redacted" + sanitized_item.session_id = "redacted" + sanitized_item.origin = None + sanitized_item.destination = None + sanitized_item.priority = 0 sanitized_item.field_values = None + sanitized_item.retried_from_item_id = None sanitized_item.workflow = None - # Clear the session graph by replacing it with an empty graph execution state - # This prevents information leakage through the generation graph + sanitized_item.error_type = None + sanitized_item.error_message = None + sanitized_item.error_traceback = None sanitized_item.session = GraphExecutionState( - id=queue_item.session.id, + id="redacted", graph=Graph(), ) return sanitized_item @@ -126,12 +137,16 @@ async def list_all_queue_items( }, ) async def get_queue_item_ids( + current_user: CurrentUserOrDefault, queue_id: str = Path(description="The queue id to perform this operation on"), order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"), ) -> ItemIdsResult: - """Gets all queue item ids that match the given parameters""" + """Gets all queue item ids that match the given parameters. Non-admin users only see their own items.""" try: - return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(queue_id=queue_id, order_dir=order_dir) + user_id = None if current_user.is_admin else current_user.user_id + return ApiDependencies.invoker.services.session_queue.get_queue_item_ids( + queue_id=queue_id, order_dir=order_dir, user_id=user_id + ) except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue item ids: {e}") @@ -376,11 +391,15 @@ async def prune( }, ) async def get_current_queue_item( + current_user: CurrentUserOrDefault, queue_id: str = Path(description="The queue id to perform this operation on"), ) -> Optional[SessionQueueItem]: """Gets the currently execution queue item""" try: - return ApiDependencies.invoker.services.session_queue.get_current(queue_id) + item = ApiDependencies.invoker.services.session_queue.get_current(queue_id) + if item is not None: + item = sanitize_queue_item_for_user(item, current_user.user_id, current_user.is_admin) + return item except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error while getting current queue item: {e}") @@ -393,11 +412,15 @@ async def get_current_queue_item( }, ) async def get_next_queue_item( + current_user: CurrentUserOrDefault, queue_id: str = Path(description="The queue id to perform this operation on"), ) -> Optional[SessionQueueItem]: """Gets the next queue item, without executing it""" try: - return ApiDependencies.invoker.services.session_queue.get_next(queue_id) + item = ApiDependencies.invoker.services.session_queue.get_next(queue_id) + if item is not None: + item = sanitize_queue_item_for_user(item, current_user.user_id, current_user.is_admin) + return item except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error while getting next queue item: {e}") @@ -413,9 +436,10 @@ async def get_queue_status( current_user: CurrentUserOrDefault, queue_id: str = Path(description="The queue id to perform this operation on"), ) -> SessionQueueAndProcessorStatus: - """Gets the status of the session queue""" + """Gets the status of the session queue. Non-admin users see only their own counts and cannot see current item details unless they own it.""" try: - queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=current_user.user_id) + user_id = None if current_user.is_admin else current_user.user_id + queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=user_id) processor = ApiDependencies.invoker.services.session_processor.get_status() return SessionQueueAndProcessorStatus(queue=queue, processor=processor) except Exception as e: @@ -430,12 +454,16 @@ async def get_queue_status( }, ) async def get_batch_status( + current_user: CurrentUserOrDefault, queue_id: str = Path(description="The queue id to perform this operation on"), batch_id: str = Path(description="The batch to get the status of"), ) -> BatchStatus: - """Gets the status of the session queue""" + """Gets the status of a batch. Non-admin users only see their own batches.""" try: - return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id) + user_id = None if current_user.is_admin else current_user.user_id + return ApiDependencies.invoker.services.session_queue.get_batch_status( + queue_id=queue_id, batch_id=batch_id, user_id=user_id + ) except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error while getting batch status: {e}") @@ -529,13 +557,15 @@ async def cancel_queue_item( responses={200: {"model": SessionQueueCountsByDestination}}, ) async def counts_by_destination( + current_user: CurrentUserOrDefault, queue_id: str = Path(description="The queue id to query"), destination: str = Query(description="The destination to query"), ) -> SessionQueueCountsByDestination: - """Gets the counts of queue items by destination""" + """Gets the counts of queue items by destination. Non-admin users only see their own items.""" try: + user_id = None if current_user.is_admin else current_user.user_id return ApiDependencies.invoker.services.session_queue.get_counts_by_destination( - queue_id=queue_id, destination=destination + queue_id=queue_id, destination=destination, user_id=user_id ) except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error while fetching counts by destination: {e}") diff --git a/invokeai/app/api/routers/workflows.py b/invokeai/app/api/routers/workflows.py index 72d50a416b..eb89325195 100644 --- a/invokeai/app/api/routers/workflows.py +++ b/invokeai/app/api/routers/workflows.py @@ -6,6 +6,7 @@ from fastapi import APIRouter, Body, File, HTTPException, Path, Query, UploadFil from fastapi.responses import FileResponse from PIL import Image +from invokeai.app.api.auth_dependencies import CurrentUserOrDefault from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection @@ -33,16 +34,25 @@ workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"]) }, ) async def get_workflow( + current_user: CurrentUserOrDefault, workflow_id: str = Path(description="The workflow to get"), ) -> WorkflowRecordWithThumbnailDTO: """Gets a workflow""" try: - thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id) workflow = ApiDependencies.invoker.services.workflow_records.get(workflow_id) - return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump()) except WorkflowNotFoundError: raise HTTPException(status_code=404, detail="Workflow not found") + config = ApiDependencies.invoker.services.configuration + if config.multiuser: + is_default = workflow.workflow.meta.category is WorkflowCategory.Default + is_owner = workflow.user_id == current_user.user_id + if not (is_default or is_owner or workflow.is_public or current_user.is_admin): + raise HTTPException(status_code=403, detail="Not authorized to access this workflow") + + thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id) + return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump()) + @workflows_router.patch( "/i/{workflow_id}", @@ -52,10 +62,21 @@ async def get_workflow( }, ) async def update_workflow( + current_user: CurrentUserOrDefault, workflow: Workflow = Body(description="The updated workflow", embed=True), ) -> WorkflowRecordDTO: """Updates a workflow""" - return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow) + config = ApiDependencies.invoker.services.configuration + if config.multiuser: + try: + existing = ApiDependencies.invoker.services.workflow_records.get(workflow.id) + except WorkflowNotFoundError: + raise HTTPException(status_code=404, detail="Workflow not found") + if not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to update this workflow") + # Pass user_id for defense-in-depth SQL scoping; admins pass None to allow any. + user_id = None if current_user.is_admin else current_user.user_id + return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow, user_id=user_id) @workflows_router.delete( @@ -63,15 +84,25 @@ async def update_workflow( operation_id="delete_workflow", ) async def delete_workflow( + current_user: CurrentUserOrDefault, workflow_id: str = Path(description="The workflow to delete"), ) -> None: """Deletes a workflow""" + config = ApiDependencies.invoker.services.configuration + if config.multiuser: + try: + existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id) + except WorkflowNotFoundError: + raise HTTPException(status_code=404, detail="Workflow not found") + if not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to delete this workflow") try: ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id) except WorkflowThumbnailFileNotFoundException: # It's OK if the workflow has no thumbnail file. We can still delete the workflow. pass - ApiDependencies.invoker.services.workflow_records.delete(workflow_id) + user_id = None if current_user.is_admin else current_user.user_id + ApiDependencies.invoker.services.workflow_records.delete(workflow_id, user_id=user_id) @workflows_router.post( @@ -82,10 +113,17 @@ async def delete_workflow( }, ) async def create_workflow( + current_user: CurrentUserOrDefault, workflow: WorkflowWithoutID = Body(description="The workflow to create", embed=True), ) -> WorkflowRecordDTO: """Creates a workflow""" - return ApiDependencies.invoker.services.workflow_records.create(workflow=workflow) + # In single-user mode, workflows are owned by 'system' and shared by default so all legacy/single-user + # workflows remain visible. In multiuser mode, workflows are private to the creator by default. + config = ApiDependencies.invoker.services.configuration + is_public = not config.multiuser + return ApiDependencies.invoker.services.workflow_records.create( + workflow=workflow, user_id=current_user.user_id, is_public=is_public + ) @workflows_router.get( @@ -96,6 +134,7 @@ async def create_workflow( }, ) async def list_workflows( + current_user: CurrentUserOrDefault, page: int = Query(default=0, description="The page to get"), per_page: Optional[int] = Query(default=None, description="The number of workflows per page"), order_by: WorkflowRecordOrderBy = Query( @@ -106,8 +145,19 @@ async def list_workflows( tags: Optional[list[str]] = Query(default=None, description="The tags of workflow to get"), query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"), has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"), + is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"), ) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]: """Gets a page of workflows""" + config = ApiDependencies.invoker.services.configuration + + # In multiuser mode, scope user-category workflows to the current user unless fetching shared workflows. + # Admins skip the user_id filter so they can see and manage all workflows including system-owned ones. + user_id_filter: Optional[str] = None + if config.multiuser and not current_user.is_admin: + has_user_category = not categories or WorkflowCategory.User in categories + if has_user_category and is_public is not True: + user_id_filter = current_user.user_id + workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = [] workflows = ApiDependencies.invoker.services.workflow_records.get_many( order_by=order_by, @@ -118,6 +168,8 @@ async def list_workflows( categories=categories, tags=tags, has_been_opened=has_been_opened, + user_id=user_id_filter, + is_public=is_public, ) for workflow in workflows.items: workflows_with_thumbnails.append( @@ -143,15 +195,20 @@ async def list_workflows( }, ) async def set_workflow_thumbnail( + current_user: CurrentUserOrDefault, workflow_id: str = Path(description="The workflow to update"), image: UploadFile = File(description="The image file to upload"), ): """Sets a workflow's thumbnail image""" try: - ApiDependencies.invoker.services.workflow_records.get(workflow_id) + existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id) except WorkflowNotFoundError: raise HTTPException(status_code=404, detail="Workflow not found") + config = ApiDependencies.invoker.services.configuration + if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to update this workflow") + if not image.content_type or not image.content_type.startswith("image"): raise HTTPException(status_code=415, detail="Not an image") @@ -177,14 +234,19 @@ async def set_workflow_thumbnail( }, ) async def delete_workflow_thumbnail( + current_user: CurrentUserOrDefault, workflow_id: str = Path(description="The workflow to update"), ): """Removes a workflow's thumbnail image""" try: - ApiDependencies.invoker.services.workflow_records.get(workflow_id) + existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id) except WorkflowNotFoundError: raise HTTPException(status_code=404, detail="Workflow not found") + config = ApiDependencies.invoker.services.configuration + if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to update this workflow") + try: ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id) except ValueError as e: @@ -206,8 +268,12 @@ async def delete_workflow_thumbnail( async def get_workflow_thumbnail( workflow_id: str = Path(description="The id of the workflow thumbnail to get"), ) -> FileResponse: - """Gets a workflow's thumbnail image""" + """Gets a workflow's thumbnail image. + This endpoint is intentionally unauthenticated because browsers load images + via tags which cannot send Bearer tokens. Workflow IDs are UUIDs, + providing security through unguessability. + """ try: path = ApiDependencies.invoker.services.workflow_thumbnails.get_path(workflow_id) @@ -223,37 +289,91 @@ async def get_workflow_thumbnail( raise HTTPException(status_code=404) +@workflows_router.patch( + "/i/{workflow_id}/is_public", + operation_id="update_workflow_is_public", + responses={ + 200: {"model": WorkflowRecordDTO}, + }, +) +async def update_workflow_is_public( + current_user: CurrentUserOrDefault, + workflow_id: str = Path(description="The workflow to update"), + is_public: bool = Body(description="Whether the workflow should be shared publicly", embed=True), +) -> WorkflowRecordDTO: + """Updates whether a workflow is shared publicly""" + try: + existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id) + except WorkflowNotFoundError: + raise HTTPException(status_code=404, detail="Workflow not found") + + config = ApiDependencies.invoker.services.configuration + if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to update this workflow") + + user_id = None if current_user.is_admin else current_user.user_id + return ApiDependencies.invoker.services.workflow_records.update_is_public( + workflow_id=workflow_id, is_public=is_public, user_id=user_id + ) + + @workflows_router.get("/tags", operation_id="get_all_tags") async def get_all_tags( + current_user: CurrentUserOrDefault, categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"), + is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"), ) -> list[str]: """Gets all unique tags from workflows""" + config = ApiDependencies.invoker.services.configuration + user_id_filter: Optional[str] = None + if config.multiuser and not current_user.is_admin: + has_user_category = not categories or WorkflowCategory.User in categories + if has_user_category and is_public is not True: + user_id_filter = current_user.user_id - return ApiDependencies.invoker.services.workflow_records.get_all_tags(categories=categories) + return ApiDependencies.invoker.services.workflow_records.get_all_tags( + categories=categories, user_id=user_id_filter, is_public=is_public + ) @workflows_router.get("/counts_by_tag", operation_id="get_counts_by_tag") async def get_counts_by_tag( + current_user: CurrentUserOrDefault, tags: list[str] = Query(description="The tags to get counts for"), categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"), has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"), + is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"), ) -> dict[str, int]: """Counts workflows by tag""" + config = ApiDependencies.invoker.services.configuration + user_id_filter: Optional[str] = None + if config.multiuser and not current_user.is_admin: + has_user_category = not categories or WorkflowCategory.User in categories + if has_user_category and is_public is not True: + user_id_filter = current_user.user_id return ApiDependencies.invoker.services.workflow_records.counts_by_tag( - tags=tags, categories=categories, has_been_opened=has_been_opened + tags=tags, categories=categories, has_been_opened=has_been_opened, user_id=user_id_filter, is_public=is_public ) @workflows_router.get("/counts_by_category", operation_id="counts_by_category") async def counts_by_category( + current_user: CurrentUserOrDefault, categories: list[WorkflowCategory] = Query(description="The categories to include"), has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"), + is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"), ) -> dict[str, int]: """Counts workflows by category""" + config = ApiDependencies.invoker.services.configuration + user_id_filter: Optional[str] = None + if config.multiuser and not current_user.is_admin: + has_user_category = WorkflowCategory.User in categories + if has_user_category and is_public is not True: + user_id_filter = current_user.user_id return ApiDependencies.invoker.services.workflow_records.counts_by_category( - categories=categories, has_been_opened=has_been_opened + categories=categories, has_been_opened=has_been_opened, user_id=user_id_filter, is_public=is_public ) @@ -262,7 +382,18 @@ async def counts_by_category( operation_id="update_opened_at", ) async def update_opened_at( + current_user: CurrentUserOrDefault, workflow_id: str = Path(description="The workflow to update"), ) -> None: """Updates the opened_at field of a workflow""" - ApiDependencies.invoker.services.workflow_records.update_opened_at(workflow_id) + try: + existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id) + except WorkflowNotFoundError: + raise HTTPException(status_code=404, detail="Workflow not found") + + config = ApiDependencies.invoker.services.configuration + if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to update this workflow") + + user_id = None if current_user.is_admin else current_user.user_id + ApiDependencies.invoker.services.workflow_records.update_opened_at(workflow_id, user_id=user_id) diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index fcead54eb1..5783b804c0 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -121,6 +121,11 @@ class SocketIO: Returns True to accept the connection, False to reject it. Stores user_id in the internal socket users dict for later use. + + In multiuser mode, connections without a valid token are rejected outright + so that anonymous clients cannot subscribe to queue rooms and observe + queue activity belonging to other users. In single-user mode, unauthenticated + connections are accepted as the system admin user. """ # Extract token from auth data or headers token = None @@ -137,6 +142,23 @@ class SocketIO: if token: token_data = verify_token(token) if token_data: + # In multiuser mode, also verify the backing user record still + # exists and is active — mirrors the REST auth check in + # auth_dependencies.py. A deleted or deactivated user whose + # JWT has not yet expired must not be allowed to open a socket. + if self._is_multiuser_enabled(): + try: + from invokeai.app.api.dependencies import ApiDependencies + + user = ApiDependencies.invoker.services.users.get(token_data.user_id) + if user is None or not user.is_active: + logger.warning(f"Rejecting socket {sid}: user {token_data.user_id} not found or inactive") + return False + except Exception: + # If user service is unavailable, fail closed + logger.warning(f"Rejecting socket {sid}: unable to verify user record") + return False + # Store user_id and is_admin in socket users dict self._socket_users[sid] = { "user_id": token_data.user_id, @@ -147,14 +169,37 @@ class SocketIO: ) return True - # If no valid token, store system user for backward compatibility + # No valid token provided. In multiuser mode this is not allowed — reject + # the connection so anonymous clients cannot subscribe to queue rooms. + # In single-user mode, fall through and accept the socket as system admin. + if self._is_multiuser_enabled(): + logger.warning( + f"Rejecting socket {sid} connection: multiuser mode is enabled and no valid auth token was provided" + ) + return False + self._socket_users[sid] = { "user_id": "system", - "is_admin": False, + "is_admin": True, } - logger.debug(f"Socket {sid} connected as system user (no valid token)") + logger.debug(f"Socket {sid} connected as system admin (single-user mode)") return True + @staticmethod + def _is_multiuser_enabled() -> bool: + """Check whether multiuser mode is enabled. Fails closed if configuration + is not yet initialized, which should not happen in practice but prevents + accidentally opening the socket during startup races.""" + try: + # Imported here to avoid a circular import at module load time. + from invokeai.app.api.dependencies import ApiDependencies + + return bool(ApiDependencies.invoker.services.configuration.multiuser) + except Exception: + # If dependencies are not initialized, fail closed (treat as multiuser) + # so we never accidentally admit an anonymous socket. + return True + async def _handle_disconnect(self, sid: str) -> None: """Handle socket disconnection and cleanup user info.""" if sid in self._socket_users: @@ -165,15 +210,20 @@ class SocketIO: """Handle queue subscription and add socket to both queue and user-specific rooms.""" queue_id = QueueSubscriptionEvent(**data).queue_id - # Check if we have user info for this socket + # Check if we have user info for this socket. In multiuser mode _handle_connect + # will have already rejected any socket without a valid token, so missing user + # info here is a bug — refuse the subscription rather than silently falling back + # to an anonymous system user who could then receive queue item events. if sid not in self._socket_users: - logger.warning( - f"Socket {sid} subscribing to queue {queue_id} but has no user info - need to authenticate via connect event" - ) - # Store as system user temporarily - real auth should happen in connect + if self._is_multiuser_enabled(): + logger.warning( + f"Refusing queue subscription for socket {sid}: no user info (socket not authenticated via connect event)" + ) + return + # Single-user mode: safe to fall back to the system admin user. self._socket_users[sid] = { "user_id": "system", - "is_admin": False, + "is_admin": True, } user_id = self._socket_users[sid]["user_id"] @@ -198,6 +248,13 @@ class SocketIO: await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id) async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None: + # In multiuser mode, only allow authenticated sockets to subscribe. + # Bulk download events are routed to user-specific rooms, so the + # bulk_download_id room subscription is only kept for single-user + # backward compatibility. + if self._is_multiuser_enabled() and sid not in self._socket_users: + logger.warning(f"Refusing bulk download subscription for unknown socket {sid} in multiuser mode") + return await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None: @@ -206,9 +263,17 @@ class SocketIO: async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]): """Handle queue events with user isolation. - Invocation events (progress, started, complete) are private - only emit to owner and admins. - Queue item status events are public - emit to all users (field values hidden via API). - Other queue events emit to all subscribers. + All queue item events (invocation events AND QueueItemStatusChangedEvent) are + private to the owning user and admins. They carry unsanitized user_id, batch_id, + session_id, origin, destination and error metadata, and must never be broadcast + to the whole queue room — otherwise any other authenticated subscriber could + observe cross-user queue activity. + + RecallParametersUpdatedEvent is also private to the owner + admins. + + BatchEnqueuedEvent carries the enqueuing user's batch_id/origin/counts and + is also routed privately. QueueClearedEvent is the only queue event that + is still broadcast to the whole queue room. IMPORTANT: Check InvocationEventBase BEFORE QueueItemEventBase since InvocationEventBase inherits from QueueItemEventBase. The order of isinstance checks matters! @@ -237,24 +302,40 @@ class SocketIO: logger.debug(f"Emitted private invocation event {event_name} to user room {user_room} and admin room") - # Queue item status events are visible to all users (field values masked via API) - # This catches QueueItemStatusChangedEvent but NOT InvocationEvents (already handled above) + # Other queue item events (QueueItemStatusChangedEvent) carry unsanitized + # user_id, batch_id, session_id, origin, destination and error metadata. + # They are private to the owning user + admins — never broadcast to the + # full queue room. elif isinstance(event_data, QueueItemEventBase) and hasattr(event_data, "user_id"): - # Emit to all subscribers in the queue - await self._sio.emit( - event=event_name, data=event_data.model_dump(mode="json"), room=event_data.queue_id - ) + user_room = f"user:{event_data.user_id}" + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room) + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") - logger.info( - f"Emitted public queue item event {event_name} to all subscribers in queue {event_data.queue_id}" - ) + logger.debug(f"Emitted private queue item event {event_name} to user room {user_room} and admin room") + + # RecallParametersUpdatedEvent is private - only emit to owner + admins + elif isinstance(event_data, RecallParametersUpdatedEvent): + user_room = f"user:{event_data.user_id}" + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room) + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") + logger.debug(f"Emitted private recall_parameters_updated event to user room {user_room} and admin room") + + # BatchEnqueuedEvent carries the enqueuing user's batch_id, origin, and + # enqueued counts. Route it privately to the owner + admins so other + # users do not observe cross-user batch activity. + elif isinstance(event_data, BatchEnqueuedEvent): + user_room = f"user:{event_data.user_id}" + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room) + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") + logger.debug(f"Emitted private batch_enqueued event to user room {user_room} and admin room") else: - # For other queue events (like QueueClearedEvent, BatchEnqueuedEvent), emit to all subscribers + # For remaining queue events (e.g. QueueClearedEvent) that do not + # carry user identity, emit to all subscribers in the queue room. await self._sio.emit( event=event_name, data=event_data.model_dump(mode="json"), room=event_data.queue_id ) - logger.info( + logger.debug( f"Emitted general queue event {event_name} to all subscribers in queue {event_data.queue_id}" ) except Exception as e: @@ -265,4 +346,17 @@ class SocketIO: await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json")) async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None: - await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id) + event_name, event_data = event + # Route to user-specific + admin rooms so that other authenticated + # users cannot learn the bulk_download_item_name (the capability token + # needed to fetch the zip from the unauthenticated GET endpoint). + # In single-user mode (user_id="system"), fall back to the shared + # bulk_download_id room for backward compatibility. + if hasattr(event_data, "user_id") and event_data.user_id != "system": + user_room = f"user:{event_data.user_id}" + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room) + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") + else: + await self._sio.emit( + event=event_name, data=event_data.model_dump(mode="json"), room=event_data.bulk_download_id + ) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 49894dba3c..2ca6746b49 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -79,6 +79,50 @@ app = FastAPI( ) +class SlidingWindowTokenMiddleware(BaseHTTPMiddleware): + """Refresh the JWT token on each authenticated response. + + When a request includes a valid Bearer token, the response includes a + X-Refreshed-Token header with a new token that has a fresh expiry. + This implements sliding-window session expiry: the session only expires + after a period of *inactivity*, not a fixed time after login. + """ + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): + response = await call_next(request) + + # Only refresh on mutating requests (POST/PUT/PATCH/DELETE) — these indicate + # genuine user activity. GET requests are often background fetches (RTK Query + # cache revalidation, refetch-on-focus, etc.) and should not reset the + # inactivity timer. + if response.status_code < 400 and request.method in ("POST", "PUT", "PATCH", "DELETE"): + auth_header = request.headers.get("authorization", "") + if auth_header.startswith("Bearer "): + token = auth_header[7:] + try: + from datetime import timedelta + + from invokeai.app.api.routers.auth import TOKEN_EXPIRATION_NORMAL, TOKEN_EXPIRATION_REMEMBER_ME + from invokeai.app.services.auth.token_service import create_access_token, verify_token + + token_data = verify_token(token) + if token_data is not None: + # Use the remember_me claim from the token to determine the + # correct refresh duration. This avoids the bug where a 7-day + # token with <24h remaining would be silently downgraded to 1 day. + if token_data.remember_me: + expires_delta = timedelta(days=TOKEN_EXPIRATION_REMEMBER_ME) + else: + expires_delta = timedelta(days=TOKEN_EXPIRATION_NORMAL) + + new_token = create_access_token(token_data, expires_delta) + response.headers["X-Refreshed-Token"] = new_token + except Exception: + pass # Don't fail the request if token refresh fails + + return response + + class RedirectRootWithQueryStringMiddleware(BaseHTTPMiddleware): """When a request is made to the root path with a query string, redirect to the root path without the query string. @@ -99,6 +143,7 @@ class RedirectRootWithQueryStringMiddleware(BaseHTTPMiddleware): # Add the middleware app.add_middleware(RedirectRootWithQueryStringMiddleware) +app.add_middleware(SlidingWindowTokenMiddleware) # Add event handler @@ -117,6 +162,7 @@ app.add_middleware( allow_credentials=app_config.allow_credentials, allow_methods=app_config.allow_methods, allow_headers=app_config.allow_headers, + expose_headers=["X-Refreshed-Token"], ) app.add_middleware(GZipMiddleware, minimum_size=1000) diff --git a/invokeai/app/invocations/anima_denoise.py b/invokeai/app/invocations/anima_denoise.py new file mode 100644 index 0000000000..530bf2918f --- /dev/null +++ b/invokeai/app/invocations/anima_denoise.py @@ -0,0 +1,715 @@ +"""Anima denoising invocation. + +Implements the rectified flow denoising loop for Anima models: +- Direct prediction: denoised = input - output * sigma +- Fixed shift=3.0 via loglinear_timestep_shift (Flux paper by Black Forest Labs) +- Timestep convention: timestep = sigma * 1.0 (raw sigma, NOT 1-sigma like Z-Image) +- NO v-prediction negation (unlike Z-Image) +- 3D latent space: [B, C, T, H, W] with T=1 for images +- 16 latent channels, 8x spatial compression + +Key differences from Z-Image denoise: +- Anima uses fixed shift=3.0, Z-Image uses dynamic shift based on resolution +- Anima: timestep = sigma (raw), Z-Image: model_t = 1.0 - sigma +- Anima: noise_pred = model_output (direct), Z-Image: noise_pred = -model_output (v-pred) +- Anima transformer takes (x, timesteps, context, t5xxl_ids, t5xxl_weights) +- Anima uses 3D latents directly, Z-Image converts 4D -> list of 5D +""" + +import inspect +import math +from contextlib import ExitStack +from typing import Callable, Iterator, Optional, Tuple + +import torch +import torchvision.transforms as tv_transforms +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from torchvision.transforms.functional import resize as tv_resize +from tqdm import tqdm + +from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation +from invokeai.app.invocations.fields import ( + AnimaConditioningField, + DenoiseMaskField, + FieldDescriptions, + Input, + InputField, + LatentsField, +) +from invokeai.app.invocations.model import TransformerField +from invokeai.app.invocations.primitives import LatentsOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.anima.anima_transformer_patch import patch_anima_for_regional_prompting +from invokeai.backend.anima.conditioning_data import AnimaRegionalTextConditioning, AnimaTextConditioning +from invokeai.backend.anima.regional_prompting import AnimaRegionalPromptingExtension +from invokeai.backend.flux.schedulers import ANIMA_SCHEDULER_LABELS, ANIMA_SCHEDULER_MAP, ANIMA_SCHEDULER_NAME_VALUES +from invokeai.backend.model_manager.taxonomy import BaseModelType +from invokeai.backend.patches.layer_patcher import LayerPatcher +from invokeai.backend.patches.lora_conversions.anima_lora_constants import ANIMA_LORA_TRANSFORMER_PREFIX +from invokeai.backend.patches.model_patch_raw import ModelPatchRaw +from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import ( + RectifiedFlowInpaintExtension, + assert_broadcastable, +) +from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import AnimaConditioningInfo, Range +from invokeai.backend.util.devices import TorchDevice + +# Anima uses 8x spatial compression (VAE downsamples by 2^3) +ANIMA_LATENT_SCALE_FACTOR = 8 +# Anima uses 16 latent channels +ANIMA_LATENT_CHANNELS = 16 +# Anima uses fixed shift=3.0 for the rectified flow schedule +ANIMA_SHIFT = 3.0 +# Anima uses raw sigma values as timesteps (no rescaling) +ANIMA_MULTIPLIER = 1.0 + + +def loglinear_timestep_shift(alpha: float, t: float) -> float: + """Apply log-linear timestep shift to a noise schedule value. + + This shift biases the noise schedule toward higher noise levels, as described + in the Flux model (Black Forest Labs, 2024). With alpha > 1, the model spends + proportionally more denoising steps at higher noise levels. + + Formula: sigma = alpha * t / (1 + (alpha - 1) * t) + + Args: + alpha: Shift factor (3.0 for Anima, resolution-dependent for Flux). + t: Timestep value in [0, 1]. + + Returns: + Shifted timestep value. + """ + if alpha == 1.0: + return t + return alpha * t / (1 + (alpha - 1) * t) + + +def inverse_loglinear_timestep_shift(alpha: float, sigma: float) -> float: + """Recover linear t from a shifted sigma value. + + Inverse of loglinear_timestep_shift: given sigma = alpha * t / (1 + (alpha-1) * t), + solve for t = sigma / (alpha - (alpha-1) * sigma). + + This is needed for the inpainting extension, which expects linear t values + for gradient mask thresholding. With Anima's shift=3.0, the difference + between shifted sigma and linear t is large (e.g. at t=0.5, sigma=0.75), + causing overly aggressive mask thresholding if sigma is used directly. + + Args: + alpha: Shift factor (3.0 for Anima). + sigma: Shifted sigma value in [0, 1]. + + Returns: + Linear t value in [0, 1]. + """ + if alpha == 1.0: + return sigma + denominator = alpha - (alpha - 1) * sigma + if abs(denominator) < 1e-8: + return 1.0 + return sigma / denominator + + +class AnimaInpaintExtension(RectifiedFlowInpaintExtension): + """Inpaint extension for Anima that accounts for the time-SNR shift. + + Anima uses a fixed shift=3.0 which makes sigma values significantly larger + than the corresponding linear t values. The base RectifiedFlowInpaintExtension + uses t_prev for both gradient mask thresholding and noise mixing, which assumes + linear t values. + + This subclass: + - Uses the LINEAR t for gradient mask thresholding (correct progressive reveal) + - Uses the SHIFTED sigma for noise mixing (matches the denoiser's noise level) + """ + + def __init__( + self, + init_latents: torch.Tensor, + inpaint_mask: torch.Tensor, + noise: torch.Tensor, + shift: float = ANIMA_SHIFT, + ): + assert_broadcastable(init_latents.shape, inpaint_mask.shape, noise.shape) + self._init_latents = init_latents + self._inpaint_mask = inpaint_mask + self._noise = noise + self._shift = shift + + def merge_intermediate_latents_with_init_latents( + self, intermediate_latents: torch.Tensor, sigma_prev: float + ) -> torch.Tensor: + """Merge intermediate latents with init latents, correcting for Anima's shift. + + Args: + intermediate_latents: The denoised latents at the current step. + sigma_prev: The SHIFTED sigma value for the next step. + """ + # Recover linear t from shifted sigma for gradient mask thresholding. + # This ensures the gradient mask is revealed at the correct pace. + t_prev = inverse_loglinear_timestep_shift(self._shift, sigma_prev) + mask = self._apply_mask_gradient_adjustment(t_prev) + + # Use shifted sigma for noise mixing to match the denoiser's noise level. + # The Euler step produces latents at noise level sigma_prev, so the + # preserved regions must also be at sigma_prev noise level. + noised_init_latents = self._noise * sigma_prev + (1.0 - sigma_prev) * self._init_latents + + return intermediate_latents * mask + noised_init_latents * (1.0 - mask) + + +@invocation( + "anima_denoise", + title="Denoise - Anima", + tags=["image", "anima"], + category="image", + version="1.2.0", + classification=Classification.Prototype, +) +class AnimaDenoiseInvocation(BaseInvocation): + """Run the denoising process with an Anima model. + + Uses rectified flow sampling with shift=3.0 and the Cosmos Predict2 DiT + backbone with integrated LLM Adapter for text conditioning. + + Supports txt2img, img2img (via latents input), and inpainting (via denoise_mask). + """ + + # If latents is provided, this means we are doing image-to-image. + latents: Optional[LatentsField] = InputField( + default=None, description=FieldDescriptions.latents, input=Input.Connection + ) + # denoise_mask is used for inpainting. Only the masked region is modified. + denoise_mask: Optional[DenoiseMaskField] = InputField( + default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection + ) + denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start) + denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end) + add_noise: bool = InputField(default=True, description="Add noise based on denoising start.") + transformer: TransformerField = InputField( + description="Anima transformer model.", input=Input.Connection, title="Transformer" + ) + positive_conditioning: AnimaConditioningField | list[AnimaConditioningField] = InputField( + description=FieldDescriptions.positive_cond, input=Input.Connection + ) + negative_conditioning: AnimaConditioningField | list[AnimaConditioningField] | None = InputField( + default=None, description=FieldDescriptions.negative_cond, input=Input.Connection + ) + guidance_scale: float = InputField( + default=4.5, + ge=1.0, + description="Guidance scale for classifier-free guidance. Recommended: 4.0-5.0 for Anima.", + title="Guidance Scale", + ) + width: int = InputField(default=1024, multiple_of=8, description="Width of the generated image.") + height: int = InputField(default=1024, multiple_of=8, description="Height of the generated image.") + steps: int = InputField(default=30, gt=0, description="Number of denoising steps. 30 recommended for Anima.") + seed: int = InputField(default=0, description="Randomness seed for reproducibility.") + scheduler: ANIMA_SCHEDULER_NAME_VALUES = InputField( + default="euler", + description="Scheduler (sampler) for the denoising process.", + ui_choice_labels=ANIMA_SCHEDULER_LABELS, + ) + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> LatentsOutput: + latents = self._run_diffusion(context) + latents = latents.detach().to("cpu") + name = context.tensors.save(tensor=latents) + return LatentsOutput.build(latents_name=name, latents=latents, seed=None) + + def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None: + """Prepare the inpaint mask for Anima. + + Anima uses 3D latents [B, C, T, H, W] internally but the mask operates + on the spatial dimensions [B, C, H, W] which match the squeezed output. + """ + if self.denoise_mask is None: + return None + mask = context.tensors.load(self.denoise_mask.mask_name) + + # Invert mask: 0.0 = regions to denoise, 1.0 = regions to preserve + mask = 1.0 - mask + + _, _, latent_height, latent_width = latents.shape + mask = tv_resize( + img=mask, + size=[latent_height, latent_width], + interpolation=tv_transforms.InterpolationMode.BILINEAR, + antialias=False, + ) + + mask = mask.to(device=latents.device, dtype=latents.dtype) + return mask + + def _get_noise( + self, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + seed: int, + ) -> torch.Tensor: + """Generate initial noise tensor in 3D latent space [B, C, T, H, W].""" + rand_device = "cpu" + return torch.randn( + 1, + ANIMA_LATENT_CHANNELS, + 1, # T=1 for single image + height // ANIMA_LATENT_SCALE_FACTOR, + width // ANIMA_LATENT_SCALE_FACTOR, + device=rand_device, + dtype=torch.float32, + generator=torch.Generator(device=rand_device).manual_seed(seed), + ).to(device=device, dtype=dtype) + + def _get_sigmas(self, num_steps: int) -> list[float]: + """Generate sigma schedule with fixed shift=3.0. + + Uses the log-linear timestep shift from the Flux model (Black Forest Labs) + with a fixed shift factor of 3.0 (no dynamic resolution-based shift). + + Returns: + List of num_steps + 1 sigma values from ~1.0 (noise) to 0.0 (clean). + """ + sigmas = [] + for i in range(num_steps + 1): + t = 1.0 - i / num_steps + sigma = loglinear_timestep_shift(ANIMA_SHIFT, t) + sigmas.append(sigma) + return sigmas + + def _load_conditioning( + self, + context: InvocationContext, + cond_field: AnimaConditioningField, + dtype: torch.dtype, + device: torch.device, + ) -> AnimaConditioningInfo: + """Load Anima conditioning data from storage.""" + cond_data = context.conditioning.load(cond_field.conditioning_name) + assert len(cond_data.conditionings) == 1 + cond_info = cond_data.conditionings[0] + assert isinstance(cond_info, AnimaConditioningInfo) + return cond_info.to(dtype=dtype, device=device) + + def _load_text_conditionings( + self, + context: InvocationContext, + cond_field: AnimaConditioningField | list[AnimaConditioningField], + img_token_height: int, + img_token_width: int, + dtype: torch.dtype, + device: torch.device, + ) -> list[AnimaTextConditioning]: + """Load Anima text conditioning with optional regional masks. + + Args: + context: The invocation context. + cond_field: Single conditioning field or list of fields. + img_token_height: Height of the image token grid (H // patch_size). + img_token_width: Width of the image token grid (W // patch_size). + dtype: Target dtype. + device: Target device. + + Returns: + List of AnimaTextConditioning objects with optional masks. + """ + cond_list = cond_field if isinstance(cond_field, list) else [cond_field] + + text_conditionings: list[AnimaTextConditioning] = [] + for cond in cond_list: + cond_info = self._load_conditioning(context, cond, dtype, device) + + # Load the mask, if provided + mask: torch.Tensor | None = None + if cond.mask is not None: + mask = context.tensors.load(cond.mask.tensor_name) + mask = mask.to(device=device) + mask = AnimaRegionalPromptingExtension.preprocess_regional_prompt_mask( + mask, img_token_height, img_token_width, dtype, device + ) + + text_conditionings.append( + AnimaTextConditioning( + qwen3_embeds=cond_info.qwen3_embeds, + t5xxl_ids=cond_info.t5xxl_ids, + t5xxl_weights=cond_info.t5xxl_weights, + mask=mask, + ) + ) + + return text_conditionings + + def _run_llm_adapter_for_regions( + self, + transformer, + text_conditionings: list[AnimaTextConditioning], + dtype: torch.dtype, + ) -> AnimaRegionalTextConditioning: + """Run the LLM Adapter separately for each regional conditioning and concatenate. + + Args: + transformer: The AnimaTransformer instance (must be on device). + text_conditionings: List of per-region conditioning data. + dtype: Inference dtype. + + Returns: + AnimaRegionalTextConditioning with concatenated context and masks. + """ + context_embeds_list: list[torch.Tensor] = [] + context_ranges: list[Range] = [] + image_masks: list[torch.Tensor | None] = [] + cur_len = 0 + + for tc in text_conditionings: + qwen3_embeds = tc.qwen3_embeds.unsqueeze(0) # (1, seq_len, 1024) + t5xxl_ids = tc.t5xxl_ids.unsqueeze(0) # (1, seq_len) + t5xxl_weights = None + if tc.t5xxl_weights is not None: + t5xxl_weights = tc.t5xxl_weights.unsqueeze(0).unsqueeze(-1) # (1, seq_len, 1) + + # Run the LLM Adapter to produce context for this region + context = transformer.preprocess_text_embeds( + qwen3_embeds.to(dtype=dtype), + t5xxl_ids, + t5xxl_weights=t5xxl_weights.to(dtype=dtype) if t5xxl_weights is not None else None, + ) + # context shape: (1, 512, 1024) — squeeze batch dim + context_2d = context.squeeze(0) # (512, 1024) + + context_embeds_list.append(context_2d) + context_ranges.append(Range(start=cur_len, end=cur_len + context_2d.shape[0])) + image_masks.append(tc.mask) + cur_len += context_2d.shape[0] + + concatenated_context = torch.cat(context_embeds_list, dim=0) + + return AnimaRegionalTextConditioning( + context_embeds=concatenated_context, + image_masks=image_masks, + context_ranges=context_ranges, + ) + + def _run_diffusion(self, context: InvocationContext) -> torch.Tensor: + device = TorchDevice.choose_torch_device() + inference_dtype = TorchDevice.choose_bfloat16_safe_dtype(device) + + if self.denoising_start >= self.denoising_end: + raise ValueError( + f"denoising_start ({self.denoising_start}) must be less than denoising_end ({self.denoising_end})." + ) + + transformer_info = context.models.load(self.transformer.transformer) + + # Compute image token grid dimensions for regional prompting + # Anima: 8x VAE compression, 2x patch size → 16x total + patch_size = 2 + latent_height = self.height // ANIMA_LATENT_SCALE_FACTOR + latent_width = self.width // ANIMA_LATENT_SCALE_FACTOR + img_token_height = latent_height // patch_size + img_token_width = latent_width // patch_size + img_seq_len = img_token_height * img_token_width + + # Load positive conditioning with optional regional masks + pos_text_conditionings = self._load_text_conditionings( + context=context, + cond_field=self.positive_conditioning, + img_token_height=img_token_height, + img_token_width=img_token_width, + dtype=inference_dtype, + device=device, + ) + has_regional = len(pos_text_conditionings) > 1 or any(tc.mask is not None for tc in pos_text_conditionings) + + # Load negative conditioning if CFG is enabled + do_cfg = not math.isclose(self.guidance_scale, 1.0) and self.negative_conditioning is not None + neg_text_conditionings: list[AnimaTextConditioning] | None = None + if do_cfg: + assert self.negative_conditioning is not None + neg_text_conditionings = self._load_text_conditionings( + context=context, + cond_field=self.negative_conditioning, + img_token_height=img_token_height, + img_token_width=img_token_width, + dtype=inference_dtype, + device=device, + ) + + # Generate sigma schedule + sigmas = self._get_sigmas(self.steps) + + # Apply denoising_start and denoising_end clipping (for img2img/inpaint) + if self.denoising_start > 0 or self.denoising_end < 1: + total_sigmas = len(sigmas) + start_idx = int(self.denoising_start * (total_sigmas - 1)) + end_idx = int(self.denoising_end * (total_sigmas - 1)) + 1 + sigmas = sigmas[start_idx:end_idx] + + total_steps = len(sigmas) - 1 + + # Load input latents if provided (image-to-image) + init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None + if init_latents is not None: + init_latents = init_latents.to(device=device, dtype=inference_dtype) + # Anima denoiser works in 3D: add temporal dim if needed + if init_latents.ndim == 4: + init_latents = init_latents.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W] + + # Generate initial noise (3D latent: [B, C, T, H, W]) + noise = self._get_noise(self.height, self.width, inference_dtype, device, self.seed) + + # Prepare input latents + if init_latents is not None: + if self.add_noise: + s_0 = sigmas[0] + latents = s_0 * noise + (1.0 - s_0) * init_latents + else: + latents = init_latents + else: + if self.denoising_start > 1e-5: + raise ValueError("denoising_start should be 0 when initial latents are not provided.") + latents = noise + + if total_steps <= 0: + return latents.squeeze(2) + + # Prepare inpaint extension + inpaint_mask = self._prep_inpaint_mask(context, latents.squeeze(2)) + inpaint_extension: AnimaInpaintExtension | None = None + if inpaint_mask is not None: + if init_latents is None: + raise ValueError("Initial latents are required when using an inpaint mask (image-to-image inpainting)") + inpaint_extension = AnimaInpaintExtension( + init_latents=init_latents.squeeze(2), + inpaint_mask=inpaint_mask, + noise=noise.squeeze(2), + shift=ANIMA_SHIFT, + ) + + step_callback = self._build_step_callback(context) + + # Initialize diffusers scheduler if not using built-in Euler + scheduler: SchedulerMixin | None = None + use_scheduler = self.scheduler != "euler" + + if use_scheduler: + scheduler_class = ANIMA_SCHEDULER_MAP[self.scheduler] + scheduler = scheduler_class(num_train_timesteps=1000, shift=1.0) + is_lcm = self.scheduler == "lcm" + set_timesteps_sig = inspect.signature(scheduler.set_timesteps) + if not is_lcm and "sigmas" in set_timesteps_sig.parameters: + scheduler.set_timesteps(sigmas=sigmas, device=device) + else: + scheduler.set_timesteps(num_inference_steps=total_steps, device=device) + num_scheduler_steps = len(scheduler.timesteps) + else: + num_scheduler_steps = total_steps + + with ExitStack() as exit_stack: + (cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device()) + + # Apply LoRA models to the transformer. + # Note: We apply the LoRA after the transformer has been moved to its target device for faster patching. + exit_stack.enter_context( + LayerPatcher.apply_smart_model_patches( + model=transformer, + patches=self._lora_iterator(context), + prefix=ANIMA_LORA_TRANSFORMER_PREFIX, + dtype=inference_dtype, + cached_weights=cached_weights, + ) + ) + + # Run LLM Adapter for each regional conditioning to produce context vectors. + # This must happen with the transformer on device since it uses the adapter weights. + if has_regional: + pos_regional = self._run_llm_adapter_for_regions(transformer, pos_text_conditionings, inference_dtype) + pos_context = pos_regional.context_embeds.unsqueeze(0) # (1, total_ctx_len, 1024) + + # Build regional prompting extension with cross-attention mask + regional_extension = AnimaRegionalPromptingExtension.from_regional_conditioning( + pos_regional, img_seq_len + ) + + # For negative, concatenate all regions without masking (matches Z-Image behavior) + neg_context = None + if do_cfg and neg_text_conditionings is not None: + neg_regional = self._run_llm_adapter_for_regions( + transformer, neg_text_conditionings, inference_dtype + ) + neg_context = neg_regional.context_embeds.unsqueeze(0) + else: + # Single conditioning — run LLM Adapter via normal forward path + tc = pos_text_conditionings[0] + pos_qwen3_embeds = tc.qwen3_embeds.unsqueeze(0) + pos_t5xxl_ids = tc.t5xxl_ids.unsqueeze(0) + pos_t5xxl_weights = None + if tc.t5xxl_weights is not None: + pos_t5xxl_weights = tc.t5xxl_weights.unsqueeze(0).unsqueeze(-1) + + # Pre-compute context via LLM Adapter + pos_context = transformer.preprocess_text_embeds( + pos_qwen3_embeds.to(dtype=inference_dtype), + pos_t5xxl_ids, + t5xxl_weights=pos_t5xxl_weights.to(dtype=inference_dtype) + if pos_t5xxl_weights is not None + else None, + ) + + neg_context = None + if do_cfg and neg_text_conditionings is not None: + ntc = neg_text_conditionings[0] + neg_qwen3 = ntc.qwen3_embeds.unsqueeze(0) + neg_ids = ntc.t5xxl_ids.unsqueeze(0) + neg_weights = None + if ntc.t5xxl_weights is not None: + neg_weights = ntc.t5xxl_weights.unsqueeze(0).unsqueeze(-1) + neg_context = transformer.preprocess_text_embeds( + neg_qwen3.to(dtype=inference_dtype), + neg_ids, + t5xxl_weights=neg_weights.to(dtype=inference_dtype) if neg_weights is not None else None, + ) + + regional_extension = None + + # Apply regional prompting patch if we have regional masks + exit_stack.enter_context(patch_anima_for_regional_prompting(transformer, regional_extension)) + + # Helper to run transformer with pre-computed context (bypasses LLM Adapter) + def _run_transformer(ctx: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + return transformer( + x=x.to(transformer.dtype if hasattr(transformer, "dtype") else inference_dtype), + timesteps=t, + context=ctx, + # t5xxl_ids=None skips the LLM Adapter — context is already pre-computed + ) + + if use_scheduler and scheduler is not None: + # Scheduler-based denoising + user_step = 0 + pbar = tqdm(total=total_steps, desc="Denoising (Anima)") + for step_index in range(num_scheduler_steps): + sched_timestep = scheduler.timesteps[step_index] + sigma_curr = sched_timestep.item() / scheduler.config.num_train_timesteps + + is_heun = hasattr(scheduler, "state_in_first_order") + in_first_order = scheduler.state_in_first_order if is_heun else True + + timestep = torch.tensor( + [sigma_curr * ANIMA_MULTIPLIER], device=device, dtype=inference_dtype + ).expand(latents.shape[0]) + + noise_pred_cond = _run_transformer(pos_context, latents, timestep).float() + + if do_cfg and neg_context is not None: + noise_pred_uncond = _run_transformer(neg_context, latents, timestep).float() + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + + step_output = scheduler.step(model_output=noise_pred, timestep=sched_timestep, sample=latents) + latents = step_output.prev_sample + + if step_index + 1 < len(scheduler.sigmas): + sigma_prev = scheduler.sigmas[step_index + 1].item() + else: + sigma_prev = 0.0 + + if inpaint_extension is not None: + latents_4d = latents.squeeze(2) + latents_4d = inpaint_extension.merge_intermediate_latents_with_init_latents( + latents_4d, sigma_prev + ) + latents = latents_4d.unsqueeze(2) + + if is_heun: + if not in_first_order: + user_step += 1 + if user_step <= total_steps: + pbar.update(1) + step_callback( + PipelineIntermediateState( + step=user_step, + order=2, + total_steps=total_steps, + timestep=int(sigma_curr * 1000), + latents=latents.squeeze(2), + ) + ) + else: + user_step += 1 + if user_step <= total_steps: + pbar.update(1) + step_callback( + PipelineIntermediateState( + step=user_step, + order=1, + total_steps=total_steps, + timestep=int(sigma_curr * 1000), + latents=latents.squeeze(2), + ) + ) + pbar.close() + else: + # Built-in Euler implementation (default for Anima) + for step_idx in tqdm(range(total_steps), desc="Denoising (Anima)"): + sigma_curr = sigmas[step_idx] + sigma_prev = sigmas[step_idx + 1] + + timestep = torch.tensor( + [sigma_curr * ANIMA_MULTIPLIER], device=device, dtype=inference_dtype + ).expand(latents.shape[0]) + + noise_pred_cond = _run_transformer(pos_context, latents, timestep).float() + + if do_cfg and neg_context is not None: + noise_pred_uncond = _run_transformer(neg_context, latents, timestep).float() + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + + latents_dtype = latents.dtype + latents = latents.to(dtype=torch.float32) + latents = latents + (sigma_prev - sigma_curr) * noise_pred + latents = latents.to(dtype=latents_dtype) + + if inpaint_extension is not None: + latents_4d = latents.squeeze(2) + latents_4d = inpaint_extension.merge_intermediate_latents_with_init_latents( + latents_4d, sigma_prev + ) + latents = latents_4d.unsqueeze(2) + + step_callback( + PipelineIntermediateState( + step=step_idx + 1, + order=1, + total_steps=total_steps, + timestep=int(sigma_curr * 1000), + latents=latents.squeeze(2), + ), + ) + + # Remove temporal dimension for output: [B, C, 1, H, W] -> [B, C, H, W] + return latents.squeeze(2) + + def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]: + def step_callback(state: PipelineIntermediateState) -> None: + context.util.sd_step_callback(state, BaseModelType.Anima) + + return step_callback + + def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]: + """Iterate over LoRA models to apply to the transformer.""" + for lora in self.transformer.loras: + lora_info = context.models.load(lora.lora) + if not isinstance(lora_info.model, ModelPatchRaw): + raise TypeError( + f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. " + "The LoRA model may be corrupted or incompatible." + ) + yield (lora_info.model, lora.weight) + del lora_info diff --git a/invokeai/app/invocations/anima_image_to_latents.py b/invokeai/app/invocations/anima_image_to_latents.py new file mode 100644 index 0000000000..83073ab4a8 --- /dev/null +++ b/invokeai/app/invocations/anima_image_to_latents.py @@ -0,0 +1,119 @@ +"""Anima image-to-latents invocation. + +Encodes an image to latent space using the Anima VAE (AutoencoderKLWan or FLUX VAE). + +For Wan VAE (AutoencoderKLWan): +- Input image is converted to 5D tensor [B, C, T, H, W] with T=1 +- After encoding, latents are normalized: (latents - mean) / std + (inverse of the denormalization in anima_latents_to_image.py) + +For FLUX VAE (AutoEncoder): +- Encoding is handled internally by the FLUX VAE +""" + +from typing import Union + +import einops +import torch +from diffusers.models.autoencoders import AutoencoderKLWan + +from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation +from invokeai.app.invocations.fields import ( + FieldDescriptions, + ImageField, + Input, + InputField, + WithBoard, + WithMetadata, +) +from invokeai.app.invocations.model import VAEField +from invokeai.app.invocations.primitives import LatentsOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.flux.modules.autoencoder import AutoEncoder as FluxAutoEncoder +from invokeai.backend.model_manager.load.load_base import LoadedModel +from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor +from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux + +AnimaVAE = Union[AutoencoderKLWan, FluxAutoEncoder] + + +@invocation( + "anima_i2l", + title="Image to Latents - Anima", + tags=["image", "latents", "vae", "i2l", "anima"], + category="image", + version="1.0.0", + classification=Classification.Prototype, +) +class AnimaImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard): + """Generates latents from an image using the Anima VAE (supports Wan 2.1 and FLUX VAE).""" + + image: ImageField = InputField(description="The image to encode.") + vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection) + + @staticmethod + def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor: + if not isinstance(vae_info.model, (AutoencoderKLWan, FluxAutoEncoder)): + raise TypeError( + f"Expected AutoencoderKLWan or FluxAutoEncoder for Anima VAE, got {type(vae_info.model).__name__}." + ) + + estimated_working_memory = estimate_vae_working_memory_flux( + operation="encode", + image_tensor=image_tensor, + vae=vae_info.model, + ) + + with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae): + if not isinstance(vae, (AutoencoderKLWan, FluxAutoEncoder)): + raise TypeError(f"Expected AutoencoderKLWan or FluxAutoEncoder, got {type(vae).__name__}.") + + vae_dtype = next(iter(vae.parameters())).dtype + image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype) + + with torch.inference_mode(): + if isinstance(vae, FluxAutoEncoder): + # FLUX VAE handles scaling internally + generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0) + latents = vae.encode(image_tensor, sample=True, generator=generator) + else: + # AutoencoderKLWan expects 5D input [B, C, T, H, W] + if image_tensor.ndim == 4: + image_tensor = image_tensor.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W] + + encoded = vae.encode(image_tensor, return_dict=False)[0] + latents = encoded.sample().to(dtype=vae_dtype) + + # Normalize to denoiser space: (latents - mean) / std + # This is the inverse of the denormalization in anima_latents_to_image.py + latents_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latents) + latents_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(latents) + latents = (latents - latents_mean) / latents_std + + # Remove temporal dimension: [B, C, 1, H, W] -> [B, C, H, W] + if latents.ndim == 5: + latents = latents.squeeze(2) + + return latents + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> LatentsOutput: + image = context.images.get_pil(self.image.image_name) + + image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) + if image_tensor.dim() == 3: + image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") + + vae_info = context.models.load(self.vae.vae) + if not isinstance(vae_info.model, (AutoencoderKLWan, FluxAutoEncoder)): + raise TypeError( + f"Expected AutoencoderKLWan or FluxAutoEncoder for Anima VAE, got {type(vae_info.model).__name__}." + ) + + context.util.signal_progress("Running Anima VAE encode") + latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor) + + latents = latents.to("cpu") + name = context.tensors.save(tensor=latents) + return LatentsOutput.build(latents_name=name, latents=latents, seed=None) diff --git a/invokeai/app/invocations/anima_latents_to_image.py b/invokeai/app/invocations/anima_latents_to_image.py new file mode 100644 index 0000000000..080c101fa4 --- /dev/null +++ b/invokeai/app/invocations/anima_latents_to_image.py @@ -0,0 +1,108 @@ +"""Anima latents-to-image invocation. + +Decodes Anima latents using the QwenImage VAE (AutoencoderKLWan) or +compatible FLUX VAE as fallback. + +Latents from the denoiser are in normalized space (zero-centered). Before +VAE decode, they must be denormalized using the Wan 2.1 per-channel +mean/std: latents = latents * std + mean (matching diffusers WanPipeline). + +The VAE expects 5D latents [B, C, T, H, W] — for single images, T=1. +""" + +import torch +from diffusers.models.autoencoders import AutoencoderKLWan +from einops import rearrange +from PIL import Image + +from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation +from invokeai.app.invocations.fields import ( + FieldDescriptions, + Input, + InputField, + LatentsField, + WithBoard, + WithMetadata, +) +from invokeai.app.invocations.model import VAEField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.flux.modules.autoencoder import AutoEncoder as FluxAutoEncoder +from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux + + +@invocation( + "anima_l2i", + title="Latents to Image - Anima", + tags=["latents", "image", "vae", "l2i", "anima"], + category="latents", + version="1.0.2", + classification=Classification.Prototype, +) +class AnimaLatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): + """Generates an image from latents using the Anima VAE. + + Supports the Wan 2.1 QwenImage VAE (AutoencoderKLWan) with explicit + latent denormalization, and FLUX VAE as fallback. + """ + + latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection) + vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection) + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> ImageOutput: + latents = context.tensors.load(self.latents.latents_name) + + vae_info = context.models.load(self.vae.vae) + if not isinstance(vae_info.model, (AutoencoderKLWan, FluxAutoEncoder)): + raise TypeError( + f"Expected AutoencoderKLWan or FluxAutoEncoder for Anima VAE, got {type(vae_info.model).__name__}." + ) + + estimated_working_memory = estimate_vae_working_memory_flux( + operation="decode", + image_tensor=latents, + vae=vae_info.model, + ) + + with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae): + context.util.signal_progress("Running Anima VAE decode") + if not isinstance(vae, (AutoencoderKLWan, FluxAutoEncoder)): + raise TypeError(f"Expected AutoencoderKLWan or FluxAutoEncoder, got {type(vae).__name__}.") + + vae_dtype = next(iter(vae.parameters())).dtype + latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype) + + TorchDevice.empty_cache() + + with torch.inference_mode(): + if isinstance(vae, FluxAutoEncoder): + # FLUX VAE handles scaling internally, expects 4D [B, C, H, W] + img = vae.decode(latents) + else: + # Expects 5D latents [B, C, T, H, W] + if latents.ndim == 4: + latents = latents.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W] + + # Denormalize from denoiser space to raw VAE space + # (same as diffusers WanPipeline and ComfyUI Wan21.process_out) + latents_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latents) + latents_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(latents) + latents = latents * latents_std + latents_mean + + decoded = vae.decode(latents, return_dict=False)[0] + + # Output is 5D [B, C, T, H, W] — squeeze temporal dim + if decoded.ndim == 5: + decoded = decoded.squeeze(2) + img = decoded + + img = img.clamp(-1, 1) + img = rearrange(img[0], "c h w -> h w c") + img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy()) + + TorchDevice.empty_cache() + + image_dto = context.images.save(image=img_pil) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/anima_lora_loader.py b/invokeai/app/invocations/anima_lora_loader.py new file mode 100644 index 0000000000..6a035b55aa --- /dev/null +++ b/invokeai/app/invocations/anima_lora_loader.py @@ -0,0 +1,162 @@ +from typing import Optional + +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + Classification, + invocation, + invocation_output, +) +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField +from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, Qwen3EncoderField, TransformerField +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType + + +@invocation_output("anima_lora_loader_output") +class AnimaLoRALoaderOutput(BaseInvocationOutput): + """Anima LoRA Loader Output""" + + transformer: Optional[TransformerField] = OutputField( + default=None, description=FieldDescriptions.transformer, title="Anima Transformer" + ) + qwen3_encoder: Optional[Qwen3EncoderField] = OutputField( + default=None, description=FieldDescriptions.qwen3_encoder, title="Qwen3 Encoder" + ) + + +@invocation( + "anima_lora_loader", + title="Apply LoRA - Anima", + tags=["lora", "model", "anima"], + category="model", + version="1.0.0", + classification=Classification.Prototype, +) +class AnimaLoRALoaderInvocation(BaseInvocation): + """Apply a LoRA model to an Anima transformer and/or Qwen3 text encoder.""" + + lora: ModelIdentifierField = InputField( + description=FieldDescriptions.lora_model, + title="LoRA", + ui_model_base=BaseModelType.Anima, + ui_model_type=ModelType.LoRA, + ) + weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) + transformer: TransformerField | None = InputField( + default=None, + description=FieldDescriptions.transformer, + input=Input.Connection, + title="Anima Transformer", + ) + qwen3_encoder: Qwen3EncoderField | None = InputField( + default=None, + title="Qwen3 Encoder", + description=FieldDescriptions.qwen3_encoder, + input=Input.Connection, + ) + + def invoke(self, context: InvocationContext) -> AnimaLoRALoaderOutput: + lora_key = self.lora.key + + if not context.models.exists(lora_key): + raise ValueError(f"Unknown lora: {lora_key}!") + + if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras): + raise ValueError(f'LoRA "{lora_key}" already applied to transformer.') + if self.qwen3_encoder and any(lora.lora.key == lora_key for lora in self.qwen3_encoder.loras): + raise ValueError(f'LoRA "{lora_key}" already applied to Qwen3 encoder.') + + output = AnimaLoRALoaderOutput() + + if self.transformer is not None: + output.transformer = self.transformer.model_copy(deep=True) + output.transformer.loras.append( + LoRAField( + lora=self.lora, + weight=self.weight, + ) + ) + if self.qwen3_encoder is not None: + output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True) + output.qwen3_encoder.loras.append( + LoRAField( + lora=self.lora, + weight=self.weight, + ) + ) + + return output + + +@invocation( + "anima_lora_collection_loader", + title="Apply LoRA Collection - Anima", + tags=["lora", "model", "anima"], + category="model", + version="1.0.0", + classification=Classification.Prototype, +) +class AnimaLoRACollectionLoader(BaseInvocation): + """Applies a collection of LoRAs to an Anima transformer.""" + + loras: Optional[LoRAField | list[LoRAField]] = InputField( + default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs" + ) + + transformer: Optional[TransformerField] = InputField( + default=None, + description=FieldDescriptions.transformer, + input=Input.Connection, + title="Transformer", + ) + qwen3_encoder: Qwen3EncoderField | None = InputField( + default=None, + title="Qwen3 Encoder", + description=FieldDescriptions.qwen3_encoder, + input=Input.Connection, + ) + + def invoke(self, context: InvocationContext) -> AnimaLoRALoaderOutput: + output = AnimaLoRALoaderOutput() + + if self.loras is None: + if self.transformer is not None: + output.transformer = self.transformer.model_copy(deep=True) + if self.qwen3_encoder is not None: + output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True) + return output + + loras = self.loras if isinstance(self.loras, list) else [self.loras] + added_loras: list[str] = [] + + if self.transformer is not None: + output.transformer = self.transformer.model_copy(deep=True) + + if self.qwen3_encoder is not None: + output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True) + + for lora in loras: + if lora is None: + continue + if lora.lora.key in added_loras: + continue + + if not context.models.exists(lora.lora.key): + raise ValueError(f"Unknown lora: {lora.lora.key}!") + + if lora.lora.base is not BaseModelType.Anima: + raise ValueError( + f"LoRA '{lora.lora.key}' is for {lora.lora.base.value if lora.lora.base else 'unknown'} models, " + "not Anima models. Ensure you are using an Anima compatible LoRA." + ) + + added_loras.append(lora.lora.key) + + if self.transformer is not None and output.transformer is not None: + output.transformer.loras.append(lora) + + if self.qwen3_encoder is not None and output.qwen3_encoder is not None: + output.qwen3_encoder.loras.append(lora) + + return output diff --git a/invokeai/app/invocations/anima_model_loader.py b/invokeai/app/invocations/anima_model_loader.py new file mode 100644 index 0000000000..7051148cb1 --- /dev/null +++ b/invokeai/app/invocations/anima_model_loader.py @@ -0,0 +1,102 @@ +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + Classification, + invocation, + invocation_output, +) +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField +from invokeai.app.invocations.model import ( + ModelIdentifierField, + Qwen3EncoderField, + T5EncoderField, + TransformerField, + VAEField, +) +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.app.util.t5_model_identifier import ( + preprocess_t5_encoder_model_identifier, + preprocess_t5_tokenizer_model_identifier, +) +from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType + + +@invocation_output("anima_model_loader_output") +class AnimaModelLoaderOutput(BaseInvocationOutput): + """Anima model loader output.""" + + transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer") + qwen3_encoder: Qwen3EncoderField = OutputField(description=FieldDescriptions.qwen3_encoder, title="Qwen3 Encoder") + vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE") + t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder") + + +@invocation( + "anima_model_loader", + title="Main Model - Anima", + tags=["model", "anima"], + category="model", + version="1.3.0", + classification=Classification.Prototype, +) +class AnimaModelLoaderInvocation(BaseInvocation): + """Loads an Anima model, outputting its submodels. + + Anima uses: + - Transformer: Cosmos Predict2 DiT + LLM Adapter (from single-file checkpoint) + - Qwen3 Encoder: Qwen3 0.6B (standalone single-file) + - VAE: AutoencoderKLQwenImage / Wan 2.1 VAE (standalone single-file or FLUX VAE) + - T5 Encoder: T5-XXL model (only the tokenizer submodel is used, for LLM Adapter token IDs) + """ + + model: ModelIdentifierField = InputField( + description="Anima main model (transformer + LLM adapter).", + input=Input.Direct, + ui_model_base=BaseModelType.Anima, + ui_model_type=ModelType.Main, + title="Transformer", + ) + + vae_model: ModelIdentifierField = InputField( + description="Standalone VAE model. Anima uses a Wan 2.1 / QwenImage VAE (16-channel). " + "A FLUX VAE can also be used as a compatible fallback.", + input=Input.Direct, + ui_model_type=ModelType.VAE, + title="VAE", + ) + + qwen3_encoder_model: ModelIdentifierField = InputField( + description="Standalone Qwen3 0.6B Encoder model.", + input=Input.Direct, + ui_model_type=ModelType.Qwen3Encoder, + title="Qwen3 Encoder", + ) + + t5_encoder_model: ModelIdentifierField = InputField( + description="T5-XXL encoder model. The tokenizer submodel is used for Anima text encoding.", + input=Input.Direct, + ui_model_type=ModelType.T5Encoder, + title="T5 Encoder", + ) + + def invoke(self, context: InvocationContext) -> AnimaModelLoaderOutput: + # Transformer always comes from the main model + transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer}) + + # VAE + vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE}) + + # Qwen3 Encoder + qwen3_tokenizer = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer}) + qwen3_encoder = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder}) + + # T5 Encoder (only tokenizer submodel is used by Anima) + t5_tokenizer = preprocess_t5_tokenizer_model_identifier(self.t5_encoder_model) + t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model) + + return AnimaModelLoaderOutput( + transformer=TransformerField(transformer=transformer, loras=[]), + qwen3_encoder=Qwen3EncoderField(tokenizer=qwen3_tokenizer, text_encoder=qwen3_encoder), + vae=VAEField(vae=vae), + t5_encoder=T5EncoderField(tokenizer=t5_tokenizer, text_encoder=t5_encoder, loras=[]), + ) diff --git a/invokeai/app/invocations/anima_text_encoder.py b/invokeai/app/invocations/anima_text_encoder.py new file mode 100644 index 0000000000..1856a69ae7 --- /dev/null +++ b/invokeai/app/invocations/anima_text_encoder.py @@ -0,0 +1,221 @@ +"""Anima text encoder invocation. + +Encodes text using the dual-conditioning pipeline: +1. Qwen3 0.6B: Produces hidden states (last layer) +2. T5-XXL Tokenizer: Produces token IDs only (no T5 model needed) + +Both outputs are stored together in AnimaConditioningInfo and used by +the LLM Adapter inside the transformer during denoising. + +Key differences from Z-Image text encoder: +- Anima uses Qwen3 0.6B (base model, NOT instruct) — no chat template +- Anima additionally tokenizes with T5-XXL tokenizer to get token IDs +- Qwen3 output uses all positions (including padding) for full context +""" + +from contextlib import ExitStack +from typing import Iterator, Tuple + +import torch +from transformers import PreTrainedModel, PreTrainedTokenizerBase + +from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation +from invokeai.app.invocations.fields import ( + AnimaConditioningField, + FieldDescriptions, + Input, + InputField, + TensorField, + UIComponent, +) +from invokeai.app.invocations.model import Qwen3EncoderField, T5EncoderField +from invokeai.app.invocations.primitives import AnimaConditioningOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.patches.layer_patcher import LayerPatcher +from invokeai.backend.patches.lora_conversions.anima_lora_constants import ANIMA_LORA_QWEN3_PREFIX +from invokeai.backend.patches.model_patch_raw import ModelPatchRaw +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + AnimaConditioningInfo, + ConditioningFieldData, +) +from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.logging import InvokeAILogger + +logger = InvokeAILogger.get_logger(__name__) + +# T5-XXL max sequence length for token IDs +T5_MAX_SEQ_LEN = 512 + +# Safety cap for Qwen3 sequence length to prevent GPU OOM on extremely long prompts. +# Qwen3 0.6B supports 32K context but the LLM Adapter doesn't need that much. +QWEN3_MAX_SEQ_LEN = 8192 + + +@invocation( + "anima_text_encoder", + title="Prompt - Anima", + tags=["prompt", "conditioning", "anima"], + category="conditioning", + version="1.3.0", + classification=Classification.Prototype, +) +class AnimaTextEncoderInvocation(BaseInvocation): + """Encodes and preps a prompt for an Anima image. + + Uses Qwen3 0.6B for hidden state extraction and T5-XXL tokenizer for + token IDs (no T5 model weights needed). Both are combined by the + LLM Adapter inside the Anima transformer during denoising. + """ + + prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea) + qwen3_encoder: Qwen3EncoderField = InputField( + title="Qwen3 Encoder", + description=FieldDescriptions.qwen3_encoder, + input=Input.Connection, + ) + t5_encoder: T5EncoderField = InputField( + title="T5 Encoder", + description=FieldDescriptions.t5_encoder, + input=Input.Connection, + ) + mask: TensorField | None = InputField( + default=None, + description="A mask defining the region that this conditioning prompt applies to.", + ) + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> AnimaConditioningOutput: + qwen3_embeds, t5xxl_ids, t5xxl_weights = self._encode_prompt(context) + + # Move to CPU for storage + qwen3_embeds = qwen3_embeds.detach().to("cpu") + t5xxl_ids = t5xxl_ids.detach().to("cpu") + t5xxl_weights = t5xxl_weights.detach().to("cpu") if t5xxl_weights is not None else None + + conditioning_data = ConditioningFieldData( + conditionings=[ + AnimaConditioningInfo( + qwen3_embeds=qwen3_embeds, + t5xxl_ids=t5xxl_ids, + t5xxl_weights=t5xxl_weights, + ) + ] + ) + conditioning_name = context.conditioning.save(conditioning_data) + return AnimaConditioningOutput( + conditioning=AnimaConditioningField(conditioning_name=conditioning_name, mask=self.mask) + ) + + def _encode_prompt( + self, + context: InvocationContext, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + """Encode prompt using Qwen3 0.6B and T5-XXL tokenizer. + + Returns: + Tuple of (qwen3_embeds, t5xxl_ids, t5xxl_weights). + - qwen3_embeds: Shape (max_seq_len, 1024) — includes all positions (including padding) + to preserve full sequence context for the LLM Adapter. + - t5xxl_ids: Shape (seq_len,) — T5-XXL token IDs (unpadded). + - t5xxl_weights: None (uniform weights for now). + """ + prompt = self.prompt + + # --- Step 1: Encode with Qwen3 0.6B --- + text_encoder_info = context.models.load(self.qwen3_encoder.text_encoder) + tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer) + + with ExitStack() as exit_stack: + (_, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device()) + (_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device()) + + device = text_encoder.device + + # Apply LoRA models to the text encoder + lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device) + exit_stack.enter_context( + LayerPatcher.apply_smart_model_patches( + model=text_encoder, + patches=self._lora_iterator(context), + prefix=ANIMA_LORA_QWEN3_PREFIX, + dtype=lora_dtype, + ) + ) + + if not isinstance(text_encoder, PreTrainedModel): + raise TypeError(f"Expected PreTrainedModel for text encoder, got {type(text_encoder).__name__}.") + if not isinstance(tokenizer, PreTrainedTokenizerBase): + raise TypeError(f"Expected PreTrainedTokenizerBase for tokenizer, got {type(tokenizer).__name__}.") + + context.util.signal_progress("Running Qwen3 0.6B text encoder") + + # Anima uses base Qwen3 (not instruct) — tokenize directly, no chat template. + # A safety cap is applied to prevent GPU OOM on extremely long prompts. + text_inputs = tokenizer( + prompt, + padding=False, + truncation=True, + max_length=QWEN3_MAX_SEQ_LEN, + return_attention_mask=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + if not isinstance(text_input_ids, torch.Tensor) or not isinstance(attention_mask, torch.Tensor): + raise TypeError("Tokenizer returned unexpected types.") + + if text_input_ids.shape[-1] == QWEN3_MAX_SEQ_LEN: + logger.warning( + f"Prompt was truncated to {QWEN3_MAX_SEQ_LEN} tokens. " + "Consider shortening the prompt for best results." + ) + + # Ensure at least 1 token (empty prompts produce 0 tokens with padding=False) + if text_input_ids.shape[-1] == 0: + pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id + text_input_ids = torch.tensor([[pad_id]]) + attention_mask = torch.tensor([[1]]) + + # Get last hidden state from Qwen3 (final layer output) + prompt_mask = attention_mask.to(device).bool() + outputs = text_encoder( + text_input_ids.to(device), + attention_mask=prompt_mask, + output_hidden_states=True, + ) + + if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None: + raise RuntimeError("Text encoder did not return hidden_states.") + if len(outputs.hidden_states) < 1: + raise RuntimeError(f"Expected at least 1 hidden state, got {len(outputs.hidden_states)}.") + + # Use last hidden state — only real tokens, no padding + qwen3_embeds = outputs.hidden_states[-1][0] # Shape: (seq_len, 1024) + + # --- Step 2: Tokenize with T5-XXL tokenizer (IDs only, no model) --- + context.util.signal_progress("Tokenizing with T5-XXL") + t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer) + with t5_tokenizer_info.model_on_device() as (_, t5_tokenizer): + t5_tokens = t5_tokenizer( + prompt, + padding=False, + truncation=True, + max_length=T5_MAX_SEQ_LEN, + return_tensors="pt", + ) + t5xxl_ids = t5_tokens.input_ids[0] # Shape: (seq_len,) + + return qwen3_embeds, t5xxl_ids, None + + def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]: + """Iterate over LoRA models to apply to the Qwen3 text encoder.""" + for lora in self.qwen3_encoder.loras: + lora_info = context.models.load(lora.lora) + if not isinstance(lora_info.model, ModelPatchRaw): + raise TypeError( + f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. " + "The LoRA model may be corrupted or incompatible." + ) + yield (lora_info.model, lora.weight) + del lora_info diff --git a/invokeai/app/invocations/batch.py b/invokeai/app/invocations/batch.py index 34ecd38f26..f79b8816ad 100644 --- a/invokeai/app/invocations/batch.py +++ b/invokeai/app/invocations/batch.py @@ -56,7 +56,7 @@ class BaseBatchInvocation(BaseInvocation): "image_batch", title="Image Batch", tags=["primitives", "image", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) @@ -87,7 +87,7 @@ class ImageGeneratorField(BaseModel): "image_generator", title="Image Generator", tags=["primitives", "board", "image", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) @@ -111,7 +111,7 @@ class ImageGenerator(BaseInvocation): "string_batch", title="String Batch", tags=["primitives", "string", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) @@ -142,7 +142,7 @@ class StringGeneratorField(BaseModel): "string_generator", title="String Generator", tags=["primitives", "string", "number", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) @@ -166,7 +166,7 @@ class StringGenerator(BaseInvocation): "integer_batch", title="Integer Batch", tags=["primitives", "integer", "number", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) @@ -195,7 +195,7 @@ class IntegerGeneratorField(BaseModel): "integer_generator", title="Integer Generator", tags=["primitives", "int", "number", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) @@ -219,7 +219,7 @@ class IntegerGenerator(BaseInvocation): "float_batch", title="Float Batch", tags=["primitives", "float", "number", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) @@ -250,7 +250,7 @@ class FloatGeneratorField(BaseModel): "float_generator", title="Float Generator", tags=["primitives", "float", "number", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) diff --git a/invokeai/app/invocations/canny.py b/invokeai/app/invocations/canny.py index 0cdc386e62..dbfde6d353 100644 --- a/invokeai/app/invocations/canny.py +++ b/invokeai/app/invocations/canny.py @@ -11,7 +11,7 @@ from invokeai.backend.image_util.util import cv2_to_pil, pil_to_cv2 "canny_edge_detection", title="Canny Edge Detection", tags=["controlnet", "canny"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class CannyEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/canvas.py b/invokeai/app/invocations/canvas.py new file mode 100644 index 0000000000..cf13c3334f --- /dev/null +++ b/invokeai/app/invocations/canvas.py @@ -0,0 +1,27 @@ +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext + + +@invocation( + "canvas_output", + title="Canvas Output", + tags=["canvas", "output", "image"], + category="canvas", + version="1.0.0", + use_cache=False, +) +class CanvasOutputInvocation(BaseInvocation): + """Outputs an image to the canvas staging area. + + Use this node in workflows intended for canvas workflow integration. + Connect the final image of your workflow to this node to send it + to the canvas staging area when run via 'Run Workflow on Canvas'.""" + + image: ImageField = InputField(description=FieldDescriptions.image) + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) + image_dto = context.images.save(image=image) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/cogview4_denoise.py b/invokeai/app/invocations/cogview4_denoise.py index 070d8a3478..e8b910f731 100644 --- a/invokeai/app/invocations/cogview4_denoise.py +++ b/invokeai/app/invocations/cogview4_denoise.py @@ -33,7 +33,7 @@ from invokeai.backend.util.devices import TorchDevice "cogview4_denoise", title="Denoise - CogView4", tags=["image", "cogview4"], - category="image", + category="latents", version="1.0.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/cogview4_image_to_latents.py b/invokeai/app/invocations/cogview4_image_to_latents.py index 630b9ab1e3..facbc38dd4 100644 --- a/invokeai/app/invocations/cogview4_image_to_latents.py +++ b/invokeai/app/invocations/cogview4_image_to_latents.py @@ -27,7 +27,7 @@ from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory "cogview4_i2l", title="Image to Latents - CogView4", tags=["image", "latents", "vae", "i2l", "cogview4"], - category="image", + category="latents", version="1.0.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/cogview4_text_encoder.py b/invokeai/app/invocations/cogview4_text_encoder.py index c6ef1663cf..13234889fb 100644 --- a/invokeai/app/invocations/cogview4_text_encoder.py +++ b/invokeai/app/invocations/cogview4_text_encoder.py @@ -6,6 +6,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField from invokeai.app.invocations.model import GlmEncoderField from invokeai.app.invocations.primitives import CogView4ConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( CogView4ConditioningInfo, ConditioningFieldData, @@ -19,7 +20,7 @@ COGVIEW4_GLM_MAX_SEQ_LEN = 1024 "cogview4_text_encoder", title="Prompt - CogView4", tags=["prompt", "conditioning", "cogview4"], - category="conditioning", + category="prompt", version="1.0.0", classification=Classification.Prototype, ) @@ -46,10 +47,18 @@ class CogView4TextEncoderInvocation(BaseInvocation): prompt = [self.prompt] # TODO(ryand): Add model inputs to the invocation rather than hard-coding. + glm_text_encoder_info = context.models.load(self.glm_encoder.text_encoder) with ( - context.models.load(self.glm_encoder.text_encoder).model_on_device() as (_, glm_text_encoder), + glm_text_encoder_info.model_on_device() as (_, glm_text_encoder), context.models.load(self.glm_encoder.tokenizer).model_on_device() as (_, glm_tokenizer), ): + repaired_tensors = glm_text_encoder_info.repair_required_tensors_on_device() + device = get_effective_device(glm_text_encoder) + if repaired_tensors > 0: + context.logger.warning( + f"Recovered {repaired_tensors} required GLM tensor(s) onto {device} after a partial device mismatch." + ) + context.util.signal_progress("Running GLM text encoder") assert isinstance(glm_text_encoder, GlmModel) assert isinstance(glm_tokenizer, PreTrainedTokenizerFast) @@ -85,9 +94,7 @@ class CogView4TextEncoderInvocation(BaseInvocation): device=text_input_ids.device, ) text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1) - prompt_embeds = glm_text_encoder( - text_input_ids.to(glm_text_encoder.device), output_hidden_states=True - ).hidden_states[-2] + prompt_embeds = glm_text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2] assert isinstance(prompt_embeds, torch.Tensor) return prompt_embeds diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index bd3dedb3f8..39e77f5b63 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -11,9 +11,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX -@invocation( - "range", title="Integer Range", tags=["collection", "integer", "range"], category="collections", version="1.0.0" -) +@invocation("range", title="Integer Range", tags=["collection", "integer", "range"], category="batch", version="1.0.0") class RangeInvocation(BaseInvocation): """Creates a range of numbers from start to stop with step""" @@ -35,7 +33,7 @@ class RangeInvocation(BaseInvocation): "range_of_size", title="Integer Range of Size", tags=["collection", "integer", "size", "range"], - category="collections", + category="batch", version="1.0.0", ) class RangeOfSizeInvocation(BaseInvocation): @@ -55,7 +53,7 @@ class RangeOfSizeInvocation(BaseInvocation): "random_range", title="Random Range", tags=["range", "integer", "random", "collection"], - category="collections", + category="batch", version="1.0.1", use_cache=False, ) diff --git a/invokeai/app/invocations/color_map.py b/invokeai/app/invocations/color_map.py index e55584caf5..ec95acfffd 100644 --- a/invokeai/app/invocations/color_map.py +++ b/invokeai/app/invocations/color_map.py @@ -11,7 +11,7 @@ from invokeai.backend.image_util.util import np_to_pil, pil_to_np "color_map", title="Color Map", tags=["controlnet"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class ColorMapInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 5ce88145ff..99373531d8 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -19,6 +19,7 @@ from invokeai.app.invocations.model import CLIPField from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.ti_utils import generate_ti_list +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.model_patch_raw import ModelPatchRaw @@ -42,7 +43,7 @@ from invokeai.backend.util.devices import TorchDevice "compel", title="Prompt - SD1.5", tags=["prompt", "compel"], - category="conditioning", + category="prompt", version="1.2.1", ) class CompelInvocation(BaseInvocation): @@ -103,7 +104,7 @@ class CompelInvocation(BaseInvocation): textual_inversion_manager=ti_manager, dtype_for_device_getter=TorchDevice.choose_torch_dtype, truncate_long_prompts=False, - device=text_encoder.device, # Use the device the model is actually on + device=get_effective_device(text_encoder), split_long_text_mode=SplitLongTextMode.SENTENCES, ) @@ -212,7 +213,7 @@ class SDXLPromptInvocationBase: truncate_long_prompts=False, # TODO: returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip requires_pooled=get_pooled, - device=text_encoder.device, # Use the device the model is actually on + device=get_effective_device(text_encoder), split_long_text_mode=SplitLongTextMode.SENTENCES, ) @@ -247,7 +248,7 @@ class SDXLPromptInvocationBase: "sdxl_compel_prompt", title="Prompt - SDXL", tags=["sdxl", "compel", "prompt"], - category="conditioning", + category="prompt", version="1.2.1", ) class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): @@ -341,7 +342,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): "sdxl_refiner_compel_prompt", title="Prompt - SDXL Refiner", tags=["sdxl", "compel", "prompt"], - category="conditioning", + category="prompt", version="1.1.2", ) class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): @@ -390,7 +391,7 @@ class CLIPSkipInvocationOutput(BaseInvocationOutput): "clip_skip", title="Apply CLIP Skip - SD1.5, SDXL", tags=["clipskip", "clip", "skip"], - category="conditioning", + category="prompt", version="1.1.1", ) class CLIPSkipInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/content_shuffle.py b/invokeai/app/invocations/content_shuffle.py index e01096ecea..6fd35b53eb 100644 --- a/invokeai/app/invocations/content_shuffle.py +++ b/invokeai/app/invocations/content_shuffle.py @@ -9,7 +9,7 @@ from invokeai.backend.image_util.content_shuffle import content_shuffle "content_shuffle", title="Content Shuffle", tags=["controlnet", "normal"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class ContentShuffleInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/controlnet.py b/invokeai/app/invocations/controlnet.py index d1878d967e..9b0fc8219b 100644 --- a/invokeai/app/invocations/controlnet.py +++ b/invokeai/app/invocations/controlnet.py @@ -64,7 +64,7 @@ class ControlOutput(BaseInvocationOutput): @invocation( - "controlnet", title="ControlNet - SD1.5, SD2, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3" + "controlnet", title="ControlNet - SD1.5, SD2, SDXL", tags=["controlnet"], category="conditioning", version="1.1.3" ) class ControlNetInvocation(BaseInvocation): """Collects ControlNet info to pass to other nodes""" @@ -116,7 +116,7 @@ class ControlNetInvocation(BaseInvocation): "heuristic_resize", title="Heuristic Resize", tags=["image, controlnet"], - category="image", + category="controlnet_preprocessors", version="1.1.1", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/create_denoise_mask.py b/invokeai/app/invocations/create_denoise_mask.py index d013e8f4f6..419a516bcd 100644 --- a/invokeai/app/invocations/create_denoise_mask.py +++ b/invokeai/app/invocations/create_denoise_mask.py @@ -18,7 +18,7 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_t "create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], - category="latents", + category="mask", version="1.0.2", ) class CreateDenoiseMaskInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/create_gradient_mask.py b/invokeai/app/invocations/create_gradient_mask.py index 8a7e7c5231..08826cc5ef 100644 --- a/invokeai/app/invocations/create_gradient_mask.py +++ b/invokeai/app/invocations/create_gradient_mask.py @@ -41,7 +41,7 @@ class GradientMaskOutput(BaseInvocationOutput): "create_gradient_mask", title="Create Gradient Mask", tags=["mask", "denoise"], - category="latents", + category="mask", version="1.3.0", ) class CreateGradientMaskInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/depth_anything.py b/invokeai/app/invocations/depth_anything.py index af79413ce0..1fd808efde 100644 --- a/invokeai/app/invocations/depth_anything.py +++ b/invokeai/app/invocations/depth_anything.py @@ -20,7 +20,7 @@ DEPTH_ANYTHING_MODELS = { "depth_anything_depth_estimation", title="Depth Anything Depth Estimation", tags=["controlnet", "depth", "depth anything"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class DepthAnythingDepthEstimationInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/dw_openpose.py b/invokeai/app/invocations/dw_openpose.py index 225c7e2283..918a4bc4d0 100644 --- a/invokeai/app/invocations/dw_openpose.py +++ b/invokeai/app/invocations/dw_openpose.py @@ -11,7 +11,7 @@ from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector "dw_openpose_detection", title="DW Openpose Detection", tags=["controlnet", "dwpose", "openpose"], - category="controlnet", + category="controlnet_preprocessors", version="1.1.1", ) class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/external_image_generation.py b/invokeai/app/invocations/external_image_generation.py index 983dc5caf5..b66affe9b0 100644 --- a/invokeai/app/invocations/external_image_generation.py +++ b/invokeai/app/invocations/external_image_generation.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar +from typing import Any, ClassVar, Literal from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation from invokeai.app.invocations.fields import ( @@ -34,18 +34,18 @@ class BaseExternalImageGenerationInvocation(BaseInvocation, WithMetadata, WithBo ) mode: ExternalGenerationMode = InputField(default="txt2img", description="Generation mode") prompt: str = InputField(description="Prompt") - negative_prompt: str | None = InputField(default=None, description="Negative prompt") seed: int | None = InputField(default=None, description=FieldDescriptions.seed) num_images: int = InputField(default=1, gt=0, description="Number of images to generate") width: int = InputField(default=1024, gt=0, description=FieldDescriptions.width) height: int = InputField(default=1024, gt=0, description=FieldDescriptions.height) - steps: int | None = InputField(default=None, gt=0, description=FieldDescriptions.steps) - guidance: float | None = InputField(default=None, ge=0, description="Guidance strength") + image_size: str | None = InputField(default=None, description="Image size preset (e.g. 1K, 2K, 4K)") init_image: ImageField | None = InputField(default=None, description="Init image for img2img/inpaint") mask_image: ImageField | None = InputField(default=None, description="Mask image for inpaint") reference_images: list[ImageField] = InputField(default=[], description="Reference images") - reference_image_weights: list[float] | None = InputField(default=None, description="Reference image weights") - reference_image_modes: list[str] | None = InputField(default=None, description="Reference image modes") + + def _build_provider_options(self) -> dict[str, Any] | None: + """Override in provider-specific subclasses to pass extra options.""" + return None def invoke(self, context: InvocationContext) -> ImageCollectionOutput: model_config = context.models.get_config(self.model) @@ -65,38 +65,25 @@ class BaseExternalImageGenerationInvocation(BaseInvocation, WithMetadata, WithBo if self.mask_image is not None: mask_image = context.images.get_pil(self.mask_image.image_name, mode="L") - if self.reference_image_weights is not None and len(self.reference_image_weights) != len(self.reference_images): - raise ValueError("reference_image_weights must match reference_images length") - - if self.reference_image_modes is not None and len(self.reference_image_modes) != len(self.reference_images): - raise ValueError("reference_image_modes must match reference_images length") - reference_images: list[ExternalReferenceImage] = [] - for index, image_field in enumerate(self.reference_images): + for image_field in self.reference_images: reference_image = context.images.get_pil(image_field.image_name, mode="RGB") - weight = None - mode = None - if self.reference_image_weights is not None: - weight = self.reference_image_weights[index] - if self.reference_image_modes is not None: - mode = self.reference_image_modes[index] - reference_images.append(ExternalReferenceImage(image=reference_image, weight=weight, mode=mode)) + reference_images.append(ExternalReferenceImage(image=reference_image)) request = ExternalGenerationRequest( model=model_config, mode=self.mode, prompt=self.prompt, - negative_prompt=self.negative_prompt, seed=self.seed, num_images=self.num_images, width=self.width, height=self.height, - steps=self.steps, - guidance=self.guidance, + image_size=self.image_size, init_image=init_image, mask_image=mask_image, reference_images=reference_images, metadata=self._build_request_metadata(), + provider_options=self._build_provider_options(), ) result = context._services.external_generation.generate(request) @@ -172,6 +159,23 @@ class OpenAIImageGenerationInvocation(BaseExternalImageGenerationInvocation): provider_id = "openai" + quality: Literal["auto", "high", "medium", "low"] = InputField(default="auto", description="Output image quality") + background: Literal["auto", "transparent", "opaque"] = InputField( + default="auto", description="Background transparency handling" + ) + input_fidelity: Literal["low", "high"] | None = InputField( + default=None, description="Fidelity to source images (edits only)" + ) + + def _build_provider_options(self) -> dict[str, Any]: + options: dict[str, Any] = { + "quality": self.quality, + "background": self.background, + } + if self.input_fidelity is not None: + options["input_fidelity"] = self.input_fidelity + return options + @invocation( "gemini_image_generation", @@ -184,3 +188,16 @@ class GeminiImageGenerationInvocation(BaseExternalImageGenerationInvocation): """Generate images using a Gemini-hosted external model.""" provider_id = "gemini" + + temperature: float | None = InputField(default=None, ge=0.0, le=2.0, description="Sampling temperature") + thinking_level: Literal["minimal", "high"] | None = InputField( + default=None, description="Thinking level for image generation" + ) + + def _build_provider_options(self) -> dict[str, Any] | None: + options: dict[str, Any] = {} + if self.temperature is not None: + options["temperature"] = self.temperature + if self.thinking_level is not None: + options["thinking_level"] = self.thinking_level + return options or None diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index 987f1b1e40..1092a67ce9 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -435,7 +435,9 @@ def get_faces_list( return all_faces -@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.2") +@invocation( + "face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="segmentation", version="1.2.2" +) class FaceOffInvocation(BaseInvocation, WithMetadata): """Bound, extract, and mask a face from an image using MediaPipe detection""" @@ -514,7 +516,9 @@ class FaceOffInvocation(BaseInvocation, WithMetadata): return output -@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.2") +@invocation( + "face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="segmentation", version="1.2.2" +) class FaceMaskInvocation(BaseInvocation, WithMetadata): """Face mask creation using mediapipe face detection""" @@ -617,7 +621,11 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata): @invocation( - "face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.2" + "face_identifier", + title="FaceIdentifier", + tags=["image", "face", "identifier"], + category="segmentation", + version="1.2.2", ) class FaceIdentifierInvocation(BaseInvocation, WithMetadata, WithBoard): """Outputs an image with detected face IDs printed on each face. For use with other FaceTools.""" diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index cca09a059d..fbe0e9a615 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -171,6 +171,8 @@ class FieldDescriptions: sd3_model = "SD3 model (MMDiTX) to load" cogview4_model = "CogView4 model (Transformer) to load" z_image_model = "Z-Image model (Transformer) to load" + qwen_image_model = "Qwen Image Edit model (Transformer) to load" + qwen_vl_encoder = "Qwen2.5-VL tokenizer, processor and text/vision encoder" sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load" sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load" onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load" @@ -340,6 +342,27 @@ class ZImageConditioningField(BaseModel): ) +class QwenImageConditioningField(BaseModel): + """A Qwen Image Edit conditioning tensor primitive value""" + + conditioning_name: str = Field(description="The name of conditioning tensor") + + +class AnimaConditioningField(BaseModel): + """An Anima conditioning tensor primitive value. + + Anima conditioning contains Qwen3 0.6B hidden states and T5-XXL token IDs, + which are combined by the LLM Adapter inside the transformer. + """ + + conditioning_name: str = Field(description="The name of conditioning tensor") + mask: Optional[TensorField] = Field( + default=None, + description="The mask associated with this conditioning tensor for regional prompting. " + "Excluded regions should be set to False, included regions should be set to True.", + ) + + class ConditioningField(BaseModel): """A conditioning tensor primitive value""" diff --git a/invokeai/app/invocations/flux2_denoise.py b/invokeai/app/invocations/flux2_denoise.py index c387a72790..1b5ea372d6 100644 --- a/invokeai/app/invocations/flux2_denoise.py +++ b/invokeai/app/invocations/flux2_denoise.py @@ -53,7 +53,7 @@ from invokeai.backend.util.devices import TorchDevice "flux2_denoise", title="FLUX2 Denoise", tags=["image", "flux", "flux2", "klein", "denoise"], - category="image", + category="latents", version="1.4.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/flux2_klein_text_encoder.py b/invokeai/app/invocations/flux2_klein_text_encoder.py index 6ca307ebf0..b2728d1d7c 100644 --- a/invokeai/app/invocations/flux2_klein_text_encoder.py +++ b/invokeai/app/invocations/flux2_klein_text_encoder.py @@ -25,6 +25,7 @@ from invokeai.app.invocations.fields import ( from invokeai.app.invocations.model import Qwen3EncoderField from invokeai.app.invocations.primitives import FluxConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_T5_PREFIX from invokeai.backend.patches.model_patch_raw import ModelPatchRaw @@ -44,7 +45,7 @@ KLEIN_MAX_SEQ_LEN = 512 "flux2_klein_text_encoder", title="Prompt - Flux2 Klein", tags=["prompt", "conditioning", "flux", "klein", "qwen3"], - category="conditioning", + category="prompt", version="1.1.1", classification=Classification.Prototype, ) @@ -100,7 +101,12 @@ class Flux2KleinTextEncoderInvocation(BaseInvocation): tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer) (_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device()) - device = text_encoder.device + repaired_tensors = text_encoder_info.repair_required_tensors_on_device() + device = get_effective_device(text_encoder) + if repaired_tensors > 0: + context.logger.warning( + f"Recovered {repaired_tensors} required Qwen3 tensor(s) onto {device} after a partial device mismatch." + ) # Apply LoRA models lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device) diff --git a/invokeai/app/invocations/flux_controlnet.py b/invokeai/app/invocations/flux_controlnet.py index 8228484375..b11d497f31 100644 --- a/invokeai/app/invocations/flux_controlnet.py +++ b/invokeai/app/invocations/flux_controlnet.py @@ -50,7 +50,7 @@ class FluxControlNetOutput(BaseInvocationOutput): "flux_controlnet", title="FLUX ControlNet", tags=["controlnet", "flux"], - category="controlnet", + category="conditioning", version="1.0.0", ) class FluxControlNetInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index d6102b105b..84f0a030c5 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -70,7 +70,7 @@ from invokeai.backend.util.devices import TorchDevice "flux_denoise", title="FLUX Denoise", tags=["image", "flux"], - category="image", + category="latents", version="4.5.1", ) class FluxDenoiseInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/flux_fill.py b/invokeai/app/invocations/flux_fill.py index cff8f2b1e5..440f3e5c97 100644 --- a/invokeai/app/invocations/flux_fill.py +++ b/invokeai/app/invocations/flux_fill.py @@ -29,7 +29,7 @@ class FluxFillOutput(BaseInvocationOutput): "flux_fill", title="FLUX Fill Conditioning", tags=["inpaint"], - category="inpaint", + category="conditioning", version="1.0.0", classification=Classification.Beta, ) diff --git a/invokeai/app/invocations/flux_ip_adapter.py b/invokeai/app/invocations/flux_ip_adapter.py index 4a1997c512..c0d797d0bd 100644 --- a/invokeai/app/invocations/flux_ip_adapter.py +++ b/invokeai/app/invocations/flux_ip_adapter.py @@ -24,7 +24,7 @@ from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType "flux_ip_adapter", title="FLUX IP-Adapter", tags=["ip_adapter", "control"], - category="ip_adapter", + category="conditioning", version="1.0.0", ) class FluxIPAdapterInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/flux_redux.py b/invokeai/app/invocations/flux_redux.py index 403d78b078..b68e9911c5 100644 --- a/invokeai/app/invocations/flux_redux.py +++ b/invokeai/app/invocations/flux_redux.py @@ -47,7 +47,7 @@ DOWNSAMPLING_FUNCTIONS = Literal["nearest", "bilinear", "bicubic", "area", "near "flux_redux", title="FLUX Redux", tags=["ip_adapter", "control"], - category="ip_adapter", + category="conditioning", version="2.1.0", classification=Classification.Beta, ) diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index 56ebbe7fd9..8b3b33fad1 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -28,7 +28,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit "flux_text_encoder", title="Prompt - FLUX", tags=["prompt", "conditioning", "flux"], - category="conditioning", + category="prompt", version="1.1.2", ) class FluxTextEncoderInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/grounding_dino.py b/invokeai/app/invocations/grounding_dino.py index 1e3d5cea0c..4d900c5034 100644 --- a/invokeai/app/invocations/grounding_dino.py +++ b/invokeai/app/invocations/grounding_dino.py @@ -24,7 +24,7 @@ GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = { "grounding_dino", title="Grounding DINO (Text Prompt Object Detection)", tags=["prompt", "object detection"], - category="image", + category="segmentation", version="1.0.0", ) class GroundingDinoInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/hed.py b/invokeai/app/invocations/hed.py index 5ea6e8df1f..e2b68143e5 100644 --- a/invokeai/app/invocations/hed.py +++ b/invokeai/app/invocations/hed.py @@ -11,7 +11,7 @@ from invokeai.backend.image_util.hed import ControlNetHED_Apache2, HEDEdgeDetect "hed_edge_detection", title="HED Edge Detection", tags=["controlnet", "hed", "softedge"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class HEDEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/ideal_size.py b/invokeai/app/invocations/ideal_size.py index aae3a37c8e..5cfa9c04d0 100644 --- a/invokeai/app/invocations/ideal_size.py +++ b/invokeai/app/invocations/ideal_size.py @@ -21,6 +21,7 @@ class IdealSizeOutput(BaseInvocationOutput): "ideal_size", title="Ideal Size - SD1.5, SDXL", tags=["latents", "math", "ideal_size"], + category="latents", version="1.0.6", ) class IdealSizeInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 1d5ba44b24..17576a0296 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -21,7 +21,7 @@ from invokeai.app.invocations.fields import ( WithBoard, WithMetadata, ) -from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.invocations.primitives import ImageOutput, StringOutput from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX @@ -197,7 +197,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard): "tomask", title="Mask from Alpha", tags=["image", "mask"], - category="image", + category="mask", version="1.2.2", ) class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard): @@ -581,11 +581,30 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithBoard): return ImageOutput.build(image_dto) +@invocation( + "decode_watermark", + title="Decode Invisible Watermark", + tags=["image", "watermark"], + category="image", + version="1.0.0", +) +class DecodeInvisibleWatermarkInvocation(BaseInvocation): + """Decode an invisible watermark from an image.""" + + image: ImageField = InputField(description="The image to decode the watermark from") + length: int = InputField(default=8, description="The expected watermark length in bytes") + + def invoke(self, context: InvocationContext) -> StringOutput: + image = context.images.get_pil(self.image.image_name) + watermark = InvisibleWatermark.decode_watermark(image, self.length) + return StringOutput(value=watermark) + + @invocation( "mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], - category="image", + category="mask", version="1.2.2", ) class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard): @@ -624,7 +643,7 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard): "mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], - category="image", + category="mask", version="1.2.2", ) class MaskCombineInvocation(BaseInvocation, WithMetadata, WithBoard): @@ -955,7 +974,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard): "save_image", title="Save Image", tags=["primitives", "image"], - category="primitives", + category="image", version="1.2.2", use_cache=False, ) @@ -976,7 +995,7 @@ class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard): "canvas_paste_back", title="Canvas Paste Back", tags=["image", "combine"], - category="image", + category="canvas", version="1.0.1", ) class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard): @@ -1013,7 +1032,7 @@ class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard): "mask_from_id", title="Mask from Segmented Image", tags=["image", "mask", "id"], - category="image", + category="mask", version="1.0.1", ) class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard): @@ -1050,7 +1069,7 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard): "canvas_v2_mask_and_crop", title="Canvas V2 Mask and Crop", tags=["image", "mask", "id"], - category="image", + category="canvas", version="1.0.0", classification=Classification.Deprecated, ) @@ -1091,7 +1110,7 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard): @invocation( - "expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.1" + "expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="mask", version="1.0.1" ) class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard): """Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard. @@ -1180,7 +1199,7 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard): "apply_mask_to_image", title="Apply Mask to Image", tags=["image", "mask", "blend"], - category="image", + category="mask", version="1.0.0", ) class ApplyMaskToImageInvocation(BaseInvocation, WithMetadata, WithBoard): @@ -1355,7 +1374,7 @@ class PasteImageIntoBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoar "flux_kontext_image_prep", title="FLUX Kontext Image Prep", tags=["image", "concatenate", "flux", "kontext"], - category="image", + category="conditioning", version="1.0.0", ) class FluxKontextConcatenateImagesInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/image_panels.py b/invokeai/app/invocations/image_panels.py index bb9aa4995a..71fefbd1c6 100644 --- a/invokeai/app/invocations/image_panels.py +++ b/invokeai/app/invocations/image_panels.py @@ -23,7 +23,7 @@ class ImagePanelCoordinateOutput(BaseInvocationOutput): "image_panel_layout", title="Image Panel Layout", tags=["image", "panel", "layout"], - category="image", + category="canvas", version="1.0.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 2b2931e78f..711f910d58 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -73,7 +73,7 @@ CLIP_VISION_MODEL_MAP: dict[Literal["ViT-L", "ViT-H", "ViT-G"], StarterModel] = "ip_adapter", title="IP-Adapter - SD1.5, SDXL", tags=["ip_adapter", "control"], - category="ip_adapter", + category="conditioning", version="1.5.1", ) class IPAdapterInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/lineart.py b/invokeai/app/invocations/lineart.py index c486c329ec..3ffd51b5b6 100644 --- a/invokeai/app/invocations/lineart.py +++ b/invokeai/app/invocations/lineart.py @@ -11,7 +11,7 @@ from invokeai.backend.image_util.lineart import Generator, LineartEdgeDetector "lineart_edge_detection", title="Lineart Edge Detection", tags=["controlnet", "lineart"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class LineartEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/lineart_anime.py b/invokeai/app/invocations/lineart_anime.py index 848756b113..f07476491c 100644 --- a/invokeai/app/invocations/lineart_anime.py +++ b/invokeai/app/invocations/lineart_anime.py @@ -9,7 +9,7 @@ from invokeai.backend.image_util.lineart_anime import LineartAnimeEdgeDetector, "lineart_anime_edge_detection", title="Lineart Anime Edge Detection", tags=["controlnet", "lineart"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class LineartAnimeEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/llava_onevision_vllm.py b/invokeai/app/invocations/llava_onevision_vllm.py index fbd2420590..ff3b801d37 100644 --- a/invokeai/app/invocations/llava_onevision_vllm.py +++ b/invokeai/app/invocations/llava_onevision_vllm.py @@ -19,7 +19,7 @@ from invokeai.backend.util.devices import TorchDevice "llava_onevision_vllm", title="LLaVA OneVision VLLM", tags=["vllm"], - category="vllm", + category="multimodal", version="1.0.0", classification=Classification.Beta, ) diff --git a/invokeai/app/invocations/logic.py b/invokeai/app/invocations/logic.py new file mode 100644 index 0000000000..7cc98afbbc --- /dev/null +++ b/invokeai/app/invocations/logic.py @@ -0,0 +1,34 @@ +from typing import Any, Optional + +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output +from invokeai.app.invocations.fields import InputField, OutputField, UIType +from invokeai.app.services.shared.invocation_context import InvocationContext + + +@invocation_output("if_output") +class IfInvocationOutput(BaseInvocationOutput): + value: Optional[Any] = OutputField( + default=None, description="The selected value", title="Output", ui_type=UIType.Any + ) + + +@invocation("if", title="If", tags=["logic", "conditional"], category="math", version="1.0.0") +class IfInvocation(BaseInvocation): + """Selects between two optional inputs based on a boolean condition.""" + + condition: bool = InputField(default=False, description="The condition used to select an input", title="Condition") + true_input: Optional[Any] = InputField( + default=None, + description="Selected when the condition is true", + title="True Input", + ui_type=UIType.Any, + ) + false_input: Optional[Any] = InputField( + default=None, + description="Selected when the condition is false", + title="False Input", + ui_type=UIType.Any, + ) + + def invoke(self, context: InvocationContext) -> IfInvocationOutput: + return IfInvocationOutput(value=self.true_input if self.condition else self.false_input) diff --git a/invokeai/app/invocations/mask.py b/invokeai/app/invocations/mask.py index 556ab8801d..49749f43b6 100644 --- a/invokeai/app/invocations/mask.py +++ b/invokeai/app/invocations/mask.py @@ -24,7 +24,7 @@ from invokeai.backend.image_util.util import pil_to_np "rectangle_mask", title="Create Rectangle Mask", tags=["conditioning"], - category="conditioning", + category="mask", version="1.0.1", ) class RectangleMaskInvocation(BaseInvocation, WithMetadata): @@ -55,7 +55,7 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata): "alpha_mask_to_tensor", title="Alpha Mask to Tensor", tags=["conditioning"], - category="conditioning", + category="mask", version="1.0.0", ) class AlphaMaskToTensorInvocation(BaseInvocation): @@ -83,7 +83,7 @@ class AlphaMaskToTensorInvocation(BaseInvocation): "invert_tensor_mask", title="Invert Tensor Mask", tags=["conditioning"], - category="conditioning", + category="mask", version="1.1.0", ) class InvertTensorMaskInvocation(BaseInvocation): @@ -115,7 +115,7 @@ class InvertTensorMaskInvocation(BaseInvocation): "image_mask_to_tensor", title="Image Mask to Tensor", tags=["conditioning"], - category="conditioning", + category="mask", version="1.0.0", ) class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata): diff --git a/invokeai/app/invocations/mediapipe_face.py b/invokeai/app/invocations/mediapipe_face.py index 89fccfc1ac..e81326463c 100644 --- a/invokeai/app/invocations/mediapipe_face.py +++ b/invokeai/app/invocations/mediapipe_face.py @@ -9,7 +9,7 @@ from invokeai.backend.image_util.mediapipe_face import detect_faces "mediapipe_face_detection", title="MediaPipe Face Detection", tags=["controlnet", "face"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class MediaPipeFaceDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index bc13b72c7b..da24d8802b 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -166,6 +166,14 @@ GENERATION_MODES = Literal[ "z_image_img2img", "z_image_inpaint", "z_image_outpaint", + "qwen_image_txt2img", + "qwen_image_img2img", + "qwen_image_inpaint", + "qwen_image_outpaint", + "anima_txt2img", + "anima_img2img", + "anima_inpaint", + "anima_outpaint", ] diff --git a/invokeai/app/invocations/metadata_linked.py b/invokeai/app/invocations/metadata_linked.py index 6a9db3e589..53f2ea7471 100644 --- a/invokeai/app/invocations/metadata_linked.py +++ b/invokeai/app/invocations/metadata_linked.py @@ -621,7 +621,7 @@ class LatentsMetaOutput(LatentsOutput, MetadataOutput): "denoise_latents_meta", title=f"{DenoiseLatentsInvocation.UIConfig.title} + Metadata", tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], - category="latents", + category="metadata", version="1.1.1", ) class DenoiseLatentsMetaInvocation(DenoiseLatentsInvocation, WithMetadata): @@ -686,7 +686,7 @@ class DenoiseLatentsMetaInvocation(DenoiseLatentsInvocation, WithMetadata): "flux_denoise_meta", title=f"{FluxDenoiseInvocation.UIConfig.title} + Metadata", tags=["flux", "latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], - category="latents", + category="metadata", version="1.0.1", ) class FluxDenoiseLatentsMetaInvocation(FluxDenoiseInvocation, WithMetadata): @@ -734,7 +734,7 @@ class FluxDenoiseLatentsMetaInvocation(FluxDenoiseInvocation, WithMetadata): "z_image_denoise_meta", title=f"{ZImageDenoiseInvocation.UIConfig.title} + Metadata", tags=["z-image", "latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], - category="latents", + category="metadata", version="1.0.0", ) class ZImageDenoiseMetaInvocation(ZImageDenoiseInvocation, WithMetadata): diff --git a/invokeai/app/invocations/mlsd.py b/invokeai/app/invocations/mlsd.py index 1526350db8..a2446876c8 100644 --- a/invokeai/app/invocations/mlsd.py +++ b/invokeai/app/invocations/mlsd.py @@ -10,7 +10,7 @@ from invokeai.backend.image_util.mlsd.models.mbv2_mlsd_large import MobileV2_MLS "mlsd_detection", title="MLSD Detection", tags=["controlnet", "mlsd", "edge"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class MLSDDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 29fbe5100c..0c96cdb1d9 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -72,6 +72,13 @@ class GlmEncoderField(BaseModel): text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel") +class QwenVLEncoderField(BaseModel): + """Field for Qwen2.5-VL encoder used by Qwen Image Edit models.""" + + tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel") + text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel") + + class Qwen3EncoderField(BaseModel): """Field for Qwen3 text encoder used by Z-Image models.""" @@ -577,7 +584,7 @@ class SeamlessModeInvocation(BaseInvocation): return SeamlessModeOutput(unet=unet, vae=vae) -@invocation("freeu", title="Apply FreeU - SD1.5, SDXL", tags=["freeu"], category="unet", version="1.0.2") +@invocation("freeu", title="Apply FreeU - SD1.5, SDXL", tags=["freeu"], category="model", version="1.0.2") class FreeUInvocation(BaseInvocation): """ Applies FreeU to the UNet. Suggested values (b1/b2/s1/s2): diff --git a/invokeai/app/invocations/normal_bae.py b/invokeai/app/invocations/normal_bae.py index ebbea869a1..1159927150 100644 --- a/invokeai/app/invocations/normal_bae.py +++ b/invokeai/app/invocations/normal_bae.py @@ -10,7 +10,7 @@ from invokeai.backend.image_util.normal_bae.nets.NNET import NNET "normal_map", title="Normal Map", tags=["controlnet", "normal"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class NormalMapInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/pbr_maps.py b/invokeai/app/invocations/pbr_maps.py index 5e519d38bc..945c3cad59 100644 --- a/invokeai/app/invocations/pbr_maps.py +++ b/invokeai/app/invocations/pbr_maps.py @@ -16,7 +16,9 @@ class PBRMapsOutput(BaseInvocationOutput): displacement_map: ImageField = OutputField(default=None, description="The generated displacement map") -@invocation("pbr_maps", title="PBR Maps", tags=["image", "material"], category="image", version="1.0.0") +@invocation( + "pbr_maps", title="PBR Maps", tags=["image", "material"], category="controlnet_preprocessors", version="1.0.0" +) class PBRMapsInvocation(BaseInvocation, WithMetadata, WithBoard): """Generate Normal, Displacement and Roughness Map from a given image""" diff --git a/invokeai/app/invocations/pidi.py b/invokeai/app/invocations/pidi.py index 47b241ee1f..5d8cab0458 100644 --- a/invokeai/app/invocations/pidi.py +++ b/invokeai/app/invocations/pidi.py @@ -10,7 +10,7 @@ from invokeai.backend.image_util.pidi.model import PiDiNet "pidi_edge_detection", title="PiDiNet Edge Detection", tags=["controlnet", "edge"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class PiDiNetEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index dcb1fc6a45..7ec6c3dc14 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -12,6 +12,7 @@ from invokeai.app.invocations.baseinvocation import ( ) from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.fields import ( + AnimaConditioningField, BoundingBoxField, CogView4ConditioningField, ColorField, @@ -24,6 +25,7 @@ from invokeai.app.invocations.fields import ( InputField, LatentsField, OutputField, + QwenImageConditioningField, SD3ConditioningField, TensorField, UIComponent, @@ -473,6 +475,28 @@ class ZImageConditioningOutput(BaseInvocationOutput): return cls(conditioning=ZImageConditioningField(conditioning_name=conditioning_name)) +@invocation_output("qwen_image_conditioning_output") +class QwenImageConditioningOutput(BaseInvocationOutput): + """Base class for nodes that output a Qwen Image Edit conditioning tensor.""" + + conditioning: QwenImageConditioningField = OutputField(description=FieldDescriptions.cond) + + @classmethod + def build(cls, conditioning_name: str) -> "QwenImageConditioningOutput": + return cls(conditioning=QwenImageConditioningField(conditioning_name=conditioning_name)) + + +@invocation_output("anima_conditioning_output") +class AnimaConditioningOutput(BaseInvocationOutput): + """Base class for nodes that output an Anima text conditioning tensor.""" + + conditioning: AnimaConditioningField = OutputField(description=FieldDescriptions.cond) + + @classmethod + def build(cls, conditioning_name: str) -> "AnimaConditioningOutput": + return cls(conditioning=AnimaConditioningField(conditioning_name=conditioning_name)) + + @invocation_output("conditioning_output") class ConditioningOutput(BaseInvocationOutput): """Base class for nodes that output a single conditioning tensor""" diff --git a/invokeai/app/invocations/qwen_image_denoise.py b/invokeai/app/invocations/qwen_image_denoise.py new file mode 100644 index 0000000000..04e21a26c3 --- /dev/null +++ b/invokeai/app/invocations/qwen_image_denoise.py @@ -0,0 +1,490 @@ +from contextlib import ExitStack +from typing import Callable, Iterator, Optional, Tuple + +import torch +import torchvision.transforms as tv_transforms +from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformer2DModel +from torchvision.transforms.functional import resize as tv_resize +from tqdm import tqdm + +from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR +from invokeai.app.invocations.fields import ( + DenoiseMaskField, + FieldDescriptions, + Input, + InputField, + LatentsField, + QwenImageConditioningField, + WithBoard, + WithMetadata, +) +from invokeai.app.invocations.model import TransformerField +from invokeai.app.invocations.primitives import LatentsOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat +from invokeai.backend.patches.layer_patcher import LayerPatcher +from invokeai.backend.patches.lora_conversions.qwen_image_lora_constants import ( + QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX, +) +from invokeai.backend.patches.model_patch_raw import ModelPatchRaw +from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension +from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import QwenImageConditioningInfo +from invokeai.backend.util.devices import TorchDevice + + +@invocation( + "qwen_image_denoise", + title="Denoise - Qwen Image", + tags=["image", "qwen_image"], + category="image", + version="1.0.0", + classification=Classification.Prototype, +) +class QwenImageDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): + """Run the denoising process with a Qwen Image model.""" + + # If latents is provided, this means we are doing image-to-image. + latents: Optional[LatentsField] = InputField( + default=None, description=FieldDescriptions.latents, input=Input.Connection + ) + # Reference image latents (encoded through VAE) to concatenate with noisy latents. + reference_latents: Optional[LatentsField] = InputField( + default=None, + description="Reference image latents to guide generation. Encoded through the VAE.", + input=Input.Connection, + ) + # denoise_mask is used for image-to-image inpainting. Only the masked region is modified. + denoise_mask: Optional[DenoiseMaskField] = InputField( + default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection + ) + denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start) + denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end) + transformer: TransformerField = InputField( + description=FieldDescriptions.qwen_image_model, input=Input.Connection, title="Transformer" + ) + positive_conditioning: QwenImageConditioningField = InputField( + description=FieldDescriptions.positive_cond, input=Input.Connection + ) + negative_conditioning: Optional[QwenImageConditioningField] = InputField( + default=None, description=FieldDescriptions.negative_cond, input=Input.Connection + ) + cfg_scale: float | list[float] = InputField(default=4.0, description=FieldDescriptions.cfg_scale, title="CFG Scale") + width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.") + height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.") + steps: int = InputField(default=40, gt=0, description=FieldDescriptions.steps) + seed: int = InputField(default=0, description="Randomness seed for reproducibility.") + shift: Optional[float] = InputField( + default=None, + description="Override the sigma schedule shift. " + "When set, uses a fixed shift (e.g. 3.0 for Lightning LoRAs) instead of the default dynamic shifting. " + "Leave unset for the base model's default schedule.", + ) + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> LatentsOutput: + latents = self._run_diffusion(context) + latents = latents.detach().to("cpu") + + name = context.tensors.save(tensor=latents) + return LatentsOutput.build(latents_name=name, latents=latents, seed=None) + + def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None: + if self.denoise_mask is None: + return None + mask = context.tensors.load(self.denoise_mask.mask_name) + mask = 1.0 - mask + + _, _, latent_height, latent_width = latents.shape + mask = tv_resize( + img=mask, + size=[latent_height, latent_width], + interpolation=tv_transforms.InterpolationMode.BILINEAR, + antialias=False, + ) + + mask = mask.to(device=latents.device, dtype=latents.dtype) + return mask + + def _load_text_conditioning( + self, + context: InvocationContext, + conditioning_name: str, + dtype: torch.dtype, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + cond_data = context.conditioning.load(conditioning_name) + assert len(cond_data.conditionings) == 1 + conditioning = cond_data.conditionings[0] + assert isinstance(conditioning, QwenImageConditioningInfo) + conditioning = conditioning.to(dtype=dtype, device=device) + return conditioning.prompt_embeds, conditioning.prompt_embeds_mask + + def _get_noise( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + seed: int, + ) -> torch.Tensor: + rand_device = "cpu" + rand_dtype = torch.float32 + + return torch.randn( + batch_size, + num_channels_latents, + int(height) // LATENT_SCALE_FACTOR, + int(width) // LATENT_SCALE_FACTOR, + device=rand_device, + dtype=rand_dtype, + generator=torch.Generator(device=rand_device).manual_seed(seed), + ).to(device=device, dtype=dtype) + + def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]: + if isinstance(self.cfg_scale, float): + cfg_scale = [self.cfg_scale] * num_timesteps + elif isinstance(self.cfg_scale, list): + assert len(self.cfg_scale) == num_timesteps + cfg_scale = self.cfg_scale + else: + raise ValueError(f"Invalid CFG scale type: {type(self.cfg_scale)}") + return cfg_scale + + @staticmethod + def _pack_latents( + latents: torch.Tensor, batch_size: int, num_channels: int, height: int, width: int + ) -> torch.Tensor: + """Pack 4D latents (B, C, H, W) into 2x2-patched 3D (B, H/2*W/2, C*4).""" + latents = latents.view(batch_size, num_channels, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels * 4) + return latents + + @staticmethod + def _unpack_latents(latents: torch.Tensor, height: int, width: int) -> torch.Tensor: + """Unpack 3D patched latents (B, seq, C*4) back to 4D (B, C, H, W).""" + batch_size, _num_patches, channels = latents.shape + # height/width are in latent space; they must be divisible by 2 for packing + h = 2 * (height // 2) + w = 2 * (width // 2) + latents = latents.view(batch_size, h // 2, w // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + latents = latents.reshape(batch_size, channels // 4, h, w) + return latents + + def _run_diffusion(self, context: InvocationContext): + inference_dtype = torch.bfloat16 + device = TorchDevice.choose_torch_device() + + transformer_info = context.models.load(self.transformer.transformer) + assert isinstance(transformer_info.model, QwenImageTransformer2DModel) + + # Load conditioning + pos_prompt_embeds, pos_prompt_mask = self._load_text_conditioning( + context=context, + conditioning_name=self.positive_conditioning.conditioning_name, + dtype=inference_dtype, + device=device, + ) + + neg_prompt_embeds = None + neg_prompt_mask = None + # Match the diffusers pipeline: only enable CFG when cfg_scale > 1 AND negative conditioning is provided. + # With cfg_scale <= 1, the negative prediction is unused, so skip it entirely. + # For per-step arrays, enable CFG if any step has scale > 1. + if isinstance(self.cfg_scale, list): + any_cfg_above_one = any(v > 1.0 for v in self.cfg_scale) + else: + any_cfg_above_one = self.cfg_scale > 1.0 + do_classifier_free_guidance = self.negative_conditioning is not None and any_cfg_above_one + if do_classifier_free_guidance: + neg_prompt_embeds, neg_prompt_mask = self._load_text_conditioning( + context=context, + conditioning_name=self.negative_conditioning.conditioning_name, + dtype=inference_dtype, + device=device, + ) + + # Prepare the timestep / sigma schedule + patch_size = transformer_info.model.config.patch_size + assert isinstance(patch_size, int) + # Output channels is 16 (the actual latent channels) + out_channels = transformer_info.model.config.out_channels + assert isinstance(out_channels, int) + + latent_height = self.height // LATENT_SCALE_FACTOR + latent_width = self.width // LATENT_SCALE_FACTOR + image_seq_len = (latent_height * latent_width) // (patch_size**2) + + # Use the actual FlowMatchEulerDiscreteScheduler to compute sigmas/timesteps, + # exactly matching the diffusers pipeline. + import math + + import numpy as np + from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler + + # Try to load the scheduler config from the model's directory (Diffusers models + # have a scheduler/ subdir). For GGUF models this path doesn't exist, so fall + # back to instantiating the scheduler with the known Qwen Image defaults. + model_path = context.models.get_absolute_path(context.models.get_config(self.transformer.transformer)) + scheduler_path = model_path / "scheduler" + if scheduler_path.is_dir() and (scheduler_path / "scheduler_config.json").exists(): + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(str(scheduler_path), local_files_only=True) + else: + scheduler = FlowMatchEulerDiscreteScheduler( + use_dynamic_shifting=True, + base_shift=0.5, + max_shift=0.9, + base_image_seq_len=256, + max_image_seq_len=8192, + shift_terminal=0.02, + num_train_timesteps=1000, + time_shift_type="exponential", + ) + + if self.shift is not None: + # Lightning LoRA: fixed shift + mu = math.log(self.shift) + else: + # Default dynamic shifting + # Linear interpolation matching diffusers' calculate_shift + base_shift = scheduler.config.get("base_shift", 0.5) + max_shift = scheduler.config.get("max_shift", 0.9) + base_seq = scheduler.config.get("base_image_seq_len", 256) + max_seq = scheduler.config.get("max_image_seq_len", 4096) + m = (max_shift - base_shift) / (max_seq - base_seq) + b = base_shift - m * base_seq + mu = image_seq_len * m + b + + init_sigmas = np.linspace(1.0, 1.0 / self.steps, self.steps).tolist() + scheduler.set_timesteps(sigmas=init_sigmas, mu=mu, device=device) + + # Clip the schedule based on denoising_start/denoising_end to support img2img strength. + # The scheduler's sigmas go from high (noisy) to 0 (clean). We clip to the fractional range. + sigmas_sched = scheduler.sigmas # (N+1,) including terminal 0 + if self.denoising_start > 0 or self.denoising_end < 1: + total_sigmas = len(sigmas_sched) - 1 # exclude terminal + start_idx = int(round(self.denoising_start * total_sigmas)) + end_idx = int(round(self.denoising_end * total_sigmas)) + sigmas_sched = sigmas_sched[start_idx : end_idx + 1] # +1 to include the next sigma for dt + # Rebuild timesteps from clipped sigmas (exclude terminal 0) + timesteps_sched = sigmas_sched[:-1] * scheduler.config.num_train_timesteps + else: + timesteps_sched = scheduler.timesteps + + total_steps = len(timesteps_sched) + + cfg_scale = self._prepare_cfg_scale(total_steps) + + # Load initial latents if provided (for img2img) + init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None + if init_latents is not None: + init_latents = init_latents.to(device=device, dtype=inference_dtype) + if init_latents.dim() == 5: + init_latents = init_latents.squeeze(2) + + # Load reference image latents if provided + ref_latents = None + if self.reference_latents is not None: + ref_latents = context.tensors.load(self.reference_latents.latents_name) + ref_latents = ref_latents.to(device=device, dtype=inference_dtype) + # The VAE encoder produces 5D latents (B, C, 1, H, W); squeeze the frame dim + # so we have 4D (B, C, H, W) for packing. + if ref_latents.dim() == 5: + ref_latents = ref_latents.squeeze(2) + + # Generate noise (16 channels - the output latent channels) + noise = self._get_noise( + batch_size=1, + num_channels_latents=out_channels, + height=self.height, + width=self.width, + dtype=inference_dtype, + device=device, + seed=self.seed, + ) + + # Prepare input latent image + if init_latents is not None: + s_0 = sigmas_sched[0].item() + latents = s_0 * noise + (1.0 - s_0) * init_latents + else: + if self.denoising_start > 1e-5: + raise ValueError("denoising_start should be 0 when initial latents are not provided.") + latents = noise + + if total_steps <= 0: + return latents + + # Pack latents into 2x2 patches: (B, C, H, W) -> (B, H/2*W/2, C*4) + latents = self._pack_latents(latents, 1, out_channels, latent_height, latent_width) + + # Determine whether the model uses reference latent conditioning (zero_cond_t). + # Edit models (zero_cond_t=True) expect [noisy_patches ; ref_patches] in the sequence. + # Txt2img models (zero_cond_t=False) only take noisy patches. + has_zero_cond_t = getattr(transformer_info.model, "zero_cond_t", False) or getattr( + transformer_info.model.config, "zero_cond_t", False + ) + use_ref_latents = has_zero_cond_t + + ref_latents_packed = None + if use_ref_latents: + if ref_latents is not None: + _, ref_ch, rh, rw = ref_latents.shape + if rh != latent_height or rw != latent_width: + ref_latents = torch.nn.functional.interpolate( + ref_latents, size=(latent_height, latent_width), mode="bilinear" + ) + else: + # No reference image provided — use zeros so the model still gets the + # expected sequence layout. + ref_latents = torch.zeros( + 1, out_channels, latent_height, latent_width, device=device, dtype=inference_dtype + ) + ref_latents_packed = self._pack_latents(ref_latents, 1, out_channels, latent_height, latent_width) + + # img_shapes tells the transformer the spatial layout of patches. + if use_ref_latents: + img_shapes = [ + [ + (1, latent_height // 2, latent_width // 2), + (1, latent_height // 2, latent_width // 2), + ] + ] + else: + img_shapes = [ + [ + (1, latent_height // 2, latent_width // 2), + ] + ] + + # Prepare inpaint extension (operates in 4D space, so unpack/repack around it) + inpaint_mask = self._prep_inpaint_mask(context, noise) # noise has the right 4D shape + inpaint_extension: RectifiedFlowInpaintExtension | None = None + if inpaint_mask is not None: + assert init_latents is not None + inpaint_extension = RectifiedFlowInpaintExtension( + init_latents=init_latents, + inpaint_mask=inpaint_mask, + noise=noise, + ) + + step_callback = self._build_step_callback(context) + + step_callback( + PipelineIntermediateState( + step=0, + order=1, + total_steps=total_steps, + timestep=int(timesteps_sched[0].item()) if len(timesteps_sched) > 0 else 0, + latents=self._unpack_latents(latents, latent_height, latent_width), + ), + ) + + noisy_seq_len = latents.shape[1] + + # Determine if the model is quantized — GGUF models need sidecar patching for LoRAs + transformer_config = context.models.get_config(self.transformer.transformer) + model_is_quantized = transformer_config.format in (ModelFormat.GGUFQuantized,) + + with ExitStack() as exit_stack: + (cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device()) + assert isinstance(transformer, QwenImageTransformer2DModel) + + # Apply LoRA patches to the transformer + exit_stack.enter_context( + LayerPatcher.apply_smart_model_patches( + model=transformer, + patches=self._lora_iterator(context), + prefix=QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX, + dtype=inference_dtype, + cached_weights=cached_weights, + force_sidecar_patching=model_is_quantized, + ) + ) + + for step_idx, t in enumerate(tqdm(timesteps_sched)): + # The pipeline passes timestep / 1000 to the transformer + timestep = t.expand(latents.shape[0]).to(inference_dtype) + + # For edit models: concatenate noisy and reference patches along the sequence dim + # For txt2img models: just use noisy patches + if ref_latents_packed is not None: + model_input = torch.cat([latents, ref_latents_packed], dim=1) + else: + model_input = latents + + noise_pred_cond = transformer( + hidden_states=model_input, + encoder_hidden_states=pos_prompt_embeds, + encoder_hidden_states_mask=pos_prompt_mask, + timestep=timestep / 1000, + img_shapes=img_shapes, + return_dict=False, + )[0] + # Only keep the noisy-latent portion of the output + noise_pred_cond = noise_pred_cond[:, :noisy_seq_len] + + if do_classifier_free_guidance and neg_prompt_embeds is not None: + noise_pred_uncond = transformer( + hidden_states=model_input, + encoder_hidden_states=neg_prompt_embeds, + encoder_hidden_states_mask=neg_prompt_mask, + timestep=timestep / 1000, + img_shapes=img_shapes, + return_dict=False, + )[0] + noise_pred_uncond = noise_pred_uncond[:, :noisy_seq_len] + + noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + + # Euler step using the (possibly clipped) sigma schedule + sigma_curr = sigmas_sched[step_idx] + sigma_next = sigmas_sched[step_idx + 1] + dt = sigma_next - sigma_curr + latents = latents.to(torch.float32) + dt * noise_pred.to(torch.float32) + latents = latents.to(inference_dtype) + + if inpaint_extension is not None: + sigma_next = sigmas_sched[step_idx + 1].item() + latents_4d = self._unpack_latents(latents, latent_height, latent_width) + latents_4d = inpaint_extension.merge_intermediate_latents_with_init_latents(latents_4d, sigma_next) + latents = self._pack_latents(latents_4d, 1, out_channels, latent_height, latent_width) + + step_callback( + PipelineIntermediateState( + step=step_idx + 1, + order=1, + total_steps=total_steps, + timestep=int(t.item()), + latents=self._unpack_latents(latents, latent_height, latent_width), + ), + ) + + # Unpack back to 4D then add frame dim for the video-style VAE: (B, C, 1, H, W) + latents = self._unpack_latents(latents, latent_height, latent_width) + latents = latents.unsqueeze(2) + return latents + + def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]: + def step_callback(state: PipelineIntermediateState) -> None: + context.util.sd_step_callback(state, BaseModelType.QwenImage) + + return step_callback + + def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]: + """Iterate over LoRA models to apply to the transformer.""" + for lora in self.transformer.loras: + lora_info = context.models.load(lora.lora) + if not isinstance(lora_info.model, ModelPatchRaw): + raise TypeError( + f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}." + ) + yield (lora_info.model, lora.weight) + del lora_info diff --git a/invokeai/app/invocations/qwen_image_image_to_latents.py b/invokeai/app/invocations/qwen_image_image_to_latents.py new file mode 100644 index 0000000000..c5fe1b5d5c --- /dev/null +++ b/invokeai/app/invocations/qwen_image_image_to_latents.py @@ -0,0 +1,96 @@ +import einops +import torch +from diffusers.models.autoencoders.autoencoder_kl_qwenimage import AutoencoderKLQwenImage +from PIL import Image as PILImage + +from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation +from invokeai.app.invocations.fields import ( + FieldDescriptions, + ImageField, + Input, + InputField, + WithBoard, + WithMetadata, +) +from invokeai.app.invocations.model import VAEField +from invokeai.app.invocations.primitives import LatentsOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.load.load_base import LoadedModel +from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor +from invokeai.backend.util.devices import TorchDevice + + +@invocation( + "qwen_image_i2l", + title="Image to Latents - Qwen Image", + tags=["image", "latents", "vae", "i2l", "qwen_image"], + category="image", + version="1.0.0", + classification=Classification.Prototype, +) +class QwenImageImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard): + """Generates latents from an image using the Qwen Image VAE.""" + + image: ImageField = InputField(description="The image to encode.") + vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection) + width: int | None = InputField( + default=None, + description="Resize the image to this width before encoding. If not set, encodes at the image's original size.", + ) + height: int | None = InputField( + default=None, + description="Resize the image to this height before encoding. If not set, encodes at the image's original size.", + ) + + @staticmethod + def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor: + with vae_info.model_on_device() as (_, vae): + assert isinstance(vae, AutoencoderKLQwenImage) + + vae.disable_tiling() + + image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype) + with torch.inference_mode(): + # The Qwen Image VAE expects 5D input: (B, C, num_frames, H, W) + if image_tensor.dim() == 4: + image_tensor = image_tensor.unsqueeze(2) + + posterior = vae.encode(image_tensor).latent_dist + # Use mode (argmax) for deterministic encoding, matching diffusers + latents: torch.Tensor = posterior.mode().to(dtype=vae.dtype) + + # Normalize with per-channel latents_mean / latents_std + latents_mean = ( + torch.tensor(vae.config.latents_mean) + .view(1, vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(vae.config.latents_std) + .view(1, vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents = (latents - latents_mean) / latents_std + + return latents + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> LatentsOutput: + image = context.images.get_pil(self.image.image_name) + + # If target dimensions are specified, resize the image BEFORE encoding + # (matching the diffusers pipeline which resizes in pixel space, not latent space). + if self.width is not None and self.height is not None: + image = image.convert("RGB").resize((self.width, self.height), resample=PILImage.LANCZOS) + + image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) + if image_tensor.dim() == 3: + image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") + + vae_info = context.models.load(self.vae.vae) + + latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor) + + latents = latents.to("cpu") + name = context.tensors.save(tensor=latents) + return LatentsOutput.build(latents_name=name, latents=latents, seed=None) diff --git a/invokeai/app/invocations/qwen_image_latents_to_image.py b/invokeai/app/invocations/qwen_image_latents_to_image.py new file mode 100644 index 0000000000..b3ea39c4bb --- /dev/null +++ b/invokeai/app/invocations/qwen_image_latents_to_image.py @@ -0,0 +1,85 @@ +from contextlib import nullcontext + +import torch +from diffusers.models.autoencoders.autoencoder_kl_qwenimage import AutoencoderKLQwenImage +from einops import rearrange +from PIL import Image + +from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation +from invokeai.app.invocations.fields import ( + FieldDescriptions, + Input, + InputField, + LatentsField, + WithBoard, + WithMetadata, +) +from invokeai.app.invocations.model import VAEField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt +from invokeai.backend.util.devices import TorchDevice + + +@invocation( + "qwen_image_l2i", + title="Latents to Image - Qwen Image", + tags=["latents", "image", "vae", "l2i", "qwen_image"], + category="latents", + version="1.0.0", + classification=Classification.Prototype, +) +class QwenImageLatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): + """Generates an image from latents using the Qwen Image VAE.""" + + latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection) + vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection) + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> ImageOutput: + latents = context.tensors.load(self.latents.latents_name) + + vae_info = context.models.load(self.vae.vae) + assert isinstance(vae_info.model, AutoencoderKLQwenImage) + with ( + SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), + vae_info.model_on_device() as (_, vae), + ): + context.util.signal_progress("Running VAE") + assert isinstance(vae, AutoencoderKLQwenImage) + latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype) + + vae.disable_tiling() + + tiling_context = nullcontext() + + TorchDevice.empty_cache() + + with torch.inference_mode(), tiling_context: + # The Qwen Image VAE uses per-channel latents_mean / latents_std + # instead of a single scaling_factor. + # Latents are 5D: (B, C, num_frames, H, W) — the unpack from the + # denoise step already produces this shape. + latents_mean = ( + torch.tensor(vae.config.latents_mean) + .view(1, vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + + img = vae.decode(latents, return_dict=False)[0] + # Drop the temporal frame dimension: (B, C, 1, H, W) -> (B, C, H, W) + img = img[:, :, 0] + + img = img.clamp(-1, 1) + img = rearrange(img[0], "c h w -> h w c") + img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy()) + + TorchDevice.empty_cache() + + image_dto = context.images.save(image=img_pil) + + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/qwen_image_lora_loader.py b/invokeai/app/invocations/qwen_image_lora_loader.py new file mode 100644 index 0000000000..f670b2d895 --- /dev/null +++ b/invokeai/app/invocations/qwen_image_lora_loader.py @@ -0,0 +1,115 @@ +from typing import Optional + +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + Classification, + invocation, + invocation_output, +) +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField +from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType + + +@invocation_output("qwen_image_lora_loader_output") +class QwenImageLoRALoaderOutput(BaseInvocationOutput): + """Qwen Image LoRA Loader Output""" + + transformer: Optional[TransformerField] = OutputField( + default=None, description=FieldDescriptions.transformer, title="Transformer" + ) + + +@invocation( + "qwen_image_lora_loader", + title="Apply LoRA - Qwen Image", + tags=["lora", "model", "qwen_image"], + category="model", + version="1.0.0", + classification=Classification.Prototype, +) +class QwenImageLoRALoaderInvocation(BaseInvocation): + """Apply a LoRA model to a Qwen Image transformer.""" + + lora: ModelIdentifierField = InputField( + description=FieldDescriptions.lora_model, + title="LoRA", + ui_model_base=BaseModelType.QwenImage, + ui_model_type=ModelType.LoRA, + ) + weight: float = InputField(default=1.0, description=FieldDescriptions.lora_weight) + transformer: TransformerField | None = InputField( + default=None, + description=FieldDescriptions.transformer, + input=Input.Connection, + title="Transformer", + ) + + def invoke(self, context: InvocationContext) -> QwenImageLoRALoaderOutput: + lora_key = self.lora.key + + if not context.models.exists(lora_key): + raise ValueError(f"Unknown lora: {lora_key}!") + + if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras): + raise ValueError(f'LoRA "{lora_key}" already applied to transformer.') + + output = QwenImageLoRALoaderOutput() + + if self.transformer is not None: + output.transformer = self.transformer.model_copy(deep=True) + output.transformer.loras.append( + LoRAField( + lora=self.lora, + weight=self.weight, + ) + ) + + return output + + +@invocation( + "qwen_image_lora_collection_loader", + title="Apply LoRA Collection - Qwen Image", + tags=["lora", "model", "qwen_image"], + category="model", + version="1.0.0", + classification=Classification.Prototype, +) +class QwenImageLoRACollectionLoader(BaseInvocation): + """Applies a collection of LoRAs to a Qwen Image transformer.""" + + loras: Optional[LoRAField | list[LoRAField]] = InputField( + default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs" + ) + transformer: Optional[TransformerField] = InputField( + default=None, + description=FieldDescriptions.transformer, + input=Input.Connection, + title="Transformer", + ) + + def invoke(self, context: InvocationContext) -> QwenImageLoRALoaderOutput: + output = QwenImageLoRALoaderOutput() + loras = self.loras if isinstance(self.loras, list) else [self.loras] + added_loras: list[str] = [] + + if self.transformer is not None: + output.transformer = self.transformer.model_copy(deep=True) + + for lora in loras: + if lora is None: + continue + if lora.lora.key in added_loras: + continue + if not context.models.exists(lora.lora.key): + raise Exception(f"Unknown lora: {lora.lora.key}!") + + added_loras.append(lora.lora.key) + + if self.transformer is not None and output.transformer is not None: + output.transformer.loras.append(lora) + + return output diff --git a/invokeai/app/invocations/qwen_image_model_loader.py b/invokeai/app/invocations/qwen_image_model_loader.py new file mode 100644 index 0000000000..fd96067f56 --- /dev/null +++ b/invokeai/app/invocations/qwen_image_model_loader.py @@ -0,0 +1,107 @@ +from typing import Optional + +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + Classification, + invocation, + invocation_output, +) +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField +from invokeai.app.invocations.model import ( + ModelIdentifierField, + QwenVLEncoderField, + TransformerField, + VAEField, +) +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, SubModelType + + +@invocation_output("qwen_image_model_loader_output") +class QwenImageModelLoaderOutput(BaseInvocationOutput): + """Qwen Image model loader output.""" + + transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer") + qwen_vl_encoder: QwenVLEncoderField = OutputField( + description=FieldDescriptions.qwen_vl_encoder, title="Qwen VL Encoder" + ) + vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE") + + +@invocation( + "qwen_image_model_loader", + title="Main Model - Qwen Image", + tags=["model", "qwen_image"], + category="model", + version="1.1.0", + classification=Classification.Prototype, +) +class QwenImageModelLoaderInvocation(BaseInvocation): + """Loads a Qwen Image model, outputting its submodels. + + The transformer is always loaded from the main model (Diffusers or GGUF). + + For GGUF quantized models, the VAE and Qwen VL encoder must come from a + separate Diffusers model specified in the "Component Source" field. + + For Diffusers models, all components are extracted from the main model + automatically. The "Component Source" field is ignored. + """ + + model: ModelIdentifierField = InputField( + description=FieldDescriptions.qwen_image_model, + input=Input.Direct, + ui_model_base=BaseModelType.QwenImage, + ui_model_type=ModelType.Main, + title="Transformer", + ) + + component_source: Optional[ModelIdentifierField] = InputField( + default=None, + description="Diffusers Qwen Image model to extract the VAE and Qwen VL encoder from. " + "Required when using a GGUF quantized transformer. " + "Ignored when the main model is already in Diffusers format.", + input=Input.Direct, + ui_model_base=BaseModelType.QwenImage, + ui_model_type=ModelType.Main, + ui_model_format=ModelFormat.Diffusers, + title="Component Source (Diffusers)", + ) + + def invoke(self, context: InvocationContext) -> QwenImageModelLoaderOutput: + main_config = context.models.get_config(self.model) + main_is_diffusers = main_config.format == ModelFormat.Diffusers + + # Transformer always comes from the main model + transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer}) + + if main_is_diffusers: + # Diffusers model: extract all components directly + vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE}) + tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer}) + text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder}) + elif self.component_source is not None: + # GGUF/checkpoint transformer: get VAE + encoder from the component source + source_config = context.models.get_config(self.component_source) + if source_config.format != ModelFormat.Diffusers: + raise ValueError( + f"The Component Source model must be in Diffusers format. " + f"The selected model '{source_config.name}' is in {source_config.format.value} format." + ) + vae = self.component_source.model_copy(update={"submodel_type": SubModelType.VAE}) + tokenizer = self.component_source.model_copy(update={"submodel_type": SubModelType.Tokenizer}) + text_encoder = self.component_source.model_copy(update={"submodel_type": SubModelType.TextEncoder}) + else: + raise ValueError( + "No source for VAE and Qwen VL encoder. " + "GGUF quantized models only contain the transformer — " + "please set 'Component Source' to a Diffusers Qwen Image model " + "to provide the VAE and text encoder." + ) + + return QwenImageModelLoaderOutput( + transformer=TransformerField(transformer=transformer, loras=[]), + qwen_vl_encoder=QwenVLEncoderField(tokenizer=tokenizer, text_encoder=text_encoder), + vae=VAEField(vae=vae), + ) diff --git a/invokeai/app/invocations/qwen_image_text_encoder.py b/invokeai/app/invocations/qwen_image_text_encoder.py new file mode 100644 index 0000000000..a067421452 --- /dev/null +++ b/invokeai/app/invocations/qwen_image_text_encoder.py @@ -0,0 +1,298 @@ +from typing import Literal + +import torch +from PIL import Image as PILImage + +from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation +from invokeai.app.invocations.fields import ( + FieldDescriptions, + ImageField, + Input, + InputField, + UIComponent, +) +from invokeai.app.invocations.model import QwenVLEncoderField +from invokeai.app.invocations.primitives import QwenImageConditioningOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + ConditioningFieldData, + QwenImageConditioningInfo, +) + +# Prompt templates and drop indices for the two Qwen Image model modes. +# These are taken directly from the diffusers pipelines. + +# Image editing mode (QwenImagePipeline) +_EDIT_SYSTEM_PROMPT = ( + "Describe the key features of the input image (color, shape, size, texture, objects, background), " + "then explain how the user's text instruction should alter or modify the image. " + "Generate a new image that meets the user's requirements while maintaining consistency " + "with the original input where appropriate." +) +_EDIT_DROP_IDX = 64 + +# Text-to-image mode (QwenImagePipeline) +_GENERATE_SYSTEM_PROMPT = ( + "Describe the image by detailing the color, shape, size, texture, quantity, " + "text, spatial relationships of the objects and background:" +) +_GENERATE_DROP_IDX = 34 + +_IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>" + + +def _build_prompt(user_prompt: str, num_images: int) -> str: + """Build the full prompt with the appropriate template based on whether reference images are provided.""" + if num_images > 0: + # Edit mode: include vision placeholders for reference images + image_tokens = _IMAGE_PLACEHOLDER * num_images + return ( + f"<|im_start|>system\n{_EDIT_SYSTEM_PROMPT}<|im_end|>\n" + f"<|im_start|>user\n{image_tokens}{user_prompt}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + else: + # Generate mode: text-only prompt + return ( + f"<|im_start|>system\n{_GENERATE_SYSTEM_PROMPT}<|im_end|>\n" + f"<|im_start|>user\n{user_prompt}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + +@invocation( + "qwen_image_text_encoder", + title="Prompt - Qwen Image", + tags=["prompt", "conditioning", "qwen_image"], + category="conditioning", + version="1.2.0", + classification=Classification.Prototype, +) +class QwenImageTextEncoderInvocation(BaseInvocation): + """Encodes text and reference images for Qwen Image using Qwen2.5-VL.""" + + prompt: str = InputField(description="Text prompt describing the desired edit.", ui_component=UIComponent.Textarea) + reference_images: list[ImageField] = InputField( + default=[], + description="Reference images to guide the edit. The model can use multiple reference images.", + ) + qwen_vl_encoder: QwenVLEncoderField = InputField( + title="Qwen VL Encoder", + description=FieldDescriptions.qwen_vl_encoder, + input=Input.Connection, + ) + quantization: Literal["none", "int8", "nf4"] = InputField( + default="none", + description="Quantize the Qwen VL encoder to reduce VRAM usage. " + "'nf4' (4-bit) saves the most memory, 'int8' (8-bit) is a middle ground.", + ) + + @staticmethod + def _resize_for_vl_encoder(image: PILImage.Image, target_pixels: int = 512 * 512) -> PILImage.Image: + """Resize image to fit within target_pixels while preserving aspect ratio. + + Matches the diffusers pipeline's calculate_dimensions logic: the image is resized + so its total pixel count is approximately target_pixels, with dimensions rounded to + multiples of 32. This prevents large images from producing too many vision tokens + which can overwhelm the text prompt. + """ + w, h = image.size + aspect = w / h + # Compute dimensions that preserve aspect ratio at ~target_pixels total + new_w = int((target_pixels * aspect) ** 0.5) + new_h = int(target_pixels / new_w) + # Round to multiples of 32 + new_w = max(32, (new_w // 32) * 32) + new_h = max(32, (new_h // 32) * 32) + if new_w != w or new_h != h: + image = image.resize((new_w, new_h), resample=PILImage.LANCZOS) + return image + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> QwenImageConditioningOutput: + # Load and resize reference images to ~1M pixels (matching diffusers pipeline) + pil_images: list[PILImage.Image] = [] + for img_field in self.reference_images: + pil_img = context.images.get_pil(img_field.image_name) + pil_img = self._resize_for_vl_encoder(pil_img.convert("RGB")) + pil_images.append(pil_img) + + prompt_embeds, prompt_mask = self._encode(context, pil_images) + prompt_embeds = prompt_embeds.detach().to("cpu") + prompt_mask = prompt_mask.detach().to("cpu") if prompt_mask is not None else None + + conditioning_data = ConditioningFieldData( + conditionings=[QwenImageConditioningInfo(prompt_embeds=prompt_embeds, prompt_embeds_mask=prompt_mask)] + ) + conditioning_name = context.conditioning.save(conditioning_data) + return QwenImageConditioningOutput.build(conditioning_name) + + def _encode( + self, context: InvocationContext, images: list[PILImage.Image] + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Encode text prompt and reference images using Qwen2.5-VL. + + Matches the diffusers QwenImagePipeline._get_qwen_prompt_embeds logic: + 1. Format prompt with the edit-specific system template + 2. Run through Qwen2.5-VL to get hidden states + 3. Extract valid (non-padding) tokens and drop the system prefix + 4. Return padded embeddings + attention mask + """ + from transformers import AutoTokenizer, Qwen2_5_VLProcessor + + try: + from transformers import Qwen2_5_VLImageProcessor as _ImageProcessorCls + except ImportError: + from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( # type: ignore[no-redef] + Qwen2VLImageProcessor as _ImageProcessorCls, + ) + + try: + from transformers import Qwen2_5_VLVideoProcessor as _VideoProcessorCls + except ImportError: + from transformers.models.qwen2_vl.video_processing_qwen2_vl import ( # type: ignore[no-redef] + Qwen2VLVideoProcessor as _VideoProcessorCls, + ) + + # Format the prompt with one vision placeholder per reference image + text = _build_prompt(self.prompt, len(images)) + + # Build the processor + tokenizer_config = context.models.get_config(self.qwen_vl_encoder.tokenizer) + model_root = context.models.get_absolute_path(tokenizer_config) + tokenizer_dir = model_root / "tokenizer" + + tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_dir), local_files_only=True) + + image_processor = None + for search_dir in [model_root / "processor", tokenizer_dir, model_root, model_root / "image_processor"]: + if (search_dir / "preprocessor_config.json").exists(): + image_processor = _ImageProcessorCls.from_pretrained(str(search_dir), local_files_only=True) + break + if image_processor is None: + image_processor = _ImageProcessorCls() + + processor = Qwen2_5_VLProcessor( + tokenizer=tokenizer, + image_processor=image_processor, + video_processor=_VideoProcessorCls(), + ) + + context.util.signal_progress("Running Qwen2.5-VL text/vision encoder") + + if self.quantization != "none": + text_encoder, device, cleanup = self._load_quantized_encoder(context) + else: + text_encoder, device, cleanup = self._load_cached_encoder(context) + + try: + model_inputs = processor( + text=[text], + images=images if images else None, + padding=True, + return_tensors="pt", + ).to(device=device) + + outputs = text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=getattr(model_inputs, "pixel_values", None), + image_grid_thw=getattr(model_inputs, "image_grid_thw", None), + output_hidden_states=True, + ) + + # Use last hidden state (matching diffusers pipeline) + hidden_states = outputs.hidden_states[-1] + + # Extract valid (non-padding) tokens using the attention mask, + # then drop the system prompt prefix tokens. + # The drop index differs between edit mode (64) and generate mode (34). + drop_idx = _EDIT_DROP_IDX if images else _GENERATE_DROP_IDX + + attn_mask = model_inputs.attention_mask + bool_mask = attn_mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_hidden = torch.split(selected, valid_lengths.tolist(), dim=0) + + # Drop system prefix tokens and build padded output + trimmed = [h[drop_idx:] for h in split_hidden] + attn_mask_list = [torch.ones(h.size(0), dtype=torch.long, device=device) for h in trimmed] + max_seq_len = max(h.size(0) for h in trimmed) + + prompt_embeds = torch.stack( + [torch.cat([h, h.new_zeros(max_seq_len - h.size(0), h.size(1))]) for h in trimmed] + ) + encoder_attention_mask = torch.stack( + [torch.cat([m, m.new_zeros(max_seq_len - m.size(0))]) for m in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=torch.bfloat16) + finally: + if cleanup is not None: + cleanup() + + # If all tokens are valid (no padding), mask is not needed + if encoder_attention_mask.all(): + encoder_attention_mask = None + + return prompt_embeds, encoder_attention_mask + + def _load_cached_encoder(self, context: InvocationContext): + """Load the text encoder through the model cache (no quantization).""" + from transformers import Qwen2_5_VLForConditionalGeneration + + text_encoder_info = context.models.load(self.qwen_vl_encoder.text_encoder) + ctx = text_encoder_info.model_on_device() + _, text_encoder = ctx.__enter__() + device = get_effective_device(text_encoder) + assert isinstance(text_encoder, Qwen2_5_VLForConditionalGeneration) + return text_encoder, device, lambda: ctx.__exit__(None, None, None) + + def _load_quantized_encoder(self, context: InvocationContext): + """Load the text encoder with BitsAndBytes quantization, bypassing the model cache. + + BnB-quantized models are pinned to GPU and can't be moved between devices, + so they can't go through the standard model cache. The model is loaded fresh + each time and freed after use via the cleanup callback. + """ + import gc + import warnings + + from transformers import BitsAndBytesConfig, Qwen2_5_VLForConditionalGeneration + + encoder_config = context.models.get_config(self.qwen_vl_encoder.text_encoder) + model_root = context.models.get_absolute_path(encoder_config) + encoder_path = model_root / "text_encoder" + + if self.quantization == "nf4": + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_quant_type="nf4", + ) + else: # int8 + bnb_config = BitsAndBytesConfig(load_in_8bit=True) + + context.util.signal_progress("Loading Qwen2.5-VL encoder (quantized)") + with warnings.catch_warnings(): + # BnB int8 internally casts bfloat16→float16; the warning is harmless + warnings.filterwarnings("ignore", message="MatMul8bitLt.*cast.*float16") + text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( + str(encoder_path), + quantization_config=bnb_config, + device_map="auto", + torch_dtype=torch.bfloat16, + local_files_only=True, + ) + + device = next(text_encoder.parameters()).device + + def cleanup(): + nonlocal text_encoder + del text_encoder + gc.collect() + torch.cuda.empty_cache() + + return text_encoder, device, cleanup diff --git a/invokeai/app/invocations/sd3_denoise.py b/invokeai/app/invocations/sd3_denoise.py index b9d69369b7..4b990ee42b 100644 --- a/invokeai/app/invocations/sd3_denoise.py +++ b/invokeai/app/invocations/sd3_denoise.py @@ -34,7 +34,7 @@ from invokeai.backend.util.devices import TorchDevice "sd3_denoise", title="Denoise - SD3", tags=["image", "sd3"], - category="image", + category="latents", version="1.1.1", ) class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/sd3_image_to_latents.py b/invokeai/app/invocations/sd3_image_to_latents.py index 71a48ee9ad..9af641d8bc 100644 --- a/invokeai/app/invocations/sd3_image_to_latents.py +++ b/invokeai/app/invocations/sd3_image_to_latents.py @@ -24,7 +24,7 @@ from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory "sd3_i2l", title="Image to Latents - SD3", tags=["image", "latents", "vae", "i2l", "sd3"], - category="image", + category="latents", version="1.0.1", ) class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/sd3_text_encoder.py b/invokeai/app/invocations/sd3_text_encoder.py index 24647c9cfc..7af138fe45 100644 --- a/invokeai/app/invocations/sd3_text_encoder.py +++ b/invokeai/app/invocations/sd3_text_encoder.py @@ -16,6 +16,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField from invokeai.app.invocations.model import CLIPField, T5EncoderField from invokeai.app.invocations.primitives import SD3ConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device from invokeai.backend.model_manager.taxonomy import ModelFormat from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX @@ -30,7 +31,7 @@ SD3_T5_MAX_SEQ_LEN = 256 "sd3_text_encoder", title="Prompt - SD3", tags=["prompt", "conditioning", "sd3"], - category="conditioning", + category="prompt", version="1.0.1", ) class Sd3TextEncoderInvocation(BaseInvocation): @@ -103,6 +104,7 @@ class Sd3TextEncoderInvocation(BaseInvocation): context.util.signal_progress("Running T5 encoder") assert isinstance(t5_text_encoder, T5EncoderModel) assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast)) + t5_device = get_effective_device(t5_text_encoder) text_inputs = t5_tokenizer( prompt, @@ -125,7 +127,7 @@ class Sd3TextEncoderInvocation(BaseInvocation): f" {max_seq_len} tokens: {removed_text}" ) - prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0] + prompt_embeds = t5_text_encoder(text_input_ids.to(t5_device))[0] assert isinstance(prompt_embeds, torch.Tensor) return prompt_embeds @@ -144,6 +146,7 @@ class Sd3TextEncoderInvocation(BaseInvocation): context.util.signal_progress("Running CLIP encoder") assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)) assert isinstance(clip_tokenizer, CLIPTokenizer) + clip_device = get_effective_device(clip_text_encoder) clip_text_encoder_config = clip_text_encoder_info.config assert clip_text_encoder_config is not None @@ -187,9 +190,7 @@ class Sd3TextEncoderInvocation(BaseInvocation): "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer_max_length} tokens: {removed_text}" ) - prompt_embeds = clip_text_encoder( - input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True - ) + prompt_embeds = clip_text_encoder(input_ids=text_input_ids.to(clip_device), output_hidden_states=True) pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] diff --git a/invokeai/app/invocations/strings.py b/invokeai/app/invocations/strings.py index 2b6bf300b9..6d64e8771a 100644 --- a/invokeai/app/invocations/strings.py +++ b/invokeai/app/invocations/strings.py @@ -20,7 +20,7 @@ class StringPosNegOutput(BaseInvocationOutput): "string_split_neg", title="String Split Negative", tags=["string", "split", "negative"], - category="string", + category="strings", version="1.0.1", ) class StringSplitNegInvocation(BaseInvocation): @@ -63,7 +63,7 @@ class String2Output(BaseInvocationOutput): string_2: str = OutputField(description="string 2") -@invocation("string_split", title="String Split", tags=["string", "split"], category="string", version="1.0.1") +@invocation("string_split", title="String Split", tags=["string", "split"], category="strings", version="1.0.1") class StringSplitInvocation(BaseInvocation): """Splits string into two strings, based on the first occurance of the delimiter. The delimiter will be removed from the string""" @@ -83,7 +83,7 @@ class StringSplitInvocation(BaseInvocation): return String2Output(string_1=part1, string_2=part2) -@invocation("string_join", title="String Join", tags=["string", "join"], category="string", version="1.0.1") +@invocation("string_join", title="String Join", tags=["string", "join"], category="strings", version="1.0.1") class StringJoinInvocation(BaseInvocation): """Joins string left to string right""" @@ -94,7 +94,9 @@ class StringJoinInvocation(BaseInvocation): return StringOutput(value=((self.string_left or "") + (self.string_right or ""))) -@invocation("string_join_three", title="String Join Three", tags=["string", "join"], category="string", version="1.0.1") +@invocation( + "string_join_three", title="String Join Three", tags=["string", "join"], category="strings", version="1.0.1" +) class StringJoinThreeInvocation(BaseInvocation): """Joins string left to string middle to string right""" @@ -107,7 +109,7 @@ class StringJoinThreeInvocation(BaseInvocation): @invocation( - "string_replace", title="String Replace", tags=["string", "replace", "regex"], category="string", version="1.0.1" + "string_replace", title="String Replace", tags=["string", "replace", "regex"], category="strings", version="1.0.1" ) class StringReplaceInvocation(BaseInvocation): """Replaces the search string with the replace string""" diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index 15f1881eef..cf4b7cda47 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -49,7 +49,7 @@ class T2IAdapterOutput(BaseInvocationOutput): "t2i_adapter", title="T2I-Adapter - SD1.5, SDXL", tags=["t2i_adapter", "control"], - category="t2i_adapter", + category="conditioning", version="1.0.4", ) class T2IAdapterInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index e7b3968aec..64e372a0f6 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -30,7 +30,7 @@ ESRGAN_MODEL_URLS: dict[str, str] = { } -@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2") +@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="upscale", version="1.3.2") class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): """Upscales an image using RealESRGAN.""" diff --git a/invokeai/app/invocations/z_image_control.py b/invokeai/app/invocations/z_image_control.py index 3b01f12373..f51c2fcd16 100644 --- a/invokeai/app/invocations/z_image_control.py +++ b/invokeai/app/invocations/z_image_control.py @@ -57,7 +57,7 @@ class ZImageControlOutput(BaseInvocationOutput): "z_image_control", title="Z-Image ControlNet", tags=["image", "z-image", "control", "controlnet"], - category="control", + category="conditioning", version="1.1.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/z_image_denoise.py b/invokeai/app/invocations/z_image_denoise.py index 6ab43d6657..397e917112 100644 --- a/invokeai/app/invocations/z_image_denoise.py +++ b/invokeai/app/invocations/z_image_denoise.py @@ -49,8 +49,8 @@ from invokeai.backend.z_image.z_image_transformer_patch import patch_transformer "z_image_denoise", title="Denoise - Z-Image", tags=["image", "z-image"], - category="image", - version="1.4.0", + category="latents", + version="1.5.0", classification=Classification.Prototype, ) class ZImageDenoiseInvocation(BaseInvocation): @@ -104,6 +104,15 @@ class ZImageDenoiseInvocation(BaseInvocation): description=FieldDescriptions.vae + " Required for control conditioning.", input=Input.Connection, ) + # Shift override for the sigma schedule. If None, shift is auto-calculated from image dimensions. + shift: Optional[float] = InputField( + default=None, + ge=0.0, + description="Override the timestep shift (mu) for the sigma schedule. " + "Leave blank to auto-calculate based on image dimensions (recommended). " + "Lower values (~0.5) produce less noise shifting, higher values (~1.15) produce more.", + title="Shift", + ) # Scheduler selection for the denoising process scheduler: ZIMAGE_SCHEDULER_NAME_VALUES = InputField( default="euler", @@ -225,34 +234,36 @@ class ZImageDenoiseInvocation(BaseInvocation): """Calculate timestep shift based on image sequence length. Based on diffusers ZImagePipeline.calculate_shift method. - """ - m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len) - b = base_shift - m * base_image_seq_len - mu = image_seq_len * m + b - return mu - - def _get_sigmas(self, mu: float, num_steps: int) -> list[float]: - """Generate sigma schedule with time shift. - - Based on FlowMatchEulerDiscreteScheduler with shift. - Generates num_steps + 1 sigma values (including terminal 0.0). + Returns a linear shift value (exp(mu) from the original formula). """ import math - def time_shift(mu: float, sigma: float, t: float) -> float: - """Apply time shift to a single timestep value.""" + m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len) + b = base_shift - m * base_image_seq_len + mu = image_seq_len * m + b + # Convert from exponential mu to linear shift value + return math.exp(mu) + + def _get_sigmas(self, shift: float, num_steps: int) -> list[float]: + """Generate sigma schedule with linear time shift. + + Uses linear time shift: shift / (shift + (1/t - 1)). + The shift value is used directly as a multiplier. + Generates num_steps + 1 sigma values (including terminal 0.0). + """ + + def time_shift(shift: float, t: float) -> float: + """Apply linear time shift to a single timestep value.""" if t <= 0: return 0.0 if t >= 1: return 1.0 - return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + return shift / (shift + (1 / t - 1)) - # Generate linearly spaced values from 1 to 0 (excluding endpoints for safety) - # then apply time shift sigmas = [] for i in range(num_steps + 1): t = 1.0 - i / num_steps # Goes from 1.0 to 0.0 - sigma = time_shift(mu, 1.0, t) + sigma = time_shift(shift, t) sigmas.append(sigma) return sigmas @@ -313,11 +324,14 @@ class ZImageDenoiseInvocation(BaseInvocation): # Concatenate all negative embeddings neg_prompt_embeds = torch.cat([tc.prompt_embeds for tc in neg_text_conditionings], dim=0) - # Calculate shift based on image sequence length - mu = self._calculate_shift(img_seq_len) + # Calculate shift based on image sequence length, or use override + if self.shift is not None: + shift = self.shift + else: + shift = self._calculate_shift(img_seq_len) # Generate sigma schedule with time shift - sigmas = self._get_sigmas(mu, self.steps) + sigmas = self._get_sigmas(shift, self.steps) # Apply denoising_start and denoising_end clipping if self.denoising_start > 0 or self.denoising_end < 1: diff --git a/invokeai/app/invocations/z_image_image_to_latents.py b/invokeai/app/invocations/z_image_image_to_latents.py index 5a70fdba13..263346e296 100644 --- a/invokeai/app/invocations/z_image_image_to_latents.py +++ b/invokeai/app/invocations/z_image_image_to_latents.py @@ -30,7 +30,7 @@ ZImageVAE = Union[AutoencoderKL, FluxAutoEncoder] "z_image_i2l", title="Image to Latents - Z-Image", tags=["image", "latents", "vae", "i2l", "z-image"], - category="image", + category="latents", version="1.1.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/z_image_seed_variance_enhancer.py b/invokeai/app/invocations/z_image_seed_variance_enhancer.py index b24002e971..72819a966a 100644 --- a/invokeai/app/invocations/z_image_seed_variance_enhancer.py +++ b/invokeai/app/invocations/z_image_seed_variance_enhancer.py @@ -19,7 +19,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( "z_image_seed_variance_enhancer", title="Seed Variance Enhancer - Z-Image", tags=["conditioning", "z-image", "variance", "seed"], - category="conditioning", + category="prompt", version="1.0.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/z_image_text_encoder.py b/invokeai/app/invocations/z_image_text_encoder.py index 06718c4897..71af6085d0 100644 --- a/invokeai/app/invocations/z_image_text_encoder.py +++ b/invokeai/app/invocations/z_image_text_encoder.py @@ -16,6 +16,7 @@ from invokeai.app.invocations.fields import ( from invokeai.app.invocations.model import Qwen3EncoderField from invokeai.app.invocations.primitives import ZImageConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.lora_conversions.z_image_lora_constants import Z_IMAGE_LORA_QWEN3_PREFIX from invokeai.backend.patches.model_patch_raw import ModelPatchRaw @@ -33,7 +34,7 @@ Z_IMAGE_MAX_SEQ_LEN = 512 "z_image_text_encoder", title="Prompt - Z-Image", tags=["prompt", "conditioning", "z-image"], - category="conditioning", + category="prompt", version="1.1.0", classification=Classification.Prototype, ) @@ -76,11 +77,17 @@ class ZImageTextEncoderInvocation(BaseInvocation): tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer) with ExitStack() as exit_stack: - (_, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device()) + (cached_weights, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device()) (_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device()) - # Use the device that the text_encoder is actually on - device = text_encoder.device + # Use the device that the text encoder is effectively executing on, and repair any required tensors left on + # the CPU by a previous interrupted run. + repaired_tensors = text_encoder_info.repair_required_tensors_on_device() + device = get_effective_device(text_encoder) + if repaired_tensors > 0: + context.logger.warning( + f"Recovered {repaired_tensors} required Qwen3 tensor(s) onto {device} after a partial device mismatch." + ) # Apply LoRA models to the text encoder lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device) @@ -90,6 +97,7 @@ class ZImageTextEncoderInvocation(BaseInvocation): patches=self._lora_iterator(context), prefix=Z_IMAGE_LORA_QWEN3_PREFIX, dtype=lora_dtype, + cached_weights=cached_weights, ) ) diff --git a/invokeai/app/services/auth/token_service.py b/invokeai/app/services/auth/token_service.py index 9c35261c38..2d766bb90a 100644 --- a/invokeai/app/services/auth/token_service.py +++ b/invokeai/app/services/auth/token_service.py @@ -21,6 +21,7 @@ class TokenData(BaseModel): user_id: str email: str is_admin: bool + remember_me: bool = False def set_jwt_secret(secret: str) -> None: diff --git a/invokeai/app/services/board_records/board_records_common.py b/invokeai/app/services/board_records/board_records_common.py index ab6355a393..b263f264cb 100644 --- a/invokeai/app/services/board_records/board_records_common.py +++ b/invokeai/app/services/board_records/board_records_common.py @@ -9,6 +9,17 @@ from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.model_exclude_null import BaseModelExcludeNull +class BoardVisibility(str, Enum, metaclass=MetaEnum): + """The visibility options for a board.""" + + Private = "private" + """Only the board owner (and admins) can see and modify this board.""" + Shared = "shared" + """All users can view this board, but only the owner (and admins) can modify it.""" + Public = "public" + """All users can view this board; only the owner (and admins) can modify its structure.""" + + class BoardRecord(BaseModelExcludeNull): """Deserialized board record.""" @@ -28,6 +39,10 @@ class BoardRecord(BaseModelExcludeNull): """The name of the cover image of the board.""" archived: bool = Field(description="Whether or not the board is archived.") """Whether or not the board is archived.""" + board_visibility: BoardVisibility = Field( + default=BoardVisibility.Private, description="The visibility of the board." + ) + """The visibility of the board (private, shared, or public).""" def deserialize_board_record(board_dict: dict) -> BoardRecord: @@ -44,6 +59,11 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord: updated_at = board_dict.get("updated_at", get_iso_timestamp()) deleted_at = board_dict.get("deleted_at", get_iso_timestamp()) archived = board_dict.get("archived", False) + board_visibility_raw = board_dict.get("board_visibility", BoardVisibility.Private.value) + try: + board_visibility = BoardVisibility(board_visibility_raw) + except ValueError: + board_visibility = BoardVisibility.Private return BoardRecord( board_id=board_id, @@ -54,6 +74,7 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord: updated_at=updated_at, deleted_at=deleted_at, archived=archived, + board_visibility=board_visibility, ) @@ -61,6 +82,7 @@ class BoardChanges(BaseModel, extra="forbid"): board_name: Optional[str] = Field(default=None, description="The board's new name.", max_length=300) cover_image_name: Optional[str] = Field(default=None, description="The name of the board's new cover image.") archived: Optional[bool] = Field(default=None, description="Whether or not the board is archived") + board_visibility: Optional[BoardVisibility] = Field(default=None, description="The visibility of the board.") class BoardRecordOrderBy(str, Enum, metaclass=MetaEnum): diff --git a/invokeai/app/services/board_records/board_records_sqlite.py b/invokeai/app/services/board_records/board_records_sqlite.py index a54f65686f..1e3e11c8a3 100644 --- a/invokeai/app/services/board_records/board_records_sqlite.py +++ b/invokeai/app/services/board_records/board_records_sqlite.py @@ -116,6 +116,17 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): (changes.archived, board_id), ) + # Change the visibility of a board + if changes.board_visibility is not None: + cursor.execute( + """--sql + UPDATE boards + SET board_visibility = ? + WHERE board_id = ?; + """, + (changes.board_visibility.value, board_id), + ) + except sqlite3.Error as e: raise BoardRecordSaveException from e return self.get(board_id) @@ -155,7 +166,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): SELECT DISTINCT boards.* FROM boards LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id - WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1) + WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public')) {archived_filter} ORDER BY {order_by} {direction} LIMIT ? OFFSET ?; @@ -194,14 +205,14 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): SELECT COUNT(DISTINCT boards.board_id) FROM boards LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id - WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1); + WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public')); """ else: count_query = """ SELECT COUNT(DISTINCT boards.board_id) FROM boards LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id - WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1) + WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public')) AND boards.archived = 0; """ @@ -251,7 +262,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): SELECT DISTINCT boards.* FROM boards LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id - WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1) + WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public')) {archived_filter} ORDER BY LOWER(boards.board_name) {direction} """ @@ -260,7 +271,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): SELECT DISTINCT boards.* FROM boards LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id - WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1) + WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public')) {archived_filter} ORDER BY {order_by} {direction} """ diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index 617b611f56..6cd4ed0cba 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -7,7 +7,11 @@ class BulkDownloadBase(ABC): @abstractmethod def handler( - self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str] + self, + image_names: Optional[list[str]], + board_id: Optional[str], + bulk_download_item_id: Optional[str], + user_id: str = "system", ) -> None: """ Create a zip file containing the images specified by the given image names or board id. @@ -15,6 +19,7 @@ class BulkDownloadBase(ABC): :param image_names: A list of image names to include in the zip file. :param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. :param bulk_download_item_id: The bulk_download_item_id that will be used to retrieve the bulk download item when it is prepared, if none is provided a uuid will be generated. + :param user_id: The ID of the user who initiated the download. """ @abstractmethod @@ -42,3 +47,12 @@ class BulkDownloadBase(ABC): :param bulk_download_item_name: The name of the bulk download item. """ + + @abstractmethod + def get_owner(self, bulk_download_item_name: str) -> Optional[str]: + """ + Get the user_id of the user who initiated the download. + + :param bulk_download_item_name: The name of the bulk download item. + :return: The user_id of the owner, or None if not tracked. + """ diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index dc4f8b1d81..c037e9c5c1 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -25,15 +25,24 @@ class BulkDownloadService(BulkDownloadBase): self._temp_directory = TemporaryDirectory() self._bulk_downloads_folder = Path(self._temp_directory.name) / "bulk_downloads" self._bulk_downloads_folder.mkdir(parents=True, exist_ok=True) + # Track which user owns each download so the fetch endpoint can enforce ownership + self._download_owners: dict[str, str] = {} def handler( - self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str] + self, + image_names: Optional[list[str]], + board_id: Optional[str], + bulk_download_item_id: Optional[str], + user_id: str = "system", ) -> None: bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID bulk_download_item_id = bulk_download_item_id or uuid_string() bulk_download_item_name = bulk_download_item_id + ".zip" - self._signal_job_started(bulk_download_id, bulk_download_item_id, bulk_download_item_name) + # Record ownership so the fetch endpoint can verify the caller + self._download_owners[bulk_download_item_name] = user_id + + self._signal_job_started(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id) try: image_dtos: list[ImageDTO] = [] @@ -46,16 +55,16 @@ class BulkDownloadService(BulkDownloadBase): raise BulkDownloadParametersException() bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id) - self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name) + self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id) except ( ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException, BulkDownloadParametersException, ) as e: - self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e) + self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e, user_id) except Exception as e: - self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e) + self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e, user_id) self._invoker.services.logger.error("Problem bulk downloading images.") raise e @@ -103,43 +112,60 @@ class BulkDownloadService(BulkDownloadBase): return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " " or c == "_" or c == "-"]).rstrip() def _signal_job_started( - self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + self, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + user_id: str = "system", ) -> None: """Signal that a bulk download job has started.""" if self._invoker: assert bulk_download_id is not None self._invoker.services.events.emit_bulk_download_started( - bulk_download_id, bulk_download_item_id, bulk_download_item_name + bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id=user_id ) def _signal_job_completed( - self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + self, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + user_id: str = "system", ) -> None: """Signal that a bulk download job has completed.""" if self._invoker: assert bulk_download_id is not None assert bulk_download_item_name is not None self._invoker.services.events.emit_bulk_download_complete( - bulk_download_id, bulk_download_item_id, bulk_download_item_name + bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id=user_id ) def _signal_job_failed( - self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, exception: Exception + self, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + exception: Exception, + user_id: str = "system", ) -> None: """Signal that a bulk download job has failed.""" if self._invoker: assert bulk_download_id is not None assert exception is not None self._invoker.services.events.emit_bulk_download_error( - bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception) + bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception), user_id=user_id ) def stop(self, *args, **kwargs): self._temp_directory.cleanup() + def get_owner(self, bulk_download_item_name: str) -> Optional[str]: + return self._download_owners.get(bulk_download_item_name) + def delete(self, bulk_download_item_name: str) -> None: path = self.get_path(bulk_download_item_name) Path(path).unlink() + self._download_owners.pop(bulk_download_item_name, None) def get_path(self, bulk_download_item_name: str) -> str: path = str(self._bulk_downloads_folder / bulk_download_item_name) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 1351a94a7b..b391f92020 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -110,7 +110,8 @@ class InvokeAIAppConfig(BaseSettings): force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty). pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. max_queue_size: Maximum number of items in the session queue. - clear_queue_on_startup: Empties session queue on startup. + clear_queue_on_startup: Empties session queue on startup. If true, disables `max_queue_history`. + max_queue_history: Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true. allow_nodes: List of nodes to allow. Omit to allow all. deny_nodes: List of nodes to deny. Omit to deny none. node_cache_size: How many cached nodes to keep in memory. @@ -204,7 +205,8 @@ class InvokeAIAppConfig(BaseSettings): force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).") pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.") max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.") - clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.") + clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup. If true, disables `max_queue_history`.") + max_queue_history: Optional[int] = Field(default=None, ge=0, description="Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true.") # NODES allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.") diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index aa1cbb5e0e..935b422a73 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -100,9 +100,9 @@ class EventServiceBase: """Emitted when a queue item's status changes""" self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status)) - def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None: + def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult", user_id: str = "system") -> None: """Emitted when a batch is enqueued""" - self.dispatch(BatchEnqueuedEvent.build(enqueue_result)) + self.dispatch(BatchEnqueuedEvent.build(enqueue_result, user_id)) def emit_queue_items_retried(self, retry_result: "RetryItemsResult") -> None: """Emitted when a list of queue items are retried""" @@ -112,9 +112,9 @@ class EventServiceBase: """Emitted when a queue is cleared""" self.dispatch(QueueClearedEvent.build(queue_id)) - def emit_recall_parameters_updated(self, queue_id: str, parameters: dict) -> None: + def emit_recall_parameters_updated(self, queue_id: str, user_id: str, parameters: dict) -> None: """Emitted when recall parameters are updated""" - self.dispatch(RecallParametersUpdatedEvent.build(queue_id, parameters)) + self.dispatch(RecallParametersUpdatedEvent.build(queue_id, user_id, parameters)) # endregion @@ -194,23 +194,42 @@ class EventServiceBase: # region Bulk image download def emit_bulk_download_started( - self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + self, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + user_id: str = "system", ) -> None: """Emitted when a bulk image download is started""" - self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name)) + self.dispatch( + BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id) + ) def emit_bulk_download_complete( - self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + self, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + user_id: str = "system", ) -> None: """Emitted when a bulk image download is complete""" - self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name)) + self.dispatch( + BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id) + ) def emit_bulk_download_error( - self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str + self, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + error: str, + user_id: str = "system", ) -> None: """Emitted when a bulk image download has an error""" self.dispatch( - BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error) + BulkDownloadErrorEvent.build( + bulk_download_id, bulk_download_item_id, bulk_download_item_name, error, user_id + ) ) # endregion diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index bfb44eb48e..998fe4f530 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -281,9 +281,10 @@ class BatchEnqueuedEvent(QueueEventBase): ) priority: int = Field(description="The priority of the batch") origin: str | None = Field(default=None, description="The origin of the batch") + user_id: str = Field(default="system", description="The ID of the user who enqueued the batch") @classmethod - def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent": + def build(cls, enqueue_result: EnqueueBatchResult, user_id: str = "system") -> "BatchEnqueuedEvent": return cls( queue_id=enqueue_result.queue_id, batch_id=enqueue_result.batch.batch_id, @@ -291,6 +292,7 @@ class BatchEnqueuedEvent(QueueEventBase): enqueued=enqueue_result.enqueued, requested=enqueue_result.requested, priority=enqueue_result.priority, + user_id=user_id, ) @@ -609,6 +611,7 @@ class BulkDownloadEventBase(EventBase): bulk_download_id: str = Field(description="The ID of the bulk image download") bulk_download_item_id: str = Field(description="The ID of the bulk image download item") bulk_download_item_name: str = Field(description="The name of the bulk image download item") + user_id: str = Field(default="system", description="The ID of the user who initiated the download") @payload_schema.register @@ -619,12 +622,17 @@ class BulkDownloadStartedEvent(BulkDownloadEventBase): @classmethod def build( - cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + cls, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + user_id: str = "system", ) -> "BulkDownloadStartedEvent": return cls( bulk_download_id=bulk_download_id, bulk_download_item_id=bulk_download_item_id, bulk_download_item_name=bulk_download_item_name, + user_id=user_id, ) @@ -636,12 +644,17 @@ class BulkDownloadCompleteEvent(BulkDownloadEventBase): @classmethod def build( - cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + cls, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + user_id: str = "system", ) -> "BulkDownloadCompleteEvent": return cls( bulk_download_id=bulk_download_id, bulk_download_item_id=bulk_download_item_id, bulk_download_item_name=bulk_download_item_name, + user_id=user_id, ) @@ -655,13 +668,19 @@ class BulkDownloadErrorEvent(BulkDownloadEventBase): @classmethod def build( - cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str + cls, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + error: str, + user_id: str = "system", ) -> "BulkDownloadErrorEvent": return cls( bulk_download_id=bulk_download_id, bulk_download_item_id=bulk_download_item_id, bulk_download_item_name=bulk_download_item_name, error=error, + user_id=user_id, ) @@ -671,8 +690,9 @@ class RecallParametersUpdatedEvent(QueueEventBase): __event_name__ = "recall_parameters_updated" + user_id: str = Field(description="The ID of the user whose recall parameters were updated") parameters: dict[str, Any] = Field(description="The recall parameters that were updated") @classmethod - def build(cls, queue_id: str, parameters: dict[str, Any]) -> "RecallParametersUpdatedEvent": - return cls(queue_id=queue_id, parameters=parameters) + def build(cls, queue_id: str, user_id: str, parameters: dict[str, Any]) -> "RecallParametersUpdatedEvent": + return cls(queue_id=queue_id, user_id=user_id, parameters=parameters) diff --git a/invokeai/app/services/events/events_fastapievents.py b/invokeai/app/services/events/events_fastapievents.py index f44eecc555..90e1402773 100644 --- a/invokeai/app/services/events/events_fastapievents.py +++ b/invokeai/app/services/events/events_fastapievents.py @@ -46,3 +46,9 @@ class FastAPIEventService(EventServiceBase): except asyncio.CancelledError as e: raise e # Raise a proper error + except Exception: + import logging + + logging.getLogger("InvokeAI").error( + f"Error dispatching event {getattr(event, '__event_name__', event)}", exc_info=True + ) diff --git a/invokeai/app/services/external_generation/__init__.py b/invokeai/app/services/external_generation/__init__.py index 692da64643..b933811d29 100644 --- a/invokeai/app/services/external_generation/__init__.py +++ b/invokeai/app/services/external_generation/__init__.py @@ -3,9 +3,9 @@ from invokeai.app.services.external_generation.external_generation_base import ( ExternalProvider, ) from invokeai.app.services.external_generation.external_generation_common import ( + ExternalGeneratedImage, ExternalGenerationRequest, ExternalGenerationResult, - ExternalGeneratedImage, ExternalProviderStatus, ExternalReferenceImage, ) diff --git a/invokeai/app/services/external_generation/errors.py b/invokeai/app/services/external_generation/errors.py index 9980b39bc4..f61a6a8c73 100644 --- a/invokeai/app/services/external_generation/errors.py +++ b/invokeai/app/services/external_generation/errors.py @@ -16,3 +16,13 @@ class ExternalProviderCapabilityError(ExternalGenerationError): class ExternalProviderRequestError(ExternalGenerationError): """Raised when a provider rejects the request or returns an error.""" + + +class ExternalProviderRateLimitError(ExternalProviderRequestError): + """Raised when a provider returns HTTP 429 (rate limit exceeded).""" + + retry_after: float | None + + def __init__(self, message: str, retry_after: float | None = None) -> None: + super().__init__(message) + self.retry_after = retry_after diff --git a/invokeai/app/services/external_generation/external_generation_common.py b/invokeai/app/services/external_generation/external_generation_common.py index c1e2f4706f..f14bff52dd 100644 --- a/invokeai/app/services/external_generation/external_generation_common.py +++ b/invokeai/app/services/external_generation/external_generation_common.py @@ -11,8 +11,6 @@ from invokeai.backend.model_manager.configs.external_api import ExternalApiModel @dataclass(frozen=True) class ExternalReferenceImage: image: PILImageType - weight: float | None = None - mode: str | None = None @dataclass(frozen=True) @@ -20,17 +18,16 @@ class ExternalGenerationRequest: model: ExternalApiModelConfig mode: ExternalGenerationMode prompt: str - negative_prompt: str | None seed: int | None num_images: int width: int height: int - steps: int | None - guidance: float | None + image_size: str | None init_image: PILImageType | None mask_image: PILImageType | None reference_images: list[ExternalReferenceImage] metadata: dict[str, Any] | None + provider_options: dict[str, Any] | None = None @dataclass(frozen=True) diff --git a/invokeai/app/services/external_generation/external_generation_default.py b/invokeai/app/services/external_generation/external_generation_default.py index ff54d71476..2622aa9b1c 100644 --- a/invokeai/app/services/external_generation/external_generation_default.py +++ b/invokeai/app/services/external_generation/external_generation_default.py @@ -1,5 +1,6 @@ from __future__ import annotations +import time from logging import Logger from typing import TYPE_CHECKING @@ -10,6 +11,7 @@ from invokeai.app.services.external_generation.errors import ( ExternalProviderCapabilityError, ExternalProviderNotConfiguredError, ExternalProviderNotFoundError, + ExternalProviderRateLimitError, ) from invokeai.app.services.external_generation.external_generation_base import ( ExternalGenerationServiceBase, @@ -52,7 +54,7 @@ class ExternalGenerationService(ExternalGenerationServiceBase): request = self._bucket_request(request) self._validate_request(request) - result = provider.generate(request) + result = self._generate_with_retry(provider, request) if resize_to_original_inpaint_size is None: return result @@ -60,6 +62,30 @@ class ExternalGenerationService(ExternalGenerationServiceBase): width, height = resize_to_original_inpaint_size return _resize_result_images(result, width, height) + _MAX_RETRIES = 3 + _DEFAULT_RETRY_DELAY = 10.0 + _MAX_RETRY_DELAY = 60.0 + + def _generate_with_retry( + self, provider: ExternalProvider, request: ExternalGenerationRequest + ) -> ExternalGenerationResult: + for attempt in range(self._MAX_RETRIES): + try: + return provider.generate(request) + except ExternalProviderRateLimitError as exc: + if attempt == self._MAX_RETRIES - 1: + raise + delay = min(exc.retry_after or self._DEFAULT_RETRY_DELAY, self._MAX_RETRY_DELAY) + self._logger.warning( + "Rate limited by %s (attempt %d/%d), retrying in %.0fs", + request.model.provider_id, + attempt + 1, + self._MAX_RETRIES, + delay, + ) + time.sleep(delay) + raise ExternalProviderRateLimitError("Rate limit exceeded after all retries") + def get_provider_statuses(self) -> dict[str, ExternalProviderStatus]: return {provider_id: provider.get_status() for provider_id, provider in self._providers.items()} @@ -77,15 +103,9 @@ class ExternalGenerationService(ExternalGenerationServiceBase): if request.mode not in capabilities.modes: raise ExternalProviderCapabilityError(f"Mode '{request.mode}' is not supported by {request.model.name}") - if request.negative_prompt and not capabilities.supports_negative_prompt: - raise ExternalProviderCapabilityError(f"Negative prompts are not supported by {request.model.name}") - if request.seed is not None and not capabilities.supports_seed: raise ExternalProviderCapabilityError(f"Seed control is not supported by {request.model.name}") - if request.guidance is not None and not capabilities.supports_guidance: - raise ExternalProviderCapabilityError(f"Guidance is not supported by {request.model.name}") - if request.reference_images and not capabilities.supports_reference_images: raise ExternalProviderCapabilityError(f"Reference images are not supported by {request.model.name}") @@ -159,17 +179,16 @@ class ExternalGenerationService(ExternalGenerationServiceBase): model=record, mode=request.mode, prompt=request.prompt, - negative_prompt=request.negative_prompt, seed=request.seed, num_images=request.num_images, width=request.width, height=request.height, - steps=request.steps, - guidance=request.guidance, + image_size=request.image_size, init_image=request.init_image, mask_image=request.mask_image, reference_images=request.reference_images, metadata=request.metadata, + provider_options=request.provider_options, ) def _bucket_request(self, request: ExternalGenerationRequest) -> ExternalGenerationRequest: @@ -229,17 +248,16 @@ class ExternalGenerationService(ExternalGenerationServiceBase): model=request.model, mode=request.mode, prompt=request.prompt, - negative_prompt=request.negative_prompt, seed=request.seed, num_images=request.num_images, width=width, height=height, - steps=request.steps, - guidance=request.guidance, + image_size=request.image_size, init_image=_resize_image(request.init_image, width, height, "RGB"), mask_image=_resize_image(request.mask_image, width, height, "L"), reference_images=request.reference_images, metadata=request.metadata, + provider_options=request.provider_options, ) diff --git a/invokeai/app/services/external_generation/providers/gemini.py b/invokeai/app/services/external_generation/providers/gemini.py index 4d43431a14..70cf6a0965 100644 --- a/invokeai/app/services/external_generation/providers/gemini.py +++ b/invokeai/app/services/external_generation/providers/gemini.py @@ -6,7 +6,10 @@ import uuid import requests from PIL.Image import Image as PILImageType -from invokeai.app.services.external_generation.errors import ExternalProviderRequestError +from invokeai.app.services.external_generation.errors import ( + ExternalProviderRateLimitError, + ExternalProviderRequestError, +) from invokeai.app.services.external_generation.external_generation_base import ExternalProvider from invokeai.app.services.external_generation.external_generation_common import ( ExternalGeneratedImage, @@ -64,15 +67,28 @@ class GeminiProvider(ExternalProvider): } ) + opts = request.provider_options or {} + generation_config: dict[str, object] = { "candidateCount": request.num_images, "responseModalities": ["IMAGE"], } + if "temperature" in opts: + generation_config["temperature"] = opts["temperature"] aspect_ratio = _select_aspect_ratio( request.width, request.height, request.model.capabilities.allowed_aspect_ratios, ) + uses_image_config = request.model.capabilities.resolution_presets is not None + if uses_image_config: + image_config: dict[str, str] = {} + if aspect_ratio is not None: + image_config["aspectRatio"] = aspect_ratio + if request.image_size is not None: + image_config["imageSize"] = request.image_size + if image_config: + generation_config["imageConfig"] = image_config system_instruction = self._SYSTEM_INSTRUCTION if request.init_image is not None: system_instruction = ( @@ -80,7 +96,7 @@ class GeminiProvider(ExternalProvider): "Treat the prompt as an edit instruction and modify the image accordingly. " "Do not return the original image unchanged." ) - if aspect_ratio is not None: + if not uses_image_config and aspect_ratio is not None: system_instruction = f"{system_instruction} Use an aspect ratio of {aspect_ratio}." payload: dict[str, object] = { @@ -88,6 +104,8 @@ class GeminiProvider(ExternalProvider): "contents": [{"role": "user", "parts": request_parts}], "generationConfig": generation_config, } + if "thinking_level" in opts: + payload["thinkingConfig"] = {"thinkingLevel": opts["thinking_level"].upper()} self._dump_debug_payload("request", payload) @@ -99,6 +117,12 @@ class GeminiProvider(ExternalProvider): ) if not response.ok: + if response.status_code == 429: + retry_after = _parse_retry_after(response.headers.get("retry-after")) + raise ExternalProviderRateLimitError( + f"Gemini rate limit exceeded. {f'Retry after {retry_after:.0f}s.' if retry_after else 'Please try again later.'}", + retry_after=retry_after, + ) raise ExternalProviderRequestError( f"Gemini request failed with status {response.status_code} for model '{model_id}': {response.text}" ) @@ -243,6 +267,15 @@ def _parse_ratio(value: str) -> float | None: return numerator / denominator +def _parse_retry_after(value: str | None) -> float | None: + if not value: + return None + try: + return float(value) + except ValueError: + return None + + def _gcd(a: int, b: int) -> int: while b: a, b = b, a % b diff --git a/invokeai/app/services/external_generation/providers/openai.py b/invokeai/app/services/external_generation/providers/openai.py index f06491a225..033f6cd4a6 100644 --- a/invokeai/app/services/external_generation/providers/openai.py +++ b/invokeai/app/services/external_generation/providers/openai.py @@ -5,7 +5,10 @@ import io import requests from PIL.Image import Image as PILImageType -from invokeai.app.services.external_generation.errors import ExternalProviderRequestError +from invokeai.app.services.external_generation.errors import ( + ExternalProviderRateLimitError, + ExternalProviderRequestError, +) from invokeai.app.services.external_generation.external_generation_base import ExternalProvider from invokeai.app.services.external_generation.external_generation_common import ( ExternalGeneratedImage, @@ -18,6 +21,8 @@ from invokeai.app.services.external_generation.image_utils import decode_image_b class OpenAIProvider(ExternalProvider): provider_id = "openai" + _GPT_IMAGE_MODELS = {"gpt-image-1", "gpt-image-1.5", "gpt-image-1-mini"} + def is_configured(self) -> bool: return bool(self._app_config.external_openai_api_key) @@ -26,21 +31,33 @@ class OpenAIProvider(ExternalProvider): if not api_key: raise ExternalProviderRequestError("OpenAI API key is not configured") + model_id = request.model.provider_model_id + is_gpt_image = model_id in self._GPT_IMAGE_MODELS size = f"{request.width}x{request.height}" base_url = (self._app_config.external_openai_base_url or "https://api.openai.com").rstrip("/") headers = {"Authorization": f"Bearer {api_key}"} use_edits_endpoint = request.mode != "txt2img" or bool(request.reference_images) + opts = request.provider_options or {} + if not use_edits_endpoint: payload: dict[str, object] = { + "model": model_id, "prompt": request.prompt, "n": request.num_images, "size": size, - "response_format": "b64_json", } - if request.seed is not None: - payload["seed"] = request.seed + # GPT Image models use output_format; DALL-E uses response_format + if is_gpt_image: + payload["output_format"] = "png" + else: + payload["response_format"] = "b64_json" + if is_gpt_image: + if opts.get("quality") and opts["quality"] != "auto": + payload["quality"] = opts["quality"] + if opts.get("background") and opts["background"] != "auto": + payload["background"] = opts["background"] response = requests.post( f"{base_url}/v1/images/generations", headers=headers, @@ -72,11 +89,22 @@ class OpenAIProvider(ExternalProvider): files.append(("mask", ("mask.png", mask_buffer, "image/png"))) data: dict[str, object] = { + "model": model_id, "prompt": request.prompt, "n": request.num_images, "size": size, - "response_format": "b64_json", } + if is_gpt_image: + data["output_format"] = "png" + else: + data["response_format"] = "b64_json" + if is_gpt_image: + if opts.get("quality") and opts["quality"] != "auto": + data["quality"] = opts["quality"] + if opts.get("background") and opts["background"] != "auto": + data["background"] = opts["background"] + if opts.get("input_fidelity"): + data["input_fidelity"] = opts["input_fidelity"] response = requests.post( f"{base_url}/v1/images/edits", headers=headers, @@ -86,15 +114,21 @@ class OpenAIProvider(ExternalProvider): ) if not response.ok: + if response.status_code == 429: + retry_after = _parse_retry_after(response.headers.get("retry-after")) + raise ExternalProviderRateLimitError( + f"OpenAI rate limit exceeded. {f'Retry after {retry_after:.0f}s.' if retry_after else 'Please try again later.'}", + retry_after=retry_after, + ) raise ExternalProviderRequestError( f"OpenAI request failed with status {response.status_code}: {response.text}" ) - payload = response.json() - if not isinstance(payload, dict): + response_payload = response.json() + if not isinstance(response_payload, dict): raise ExternalProviderRequestError("OpenAI response payload was not a JSON object") images: list[ExternalGeneratedImage] = [] - data_items = payload.get("data") + data_items = response_payload.get("data") if not isinstance(data_items, list): raise ExternalProviderRequestError("OpenAI response payload missing image data") for item in data_items: @@ -112,5 +146,14 @@ class OpenAIProvider(ExternalProvider): images=images, seed_used=request.seed, provider_request_id=response.headers.get("x-request-id"), - provider_metadata={"model": request.model.provider_model_id}, + provider_metadata={"model": model_id}, ) + + +def _parse_retry_after(value: str | None) -> float | None: + if not value: + return None + try: + return float(value) + except ValueError: + return None diff --git a/invokeai/app/services/image_records/image_records_base.py b/invokeai/app/services/image_records/image_records_base.py index 16405c5270..457cf2f468 100644 --- a/invokeai/app/services/image_records/image_records_base.py +++ b/invokeai/app/services/image_records/image_records_base.py @@ -74,8 +74,8 @@ class ImageRecordStorageBase(ABC): pass @abstractmethod - def get_intermediates_count(self) -> int: - """Gets a count of all intermediate images.""" + def get_intermediates_count(self, user_id: Optional[str] = None) -> int: + """Gets a count of intermediate images. If user_id is provided, only counts that user's intermediates.""" pass @abstractmethod @@ -97,6 +97,11 @@ class ImageRecordStorageBase(ABC): """Saves an image record.""" pass + @abstractmethod + def get_user_id(self, image_name: str) -> Optional[str]: + """Gets the user_id of the image owner. Returns None if image not found.""" + pass + @abstractmethod def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]: """Gets the most recent image for a board.""" diff --git a/invokeai/app/services/image_records/image_records_sqlite.py b/invokeai/app/services/image_records/image_records_sqlite.py index c6c237fc1e..07126d53a9 100644 --- a/invokeai/app/services/image_records/image_records_sqlite.py +++ b/invokeai/app/services/image_records/image_records_sqlite.py @@ -46,6 +46,20 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): return deserialize_image_record(dict(result)) + def get_user_id(self, image_name: str) -> Optional[str]: + with self._db.transaction() as cursor: + cursor.execute( + """--sql + SELECT user_id FROM images + WHERE image_name = ?; + """, + (image_name,), + ) + result = cast(Optional[sqlite3.Row], cursor.fetchone()) + if not result: + return None + return cast(Optional[str], dict(result).get("user_id")) + def get_metadata(self, image_name: str) -> Optional[MetadataField]: with self._db.transaction() as cursor: try: @@ -269,14 +283,14 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): except sqlite3.Error as e: raise ImageRecordDeleteException from e - def get_intermediates_count(self) -> int: + def get_intermediates_count(self, user_id: Optional[str] = None) -> int: with self._db.transaction() as cursor: - cursor.execute( - """--sql - SELECT COUNT(*) FROM images - WHERE is_intermediate = TRUE; - """ - ) + query = "SELECT COUNT(*) FROM images WHERE is_intermediate = TRUE" + params: list[str] = [] + if user_id is not None: + query += " AND user_id = ?" + params.append(user_id) + cursor.execute(query, params) count = cast(int, cursor.fetchone()[0]) return count diff --git a/invokeai/app/services/images/images_base.py b/invokeai/app/services/images/images_base.py index d11d75b3c1..aebbead2f3 100644 --- a/invokeai/app/services/images/images_base.py +++ b/invokeai/app/services/images/images_base.py @@ -143,8 +143,8 @@ class ImageServiceABC(ABC): pass @abstractmethod - def get_intermediates_count(self) -> int: - """Gets the number of intermediate images.""" + def get_intermediates_count(self, user_id: Optional[str] = None) -> int: + """Gets the number of intermediate images. If user_id is provided, only counts that user's intermediates.""" pass @abstractmethod diff --git a/invokeai/app/services/images/images_default.py b/invokeai/app/services/images/images_default.py index e82bd7f4de..0f03f7c400 100644 --- a/invokeai/app/services/images/images_default.py +++ b/invokeai/app/services/images/images_default.py @@ -310,9 +310,9 @@ class ImageService(ImageServiceABC): self.__invoker.services.logger.error("Problem deleting image records and files") raise e - def get_intermediates_count(self) -> int: + def get_intermediates_count(self, user_id: Optional[str] = None) -> int: try: - return self.__invoker.services.image_records.get_intermediates_count() + return self.__invoker.services.image_records.get_intermediates_count(user_id=user_id) except Exception as e: self.__invoker.services.logger.error("Problem getting intermediates count") raise e diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 361c2e4811..49d3cfdf7f 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import torch import yaml -from huggingface_hub import HfFolder +from huggingface_hub import get_token as hf_get_token from pydantic.networks import AnyHttpUrl from pydantic_core import Url from requests import Session @@ -1115,7 +1115,7 @@ class ModelInstallService(ModelInstallServiceBase): ) -> ModelInstallJob: # Add user's cached access token to HuggingFace requests if source.access_token is None: - source.access_token = HfFolder.get_token() + source.access_token = hf_get_token() remote_files, metadata = self._remote_files_from_source(source) return self._import_remote_model( source=source, diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 318ebb000e..6420949c29 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -30,6 +30,7 @@ from invokeai.backend.model_manager.taxonomy import ( ModelType, ModelVariantType, Qwen3VariantType, + QwenImageVariantType, SchedulerPredictionType, ZImageVariantType, ) @@ -109,7 +110,13 @@ class ModelRecordChanges(BaseModelExcludeNull): # Checkpoint-specific changes # TODO(MM2): Should we expose these? Feels footgun-y... variant: Optional[ - ModelVariantType | ClipVariantType | FluxVariantType | Flux2VariantType | ZImageVariantType | Qwen3VariantType + ModelVariantType + | ClipVariantType + | FluxVariantType + | Flux2VariantType + | ZImageVariantType + | QwenImageVariantType + | Qwen3VariantType ] = Field(description="The variant of the model.", default=None) prediction_type: Optional[SchedulerPredictionType] = Field( description="The prediction type of the model.", default=None diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index bda6ac98e3..7159c19e74 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -133,9 +133,6 @@ class DefaultSessionRunner(SessionRunnerBase): self._on_after_run_node(invocation, queue_item, output) - except KeyboardInterrupt: - # TODO(psyche): This is expected to be caught in the main thread. Do we need to catch this here? - pass except CanceledException: # A CanceledException is raised during the denoising step callback if the cancel event is set. We don't need # to do any handling here, and no error should be set - just pass and the cancellation will be handled diff --git a/invokeai/app/services/session_queue/session_queue_base.py b/invokeai/app/services/session_queue/session_queue_base.py index 3c037dc77a..14b93d97fc 100644 --- a/invokeai/app/services/session_queue/session_queue_base.py +++ b/invokeai/app/services/session_queue/session_queue_base.py @@ -78,13 +78,15 @@ class SessionQueueBase(ABC): pass @abstractmethod - def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination: - """Gets the counts of queue items by destination""" + def get_counts_by_destination( + self, queue_id: str, destination: str, user_id: Optional[str] = None + ) -> SessionQueueCountsByDestination: + """Gets the counts of queue items by destination. If user_id is provided, only counts that user's items.""" pass @abstractmethod - def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus: - """Gets the status of a batch""" + def get_batch_status(self, queue_id: str, batch_id: str, user_id: Optional[str] = None) -> BatchStatus: + """Gets the status of a batch. If user_id is provided, only counts that user's items.""" pass @abstractmethod @@ -172,8 +174,9 @@ class SessionQueueBase(ABC): self, queue_id: str, order_dir: SQLiteDirection = SQLiteDirection.Descending, + user_id: Optional[str] = None, ) -> ItemIdsResult: - """Gets all queue item ids that match the given parameters""" + """Gets all queue item ids that match the given parameters. If user_id is provided, only returns items for that user.""" pass @abstractmethod diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index 5854442211..09820fe621 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -304,12 +304,6 @@ class SessionQueueStatus(BaseModel): failed: int = Field(..., description="Number of queue items with status 'error'") canceled: int = Field(..., description="Number of queue items with status 'canceled'") total: int = Field(..., description="Total number of queue items") - user_pending: Optional[int] = Field( - default=None, description="Number of queue items with status 'pending' for the current user" - ) - user_in_progress: Optional[int] = Field( - default=None, description="Number of queue items with status 'in_progress' for the current user" - ) class SessionQueueCountsByDestination(BaseModel): diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 4f46136fd7..172dc08d55 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -45,10 +45,19 @@ class SqliteSessionQueue(SessionQueueBase): def start(self, invoker: Invoker) -> None: self.__invoker = invoker self._set_in_progress_to_canceled() - if self.__invoker.services.configuration.clear_queue_on_startup: + config = self.__invoker.services.configuration + if config.clear_queue_on_startup: clear_result = self.clear(DEFAULT_QUEUE_ID) if clear_result.deleted > 0: self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items") + return + + if config.max_queue_history is not None: + deleted = self._prune_terminal_to_limit(DEFAULT_QUEUE_ID, config.max_queue_history) + if deleted > 0: + self.__invoker.services.logger.info( + f"Pruned {deleted} completed/failed/canceled queue items (kept up to {config.max_queue_history})" + ) def __init__(self, db: SqliteDatabase) -> None: super().__init__() @@ -68,6 +77,51 @@ class SqliteSessionQueue(SessionQueueBase): """ ) + def _prune_terminal_to_limit(self, queue_id: str, keep: int) -> int: + """Prune terminal items (completed/failed/canceled) to keep at most N most-recent items.""" + with self._db.transaction() as cursor: + where = """--sql + WHERE + queue_id = ? + AND ( + status = 'completed' + OR status = 'failed' + OR status = 'canceled' + ) + """ + cursor.execute( + f"""--sql + SELECT COUNT(*) + FROM session_queue + {where} + AND item_id NOT IN ( + SELECT item_id + FROM session_queue + {where} + ORDER BY COALESCE(completed_at, updated_at, created_at) DESC, item_id DESC + LIMIT ? + ); + """, + (queue_id, queue_id, keep), + ) + count = cursor.fetchone()[0] + cursor.execute( + f"""--sql + DELETE + FROM session_queue + {where} + AND item_id NOT IN ( + SELECT item_id + FROM session_queue + {where} + ORDER BY COALESCE(completed_at, updated_at, created_at) DESC, item_id DESC + LIMIT ? + ); + """, + (queue_id, queue_id, keep), + ) + return count + def _get_current_queue_size(self, queue_id: str) -> int: """Gets the current number of pending queue items""" with self._db.transaction() as cursor: @@ -151,7 +205,7 @@ class SqliteSessionQueue(SessionQueueBase): priority=priority, item_ids=item_ids, ) - self.__invoker.services.events.emit_batch_enqueued(enqueue_result) + self.__invoker.services.events.emit_batch_enqueued(enqueue_result, user_id=user_id) return enqueue_result def dequeue(self) -> Optional[SessionQueueItem]: @@ -765,15 +819,21 @@ class SqliteSessionQueue(SessionQueueBase): self, queue_id: str, order_dir: SQLiteDirection = SQLiteDirection.Descending, + user_id: Optional[str] = None, ) -> ItemIdsResult: with self._db.transaction() as cursor_: - query = f"""--sql + query = """--sql SELECT item_id FROM session_queue WHERE queue_id = ? - ORDER BY created_at {order_dir.value} """ - query_params = [queue_id] + query_params: list[str] = [queue_id] + + if user_id is not None: + query += " AND user_id = ?" + query_params.append(user_id) + + query += f" ORDER BY created_at {order_dir.value}" cursor_.execute(query, query_params) result = cast(list[sqlite3.Row], cursor_.fetchall()) @@ -783,20 +843,7 @@ class SqliteSessionQueue(SessionQueueBase): def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> SessionQueueStatus: with self._db.transaction() as cursor: - # Get total counts - cursor.execute( - """--sql - SELECT status, count(*) - FROM session_queue - WHERE queue_id = ? - GROUP BY status - """, - (queue_id,), - ) - counts_result = cast(list[sqlite3.Row], cursor.fetchall()) - - # Get user-specific counts if user_id is provided (using a single query with CASE) - user_counts_result = [] + # When user_id is provided (non-admin), only count that user's items if user_id is not None: cursor.execute( """--sql @@ -807,48 +854,51 @@ class SqliteSessionQueue(SessionQueueBase): """, (queue_id, user_id), ) - user_counts_result = cast(list[sqlite3.Row], cursor.fetchall()) + else: + cursor.execute( + """--sql + SELECT status, count(*) + FROM session_queue + WHERE queue_id = ? + GROUP BY status + """, + (queue_id,), + ) + counts_result = cast(list[sqlite3.Row], cursor.fetchall()) current_item = self.get_current(queue_id=queue_id) total = sum(row[1] or 0 for row in counts_result) counts: dict[str, int] = {row[0]: row[1] for row in counts_result} - # Process user-specific counts if available - user_pending = None - user_in_progress = None - if user_id is not None: - user_counts: dict[str, int] = {row[0]: row[1] for row in user_counts_result} - user_pending = user_counts.get("pending", 0) - user_in_progress = user_counts.get("in_progress", 0) + # For non-admin users, hide current item details if they don't own it + show_current_item = current_item is not None and (user_id is None or current_item.user_id == user_id) return SessionQueueStatus( queue_id=queue_id, - item_id=current_item.item_id if current_item else None, - session_id=current_item.session_id if current_item else None, - batch_id=current_item.batch_id if current_item else None, + item_id=current_item.item_id if show_current_item else None, + session_id=current_item.session_id if show_current_item else None, + batch_id=current_item.batch_id if show_current_item else None, pending=counts.get("pending", 0), in_progress=counts.get("in_progress", 0), completed=counts.get("completed", 0), failed=counts.get("failed", 0), canceled=counts.get("canceled", 0), total=total, - user_pending=user_pending, - user_in_progress=user_in_progress, ) - def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus: + def get_batch_status(self, queue_id: str, batch_id: str, user_id: Optional[str] = None) -> BatchStatus: with self._db.transaction() as cursor: - cursor.execute( - """--sql + query = """--sql SELECT status, count(*), origin, destination FROM session_queue - WHERE - queue_id = ? - AND batch_id = ? - GROUP BY status - """, - (queue_id, batch_id), - ) + WHERE queue_id = ? AND batch_id = ? + """ + params: list[str] = [queue_id, batch_id] + if user_id is not None: + query += " AND user_id = ?" + params.append(user_id) + query += " GROUP BY status" + cursor.execute(query, params) result = cast(list[sqlite3.Row], cursor.fetchall()) total = sum(row[1] or 0 for row in result) counts: dict[str, int] = {row[0]: row[1] for row in result} @@ -868,18 +918,21 @@ class SqliteSessionQueue(SessionQueueBase): total=total, ) - def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination: + def get_counts_by_destination( + self, queue_id: str, destination: str, user_id: Optional[str] = None + ) -> SessionQueueCountsByDestination: with self._db.transaction() as cursor: - cursor.execute( - """--sql + query = """--sql SELECT status, count(*) FROM session_queue - WHERE queue_id = ? - AND destination = ? - GROUP BY status - """, - (queue_id, destination), - ) + WHERE queue_id = ? AND destination = ? + """ + params: list[str] = [queue_id, destination] + if user_id is not None: + query += " AND user_id = ?" + params.append(user_id) + query += " GROUP BY status" + cursor.execute(query, params) counts_result = cast(list[sqlite3.Row], cursor.fetchall()) total = sum(row[1] or 0 for row in counts_result) diff --git a/invokeai/app/services/shared/README.md b/invokeai/app/services/shared/README.md index a65a8ebe49..113b7a41e5 100644 --- a/invokeai/app/services/shared/README.md +++ b/invokeai/app/services/shared/README.md @@ -6,32 +6,32 @@ High-level design for the graph module. Focuses on responsibilities, data flow, Provide a typed, acyclic workflow model (**Graph**) plus a runtime scheduler (**GraphExecutionState**) that expands iterator patterns, tracks readiness via indegree (the number of incoming edges to a node in the directed graph), and -executes nodes in class-grouped batches. Source graphs remain immutable during a run; runtime expansion happens in a -separate execution graph. +executes nodes in class-grouped batches. In normal execution, runtime expansion happens in a separate execution graph +instead of mutating the source graph. ## 2) Major Data Types ### EdgeConnection -* Fields: `node_id: str`, `field: str`. -* Hashable; printed as `node.field` for readable diagnostics. +- Fields: `node_id: str`, `field: str`. +- Hashable; printed as `node.field` for readable diagnostics. ### Edge -* Fields: `source: EdgeConnection`, `destination: EdgeConnection`. -* One directed connection from a specific output port to a specific input port. +- Fields: `source: EdgeConnection`, `destination: EdgeConnection`. +- One directed connection from a specific output port to a specific input port. ### AnyInvocation / AnyInvocationOutput -* Pydantic wrappers that carry concrete invocation models and outputs. -* No registry logic in this file; they are permissive containers for heterogeneous nodes. +- Pydantic wrappers that carry concrete invocation models and outputs. +- No registry logic in this file; they are permissive containers for heterogeneous nodes. ### IterateInvocation / CollectInvocation -* Control nodes used by validation and execution: +- Control nodes used by validation and execution: - * **IterateInvocation**: input `collection`, outputs include `item` (and index/total). - * **CollectInvocation**: many `item` inputs aggregated to one `collection` output. + - **IterateInvocation**: input `collection`, outputs include `item` (and index/total). + - **CollectInvocation**: many `item` inputs aggregated to one `collection` output. ## 3) Graph (author-time model) @@ -39,156 +39,209 @@ A container for declared nodes and edges. Does **not** perform iteration expansi ### 3.1 Data -* `nodes: dict[str, AnyInvocation]` - key must equal `node.id`. -* `edges: list[Edge]` - zero or more. -* Utility: `_get_input_edges(node_id, field?)`, `_get_output_edges(node_id, field?)` - These scan `self.edges` (no adjacency indices in the current code). +- `nodes: dict[str, AnyInvocation]` - key must equal `node.id`. +- `edges: list[Edge]` - zero or more. +- Utility: `_get_input_edges(node_id, field?)`, `_get_output_edges(node_id, field?)` These scan `self.edges` (no + adjacency indices in the current code). ### 3.2 Validation (`validate_self`) Runs a sequence of checks: -1. **Node ID uniqueness** - No duplicate IDs; map key equals `node.id`. -2. **Endpoint existence** - Source and destination node IDs must exist. -3. **Port existence** - Input ports must exist on the node class; output ports on the node's output model. -4. **Type compatibility** - `get_output_field_type` vs `get_input_field_type` and `are_connection_types_compatible`. -5. **DAG constraint** - Build a *flat* `DiGraph` (no runtime expansion) and assert acyclicity. -6. **Iterator / collector structure** - Enforce special rules: +1. **Node ID uniqueness** No duplicate IDs; map key equals `node.id`. - * Iterator's input must be `collection`; its outgoing edges use `item`. - * Collector accepts many `item` inputs; outputs a single `collection`. - * Edge fan-in to a non-collector input is rejected. +1. **Endpoint existence** Source and destination node IDs must exist. + +1. **Port existence** Input ports must exist on the node class; output ports on the node's output model. + +1. **DAG constraint** Build a *flat* `DiGraph` (no runtime expansion) and assert acyclicity. + +1. **Type compatibility** `get_output_field_type` vs `get_input_field_type` and `are_connection_types_compatible`. + +1. **Iterator / collector structure** Enforce special rules: + + - Iterator's input must be `collection`; its outgoing edges use `item`. + - Collector accepts many `item` inputs; outputs a single `collection`. + - Edge fan-in to a non-collector input is rejected. ### 3.3 Edge admission (`_validate_edge`) Checks a single prospective edge before insertion: -* Endpoints/ports exist. -* Destination port is not already occupied unless it's a collector `item`. -* Adding the edge to the flat DAG must keep it acyclic. -* Iterator/collector constraints re-checked when the edge creates relevant patterns. +- Endpoints/ports exist. +- Destination port is not already occupied unless it's a collector `item`. +- Adding the edge to the flat DAG must keep it acyclic. +- Iterator/collector constraints re-checked when the edge creates relevant patterns. ### 3.4 Topology utilities -* `nx_graph()` - DiGraph of declared nodes and edges. -* `nx_graph_with_data()` - includes node/edge attributes. -* `nx_graph_flat()` - "flattened" DAG (still author-time; no runtime copies). - Used in validation and in `_prepare()` during execution planning. +- `nx_graph()` - DiGraph of declared nodes and edges. +- `nx_graph_flat()` - "flattened" DAG (still author-time; no runtime copies). Used in validation and in `_prepare()` + during execution planning. ### 3.5 Mutation helpers -* `add_node`, `update_node` (preserve edges, rewrite endpoints if id changes), `delete_node`. -* `add_edge`, `delete_edge` (with validation). +- `add_node`, `update_node` (preserve edges, rewrite endpoints if id changes), `delete_node`. +- `add_edge`, `delete_edge` (with validation). ## 4) GraphExecutionState (runtime) -Holds the state for a single run. Keeps the source graph intact; materializes a separate execution graph. +Holds the state for a single run. Keeps the source graph intact and materializes a separate execution graph. +`GraphExecutionState` is still the public runtime entry point, but most execution behavior is now delegated to a small +set of internal helper classes. + +The source graph is treated as stable during normal execution, but the runtime object still exposes guarded graph +mutation helpers. Those helpers reject changes once the affected nodes have already been prepared or executed. ### 4.1 Data -* `graph: Graph` - immutable source during a run. -* `execution_graph: Graph` - materialized runtime nodes/edges. -* `executed: set[str]`, `executed_history: list[str]`. -* `results: dict[str, AnyInvocationOutput]`, `errors: dict[str, str]`. -* `prepared_source_mapping: dict[str, str]` - exec id → source id. -* `source_prepared_mapping: dict[str, set[str]]` - source id → exec ids. -* `indegree: dict[str, int]` - unmet inputs per exec node. -* **Ready queues grouped by class** (private attrs): - `_ready_queues: dict[class_name, deque[str]]`, `_active_class: Optional[str]`. Optional `ready_order: list[str]` to - prioritize classes. +- `graph: Graph` - source graph for the run; treated as stable during normal execution. +- `execution_graph: Graph` - materialized runtime nodes/edges. +- `executed: set[str]`, `executed_history: list[str]`. +- `results: dict[str, AnyInvocationOutput]`, `errors: dict[str, str]`. +- `prepared_source_mapping: dict[str, str]` - exec id -> source id. +- `source_prepared_mapping: dict[str, set[str]]` - source id -> exec ids. +- `indegree: dict[str, int]` - unmet inputs per exec node. +- Prepared exec metadata caches: + - source node id + - iteration path + - runtime state such as pending, ready, executed, or skipped +- **Ready queues grouped by class** (private attrs): `_ready_queues: dict[class_name, deque[str]]`, + `_active_class: Optional[str]`. Optional `ready_order: list[str]` to prioritize classes. ### 4.2 Core methods -* `next()` - Returns the next ready exec node. If none, calls `_prepare()` to materialize more, then retries. Before returning a - node, `_prepare_inputs()` deep-copies inbound values into the node fields. -* `complete(node_id, output)` - Record result; mark exec node executed; if all exec copies of the same **source** are done, mark the source executed. - For each outgoing exec edge, decrement child indegree and enqueue when it reaches zero. +- `next()` Returns the next ready exec node. If none are ready, it asks the materializer to expand more source nodes and + then retries. Before returning a node, the runtime helper deep-copies inbound values into the node fields. +- `complete(node_id, output)` Records the result, marks the exec node executed, marks the source node executed once all + of its prepared exec copies are done, then decrements downstream indegrees and enqueues newly ready nodes. -### 4.3 Preparation (`_prepare()`) +### 4.3 Runtime helper classes -* Build a flat DAG from the **source** graph. -* Choose the **next source node** in topological order that: +`GraphExecutionState` now delegates most runtime behavior to internal helpers: + +- `_PreparedExecRegistry` Owns the relationship between source graph nodes and prepared execution graph nodes, plus + cached metadata such as iteration path and runtime state. +- `_ExecutionMaterializer` Expands source graph nodes into concrete execution graph nodes when the scheduler runs out of + ready work. +- `_ExecutionScheduler` Owns indegree transitions, ready queues, class batching, and downstream release on completion. +- `_ExecutionRuntime` Owns iteration-path lookup and input hydration for prepared exec nodes. +- `_IfBranchScheduler` Applies lazy `If` semantics by deferring branch-local work until the condition is known, then + releasing the selected branch and skipping the unselected branch. + +### 4.4 Preparation (`_prepare()`) + +- Build a flat DAG from the **source** graph. + +- Choose the **next source node** in topological order that: 1. has not been prepared, - 2. if it is an iterator, *its inputs are already executed*, - 3. it has *no unexecuted iterator ancestors*. -* If the node is a **CollectInvocation**: collapse all prepared parents into one mapping and create **one** exec node. -* Otherwise: compute all combinations of prepared iterator ancestors. For each combination, pick the matching prepared parent per upstream and create **one** exec node. -* For each new exec node: + 1. if it is an iterator, *its inputs are already executed*, + 1. it has *no unexecuted iterator ancestors*. - * Deep-copy the source node; assign a fresh ID (and `index` for iterators). - * Wire edges from chosen prepared parents. - * Set `indegree = number of unmet inputs` (i.e., parents not yet executed). - * If `indegree == 0`, enqueue into its class queue. +- If the node is a **CollectInvocation**: collapse all prepared parents into one mapping and create **one** exec node. -### 4.4 Readiness and batching +- Otherwise: compute all combinations of prepared iterator ancestors. For each combination, choose the prepared parent + for each upstream by matching iterator ancestry, then create **one** exec node. -* `_enqueue_if_ready(nid)` enqueues by class name only when `indegree == 0` and not executed. -* `_get_next_node()` drains the `_active_class` queue FIFO; when empty, selects the next nonempty class queue (by `ready_order` if set, else alphabetical), and continues. Optional fairness knobs can limit batch size per class; default is drain fully. +- For each new exec node: -#### 4.4.1 Indegree (what it is and how it's used) + - Deep-copy the source node; assign a fresh ID (and `index` for iterators). + - Wire edges from chosen prepared parents. + - Set `indegree = number of unmet inputs` (i.e., parents not yet executed). + - Try to resolve any `If`-specific scheduling state. + - If the node is ready and not deferred by an unresolved `If`, enqueue it into its class queue. + +### 4.5 Readiness and batching + +- `_enqueue_if_ready(nid)` enqueues by class name only when `indegree == 0`, the node has not already executed, and the + node is not deferred by an unresolved `If`. +- `_get_next_node()` drains the `_active_class` queue FIFO; when empty, selects the next nonempty class queue (by + `ready_order` if set, else alphabetical), and continues. Optional fairness knobs can limit batch size per class; + default is drain fully. + +#### 4.5.1 Indegree (what it is and how it's used) **Indegree** is the number of incoming edges to a node in the execution graph that are still unmet. In this engine: -* For every materialized exec node, `indegree[node]` equals the count of its prerequisite parents that have **not** finished yet. -* A node is "ready" exactly when `indegree[node] == 0`; only then is it enqueued. -* When a node completes, the scheduler decrements `indegree[child]` for each outgoing edge. Any child that reaches 0 is enqueued. -Example: edges `A→C`, `B→C`, `C→D`. Start: `A:0, B:0, C:2, D:1`. Run `A` → `C:1`. Run `B` → `C:0` → enqueue `C`. Run `C` -→ `D:0` → enqueue `D`. Run `D` → done. +- For every materialized exec node, `indegree[node]` equals the count of its prerequisite parents that have **not** + finished yet. +- A node is "ready" exactly when `indegree[node] == 0`; only then is it enqueued. +- When a node completes, the scheduler decrements `indegree[child]` for each outgoing edge. Any child that reaches 0 is + enqueued. -### 4.5 Input hydration (`_prepare_inputs()`) +Example: edges `A->C`, `B->C`, `C->D`. Start: `A:0, B:0, C:2, D:1`. Run `A` -> `C:1`. Run `B` -> `C:0` -> enqueue `C`. +Run `C` -> `D:0` -> enqueue `D`. Run `D` -> done. -* For **CollectInvocation**: gather all incoming `item` values into `collection`. -* For all others: deep-copy each incoming edge's value into the destination field. - This prevents cross-node mutation through shared references. +### 4.6 Input hydration (`_prepare_inputs()`) + +- For **CollectInvocation**: gather all incoming `item` values into `collection`, sorting inputs by iteration path so + collected results are stable across expanded iterations. Incoming `collection` values are merged first, then incoming + `item` values are appended. +- For **IfInvocation**: hydrate only `condition` and the selected branch input. +- For all others: deep-copy each incoming edge's value into the destination field. This prevents cross-node mutation + through shared references. + +### 4.7 Lazy `If` semantics + +`IfInvocation` now acts as a lazy branch boundary rather than a simple value multiplexer. + +- The `condition` input must resolve first. +- Nodes that are exclusive to the true or false branch can remain deferred even when their indegree is zero. +- Once the prepared `If` node resolves its condition: + - the selected branch is released + - the unselected branch is marked skipped + - branch-exclusive ancestors of the unselected branch are never executed +- Shared ancestors still execute if they are required by the selected branch or by any other live path in the graph. + +This behavior is implemented in the runtime scheduler, not in the invocation body itself. ## 5) Traversal Summary 1. Author builds a valid **Graph**. -2. Create **GraphExecutionState** with that graph. -3. Loop: - * `node = state.next()` → may trigger `_prepare()` expansion. - * Execute node externally → `output`. - * `state.complete(node.id, output)` → updates indegrees and queues. -4. Finish when `next()` returns `None`. +1. Create **GraphExecutionState** with that graph. -The source graph is never mutated; all expansion occurs in `execution_graph` with traceability back to source nodes. +1. Loop: + + - `node = state.next()` -> may trigger `_prepare()` expansion. + - Execute node externally -> `output`. + - `state.complete(node.id, output)` -> updates indegrees, `If` state, and ready queues. + +1. Finish when `next()` returns `None`. + +In normal execution, all runtime expansion occurs in `execution_graph` with traceability back to source nodes. ## 6) Invariants -* Source **Graph** remains a DAG and type-consistent. -* `execution_graph` remains a DAG. -* Nodes are enqueued only when `indegree == 0`. -* `results` and `errors` are keyed by **exec node id**. -* Collectors only aggregate `item` inputs; other inputs behave one-to-one. +- Source **Graph** remains a DAG and type-consistent. +- `execution_graph` remains a DAG. +- Nodes are enqueued only when `indegree == 0` and they are not deferred by an unresolved `If`. +- `results` and `errors` are keyed by **exec node id**. +- Collectors aggregate `item` inputs and may also merge incoming `collection` inputs during runtime hydration. +- Branch-exclusive nodes behind an unselected `If` branch are skipped, not failed. ## 7) Extensibility -* **New node types**: implement as Pydantic models with typed fields and outputs. Register per your invocation system; this file accepts them as `AnyInvocation`. -* **Scheduling policy**: adjust `ready_order` to batch by class; add a batch cap for fairness without changing complexity. -* **Dynamic behaviors** (future): can be added in `GraphExecutionState` by creating exec nodes and edges at `complete()` time, as long as the DAG invariant holds. +- **New node types**: implement as Pydantic models with typed fields and outputs. Register per your invocation system; + this file accepts them as `AnyInvocation`. +- **Scheduling policy**: adjust `ready_order` to batch by class; add a batch cap for fairness without changing + complexity. +- **Dynamic behaviors** (future): can be added in `GraphExecutionState` by creating exec nodes and edges at `complete()` + time, as long as the DAG invariant holds. ## 8) Error Model (selected) -* `DuplicateNodeIdError`, `NodeAlreadyInGraphError` -* `NodeNotFoundError`, `NodeFieldNotFoundError` -* `InvalidEdgeError`, `CyclicalGraphError` -* `NodeInputError` (raised when preparing inputs for execution) +- `DuplicateNodeIdError`, `NodeAlreadyInGraphError` +- `NodeNotFoundError`, `NodeFieldNotFoundError` +- `InvalidEdgeError`, `CyclicalGraphError` +- `NodeInputError` (raised when preparing inputs for execution) Messages favor short, precise diagnostics (node id, field, and failing condition). ## 9) Rationale -* **Two-graph approach** isolates authoring from execution expansion and keeps validation simple. -* **Indegree + queues** gives O(1) scheduling decisions with clear batching semantics. -* **Iterator/collector separation** keeps fan-out/fan-in explicit and testable. -* **Deep-copy hydration** avoids incidental aliasing bugs between nodes. +- **Two-graph approach** isolates authoring from execution expansion and keeps validation simple. +- **Indegree + queues** gives O(1) scheduling decisions with clear batching semantics. +- **Iterator/collector separation** keeps fan-out/fan-in explicit and testable. +- **Deep-copy hydration** avoids incidental aliasing bugs between nodes. diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index fd31448ea4..24c1dd1fe4 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -3,7 +3,8 @@ import copy import itertools from collections import deque -from typing import Any, Deque, Iterable, Optional, Type, TypeVar, Union, get_args, get_origin +from dataclasses import dataclass +from typing import Any, Deque, Iterable, Literal, Optional, Type, TypeVar, Union, get_args, get_origin import networkx as nx from pydantic import ( @@ -29,6 +30,7 @@ from invokeai.app.invocations.baseinvocation import ( invocation_output, ) from invokeai.app.invocations.fields import Input, InputField, OutputField, UIType +from invokeai.app.invocations.logic import IfInvocation from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import uuid_string @@ -63,6 +65,731 @@ class Edge(BaseModel): return f"{self.source.node_id}.{self.source.field} -> {self.destination.node_id}.{self.destination.field}" +PreparedExecState = Literal["pending", "ready", "executed", "skipped"] + + +@dataclass +class _PreparedExecNodeMetadata: + """Cached metadata for a materialized execution node.""" + + source_node_id: str + iteration_path: Optional[tuple[int, ...]] = None + state: PreparedExecState = "pending" + + +class _PreparedExecRegistry: + """Tracks prepared execution nodes and their relationship to source graph nodes.""" + + def __init__( + self, + prepared_source_mapping: dict[str, str], + source_prepared_mapping: dict[str, set[str]], + metadata: dict[str, _PreparedExecNodeMetadata], + ) -> None: + self._prepared_source_mapping = prepared_source_mapping + self._source_prepared_mapping = source_prepared_mapping + self._metadata = metadata + + def register(self, exec_node_id: str, source_node_id: str) -> None: + self._prepared_source_mapping[exec_node_id] = source_node_id + self._metadata[exec_node_id] = _PreparedExecNodeMetadata(source_node_id=source_node_id) + if source_node_id not in self._source_prepared_mapping: + self._source_prepared_mapping[source_node_id] = set() + self._source_prepared_mapping[source_node_id].add(exec_node_id) + + def get_metadata(self, exec_node_id: str) -> _PreparedExecNodeMetadata: + metadata = self._metadata.get(exec_node_id) + if metadata is None: + metadata = _PreparedExecNodeMetadata(source_node_id=self._prepared_source_mapping[exec_node_id]) + self._metadata[exec_node_id] = metadata + return metadata + + def get_source_node_id(self, exec_node_id: str) -> str: + metadata = self._metadata.get(exec_node_id) + if metadata is not None: + return metadata.source_node_id + return self._prepared_source_mapping[exec_node_id] + + def get_prepared_ids(self, source_node_id: str) -> set[str]: + return self._source_prepared_mapping.get(source_node_id, set()) + + def set_state(self, exec_node_id: str, state: PreparedExecState) -> None: + self.get_metadata(exec_node_id).state = state + + def get_iteration_path(self, exec_node_id: str) -> Optional[tuple[int, ...]]: + metadata = self._metadata.get(exec_node_id) + return metadata.iteration_path if metadata is not None else None + + def set_iteration_path(self, exec_node_id: str, iteration_path: tuple[int, ...]) -> None: + self.get_metadata(exec_node_id).iteration_path = iteration_path + + +class _IfBranchScheduler: + """Applies lazy `If` semantics by deferring, releasing, and skipping branch-local exec nodes.""" + + def __init__(self, state: "GraphExecutionState") -> None: + self._state = state + + def _get_branch_input_sources(self, if_node_id: str, branch_field: str) -> set[str]: + return {e.source.node_id for e in self._state.graph._get_input_edges(if_node_id, branch_field)} + + def _expand_with_ancestors(self, node_ids: set[str]) -> set[str]: + expanded = set(node_ids) + source_graph = self._state.graph.nx_graph_flat() + for node_id in list(expanded): + expanded.update(nx.ancestors(source_graph, node_id)) + return expanded + + def _node_outputs_stay_in_branch( + self, node_id: str, if_node_id: str, branch_field: str, branch_nodes: set[str] + ) -> bool: + output_edges = self._state.graph._get_output_edges(node_id) + return all( + edge.destination.node_id in branch_nodes + or (edge.destination.node_id == if_node_id and edge.destination.field == branch_field) + for edge in output_edges + ) + + def _prune_nonexclusive_branch_nodes( + self, if_node_id: str, branch_field: str, candidate_nodes: set[str] + ) -> set[str]: + exclusive_nodes = set(candidate_nodes) + changed = True + while changed: + changed = False + for node_id in list(exclusive_nodes): + if self._node_outputs_stay_in_branch(node_id, if_node_id, branch_field, exclusive_nodes): + continue + exclusive_nodes.remove(node_id) + changed = True + return exclusive_nodes + + def _get_matching_prepared_if_ids(self, if_node_id: str, iteration_path: tuple[int, ...]) -> list[str]: + prepared_if_ids = self._state._prepared_registry().get_prepared_ids(if_node_id) + return [pid for pid in prepared_if_ids if self._state._get_iteration_path(pid) == iteration_path] + + def _has_unresolved_matching_if(self, if_node_id: str, iteration_path: tuple[int, ...]) -> bool: + matching_prepared_if_ids = self._get_matching_prepared_if_ids(if_node_id, iteration_path) + if not matching_prepared_if_ids: + return True + return not all(pid in self._state._resolved_if_exec_branches for pid in matching_prepared_if_ids) + + def _apply_condition_inputs(self, exec_node_id: str, node: IfInvocation) -> bool: + condition_edges = self._state.execution_graph._get_input_edges(exec_node_id, "condition") + if any(edge.source.node_id not in self._state.executed for edge in condition_edges): + return False + + for edge in condition_edges: + setattr( + node, + edge.destination.field, + copydeep(getattr(self._state.results[edge.source.node_id], edge.source.field)), + ) + return True + + def _get_selected_branch_fields(self, node: IfInvocation) -> tuple[str, str]: + selected_field = "true_input" if node.condition else "false_input" + unselected_field = "false_input" if node.condition else "true_input" + return selected_field, unselected_field + + def _prune_unselected_if_inputs(self, exec_node_id: str, unselected_field: str) -> None: + for edge in self._state.execution_graph._get_input_edges(exec_node_id, unselected_field): + if edge.source.node_id in self._state.executed: + continue + if self._state.indegree[exec_node_id] == 0: + raise RuntimeError(f"indegree underflow for {exec_node_id} when pruning {unselected_field}") + self._state.indegree[exec_node_id] -= 1 + + def _apply_branch_resolution( + self, + exec_node_id: str, + iteration_path: tuple[int, ...], + exclusive_sources: dict[str, set[str]], + selected_field: str, + unselected_field: str, + ) -> None: + # This iterates over the stable prepared-source mapping while mutating per-exec runtime state such as ready + # queues, execution state, and prepared metadata. Branch resolution never adds or removes prepared exec nodes. + for prepared_id, prepared_source in self._state.prepared_source_mapping.items(): + if prepared_id in self._state.executed: + continue + if self._state._get_iteration_path(prepared_id) != iteration_path: + continue + if prepared_source in exclusive_sources[selected_field]: + self._state._enqueue_if_ready(prepared_id) + elif prepared_source in exclusive_sources[unselected_field]: + self.mark_exec_node_skipped(prepared_id) + + def get_branch_exclusive_sources(self, if_node_id: str) -> dict[str, set[str]]: + cached = self._state._if_branch_exclusive_sources.get(if_node_id) + if cached is not None: + return cached + + branch_sources: dict[str, set[str]] = {} + for branch_field in ("true_input", "false_input"): + direct_inputs = self._get_branch_input_sources(if_node_id, branch_field) + candidate_nodes = self._expand_with_ancestors(direct_inputs) + branch_sources[branch_field] = self._prune_nonexclusive_branch_nodes( + if_node_id, branch_field, candidate_nodes + ) + + self._state._if_branch_exclusive_sources[if_node_id] = branch_sources + return branch_sources + + def is_deferred_by_unresolved_if(self, exec_node_id: str) -> bool: + source_node_id = self._state._prepared_registry().get_source_node_id(exec_node_id) + iteration_path = self._state._get_iteration_path(exec_node_id) + + for source_if_id, source_if_node in self._state.graph.nodes.items(): + if not isinstance(source_if_node, IfInvocation): + continue + + branches = self.get_branch_exclusive_sources(source_if_id) + if source_node_id not in branches["true_input"] and source_node_id not in branches["false_input"]: + continue + + if self._has_unresolved_matching_if(source_if_id, iteration_path): + return True + return False + + def mark_exec_node_skipped(self, exec_node_id: str) -> None: + self._state._remove_from_ready_queues(exec_node_id) + self._state._set_prepared_exec_state(exec_node_id, "skipped") + self._state.executed.add(exec_node_id) + + registry = self._state._prepared_registry() + source_node_id = registry.get_source_node_id(exec_node_id) + prepared_nodes = registry.get_prepared_ids(source_node_id) + if all(n in self._state.executed for n in prepared_nodes): + if source_node_id not in self._state.executed: + self._state.executed.add(source_node_id) + self._state.executed_history.append(source_node_id) + + def try_resolve_if_node(self, exec_node_id: str) -> None: + if exec_node_id in self._state._resolved_if_exec_branches: + return + node = self._state.execution_graph.get_node(exec_node_id) + if not isinstance(node, IfInvocation): + return + + if not self._apply_condition_inputs(exec_node_id, node): + return + + selected_field, unselected_field = self._get_selected_branch_fields(node) + self._state._resolved_if_exec_branches[exec_node_id] = selected_field + + source_if_node_id = self._state._prepared_registry().get_source_node_id(exec_node_id) + exclusive_sources = self.get_branch_exclusive_sources(source_if_node_id) + + iteration_path = self._state._get_iteration_path(exec_node_id) + self._prune_unselected_if_inputs(exec_node_id, unselected_field) + self._apply_branch_resolution(exec_node_id, iteration_path, exclusive_sources, selected_field, unselected_field) + self._state._enqueue_if_ready(exec_node_id) + + +class _ExecutionMaterializer: + """Expands source-graph nodes into concrete execution-graph nodes for the current runtime state. + + `GraphExecutionState.next()` calls into this helper when no prepared exec node is ready. The materializer chooses + the next source node that can be expanded, creates the corresponding exec nodes in the execution graph, wires their + inputs, and initializes their scheduler state. + """ + + def __init__(self, state: "GraphExecutionState") -> None: + self._state = state + + def _get_iterator_iteration_count(self, node_id: str, iteration_node_map: list[tuple[str, str]]) -> int: + input_collection_edge = next(iter(self._state.graph._get_input_edges(node_id, COLLECTION_FIELD))) + input_collection_prepared_node_id = next( + prepared_id + for source_id, prepared_id in iteration_node_map + if source_id == input_collection_edge.source.node_id + ) + input_collection_output = self._state.results[input_collection_prepared_node_id] + input_collection = getattr(input_collection_output, input_collection_edge.source.field) + return len(input_collection) + + def _get_new_node_iterations( + self, node: BaseInvocation, node_id: str, iteration_node_map: list[tuple[str, str]] + ) -> list[int]: + if not isinstance(node, IterateInvocation): + return [-1] + + iteration_count = self._get_iterator_iteration_count(node_id, iteration_node_map) + if iteration_count == 0: + return [] + return list(range(iteration_count)) + + def _build_execution_edges(self, node_id: str, iteration_node_map: list[tuple[str, str]]) -> list[Edge]: + input_edges = self._state.graph._get_input_edges(node_id) + new_edges: list[Edge] = [] + for edge in input_edges: + matching_inputs = [ + prepared_id for source_id, prepared_id in iteration_node_map if source_id == edge.source.node_id + ] + for input_node_id in matching_inputs: + new_edges.append( + Edge( + source=EdgeConnection(node_id=input_node_id, field=edge.source.field), + destination=EdgeConnection(node_id="", field=edge.destination.field), + ) + ) + return new_edges + + def _create_execution_node_copy(self, node: BaseInvocation, node_id: str, iteration_index: int) -> BaseInvocation: + new_node = node.model_copy(deep=True) + new_node.id = uuid_string() + + if isinstance(new_node, IterateInvocation): + new_node.index = iteration_index + + self._state.execution_graph.add_node(new_node) + self._state._register_prepared_exec_node(new_node.id, node_id) + return new_node + + def _attach_execution_edges(self, exec_node_id: str, new_edges: list[Edge]) -> None: + for edge in new_edges: + self._state.execution_graph.add_edge( + Edge( + source=edge.source, + destination=EdgeConnection(node_id=exec_node_id, field=edge.destination.field), + ) + ) + + def _initialize_execution_node(self, exec_node_id: str) -> None: + inputs = self._state.execution_graph._get_input_edges(exec_node_id) + unmet = sum(1 for edge in inputs if edge.source.node_id not in self._state.executed) + self._state.indegree[exec_node_id] = unmet + self._state._try_resolve_if_node(exec_node_id) + self._state._enqueue_if_ready(exec_node_id) + + def _get_collect_iteration_mappings(self, parent_node_ids: list[str]) -> list[tuple[str, str]]: + all_iteration_mappings: list[tuple[str, str]] = [] + for source_node_id in parent_node_ids: + prepared_nodes = self._state.source_prepared_mapping[source_node_id] + all_iteration_mappings.extend((source_node_id, prepared_id) for prepared_id in prepared_nodes) + return all_iteration_mappings + + def _get_parent_iteration_mappings(self, next_node_id: str, graph: nx.DiGraph) -> list[list[tuple[str, str]]]: + parent_node_ids = [source_id for source_id, _ in graph.in_edges(next_node_id)] + iterator_graph = self.iterator_graph(graph) + iterator_nodes = self.get_node_iterators(next_node_id, iterator_graph) + iterator_nodes_prepared = [list(self._state.source_prepared_mapping[node_id]) for node_id in iterator_nodes] + iterator_node_prepared_combinations = list(itertools.product(*iterator_nodes_prepared)) + + execution_graph = self._state.execution_graph.nx_graph_flat() + prepared_parent_mappings = [ + [ + (node_id, self.get_iteration_node(node_id, graph, execution_graph, prepared_iterators)) + for node_id in parent_node_ids + ] + for prepared_iterators in iterator_node_prepared_combinations + ] + return [ + mapping + for mapping in prepared_parent_mappings + if all(prepared_id is not None for _, prepared_id in mapping) + ] + + def create_execution_node(self, node_id: str, iteration_node_map: list[tuple[str, str]]) -> list[str]: + """Prepares an iteration node and connects all edges, returning the new node id""" + + node = self._state.graph.get_node(node_id) + iteration_indexes = self._get_new_node_iterations(node, node_id, iteration_node_map) + if not iteration_indexes: + return [] + + new_edges = self._build_execution_edges(node_id, iteration_node_map) + new_nodes: list[str] = [] + for iteration_index in iteration_indexes: + new_node = self._create_execution_node_copy(node, node_id, iteration_index) + self._attach_execution_edges(new_node.id, new_edges) + self._initialize_execution_node(new_node.id) + new_nodes.append(new_node.id) + + return new_nodes + + def iterator_graph(self, base: Optional[nx.DiGraph] = None) -> nx.DiGraph: + """Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node""" + g = base.copy() if base is not None else self._state.graph.nx_graph_flat() + collectors = ( + n for n in self._state.graph.nodes if isinstance(self._state.graph.get_node(n), CollectInvocation) + ) + for c in collectors: + g.remove_edges_from(list(g.in_edges(c))) + return g + + def get_node_iterators(self, node_id: str, it_graph: Optional[nx.DiGraph] = None) -> list[str]: + g = it_graph or self.iterator_graph() + return [n for n in nx.ancestors(g, node_id) if isinstance(self._state.graph.get_node(n), IterateInvocation)] + + def _get_prepared_nodes_for_source(self, source_node_id: str) -> set[str]: + return self._state.source_prepared_mapping[source_node_id] + + def _get_parent_iterator_exec_nodes( + self, source_node_id: str, graph: nx.DiGraph, prepared_iterator_nodes: list[str] + ) -> list[tuple[str, str]]: + iterator_source_node_mapping = [ + (prepared_exec_node_id, self._state.prepared_source_mapping[prepared_exec_node_id]) + for prepared_exec_node_id in prepared_iterator_nodes + ] + return [ + iterator_mapping + for iterator_mapping in iterator_source_node_mapping + if nx.has_path(graph, iterator_mapping[1], source_node_id) + ] + + def _matches_parent_iterators( + self, candidate_exec_node_id: str, parent_iterators: list[tuple[str, str]], execution_graph: nx.DiGraph + ) -> bool: + return all( + nx.has_path(execution_graph, parent_iterator_exec_id, candidate_exec_node_id) + for parent_iterator_exec_id, _ in parent_iterators + ) + + def _get_direct_prepared_iterator_match( + self, + prepared_nodes: set[str], + prepared_iterator_nodes: list[str], + parent_iterators: list[tuple[str, str]], + execution_graph: nx.DiGraph, + ) -> Optional[str]: + prepared_iterator = next((node_id for node_id in prepared_nodes if node_id in prepared_iterator_nodes), None) + if prepared_iterator is None: + return None + if self._matches_parent_iterators(prepared_iterator, parent_iterators, execution_graph): + return prepared_iterator + return None + + def _find_prepared_node_matching_iterators( + self, prepared_nodes: set[str], parent_iterators: list[tuple[str, str]], execution_graph: nx.DiGraph + ) -> Optional[str]: + return next( + ( + node_id + for node_id in prepared_nodes + if self._matches_parent_iterators(node_id, parent_iterators, execution_graph) + ), + None, + ) + + def get_iteration_node( + self, + source_node_id: str, + graph: nx.DiGraph, + execution_graph: nx.DiGraph, + prepared_iterator_nodes: list[str], + ) -> Optional[str]: + prepared_nodes = self._get_prepared_nodes_for_source(source_node_id) + if len(prepared_nodes) == 1: + return next(iter(prepared_nodes)) + + parent_iterators = self._get_parent_iterator_exec_nodes(source_node_id, graph, prepared_iterator_nodes) + + direct_iterator_match = self._get_direct_prepared_iterator_match( + prepared_nodes, prepared_iterator_nodes, parent_iterators, execution_graph + ) + if direct_iterator_match is not None: + return direct_iterator_match + + return self._find_prepared_node_matching_iterators(prepared_nodes, parent_iterators, execution_graph) + + def prepare(self, base_g: Optional[nx.DiGraph] = None) -> Optional[str]: + g = base_g or self._state.graph.nx_graph_flat() + next_node_id = next( + ( + node_id + for node_id in nx.topological_sort(g) + if node_id not in self._state.source_prepared_mapping + and ( + not isinstance(self._state.graph.get_node(node_id), IterateInvocation) + or all(source_id in self._state.executed for source_id, _ in g.in_edges(node_id)) + ) + and not any( + isinstance(self._state.graph.get_node(ancestor_id), IterateInvocation) + and ancestor_id not in self._state.executed + for ancestor_id in nx.ancestors(g, node_id) + ) + ), + None, + ) + + if next_node_id is None: + return None + + next_node = self._state.graph.get_node(next_node_id) + new_node_ids: list[str] = [] + + if isinstance(next_node, CollectInvocation): + next_node_parents = [source_id for source_id, _ in g.in_edges(next_node_id)] + create_results = self.create_execution_node( + next_node_id, self._get_collect_iteration_mappings(next_node_parents) + ) + if create_results is not None: + new_node_ids.extend(create_results) + else: + for iteration_mappings in self._get_parent_iteration_mappings(next_node_id, g): + create_results = self.create_execution_node(next_node_id, iteration_mappings) + if create_results is not None: + new_node_ids.extend(create_results) + + return next(iter(new_node_ids), None) + + +class _ExecutionScheduler: + """Owns ready-queue ordering and indegree-driven execution transitions.""" + + def __init__(self, state: "GraphExecutionState") -> None: + self._state = state + + def _validate_exec_node_ready_state(self, exec_node_id: str) -> None: + if exec_node_id not in self._state.execution_graph.nodes: + raise KeyError(f"exec node {exec_node_id} missing from execution_graph") + if exec_node_id not in self._state.indegree: + raise KeyError(f"indegree missing for exec node {exec_node_id}") + + def _should_skip_ready_enqueue(self, exec_node_id: str) -> bool: + return ( + self._state.indegree[exec_node_id] != 0 + or exec_node_id in self._state.executed + or self._state._is_deferred_by_unresolved_if(exec_node_id) + ) + + def _get_ready_queue(self, exec_node_id: str) -> Deque[str]: + node_obj = self._state.execution_graph.nodes[exec_node_id] + return self.queue_for(self._state._type_key(node_obj)) + + def _insert_ready_node(self, queue: Deque[str], exec_node_id: str) -> None: + exec_node_path = self._state._get_iteration_path(exec_node_id) + for i, existing in enumerate(queue): + if self._state._get_iteration_path(existing) > exec_node_path: + queue.insert(i, exec_node_id) + return + queue.append(exec_node_id) + + def _record_completed_node(self, exec_node_id: str, output: BaseInvocationOutput) -> None: + self._state._set_prepared_exec_state(exec_node_id, "executed") + self._state.executed.add(exec_node_id) + self._state.results[exec_node_id] = output + + def _mark_source_node_complete(self, exec_node_id: str) -> None: + registry = self._state._prepared_registry() + source_node_id = registry.get_source_node_id(exec_node_id) + prepared_nodes = registry.get_prepared_ids(source_node_id) + if all(node_id in self._state.executed for node_id in prepared_nodes): + self._state.executed.add(source_node_id) + self._state.executed_history.append(source_node_id) + + def _decrement_child_indegree(self, child_exec_node_id: str, parent_exec_node_id: str) -> None: + if child_exec_node_id not in self._state.indegree: + raise KeyError(f"indegree missing for exec node {child_exec_node_id}") + if self._state.indegree[child_exec_node_id] == 0: + raise RuntimeError(f"indegree underflow for {child_exec_node_id} from parent {parent_exec_node_id}") + self._state.indegree[child_exec_node_id] -= 1 + + def _release_downstream_nodes(self, exec_node_id: str) -> None: + for edge in self._state.execution_graph._get_output_edges(exec_node_id): + child = edge.destination.node_id + self._decrement_child_indegree(child, exec_node_id) + self._state._try_resolve_if_node(child) + if self._state.indegree[child] == 0: + self.enqueue_if_ready(child) + + def queue_for(self, cls_name: str) -> Deque[str]: + q = self._state._ready_queues.get(cls_name) + if q is None: + q = deque() + self._state._ready_queues[cls_name] = q + return q + + def remove_from_ready_queues(self, exec_node_id: str) -> None: + for q in self._state._ready_queues.values(): + try: + q.remove(exec_node_id) + except ValueError: + continue + + def enqueue_if_ready(self, exec_node_id: str) -> None: + """Push exec_node_id to its class queue if unmet inputs == 0.""" + self._validate_exec_node_ready_state(exec_node_id) + if self._should_skip_ready_enqueue(exec_node_id): + return + queue = self._get_ready_queue(exec_node_id) + if exec_node_id in queue: + return + self._state._set_prepared_exec_state(exec_node_id, "ready") + self._insert_ready_node(queue, exec_node_id) + + def get_next_node(self) -> Optional[BaseInvocation]: + """Gets the next ready node: FIFO within class, drain class before switching.""" + while True: + if self._state._active_class: + q = self._state._ready_queues.get(self._state._active_class) + while q: + exec_node_id = q.popleft() + if exec_node_id not in self._state.executed: + return self._state.execution_graph.nodes[exec_node_id] + self._state._active_class = None + continue + + seen = set(self._state.ready_order) + next_class = next( + (cls_name for cls_name in self._state.ready_order if self._state._ready_queues.get(cls_name)), + None, + ) + if next_class is None: + next_class = next( + ( + cls_name + for cls_name in sorted(k for k in self._state._ready_queues.keys() if k not in seen) + if self._state._ready_queues[cls_name] + ), + None, + ) + if next_class is None: + return None + + self._state._active_class = next_class + + def complete(self, exec_node_id: str, output: BaseInvocationOutput) -> None: + if exec_node_id not in self._state.execution_graph.nodes: + return + + self._record_completed_node(exec_node_id, output) + self._mark_source_node_complete(exec_node_id) + self._release_downstream_nodes(exec_node_id) + + +class _ExecutionRuntime: + """Provides runtime-only helpers such as iteration-path lookup and input hydration.""" + + def __init__(self, state: "GraphExecutionState") -> None: + self._state = state + + def _get_cached_iteration_path(self, exec_node_id: str) -> Optional[tuple[int, ...]]: + registry = self._state._prepared_registry() + metadata_iteration_path = registry.get_iteration_path(exec_node_id) + if metadata_iteration_path is not None: + return metadata_iteration_path + + return self._state._iteration_path_cache.get(exec_node_id) + + def _get_iteration_source_node_id(self, exec_node_id: str) -> Optional[str]: + if exec_node_id not in self._state.prepared_source_mapping: + return None + return self._state._prepared_registry().get_source_node_id(exec_node_id) + + def _get_ordered_iterator_sources(self, source_node_id: str) -> list[str]: + iterator_graph = self._state._iterator_graph(self._state.graph.nx_graph()) + iterator_sources = [ + node_id + for node_id in nx.ancestors(iterator_graph, source_node_id) + if isinstance(self._state.graph.get_node(node_id), IterateInvocation) + ] + + topo = list(nx.topological_sort(iterator_graph)) + topo_index = {node_id: i for i, node_id in enumerate(topo)} + iterator_sources.sort(key=lambda node_id: topo_index.get(node_id, 0)) + return iterator_sources + + def _get_iterator_exec_id( + self, iterator_source_id: str, exec_node_id: str, execution_graph: nx.DiGraph + ) -> Optional[str]: + prepared = self._state.source_prepared_mapping.get(iterator_source_id) + if not prepared: + return None + return next((pid for pid in prepared if nx.has_path(execution_graph, pid, exec_node_id)), None) + + def _build_iteration_path(self, exec_node_id: str, source_node_id: str) -> tuple[int, ...]: + iterator_sources = self._get_ordered_iterator_sources(source_node_id) + execution_graph = self._state.execution_graph.nx_graph() + path: list[int] = [] + for iterator_source_id in iterator_sources: + iterator_exec_id = self._get_iterator_exec_id(iterator_source_id, exec_node_id, execution_graph) + if iterator_exec_id is None: + continue + iterator_node = self._state.execution_graph.nodes.get(iterator_exec_id) + if isinstance(iterator_node, IterateInvocation): + path.append(iterator_node.index) + + node_obj = self._state.execution_graph.nodes.get(exec_node_id) + if isinstance(node_obj, IterateInvocation): + path.append(node_obj.index) + + return tuple(path) + + def _cache_iteration_path(self, exec_node_id: str, iteration_path: tuple[int, ...]) -> tuple[int, ...]: + self._state._iteration_path_cache[exec_node_id] = iteration_path + self._state._prepared_registry().set_iteration_path(exec_node_id, iteration_path) + return iteration_path + + def get_iteration_path(self, exec_node_id: str) -> tuple[int, ...]: + """Best-effort outer->inner iteration indices for an execution node, stopping at collectors.""" + cached = self._get_cached_iteration_path(exec_node_id) + if cached is not None: + return cached + + source_node_id = self._get_iteration_source_node_id(exec_node_id) + if source_node_id is None: + return self._cache_iteration_path(exec_node_id, ()) + + return self._cache_iteration_path(exec_node_id, self._build_iteration_path(exec_node_id, source_node_id)) + + def _sort_collect_input_edges(self, input_edges: list[Edge], field_name: str) -> list[Edge]: + matching_edges = [edge for edge in input_edges if edge.destination.field == field_name] + matching_edges.sort(key=lambda edge: (self.get_iteration_path(edge.source.node_id), edge.source.node_id)) + return matching_edges + + def _get_copied_result_value(self, edge: Edge) -> Any: + return copydeep(getattr(self._state.results[edge.source.node_id], edge.source.field)) + + def _build_collect_collection(self, input_edges: list[Edge]) -> list[Any]: + item_edges = self._sort_collect_input_edges(input_edges, ITEM_FIELD) + collection_edges = self._sort_collect_input_edges(input_edges, COLLECTION_FIELD) + + output_collection = [] + for edge in collection_edges: + source_value = self._get_copied_result_value(edge) + if isinstance(source_value, list): + output_collection.extend(source_value) + else: + output_collection.append(source_value) + output_collection.extend(self._get_copied_result_value(edge) for edge in item_edges) + return output_collection + + def _set_node_inputs( + self, node: BaseInvocation, input_edges: list[Edge], allowed_fields: Optional[set[str]] = None + ) -> None: + for edge in input_edges: + if allowed_fields is not None and edge.destination.field not in allowed_fields: + continue + setattr(node, edge.destination.field, self._get_copied_result_value(edge)) + + def _prepare_collect_inputs(self, node: "CollectInvocation", input_edges: list[Edge]) -> None: + node.collection = self._build_collect_collection(input_edges) + + def _prepare_if_inputs(self, node: IfInvocation, input_edges: list[Edge]) -> None: + selected_field = self._state._resolved_if_exec_branches.get(node.id) + allowed_fields = {"condition", selected_field} if selected_field is not None else {"condition"} + self._set_node_inputs(node, input_edges, allowed_fields) + + def _prepare_default_inputs(self, node: BaseInvocation, input_edges: list[Edge]) -> None: + self._set_node_inputs(node, input_edges) + + def prepare_inputs(self, node: BaseInvocation) -> None: + input_edges = self._state.execution_graph._get_input_edges(node.id) + + if isinstance(node, CollectInvocation): + self._prepare_collect_inputs(node, input_edges) + return + + if isinstance(node, IfInvocation): + self._prepare_if_inputs(node, input_edges) + return + + self._prepare_default_inputs(node, input_edges) + + def get_output_field_type(node: BaseInvocation, field: str) -> Any: # TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which # really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this @@ -123,6 +850,23 @@ def is_any(t: Any) -> bool: return t == Any or Any in get_args(t) +def extract_collection_item_types(t: Any) -> set[Any]: + """Extracts list item types from a collection annotation, including unions containing list branches.""" + if is_any(t): + return {Any} + + if get_origin(t) is list: + return {arg for arg in get_args(t) if arg != NoneType} + + item_types: set[Any] = set() + for arg in get_args(t): + if is_any(arg): + item_types.add(Any) + elif get_origin(arg) is list: + item_types.update(item_arg for item_arg in get_args(arg) if item_arg != NoneType) + return item_types + + def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool: if not from_type or not to_type: return False @@ -280,7 +1024,7 @@ class CollectInvocationOutput(BaseInvocationOutput): ) -@invocation("collect", version="1.0.0") +@invocation("collect", version="1.1.0") class CollectInvocation(BaseInvocation): """Collects values into a collection""" @@ -292,7 +1036,10 @@ class CollectInvocation(BaseInvocation): input=Input.Connection, ) collection: list[Any] = InputField( - description="The collection, will be provided on execution", default=[], ui_hidden=True + description="An optional collection to append to", + default=[], + ui_type=UIType._Collection, + input=Input.Connection, ) def invoke(self, context: InvocationContext) -> CollectInvocationOutput: @@ -344,6 +1091,8 @@ class AnyInvocationOutput(BaseInvocationOutput): class Graph(BaseModel): + """A validated invocation graph made of nodes and typed edges.""" + id: str = Field(description="The id of this graph", default_factory=uuid_string) # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me nodes: dict[str, AnyInvocation] = Field(description="The nodes in this graph", default_factory=dict) @@ -402,6 +1151,65 @@ class Graph(BaseModel): except ValueError: pass + def _validate_unique_node_ids(self) -> None: + node_ids = [n.id for n in self.nodes.values()] + seen = set() + duplicate_node_ids = {nid for nid in node_ids if (nid in seen) or seen.add(nid)} + if duplicate_node_ids: + raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}") + + def _validate_node_id_mapping(self) -> None: + for node_dict_id, node in self.nodes.items(): + if node_dict_id != node.id: + raise NodeIdMismatchError(f"Node ids must match, got {node_dict_id} and {node.id}") + + def _validate_edge_nodes_and_fields(self) -> None: + for edge in self.edges: + source_node = self.nodes.get(edge.source.node_id, None) + if source_node is None: + raise NodeNotFoundError(f"Edge source node {edge.source.node_id} does not exist in the graph") + + destination_node = self.nodes.get(edge.destination.node_id, None) + if destination_node is None: + raise NodeNotFoundError(f"Edge destination node {edge.destination.node_id} does not exist in the graph") + + if edge.source.field not in source_node.get_output_annotation().model_fields: + raise NodeFieldNotFoundError( + f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}" + ) + + if edge.destination.field not in type(destination_node).model_fields: + raise NodeFieldNotFoundError( + f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}" + ) + + def _validate_graph_is_acyclic(self) -> None: + graph = self.nx_graph_flat() + if not nx.is_directed_acyclic_graph(graph): + raise CyclicalGraphError("Graph contains cycles") + + def _validate_edge_type_compatibility(self) -> None: + for edge in self.edges: + if not are_connections_compatible( + self.get_node(edge.source.node_id), + edge.source.field, + self.get_node(edge.destination.node_id), + edge.destination.field, + ): + raise InvalidEdgeError(f"Edge source and target types do not match ({edge})") + + def _validate_special_nodes(self) -> None: + # TODO: may need to validate all iterators & collectors in subgraphs so edge connections in parent graphs will be available + for node in self.nodes.values(): + if isinstance(node, IterateInvocation): + err = self._is_iterator_connection_valid(node.id) + if err is not None: + raise InvalidEdgeError(f"Invalid iterator node ({node.id}): {err}") + if isinstance(node, CollectInvocation): + err = self._is_collector_connection_valid(node.id) + if err is not None: + raise InvalidEdgeError(f"Invalid collector node ({node.id}): {err}") + def validate_self(self) -> None: """ Validates the graph. @@ -416,67 +1224,12 @@ class Graph(BaseModel): - `InvalidEdgeError` """ - # Validate that all node ids are unique - node_ids = [n.id for n in self.nodes.values()] - seen = set() - duplicate_node_ids = {nid for nid in node_ids if (nid in seen) or seen.add(nid)} - if duplicate_node_ids: - raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}") - - # Validate that all node ids match the keys in the nodes dict - for k, v in self.nodes.items(): - if k != v.id: - raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}") - - # Validate that all edges match nodes and fields in the graph - for edge in self.edges: - source_node = self.nodes.get(edge.source.node_id, None) - if source_node is None: - raise NodeNotFoundError(f"Edge source node {edge.source.node_id} does not exist in the graph") - - destination_node = self.nodes.get(edge.destination.node_id, None) - if destination_node is None: - raise NodeNotFoundError(f"Edge destination node {edge.destination.node_id} does not exist in the graph") - - # output fields are not on the node object directly, they are on the output type - if edge.source.field not in source_node.get_output_annotation().model_fields: - raise NodeFieldNotFoundError( - f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}" - ) - - # input fields are on the node - if edge.destination.field not in type(destination_node).model_fields: - raise NodeFieldNotFoundError( - f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}" - ) - - # Validate there are no cycles - g = self.nx_graph_flat() - if not nx.is_directed_acyclic_graph(g): - raise CyclicalGraphError("Graph contains cycles") - - # Validate all edge connections are valid - for edge in self.edges: - if not are_connections_compatible( - self.get_node(edge.source.node_id), - edge.source.field, - self.get_node(edge.destination.node_id), - edge.destination.field, - ): - raise InvalidEdgeError(f"Edge source and target types do not match ({edge})") - - # Validate all iterators & collectors - # TODO: may need to validate all iterators & collectors in subgraphs so edge connections in parent graphs will be available - for node in self.nodes.values(): - if isinstance(node, IterateInvocation): - err = self._is_iterator_connection_valid(node.id) - if err is not None: - raise InvalidEdgeError(f"Invalid iterator node ({node.id}): {err}") - if isinstance(node, CollectInvocation): - err = self._is_collector_connection_valid(node.id) - if err is not None: - raise InvalidEdgeError(f"Invalid collector node ({node.id}): {err}") - + self._validate_unique_node_ids() + self._validate_node_id_mapping() + self._validate_edge_nodes_and_fields() + self._validate_graph_is_acyclic() + self._validate_edge_type_compatibility() + self._validate_special_nodes() return None def is_valid(self) -> bool: @@ -508,52 +1261,56 @@ class Graph(BaseModel): """Checks if the destination field for an edge is of type typing.Any""" return get_input_field_type(self.get_node(edge.destination.node_id), edge.destination.field) == list[Any] - def _validate_edge(self, edge: Edge): - """Validates that a new edge doesn't create a cycle in the graph""" - - # Validate that the nodes exist + def _get_edge_nodes(self, edge: Edge) -> tuple[BaseInvocation, BaseInvocation]: try: - from_node = self.get_node(edge.source.node_id) - to_node = self.get_node(edge.destination.node_id) + return self.get_node(edge.source.node_id), self.get_node(edge.destination.node_id) except NodeNotFoundError: raise InvalidEdgeError(f"One or both nodes don't exist ({edge})") - # Validate that an edge to this node+field doesn't already exist + def _validate_edge_destination_uniqueness(self, edge: Edge, destination_node: BaseInvocation) -> None: input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field) - if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation): + if len(input_edges) > 0 and ( + not isinstance(destination_node, CollectInvocation) or edge.destination.field != ITEM_FIELD + ): raise InvalidEdgeError(f"Edge already exists ({edge})") - # Validate that no cycles would be created - g = self.nx_graph_flat() - g.add_edge(edge.source.node_id, edge.destination.node_id) - if not nx.is_directed_acyclic_graph(g): + def _validate_edge_would_not_create_cycle(self, edge: Edge) -> None: + graph = self.nx_graph_flat() + graph.add_edge(edge.source.node_id, edge.destination.node_id) + if not nx.is_directed_acyclic_graph(graph): raise InvalidEdgeError(f"Edge creates a cycle in the graph ({edge})") - # Validate that the field types are compatible - if not are_connections_compatible(from_node, edge.source.field, to_node, edge.destination.field): + def _validate_edge_field_compatibility( + self, edge: Edge, source_node: BaseInvocation, destination_node: BaseInvocation + ) -> None: + if not are_connections_compatible(source_node, edge.source.field, destination_node, edge.destination.field): raise InvalidEdgeError(f"Field types are incompatible ({edge})") - # Validate if iterator output type matches iterator input type (if this edge results in both being set) - if isinstance(to_node, IterateInvocation) and edge.destination.field == COLLECTION_FIELD: + def _validate_iterator_edge_rules( + self, edge: Edge, source_node: BaseInvocation, destination_node: BaseInvocation + ) -> None: + if isinstance(destination_node, IterateInvocation) and edge.destination.field == COLLECTION_FIELD: err = self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source) if err is not None: raise InvalidEdgeError(f"Iterator input type does not match iterator output type ({edge}): {err}") - # Validate if iterator input type matches output type (if this edge results in both being set) - if isinstance(from_node, IterateInvocation) and edge.source.field == ITEM_FIELD: + if isinstance(source_node, IterateInvocation) and edge.source.field == ITEM_FIELD: err = self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination) if err is not None: raise InvalidEdgeError(f"Iterator output type does not match iterator input type ({edge}): {err}") - # Validate if collector input type matches output type (if this edge results in both being set) - if isinstance(to_node, CollectInvocation) and edge.destination.field == ITEM_FIELD: - err = self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source) + def _validate_collector_edge_rules( + self, edge: Edge, source_node: BaseInvocation, destination_node: BaseInvocation + ) -> None: + if isinstance(destination_node, CollectInvocation) and edge.destination.field in (ITEM_FIELD, COLLECTION_FIELD): + err = self._is_collector_connection_valid( + edge.destination.node_id, new_input=edge.source, new_input_field=edge.destination.field + ) if err is not None: raise InvalidEdgeError(f"Collector output type does not match collector input type ({edge}): {err}") - # Validate if collector output type matches input type (if this edge results in both being set) - skip if the destination field is not Any or list[Any] if ( - isinstance(from_node, CollectInvocation) + isinstance(source_node, CollectInvocation) and edge.source.field == COLLECTION_FIELD and not self._is_destination_field_list_of_Any(edge) and not self._is_destination_field_Any(edge) @@ -562,6 +1319,15 @@ class Graph(BaseModel): if err is not None: raise InvalidEdgeError(f"Collector input type does not match collector output type ({edge}): {err}") + def _validate_edge(self, edge: Edge): + """Validates that a new edge doesn't create a cycle in the graph""" + source_node, destination_node = self._get_edge_nodes(edge) + self._validate_edge_destination_uniqueness(edge, destination_node) + self._validate_edge_would_not_create_cycle(edge) + self._validate_edge_field_compatibility(edge, source_node, destination_node) + self._validate_iterator_edge_rules(edge, source_node, destination_node) + self._validate_collector_edge_rules(edge, source_node, destination_node) + def has_node(self, node_id: str) -> bool: """Determines whether or not a node exists in the graph.""" try: @@ -652,99 +1418,253 @@ class Graph(BaseModel): if new_output is not None: outputs.append(new_output) - if len(inputs) == 0: - return "Iterator must have a collection input edge" + return self._validate_iterator_connections(inputs, outputs) - # Only one input is allowed for iterators - if len(inputs) > 1: - return "Iterator may only have one input edge" + def _validate_iterator_connections(self, inputs: list[EdgeConnection], outputs: list[EdgeConnection]) -> str | None: + presence_error = self._validate_iterator_input_presence(inputs) + if presence_error is not None: + return presence_error input_node = self.get_node(inputs[0].node_id) - - # Get input and output fields (the fields linked to the iterator's input/output) input_field_type = get_output_field_type(input_node, inputs[0].field) - output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs] + output_field_types = self._get_iterator_output_field_types(outputs) - # Input type must be a list + input_type_error = self._validate_iterator_input_type(input_field_type) + if input_type_error is not None: + return input_type_error + + output_type_error = self._validate_iterator_output_types(input_field_type, output_field_types) + if output_type_error is not None: + return output_type_error + + return self._validate_iterator_collector_input(input_node, output_field_types) + + def _validate_iterator_input_presence(self, inputs: list[EdgeConnection]) -> str | None: + if len(inputs) == 0: + return "Iterator must have a collection input edge" + if len(inputs) > 1: + return "Iterator may only have one input edge" + return None + + def _get_iterator_output_field_types(self, outputs: list[EdgeConnection]) -> list[Any]: + return [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs] + + def _validate_iterator_input_type(self, input_field_type: Any) -> str | None: if get_origin(input_field_type) is not list: return "Iterator input must be a collection" + return None - # Validate that all outputs match the input type + def _validate_iterator_output_types(self, input_field_type: Any, output_field_types: list[Any]) -> str | None: input_field_item_type = get_args(input_field_type)[0] - if not all((are_connection_types_compatible(input_field_item_type, t) for t in output_field_types)): + if not all(are_connection_types_compatible(input_field_item_type, t) for t in output_field_types): return "Iterator outputs must connect to an input with a matching type" + return None - # Collector input type must match all iterator output types - if isinstance(input_node, CollectInvocation): - collector_inputs = self._get_input_edges(input_node.id, ITEM_FIELD) - if len(collector_inputs) == 0: - return "Iterator input collector must have at least one item input edge" + def _validate_iterator_collector_input( + self, input_node: BaseInvocation, output_field_types: list[Any] + ) -> str | None: + if not isinstance(input_node, CollectInvocation): + return None - # Traverse the graph to find the first collector input edge. Collectors validate that their collection - # inputs are all of the same type, so we can use the first input edge to determine the collector's type - first_collector_input_edge = collector_inputs[0] - first_collector_input_type = get_output_field_type( - self.get_node(first_collector_input_edge.source.node_id), first_collector_input_edge.source.field + input_root_type = self._get_collector_input_root_type(input_node.id) + if input_root_type is None: + return "Iterator input collector must have at least one item or collection input edge" + if not all(are_connection_types_compatible(input_root_type, t) for t in output_field_types): + return "Iterator collection type must match all iterator output types" + return None + + def _resolve_collector_input_types(self, node_id: str, visited: Optional[set[str]] = None) -> set[Any]: + """Resolves possible item types for a collector's inputs, recursively following chained collectors.""" + visited = visited or set() + if node_id in visited: + return set() + visited.add(node_id) + + input_types: set[Any] = set() + + for edge in self._get_input_edges(node_id, ITEM_FIELD): + input_field_type = get_output_field_type(self.get_node(edge.source.node_id), edge.source.field) + resolved_types = [input_field_type] if get_origin(input_field_type) is None else get_args(input_field_type) + input_types.update(t for t in resolved_types if t != NoneType) + + for edge in self._get_input_edges(node_id, COLLECTION_FIELD): + source_node = self.get_node(edge.source.node_id) + if isinstance(source_node, CollectInvocation) and edge.source.field == COLLECTION_FIELD: + input_types.update(self._resolve_collector_input_types(source_node.id, visited.copy())) + continue + + input_field_type = get_output_field_type(source_node, edge.source.field) + input_types.update(extract_collection_item_types(input_field_type)) + + return input_types + + def _get_type_tree_root_types(self, input_types: set[Any]) -> list[Any]: + type_tree = nx.DiGraph() + type_tree.add_nodes_from(input_types) + type_tree.add_edges_from([e for e in itertools.permutations(input_types, 2) if issubclass(e[1], e[0])]) + type_degrees = type_tree.in_degree(type_tree.nodes) + return [t[0] for t in type_degrees if t[1] == 0] # type: ignore + + def _get_collector_input_root_type(self, node_id: str) -> Any | None: + input_types = self._resolve_collector_input_types(node_id) + non_any_input_types = {t for t in input_types if t != Any} + if len(non_any_input_types) == 0 and Any in input_types: + return Any + if len(non_any_input_types) == 0: + return None + + root_types = self._get_type_tree_root_types(non_any_input_types) + if len(root_types) != 1: + return Any + return root_types[0] + + def _get_collector_connections( + self, + node_id: str, + new_input: Optional[EdgeConnection] = None, + new_input_field: Optional[str] = None, + new_output: Optional[EdgeConnection] = None, + ) -> tuple[list[EdgeConnection], list[EdgeConnection], list[EdgeConnection]]: + item_inputs = [e.source for e in self._get_input_edges(node_id, ITEM_FIELD)] + collection_inputs = [e.source for e in self._get_input_edges(node_id, COLLECTION_FIELD)] + outputs = [e.destination for e in self._get_output_edges(node_id, COLLECTION_FIELD)] + + if new_input is not None: + field = new_input_field or ITEM_FIELD + if field == ITEM_FIELD: + item_inputs.append(new_input) + elif field == COLLECTION_FIELD: + collection_inputs.append(new_input) + + if new_output is not None: + outputs.append(new_output) + + return item_inputs, collection_inputs, outputs + + def _get_collector_port_types( + self, + item_inputs: list[EdgeConnection], + collection_inputs: list[EdgeConnection], + outputs: list[EdgeConnection], + ) -> tuple[list[Any], list[Any], list[Any]]: + item_input_field_types = [get_output_field_type(self.get_node(e.node_id), e.field) for e in item_inputs] + collection_input_field_types = [ + get_output_field_type(self.get_node(e.node_id), e.field) for e in collection_inputs + ] + output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs] + return item_input_field_types, collection_input_field_types, output_field_types + + def _resolve_item_input_types(self, item_input_field_types: list[Any]) -> set[Any]: + return { + resolved_type + for input_field_type in item_input_field_types + for resolved_type in ( + [input_field_type] if get_origin(input_field_type) is None else get_args(input_field_type) ) - resolved_collector_type = ( - first_collector_input_type - if get_origin(first_collector_input_type) is None - else get_args(first_collector_input_type) - ) - if not all((are_connection_types_compatible(resolved_collector_type, t) for t in output_field_types)): - return "Iterator collection type must match all iterator output types" + if resolved_type != NoneType + } + def _resolve_collection_input_types( + self, collection_inputs: list[EdgeConnection], collection_input_field_types: list[Any] + ) -> set[Any]: + input_field_types: set[Any] = set() + for input_conn, input_field_type in zip(collection_inputs, collection_input_field_types, strict=False): + source_node = self.get_node(input_conn.node_id) + if isinstance(source_node, CollectInvocation) and input_conn.field == COLLECTION_FIELD: + input_field_types.update(self._resolve_collector_input_types(source_node.id)) + continue + input_field_types.update(extract_collection_item_types(input_field_type)) + return input_field_types + + def _validate_collector_collection_inputs(self, collection_input_field_types: list[Any]) -> str | None: + if not all((is_list_or_contains_list(t) or is_any(t) for t in collection_input_field_types)): + return "Collector collection input must be a collection" + return None + + def _get_collector_input_root_type_from_resolved_types( + self, input_field_types: set[Any] + ) -> tuple[bool, Any | None]: + non_any_input_field_types = {t for t in input_field_types if t != Any} + root_types = self._get_type_tree_root_types(non_any_input_field_types) + if len(root_types) > 1: + return True, None + return False, root_types[0] if len(root_types) == 1 else None + + def _validate_collector_output_types( + self, output_field_types: list[Any], input_root_type: Any | None + ) -> str | None: + if not all(is_list_or_contains_list(t) or is_any(t) for t in output_field_types): + return "Collector output must connect to a collection input" + + if input_root_type is not None: + if not all( + is_any(t) + or is_union_subtype(input_root_type, get_args(t)[0]) + or issubclass(input_root_type, get_args(t)[0]) + for t in output_field_types + ): + return "Collector outputs must connect to a collection input with a matching type" + elif any(not is_any(t) and get_args(t)[0] != Any for t in output_field_types): + return "Collector outputs must connect to a collection input with a matching type" + + return None + + def _validate_downstream_collector_outputs( + self, outputs: list[EdgeConnection], input_root_type: Any | None + ) -> str | None: + for output in outputs: + output_node = self.get_node(output.node_id) + if not isinstance(output_node, CollectInvocation) or output.field != COLLECTION_FIELD: + continue + output_root_type = self._get_collector_input_root_type(output_node.id) + if output_root_type is None: + continue + if input_root_type is None: + if output_root_type != Any: + return "Collector outputs must connect to a collection input with a matching type" + continue + if not are_connection_types_compatible(input_root_type, output_root_type): + return "Collector outputs must connect to a collection input with a matching type" return None def _is_collector_connection_valid( self, node_id: str, new_input: Optional[EdgeConnection] = None, + new_input_field: Optional[str] = None, new_output: Optional[EdgeConnection] = None, ) -> str | None: - inputs = [e.source for e in self._get_input_edges(node_id, ITEM_FIELD)] - outputs = [e.destination for e in self._get_output_edges(node_id, COLLECTION_FIELD)] + item_inputs, collection_inputs, outputs = self._get_collector_connections( + node_id, new_input=new_input, new_input_field=new_input_field, new_output=new_output + ) - if new_input is not None: - inputs.append(new_input) - if new_output is not None: - outputs.append(new_output) + if len(item_inputs) == 0 and len(collection_inputs) == 0: + return "Collector must have at least one item or collection input edge" - # Get input and output fields (the fields linked to the iterator's input/output) - input_field_types = [get_output_field_type(self.get_node(e.node_id), e.field) for e in inputs] - output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs] + item_input_field_types, collection_input_field_types, output_field_types = self._get_collector_port_types( + item_inputs, collection_inputs, outputs + ) - # Validate that all inputs are derived from or match a single type - input_field_types = { - resolved_type - for input_field_type in input_field_types - for resolved_type in ( - [input_field_type] if get_origin(input_field_type) is None else get_args(input_field_type) - ) - if resolved_type != NoneType - } # Get unique types - type_tree = nx.DiGraph() - type_tree.add_nodes_from(input_field_types) - type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])]) - type_degrees = type_tree.in_degree(type_tree.nodes) - if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore + collection_input_error = self._validate_collector_collection_inputs(collection_input_field_types) + if collection_input_error is not None: + return collection_input_error + + input_field_types = self._resolve_item_input_types(item_input_field_types) + input_field_types.update(self._resolve_collection_input_types(collection_inputs, collection_input_field_types)) + + has_multiple_root_types, input_root_type = self._get_collector_input_root_type_from_resolved_types( + input_field_types + ) + if has_multiple_root_types: return "Collector input collection items must be of a single type" - # Get the input root type - input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore + output_type_error = self._validate_collector_output_types(output_field_types, input_root_type) + if output_type_error is not None: + return output_type_error - # Verify that all outputs are lists - if not all(is_list_or_contains_list(t) or is_any(t) for t in output_field_types): - return "Collector output must connect to a collection input" - - # Verify that all outputs match the input type (are a base class or the same class) - if not all( - is_any(t) - or is_union_subtype(input_root_type, get_args(t)[0]) - or issubclass(input_root_type, get_args(t)[0]) - for t in output_field_types - ): - return "Collector outputs must connect to a collection input with a matching type" + downstream_output_error = self._validate_downstream_collector_outputs(outputs, input_root_type) + if downstream_output_error is not None: + return downstream_output_error return None @@ -769,7 +1689,7 @@ class Graph(BaseModel): class GraphExecutionState(BaseModel): - """Tracks the state of a graph execution""" + """Tracks source-graph expansion, execution progress, and runtime results.""" id: str = Field(description="The id of the execution state", default_factory=uuid_string) # TODO: Store a reference to the graph instead of the actual graph? @@ -813,62 +1733,70 @@ class GraphExecutionState(BaseModel): ready_order: list[str] = Field(default_factory=list) indegree: dict[str, int] = Field(default_factory=dict, description="Remaining unmet input count for exec nodes") _iteration_path_cache: dict[str, tuple[int, ...]] = PrivateAttr(default_factory=dict) + _if_branch_exclusive_sources: dict[str, dict[str, set[str]]] = PrivateAttr(default_factory=dict) + _resolved_if_exec_branches: dict[str, str] = PrivateAttr(default_factory=dict) + _prepared_exec_metadata: dict[str, _PreparedExecNodeMetadata] = PrivateAttr(default_factory=dict) + _prepared_exec_registry: Optional[_PreparedExecRegistry] = PrivateAttr(default=None) + _if_branch_scheduler: Optional[_IfBranchScheduler] = PrivateAttr(default=None) + _execution_materializer: Optional[_ExecutionMaterializer] = PrivateAttr(default=None) + _execution_scheduler: Optional[_ExecutionScheduler] = PrivateAttr(default=None) + _execution_runtime: Optional[_ExecutionRuntime] = PrivateAttr(default=None) def _type_key(self, node_obj: BaseInvocation) -> str: return node_obj.__class__.__name__ + def _prepared_registry(self) -> _PreparedExecRegistry: + if self._prepared_exec_registry is None: + self._prepared_exec_registry = _PreparedExecRegistry( + prepared_source_mapping=self.prepared_source_mapping, + source_prepared_mapping=self.source_prepared_mapping, + metadata=self._prepared_exec_metadata, + ) + return self._prepared_exec_registry + + def _if_scheduler(self) -> _IfBranchScheduler: + if self._if_branch_scheduler is None: + self._if_branch_scheduler = _IfBranchScheduler(self) + return self._if_branch_scheduler + + def _materializer(self) -> _ExecutionMaterializer: + if self._execution_materializer is None: + self._execution_materializer = _ExecutionMaterializer(self) + return self._execution_materializer + + def _scheduler(self) -> _ExecutionScheduler: + if self._execution_scheduler is None: + self._execution_scheduler = _ExecutionScheduler(self) + return self._execution_scheduler + + def _runtime(self) -> _ExecutionRuntime: + if self._execution_runtime is None: + self._execution_runtime = _ExecutionRuntime(self) + return self._execution_runtime + + def _register_prepared_exec_node(self, exec_node_id: str, source_node_id: str) -> None: + self._prepared_registry().register(exec_node_id, source_node_id) + + def _get_prepared_exec_metadata(self, exec_node_id: str) -> _PreparedExecNodeMetadata: + return self._prepared_registry().get_metadata(exec_node_id) + + def _set_prepared_exec_state(self, exec_node_id: str, state: PreparedExecState) -> None: + self._prepared_registry().set_state(exec_node_id, state) + def _get_iteration_path(self, exec_node_id: str) -> tuple[int, ...]: - """Best-effort outer->inner iteration indices for an execution node, stopping at collectors.""" - cached = self._iteration_path_cache.get(exec_node_id) - if cached is not None: - return cached - - # Only prepared execution nodes participate; otherwise treat as non-iterated. - source_node_id = self.prepared_source_mapping.get(exec_node_id) - if source_node_id is None: - self._iteration_path_cache[exec_node_id] = () - return () - - # Source-graph iterator ancestry, with edges into collectors removed so iteration context doesn't leak. - it_g = self._iterator_graph(self.graph.nx_graph()) - iterator_sources = [ - n for n in nx.ancestors(it_g, source_node_id) if isinstance(self.graph.get_node(n), IterateInvocation) - ] - - # Order iterators outer->inner via topo order of the iterator graph. - topo = list(nx.topological_sort(it_g)) - topo_index = {n: i for i, n in enumerate(topo)} - iterator_sources.sort(key=lambda n: topo_index.get(n, 0)) - - # Map iterator source nodes to the prepared iterator exec nodes that are ancestors of exec_node_id. - eg = self.execution_graph.nx_graph() - path: list[int] = [] - for it_src in iterator_sources: - prepared = self.source_prepared_mapping.get(it_src) - if not prepared: - continue - it_exec = next((p for p in prepared if nx.has_path(eg, p, exec_node_id)), None) - if it_exec is None: - continue - it_node = self.execution_graph.nodes.get(it_exec) - if isinstance(it_node, IterateInvocation): - path.append(it_node.index) - - # If this exec node is itself an iterator, include its own index as the innermost element. - node_obj = self.execution_graph.nodes.get(exec_node_id) - if isinstance(node_obj, IterateInvocation): - path.append(node_obj.index) - - result = tuple(path) - self._iteration_path_cache[exec_node_id] = result - return result + return self._runtime().get_iteration_path(exec_node_id) def _queue_for(self, cls_name: str) -> Deque[str]: - q = self._ready_queues.get(cls_name) - if q is None: - q = deque() - self._ready_queues[cls_name] = q - return q + return self._scheduler().queue_for(cls_name) + + def _is_deferred_by_unresolved_if(self, exec_node_id: str) -> bool: + return self._if_scheduler().is_deferred_by_unresolved_if(exec_node_id) + + def _remove_from_ready_queues(self, exec_node_id: str) -> None: + self._scheduler().remove_from_ready_queues(exec_node_id) + + def _try_resolve_if_node(self, exec_node_id: str) -> None: + self._if_scheduler().try_resolve_if_node(exec_node_id) def set_ready_order(self, order: Iterable[Type[BaseInvocation] | str]) -> None: names: list[str] = [] @@ -877,24 +1805,19 @@ class GraphExecutionState(BaseModel): self.ready_order = names def _enqueue_if_ready(self, nid: str) -> None: - """Push nid to its class queue if unmet inputs == 0.""" - # Invariants: exec node exists and has an indegree entry - if nid not in self.execution_graph.nodes: - raise KeyError(f"exec node {nid} missing from execution_graph") - if nid not in self.indegree: - raise KeyError(f"indegree missing for exec node {nid}") - if self.indegree[nid] != 0 or nid in self.executed: - return - node_obj = self.execution_graph.nodes[nid] - q = self._queue_for(self._type_key(node_obj)) - nid_path = self._get_iteration_path(nid) - # Insert in lexicographic outer->inner order; preserve FIFO for equal paths. - for i, existing in enumerate(q): - if self._get_iteration_path(existing) > nid_path: - q.insert(i, nid) - break - else: - q.append(nid) + self._scheduler().enqueue_if_ready(nid) + + def _prepare_until_node_ready(self) -> Optional[BaseInvocation]: + base_graph = self.graph.nx_graph_flat() + prepared_id = self._materializer().prepare(base_graph) + next_node: Optional[BaseInvocation] = None + + while prepared_id is not None: + prepared_id = self._materializer().prepare(base_graph) + if next_node is None: + next_node = self._get_next_node() + + return next_node model_config = ConfigDict( json_schema_extra={ @@ -927,14 +1850,7 @@ class GraphExecutionState(BaseModel): # If there are no prepared nodes, prepare some nodes next_node = self._get_next_node() if next_node is None: - base_g = self.graph.nx_graph_flat() - prepared_id = self._prepare(base_g) - - # Prepare as many nodes as we can - while prepared_id is not None: - prepared_id = self._prepare(base_g) - if next_node is None: - next_node = self._get_next_node() + next_node = self._prepare_until_node_ready() # Get values from edges if next_node is not None: @@ -948,33 +1864,7 @@ class GraphExecutionState(BaseModel): def complete(self, node_id: str, output: BaseInvocationOutput) -> None: """Marks a node as complete""" - - if node_id not in self.execution_graph.nodes: - return # TODO: log error? - - # Mark node as executed - self.executed.add(node_id) - self.results[node_id] = output - - # Check if source node is complete (all prepared nodes are complete) - source_node = self.prepared_source_mapping[node_id] - prepared_nodes = self.source_prepared_mapping[source_node] - - if all(n in self.executed for n in prepared_nodes): - self.executed.add(source_node) - self.executed_history.append(source_node) - - # Decrement children indegree and enqueue when ready - for e in self.execution_graph._get_output_edges(node_id): - child = e.destination.node_id - if child not in self.indegree: - raise KeyError(f"indegree missing for exec node {child}") - # Only decrement if there's something to satisfy - if self.indegree[child] == 0: - raise RuntimeError(f"indegree underflow for {child} from parent {node_id}") - self.indegree[child] -= 1 - if self.indegree[child] == 0: - self._enqueue_if_ready(child) + self._scheduler().complete(node_id, output) def set_node_error(self, node_id: str, error: str): """Marks a node as errored""" @@ -990,164 +1880,16 @@ class GraphExecutionState(BaseModel): return len(self.errors) > 0 def _create_execution_node(self, node_id: str, iteration_node_map: list[tuple[str, str]]) -> list[str]: - """Prepares an iteration node and connects all edges, returning the new node id""" - - node = self.graph.get_node(node_id) - - self_iteration_count = -1 - - # If this is an iterator node, we must create a copy for each iteration - if isinstance(node, IterateInvocation): - # Get input collection edge (should error if there are no inputs) - input_collection_edge = next(iter(self.graph._get_input_edges(node_id, COLLECTION_FIELD))) - input_collection_prepared_node_id = next( - n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id - ) - input_collection_prepared_node_output = self.results[input_collection_prepared_node_id] - input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field) - self_iteration_count = len(input_collection) - - new_nodes: list[str] = [] - if self_iteration_count == 0: - # TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid. - return new_nodes - - # Get all input edges - input_edges = self.graph._get_input_edges(node_id) - - # Create new edges for this iteration - # For collect nodes, this may contain multiple inputs to the same field - new_edges: list[Edge] = [] - for edge in input_edges: - for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id): - new_edge = Edge( - source=EdgeConnection(node_id=input_node_id, field=edge.source.field), - destination=EdgeConnection(node_id="", field=edge.destination.field), - ) - new_edges.append(new_edge) - - # Create a new node (or one for each iteration of this iterator) - for i in range(self_iteration_count) if self_iteration_count > 0 else [-1]: - # Create a new node - new_node = node.model_copy(deep=True) - - # Create the node id (use a random uuid) - new_node.id = uuid_string() - - # Set the iteration index for iteration invocations - if isinstance(new_node, IterateInvocation): - new_node.index = i - - # Add to execution graph - self.execution_graph.add_node(new_node) - self.prepared_source_mapping[new_node.id] = node_id - if node_id not in self.source_prepared_mapping: - self.source_prepared_mapping[node_id] = set() - self.source_prepared_mapping[node_id].add(new_node.id) - - # Add new edges to execution graph - for edge in new_edges: - new_edge = Edge( - source=edge.source, - destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field), - ) - self.execution_graph.add_edge(new_edge) - - # Initialize indegree as unmet inputs only and enqueue if ready - inputs = self.execution_graph._get_input_edges(new_node.id) - unmet = sum(1 for e in inputs if e.source.node_id not in self.executed) - self.indegree[new_node.id] = unmet - self._enqueue_if_ready(new_node.id) - - new_nodes.append(new_node.id) - - return new_nodes + return self._materializer().create_execution_node(node_id, iteration_node_map) def _iterator_graph(self, base: Optional[nx.DiGraph] = None) -> nx.DiGraph: - """Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node""" - g = base.copy() if base is not None else self.graph.nx_graph_flat() - collectors = (n for n in self.graph.nodes if isinstance(self.graph.get_node(n), CollectInvocation)) - for c in collectors: - g.remove_edges_from(list(g.in_edges(c))) - return g + return self._materializer().iterator_graph(base) def _get_node_iterators(self, node_id: str, it_graph: Optional[nx.DiGraph] = None) -> list[str]: - """Gets iterators for a node""" - g = it_graph or self._iterator_graph() - return [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.get_node(n), IterateInvocation)] + return self._materializer().get_node_iterators(node_id, it_graph) def _prepare(self, base_g: Optional[nx.DiGraph] = None) -> Optional[str]: - # Get flattened source graph - g = base_g or self.graph.nx_graph_flat() - - # Find next node that: - # - was not already prepared - # - is not an iterate node whose inputs have not been executed - # - does not have an unexecuted iterate ancestor - sorted_nodes = nx.topological_sort(g) - - def unprepared(n: str) -> bool: - return n not in self.source_prepared_mapping - - def iter_inputs_ready(n: str) -> bool: - if not isinstance(self.graph.get_node(n), IterateInvocation): - return True - return all(u in self.executed for u, _ in g.in_edges(n)) - - def no_unexecuted_iter_ancestors(n: str) -> bool: - return not any( - isinstance(self.graph.get_node(a), IterateInvocation) and a not in self.executed - for a in nx.ancestors(g, n) - ) - - next_node_id = next( - (n for n in sorted_nodes if unprepared(n) and iter_inputs_ready(n) and no_unexecuted_iter_ancestors(n)), - None, - ) - - if next_node_id is None: - return None - - # Get all parents of the next node - next_node_parents = [u for u, _ in g.in_edges(next_node_id)] - - # Create execution nodes - next_node = self.graph.get_node(next_node_id) - new_node_ids = [] - if isinstance(next_node, CollectInvocation): - # Collapse all iterator input mappings and create a single execution node for the collect invocation - all_iteration_mappings = [] - for source_node_id in next_node_parents: - prepared_nodes = self.source_prepared_mapping[source_node_id] - all_iteration_mappings.extend([(source_node_id, p) for p in prepared_nodes]) - - create_results = self._create_execution_node(next_node_id, all_iteration_mappings) - if create_results is not None: - new_node_ids.extend(create_results) - else: # Iterators or normal nodes - # Get all iterator combinations for this node - # Will produce a list of lists of prepared iterator nodes, from which results can be iterated - it_g = self._iterator_graph(g) - iterator_nodes = self._get_node_iterators(next_node_id, it_g) - iterator_nodes_prepared = [list(self.source_prepared_mapping[n]) for n in iterator_nodes] - iterator_node_prepared_combinations = list(itertools.product(*iterator_nodes_prepared)) - - # Select the correct prepared parents for each iteration - # For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator - eg = self.execution_graph.nx_graph_flat() - prepared_parent_mappings = [ - [(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] - for it in iterator_node_prepared_combinations - ] # type: ignore - prepared_parent_mappings = [m for m in prepared_parent_mappings if all(p[1] is not None for p in m)] - - # Create execution node for each iteration - for iteration_mappings in prepared_parent_mappings: - create_results = self._create_execution_node(next_node_id, iteration_mappings) # type: ignore - if create_results is not None: - new_node_ids.extend(create_results) - - return next(iter(new_node_ids), None) + return self._materializer().prepare(base_g) def _get_iteration_node( self, @@ -1156,71 +1898,13 @@ class GraphExecutionState(BaseModel): execution_graph: nx.DiGraph, prepared_iterator_nodes: list[str], ) -> Optional[str]: - """Gets the prepared version of the specified source node that matches every iteration specified""" - prepared_nodes = self.source_prepared_mapping[source_node_id] - if len(prepared_nodes) == 1: - return next(iter(prepared_nodes)) - - # Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source) - iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes] - parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_id)] - - # If the requested node is an iterator, only accept it if it is compatible with all parent iterators - prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None) - if prepared_iterator is not None: - if all(nx.has_path(execution_graph, pit[0], prepared_iterator) for pit in parent_iterators): - return prepared_iterator - return None - - return next( - (n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)), - None, - ) + return self._materializer().get_iteration_node(source_node_id, graph, execution_graph, prepared_iterator_nodes) def _get_next_node(self) -> Optional[BaseInvocation]: - """Gets the next ready node: FIFO within class, drain class before switching.""" - # 1) Continue draining the active class - if self._active_class: - q = self._ready_queues.get(self._active_class) - while q: - nid = q.popleft() - if nid not in self.executed: - return self.execution_graph.nodes[nid] - # emptied: release active class - self._active_class = None - - # 2) Pick next class by priority, then by class name - seen = set(self.ready_order) - for cls_name in self.ready_order: - q = self._ready_queues.get(cls_name) - if q: - self._active_class = cls_name - # recurse to drain newly set active class - return self._get_next_node() - for cls_name in sorted(k for k in self._ready_queues.keys() if k not in seen): - q = self._ready_queues[cls_name] - if q: - self._active_class = cls_name - return self._get_next_node() - return None + return self._scheduler().get_next_node() def _prepare_inputs(self, node: BaseInvocation): - input_edges = self.execution_graph._get_input_edges(node.id) - # Inputs must be deep-copied, else if a node mutates the object, other nodes that get the same input - # will see the mutation. - if isinstance(node, CollectInvocation): - item_edges = [e for e in input_edges if e.destination.field == ITEM_FIELD] - item_edges.sort(key=lambda e: (self._get_iteration_path(e.source.node_id), e.source.node_id)) - - output_collection = [copydeep(getattr(self.results[e.source.node_id], e.source.field)) for e in item_edges] - node.collection = output_collection - else: - for edge in input_edges: - setattr( - node, - edge.destination.field, - copydeep(getattr(self.results[edge.source.node_id], edge.source.field)), - ) + self._runtime().prepare_inputs(node) # TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state def _is_edge_valid(self, edge: Edge) -> bool: diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 645509f1dd..fb8ca9fca3 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -30,6 +30,8 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_24 import from invokeai.app.services.shared.sqlite_migrator.migrations.migration_25 import build_migration_25 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_26 import build_migration_26 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_27 import build_migration_27 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_28 import build_migration_28 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_29 import build_migration_29 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -77,6 +79,8 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_25(app_config=config, logger=logger)) migrator.register_migration(build_migration_26(app_config=config, logger=logger)) migrator.register_migration(build_migration_27()) + migrator.register_migration(build_migration_28()) + migrator.register_migration(build_migration_29()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_28.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_28.py new file mode 100644 index 0000000000..60e5d8f19b --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_28.py @@ -0,0 +1,48 @@ +"""Migration 28: Add per-user workflow isolation columns to workflow_library. + +This migration adds the database columns required for multiuser workflow isolation +to the workflow_library table: +- user_id: the owner of the workflow (defaults to 'system' for existing workflows) +- is_public: whether the workflow is shared with all users +""" + +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration28Callback: + """Migration to add user_id and is_public to the workflow_library table.""" + + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._update_workflow_library_table(cursor) + + def _update_workflow_library_table(self, cursor: sqlite3.Cursor) -> None: + """Add user_id and is_public columns to workflow_library table.""" + cursor.execute("PRAGMA table_info(workflow_library);") + columns = [row[1] for row in cursor.fetchall()] + + if "user_id" not in columns: + cursor.execute("ALTER TABLE workflow_library ADD COLUMN user_id TEXT DEFAULT 'system';") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_workflow_library_user_id ON workflow_library(user_id);") + + if "is_public" not in columns: + cursor.execute("ALTER TABLE workflow_library ADD COLUMN is_public BOOLEAN NOT NULL DEFAULT FALSE;") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_workflow_library_is_public ON workflow_library(is_public);") + cursor.execute( + "UPDATE workflow_library SET is_public = TRUE WHERE user_id = 'system';" + ) # one-time fix for legacy workflows + + +def build_migration_28() -> Migration: + """Builds the migration object for migrating from version 27 to version 28. + + This migration adds per-user workflow isolation to the workflow_library table: + - user_id column: identifies the owner of each workflow + - is_public column: controls whether a workflow is shared with all users + """ + return Migration( + from_version=27, + to_version=28, + callback=Migration28Callback(), + ) diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_29.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_29.py new file mode 100644 index 0000000000..c9eb7c901b --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_29.py @@ -0,0 +1,53 @@ +"""Migration 29: Add board_visibility column to boards table. + +This migration adds a board_visibility column to the boards table to support +three visibility levels: + - 'private': only the board owner (and admins) can view/modify + - 'shared': all users can view, but only the owner (and admins) can modify + - 'public': all users can view; only the owner (and admins) can modify the + board structure (rename/archive/delete) + +Existing boards with is_public = 1 are migrated to 'public'. +All other existing boards default to 'private'. +""" + +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration29Callback: + """Migration to add board_visibility column to the boards table.""" + + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._update_boards_table(cursor) + + def _update_boards_table(self, cursor: sqlite3.Cursor) -> None: + """Add board_visibility column to boards table.""" + # Check if boards table exists + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='boards';") + if cursor.fetchone() is None: + return + + cursor.execute("PRAGMA table_info(boards);") + columns = [row[1] for row in cursor.fetchall()] + + if "board_visibility" not in columns: + cursor.execute("ALTER TABLE boards ADD COLUMN board_visibility TEXT NOT NULL DEFAULT 'private';") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_boards_board_visibility ON boards(board_visibility);") + # Migrate existing is_public = 1 boards to 'public' + if "is_public" in columns: + cursor.execute("UPDATE boards SET board_visibility = 'public' WHERE is_public = 1;") + + +def build_migration_29() -> Migration: + """Builds the migration object for migrating from version 28 to version 29. + + This migration adds the board_visibility column to the boards table, + supporting 'private', 'shared', and 'public' visibility levels. + """ + return Migration( + from_version=28, + to_version=29, + callback=Migration29Callback(), + ) diff --git a/invokeai/app/services/users/users_base.py b/invokeai/app/services/users/users_base.py index 728a0adfa3..dd789b561e 100644 --- a/invokeai/app/services/users/users_base.py +++ b/invokeai/app/services/users/users_base.py @@ -131,6 +131,15 @@ class UserServiceBase(ABC): """ pass + @abstractmethod + def get_admin_email(self) -> str | None: + """Get the email address of the first active admin user. + + Returns: + Email address of the first active admin, or None if no admin exists + """ + pass + @abstractmethod def count_admins(self) -> int: """Count active admin users. diff --git a/invokeai/app/services/users/users_default.py b/invokeai/app/services/users/users_default.py index 709e4cb82c..6e47288212 100644 --- a/invokeai/app/services/users/users_default.py +++ b/invokeai/app/services/users/users_default.py @@ -256,6 +256,20 @@ class UserService(UserServiceBase): for row in rows ] + def get_admin_email(self) -> str | None: + """Get the email address of the first active admin user.""" + with self._db.transaction() as cursor: + cursor.execute( + """ + SELECT email FROM users + WHERE is_admin = TRUE AND is_active = TRUE + ORDER BY created_at ASC + LIMIT 1 + """, + ) + row = cursor.fetchone() + return row[0] if row else None + def count_admins(self) -> int: """Count active admin users.""" with self._db.transaction() as cursor: diff --git a/invokeai/app/services/workflow_records/workflow_records_base.py b/invokeai/app/services/workflow_records/workflow_records_base.py index d5cf319594..c07daa2662 100644 --- a/invokeai/app/services/workflow_records/workflow_records_base.py +++ b/invokeai/app/services/workflow_records/workflow_records_base.py @@ -4,6 +4,7 @@ from typing import Optional from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.services.workflow_records.workflow_records_common import ( + WORKFLOW_LIBRARY_DEFAULT_USER_ID, Workflow, WorkflowCategory, WorkflowRecordDTO, @@ -22,18 +23,23 @@ class WorkflowRecordsStorageBase(ABC): pass @abstractmethod - def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO: + def create( + self, + workflow: WorkflowWithoutID, + user_id: str = WORKFLOW_LIBRARY_DEFAULT_USER_ID, + is_public: bool = False, + ) -> WorkflowRecordDTO: """Creates a workflow.""" pass @abstractmethod - def update(self, workflow: Workflow) -> WorkflowRecordDTO: - """Updates a workflow.""" + def update(self, workflow: Workflow, user_id: Optional[str] = None) -> WorkflowRecordDTO: + """Updates a workflow. When user_id is provided, the UPDATE is scoped to that user.""" pass @abstractmethod - def delete(self, workflow_id: str) -> None: - """Deletes a workflow.""" + def delete(self, workflow_id: str, user_id: Optional[str] = None) -> None: + """Deletes a workflow. When user_id is provided, the DELETE is scoped to that user.""" pass @abstractmethod @@ -47,6 +53,8 @@ class WorkflowRecordsStorageBase(ABC): query: Optional[str], tags: Optional[list[str]], has_been_opened: Optional[bool], + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> PaginatedResults[WorkflowRecordListItemDTO]: """Gets many workflows.""" pass @@ -56,6 +64,8 @@ class WorkflowRecordsStorageBase(ABC): self, categories: list[WorkflowCategory], has_been_opened: Optional[bool] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> dict[str, int]: """Gets a dictionary of counts for each of the provided categories.""" pass @@ -66,19 +76,28 @@ class WorkflowRecordsStorageBase(ABC): tags: list[str], categories: Optional[list[WorkflowCategory]] = None, has_been_opened: Optional[bool] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> dict[str, int]: """Gets a dictionary of counts for each of the provided tags.""" pass @abstractmethod - def update_opened_at(self, workflow_id: str) -> None: - """Open a workflow.""" + def update_opened_at(self, workflow_id: str, user_id: Optional[str] = None) -> None: + """Open a workflow. When user_id is provided, the UPDATE is scoped to that user.""" pass @abstractmethod def get_all_tags( self, categories: Optional[list[WorkflowCategory]] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> list[str]: """Gets all unique tags from workflows.""" pass + + @abstractmethod + def update_is_public(self, workflow_id: str, is_public: bool, user_id: Optional[str] = None) -> WorkflowRecordDTO: + """Updates the is_public field of a workflow. When user_id is provided, the UPDATE is scoped to that user.""" + pass diff --git a/invokeai/app/services/workflow_records/workflow_records_common.py b/invokeai/app/services/workflow_records/workflow_records_common.py index e0cea37468..9c505530c9 100644 --- a/invokeai/app/services/workflow_records/workflow_records_common.py +++ b/invokeai/app/services/workflow_records/workflow_records_common.py @@ -9,6 +9,9 @@ from invokeai.app.util.metaenum import MetaEnum __workflow_meta_version__ = semver.Version.parse("1.0.0") +WORKFLOW_LIBRARY_DEFAULT_USER_ID = "system" +"""Default user_id for workflows created in single-user mode or migrated from pre-multiuser databases.""" + class ExposedField(BaseModel): nodeId: str @@ -26,6 +29,7 @@ class WorkflowRecordOrderBy(str, Enum, metaclass=MetaEnum): UpdatedAt = "updated_at" OpenedAt = "opened_at" Name = "name" + IsPublic = "is_public" class WorkflowCategory(str, Enum, metaclass=MetaEnum): @@ -100,6 +104,8 @@ class WorkflowRecordDTOBase(BaseModel): opened_at: Optional[Union[datetime.datetime, str]] = Field( default=None, description="The opened timestamp of the workflow." ) + user_id: str = Field(description="The id of the user who owns this workflow.") + is_public: bool = Field(description="Whether this workflow is shared with all users.") class WorkflowRecordDTO(WorkflowRecordDTOBase): diff --git a/invokeai/app/services/workflow_records/workflow_records_sqlite.py b/invokeai/app/services/workflow_records/workflow_records_sqlite.py index 0f72f7cd92..a62dbb9dfa 100644 --- a/invokeai/app/services/workflow_records/workflow_records_sqlite.py +++ b/invokeai/app/services/workflow_records/workflow_records_sqlite.py @@ -7,6 +7,7 @@ from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase from invokeai.app.services.workflow_records.workflow_records_common import ( + WORKFLOW_LIBRARY_DEFAULT_USER_ID, Workflow, WorkflowCategory, WorkflowNotFoundError, @@ -36,7 +37,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): with self._db.transaction() as cursor: cursor.execute( """--sql - SELECT workflow_id, workflow, name, created_at, updated_at, opened_at + SELECT workflow_id, workflow, name, created_at, updated_at, opened_at, user_id, is_public FROM workflow_library WHERE workflow_id = ?; """, @@ -47,7 +48,12 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found") return WorkflowRecordDTO.from_dict(dict(row)) - def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO: + def create( + self, + workflow: WorkflowWithoutID, + user_id: str = WORKFLOW_LIBRARY_DEFAULT_USER_ID, + is_public: bool = False, + ) -> WorkflowRecordDTO: if workflow.meta.category is WorkflowCategory.Default: raise ValueError("Default workflows cannot be created via this method") @@ -57,43 +63,99 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): """--sql INSERT OR IGNORE INTO workflow_library ( workflow_id, - workflow + workflow, + user_id, + is_public ) - VALUES (?, ?); + VALUES (?, ?, ?, ?); """, - (workflow_with_id.id, workflow_with_id.model_dump_json()), + (workflow_with_id.id, workflow_with_id.model_dump_json(), user_id, is_public), ) return self.get(workflow_with_id.id) - def update(self, workflow: Workflow) -> WorkflowRecordDTO: + def update(self, workflow: Workflow, user_id: Optional[str] = None) -> WorkflowRecordDTO: if workflow.meta.category is WorkflowCategory.Default: raise ValueError("Default workflows cannot be updated") with self._db.transaction() as cursor: - cursor.execute( - """--sql - UPDATE workflow_library - SET workflow = ? - WHERE workflow_id = ? AND category = 'user'; - """, - (workflow.model_dump_json(), workflow.id), - ) + if user_id is not None: + cursor.execute( + """--sql + UPDATE workflow_library + SET workflow = ? + WHERE workflow_id = ? AND category = 'user' AND user_id = ?; + """, + (workflow.model_dump_json(), workflow.id, user_id), + ) + else: + cursor.execute( + """--sql + UPDATE workflow_library + SET workflow = ? + WHERE workflow_id = ? AND category = 'user'; + """, + (workflow.model_dump_json(), workflow.id), + ) return self.get(workflow.id) - def delete(self, workflow_id: str) -> None: + def delete(self, workflow_id: str, user_id: Optional[str] = None) -> None: if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default: raise ValueError("Default workflows cannot be deleted") with self._db.transaction() as cursor: - cursor.execute( - """--sql - DELETE from workflow_library - WHERE workflow_id = ? AND category = 'user'; - """, - (workflow_id,), - ) + if user_id is not None: + cursor.execute( + """--sql + DELETE from workflow_library + WHERE workflow_id = ? AND category = 'user' AND user_id = ?; + """, + (workflow_id, user_id), + ) + else: + cursor.execute( + """--sql + DELETE from workflow_library + WHERE workflow_id = ? AND category = 'user'; + """, + (workflow_id,), + ) return None + def update_is_public(self, workflow_id: str, is_public: bool, user_id: Optional[str] = None) -> WorkflowRecordDTO: + """Updates the is_public field of a workflow and manages the 'shared' tag automatically.""" + record = self.get(workflow_id) + workflow = record.workflow + + # Manage "shared" tag: add when public, remove when private + tags_list = [t.strip() for t in workflow.tags.split(",") if t.strip()] if workflow.tags else [] + if is_public and "shared" not in tags_list: + tags_list.append("shared") + elif not is_public and "shared" in tags_list: + tags_list.remove("shared") + updated_tags = ", ".join(tags_list) + updated_workflow = workflow.model_copy(update={"tags": updated_tags}) + + with self._db.transaction() as cursor: + if user_id is not None: + cursor.execute( + """--sql + UPDATE workflow_library + SET workflow = ?, is_public = ? + WHERE workflow_id = ? AND category = 'user' AND user_id = ?; + """, + (updated_workflow.model_dump_json(), is_public, workflow_id, user_id), + ) + else: + cursor.execute( + """--sql + UPDATE workflow_library + SET workflow = ?, is_public = ? + WHERE workflow_id = ? AND category = 'user'; + """, + (updated_workflow.model_dump_json(), is_public, workflow_id), + ) + return self.get(workflow_id) + def get_many( self, order_by: WorkflowRecordOrderBy, @@ -104,6 +166,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): query: Optional[str] = None, tags: Optional[list[str]] = None, has_been_opened: Optional[bool] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> PaginatedResults[WorkflowRecordListItemDTO]: with self._db.transaction() as cursor: # sanitize! @@ -122,7 +186,9 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): created_at, updated_at, opened_at, - tags + tags, + user_id, + is_public FROM workflow_library """ count_query = "SELECT COUNT(*) FROM workflow_library" @@ -177,6 +243,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): conditions.append(query_condition) params.extend([wildcard_query, wildcard_query, wildcard_query]) + if user_id is not None: + # Scope to the given user but always include default workflows + conditions.append("(user_id = ? OR category = 'default')") + params.append(user_id) + + if is_public is True: + conditions.append("is_public = TRUE") + elif is_public is False: + conditions.append("is_public = FALSE") + if conditions: # If there are conditions, add a WHERE clause and then join the conditions main_query += " WHERE " @@ -226,6 +302,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): tags: list[str], categories: Optional[list[WorkflowCategory]] = None, has_been_opened: Optional[bool] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> dict[str, int]: if not tags: return {} @@ -248,6 +326,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): elif has_been_opened is False: base_conditions.append("opened_at IS NULL") + if user_id is not None: + # Scope to the given user but always include default workflows + base_conditions.append("(user_id = ? OR category = 'default')") + base_params.append(user_id) + + if is_public is True: + base_conditions.append("is_public = TRUE") + elif is_public is False: + base_conditions.append("is_public = FALSE") + # For each tag to count, run a separate query for tag in tags: # Start with the base conditions @@ -277,6 +365,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): self, categories: list[WorkflowCategory], has_been_opened: Optional[bool] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> dict[str, int]: with self._db.transaction() as cursor: result: dict[str, int] = {} @@ -296,6 +386,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): elif has_been_opened is False: base_conditions.append("opened_at IS NULL") + if user_id is not None: + # Scope to the given user but always include default workflows + base_conditions.append("(user_id = ? OR category = 'default')") + base_params.append(user_id) + + if is_public is True: + base_conditions.append("is_public = TRUE") + elif is_public is False: + base_conditions.append("is_public = FALSE") + # For each category to count, run a separate query for category in categories: # Start with the base conditions @@ -321,20 +421,32 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): return result - def update_opened_at(self, workflow_id: str) -> None: + def update_opened_at(self, workflow_id: str, user_id: Optional[str] = None) -> None: with self._db.transaction() as cursor: - cursor.execute( - f"""--sql - UPDATE workflow_library - SET opened_at = STRFTIME('{SQL_TIME_FORMAT}', 'NOW') - WHERE workflow_id = ?; - """, - (workflow_id,), - ) + if user_id is not None: + cursor.execute( + f"""--sql + UPDATE workflow_library + SET opened_at = STRFTIME('{SQL_TIME_FORMAT}', 'NOW') + WHERE workflow_id = ? AND user_id = ?; + """, + (workflow_id, user_id), + ) + else: + cursor.execute( + f"""--sql + UPDATE workflow_library + SET opened_at = STRFTIME('{SQL_TIME_FORMAT}', 'NOW') + WHERE workflow_id = ?; + """, + (workflow_id,), + ) def get_all_tags( self, categories: Optional[list[WorkflowCategory]] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> list[str]: with self._db.transaction() as cursor: conditions: list[str] = [] @@ -349,6 +461,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): conditions.append(f"category IN ({placeholders})") params.extend([category.value for category in categories]) + if user_id is not None: + # Scope to the given user but always include default workflows + conditions.append("(user_id = ? OR category = 'default')") + params.append(user_id) + + if is_public is True: + conditions.append("is_public = TRUE") + elif is_public is False: + conditions.append("is_public = FALSE") + stmt = """--sql SELECT DISTINCT tags FROM workflow_library diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index 990fdd51d8..08dc9a2265 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -93,6 +93,29 @@ COGVIEW4_LATENT_RGB_FACTORS = [ [-0.00955853, -0.00980067, -0.00977842], ] +# Qwen Image uses the same VAE as Wan 2.1 (16-channel). +# Factors from ComfyUI: https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py +QWEN_IMAGE_LATENT_RGB_FACTORS = [ + [-0.1299, -0.1692, 0.2932], + [0.0671, 0.0406, 0.0442], + [0.3568, 0.2548, 0.1747], + [0.0372, 0.2344, 0.1420], + [0.0313, 0.0189, -0.0328], + [0.0296, -0.0956, -0.0665], + [-0.3477, -0.4059, -0.2925], + [0.0166, 0.1902, 0.1975], + [-0.0412, 0.0267, -0.1364], + [-0.1293, 0.0740, 0.1636], + [0.0680, 0.3019, 0.1128], + [0.0032, 0.0581, 0.0639], + [-0.1251, 0.0927, 0.1699], + [0.0060, -0.0633, 0.0005], + [0.3477, 0.2275, 0.2950], + [0.1984, 0.0913, 0.1861], +] + +QWEN_IMAGE_LATENT_RGB_BIAS = [-0.1835, -0.0868, -0.3360] + # FLUX.2 uses 32 latent channels. # Factors from ComfyUI: https://github.com/Comfy-Org/ComfyUI/blob/main/comfy/latent_formats.py FLUX2_LATENT_RGB_FACTORS = [ @@ -133,6 +156,29 @@ FLUX2_LATENT_RGB_FACTORS = [ FLUX2_LATENT_RGB_BIAS = [-0.0329, -0.0718, -0.0851] +# Anima uses Wan 2.1 VAE with 16 latent channels. +# Factors from ComfyUI: https://github.com/Comfy-Org/ComfyUI/blob/main/comfy/latent_formats.py +ANIMA_LATENT_RGB_FACTORS = [ + [-0.1299, -0.1692, 0.2932], + [0.0671, 0.0406, 0.0442], + [0.3568, 0.2548, 0.1747], + [0.0372, 0.2344, 0.1420], + [0.0313, 0.0189, -0.0328], + [0.0296, -0.0956, -0.0665], + [-0.3477, -0.4059, -0.2925], + [0.0166, 0.1902, 0.1975], + [-0.0412, 0.0267, -0.1364], + [-0.1293, 0.0740, 0.1636], + [0.0680, 0.3019, 0.1128], + [0.0032, 0.0581, 0.0639], + [-0.1251, 0.0927, 0.1699], + [0.0060, -0.0633, 0.0005], + [0.3477, 0.2275, 0.2950], + [0.1984, 0.0913, 0.1861], +] + +ANIMA_LATENT_RGB_BIAS = [-0.1835, -0.0868, -0.3360] + def sample_to_lowres_estimated_image( samples: torch.Tensor, @@ -209,6 +255,9 @@ def diffusion_step_callback( latent_rgb_factors = SD3_5_LATENT_RGB_FACTORS elif base_model == BaseModelType.CogView4: latent_rgb_factors = COGVIEW4_LATENT_RGB_FACTORS + elif base_model == BaseModelType.QwenImage: + latent_rgb_factors = QWEN_IMAGE_LATENT_RGB_FACTORS + latent_rgb_bias = QWEN_IMAGE_LATENT_RGB_BIAS elif base_model == BaseModelType.Flux: latent_rgb_factors = FLUX_LATENT_RGB_FACTORS elif base_model == BaseModelType.Flux2: @@ -217,6 +266,10 @@ def diffusion_step_callback( elif base_model == BaseModelType.ZImage: # Z-Image uses FLUX-compatible VAE with 16 latent channels latent_rgb_factors = FLUX_LATENT_RGB_FACTORS + elif base_model == BaseModelType.Anima: + # Anima uses Wan 2.1 VAE with 16 latent channels + latent_rgb_factors = ANIMA_LATENT_RGB_FACTORS + latent_rgb_bias = ANIMA_LATENT_RGB_BIAS else: raise ValueError(f"Unsupported base model: {base_model}") diff --git a/invokeai/backend/anima/__init__.py b/invokeai/backend/anima/__init__.py new file mode 100644 index 0000000000..01a1a952e9 --- /dev/null +++ b/invokeai/backend/anima/__init__.py @@ -0,0 +1,6 @@ +"""Anima model backend module. + +Anima is a 2B-parameter anime-focused text-to-image model built on NVIDIA's +Cosmos Predict2 DiT architecture with a custom LLM Adapter that bridges Qwen3 +0.6B text encoder outputs to the DiT backbone. +""" diff --git a/invokeai/backend/anima/anima_transformer.py b/invokeai/backend/anima/anima_transformer.py new file mode 100644 index 0000000000..36c5764e97 --- /dev/null +++ b/invokeai/backend/anima/anima_transformer.py @@ -0,0 +1,1040 @@ +"""Anima transformer model: Cosmos Predict2 MiniTrainDIT + LLM Adapter. + +The Anima architecture combines: +1. MiniTrainDIT: A Cosmos Predict2 DiT backbone with 28 blocks, 2048-dim hidden state, + and 3D RoPE positional embeddings. +2. LLMAdapter: A 6-layer cross-attention transformer that fuses Qwen3 0.6B hidden states + with learned T5-XXL token embeddings to produce conditioning for the DiT. + +Original source code: +- MiniTrainDIT backbone and positional embeddings: https://github.com/nvidia-cosmos/cosmos-predict2 + SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + SPDX-License-Identifier: Apache-2.0 +- LLMAdapter and Anima wrapper: Clean-room implementation based on + https://github.com/hdae/diffusers-anima (Apache-2.0) +""" + +import logging +import math +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from torch import nn + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Positional Embeddings +# Original source: https://github.com/nvidia-cosmos/cosmos-predict2 +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. Apache-2.0 +# ============================================================================ + + +class VideoRopePosition3DEmb(nn.Module): + """3D Rotary Position Embedding for video/image transformers. + + Generates rotary embeddings with separate frequency components for + height, width, and temporal dimensions. + """ + + def __init__( + self, + *, + head_dim: int, + len_h: int, + len_w: int, + len_t: int, + base_fps: int = 24, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + enable_fps_modulation: bool = True, + device: Optional[torch.device] = None, + **kwargs, + ): + super().__init__() + self.base_fps = base_fps + self.max_h = len_h + self.max_w = len_w + self.enable_fps_modulation = enable_fps_modulation + + dim = head_dim + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + + self.register_buffer( + "dim_spatial_range", + torch.arange(0, dim_h, 2, device=device)[: (dim_h // 2)].float() / dim_h, + persistent=False, + ) + self.register_buffer( + "dim_temporal_range", + torch.arange(0, dim_t, 2, device=device)[: (dim_t // 2)].float() / dim_t, + persistent=False, + ) + + self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) + self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) + self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) + + def forward( + self, + x_B_T_H_W_C: torch.Tensor, + fps: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + return self.generate_embeddings(x_B_T_H_W_C.shape, fps=fps, device=device) + + def generate_embeddings( + self, + B_T_H_W_C: torch.Size, + fps: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + h_theta = 10000.0 * self.h_ntk_factor + w_theta = 10000.0 * self.w_ntk_factor + t_theta = 10000.0 * self.t_ntk_factor + + h_spatial_freqs = 1.0 / (h_theta ** self.dim_spatial_range.to(device=device)) + w_spatial_freqs = 1.0 / (w_theta ** self.dim_spatial_range.to(device=device)) + temporal_freqs = 1.0 / (t_theta ** self.dim_temporal_range.to(device=device)) + + B, T, H, W, _ = B_T_H_W_C + seq = torch.arange(max(H, W, T), dtype=torch.float, device=device) + + half_emb_h = torch.outer(seq[:H].to(device=device), h_spatial_freqs) + half_emb_w = torch.outer(seq[:W].to(device=device), w_spatial_freqs) + + if fps is None or self.enable_fps_modulation is False: + half_emb_t = torch.outer(seq[:T].to(device=device), temporal_freqs) + else: + half_emb_t = torch.outer(seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs) + + half_emb_h = torch.stack( + [torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1 + ) + half_emb_w = torch.stack( + [torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1 + ) + half_emb_t = torch.stack( + [torch.cos(half_emb_t), -torch.sin(half_emb_t), torch.sin(half_emb_t), torch.cos(half_emb_t)], dim=-1 + ) + + em_T_H_W_D = torch.cat( + [ + repeat(half_emb_t, "t d x -> t h w d x", h=H, w=W), + repeat(half_emb_h, "h d x -> t h w d x", t=T, w=W), + repeat(half_emb_w, "w d x -> t h w d x", t=T, h=H), + ], + dim=-2, + ) + + return rearrange(em_T_H_W_D, "t h w d (i j) -> (t h w) d i j", i=2, j=2).float() + + +def _normalize(x: torch.Tensor, dim: Optional[list[int]] = None, eps: float = 0) -> torch.Tensor: + if dim is None: + dim = list(range(1, x.ndim)) + norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) + norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel())) + return x / norm.to(x.dtype) + + +class LearnablePosEmbAxis(nn.Module): + """Learnable per-axis positional embeddings.""" + + def __init__( + self, + *, + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + **kwargs, + ): + super().__init__() + self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype)) + self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype)) + self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype)) + + def forward( + self, + x_B_T_H_W_C: torch.Tensor, + fps: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + return self.generate_embeddings(x_B_T_H_W_C.shape, device=device, dtype=dtype) + + def generate_embeddings( + self, + B_T_H_W_C: torch.Size, + fps: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + B, T, H, W, _ = B_T_H_W_C + emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype) + emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype) + emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype) + emb = ( + repeat(emb_t_T, "t d -> b t h w d", b=B, h=H, w=W) + + repeat(emb_h_H, "h d -> b t h w d", b=B, t=T, w=W) + + repeat(emb_w_W, "w d -> b t h w d", b=B, t=T, h=H) + ) + return _normalize(emb, dim=-1, eps=1e-6) + + +# ============================================================================ +# Cosmos Predict2 MiniTrainDIT +# Original source: https://github.com/nvidia-cosmos/cosmos-predict2 +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. Apache-2.0 +# ============================================================================ + + +def apply_rotary_pos_emb_cosmos(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + """Apply rotary position embeddings in Cosmos format (2x2 rotation matrices).""" + t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float() + t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1] + t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t) + return t_out + + +class GPT2FeedForward(nn.Module): + def __init__(self, d_model: int, d_ff: int) -> None: + super().__init__() + self.activation = nn.GELU() + self.layer1 = nn.Linear(d_model, d_ff, bias=False) + self.layer2 = nn.Linear(d_ff, d_model, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.layer2(self.activation(self.layer1(x))) + + +class CosmosAttention(nn.Module): + """Multi-head attention for the Cosmos DiT backbone. + + Supports both self-attention and cross-attention with QK normalization + and rotary position embeddings. + """ + + def __init__( + self, + query_dim: int, + context_dim: Optional[int] = None, + n_heads: int = 8, + head_dim: int = 64, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.is_selfattn = context_dim is None + context_dim = query_dim if context_dim is None else context_dim + inner_dim = head_dim * n_heads + + self.n_heads = n_heads + self.head_dim = head_dim + + self.q_proj = nn.Linear(query_dim, inner_dim, bias=False) + self.q_norm = nn.RMSNorm(head_dim, eps=1e-6) + + self.k_proj = nn.Linear(context_dim, inner_dim, bias=False) + self.k_norm = nn.RMSNorm(head_dim, eps=1e-6) + + self.v_proj = nn.Linear(context_dim, inner_dim, bias=False) + self.v_norm = nn.Identity() + + self.output_proj = nn.Linear(inner_dim, query_dim, bias=False) + self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity() + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + rope_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + q = self.q_proj(x) + context = x if context is None else context + k = self.k_proj(context) + v = self.v_proj(context) + q, k, v = (rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim) for t in (q, k, v)) + + q = self.q_norm(q) + k = self.k_norm(k) + v = self.v_norm(v) + + if self.is_selfattn and rope_emb is not None: + q = apply_rotary_pos_emb_cosmos(q, rope_emb) + k = apply_rotary_pos_emb_cosmos(k, rope_emb) + + # Reshape for scaled_dot_product_attention: (B, heads, seq, dim) + in_q_shape = q.shape + in_k_shape = k.shape + q = rearrange(q, "b ... h d -> b h ... d").reshape(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1]) + k = rearrange(k, "b ... h d -> b h ... d").reshape(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) + v = rearrange(v, "b ... h d -> b h ... d").reshape(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) + + result = F.scaled_dot_product_attention(q, k, v) + result = rearrange(result, "b h s d -> b s (h d)") + return self.output_dropout(self.output_proj(result)) + + +class Timesteps(nn.Module): + """Sinusoidal timestep embeddings.""" + + def __init__(self, num_channels: int): + super().__init__() + self.num_channels = num_channels + + def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor: + assert timesteps_B_T.ndim == 2 + timesteps = timesteps_B_T.flatten().float() + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) / half_dim + emb = timesteps[:, None].float() * torch.exp(exponent)[None, :] + emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1) + return rearrange(emb, "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1]) + + +class TimestepEmbedding(nn.Module): + """Projects sinusoidal timestep embeddings to model dimension.""" + + def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False): + super().__init__() + self.use_adaln_lora = use_adaln_lora + self.linear_1 = nn.Linear(in_features, out_features, bias=not use_adaln_lora) + self.activation = nn.SiLU() + if use_adaln_lora: + self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False) + else: + self.linear_2 = nn.Linear(out_features, out_features, bias=False) + + def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + emb = self.linear_2(self.activation(self.linear_1(sample))) + if self.use_adaln_lora: + return sample, emb + return emb, None + + +class PatchEmbed(nn.Module): + """Patchify input tensor via rearrange + linear projection.""" + + def __init__( + self, + spatial_patch_size: int, + temporal_patch_size: int, + in_channels: int = 3, + out_channels: int = 768, + ): + super().__init__() + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + self.proj = nn.Sequential( + Rearrange( + "b c (t r) (h m) (w n) -> b t h w (c r m n)", + r=temporal_patch_size, + m=spatial_patch_size, + n=spatial_patch_size, + ), + nn.Linear( + in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, + out_channels, + bias=False, + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + assert x.dim() == 5 + return self.proj(x) + + +class FinalLayer(nn.Module): + """Final AdaLN-modulated output projection.""" + + def __init__( + self, + hidden_size: int, + spatial_patch_size: int, + temporal_patch_size: int, + out_channels: int, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + ): + super().__init__() + self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False + ) + self.hidden_size = hidden_size + self.use_adaln_lora = use_adaln_lora + + if use_adaln_lora: + self.adaln_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, adaln_lora_dim, bias=False), + nn.Linear(adaln_lora_dim, 2 * hidden_size, bias=False), + ) + else: + self.adaln_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=False), + ) + + def forward( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_T_D: torch.Tensor, + adaln_lora_B_T_3D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.use_adaln_lora: + assert adaln_lora_B_T_3D is not None + shift, scale = (self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]).chunk( + 2, dim=-1 + ) + else: + shift, scale = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1) + + shift = rearrange(shift, "b t d -> b t 1 1 d") + scale = rearrange(scale, "b t d -> b t 1 1 d") + + x_B_T_H_W_D = self.layer_norm(x_B_T_H_W_D) * (1 + scale) + shift + return self.linear(x_B_T_H_W_D) + + +class DiTBlock(nn.Module): + """Cosmos DiT transformer block with self-attention, cross-attention, and MLP. + + Each component uses AdaLN (Adaptive Layer Normalization) modulation from + the timestep embedding. + """ + + def __init__( + self, + x_dim: int, + context_dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + ): + super().__init__() + self.x_dim = x_dim + self.use_adaln_lora = use_adaln_lora + + self.layer_norm_self_attn = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6) + self.self_attn = CosmosAttention(x_dim, None, num_heads, x_dim // num_heads) + + self.layer_norm_cross_attn = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6) + self.cross_attn = CosmosAttention(x_dim, context_dim, num_heads, x_dim // num_heads) + + self.layer_norm_mlp = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6) + self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio)) + + # AdaLN modulation layers (shift, scale, gate for each of 3 components) + if use_adaln_lora: + self.adaln_modulation_self_attn = nn.Sequential( + nn.SiLU(), + nn.Linear(x_dim, adaln_lora_dim, bias=False), + nn.Linear(adaln_lora_dim, 3 * x_dim, bias=False), + ) + self.adaln_modulation_cross_attn = nn.Sequential( + nn.SiLU(), + nn.Linear(x_dim, adaln_lora_dim, bias=False), + nn.Linear(adaln_lora_dim, 3 * x_dim, bias=False), + ) + self.adaln_modulation_mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(x_dim, adaln_lora_dim, bias=False), + nn.Linear(adaln_lora_dim, 3 * x_dim, bias=False), + ) + else: + self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), nn.Linear(x_dim, 3 * x_dim, bias=False)) + self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), nn.Linear(x_dim, 3 * x_dim, bias=False)) + self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), nn.Linear(x_dim, 3 * x_dim, bias=False)) + + def forward( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_T_D: torch.Tensor, + crossattn_emb: torch.Tensor, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_T_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual_dtype = x_B_T_H_W_D.dtype + compute_dtype = emb_B_T_D.dtype + + if extra_per_block_pos_emb is not None: + x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb + + # Compute AdaLN modulations + if self.use_adaln_lora: + assert adaln_lora_B_T_3D is not None + shift_sa, scale_sa, gate_sa = (self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D).chunk( + 3, dim=-1 + ) + shift_ca, scale_ca, gate_ca = (self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D).chunk( + 3, dim=-1 + ) + shift_mlp, scale_mlp, gate_mlp = (self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D).chunk(3, dim=-1) + else: + shift_sa, scale_sa, gate_sa = self.adaln_modulation_self_attn(emb_B_T_D).chunk(3, dim=-1) + shift_ca, scale_ca, gate_ca = self.adaln_modulation_cross_attn(emb_B_T_D).chunk(3, dim=-1) + shift_mlp, scale_mlp, gate_mlp = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1) + + # Reshape for broadcasting: (B, T, D) -> (B, T, 1, 1, D) + shift_sa, scale_sa, gate_sa = (rearrange(t, "b t d -> b t 1 1 d") for t in (shift_sa, scale_sa, gate_sa)) + shift_ca, scale_ca, gate_ca = (rearrange(t, "b t d -> b t 1 1 d") for t in (shift_ca, scale_ca, gate_ca)) + shift_mlp, scale_mlp, gate_mlp = (rearrange(t, "b t d -> b t 1 1 d") for t in (shift_mlp, scale_mlp, gate_mlp)) + + B, T, H, W, D = x_B_T_H_W_D.shape + + def _adaln(x: torch.Tensor, norm: nn.Module, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor: + return norm(x) * (1 + scale) + shift + + # Self-attention + normed = _adaln(x_B_T_H_W_D, self.layer_norm_self_attn, scale_sa, shift_sa) + result = rearrange( + self.self_attn( + rearrange(normed.to(compute_dtype), "b t h w d -> b (t h w) d"), None, rope_emb=rope_emb_L_1_1_D + ), + "b (t h w) d -> b t h w d", + t=T, + h=H, + w=W, + ) + x_B_T_H_W_D = x_B_T_H_W_D + gate_sa.to(residual_dtype) * result.to(residual_dtype) + + # Cross-attention + normed = _adaln(x_B_T_H_W_D, self.layer_norm_cross_attn, scale_ca, shift_ca) + result = rearrange( + self.cross_attn( + rearrange(normed.to(compute_dtype), "b t h w d -> b (t h w) d"), + crossattn_emb, + rope_emb=rope_emb_L_1_1_D, + ), + "b (t h w) d -> b t h w d", + t=T, + h=H, + w=W, + ) + x_B_T_H_W_D = result.to(residual_dtype) * gate_ca.to(residual_dtype) + x_B_T_H_W_D + + # MLP + normed = _adaln(x_B_T_H_W_D, self.layer_norm_mlp, scale_mlp, shift_mlp) + result = self.mlp(normed.to(compute_dtype)) + x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp.to(residual_dtype) * result.to(residual_dtype) + + return x_B_T_H_W_D + + +class MiniTrainDIT(nn.Module): + """Cosmos Predict2 DiT backbone for video/image generation. + + This is the core transformer architecture that Anima extends. It processes + 3D latent tensors (B, C, T, H, W) with patch embedding, positional encoding, + and adaptive layer normalization. + + Args: + max_img_h: Maximum image height in pixels. + max_img_w: Maximum image width in pixels. + max_frames: Maximum number of video frames. + in_channels: Number of input latent channels. + out_channels: Number of output channels. + patch_spatial: Spatial patch size. + patch_temporal: Temporal patch size. + concat_padding_mask: Whether to concatenate a padding mask channel. + model_channels: Hidden dimension of the transformer. + num_blocks: Number of DiT blocks. + num_heads: Number of attention heads. + mlp_ratio: MLP expansion ratio. + crossattn_emb_channels: Cross-attention context dimension. + use_adaln_lora: Whether to use AdaLN-LoRA. + adaln_lora_dim: AdaLN-LoRA bottleneck dimension. + extra_per_block_abs_pos_emb: Whether to use extra learnable positional embeddings. + """ + + def __init__( + self, + max_img_h: int = 240, + max_img_w: int = 240, + max_frames: int = 1, + in_channels: int = 16, + out_channels: int = 16, + patch_spatial: int = 2, + patch_temporal: int = 1, + concat_padding_mask: bool = True, + model_channels: int = 2048, + num_blocks: int = 28, + num_heads: int = 16, + mlp_ratio: float = 4.0, + crossattn_emb_channels: int = 1024, + pos_emb_cls: str = "rope3d", + pos_emb_learnable: bool = False, + pos_emb_interpolation: str = "crop", + min_fps: int = 1, + max_fps: int = 30, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + rope_h_extrapolation_ratio: float = 1.0, + rope_w_extrapolation_ratio: float = 1.0, + rope_t_extrapolation_ratio: float = 1.0, + extra_per_block_abs_pos_emb: bool = False, + extra_h_extrapolation_ratio: float = 1.0, + extra_w_extrapolation_ratio: float = 1.0, + extra_t_extrapolation_ratio: float = 1.0, + rope_enable_fps_modulation: bool = True, + image_model: Optional[str] = None, + ) -> None: + super().__init__() + self.max_img_h = max_img_h + self.max_img_w = max_img_w + self.max_frames = max_frames + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.num_heads = num_heads + self.num_blocks = num_blocks + self.model_channels = model_channels + self.concat_padding_mask = concat_padding_mask + self.pos_emb_cls = pos_emb_cls + self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb + + # Positional embeddings + self.pos_embedder = VideoRopePosition3DEmb( + head_dim=model_channels // num_heads, + len_h=max_img_h // patch_spatial, + len_w=max_img_w // patch_spatial, + len_t=max_frames // patch_temporal, + max_fps=max_fps, + min_fps=min_fps, + h_extrapolation_ratio=rope_h_extrapolation_ratio, + w_extrapolation_ratio=rope_w_extrapolation_ratio, + t_extrapolation_ratio=rope_t_extrapolation_ratio, + enable_fps_modulation=rope_enable_fps_modulation, + ) + + if extra_per_block_abs_pos_emb: + self.extra_pos_embedder = LearnablePosEmbAxis( + model_channels=model_channels, + len_h=max_img_h // patch_spatial, + len_w=max_img_w // patch_spatial, + len_t=max_frames // patch_temporal, + ) + + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + + # Timestep embedding + self.t_embedder = nn.Sequential( + Timesteps(model_channels), + TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora), + ) + self.t_embedding_norm = nn.RMSNorm(model_channels, eps=1e-6) + + # Patch embedding + embed_in_channels = in_channels + 1 if concat_padding_mask else in_channels + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=embed_in_channels, + out_channels=model_channels, + ) + + # Transformer blocks + self.blocks = nn.ModuleList( + [ + DiTBlock( + x_dim=model_channels, + context_dim=crossattn_emb_channels, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + ) + for _ in range(num_blocks) + ] + ) + + # Final output layer + self.final_layer = FinalLayer( + hidden_size=model_channels, + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + out_channels=out_channels, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + ) + + def _pad_to_patch_size(self, x: torch.Tensor) -> torch.Tensor: + """Pad input tensor so dimensions are divisible by patch sizes.""" + _, _, T, H, W = x.shape + pad_t = (self.patch_temporal - T % self.patch_temporal) % self.patch_temporal + pad_h = (self.patch_spatial - H % self.patch_spatial) % self.patch_spatial + pad_w = (self.patch_spatial - W % self.patch_spatial) % self.patch_spatial + if pad_t > 0 or pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, pad_w, 0, pad_h, 0, pad_t)) + return x + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + if self.concat_padding_mask: + if padding_mask is None: + padding_mask = torch.zeros( + x_B_C_T_H_W.shape[0], + 1, + x_B_C_T_H_W.shape[3], + x_B_C_T_H_W.shape[4], + dtype=x_B_C_T_H_W.dtype, + device=x_B_C_T_H_W.device, + ) + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + extra_pos_emb = None + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder( + x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype + ) + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb + + return x_B_T_H_W_D, None, extra_pos_emb + + def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor: + return rearrange( + x_B_T_H_W_M, + "B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)", + p1=self.patch_spatial, + p2=self.patch_spatial, + t=self.patch_temporal, + ) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + orig_shape = list(x.shape) + x = self._pad_to_patch_size(x) + + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb = self.prepare_embedded_sequence( + x, fps=fps, padding_mask=padding_mask + ) + + if timesteps.ndim == 1: + timesteps = timesteps.unsqueeze(1) + t_emb, adaln_lora = self.t_embedder[1](self.t_embedder[0](timesteps).to(x_B_T_H_W_D.dtype)) + t_emb = self.t_embedding_norm(t_emb) + + block_kwargs = { + "rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0) if rope_emb_L_1_1_D is not None else None, + "adaln_lora_B_T_3D": adaln_lora, + "extra_per_block_pos_emb": extra_pos_emb, + } + + # Keep residual stream in fp32 for numerical stability with fp16 compute + if x_B_T_H_W_D.dtype == torch.float16: + x_B_T_H_W_D = x_B_T_H_W_D.float() + + for block in self.blocks: + x_B_T_H_W_D = block(x_B_T_H_W_D, t_emb, context, **block_kwargs) + + x_out = self.final_layer(x_B_T_H_W_D.to(context.dtype), t_emb, adaln_lora_B_T_3D=adaln_lora) + x_out = self.unpatchify(x_out)[:, :, : orig_shape[-3], : orig_shape[-2], : orig_shape[-1]] + return x_out + + +# ============================================================================ +# LLM Adapter +# Reference implementation: https://github.com/hdae/diffusers-anima +# SPDX-License-Identifier: Apache-2.0 +# ============================================================================ + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """Split the last dimension in half and negate-swap: [-x2, x1].""" + half = x.shape[-1] // 2 + first, second = x[..., :half], x[..., half:] + return torch.cat((-second, first), dim=-1) + + +def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + """Apply rotary position embeddings to tensor x given precomputed cos/sin.""" + return (x * cos.unsqueeze(1)) + (_rotate_half(x) * sin.unsqueeze(1)) + + +class LLMAdapterRotaryEmbedding(nn.Module): + """Rotary position embedding for the LLM Adapter's attention layers.""" + + def __init__(self, head_dim: int, theta: float = 10000.0): + super().__init__() + half_dim = head_dim // 2 + index = torch.arange(half_dim, dtype=torch.float32) + exponent = (2.0 / float(head_dim)) * index + inv_freq = torch.reciprocal(torch.pow(torch.tensor(theta, dtype=torch.float32), exponent)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + pos = position_ids.to(device=x.device, dtype=torch.float32) + inv = self.inv_freq.to(device=x.device, dtype=torch.float32) + freqs = torch.einsum("bl,d->bld", pos, inv) + emb = freqs.repeat(1, 1, 2) + return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype) + + +class LLMAdapterAttention(nn.Module): + """Attention for the LLM Adapter with QK normalization and rotary position embeddings.""" + + def __init__(self, query_dim: int, context_dim: int, n_heads: int, head_dim: int): + super().__init__() + inner_dim = head_dim * n_heads + self.n_heads = n_heads + self.head_dim = head_dim + + self.q_proj = nn.Linear(query_dim, inner_dim, bias=False) + self.q_norm = nn.RMSNorm(head_dim, eps=1e-6) + self.k_proj = nn.Linear(context_dim, inner_dim, bias=False) + self.k_norm = nn.RMSNorm(head_dim, eps=1e-6) + self.v_proj = nn.Linear(context_dim, inner_dim, bias=False) + self.o_proj = nn.Linear(inner_dim, query_dim, bias=False) + + def forward( + self, + x: torch.Tensor, + *, + context: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + pos_q: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + pos_k: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + context = x if context is None else context + + q = self.q_proj(x).view(x.shape[0], x.shape[1], self.n_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(context).view(context.shape[0], context.shape[1], self.n_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(context).view(context.shape[0], context.shape[1], self.n_heads, self.head_dim).transpose(1, 2) + + q = self.q_norm(q) + k = self.k_norm(k) + + if pos_q is not None and pos_k is not None: + q = _apply_rope(q, *pos_q) + k = _apply_rope(k, *pos_k) + + y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + y = y.transpose(1, 2).reshape(x.shape[0], x.shape[1], -1).contiguous() + return self.o_proj(y) + + +class LLMAdapterTransformerBlock(nn.Module): + """Single transformer block in the LLM Adapter. + + Each block contains self-attention, cross-attention, and MLP with + RMSNorm pre-normalization. + """ + + def __init__( + self, + source_dim: int, + model_dim: int, + num_heads: int = 16, + ): + super().__init__() + head_dim = model_dim // num_heads + + self.norm_self_attn = nn.RMSNorm(model_dim, eps=1e-6) + self.self_attn = LLMAdapterAttention(model_dim, model_dim, num_heads, head_dim) + + self.norm_cross_attn = nn.RMSNorm(model_dim, eps=1e-6) + self.cross_attn = LLMAdapterAttention(model_dim, source_dim, num_heads, head_dim) + + self.norm_mlp = nn.RMSNorm(model_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(model_dim, model_dim * 4), + nn.GELU(), + nn.Linear(model_dim * 4, model_dim), + ) + + def forward( + self, + x: torch.Tensor, + *, + context: torch.Tensor, + target_mask: Optional[torch.Tensor] = None, + source_mask: Optional[torch.Tensor] = None, + pos_target: Tuple[torch.Tensor, torch.Tensor], + pos_source: Tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + x = x + self.self_attn( + self.norm_self_attn(x), + attn_mask=target_mask, + pos_q=pos_target, + pos_k=pos_target, + ) + x = x + self.cross_attn( + self.norm_cross_attn(x), + context=context, + attn_mask=source_mask, + pos_q=pos_target, + pos_k=pos_source, + ) + x = x + self.mlp(self.norm_mlp(x)) + return x + + +class LLMAdapter(nn.Module): + """LLM Adapter: bridges Qwen3 hidden states and T5-XXL token embeddings. + + Takes Qwen3 hidden states and T5-XXL token IDs, produces conditioning + embeddings for the Cosmos DiT via cross-attention through 6 transformer layers. + + Args: + vocab_size: Size of the T5 token vocabulary. + dim: Model dimension (used for embeddings, projections, and all layers). + num_layers: Number of transformer layers. + num_heads: Number of attention heads. + """ + + def __init__( + self, + vocab_size: int = 32128, + dim: int = 1024, + num_layers: int = 6, + num_heads: int = 16, + ): + super().__init__() + self.embed = nn.Embedding(vocab_size, dim) + self.blocks = nn.ModuleList( + [LLMAdapterTransformerBlock(source_dim=dim, model_dim=dim, num_heads=num_heads) for _ in range(num_layers)] + ) + self.out_proj = nn.Linear(dim, dim) + self.norm = nn.RMSNorm(dim, eps=1e-6) + self.rotary_emb = LLMAdapterRotaryEmbedding(dim // num_heads) + + def forward( + self, + source_hidden_states: torch.Tensor, + target_input_ids: torch.Tensor, + target_attention_mask: Optional[torch.Tensor] = None, + source_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Expand attention masks for multi-head attention + if target_attention_mask is not None: + target_attention_mask = target_attention_mask.to(torch.bool) + if target_attention_mask.ndim == 2: + target_attention_mask = target_attention_mask[:, None, None, :] + + if source_attention_mask is not None: + source_attention_mask = source_attention_mask.to(torch.bool) + if source_attention_mask.ndim == 2: + source_attention_mask = source_attention_mask[:, None, None, :] + + context = source_hidden_states + x = self.embed(target_input_ids).to(dtype=context.dtype) + + # Build position IDs and compute rotary embeddings + target_pos_ids = torch.arange(x.shape[1], device=x.device, dtype=torch.long).unsqueeze(0) + source_pos_ids = torch.arange(context.shape[1], device=x.device, dtype=torch.long).unsqueeze(0) + pos_target = self.rotary_emb(x, target_pos_ids) + pos_source = self.rotary_emb(x, source_pos_ids) + + for block in self.blocks: + x = block( + x, + context=context, + target_mask=target_attention_mask, + source_mask=source_attention_mask, + pos_target=pos_target, + pos_source=pos_source, + ) + return self.norm(self.out_proj(x)) + + +# ============================================================================ +# Anima: MiniTrainDIT + LLMAdapter +# Reference implementation: https://github.com/hdae/diffusers-anima +# SPDX-License-Identifier: Apache-2.0 +# ============================================================================ + + +class AnimaTransformer(MiniTrainDIT): + """Anima transformer: Cosmos Predict2 DiT with integrated LLM Adapter. + + Extends MiniTrainDIT by adding the LLMAdapter component that preprocesses + text embeddings before they are fed to the DiT cross-attention layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.llm_adapter = LLMAdapter() + + def preprocess_text_embeds( + self, + text_embeds: torch.Tensor, + text_ids: Optional[torch.Tensor], + t5xxl_weights: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Run the LLM Adapter to produce conditioning for the DiT. + + Args: + text_embeds: Qwen3 hidden states. Shape: (batch, seq_len, 1024). + text_ids: T5-XXL token IDs. Shape: (batch, seq_len). If None, returns text_embeds directly. + t5xxl_weights: Optional per-token weights. Shape: (batch, seq_len, 1). + + Returns: + Conditioning tensor. Shape: (batch, 512, 1024), zero-padded if needed. + """ + if text_ids is None: + return text_embeds + out = self.llm_adapter(text_embeds, text_ids) + if t5xxl_weights is not None: + out = out * t5xxl_weights + if out.shape[1] < 512: + out = F.pad(out, (0, 0, 0, 512 - out.shape[1])) + return out + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor, + t5xxl_ids: Optional[torch.Tensor] = None, + t5xxl_weights: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """Forward pass with LLM Adapter preprocessing. + + Args: + x: Input latent tensor. Shape: (B, C, T, H, W). + timesteps: Timestep values. Shape: (B,) or (B, T). + context: Qwen3 hidden states. Shape: (B, seq_len, 1024). + t5xxl_ids: T5-XXL token IDs. Shape: (B, seq_len). + t5xxl_weights: Per-token weights. Shape: (B, seq_len, 1). + + Returns: + Denoised output. Shape: (B, C, T, H, W). + """ + if t5xxl_ids is not None: + context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=t5xxl_weights) + return super().forward(x, timesteps, context, **kwargs) diff --git a/invokeai/backend/anima/anima_transformer_patch.py b/invokeai/backend/anima/anima_transformer_patch.py new file mode 100644 index 0000000000..4eff79830e --- /dev/null +++ b/invokeai/backend/anima/anima_transformer_patch.py @@ -0,0 +1,106 @@ +"""Utilities for patching the AnimaTransformer to support regional cross-attention masks.""" + +from contextlib import contextmanager +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange + +from invokeai.backend.anima.regional_prompting import AnimaRegionalPromptingExtension + + +def _patched_cross_attn_forward( + original_forward, + attn_mask: torch.Tensor, +): + """Create a patched forward for CosmosAttention that injects a cross-attention mask. + + Args: + original_forward: The original CosmosAttention.forward method (bound to self). + attn_mask: Cross-attention mask of shape (img_seq_len, context_seq_len). + """ + + def forward(x, context=None, rope_emb=None): + # If the context sequence length doesn't match the mask (e.g. negative conditioning + # has a different number of tokens than positive regional conditioning), skip masking + # and use the original unmasked forward. + actual_context = x if context is None else context + if actual_context.shape[-2] != attn_mask.shape[1]: + return original_forward(x, context, rope_emb=rope_emb) + + self = original_forward.__self__ + + q = self.q_proj(x) + context = x if context is None else context + k = self.k_proj(context) + v = self.v_proj(context) + q, k, v = (rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim) for t in (q, k, v)) + + q = self.q_norm(q) + k = self.k_norm(k) + v = self.v_norm(v) + + if self.is_selfattn and rope_emb is not None: + from invokeai.backend.anima.anima_transformer import apply_rotary_pos_emb_cosmos + + q = apply_rotary_pos_emb_cosmos(q, rope_emb) + k = apply_rotary_pos_emb_cosmos(k, rope_emb) + + in_q_shape = q.shape + in_k_shape = k.shape + q = rearrange(q, "b ... h d -> b h ... d").reshape(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1]) + k = rearrange(k, "b ... h d -> b h ... d").reshape(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) + v = rearrange(v, "b ... h d -> b h ... d").reshape(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) + + # Convert boolean mask to float additive mask for SDPA + # True (attend) -> 0.0, False (block) -> -inf + # Shape: (img_seq_len, context_seq_len) -> (1, 1, img_seq_len, context_seq_len) + float_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + float_mask[~attn_mask] = float("-inf") + expanded_mask = float_mask.unsqueeze(0).unsqueeze(0) + + result = F.scaled_dot_product_attention(q, k, v, attn_mask=expanded_mask) + result = rearrange(result, "b h s d -> b s (h d)") + return self.output_dropout(self.output_proj(result)) + + return forward + + +@contextmanager +def patch_anima_for_regional_prompting( + transformer, + regional_extension: Optional[AnimaRegionalPromptingExtension], +): + """Context manager to temporarily patch the Anima transformer for regional prompting. + + Patches the cross-attention in each DiT block to use a regional attention mask. + Uses alternating pattern: masked on even blocks, unmasked on odd blocks for + global coherence. + + Args: + transformer: The AnimaTransformer instance. + regional_extension: The regional prompting extension. If None or no mask, no patching. + + Yields: + The (possibly patched) transformer. + """ + if regional_extension is None or regional_extension.cross_attn_mask is None: + yield transformer + return + + # Store original forwards + original_forwards = [] + for block_idx, block in enumerate(transformer.blocks): + original_forwards.append(block.cross_attn.forward) + + mask = regional_extension.get_cross_attn_mask(block_idx) + if mask is not None: + block.cross_attn.forward = _patched_cross_attn_forward(block.cross_attn.forward, mask) + + try: + yield transformer + finally: + # Restore original forwards + for block_idx, block in enumerate(transformer.blocks): + block.cross_attn.forward = original_forwards[block_idx] diff --git a/invokeai/backend/anima/conditioning_data.py b/invokeai/backend/anima/conditioning_data.py new file mode 100644 index 0000000000..b96c807835 --- /dev/null +++ b/invokeai/backend/anima/conditioning_data.py @@ -0,0 +1,64 @@ +"""Anima text conditioning data structures. + +Anima uses a dual-conditioning scheme: +- Qwen3 0.6B hidden states (continuous embeddings) +- T5-XXL token IDs (discrete IDs, embedded by the LLM Adapter inside the transformer) + +Both are produced by the text encoder invocation and stored together. + +For regional prompting, multiple conditionings (each with an optional spatial mask) +are concatenated and processed together. The LLM Adapter runs on each region's +conditioning separately, producing per-region context vectors that are concatenated +for the DiT's cross-attention layers. An attention mask restricts which image tokens +attend to which regional context tokens. +""" + +from dataclasses import dataclass + +import torch + +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range + + +@dataclass +class AnimaTextConditioning: + """Anima text conditioning with Qwen3 hidden states, T5-XXL token IDs, and optional mask. + + Attributes: + qwen3_embeds: Text embeddings from Qwen3 0.6B encoder. + Shape: (seq_len, hidden_size) where hidden_size=1024. + t5xxl_ids: T5-XXL token IDs for the same prompt. + Shape: (seq_len,). + t5xxl_weights: Per-token weights for prompt weighting. + Shape: (seq_len,). Defaults to all ones if not provided. + mask: Optional binary mask for regional prompting. If None, the prompt is global. + Shape: (1, 1, img_seq_len) where img_seq_len = (H // patch_size) * (W // patch_size). + """ + + qwen3_embeds: torch.Tensor + t5xxl_ids: torch.Tensor + t5xxl_weights: torch.Tensor | None = None + mask: torch.Tensor | None = None + + +@dataclass +class AnimaRegionalTextConditioning: + """Container for multiple regional text conditionings processed by the LLM Adapter. + + After the LLM Adapter processes each region's conditioning, the outputs are concatenated. + The DiT cross-attention then uses an attention mask to restrict which image tokens + attend to which region's context tokens. + + Attributes: + context_embeds: Concatenated LLM Adapter outputs from all regional prompts. + Shape: (total_context_len, 1024). + image_masks: List of binary masks for each regional prompt. + If None, the prompt is global (applies to entire image). + Shape: (1, 1, img_seq_len). + context_ranges: List of ranges indicating which portion of context_embeds + corresponds to each regional prompt. + """ + + context_embeds: torch.Tensor + image_masks: list[torch.Tensor | None] + context_ranges: list[Range] diff --git a/invokeai/backend/anima/regional_prompting.py b/invokeai/backend/anima/regional_prompting.py new file mode 100644 index 0000000000..c0af366332 --- /dev/null +++ b/invokeai/backend/anima/regional_prompting.py @@ -0,0 +1,173 @@ +"""Regional prompting extension for Anima. + +Anima's architecture uses separate cross-attention in each DiT block: image tokens +(in 5D spatial layout) cross-attend to context tokens (LLM Adapter output). This is +different from Z-Image's unified [img, txt] sequence with self-attention. + +For regional prompting, we: +1. Run the LLM Adapter separately for each regional prompt +2. Concatenate the resulting context vectors +3. Build a cross-attention mask that restricts each image region to attend only to + its corresponding context tokens +4. Patch the DiT's cross-attention to use this mask + +The mask alternation strategy (masked on even blocks, full on odd blocks) helps +maintain global coherence across regions. +""" + +from typing import Optional + +import torch +import torchvision + +from invokeai.backend.anima.conditioning_data import AnimaRegionalTextConditioning +from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.mask import to_standard_float_mask + + +class AnimaRegionalPromptingExtension: + """Manages regional prompting for Anima's cross-attention. + + Unlike Z-Image which uses a unified [img, txt] sequence, Anima has separate + cross-attention where image tokens (query) attend to context tokens (key/value). + The cross-attention mask shape is (img_seq_len, context_seq_len). + """ + + def __init__( + self, + regional_text_conditioning: AnimaRegionalTextConditioning, + cross_attn_mask: torch.Tensor | None = None, + ): + self.regional_text_conditioning = regional_text_conditioning + self.cross_attn_mask = cross_attn_mask + + def get_cross_attn_mask(self, block_index: int) -> torch.Tensor | None: + """Get the cross-attention mask for a given block index. + + Uses alternating pattern: apply mask on even blocks, no mask on odd blocks. + This helps balance regional control with global coherence. + """ + if block_index % 2 == 0: + return self.cross_attn_mask + return None + + @classmethod + def from_regional_conditioning( + cls, + regional_text_conditioning: AnimaRegionalTextConditioning, + img_seq_len: int, + ) -> "AnimaRegionalPromptingExtension": + """Create extension from pre-processed regional conditioning. + + Args: + regional_text_conditioning: Regional conditioning with concatenated context and masks. + img_seq_len: Number of image tokens (H_patches * W_patches). + """ + cross_attn_mask = cls._prepare_cross_attn_mask(regional_text_conditioning, img_seq_len) + return cls( + regional_text_conditioning=regional_text_conditioning, + cross_attn_mask=cross_attn_mask, + ) + + @classmethod + def _prepare_cross_attn_mask( + cls, + regional_text_conditioning: AnimaRegionalTextConditioning, + img_seq_len: int, + ) -> torch.Tensor | None: + """Prepare a cross-attention mask for regional prompting. + + The mask shape is (img_seq_len, context_seq_len) where: + - Each image token can attend to context tokens from its assigned region + - Global prompts (mask=None) attend to background regions + + Args: + regional_text_conditioning: The regional text conditioning data. + img_seq_len: Number of image tokens. + + Returns: + Cross-attention mask of shape (img_seq_len, context_seq_len), or None + if no regional masks are present. + """ + has_regional_masks = any(mask is not None for mask in regional_text_conditioning.image_masks) + if not has_regional_masks: + return None + + # Identify background region (area not covered by any mask) + background_region_mask: torch.Tensor | None = None + for image_mask in regional_text_conditioning.image_masks: + if image_mask is not None: + mask_flat = image_mask.view(-1) + if background_region_mask is None: + background_region_mask = torch.ones_like(mask_flat) + background_region_mask = background_region_mask * (1 - mask_flat) + + device = TorchDevice.choose_torch_device() + context_seq_len = regional_text_conditioning.context_embeds.shape[0] + + # Cross-attention mask: (img_seq_len, context_seq_len) + # img tokens are queries, context tokens are keys/values + cross_attn_mask = torch.zeros((img_seq_len, context_seq_len), device=device, dtype=torch.float16) + + for image_mask, context_range in zip( + regional_text_conditioning.image_masks, + regional_text_conditioning.context_ranges, + strict=True, + ): + ctx_start = context_range.start + ctx_end = context_range.end + + if image_mask is not None: + # Regional prompt: only masked image tokens attend to this region's context + mask_flat = image_mask.view(img_seq_len) + cross_attn_mask[:, ctx_start:ctx_end] = mask_flat.view(img_seq_len, 1) + else: + # Global prompt: background image tokens attend to this context + if background_region_mask is not None: + cross_attn_mask[:, ctx_start:ctx_end] = background_region_mask.view(img_seq_len, 1) + else: + cross_attn_mask[:, ctx_start:ctx_end] = 1.0 + + # Convert to boolean + cross_attn_mask = cross_attn_mask > 0.5 + return cross_attn_mask + + @staticmethod + def preprocess_regional_prompt_mask( + mask: Optional[torch.Tensor], + target_height: int, + target_width: int, + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + """Preprocess a regional prompt mask to match the target image token grid. + + Args: + mask: Input mask tensor. If None, returns a mask of all ones. + target_height: Height of the image token grid (H // patch_size). + target_width: Width of the image token grid (W // patch_size). + dtype: Target dtype for the mask. + device: Target device for the mask. + + Returns: + Processed mask of shape (1, 1, target_height * target_width). + """ + img_seq_len = target_height * target_width + + if mask is None: + return torch.ones((1, 1, img_seq_len), dtype=dtype, device=device) + + mask = to_standard_float_mask(mask, out_dtype=dtype) + + tf = torchvision.transforms.Resize( + (target_height, target_width), + interpolation=torchvision.transforms.InterpolationMode.NEAREST, + ) + + if mask.ndim == 2: + mask = mask.unsqueeze(0) + if mask.ndim == 3: + mask = mask.unsqueeze(0) + + resized_mask = tf(mask) + return resized_mask.flatten(start_dim=2).to(device=device) diff --git a/invokeai/backend/flux/modules/conditioner.py b/invokeai/backend/flux/modules/conditioner.py index 9deb442929..d48d78cd4a 100644 --- a/invokeai/backend/flux/modules/conditioner.py +++ b/invokeai/backend/flux/modules/conditioner.py @@ -3,6 +3,8 @@ from torch import Tensor, nn from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device + class HFEncoder(nn.Module): def __init__( @@ -32,7 +34,7 @@ class HFEncoder(nn.Module): ) # Move inputs to the same device as the model to support cpu_only models - model_device = next(self.hf_module.parameters()).device + model_device = get_effective_device(self.hf_module) outputs = self.hf_module( input_ids=batch_encoding["input_ids"].to(model_device), diff --git a/invokeai/backend/flux/schedulers.py b/invokeai/backend/flux/schedulers.py index e5a8a7137c..05e6bb085f 100644 --- a/invokeai/backend/flux/schedulers.py +++ b/invokeai/backend/flux/schedulers.py @@ -60,3 +60,23 @@ ZIMAGE_SCHEDULER_MAP: dict[str, Type[SchedulerMixin]] = { if _HAS_LCM: ZIMAGE_SCHEDULER_MAP["lcm"] = FlowMatchLCMScheduler + + +# Anima scheduler types (same Flow Matching schedulers as Flux/Z-Image) +# Anima uses rectified flow with shift=3.0 and multiplier=1000. +# Recommended: 30 steps with Euler, CFG 4-5. +ANIMA_SCHEDULER_NAME_VALUES = Literal["euler", "heun", "lcm"] + +ANIMA_SCHEDULER_LABELS: dict[str, str] = { + "euler": "Euler", + "heun": "Heun (2nd order)", + "lcm": "LCM", +} + +ANIMA_SCHEDULER_MAP: dict[str, Type[SchedulerMixin]] = { + "euler": FlowMatchEulerDiscreteScheduler, + "heun": FlowMatchHeunDiscreteScheduler, +} + +if _HAS_LCM: + ANIMA_SCHEDULER_MAP["lcm"] = FlowMatchLCMScheduler diff --git a/invokeai/backend/image_util/controlnet_processor.py b/invokeai/backend/image_util/controlnet_processor.py index 87739f69e1..81eed42097 100644 --- a/invokeai/backend/image_util/controlnet_processor.py +++ b/invokeai/backend/image_util/controlnet_processor.py @@ -14,43 +14,61 @@ def _get_processor_invocation_class(processor_type: str): """Get the invocation class for a processor type.""" # Import processor invocation classes on demand processor_class_map = { - "canny_image_processor": lambda: __import__( - "invokeai.app.invocations.canny", fromlist=["CannyEdgeDetectionInvocation"] - ).CannyEdgeDetectionInvocation, - "hed_image_processor": lambda: __import__( - "invokeai.app.invocations.hed", fromlist=["HEDEdgeDetectionInvocation"] - ).HEDEdgeDetectionInvocation, - "mlsd_image_processor": lambda: __import__( - "invokeai.app.invocations.mlsd", fromlist=["MLSDDetectionInvocation"] - ).MLSDDetectionInvocation, - "depth_anything_image_processor": lambda: __import__( - "invokeai.app.invocations.depth_anything", fromlist=["DepthAnythingDepthEstimationInvocation"] - ).DepthAnythingDepthEstimationInvocation, - "normalbae_image_processor": lambda: __import__( - "invokeai.app.invocations.normal_bae", fromlist=["NormalMapInvocation"] - ).NormalMapInvocation, - "pidi_image_processor": lambda: __import__( - "invokeai.app.invocations.pidi", fromlist=["PiDiNetEdgeDetectionInvocation"] - ).PiDiNetEdgeDetectionInvocation, - "lineart_image_processor": lambda: __import__( - "invokeai.app.invocations.lineart", fromlist=["LineartEdgeDetectionInvocation"] - ).LineartEdgeDetectionInvocation, - "lineart_anime_image_processor": lambda: __import__( - "invokeai.app.invocations.lineart_anime", fromlist=["LineartAnimeEdgeDetectionInvocation"] - ).LineartAnimeEdgeDetectionInvocation, - "content_shuffle_image_processor": lambda: __import__( - "invokeai.app.invocations.content_shuffle", fromlist=["ContentShuffleInvocation"] - ).ContentShuffleInvocation, - "dw_openpose_image_processor": lambda: __import__( - "invokeai.app.invocations.dw_openpose", fromlist=["DWOpenposeDetectionInvocation"] - ).DWOpenposeDetectionInvocation, - "mediapipe_face_processor": lambda: __import__( - "invokeai.app.invocations.mediapipe_face", fromlist=["MediaPipeFaceDetectionInvocation"] - ).MediaPipeFaceDetectionInvocation, + "canny_image_processor": lambda: ( + __import__( + "invokeai.app.invocations.canny", fromlist=["CannyEdgeDetectionInvocation"] + ).CannyEdgeDetectionInvocation + ), + "hed_image_processor": lambda: ( + __import__( + "invokeai.app.invocations.hed", fromlist=["HEDEdgeDetectionInvocation"] + ).HEDEdgeDetectionInvocation + ), + "mlsd_image_processor": lambda: ( + __import__("invokeai.app.invocations.mlsd", fromlist=["MLSDDetectionInvocation"]).MLSDDetectionInvocation + ), + "depth_anything_image_processor": lambda: ( + __import__( + "invokeai.app.invocations.depth_anything", fromlist=["DepthAnythingDepthEstimationInvocation"] + ).DepthAnythingDepthEstimationInvocation + ), + "normalbae_image_processor": lambda: ( + __import__("invokeai.app.invocations.normal_bae", fromlist=["NormalMapInvocation"]).NormalMapInvocation + ), + "pidi_image_processor": lambda: ( + __import__( + "invokeai.app.invocations.pidi", fromlist=["PiDiNetEdgeDetectionInvocation"] + ).PiDiNetEdgeDetectionInvocation + ), + "lineart_image_processor": lambda: ( + __import__( + "invokeai.app.invocations.lineart", fromlist=["LineartEdgeDetectionInvocation"] + ).LineartEdgeDetectionInvocation + ), + "lineart_anime_image_processor": lambda: ( + __import__( + "invokeai.app.invocations.lineart_anime", fromlist=["LineartAnimeEdgeDetectionInvocation"] + ).LineartAnimeEdgeDetectionInvocation + ), + "content_shuffle_image_processor": lambda: ( + __import__( + "invokeai.app.invocations.content_shuffle", fromlist=["ContentShuffleInvocation"] + ).ContentShuffleInvocation + ), + "dw_openpose_image_processor": lambda: ( + __import__( + "invokeai.app.invocations.dw_openpose", fromlist=["DWOpenposeDetectionInvocation"] + ).DWOpenposeDetectionInvocation + ), + "mediapipe_face_processor": lambda: ( + __import__( + "invokeai.app.invocations.mediapipe_face", fromlist=["MediaPipeFaceDetectionInvocation"] + ).MediaPipeFaceDetectionInvocation + ), # Note: zoe_depth_image_processor doesn't have a processor invocation implementation - "color_map_image_processor": lambda: __import__( - "invokeai.app.invocations.color_map", fromlist=["ColorMapInvocation"] - ).ColorMapInvocation, + "color_map_image_processor": lambda: ( + __import__("invokeai.app.invocations.color_map", fromlist=["ColorMapInvocation"]).ColorMapInvocation + ), } if processor_type in processor_class_map: diff --git a/invokeai/backend/image_util/invisible_watermark.py b/invokeai/backend/image_util/invisible_watermark.py index 5b0b2dbb5b..95c483848c 100644 --- a/invokeai/backend/image_util/invisible_watermark.py +++ b/invokeai/backend/image_util/invisible_watermark.py @@ -9,7 +9,7 @@ import numpy as np from PIL import Image import invokeai.backend.util.logging as logger -from invokeai.backend.image_util.imwatermark.vendor import WatermarkEncoder +from invokeai.backend.image_util.imwatermark.vendor import WatermarkDecoder, WatermarkEncoder class InvisibleWatermark: @@ -25,3 +25,25 @@ class InvisibleWatermark: encoder.set_watermark("bytes", watermark_text.encode("utf-8")) bgr_encoded = encoder.encode(bgr, "dwtDct") return Image.fromarray(cv2.cvtColor(bgr_encoded, cv2.COLOR_BGR2RGB)).convert("RGBA") + + @classmethod + def decode_watermark(cls, image: Image.Image, length: int = 8) -> str: + """Attempt to decode an invisible watermark from an image. + + Args: + image: The PIL Image to decode the watermark from. + length: The expected watermark length in bytes. Must match the length used when encoding. + The WatermarkDecoder requires the length in bits; this value is multiplied by 8 internally. + + Returns: + The decoded watermark text, or an empty string if no watermark is detected or decoding fails. + """ + logger.debug("Attempting to decode invisible watermark") + try: + bgr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) + decoder = WatermarkDecoder("bytes", length * 8) + watermark_bytes = decoder.decode(bgr, "dwtDct") + return watermark_bytes.decode("utf-8", errors="ignore").rstrip("\x00") + except Exception: + logger.debug("Failed to decode invisible watermark") + return "" diff --git a/invokeai/backend/model_manager/configs/external_api.py b/invokeai/backend/model_manager/configs/external_api.py index 50c51e28cf..da58cba410 100644 --- a/invokeai/backend/model_manager/configs/external_api.py +++ b/invokeai/backend/model_manager/configs/external_api.py @@ -9,7 +9,7 @@ from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ExternalGenerationMode = Literal["txt2img", "img2img", "inpaint"] ExternalMaskFormat = Literal["alpha", "binary", "none"] -ExternalPanelControlName = Literal["negative_prompt", "reference_images", "dimensions", "seed", "steps", "guidance"] +ExternalPanelControlName = Literal["reference_images", "dimensions", "seed"] class ExternalImageSize(BaseModel): @@ -19,6 +19,16 @@ class ExternalImageSize(BaseModel): model_config = ConfigDict(extra="forbid") +class ExternalResolutionPreset(BaseModel): + label: str = Field(min_length=1, description="Display label, e.g. '1:1 (1K)'") + aspect_ratio: str = Field(min_length=1, description="Aspect ratio string, e.g. '1:1'") + image_size: str = Field(min_length=1, description="Image size preset, e.g. '1K'") + width: int = Field(gt=0) + height: int = Field(gt=0) + + model_config = ConfigDict(extra="forbid") + + class ExternalModelCapabilities(BaseModel): modes: list[ExternalGenerationMode] = Field(default_factory=lambda: ["txt2img"]) supports_reference_images: bool = Field(default=False) @@ -30,6 +40,7 @@ class ExternalModelCapabilities(BaseModel): max_image_size: ExternalImageSize | None = Field(default=None) allowed_aspect_ratios: list[str] | None = Field(default=None) aspect_ratio_sizes: dict[str, ExternalImageSize] | None = Field(default=None) + resolution_presets: list[ExternalResolutionPreset] | None = Field(default=None) max_reference_images: int | None = Field(default=None, gt=0) mask_format: ExternalMaskFormat = Field(default="none") input_image_required_for: list[ExternalGenerationMode] | None = Field(default=None) @@ -40,8 +51,6 @@ class ExternalModelCapabilities(BaseModel): class ExternalApiModelDefaultSettings(BaseModel): width: int | None = Field(default=None, gt=0) height: int | None = Field(default=None, gt=0) - steps: int | None = Field(default=None, gt=0) - guidance: float | None = Field(default=None, gt=0) num_images: int | None = Field(default=None, gt=0) model_config = ConfigDict(extra="forbid") diff --git a/invokeai/backend/model_manager/configs/factory.py b/invokeai/backend/model_manager/configs/factory.py index 81464a1a97..4d26b4c334 100644 --- a/invokeai/backend/model_manager/configs/factory.py +++ b/invokeai/backend/model_manager/configs/factory.py @@ -47,8 +47,10 @@ from invokeai.backend.model_manager.configs.lora import ( LoRA_Diffusers_SD2_Config, LoRA_Diffusers_SDXL_Config, LoRA_Diffusers_ZImage_Config, + LoRA_LyCORIS_Anima_Config, LoRA_LyCORIS_Flux2_Config, LoRA_LyCORIS_FLUX_Config, + LoRA_LyCORIS_QwenImage_Config, LoRA_LyCORIS_SD1_Config, LoRA_LyCORIS_SD2_Config, LoRA_LyCORIS_SDXL_Config, @@ -59,6 +61,7 @@ from invokeai.backend.model_manager.configs.lora import ( ) from invokeai.backend.model_manager.configs.main import ( Main_BnBNF4_FLUX_Config, + Main_Checkpoint_Anima_Config, Main_Checkpoint_Flux2_Config, Main_Checkpoint_FLUX_Config, Main_Checkpoint_SD1_Config, @@ -69,6 +72,7 @@ from invokeai.backend.model_manager.configs.main import ( Main_Diffusers_CogView4_Config, Main_Diffusers_Flux2_Config, Main_Diffusers_FLUX_Config, + Main_Diffusers_QwenImage_Config, Main_Diffusers_SD1_Config, Main_Diffusers_SD2_Config, Main_Diffusers_SD3_Config, @@ -77,6 +81,7 @@ from invokeai.backend.model_manager.configs.main import ( Main_Diffusers_ZImage_Config, Main_GGUF_Flux2_Config, Main_GGUF_FLUX_Config, + Main_GGUF_QwenImage_Config, Main_GGUF_ZImage_Config, MainModelDefaultSettings, ) @@ -102,6 +107,7 @@ from invokeai.backend.model_manager.configs.textual_inversion import ( ) from invokeai.backend.model_manager.configs.unknown import Unknown_Config from invokeai.backend.model_manager.configs.vae import ( + VAE_Checkpoint_Anima_Config, VAE_Checkpoint_Flux2_Config, VAE_Checkpoint_FLUX_Config, VAE_Checkpoint_SD1_Config, @@ -160,6 +166,7 @@ AnyModelConfig = Annotated[ Annotated[Main_Diffusers_FLUX_Config, Main_Diffusers_FLUX_Config.get_tag()], Annotated[Main_Diffusers_Flux2_Config, Main_Diffusers_Flux2_Config.get_tag()], Annotated[Main_Diffusers_CogView4_Config, Main_Diffusers_CogView4_Config.get_tag()], + Annotated[Main_Diffusers_QwenImage_Config, Main_Diffusers_QwenImage_Config.get_tag()], Annotated[Main_Diffusers_ZImage_Config, Main_Diffusers_ZImage_Config.get_tag()], # Main (Pipeline) - checkpoint format # IMPORTANT: FLUX.2 must be checked BEFORE FLUX.1 because FLUX.2 has specific validation @@ -171,12 +178,14 @@ AnyModelConfig = Annotated[ Annotated[Main_Checkpoint_Flux2_Config, Main_Checkpoint_Flux2_Config.get_tag()], Annotated[Main_Checkpoint_FLUX_Config, Main_Checkpoint_FLUX_Config.get_tag()], Annotated[Main_Checkpoint_ZImage_Config, Main_Checkpoint_ZImage_Config.get_tag()], + Annotated[Main_Checkpoint_Anima_Config, Main_Checkpoint_Anima_Config.get_tag()], # Main (Pipeline) - quantized formats # IMPORTANT: FLUX.2 must be checked BEFORE FLUX.1 because FLUX.2 has specific validation # that will reject FLUX.1 models, but FLUX.1 validation may incorrectly match FLUX.2 models Annotated[Main_BnBNF4_FLUX_Config, Main_BnBNF4_FLUX_Config.get_tag()], Annotated[Main_GGUF_Flux2_Config, Main_GGUF_Flux2_Config.get_tag()], Annotated[Main_GGUF_FLUX_Config, Main_GGUF_FLUX_Config.get_tag()], + Annotated[Main_GGUF_QwenImage_Config, Main_GGUF_QwenImage_Config.get_tag()], Annotated[Main_GGUF_ZImage_Config, Main_GGUF_ZImage_Config.get_tag()], # VAE - checkpoint format Annotated[VAE_Checkpoint_SD1_Config, VAE_Checkpoint_SD1_Config.get_tag()], @@ -184,6 +193,7 @@ AnyModelConfig = Annotated[ Annotated[VAE_Checkpoint_SDXL_Config, VAE_Checkpoint_SDXL_Config.get_tag()], Annotated[VAE_Checkpoint_FLUX_Config, VAE_Checkpoint_FLUX_Config.get_tag()], Annotated[VAE_Checkpoint_Flux2_Config, VAE_Checkpoint_Flux2_Config.get_tag()], + Annotated[VAE_Checkpoint_Anima_Config, VAE_Checkpoint_Anima_Config.get_tag()], # VAE - diffusers format Annotated[VAE_Diffusers_SD1_Config, VAE_Diffusers_SD1_Config.get_tag()], Annotated[VAE_Diffusers_SDXL_Config, VAE_Diffusers_SDXL_Config.get_tag()], @@ -208,6 +218,8 @@ AnyModelConfig = Annotated[ Annotated[LoRA_LyCORIS_Flux2_Config, LoRA_LyCORIS_Flux2_Config.get_tag()], Annotated[LoRA_LyCORIS_FLUX_Config, LoRA_LyCORIS_FLUX_Config.get_tag()], Annotated[LoRA_LyCORIS_ZImage_Config, LoRA_LyCORIS_ZImage_Config.get_tag()], + Annotated[LoRA_LyCORIS_QwenImage_Config, LoRA_LyCORIS_QwenImage_Config.get_tag()], + Annotated[LoRA_LyCORIS_Anima_Config, LoRA_LyCORIS_Anima_Config.get_tag()], # LoRA - OMI format Annotated[LoRA_OMI_SDXL_Config, LoRA_OMI_SDXL_Config.get_tag()], Annotated[LoRA_OMI_FLUX_Config, LoRA_OMI_FLUX_Config.get_tag()], diff --git a/invokeai/backend/model_manager/configs/lora.py b/invokeai/backend/model_manager/configs/lora.py index 1619c9d6f0..88f917d0d3 100644 --- a/invokeai/backend/model_manager/configs/lora.py +++ b/invokeai/backend/model_manager/configs/lora.py @@ -31,6 +31,10 @@ from invokeai.backend.model_manager.taxonomy import ( ZImageVariantType, ) from invokeai.backend.model_manager.util.model_util import lora_token_vector_length +from invokeai.backend.patches.lora_conversions.anima_lora_constants import ( + has_cosmos_dit_kohya_keys, + has_cosmos_dit_peft_keys, +) from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control @@ -637,6 +641,13 @@ class LoRA_LyCORIS_Config_Base(LoRA_Config_Base): return BaseModelType.Flux state_dict = mod.load_state_dict() + str_keys = [k for k in state_dict.keys() if isinstance(k, str)] + + # Rule out Anima LoRAs — their lora_te_ keys have shapes that + # lora_token_vector_length() misidentifies as SD2/SDXL. + if has_cosmos_dit_kohya_keys(str_keys) or has_cosmos_dit_peft_keys(str_keys): + raise NotAMatchError("model looks like an Anima LoRA, not a Stable Diffusion LoRA") + # If we've gotten here, we assume that the model is a Stable Diffusion model token_vector_length = lora_token_vector_length(state_dict) if token_vector_length == 768: @@ -711,6 +722,8 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base): state_dict, { "diffusion_model.layers.", # Z-Image S3-DiT layer pattern + "transformer.layers.", # OneTrainer/diffusers prefix variant + "base_model.model.transformer.layers.", # PEFT-wrapped variant }, ) @@ -747,6 +760,8 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base): state_dict, { "diffusion_model.layers.", # Z-Image S3-DiT layer pattern + "transformer.layers.", # OneTrainer/diffusers prefix variant + "base_model.model.transformer.layers.", # PEFT-wrapped variant }, ) @@ -757,6 +772,142 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base): raise NotAMatchError("model does not look like a Z-Image LoRA") +class LoRA_LyCORIS_QwenImage_Config(LoRA_LyCORIS_Config_Base, Config_Base): + """Model config for Qwen Image Edit LoRA models in LyCORIS format.""" + + base: Literal[BaseModelType.QwenImage] = Field(default=BaseModelType.QwenImage) + + @classmethod + def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None: + """Qwen Image Edit LoRAs have keys like transformer_blocks.X.attn.to_k.lora_down.weight.""" + state_dict = mod.load_state_dict() + + has_qwen_ie_keys = state_dict_has_any_keys_starting_with( + state_dict, + { + "transformer_blocks.", + "transformer.transformer_blocks.", + "lora_unet_transformer_blocks_", # Kohya format + }, + ) + has_lora_suffix = state_dict_has_any_keys_ending_with( + state_dict, + { + "lora_A.weight", + "lora_B.weight", + "lora_down.weight", + "lora_up.weight", + "dora_scale", + "lokr_w1", + "lokr_w2", # LoKR format + }, + ) + # Must NOT have diffusion_model.layers (Z-Image) or Flux-style keys. + # Flux LoRAs can have transformer.single_transformer_blocks or transformer.transformer_blocks + # (with the "transformer." prefix and "single_" variant) which would falsely match our check. + # Flux Kohya LoRAs use lora_unet_double_blocks or lora_unet_single_blocks. + has_z_image_keys = state_dict_has_any_keys_starting_with(state_dict, {"diffusion_model.layers."}) + has_flux_keys = state_dict_has_any_keys_starting_with( + state_dict, + { + "double_blocks.", + "single_blocks.", + "single_transformer_blocks.", + "transformer.single_transformer_blocks.", + "lora_unet_double_blocks_", + "lora_unet_single_blocks_", + "lora_unet_single_transformer_blocks_", + }, + ) + + if has_qwen_ie_keys and has_lora_suffix and not has_z_image_keys and not has_flux_keys: + return + + raise NotAMatchError("model does not match Qwen Image LoRA heuristics") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + state_dict = mod.load_state_dict() + has_qwen_ie_keys = state_dict_has_any_keys_starting_with( + state_dict, + {"transformer_blocks.", "transformer.transformer_blocks.", "lora_unet_transformer_blocks_"}, + ) + has_z_image_keys = state_dict_has_any_keys_starting_with(state_dict, {"diffusion_model.layers."}) + has_flux_keys = state_dict_has_any_keys_starting_with( + state_dict, + { + "double_blocks.", + "single_blocks.", + "single_transformer_blocks.", + "transformer.single_transformer_blocks.", + "lora_unet_double_blocks_", + "lora_unet_single_blocks_", + "lora_unet_single_transformer_blocks_", + }, + ) + + if has_qwen_ie_keys and not has_z_image_keys and not has_flux_keys: + return BaseModelType.QwenImage + raise NotAMatchError("model does not look like a Qwen Image Edit LoRA") + + +class LoRA_LyCORIS_Anima_Config(LoRA_LyCORIS_Config_Base, Config_Base): + """Model config for Anima LoRA models in LyCORIS format.""" + + base: Literal[BaseModelType.Anima] = Field(default=BaseModelType.Anima) + + @classmethod + def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None: + """Anima LoRAs use Kohya-style keys targeting Cosmos DiT blocks. + + Anima LoRAs have keys like: + - lora_unet_blocks_0_cross_attn_k_proj.lora_down.weight (Kohya format) + - diffusion_model.blocks.0.cross_attn.k_proj.lora_A.weight (diffusers PEFT format) + - transformer.blocks.0.cross_attn.k_proj.lora_A.weight (diffusers PEFT format) + + Detection requires Cosmos DiT-specific subcomponent names (cross_attn, + self_attn, mlp, adaln_modulation) to avoid false-positives on other + architectures that also use ``blocks`` in their paths. + """ + state_dict = mod.load_state_dict() + str_keys = [k for k in state_dict.keys() if isinstance(k, str)] + + has_cosmos_keys = has_cosmos_dit_kohya_keys(str_keys) or has_cosmos_dit_peft_keys(str_keys) + + # Also check for LoRA/LoKR weight suffixes + has_lora_suffix = state_dict_has_any_keys_ending_with( + state_dict, + { + "lora_A.weight", + "lora_B.weight", + "lora_down.weight", + "lora_up.weight", + "dora_scale", + ".lokr_w1", + ".lokr_w2", + }, + ) + + if has_cosmos_keys and has_lora_suffix: + return + + raise NotAMatchError("model does not match Anima LoRA heuristics") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + """Anima LoRAs target Cosmos DiT blocks (blocks.X.cross_attn, blocks.X.self_attn, etc.). + + Uses Cosmos DiT-specific subcomponent names to avoid false-positives. + """ + state_dict = mod.load_state_dict() + str_keys = [k for k in state_dict.keys() if isinstance(k, str)] + + if has_cosmos_dit_kohya_keys(str_keys) or has_cosmos_dit_peft_keys(str_keys): + return BaseModelType.Anima + + raise NotAMatchError("model does not look like an Anima LoRA") + + class ControlAdapter_Config_Base(ABC, BaseModel): default_settings: ControlAdapterDefaultSettings | None = Field(None) diff --git a/invokeai/backend/model_manager/configs/main.py b/invokeai/backend/model_manager/configs/main.py index 6f737ceb92..1be349f394 100644 --- a/invokeai/backend/model_manager/configs/main.py +++ b/invokeai/backend/model_manager/configs/main.py @@ -28,6 +28,7 @@ from invokeai.backend.model_manager.taxonomy import ( ModelFormat, ModelType, ModelVariantType, + QwenImageVariantType, SchedulerPredictionType, SubModelType, ZImageVariantType, @@ -76,6 +77,8 @@ class MainModelDefaultSettings(BaseModel): else: # Turbo (distilled) uses fewer steps, no CFG return cls(steps=9, cfg_scale=1.0, width=1024, height=1024) + case BaseModelType.Anima: + return cls(steps=35, cfg_scale=4.5, width=1024, height=1024) case BaseModelType.Flux2: # Different defaults based on variant if variant == Flux2VariantType.Klein9BBase: @@ -84,6 +87,8 @@ class MainModelDefaultSettings(BaseModel): else: # Distilled models (Klein 4B, Klein 9B) use fewer steps return cls(steps=4, cfg_scale=1.0, width=1024, height=1024) + case BaseModelType.QwenImage: + return cls(steps=40, cfg_scale=4.0, width=1024, height=1024) case _: # TODO(psyche): Do we want defaults for other base types? return None @@ -194,9 +199,11 @@ class Main_SD_Checkpoint_Config_Base(Checkpoint_Config_Base, Main_Config_Base): cls._validate_base(mod) - prediction_type = override_fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod) + prediction_type = override_fields.pop("prediction_type", None) or cls._get_scheduler_prediction_type_or_raise( + mod + ) - variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod) return cls(**override_fields, prediction_type=prediction_type, variant=variant) @@ -323,6 +330,16 @@ def _is_flux2_model(state_dict: dict[str | int, Any]) -> bool: return False +def _filename_suggests_base(name: str) -> bool: + """Check if a model name/filename suggests it is a Base (undistilled) variant. + + Klein 9B Base and Klein 9B have identical architectures and cannot be distinguished + from the state dict. We use the filename as a heuristic: filenames containing "base" + (e.g. "flux-2-klein-base-9b", "FLUX.2-klein-base-9B") indicate the undistilled model. + """ + return "base" in name.lower() + + def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | None: """Determine FLUX.2 variant from state dict. @@ -330,9 +347,9 @@ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | N - Klein 4B: context_in_dim = 7680 (3 × Qwen3-4B hidden_size 2560) - Klein 9B: context_in_dim = 12288 (3 × Qwen3-8B hidden_size 4096) - Note: Klein 9B Base (undistilled) also has context_in_dim = 12288 but is rare. - We default to Klein9B (distilled) for all 9B models since GGUF models may not - include guidance embedding keys needed to distinguish them. + Note: Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures + and cannot be distinguished from the state dict alone. This function defaults to Klein9B + for all 9B models. Callers should use filename heuristics to detect Klein9BBase. Supports both BFL format (checkpoint) and diffusers format keys: - BFL format: txt_in.weight (context embedder) @@ -366,7 +383,7 @@ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | N context_in_dim = shape[1] # Determine variant based on context dimension if context_in_dim == KLEIN_9B_CONTEXT_DIM: - # Default to Klein9B (distilled) - the official/common 9B model + # Default to Klein9B - callers use filename heuristics to detect Klein9BBase return Flux2VariantType.Klein9B elif context_in_dim == KLEIN_4B_CONTEXT_DIM: return Flux2VariantType.Klein4B @@ -459,7 +476,7 @@ class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Conf cls._validate_does_not_look_like_gguf_quantized(mod) - variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod) return cls(**override_fields, variant=variant) @@ -534,7 +551,7 @@ class Main_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Con cls._validate_does_not_look_like_gguf_quantized(mod) - variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod) return cls(**override_fields, variant=variant) @@ -553,6 +570,11 @@ class Main_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Con if variant is None: raise NotAMatchError("unable to determine FLUX.2 model variant from state dict") + # Klein 9B Base and Klein 9B have identical architectures. + # Use filename heuristic to detect the Base (undistilled) variant. + if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name): + return Flux2VariantType.Klein9BBase + return variant @classmethod @@ -592,7 +614,7 @@ class Main_BnBNF4_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_B cls._validate_model_looks_like_bnb_quantized(mod) - variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod) return cls(**override_fields, variant=variant) @@ -643,7 +665,7 @@ class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Bas cls._validate_is_not_flux2(mod) - variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod) return cls(**override_fields, variant=variant) @@ -701,7 +723,7 @@ class Main_GGUF_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Ba cls._validate_is_flux2(mod) - variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod) return cls(**override_fields, variant=variant) @@ -720,6 +742,11 @@ class Main_GGUF_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Ba if variant is None: raise NotAMatchError("unable to determine FLUX.2 model variant from state dict") + # Klein 9B Base and Klein 9B have identical architectures. + # Use filename heuristic to detect the Base (undistilled) variant. + if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name): + return Flux2VariantType.Klein9BBase + return variant @classmethod @@ -757,9 +784,9 @@ class Main_Diffusers_FLUX_Config(Diffusers_Config_Base, Main_Config_Base, Config }, ) - variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod) - repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + repo_variant = override_fields.pop("repo_variant", None) or cls._get_repo_variant_or_raise(mod) return cls( **override_fields, @@ -811,9 +838,9 @@ class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Confi }, ) - variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod) - repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + repo_variant = override_fields.pop("repo_variant", None) or cls._get_repo_variant_or_raise(mod) return cls( **override_fields, @@ -829,12 +856,8 @@ class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Confi - Klein 4B: joint_attention_dim = 7680 (3×Qwen3-4B hidden size) - Klein 9B/9B Base: joint_attention_dim = 12288 (3×Qwen3-8B hidden size) - To distinguish Klein 9B (distilled) from Klein 9B Base (undistilled), - we check guidance_embeds: - - Klein 9B (distilled): guidance_embeds = False (guidance is "baked in" during distillation) - - Klein 9B Base (undistilled): guidance_embeds = True (needs guidance at inference) - - Note: The official BFL Klein 9B model is the distilled version with guidance_embeds=False. + Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures + and both have guidance_embeds=False. We use a filename heuristic to detect Base models. """ KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560 KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096 @@ -842,17 +865,12 @@ class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Confi transformer_config = get_config_dict_or_raise(mod.path / "transformer" / "config.json") joint_attention_dim = transformer_config.get("joint_attention_dim", 4096) - guidance_embeds = transformer_config.get("guidance_embeds", False) # Determine variant based on joint_attention_dim if joint_attention_dim == KLEIN_9B_CONTEXT_DIM: - # Check guidance_embeds to distinguish distilled from undistilled - # Klein 9B (distilled): guidance_embeds = False (guidance is baked in) - # Klein 9B Base (undistilled): guidance_embeds = True (needs guidance) - if guidance_embeds: + if _filename_suggests_base(mod.name): return Flux2VariantType.Klein9BBase - else: - return Flux2VariantType.Klein9B + return Flux2VariantType.Klein9B elif joint_attention_dim == KLEIN_4B_CONTEXT_DIM: return Flux2VariantType.Klein4B elif joint_attention_dim > 4096: @@ -891,11 +909,13 @@ class Main_SD_Diffusers_Config_Base(Diffusers_Config_Base, Main_Config_Base): cls._validate_base(mod) - variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod) - prediction_type = override_fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod) + prediction_type = override_fields.pop("prediction_type", None) or cls._get_scheduler_prediction_type_or_raise( + mod + ) - repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + repo_variant = override_fields.pop("repo_variant", None) or cls._get_repo_variant_or_raise(mod) return cls( **override_fields, @@ -1001,9 +1021,9 @@ class Main_Diffusers_SD3_Config(Diffusers_Config_Base, Main_Config_Base, Config_ }, ) - submodels = override_fields.get("submodels") or cls._get_submodels_or_raise(mod) + submodels = override_fields.pop("submodels", None) or cls._get_submodels_or_raise(mod) - repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + repo_variant = override_fields.pop("repo_variant", None) or cls._get_repo_variant_or_raise(mod) return cls( **override_fields, @@ -1076,7 +1096,7 @@ class Main_Diffusers_CogView4_Config(Diffusers_Config_Base, Main_Config_Base, Co }, ) - repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + repo_variant = override_fields.pop("repo_variant", None) or cls._get_repo_variant_or_raise(mod) return cls( **override_fields, @@ -1084,6 +1104,44 @@ class Main_Diffusers_CogView4_Config(Diffusers_Config_Base, Main_Config_Base, Co ) +def _has_anima_keys(state_dict: dict[str | int, Any]) -> bool: + """Check if state dict contains Anima model keys. + + Anima models are identified by the presence of `llm_adapter` keys + (unique to Anima - the LLM Adapter that bridges Qwen3 text encoder to the Cosmos DiT) + alongside Cosmos Predict2 DiT keys (blocks, t_embedder, x_embedder, final_layer). + + The checkpoint keys may have a `net.` prefix (e.g. `net.llm_adapter.`, `net.blocks.`). + """ + has_llm_adapter = False + has_cosmos_dit = False + + # Cosmos DiT key prefixes — support both with and without `net.` prefix + cosmos_prefixes = ( + "blocks.", + "t_embedder.", + "x_embedder.", + "final_layer.", + "net.blocks.", + "net.t_embedder.", + "net.x_embedder.", + "net.final_layer.", + ) + + for key in state_dict.keys(): + if isinstance(key, int): + continue + if key.startswith("llm_adapter.") or key.startswith("net.llm_adapter."): + has_llm_adapter = True + for prefix in cosmos_prefixes: + if key.startswith(prefix): + has_cosmos_dit = True + if has_llm_adapter and has_cosmos_dit: + return True + + return False + + class Main_Diffusers_ZImage_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base): """Model config for Z-Image diffusers models (Z-Image-Turbo, Z-Image-Base).""" @@ -1104,9 +1162,9 @@ class Main_Diffusers_ZImage_Config(Diffusers_Config_Base, Main_Config_Base, Conf }, ) - variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod) - repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + repo_variant = override_fields.pop("repo_variant", None) or cls._get_repo_variant_or_raise(mod) return cls( **override_fields, @@ -1150,7 +1208,7 @@ class Main_Checkpoint_ZImage_Config(Checkpoint_Config_Base, Main_Config_Base, Co cls._validate_does_not_look_like_gguf_quantized(mod) - variant = override_fields.get("variant", ZImageVariantType.Turbo) + variant = override_fields.pop("variant", None) or ZImageVariantType.Turbo return cls(**override_fields, variant=variant) @@ -1184,7 +1242,7 @@ class Main_GGUF_ZImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_B cls._validate_looks_like_gguf_quantized(mod) - variant = override_fields.get("variant", ZImageVariantType.Turbo) + variant = override_fields.pop("variant", None) or ZImageVariantType.Turbo return cls(**override_fields, variant=variant) @@ -1199,3 +1257,130 @@ class Main_GGUF_ZImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_B has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict()) if not has_ggml_tensors: raise NotAMatchError("state dict does not look like GGUF quantized") + + +class Main_Diffusers_QwenImage_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base): + """Model config for Qwen Image diffusers models (both txt2img and edit).""" + + base: Literal[BaseModelType.QwenImage] = Field(BaseModelType.QwenImage) + variant: QwenImageVariantType | None = Field(default=None) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + # This check implies the base type - no further validation needed. + raise_for_class_name( + common_config_paths(mod.path), + { + "QwenImagePlusPipeline", + "QwenImageEditPlusPipeline", + "QwenImagePipeline", + }, + ) + + repo_variant = override_fields.pop("repo_variant", None) or cls._get_repo_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_qwen_image_variant(mod) + + return cls( + **override_fields, + repo_variant=repo_variant, + variant=variant, + ) + + @classmethod + def _get_qwen_image_variant(cls, mod: ModelOnDisk) -> QwenImageVariantType: + """Detect whether this is an edit or txt2img model from the pipeline class name.""" + import json + + model_index = mod.path / "model_index.json" + if model_index.exists(): + with open(model_index) as f: + config = json.load(f) + class_name = config.get("_class_name", "") + if "Edit" in class_name: + return QwenImageVariantType.Edit + return QwenImageVariantType.Generate + + +def _has_qwen_image_keys(state_dict: dict[str | int, Any]) -> bool: + """Check if state dict contains Qwen Image Edit transformer keys. + + Qwen Image Edit uses 'txt_in' and 'txt_norm' instead of 'context_embedder' (FLUX). + This distinguishes it from FLUX and other architectures. + """ + has_txt_in = any(isinstance(k, str) and k.startswith("txt_in.") for k in state_dict.keys()) + has_txt_norm = any(isinstance(k, str) and k.startswith("txt_norm.") for k in state_dict.keys()) + has_img_in = any(isinstance(k, str) and k.startswith("img_in.") for k in state_dict.keys()) + # Must NOT have context_embedder (which would indicate FLUX) + has_context_embedder = any(isinstance(k, str) and "context_embedder" in k for k in state_dict.keys()) + return has_txt_in and has_txt_norm and has_img_in and not has_context_embedder + + +class Main_GGUF_QwenImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base): + """Model config for GGUF-quantized Qwen Image transformer models.""" + + base: Literal[BaseModelType.QwenImage] = Field(default=BaseModelType.QwenImage) + format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized) + variant: QwenImageVariantType | None = Field(default=None) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + sd = mod.load_state_dict() + + if not _has_qwen_image_keys(sd): + raise NotAMatchError("state dict does not look like a Qwen Image Edit model") + + if not _has_ggml_tensors(sd): + raise NotAMatchError("state dict does not look like GGUF quantized") + + # Infer variant from the state dict if not explicitly provided. + # The Edit variant includes an extra tensor `__index_timestep_zero__` (used by the + # `zero_cond_t` dual-modulation path in diffusers' QwenImageTransformer2DModel). + # If the marker tensor is missing, fall back to the filename heuristic since older + # or alternate GGUF converters may not emit it. + explicit_variant = override_fields.pop("variant", None) + if explicit_variant is None: + if "__index_timestep_zero__" in sd: + explicit_variant = QwenImageVariantType.Edit + else: + filename = mod.path.stem.lower() + if "edit" in filename: + explicit_variant = QwenImageVariantType.Edit + else: + explicit_variant = QwenImageVariantType.Generate + + return cls(**override_fields, variant=explicit_variant) + + +class Main_Checkpoint_Anima_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base): + """Model config for Anima single-file checkpoint models (safetensors). + + Anima is built on NVIDIA Cosmos Predict2 DiT with a custom LLM Adapter + that bridges Qwen3 0.6B text encoder outputs to the DiT. + """ + + base: Literal[BaseModelType.Anima] = Field(default=BaseModelType.Anima) + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + cls._validate_looks_like_anima_model(mod) + + return cls(**override_fields) + + @classmethod + def _validate_looks_like_anima_model(cls, mod: ModelOnDisk) -> None: + has_anima_keys = _has_anima_keys(mod.load_state_dict()) + if not has_anima_keys: + raise NotAMatchError("state dict does not look like an Anima model") diff --git a/invokeai/backend/model_manager/configs/qwen3_encoder.py b/invokeai/backend/model_manager/configs/qwen3_encoder.py index 2e24fee918..82cf3b62c8 100644 --- a/invokeai/backend/model_manager/configs/qwen3_encoder.py +++ b/invokeai/backend/model_manager/configs/qwen3_encoder.py @@ -47,9 +47,10 @@ def _has_ggml_tensors(state_dict: dict[str | int, Any]) -> bool: def _get_qwen3_variant_from_state_dict(state_dict: dict[str | int, Any]) -> Optional[Qwen3VariantType]: - """Determine Qwen3 variant (4B vs 8B) from state dict based on hidden_size. + """Determine Qwen3 variant (0.6B, 4B, or 8B) from state dict based on hidden_size. The hidden_size can be determined from the embed_tokens.weight tensor shape: + - Qwen3 0.6B: hidden_size = 1024 - Qwen3 4B: hidden_size = 2560 - Qwen3 8B: hidden_size = 4096 @@ -57,6 +58,7 @@ def _get_qwen3_variant_from_state_dict(state_dict: dict[str | int, Any]) -> Opti For PyTorch format, the key is 'model.embed_tokens.weight'. """ # Hidden size thresholds + QWEN3_06B_HIDDEN_SIZE = 1024 QWEN3_4B_HIDDEN_SIZE = 2560 QWEN3_8B_HIDDEN_SIZE = 4096 @@ -91,7 +93,9 @@ def _get_qwen3_variant_from_state_dict(state_dict: dict[str | int, Any]) -> Opti return None # Determine variant based on hidden_size - if hidden_size == QWEN3_4B_HIDDEN_SIZE: + if hidden_size == QWEN3_06B_HIDDEN_SIZE: + return Qwen3VariantType.Qwen3_06B + elif hidden_size == QWEN3_4B_HIDDEN_SIZE: return Qwen3VariantType.Qwen3_4B elif hidden_size == QWEN3_8B_HIDDEN_SIZE: return Qwen3VariantType.Qwen3_8B @@ -206,6 +210,7 @@ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base): @classmethod def _get_variant_from_config(cls, config_path) -> Qwen3VariantType: """Get variant from config.json based on hidden_size.""" + QWEN3_06B_HIDDEN_SIZE = 1024 QWEN3_4B_HIDDEN_SIZE = 2560 QWEN3_8B_HIDDEN_SIZE = 4096 @@ -217,6 +222,8 @@ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base): return Qwen3VariantType.Qwen3_8B elif hidden_size == QWEN3_4B_HIDDEN_SIZE: return Qwen3VariantType.Qwen3_4B + elif hidden_size == QWEN3_06B_HIDDEN_SIZE: + return Qwen3VariantType.Qwen3_06B else: # Default to 4B for unknown sizes return Qwen3VariantType.Qwen3_4B diff --git a/invokeai/backend/model_manager/configs/vae.py b/invokeai/backend/model_manager/configs/vae.py index cc079cb9aa..ce26a94a6e 100644 --- a/invokeai/backend/model_manager/configs/vae.py +++ b/invokeai/backend/model_manager/configs/vae.py @@ -175,6 +175,43 @@ class VAE_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Config_Base): raise NotAMatchError("state dict does not look like a FLUX.2 VAE") +def _has_anima_vae_keys(state_dict: dict[str | int, Any]) -> bool: + """Check if state dict looks like an Anima QwenImage VAE (AutoencoderKLQwenImage). + + The Anima VAE has a distinctive structure with: + - encoder.downsamples.* (instead of encoder.down_blocks) + - decoder.upsamples.* (instead of decoder.up_blocks) + - decoder.head.* / decoder.middle.* + - Top-level conv1/conv2 weights + """ + required_prefixes = { + "encoder.downsamples.", + "decoder.upsamples.", + "decoder.middle.", + } + return all(any(str(k).startswith(prefix) for k in state_dict) for prefix in required_prefixes) + + +class VAE_Checkpoint_Anima_Config(Checkpoint_Config_Base, Config_Base): + """Model config for Anima QwenImage VAE checkpoint models (AutoencoderKLQwenImage).""" + + type: Literal[ModelType.VAE] = Field(default=ModelType.VAE) + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + base: Literal[BaseModelType.Anima] = Field(default=BaseModelType.Anima) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + state_dict = mod.load_state_dict() + if not _has_anima_vae_keys(state_dict): + raise NotAMatchError("state dict does not look like an Anima QwenImage VAE") + + return cls(**override_fields) + + class VAE_Diffusers_Config_Base(Diffusers_Config_Base): """Model config for standalone VAE models (diffusers version).""" diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index a4004afba7..4609a2e92a 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -14,6 +14,9 @@ import torch from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord +from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( + CachedModelWithPartialLoad, +) from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType @@ -55,7 +58,12 @@ class LoadedModelWithoutConfig: def __enter__(self) -> AnyModel: self._cache.lock(self._cache_record, None) - return self.model + try: + self.repair_required_tensors_on_device() + return self.model + except Exception: + self._cache.unlock(self._cache_record) + raise def __exit__(self, *args: Any, **kwargs: Any) -> None: self._cache.unlock(self._cache_record) @@ -71,6 +79,7 @@ class LoadedModelWithoutConfig: """ self._cache.lock(self._cache_record, working_mem_bytes) try: + self.repair_required_tensors_on_device() yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model) finally: self._cache.unlock(self._cache_record) @@ -80,6 +89,13 @@ class LoadedModelWithoutConfig: """Return the model without locking it.""" return self._cache_record.cached_model.model + def repair_required_tensors_on_device(self) -> int: + """Repair required tensors that should be resident on the cached model's execution device.""" + cached_model = self._cache_record.cached_model + if not isinstance(cached_model, CachedModelWithPartialLoad): + return 0 + return cached_model.repair_required_tensors_on_compute_device() + class LoadedModel(LoadedModelWithoutConfig): """Context manager object that mediates transfer from RAM<->VRAM.""" diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py index f80b017ba7..328978b45b 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py @@ -149,6 +149,27 @@ class CachedModelWithPartialLoad: """Unload all weights from VRAM.""" return self.partial_unload_from_vram(self.total_bytes()) + @torch.no_grad() + def repair_required_tensors_on_compute_device(self) -> int: + """Repair required non-autocast tensors that were left off the compute device. + + This can happen if an interrupted run leaves the model in a partially inconsistent state. Any repaired device + movement invalidates the cached VRAM accounting. + """ + cur_state_dict = self._model.state_dict() + keys_to_repair = { + key + for key in self._keys_in_modules_that_do_not_support_autocast + if cur_state_dict[key].device.type != self._compute_device.type + } + if len(keys_to_repair) == 0: + return 0 + + self._load_state_dict_with_device_conversion(cur_state_dict, keys_to_repair, self._compute_device) + self._move_non_persistent_buffers_to_device(self._compute_device) + self._cur_vram_bytes = None + return len(keys_to_repair) + def _load_state_dict_with_device_conversion( self, state_dict: dict[str, torch.Tensor], keys_to_convert: set[str], target_device: torch.device ): diff --git a/invokeai/backend/model_manager/load/model_loaders/anima.py b/invokeai/backend/model_manager/load/model_loaders/anima.py new file mode 100644 index 0000000000..6549c220a8 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/anima.py @@ -0,0 +1,140 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for Anima model loading in InvokeAI.""" + +from pathlib import Path +from typing import Optional + +import accelerate + +from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base +from invokeai.backend.model_manager.configs.factory import AnyModelConfig +from invokeai.backend.model_manager.configs.main import Main_Checkpoint_Anima_Config +from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry +from invokeai.backend.model_manager.taxonomy import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelType, + SubModelType, +) +from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.logging import InvokeAILogger + +logger = InvokeAILogger.get_logger(__name__) + + +@ModelLoaderRegistry.register(base=BaseModelType.Anima, type=ModelType.Main, format=ModelFormat.Checkpoint) +class AnimaCheckpointModel(ModelLoader): + """Class to load Anima transformer models from single-file checkpoints. + + The Anima checkpoint contains both the MiniTrainDIT backbone and the LLM Adapter + under a shared `net.` prefix. The loader strips this prefix and instantiates + the AnimaTransformer model with the correct architecture parameters. + """ + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not isinstance(config, Checkpoint_Config_Base): + raise ValueError("Only CheckpointConfigBase models are currently supported here.") + + match submodel_type: + case SubModelType.Transformer: + return self._load_from_singlefile(config) + + raise ValueError( + f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" + ) + + def _load_from_singlefile( + self, + config: AnyModelConfig, + ) -> AnyModel: + from safetensors.torch import load_file + + from invokeai.backend.anima.anima_transformer import AnimaTransformer + + if not isinstance(config, Main_Checkpoint_Anima_Config): + raise TypeError( + f"Expected Main_Checkpoint_Anima_Config, got {type(config).__name__}. " + "Model configuration type mismatch." + ) + model_path = Path(config.path) + + # Load the state dict from safetensors + sd = load_file(model_path) + + # Strip the `net.` prefix that all Anima checkpoint keys have + # e.g., "net.blocks.0.self_attn.q_proj.weight" -> "blocks.0.self_attn.q_proj.weight" + prefix_to_strip = None + for prefix in ["net."]: + if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)): + prefix_to_strip = prefix + break + + if prefix_to_strip: + stripped_sd = {} + for key, value in sd.items(): + if isinstance(key, str) and key.startswith(prefix_to_strip): + stripped_sd[key[len(prefix_to_strip) :]] = value + else: + stripped_sd[key] = value + sd = stripped_sd + + # Create an empty AnimaTransformer with Anima's default architecture parameters + with accelerate.init_empty_weights(): + model = AnimaTransformer( + max_img_h=240, + max_img_w=240, + max_frames=1, + in_channels=16, + out_channels=16, + patch_spatial=2, + patch_temporal=1, + concat_padding_mask=True, + model_channels=2048, + num_blocks=28, + num_heads=16, + mlp_ratio=4.0, + crossattn_emb_channels=1024, + pos_emb_cls="rope3d", + use_adaln_lora=True, + adaln_lora_dim=256, + extra_per_block_abs_pos_emb=False, + image_model="anima", + ) + + # Determine safe dtype + target_device = TorchDevice.choose_torch_device() + model_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device) + + # Handle memory management + new_sd_size = sum(ten.nelement() * model_dtype.itemsize for ten in sd.values()) + self._ram_cache.make_room(new_sd_size) + + # Convert to target dtype (skip non-float tensors like embedding indices) + for k in sd.keys(): + if sd[k].is_floating_point(): + sd[k] = sd[k].to(model_dtype) + + # Filter out rotary embedding inv_freq buffers that are regenerated at runtime + keys_to_remove = [k for k in sd.keys() if k.endswith(".inv_freq")] + for k in keys_to_remove: + del sd[k] + + load_result = model.load_state_dict(sd, assign=True, strict=False) + if load_result.unexpected_keys: + raise RuntimeError( + f"Checkpoint contains {len(load_result.unexpected_keys)} unexpected keys. " + f"This may indicate a corrupted or incompatible checkpoint. " + f"First 5 unexpected keys: {load_result.unexpected_keys[:5]}" + ) + if load_result.missing_keys: + logger.warning( + f"Checkpoint is missing {len(load_result.missing_keys)} keys " + f"(expected for inv_freq buffers). First 5: {load_result.missing_keys[:5]}" + ) + return model diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index d39982456a..6cf06d4807 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -21,6 +21,7 @@ from invokeai.backend.model_manager.taxonomy import ( ModelType, SubModelType, ) +from invokeai.backend.patches.lora_conversions.anima_lora_conversion_utils import lora_model_from_anima_state_dict from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import ( is_state_dict_likely_in_flux_aitoolkit_format, lora_model_from_flux_aitoolkit_state_dict, @@ -44,6 +45,10 @@ from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils is_state_dict_likely_in_flux_kohya_format, lora_model_from_flux_kohya_state_dict, ) +from invokeai.backend.patches.lora_conversions.flux_onetrainer_bfl_lora_conversion_utils import ( + is_state_dict_likely_in_flux_onetrainer_bfl_format, + lora_model_from_flux_onetrainer_bfl_state_dict, +) from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import ( is_state_dict_likely_in_flux_onetrainer_format, lora_model_from_flux_onetrainer_state_dict, @@ -52,6 +57,9 @@ from invokeai.backend.patches.lora_conversions.flux_xlabs_lora_conversion_utils is_state_dict_likely_in_flux_xlabs_format, lora_model_from_flux_xlabs_state_dict, ) +from invokeai.backend.patches.lora_conversions.qwen_image_lora_conversion_utils import ( + lora_model_from_qwen_image_state_dict, +) from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import lora_model_from_z_image_state_dict @@ -128,6 +136,8 @@ class LoRALoader(ModelLoader): model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None) elif is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict): model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict) + elif is_state_dict_likely_in_flux_onetrainer_bfl_format(state_dict=state_dict): + model = lora_model_from_flux_onetrainer_bfl_state_dict(state_dict=state_dict) elif is_state_dict_likely_in_flux_onetrainer_format(state_dict=state_dict): model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict) elif is_state_dict_likely_flux_control(state_dict=state_dict): @@ -155,6 +165,11 @@ class LoRALoader(ModelLoader): # Z-Image LoRAs use diffusers PEFT format with transformer and/or Qwen3 encoder layers. # We set alpha=None to use rank as alpha (common default). model = lora_model_from_z_image_state_dict(state_dict=state_dict, alpha=None) + elif self._model_base == BaseModelType.QwenImage: + model = lora_model_from_qwen_image_state_dict(state_dict=state_dict, alpha=None) + elif self._model_base == BaseModelType.Anima: + # Anima LoRAs use Kohya-style or diffusers PEFT format targeting Cosmos DiT blocks. + model = lora_model_from_anima_state_dict(state_dict=state_dict, alpha=None) else: raise ValueError(f"Unsupported LoRA base model: {self._model_base}") diff --git a/invokeai/backend/model_manager/load/model_loaders/qwen_image.py b/invokeai/backend/model_manager/load/model_loaders/qwen_image.py new file mode 100644 index 0000000000..a025e72794 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/qwen_image.py @@ -0,0 +1,177 @@ +from pathlib import Path +from typing import Optional + +import accelerate +import torch + +from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.factory import AnyModelConfig +from invokeai.backend.model_manager.configs.main import Main_GGUF_QwenImage_Config +from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry +from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader +from invokeai.backend.model_manager.taxonomy import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelType, + QwenImageVariantType, + SubModelType, +) +from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor +from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader +from invokeai.backend.util.devices import TorchDevice + + +@ModelLoaderRegistry.register(base=BaseModelType.QwenImage, type=ModelType.Main, format=ModelFormat.Diffusers) +class QwenImageDiffusersModel(GenericDiffusersLoader): + """Class to load Qwen Image Edit main models.""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if isinstance(config, Checkpoint_Config_Base): + raise NotImplementedError("CheckpointConfigBase is not implemented for Qwen Image Edit models.") + + if submodel_type is None: + raise Exception("A submodel type must be provided when loading main pipelines.") + + model_path = Path(config.path) + load_class = self.get_hf_load_class(model_path, submodel_type) + repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None + variant = repo_variant.value if repo_variant else None + model_path = model_path / submodel_type.value + + # We force bfloat16 for Qwen Image Edit models. + # Use `dtype` (newer) with fallback to `torch_dtype` (older diffusers). + dtype_kwarg = {"dtype": torch.bfloat16} + try: + result: AnyModel = load_class.from_pretrained( + model_path, + **dtype_kwarg, + variant=variant, + local_files_only=True, + ) + except TypeError: + # Older diffusers uses torch_dtype instead of dtype + dtype_kwarg = {"torch_dtype": torch.bfloat16} + result = load_class.from_pretrained( + model_path, + **dtype_kwarg, + variant=variant, + local_files_only=True, + ) + except OSError as e: + if variant and "no file named" in str(e): + result = load_class.from_pretrained(model_path, **dtype_kwarg, local_files_only=True) + else: + raise e + + return result + + +@ModelLoaderRegistry.register(base=BaseModelType.QwenImage, type=ModelType.Main, format=ModelFormat.GGUFQuantized) +class QwenImageGGUFCheckpointModel(ModelLoader): + """Class to load GGUF-quantized Qwen Image Edit transformer models.""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not isinstance(config, Checkpoint_Config_Base): + raise ValueError("Only CheckpointConfigBase models are currently supported here.") + + match submodel_type: + case SubModelType.Transformer: + return self._load_from_singlefile(config) + + raise ValueError( + f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" + ) + + def _load_from_singlefile(self, config: AnyModelConfig) -> AnyModel: + from diffusers import QwenImageTransformer2DModel + + if not isinstance(config, Main_GGUF_QwenImage_Config): + raise TypeError(f"Expected Main_GGUF_QwenImage_Config, got {type(config).__name__}.") + model_path = Path(config.path) + + target_device = TorchDevice.choose_torch_device() + compute_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device) + + sd = gguf_sd_loader(model_path, compute_dtype=compute_dtype) + + # Strip ComfyUI-style prefixes if present + prefix_to_strip = None + for prefix in ["model.diffusion_model.", "diffusion_model."]: + if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)): + prefix_to_strip = prefix + break + + if prefix_to_strip: + stripped_sd = {} + for key, value in sd.items(): + if isinstance(key, str) and key.startswith(prefix_to_strip): + stripped_sd[key[len(prefix_to_strip) :]] = value + else: + stripped_sd[key] = value + sd = stripped_sd + + # Auto-detect architecture from state dict + num_layers = 0 + for key in sd.keys(): + if isinstance(key, str) and key.startswith("transformer_blocks."): + parts = key.split(".") + if len(parts) >= 2: + try: + layer_idx = int(parts[1]) + num_layers = max(num_layers, layer_idx + 1) + except ValueError: + pass + + # Detect dimensions from weights + num_attention_heads = 24 # default + attention_head_dim = 128 # default + + if "img_in.weight" in sd: + w = sd["img_in.weight"] + shape = w.tensor_shape if isinstance(w, GGMLTensor) else w.shape + hidden_dim = shape[0] + in_channels = shape[1] + num_attention_heads = hidden_dim // attention_head_dim + + joint_attention_dim = 3584 # default + if "txt_in.weight" in sd: + w = sd["txt_in.weight"] + shape = w.tensor_shape if isinstance(w, GGMLTensor) else w.shape + joint_attention_dim = shape[1] + + model_config: dict = { + "patch_size": 2, + "in_channels": in_channels if "img_in.weight" in sd else 64, + "out_channels": 16, + "num_layers": num_layers if num_layers > 0 else 60, + "attention_head_dim": attention_head_dim, + "num_attention_heads": num_attention_heads, + "joint_attention_dim": joint_attention_dim, + "guidance_embeds": False, + "axes_dims_rope": (16, 56, 56), + } + + # zero_cond_t is only used by edit-variant models. It enables dual modulation + # for noisy vs reference patches. Setting it on txt2img models produces garbage. + # Also requires diffusers 0.37+ (the parameter doesn't exist in older versions). + import inspect + + is_edit = getattr(config, "variant", None) == QwenImageVariantType.Edit + if is_edit and "zero_cond_t" in inspect.signature(QwenImageTransformer2DModel.__init__).parameters: + model_config["zero_cond_t"] = True + + with accelerate.init_empty_weights(): + model = QwenImageTransformer2DModel(**model_config) + + model.load_state_dict(sd, strict=False, assign=True) + return model diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py index e91903ccda..db26e8c654 100644 --- a/invokeai/backend/model_manager/load/model_loaders/vae.py +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -6,7 +6,7 @@ from typing import Optional from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from invokeai.backend.model_manager.configs.factory import AnyModelConfig -from invokeai.backend.model_manager.configs.vae import VAE_Checkpoint_Config_Base +from invokeai.backend.model_manager.configs.vae import VAE_Checkpoint_Anima_Config, VAE_Checkpoint_Config_Base from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.model_manager.taxonomy import ( @@ -28,7 +28,14 @@ class VAELoader(GenericDiffusersLoader): config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, ) -> AnyModel: - if isinstance(config, VAE_Checkpoint_Config_Base): + if isinstance(config, VAE_Checkpoint_Anima_Config): + from diffusers.models.autoencoders import AutoencoderKLWan + + return AutoencoderKLWan.from_single_file( + config.path, + torch_dtype=self._torch_dtype, + ) + elif isinstance(config, VAE_Checkpoint_Config_Base): return AutoencoderKL.from_single_file( config.path, torch_dtype=self._torch_dtype, diff --git a/invokeai/backend/model_manager/metadata/fetch/huggingface.py b/invokeai/backend/model_manager/metadata/fetch/huggingface.py index 1b2b6c3674..30fe418fe1 100644 --- a/invokeai/backend/model_manager/metadata/fetch/huggingface.py +++ b/invokeai/backend/model_manager/metadata/fetch/huggingface.py @@ -19,8 +19,7 @@ from pathlib import Path from typing import Optional import requests -from huggingface_hub import HfApi, configure_http_backend, hf_hub_url -from huggingface_hub.errors import RepositoryNotFoundError, RevisionNotFoundError +from huggingface_hub import hf_hub_url from pydantic.networks import AnyHttpUrl from requests.sessions import Session @@ -47,7 +46,6 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase): this module without an internet connection. """ self._requests = session or requests.Session() - configure_http_backend(backend_factory=lambda: self._requests) @classmethod def from_json(cls, json: str) -> HuggingFaceMetadata: @@ -55,6 +53,22 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase): metadata = HuggingFaceMetadata.model_validate_json(json) return metadata + def _fetch_model_info(self, repo_id: str, variant: Optional[ModelRepoVariant] = None) -> dict: + """Fetch model info from HuggingFace API using self._requests session. + + This allows the session to be mocked in tests via requests_testadapter. + """ + url = f"https://huggingface.co/api/models/{repo_id}" + params: dict[str, str] = {"blobs": "True"} + if variant is not None: + params["revision"] = str(variant) + + response = self._requests.get(url, params=params) + if response.status_code == 404: + raise UnknownMetadataException(f"'{repo_id}' not found.") + response.raise_for_status() + return response.json() + def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata: """Return a HuggingFaceMetadata object given the model's repo_id.""" # Little loop which tries fetching a revision corresponding to the selected variant. @@ -67,10 +81,10 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase): repo_id = id.split("::")[0] or id while not model_info: try: - model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True, revision=variant) - except RepositoryNotFoundError as excp: - raise UnknownMetadataException(f"'{repo_id}' not found. See trace for details.") from excp - except RevisionNotFoundError: + model_info = self._fetch_model_info(repo_id, variant) + except UnknownMetadataException: + raise + except requests.HTTPError: if variant is None: raise else: @@ -80,15 +94,18 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase): _, name = repo_id.split("/") - for s in model_info.siblings or []: - assert s.rfilename is not None - assert s.size is not None + for s in model_info.get("siblings") or []: + rfilename = s.get("rfilename") + size = s.get("size") + assert rfilename is not None + assert size is not None + lfs = s.get("lfs") files.append( RemoteModelFile( - url=hf_hub_url(repo_id, s.rfilename, revision=variant or "main"), - path=Path(name, s.rfilename), - size=s.size, - sha256=s.lfs.get("sha256") if s.lfs else None, + url=hf_hub_url(repo_id, rfilename, revision=variant or "main"), + path=Path(name, rfilename), + size=size, + sha256=lfs.get("sha256") if lfs else None, ) ) @@ -115,10 +132,10 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase): ) return HuggingFaceMetadata( - id=model_info.id, + id=model_info["id"], name=name, files=files, - api_response=json.dumps(model_info.__dict__, default=str), + api_response=json.dumps(model_info, default=str), is_diffusers=is_diffusers, ckpt_urls=ckpt_urls, ) diff --git a/invokeai/backend/model_manager/metadata/metadata_base.py b/invokeai/backend/model_manager/metadata/metadata_base.py index e16ad4cbc4..b048144e54 100644 --- a/invokeai/backend/model_manager/metadata/metadata_base.py +++ b/invokeai/backend/model_manager/metadata/metadata_base.py @@ -17,7 +17,7 @@ remote repo. from pathlib import Path from typing import List, Literal, Optional, Union -from huggingface_hub import configure_http_backend, hf_hub_url +from huggingface_hub import hf_hub_url from pydantic import BaseModel, Field, TypeAdapter from pydantic.networks import AnyHttpUrl from requests.sessions import Session @@ -111,7 +111,6 @@ class HuggingFaceMetadata(ModelMetadataWithFiles): full-precision model is returned. """ session = session or Session() - configure_http_backend(backend_factory=lambda: session) # used in testing paths = filter_files([x.path for x in self.files], variant, subfolder, subfolders) # all files in the model diff --git a/invokeai/backend/model_manager/starter_models.py b/invokeai/backend/model_manager/starter_models.py index 952a14bbdd..bb4ea87d46 100644 --- a/invokeai/backend/model_manager/starter_models.py +++ b/invokeai/backend/model_manager/starter_models.py @@ -7,8 +7,15 @@ from invokeai.backend.model_manager.configs.external_api import ( ExternalImageSize, ExternalModelCapabilities, ExternalModelPanelSchema, + ExternalResolutionPreset, +) +from invokeai.backend.model_manager.taxonomy import ( + AnyVariant, + BaseModelType, + ModelFormat, + ModelType, + QwenImageVariantType, ) -from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType class StarterModelWithoutDependencies(BaseModel): @@ -18,6 +25,7 @@ class StarterModelWithoutDependencies(BaseModel): base: BaseModelType type: ModelType format: Optional[ModelFormat] = None + variant: Optional[AnyVariant] = None is_installed: bool = False capabilities: ExternalModelCapabilities | None = None default_settings: ExternalApiModelDefaultSettings | None = None @@ -80,7 +88,7 @@ t5_base_encoder = StarterModel( name="t5_base_encoder", base=BaseModelType.Any, source="InvokeAI/t5-v1_1-xxl::bfloat16", - description="T5-XXL text encoder (used in FLUX pipelines). ~8GB", + description="T5-XXL text encoder (used in FLUX pipelines). ~9.5GB", type=ModelType.T5Encoder, ) @@ -165,7 +173,7 @@ flux_kontext_quantized = StarterModel( name="FLUX.1 Kontext dev (quantized)", base=BaseModelType.Flux, source="https://huggingface.co/unsloth/FLUX.1-Kontext-dev-GGUF/resolve/main/flux1-kontext-dev-Q4_K_M.gguf", - description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~14GB", + description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~12GB", type=ModelType.Main, dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder], ) @@ -173,7 +181,7 @@ flux_krea = StarterModel( name="FLUX.1 Krea dev", base=BaseModelType.Flux, source="https://huggingface.co/InvokeAI/FLUX.1-Krea-dev/resolve/main/flux1-krea-dev.safetensors", - description="FLUX.1 Krea dev. Total size with dependencies: ~33GB", + description="FLUX.1 Krea dev. Total size with dependencies: ~29GB", type=ModelType.Main, dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder], ) @@ -181,7 +189,7 @@ flux_krea_quantized = StarterModel( name="FLUX.1 Krea dev (quantized)", base=BaseModelType.Flux, source="https://huggingface.co/InvokeAI/FLUX.1-Krea-dev-GGUF/resolve/main/flux1-krea-dev-Q4_K_M.gguf", - description="FLUX.1 Krea dev quantized (q4_k_m). Total size with dependencies: ~14GB", + description="FLUX.1 Krea dev quantized (q4_k_m). Total size with dependencies: ~12GB", type=ModelType.Main, dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder], ) @@ -189,7 +197,7 @@ sd35_medium = StarterModel( name="SD3.5 Medium", base=BaseModelType.StableDiffusion3, source="stabilityai/stable-diffusion-3.5-medium", - description="Medium SD3.5 Model: ~15GB", + description="Medium SD3.5 Model: ~16GB", type=ModelType.Main, dependencies=[], ) @@ -197,7 +205,7 @@ sd35_large = StarterModel( name="SD3.5 Large", base=BaseModelType.StableDiffusion3, source="stabilityai/stable-diffusion-3.5-large", - description="Large SD3.5 Model: ~19G", + description="Large SD3.5 Model: ~28GB", type=ModelType.Main, dependencies=[], ) @@ -653,11 +661,143 @@ cogview4 = StarterModel( name="CogView4", base=BaseModelType.CogView4, source="THUDM/CogView4-6B", - description="The base CogView4 model (~29GB).", + description="The base CogView4 model (~31GB).", type=ModelType.Main, ) # endregion +# region Qwen Image Edit +qwen_image_edit = StarterModel( + name="Qwen Image Edit 2511", + base=BaseModelType.QwenImage, + source="Qwen/Qwen-Image-Edit-2511", + description="Qwen Image Edit 2511 full diffusers model. Supports text-guided image editing with multiple reference images. (~40GB)", + type=ModelType.Main, + variant=QwenImageVariantType.Edit, +) + +qwen_image_edit_gguf_q4_k_m = StarterModel( + name="Qwen Image Edit 2511 (Q4_K_M)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/unsloth/Qwen-Image-Edit-2511-GGUF/resolve/main/qwen-image-edit-2511-Q4_K_M.gguf", + description="Qwen Image Edit 2511 - Q4_K_M quantized transformer. Good quality/size balance. (~13GB)", + type=ModelType.Main, + format=ModelFormat.GGUFQuantized, + variant=QwenImageVariantType.Edit, +) + +qwen_image_edit_gguf_q2_k = StarterModel( + name="Qwen Image Edit 2511 (Q2_K)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/unsloth/Qwen-Image-Edit-2511-GGUF/resolve/main/qwen-image-edit-2511-Q2_K.gguf", + description="Qwen Image Edit 2511 - Q2_K heavily quantized transformer. Smallest size, lower quality. (~7.5GB)", + type=ModelType.Main, + format=ModelFormat.GGUFQuantized, + variant=QwenImageVariantType.Edit, +) + +qwen_image_edit_gguf_q6_k = StarterModel( + name="Qwen Image Edit 2511 (Q6_K)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/unsloth/Qwen-Image-Edit-2511-GGUF/resolve/main/qwen-image-edit-2511-Q6_K.gguf", + description="Qwen Image Edit 2511 - Q6_K quantized transformer. Near-lossless quality. (~17GB)", + type=ModelType.Main, + format=ModelFormat.GGUFQuantized, + variant=QwenImageVariantType.Edit, +) + +qwen_image_edit_gguf_q8_0 = StarterModel( + name="Qwen Image Edit 2511 (Q8_0)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/unsloth/Qwen-Image-Edit-2511-GGUF/resolve/main/qwen-image-edit-2511-Q8_0.gguf", + description="Qwen Image Edit 2511 - Q8_0 quantized transformer. Highest quality quantization. (~22GB)", + type=ModelType.Main, + format=ModelFormat.GGUFQuantized, + variant=QwenImageVariantType.Edit, +) + +qwen_image_edit_lightning_4step = StarterModel( + name="Qwen Image Edit Lightning (4-step, bf16)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/lightx2v/Qwen-Image-Edit-2511-Lightning/resolve/main/Qwen-Image-Edit-2511-Lightning-4steps-V1.0-bf16.safetensors", + description="Lightning distillation LoRA for Qwen Image Edit — enables generation in just 4 steps. " + "Settings: Steps=4, CFG=1, Shift Override=3.", + type=ModelType.LoRA, +) + +qwen_image_edit_lightning_8step = StarterModel( + name="Qwen Image Edit Lightning (8-step, bf16)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/lightx2v/Qwen-Image-Edit-2511-Lightning/resolve/main/Qwen-Image-Edit-2511-Lightning-8steps-V1.0-bf16.safetensors", + description="Lightning distillation LoRA for Qwen Image Edit — enables generation in 8 steps with better quality. " + "Settings: Steps=8, CFG=1, Shift Override=3.", + type=ModelType.LoRA, +) + +# Qwen Image (txt2img) +qwen_image = StarterModel( + name="Qwen Image 2512", + base=BaseModelType.QwenImage, + source="Qwen/Qwen-Image-2512", + description="Qwen Image 2512 full diffusers model. High-quality text-to-image generation. (~40GB)", + type=ModelType.Main, +) + +qwen_image_gguf_q4_k_m = StarterModel( + name="Qwen Image 2512 (Q4_K_M)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/unsloth/Qwen-Image-2512-GGUF/resolve/main/qwen-image-2512-Q4_K_M.gguf", + description="Qwen Image 2512 - Q4_K_M quantized transformer. Good quality/size balance. (~13GB)", + type=ModelType.Main, + format=ModelFormat.GGUFQuantized, +) + +qwen_image_gguf_q2_k = StarterModel( + name="Qwen Image 2512 (Q2_K)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/unsloth/Qwen-Image-2512-GGUF/resolve/main/qwen-image-2512-Q2_K.gguf", + description="Qwen Image 2512 - Q2_K heavily quantized transformer. Smallest size, lower quality. (~7.5GB)", + type=ModelType.Main, + format=ModelFormat.GGUFQuantized, +) + +qwen_image_gguf_q6_k = StarterModel( + name="Qwen Image 2512 (Q6_K)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/unsloth/Qwen-Image-2512-GGUF/resolve/main/qwen-image-2512-Q6_K.gguf", + description="Qwen Image 2512 - Q6_K quantized transformer. Near-lossless quality. (~17GB)", + type=ModelType.Main, + format=ModelFormat.GGUFQuantized, +) + +qwen_image_gguf_q8_0 = StarterModel( + name="Qwen Image 2512 (Q8_0)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/unsloth/Qwen-Image-2512-GGUF/resolve/main/qwen-image-2512-Q8_0.gguf", + description="Qwen Image 2512 - Q8_0 quantized transformer. Highest quality quantization. (~22GB)", + type=ModelType.Main, + format=ModelFormat.GGUFQuantized, +) + +qwen_image_lightning_4step = StarterModel( + name="Qwen Image Lightning (4-step, V2.0, bf16)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-4steps-V2.0-bf16.safetensors", + description="Lightning distillation LoRA for Qwen Image — enables generation in just 4 steps. " + "Settings: Steps=4, CFG=1, Shift Override=3.", + type=ModelType.LoRA, +) + +qwen_image_lightning_8step = StarterModel( + name="Qwen Image Lightning (8-step, V2.0, bf16)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-8steps-V2.0-bf16.safetensors", + description="Lightning distillation LoRA for Qwen Image — enables generation in 8 steps with better quality. " + "Settings: Steps=8, CFG=1, Shift Override=3.", + type=ModelType.LoRA, +) +# endregion + # region SigLIP siglip = StarterModel( name="SigLIP - google/siglip-so400m-patch14-384", @@ -704,7 +844,7 @@ flux2_vae = StarterModel( name="FLUX.2 VAE", base=BaseModelType.Flux2, source="black-forest-labs/FLUX.2-klein-4B::vae", - description="FLUX.2 VAE (16-channel, same architecture as FLUX.1 VAE). ~335MB", + description="FLUX.2 VAE (16-channel, same architecture as FLUX.1 VAE). ~168MB", type=ModelType.VAE, ) @@ -728,7 +868,7 @@ flux2_klein_4b = StarterModel( name="FLUX.2 Klein 4B (Diffusers)", base=BaseModelType.Flux2, source="black-forest-labs/FLUX.2-klein-4B", - description="FLUX.2 Klein 4B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~10GB", + description="FLUX.2 Klein 4B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~16GB", type=ModelType.Main, ) @@ -754,7 +894,7 @@ flux2_klein_9b = StarterModel( name="FLUX.2 Klein 9B (Diffusers)", base=BaseModelType.Flux2, source="black-forest-labs/FLUX.2-klein-9B", - description="FLUX.2 Klein 9B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~20GB", + description="FLUX.2 Klein 9B in Diffusers format - includes transformer, VAE and Qwen3 encoder. ~35GB", type=ModelType.Main, ) @@ -830,7 +970,7 @@ z_image_turbo = StarterModel( name="Z-Image Turbo", base=BaseModelType.ZImage, source="Tongyi-MAI/Z-Image-Turbo", - description="Z-Image Turbo - fast 6B parameter text-to-image model with 8 inference steps. Supports bilingual prompts (English & Chinese). ~30.6GB", + description="Z-Image Turbo - fast 6B parameter text-to-image model with 8 inference steps. Supports bilingual prompts (English & Chinese). ~33GB", type=ModelType.Main, ) @@ -890,6 +1030,45 @@ GEMINI_3_IMAGE_ALLOWED_ASPECT_RATIOS = [ ] GEMINI_3_IMAGE_MAX_SIZE = ExternalImageSize(width=4096, height=4096) + +def _gemini_3_resolution_presets( + image_sizes: list[str], + aspect_ratios: list[str] | None = None, +) -> list[ExternalResolutionPreset]: + """Build resolution presets for Gemini 3 models. + + Each preset combines an aspect ratio with an image size preset (512/1K/2K/4K). + Pixel dimensions are approximations based on the preset name (longest side). + """ + if aspect_ratios is None: + aspect_ratios = GEMINI_3_IMAGE_ALLOWED_ASPECT_RATIOS + base_pixels = {"512": 512, "1K": 1024, "2K": 2048, "4K": 4096} + presets: list[ExternalResolutionPreset] = [] + for image_size in image_sizes: + base = base_pixels[image_size] + for ratio_str in aspect_ratios: + w_part, h_part = (int(x) for x in ratio_str.split(":")) + if w_part >= h_part: + w = base + h = max(1, round(base * h_part / w_part)) + else: + h = base + w = max(1, round(base * w_part / h_part)) + presets.append( + ExternalResolutionPreset( + label=f"{ratio_str} ({image_size}) — {w}\u00d7{h}", + aspect_ratio=ratio_str, + image_size=image_size, + width=w, + height=h, + ) + ) + return presets + + +GEMINI_3_PRO_RESOLUTION_PRESETS = _gemini_3_resolution_presets(["1K", "2K", "4K"]) +GEMINI_3_1_FLASH_RESOLUTION_PRESETS = _gemini_3_resolution_presets(["512", "1K", "2K", "4K"]) + gemini_flash_image = StarterModel( name="Gemini 2.5 Flash Image", base=BaseModelType.External, @@ -899,9 +1078,7 @@ gemini_flash_image = StarterModel( format=ModelFormat.ExternalApi, capabilities=ExternalModelCapabilities( modes=["txt2img", "img2img", "inpaint"], - supports_negative_prompt=True, supports_seed=True, - supports_guidance=True, supports_reference_images=True, max_images_per_request=1, allowed_aspect_ratios=[ @@ -936,19 +1113,18 @@ gemini_pro_image_preview = StarterModel( name="Gemini 3 Pro Image Preview", base=BaseModelType.External, source="external://gemini/gemini-3-pro-image-preview", - description="Google Gemini 3 Pro image generation preview model (external API). Supports up to 14 reference images, including up to 6 object references and up to 5 character references. Supports 512/1K/2K/4K resolution presets. Requires a configured Gemini API key and may incur provider usage costs.", + description="Google Gemini 3 Pro image generation preview model (external API). Supports up to 14 reference images, including up to 6 object references and up to 5 character references. Supports 1K/2K/4K resolution presets. Requires a configured Gemini API key and may incur provider usage costs.", type=ModelType.ExternalImageGenerator, format=ModelFormat.ExternalApi, capabilities=ExternalModelCapabilities( modes=["txt2img", "img2img", "inpaint"], - supports_negative_prompt=True, supports_seed=True, - supports_guidance=True, supports_reference_images=True, max_reference_images=14, max_images_per_request=1, max_image_size=GEMINI_3_IMAGE_MAX_SIZE, allowed_aspect_ratios=GEMINI_3_IMAGE_ALLOWED_ASPECT_RATIOS, + resolution_presets=GEMINI_3_PRO_RESOLUTION_PRESETS, ), default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1), panel_schema=ExternalModelPanelSchema(prompts=[{"name": "reference_images"}], image=[{"name": "dimensions"}]), @@ -962,14 +1138,13 @@ gemini_3_1_flash_image_preview = StarterModel( format=ModelFormat.ExternalApi, capabilities=ExternalModelCapabilities( modes=["txt2img", "img2img", "inpaint"], - supports_negative_prompt=True, supports_seed=True, - supports_guidance=True, supports_reference_images=True, max_reference_images=14, max_images_per_request=1, max_image_size=GEMINI_3_IMAGE_MAX_SIZE, allowed_aspect_ratios=GEMINI_3_IMAGE_ALLOWED_ASPECT_RATIOS, + resolution_presets=GEMINI_3_1_FLASH_RESOLUTION_PRESETS, ), default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1), panel_schema=ExternalModelPanelSchema(prompts=[{"name": "reference_images"}], image=[{"name": "dimensions"}]), @@ -1098,23 +1273,132 @@ alibabacloud_qwen_image_edit_max = StarterModel( default_settings=ExternalApiModelDefaultSettings(width=2048, height=2048, num_images=1), panel_schema=ExternalModelPanelSchema(image=[{"name": "dimensions"}]), ) -openai_gpt_image_1 = StarterModel( - name="ChatGPT Image", +OPENAI_GPT_IMAGE_ASPECT_RATIOS = ["1:1", "3:2", "2:3"] +OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES = { + "1:1": ExternalImageSize(width=1024, height=1024), + "3:2": ExternalImageSize(width=1536, height=1024), + "2:3": ExternalImageSize(width=1024, height=1536), +} +OPENAI_GPT_IMAGE_PANEL_SCHEMA = ExternalModelPanelSchema( + prompts=[{"name": "reference_images"}], image=[{"name": "dimensions"}] +) + +openai_gpt_image_1_5 = StarterModel( + name="GPT Image 1.5", base=BaseModelType.External, - source="external://openai/gpt-image-1", - description="OpenAI GPT-Image-1 image generation model (external API). Requires a configured OpenAI API key and may incur provider usage costs.", + source="external://openai/gpt-image-1.5", + description="OpenAI GPT-Image-1.5 image generation model. Fastest and most affordable GPT image model. Requires a configured OpenAI API key and may incur provider usage costs.", type=ModelType.ExternalImageGenerator, format=ModelFormat.ExternalApi, capabilities=ExternalModelCapabilities( modes=["txt2img", "img2img", "inpaint"], - supports_negative_prompt=True, - supports_seed=True, - supports_guidance=True, supports_reference_images=True, - max_images_per_request=1, + max_images_per_request=10, + allowed_aspect_ratios=OPENAI_GPT_IMAGE_ASPECT_RATIOS, + aspect_ratio_sizes=OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES, ), default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1), - panel_schema=ExternalModelPanelSchema(prompts=[{"name": "reference_images"}], image=[{"name": "dimensions"}]), + panel_schema=OPENAI_GPT_IMAGE_PANEL_SCHEMA, +) +openai_gpt_image_1 = StarterModel( + name="GPT Image 1", + base=BaseModelType.External, + source="external://openai/gpt-image-1", + description="OpenAI GPT-Image-1 image generation model. High quality image generation. Requires a configured OpenAI API key and may incur provider usage costs.", + type=ModelType.ExternalImageGenerator, + format=ModelFormat.ExternalApi, + capabilities=ExternalModelCapabilities( + modes=["txt2img", "img2img", "inpaint"], + supports_reference_images=True, + max_images_per_request=10, + allowed_aspect_ratios=OPENAI_GPT_IMAGE_ASPECT_RATIOS, + aspect_ratio_sizes=OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES, + ), + default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1), + panel_schema=OPENAI_GPT_IMAGE_PANEL_SCHEMA, +) +openai_gpt_image_1_mini = StarterModel( + name="GPT Image 1 Mini", + base=BaseModelType.External, + source="external://openai/gpt-image-1-mini", + description="OpenAI GPT-Image-1-Mini image generation model. Cost-efficient option, 80%% cheaper than GPT-Image-1. Requires a configured OpenAI API key and may incur provider usage costs.", + type=ModelType.ExternalImageGenerator, + format=ModelFormat.ExternalApi, + capabilities=ExternalModelCapabilities( + modes=["txt2img", "img2img", "inpaint"], + supports_reference_images=True, + max_images_per_request=10, + allowed_aspect_ratios=OPENAI_GPT_IMAGE_ASPECT_RATIOS, + aspect_ratio_sizes=OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES, + ), + default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1), + panel_schema=OPENAI_GPT_IMAGE_PANEL_SCHEMA, +) +openai_dall_e_3 = StarterModel( + name="DALL-E 3", + base=BaseModelType.External, + source="external://openai/dall-e-3", + description="OpenAI DALL-E 3 image generation model. Supports vivid and natural styles. Only text-to-image, no editing. Requires a configured OpenAI API key and may incur provider usage costs.", + type=ModelType.ExternalImageGenerator, + format=ModelFormat.ExternalApi, + capabilities=ExternalModelCapabilities( + modes=["txt2img"], + max_images_per_request=1, + allowed_aspect_ratios=["1:1", "7:4", "4:7"], + aspect_ratio_sizes={ + "1:1": ExternalImageSize(width=1024, height=1024), + "7:4": ExternalImageSize(width=1792, height=1024), + "4:7": ExternalImageSize(width=1024, height=1792), + }, + ), + default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1), + panel_schema=ExternalModelPanelSchema(image=[{"name": "dimensions"}]), +) +openai_dall_e_2 = StarterModel( + name="DALL-E 2", + base=BaseModelType.External, + source="external://openai/dall-e-2", + description="OpenAI DALL-E 2 image generation model. Supports square images only. Requires a configured OpenAI API key and may incur provider usage costs.", + type=ModelType.ExternalImageGenerator, + format=ModelFormat.ExternalApi, + capabilities=ExternalModelCapabilities( + modes=["txt2img", "img2img", "inpaint"], + max_images_per_request=10, + allowed_aspect_ratios=["1:1"], + aspect_ratio_sizes={ + "1:1": ExternalImageSize(width=1024, height=1024), + }, + ), + default_settings=ExternalApiModelDefaultSettings(width=1024, height=1024, num_images=1), + panel_schema=ExternalModelPanelSchema(image=[{"name": "dimensions"}]), +) +# region Anima +anima_qwen3_encoder = StarterModel( + name="Anima Qwen3 0.6B Text Encoder", + base=BaseModelType.Any, + source="https://huggingface.co/circlestone-labs/Anima/resolve/main/split_files/text_encoders/qwen_3_06b_base.safetensors", + description="Qwen3 0.6B text encoder for Anima. ~1.2GB", + type=ModelType.Qwen3Encoder, + format=ModelFormat.Checkpoint, +) + +anima_vae = StarterModel( + name="Anima QwenImage VAE", + base=BaseModelType.Anima, + source="https://huggingface.co/circlestone-labs/Anima/resolve/main/split_files/vae/qwen_image_vae.safetensors", + description="QwenImage VAE for Anima (fine-tuned Wan 2.1 VAE, 16 latent channels). ~200MB", + type=ModelType.VAE, + format=ModelFormat.Checkpoint, +) + +anima_preview3 = StarterModel( + name="Anima Preview 3", + base=BaseModelType.Anima, + source="https://huggingface.co/circlestone-labs/Anima/resolve/main/split_files/diffusion_models/anima-preview3-base.safetensors", + description="Anima Preview 3 - 2B parameter anime-focused text-to-image model built on Cosmos Predict2 DiT. ~4.5GB", + type=ModelType.Main, + format=ModelFormat.Checkpoint, + dependencies=[anima_qwen3_encoder, anima_vae, t5_base_encoder], ) # endregion @@ -1204,6 +1488,20 @@ STARTER_MODELS: list[StarterModel] = [ flux2_klein_qwen3_4b_encoder, flux2_klein_qwen3_8b_encoder, cogview4, + qwen_image_edit, + qwen_image_edit_gguf_q2_k, + qwen_image_edit_gguf_q4_k_m, + qwen_image_edit_gguf_q6_k, + qwen_image_edit_gguf_q8_0, + qwen_image_edit_lightning_4step, + qwen_image_edit_lightning_8step, + qwen_image, + qwen_image_gguf_q2_k, + qwen_image_gguf_q4_k_m, + qwen_image_gguf_q6_k, + qwen_image_gguf_q8_0, + qwen_image_lightning_4step, + qwen_image_lightning_8step, flux_krea, flux_krea_quantized, z_image_turbo, @@ -1216,12 +1514,19 @@ STARTER_MODELS: list[StarterModel] = [ gemini_flash_image, gemini_pro_image_preview, gemini_3_1_flash_image_preview, + openai_gpt_image_1_5, openai_gpt_image_1, + openai_gpt_image_1_mini, + openai_dall_e_3, + openai_dall_e_2, alibabacloud_qwen_image_2_pro, alibabacloud_qwen_image_2, alibabacloud_qwen_image_max, alibabacloud_wan26_t2i, alibabacloud_qwen_image_edit_max, + anima_preview3, + anima_qwen3_encoder, + anima_vae, ] sd1_bundle: list[StarterModel] = [ @@ -1290,12 +1595,34 @@ flux2_klein_bundle: list[StarterModel] = [ flux2_klein_qwen3_4b_encoder, ] +qwen_image_bundle: list[StarterModel] = [ + qwen_image_edit, + qwen_image_edit_gguf_q4_k_m, + qwen_image_edit_gguf_q8_0, + qwen_image_edit_lightning_4step, + qwen_image_edit_lightning_8step, + qwen_image, + qwen_image_gguf_q4_k_m, + qwen_image_gguf_q8_0, + qwen_image_lightning_4step, + qwen_image_lightning_8step, +] + +anima_bundle: list[StarterModel] = [ + anima_preview3, + anima_qwen3_encoder, + anima_vae, + t5_base_encoder, +] + STARTER_BUNDLES: dict[str, StarterModelBundle] = { BaseModelType.StableDiffusion1: StarterModelBundle(name="Stable Diffusion 1.5", models=sd1_bundle), BaseModelType.StableDiffusionXL: StarterModelBundle(name="SDXL", models=sdxl_bundle), BaseModelType.Flux: StarterModelBundle(name="FLUX.1 dev", models=flux_bundle), BaseModelType.Flux2: StarterModelBundle(name="FLUX.2 Klein", models=flux2_klein_bundle), BaseModelType.ZImage: StarterModelBundle(name="Z-Image Turbo", models=zimage_bundle), + BaseModelType.QwenImage: StarterModelBundle(name="Qwen Image", models=qwen_image_bundle), + BaseModelType.Anima: StarterModelBundle(name="Anima", models=anima_bundle), } assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models" diff --git a/invokeai/backend/model_manager/taxonomy.py b/invokeai/backend/model_manager/taxonomy.py index 4bf3461a8b..b2b55ebd3f 100644 --- a/invokeai/backend/model_manager/taxonomy.py +++ b/invokeai/backend/model_manager/taxonomy.py @@ -54,6 +54,10 @@ class BaseModelType(str, Enum): """Indicates the model is associated with Z-Image model architecture, including Z-Image-Turbo.""" External = "external" """Indicates the model is hosted by an external provider.""" + QwenImage = "qwen-image" + """Indicates the model is associated with Qwen Image Edit 2511 model architecture.""" + Anima = "anima" + """Indicates the model is associated with Anima model architecture (Cosmos Predict2 DiT + LLM Adapter).""" Unknown = "unknown" """Indicates the model's base architecture is unknown.""" @@ -146,6 +150,16 @@ class ZImageVariantType(str, Enum): """Z-Image Base - undistilled foundation model with full CFG and negative prompt support.""" +class QwenImageVariantType(str, Enum): + """Qwen Image model variants.""" + + Generate = "generate" + """Qwen Image - text-to-image generation model.""" + + Edit = "edit" + """Qwen Image Edit - image editing model with reference image support.""" + + class Qwen3VariantType(str, Enum): """Qwen3 text encoder variants based on model size.""" @@ -155,6 +169,9 @@ class Qwen3VariantType(str, Enum): Qwen3_8B = "qwen3_8b" """Qwen3 8B text encoder (hidden_size=4096). Used by FLUX.2 Klein 9B.""" + Qwen3_06B = "qwen3_06b" + """Qwen3 0.6B text encoder (hidden_size=1024). Used by Anima.""" + class ModelFormat(str, Enum): """Storage format of model.""" @@ -215,11 +232,32 @@ class FluxLoRAFormat(str, Enum): AIToolkit = "flux.aitoolkit" XLabs = "flux.xlabs" BflPeft = "flux.bfl_peft" + OneTrainerBfl = "flux.onetrainer_bfl" AnyVariant: TypeAlias = Union[ - ModelVariantType, ClipVariantType, FluxVariantType, Flux2VariantType, ZImageVariantType, Qwen3VariantType + ModelVariantType, + ClipVariantType, + FluxVariantType, + Flux2VariantType, + ZImageVariantType, + QwenImageVariantType, + Qwen3VariantType, ] variant_type_adapter = TypeAdapter[ - ModelVariantType | ClipVariantType | FluxVariantType | Flux2VariantType | ZImageVariantType | Qwen3VariantType -](ModelVariantType | ClipVariantType | FluxVariantType | Flux2VariantType | ZImageVariantType | Qwen3VariantType) + ModelVariantType + | ClipVariantType + | FluxVariantType + | Flux2VariantType + | ZImageVariantType + | QwenImageVariantType + | Qwen3VariantType +]( + ModelVariantType + | ClipVariantType + | FluxVariantType + | Flux2VariantType + | ZImageVariantType + | QwenImageVariantType + | Qwen3VariantType +) diff --git a/invokeai/backend/patches/lora_conversions/anima_lora_constants.py b/invokeai/backend/patches/lora_conversions/anima_lora_constants.py new file mode 100644 index 0000000000..380e31998a --- /dev/null +++ b/invokeai/backend/patches/lora_conversions/anima_lora_constants.py @@ -0,0 +1,45 @@ +# Anima LoRA prefix constants +# These prefixes are used for key mapping when applying LoRA patches to Anima models + +import re + +# Prefix for Anima transformer (Cosmos DiT architecture) LoRA layers +ANIMA_LORA_TRANSFORMER_PREFIX = "lora_transformer-" + +# Prefix for Qwen3 text encoder LoRA layers +ANIMA_LORA_QWEN3_PREFIX = "lora_qwen3-" + +# --------------------------------------------------------------------------- +# Cosmos DiT detection helpers +# +# Shared between ``anima_lora_conversion_utils.is_state_dict_likely_anima_lora`` +# and the config probing code in ``configs/lora.py``. Kept here (rather than +# in ``anima_lora_conversion_utils``) to avoid circular imports. +# --------------------------------------------------------------------------- + +# Cosmos DiT subcomponent names unique to the Anima / Cosmos Predict2 architecture. +_COSMOS_DIT_SUBCOMPONENTS_RE = r"(cross_attn|self_attn|mlp|adaln_modulation)" + +# Kohya format: lora_unet_[llm_adapter_]blocks_N_ +_KOHYA_ANIMA_RE = re.compile(r"lora_unet_(llm_adapter_)?blocks_\d+_" + _COSMOS_DIT_SUBCOMPONENTS_RE) + +# PEFT format: .blocks.N. +_PEFT_ANIMA_RE = re.compile( + r"(diffusion_model|transformer|base_model\.model\.transformer)\.blocks\.\d+\." + _COSMOS_DIT_SUBCOMPONENTS_RE +) + + +def has_cosmos_dit_kohya_keys(str_keys: list[str]) -> bool: + """Check for Kohya-style keys targeting Cosmos DiT blocks with specific subcomponents. + + Requires both the ``lora_unet_[llm_adapter_]blocks_N_`` prefix **and** a + Cosmos DiT subcomponent name (cross_attn, self_attn, mlp, adaln_modulation) + to avoid false-positives on other architectures that might also use bare + ``blocks`` in their key paths. + """ + return any(_KOHYA_ANIMA_RE.search(k) is not None for k in str_keys) + + +def has_cosmos_dit_peft_keys(str_keys: list[str]) -> bool: + """Check for diffusers PEFT keys targeting Cosmos DiT blocks with specific subcomponents.""" + return any(_PEFT_ANIMA_RE.search(k) is not None for k in str_keys) diff --git a/invokeai/backend/patches/lora_conversions/anima_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/anima_lora_conversion_utils.py new file mode 100644 index 0000000000..b55a96dca7 --- /dev/null +++ b/invokeai/backend/patches/lora_conversions/anima_lora_conversion_utils.py @@ -0,0 +1,300 @@ +"""Anima LoRA conversion utilities. + +Anima uses a Cosmos Predict2 DiT transformer architecture. +LoRAs for Anima typically follow the Kohya-style format with underscore-separated keys +(e.g., lora_unet_blocks_0_cross_attn_k_proj) that map to model parameter paths +(e.g., blocks.0.cross_attn.k_proj). + +Some Anima LoRAs also target the Qwen3 text encoder with lora_te_ prefix keys +(e.g., lora_te_layers_0_self_attn_q_proj -> layers.0.self_attn.q_proj). +""" + +import re +from typing import Dict + +import torch + +from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch +from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict +from invokeai.backend.patches.lora_conversions.anima_lora_constants import ( + ANIMA_LORA_QWEN3_PREFIX, + ANIMA_LORA_TRANSFORMER_PREFIX, + has_cosmos_dit_kohya_keys, + has_cosmos_dit_peft_keys, +) +from invokeai.backend.patches.model_patch_raw import ModelPatchRaw +from invokeai.backend.util.logging import InvokeAILogger + +logger = InvokeAILogger.get_logger(__name__) + + +def is_state_dict_likely_anima_lora(state_dict: dict[str | int, torch.Tensor]) -> bool: + """Checks if the provided state dict is likely an Anima LoRA. + + Anima LoRAs use Kohya-style naming with lora_unet_ prefix and underscore-separated + model key paths targeting Cosmos DiT blocks. Detection requires Cosmos DiT-specific + subcomponent names (cross_attn, self_attn, mlp, adaln_modulation) to avoid + false-positives on other architectures that also use ``blocks`` in their paths. + """ + str_keys = [k for k in state_dict.keys() if isinstance(k, str)] + + if has_cosmos_dit_kohya_keys(str_keys): + return True + + return has_cosmos_dit_peft_keys(str_keys) + + +# Mapping from Kohya underscore-style substrings to model parameter names. +# Order matters: longer/more specific patterns should come first to avoid partial matches. +_KOHYA_UNET_KEY_REPLACEMENTS = [ + ("adaln_modulation_cross_attn_", "adaln_modulation_cross_attn."), + ("adaln_modulation_self_attn_", "adaln_modulation_self_attn."), + ("adaln_modulation_mlp_", "adaln_modulation_mlp."), + ("cross_attn_k_proj", "cross_attn.k_proj"), + ("cross_attn_q_proj", "cross_attn.q_proj"), + ("cross_attn_v_proj", "cross_attn.v_proj"), + ("cross_attn_output_proj", "cross_attn.output_proj"), + ("cross_attn_o_proj", "cross_attn.o_proj"), + ("self_attn_k_proj", "self_attn.k_proj"), + ("self_attn_q_proj", "self_attn.q_proj"), + ("self_attn_v_proj", "self_attn.v_proj"), + ("self_attn_output_proj", "self_attn.output_proj"), + ("self_attn_o_proj", "self_attn.o_proj"), + ("mlp_layer1", "mlp.layer1"), + ("mlp_layer2", "mlp.layer2"), +] + +# Mapping for Qwen3 text encoder Kohya keys. +_KOHYA_TE_KEY_REPLACEMENTS = [ + ("self_attn_k_proj", "self_attn.k_proj"), + ("self_attn_q_proj", "self_attn.q_proj"), + ("self_attn_v_proj", "self_attn.v_proj"), + ("self_attn_o_proj", "self_attn.o_proj"), + ("mlp_down_proj", "mlp.down_proj"), + ("mlp_gate_proj", "mlp.gate_proj"), + ("mlp_up_proj", "mlp.up_proj"), +] + + +def _convert_kohya_unet_key(kohya_layer_name: str) -> str: + """Convert a Kohya-style LoRA layer name to a model parameter path. + + Example: lora_unet_blocks_0_cross_attn_k_proj -> blocks.0.cross_attn.k_proj + Example: lora_unet_llm_adapter_blocks_0_cross_attn_k_proj -> llm_adapter.blocks.0.cross_attn.k_proj + """ + key = kohya_layer_name + if key.startswith("lora_unet_"): + key = key[len("lora_unet_") :] + + # Handle llm_adapter prefix: strip it, run the standard block conversion, then re-add with dot + llm_adapter_prefix = "" + if key.startswith("llm_adapter_"): + key = key[len("llm_adapter_") :] + llm_adapter_prefix = "llm_adapter." + + # Convert blocks_N_ to blocks.N. + key = re.sub(r"^blocks_(\d+)_", r"blocks.\1.", key) + + # Apply known replacements for subcomponent names + for old, new in _KOHYA_UNET_KEY_REPLACEMENTS: + if old in key: + key = key.replace(old, new, 1) + break + + return llm_adapter_prefix + key + + +def _convert_kohya_te_key(kohya_layer_name: str) -> str: + """Convert a Kohya-style text encoder LoRA layer name to a model parameter path. + + The Qwen3 text encoder is loaded as Qwen3ForCausalLM which wraps the base model + under a `model.` prefix, so the final path must include it. + + Example: lora_te_layers_0_self_attn_q_proj -> model.layers.0.self_attn.q_proj + """ + key = kohya_layer_name + if key.startswith("lora_te_"): + key = key[len("lora_te_") :] + + # Convert layers_N_ to layers.N. + key = re.sub(r"^layers_(\d+)_", r"layers.\1.", key) + + # Apply known replacements + for old, new in _KOHYA_TE_KEY_REPLACEMENTS: + if old in key: + key = key.replace(old, new, 1) + break + + # Qwen3ForCausalLM wraps the base Qwen3Model under `model.` + key = f"model.{key}" + + return key + + +def _make_layer_patch(layer_dict: dict[str, torch.Tensor]) -> BaseLayerPatch: + """Create a layer patch from a layer dict, handling DoRA+LoKR edge case. + + Some Anima LoRAs combine DoRA (dora_scale) with LoKR (lokr_w1/lokr_w2) weights. + The shared any_lora_layer_from_state_dict checks dora_scale first and expects + lora_up/lora_down keys, which don't exist in LoKR layers. We strip dora_scale + from LoKR layers so they fall through to the LoKR handler instead. + """ + has_lokr = "lokr_w1" in layer_dict or "lokr_w1_a" in layer_dict + has_dora = "dora_scale" in layer_dict + if has_lokr and has_dora: + layer_dict = {k: v for k, v in layer_dict.items() if k != "dora_scale"} + logger.warning("Stripped dora_scale from LoKR layer (DoRA+LoKR combination not supported, using LoKR only)") + return any_lora_layer_from_state_dict(layer_dict) + + +# Known suffixes for Kohya format +_KOHYA_KNOWN_SUFFIXES = [ + ".lora_A.weight", + ".lora_B.weight", + ".lora_down.weight", + ".lora_up.weight", + ".dora_scale", + ".alpha", +] + +# Additional suffixes for PEFT/LoKR format +_PEFT_EXTRA_SUFFIXES = [ + ".lokr_w1", + ".lokr_w2", + ".lokr_w1_a", + ".lokr_w1_b", + ".lokr_w2_a", + ".lokr_w2_b", +] + + +def _group_keys_by_layer( + state_dict: Dict[str, torch.Tensor], + extra_suffixes: list[str] | None = None, +) -> dict[str, dict[str, torch.Tensor]]: + """Group state dict keys by layer name based on known suffixes. + + Args: + state_dict: The LoRA state dict to group. + extra_suffixes: Additional suffixes to recognize beyond the base Kohya set. + + Returns: + Dict mapping layer names to their component tensors. + """ + layer_dict: dict[str, dict[str, torch.Tensor]] = {} + + known_suffixes = list(_KOHYA_KNOWN_SUFFIXES) + if extra_suffixes: + known_suffixes.extend(extra_suffixes) + + for key in state_dict: + if not isinstance(key, str): + continue + + layer_name = None + key_name = None + for suffix in known_suffixes: + if key.endswith(suffix): + layer_name = key[: -len(suffix)] + key_name = suffix[1:] # Remove leading dot + break + + if layer_name is None: + parts = key.rsplit(".", maxsplit=2) + layer_name = parts[0] + key_name = ".".join(parts[1:]) + + if layer_name not in layer_dict: + layer_dict[layer_name] = {} + layer_dict[layer_name][key_name] = state_dict[key] + + return layer_dict + + +def _get_lora_layer_values(layer_dict: dict[str, torch.Tensor], alpha: float | None) -> dict[str, torch.Tensor]: + """Convert layer dict keys from PEFT format to internal format.""" + if "lora_A.weight" in layer_dict: + values = { + "lora_down.weight": layer_dict["lora_A.weight"], + "lora_up.weight": layer_dict["lora_B.weight"], + } + if alpha is not None: + values["alpha"] = torch.tensor(alpha) + return values + elif "lora_down.weight" in layer_dict: + return layer_dict + else: + return layer_dict + + +def lora_model_from_anima_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float | None = None) -> ModelPatchRaw: + """Convert an Anima LoRA state dict to a ModelPatchRaw. + + Supports both Kohya-style keys (lora_unet_blocks_0_...) and diffusers PEFT format. + Also supports text encoder LoRA keys (lora_te_layers_0_...) targeting the Qwen3 encoder. + + Args: + state_dict: The LoRA state dict + alpha: The alpha value for LoRA scaling. If None, uses rank as alpha. + + Returns: + A ModelPatchRaw containing the LoRA layers + """ + layers: dict[str, BaseLayerPatch] = {} + + # Detect format + str_keys = [k for k in state_dict.keys() if isinstance(k, str)] + is_kohya = any(k.startswith(("lora_unet_", "lora_te_")) for k in str_keys) + + if is_kohya: + # Kohya format: group by layer name (everything before .lora_down/.lora_up/.alpha) + grouped = _group_keys_by_layer(state_dict) + for kohya_layer_name, layer_dict in grouped.items(): + if kohya_layer_name.startswith("lora_te_"): + model_key = _convert_kohya_te_key(kohya_layer_name) + final_key = f"{ANIMA_LORA_QWEN3_PREFIX}{model_key}" + else: + model_key = _convert_kohya_unet_key(kohya_layer_name) + final_key = f"{ANIMA_LORA_TRANSFORMER_PREFIX}{model_key}" + layer = _make_layer_patch(layer_dict) + layers[final_key] = layer + else: + # Diffusers PEFT format + grouped = _group_keys_by_layer(state_dict, extra_suffixes=_PEFT_EXTRA_SUFFIXES) + for layer_key, layer_dict in grouped.items(): + values = _get_lora_layer_values(layer_dict, alpha) + clean_key = layer_key + + # Check for text encoder prefixes + text_encoder_prefixes = [ + "base_model.model.text_encoder.", + "text_encoder.", + ] + + is_text_encoder = False + for prefix in text_encoder_prefixes: + if layer_key.startswith(prefix): + clean_key = layer_key[len(prefix) :] + is_text_encoder = True + break + + # If not text encoder, check transformer prefixes + if not is_text_encoder: + for prefix in [ + "base_model.model.transformer.", + "transformer.", + "diffusion_model.", + ]: + if layer_key.startswith(prefix): + clean_key = layer_key[len(prefix) :] + break + + if is_text_encoder: + final_key = f"{ANIMA_LORA_QWEN3_PREFIX}{clean_key}" + else: + final_key = f"{ANIMA_LORA_TRANSFORMER_PREFIX}{clean_key}" + + layer = _make_layer_patch(values) + layers[final_key] = layer + + return ModelPatchRaw(layers=layers) diff --git a/invokeai/backend/patches/lora_conversions/flux_onetrainer_bfl_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_onetrainer_bfl_lora_conversion_utils.py new file mode 100644 index 0000000000..b2109222a3 --- /dev/null +++ b/invokeai/backend/patches/lora_conversions/flux_onetrainer_bfl_lora_conversion_utils.py @@ -0,0 +1,168 @@ +"""Utilities for detecting and converting FLUX LoRAs in OneTrainer BFL format. + +This format is produced by newer versions of OneTrainer and uses BFL internal key names +(double_blocks, single_blocks, img_attn, etc.) with a 'transformer.' prefix and +InvokeAI-native LoRA suffixes (lora_down.weight, lora_up.weight, alpha). + +Unlike the standard BFL PEFT format (which uses 'diffusion_model.' prefix and lora_A/lora_B), +this format also has split QKV projections: + - double_blocks.{i}.img_attn.qkv.{0,1,2} (Q, K, V separate) + - double_blocks.{i}.txt_attn.qkv.{0,1,2} (Q, K, V separate) + - single_blocks.{i}.linear1.{0,1,2,3} (Q, K, V, MLP separate) + +Example keys: + transformer.double_blocks.0.img_attn.qkv.0.lora_down.weight + transformer.double_blocks.0.img_attn.qkv.0.lora_up.weight + transformer.double_blocks.0.img_attn.qkv.0.alpha + transformer.single_blocks.0.linear1.3.lora_down.weight + transformer.double_blocks.0.img_mlp.0.lora_down.weight +""" + +import re +from typing import Any, Dict + +import torch + +from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch +from invokeai.backend.patches.layers.merged_layer_patch import MergedLayerPatch, Range +from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict +from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX +from invokeai.backend.patches.model_patch_raw import ModelPatchRaw + +_TRANSFORMER_PREFIX = "transformer." + +# Valid LoRA weight suffixes in this format. +_LORA_SUFFIXES = ("lora_down.weight", "lora_up.weight", "alpha") + +# Regex to detect split QKV keys in double blocks: e.g. "double_blocks.0.img_attn.qkv.1" +_SPLIT_QKV_RE = re.compile(r"^(double_blocks\.\d+\.(img_attn|txt_attn)\.qkv)\.\d+$") + +# Regex to detect split linear1 keys in single blocks: e.g. "single_blocks.0.linear1.2" +_SPLIT_LINEAR1_RE = re.compile(r"^(single_blocks\.\d+\.linear1)\.\d+$") + + +def is_state_dict_likely_in_flux_onetrainer_bfl_format( + state_dict: dict[str | int, Any], + metadata: dict[str, Any] | None = None, +) -> bool: + """Checks if the provided state dict is likely in the OneTrainer BFL FLUX LoRA format. + + This format uses BFL internal key names with 'transformer.' prefix and split QKV projections. + """ + str_keys = [k for k in state_dict.keys() if isinstance(k, str)] + if not str_keys: + return False + + # All keys must start with 'transformer.' + if not all(k.startswith(_TRANSFORMER_PREFIX) for k in str_keys): + return False + + # All keys must end with recognized LoRA suffixes. + if not all(k.endswith(_LORA_SUFFIXES) for k in str_keys): + return False + + # Must have BFL block structure (double_blocks or single_blocks) under transformer prefix. + has_bfl_blocks = any( + k.startswith("transformer.double_blocks.") or k.startswith("transformer.single_blocks.") for k in str_keys + ) + if not has_bfl_blocks: + return False + + # Must have split QKV pattern (qkv.0, qkv.1, qkv.2) to distinguish from other formats + # that might use transformer. prefix in the future. + has_split_qkv = any(".qkv.0." in k or ".qkv.1." in k or ".qkv.2." in k or ".linear1.0." in k for k in str_keys) + if not has_split_qkv: + return False + + return True + + +def _split_key(key: str) -> tuple[str, str]: + """Split a key into (layer_name, weight_suffix). + + Handles: + - 2-component suffixes ending with '.weight': e.g., 'lora_down.weight' → split at 2nd-to-last dot + - 1-component suffixes: e.g., 'alpha' → split at last dot + """ + if key.endswith(".weight"): + parts = key.rsplit(".", maxsplit=2) + return parts[0], f"{parts[1]}.{parts[2]}" + else: + parts = key.rsplit(".", maxsplit=1) + return parts[0], parts[1] + + +def lora_model_from_flux_onetrainer_bfl_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw: + """Convert a OneTrainer BFL format FLUX LoRA state dict to a ModelPatchRaw. + + Strips the 'transformer.' prefix, groups by layer, and merges split QKV/linear1 + layers into MergedLayerPatch instances. + """ + # Step 1: Strip prefix and group by layer name. + grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {} + for key, value in state_dict.items(): + if not isinstance(key, str): + continue + + # Strip 'transformer.' prefix. + key = key[len(_TRANSFORMER_PREFIX) :] + + layer_name, suffix = _split_key(key) + + if layer_name not in grouped_state_dict: + grouped_state_dict[layer_name] = {} + grouped_state_dict[layer_name][suffix] = value + + # Step 2: Build LoRA layers, merging split QKV and linear1. + layers: dict[str, BaseLayerPatch] = {} + + # Identify which layers need merging. + merge_groups: dict[str, list[str]] = {} + standalone_keys: list[str] = [] + + for layer_key in grouped_state_dict: + qkv_match = _SPLIT_QKV_RE.match(layer_key) + linear1_match = _SPLIT_LINEAR1_RE.match(layer_key) + + if qkv_match: + parent = qkv_match.group(1) + if parent not in merge_groups: + merge_groups[parent] = [] + merge_groups[parent].append(layer_key) + elif linear1_match: + parent = linear1_match.group(1) + if parent not in merge_groups: + merge_groups[parent] = [] + merge_groups[parent].append(layer_key) + else: + standalone_keys.append(layer_key) + + # Process standalone layers. + for layer_key in standalone_keys: + layer_sd = grouped_state_dict[layer_key] + layers[f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"] = any_lora_layer_from_state_dict(layer_sd) + + # Process merged layers. + for parent_key, sub_keys in merge_groups.items(): + # Sort by the numeric index at the end (e.g., qkv.0, qkv.1, qkv.2). + sub_keys.sort(key=lambda k: int(k.rsplit(".", maxsplit=1)[1])) + + sub_layers: list[BaseLayerPatch] = [] + sub_ranges: list[Range] = [] + dim_0_offset = 0 + + for sub_key in sub_keys: + layer_sd = grouped_state_dict[sub_key] + sub_layer = any_lora_layer_from_state_dict(layer_sd) + + # Determine the output dimension from the up weight shape. + up_weight = layer_sd["lora_up.weight"] + out_dim = up_weight.shape[0] + + sub_layers.append(sub_layer) + sub_ranges.append(Range(dim_0_offset, dim_0_offset + out_dim)) + dim_0_offset += out_dim + + layers[f"{FLUX_LORA_TRANSFORMER_PREFIX}{parent_key}"] = MergedLayerPatch(sub_layers, sub_ranges) + + return ModelPatchRaw(layers=layers) diff --git a/invokeai/backend/patches/lora_conversions/formats.py b/invokeai/backend/patches/lora_conversions/formats.py index 0b316602fc..b3e00c288b 100644 --- a/invokeai/backend/patches/lora_conversions/formats.py +++ b/invokeai/backend/patches/lora_conversions/formats.py @@ -14,6 +14,9 @@ from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_ut from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import ( is_state_dict_likely_in_flux_kohya_format, ) +from invokeai.backend.patches.lora_conversions.flux_onetrainer_bfl_lora_conversion_utils import ( + is_state_dict_likely_in_flux_onetrainer_bfl_format, +) from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import ( is_state_dict_likely_in_flux_onetrainer_format, ) @@ -28,6 +31,8 @@ def flux_format_from_state_dict( ) -> FluxLoRAFormat | None: if is_state_dict_likely_in_flux_kohya_format(state_dict): return FluxLoRAFormat.Kohya + elif is_state_dict_likely_in_flux_onetrainer_bfl_format(state_dict, metadata): + return FluxLoRAFormat.OneTrainerBfl elif is_state_dict_likely_in_flux_onetrainer_format(state_dict): return FluxLoRAFormat.OneTrainer elif is_state_dict_likely_in_flux_diffusers_format(state_dict): diff --git a/invokeai/backend/patches/lora_conversions/qwen_image_lora_constants.py b/invokeai/backend/patches/lora_conversions/qwen_image_lora_constants.py new file mode 100644 index 0000000000..727ee5a428 --- /dev/null +++ b/invokeai/backend/patches/lora_conversions/qwen_image_lora_constants.py @@ -0,0 +1,5 @@ +# Qwen Image Edit LoRA prefix constants +# These prefixes are used for key mapping when applying LoRA patches to Qwen Image Edit models + +# Prefix for Qwen Image Edit transformer LoRA layers +QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX = "lora_transformer-" diff --git a/invokeai/backend/patches/lora_conversions/qwen_image_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/qwen_image_lora_conversion_utils.py new file mode 100644 index 0000000000..7fc01f7231 --- /dev/null +++ b/invokeai/backend/patches/lora_conversions/qwen_image_lora_conversion_utils.py @@ -0,0 +1,197 @@ +"""Qwen Image LoRA conversion utilities. + +Qwen Image uses QwenImageTransformer2DModel architecture. +Supports multiple LoRA formats: +- Diffusers/PEFT: transformer_blocks.0.attn.to_k.lora_down.weight +- With prefix: transformer.transformer_blocks.0.attn.to_k.lora_down.weight +- Kohya: lora_unet_transformer_blocks_0_attn_to_k.lora_down.weight (underscores instead of dots) +""" + +import re +from typing import Dict + +import torch + +from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch +from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict +from invokeai.backend.patches.lora_conversions.qwen_image_lora_constants import ( + QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX, +) +from invokeai.backend.patches.model_patch_raw import ModelPatchRaw + +# Regex for Kohya-format Qwen Image LoRA keys. +# Example: lora_unet_transformer_blocks_0_attn_to_k +# Groups: (block_idx, sub_module_with_underscores) +_KOHYA_KEY_REGEX = re.compile(r"lora_unet_transformer_blocks_(\d+)_(.*)") + +# Mapping from Kohya underscore-separated sub-module names to dot-separated model paths. +# The Kohya format uses underscores everywhere, but some underscores are part of the +# module name (e.g., add_k_proj, to_out). We match the longest prefix first. +_KOHYA_MODULE_MAP: list[tuple[str, str]] = [ + # Attention projections + ("attn_add_k_proj", "attn.add_k_proj"), + ("attn_add_q_proj", "attn.add_q_proj"), + ("attn_add_v_proj", "attn.add_v_proj"), + ("attn_to_add_out", "attn.to_add_out"), + ("attn_to_out_0", "attn.to_out.0"), + ("attn_to_k", "attn.to_k"), + ("attn_to_q", "attn.to_q"), + ("attn_to_v", "attn.to_v"), + # Image stream MLP and modulation + ("img_mlp_net_0_proj", "img_mlp.net.0.proj"), + ("img_mlp_net_2", "img_mlp.net.2"), + ("img_mod_1", "img_mod.1"), + # Text stream MLP and modulation + ("txt_mlp_net_0_proj", "txt_mlp.net.0.proj"), + ("txt_mlp_net_2", "txt_mlp.net.2"), + ("txt_mod_1", "txt_mod.1"), +] + + +def is_state_dict_likely_kohya_qwen_image(state_dict: dict[str | int, torch.Tensor]) -> bool: + """Check if the state dict uses Kohya-format Qwen Image LoRA keys.""" + str_keys = [k for k in state_dict.keys() if isinstance(k, str)] + if not str_keys: + return False + # Check if any key matches the Kohya pattern + return any(k.startswith("lora_unet_transformer_blocks_") for k in str_keys) + + +def _convert_kohya_key(kohya_layer: str) -> str | None: + """Convert a Kohya-format layer name to a dot-separated model module path. + + Example: lora_unet_transformer_blocks_0_attn_to_k -> transformer_blocks.0.attn.to_k + """ + m = _KOHYA_KEY_REGEX.match(kohya_layer) + if not m: + return None + + block_idx = m.group(1) + sub_module = m.group(2) + + for kohya_name, model_path in _KOHYA_MODULE_MAP: + if sub_module == kohya_name: + return f"transformer_blocks.{block_idx}.{model_path}" + + # Fallback: unknown sub-module, return None so caller can warn/skip + return None + + +def lora_model_from_qwen_image_state_dict( + state_dict: Dict[str, torch.Tensor], alpha: float | None = None +) -> ModelPatchRaw: + """Convert a Qwen Image LoRA state dict to a ModelPatchRaw. + + Handles three key formats: + - Diffusers/PEFT: transformer_blocks.0.attn.to_k.lora_down.weight + - With prefix: transformer.transformer_blocks.0.attn.to_k.lora_down.weight + - Kohya: lora_unet_transformer_blocks_0_attn_to_k.lora_down.weight + """ + is_kohya = is_state_dict_likely_kohya_qwen_image(state_dict) + + if is_kohya: + return _convert_kohya_format(state_dict, alpha) + else: + return _convert_diffusers_format(state_dict, alpha) + + +def _convert_kohya_format(state_dict: Dict[str, torch.Tensor], alpha: float | None) -> ModelPatchRaw: + """Convert Kohya-format state dict. Keys are like lora_unet_transformer_blocks_0_attn_to_k.lokr_w1""" + layers: dict[str, BaseLayerPatch] = {} + + # Group by layer (split at first dot: layer_name.param_name) + grouped: dict[str, dict[str, torch.Tensor]] = {} + for key, value in state_dict.items(): + if not isinstance(key, str): + continue + layer_name, param_name = key.split(".", 1) + if layer_name not in grouped: + grouped[layer_name] = {} + grouped[layer_name][param_name] = value + + for kohya_layer, layer_dict in grouped.items(): + model_path = _convert_kohya_key(kohya_layer) + if model_path is None: + continue # Skip unrecognized layers + + layer = any_lora_layer_from_state_dict(layer_dict) + final_key = f"{QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX}{model_path}" + layers[final_key] = layer + + return ModelPatchRaw(layers=layers) + + +def _convert_diffusers_format(state_dict: Dict[str, torch.Tensor], alpha: float | None) -> ModelPatchRaw: + """Convert Diffusers/PEFT format state dict.""" + layers: dict[str, BaseLayerPatch] = {} + + # Some LoRAs use a "transformer." prefix on keys + strip_prefixes = ["transformer."] + + grouped = _group_by_layer(state_dict) + + for layer_key, layer_dict in grouped.items(): + values = _normalize_lora_keys(layer_dict, alpha) + layer = any_lora_layer_from_state_dict(values) + clean_key = layer_key + for prefix in strip_prefixes: + if clean_key.startswith(prefix): + clean_key = clean_key[len(prefix) :] + break + final_key = f"{QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX}{clean_key}" + layers[final_key] = layer + + return ModelPatchRaw(layers=layers) + + +def _normalize_lora_keys(layer_dict: dict[str, torch.Tensor], alpha: float | None) -> dict[str, torch.Tensor]: + """Normalize LoRA key names to internal format.""" + if "lora_A.weight" in layer_dict: + values: dict[str, torch.Tensor] = { + "lora_down.weight": layer_dict["lora_A.weight"], + "lora_up.weight": layer_dict["lora_B.weight"], + } + if alpha is not None: + values["alpha"] = torch.tensor(alpha) + return values + elif "lora_down.weight" in layer_dict: + return layer_dict + else: + return layer_dict + + +def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]: + """Group state dict keys by layer path.""" + layer_dict: dict[str, dict[str, torch.Tensor]] = {} + + known_suffixes = [ + ".lora_A.weight", + ".lora_B.weight", + ".lora_down.weight", + ".lora_up.weight", + ".dora_scale", + ".alpha", + ] + + for key in state_dict: + if not isinstance(key, str): + continue + + layer_name = None + key_name = None + for suffix in known_suffixes: + if key.endswith(suffix): + layer_name = key[: -len(suffix)] + key_name = suffix[1:] + break + + if layer_name is None: + parts = key.rsplit(".", maxsplit=2) + layer_name = parts[0] + key_name = ".".join(parts[1:]) + + if layer_name not in layer_dict: + layer_dict[layer_name] = {} + layer_dict[layer_name][key_name] = state_dict[key] + + return layer_dict diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index de5253f073..054e04dcb2 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -17,7 +17,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from diffusers.utils.import_utils import is_xformers_available from pydantic import Field -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from invokeai.app.services.config.config_default import get_config from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData @@ -139,7 +139,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -151,7 +151,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: Optional[StableDiffusionSafetyChecker], - feature_extractor: Optional[CLIPFeatureExtractor], + feature_extractor: Optional[CLIPImageProcessor], requires_safety_checker: bool = False, ): super().__init__( diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 9d1bd67617..6a9959f1e8 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -88,6 +88,48 @@ class ZImageConditioningInfo: return self +@dataclass +class QwenImageConditioningInfo: + """Qwen Image Edit conditioning information from Qwen2.5-VL encoder.""" + + prompt_embeds: torch.Tensor + """Text/image embeddings from Qwen2.5-VL encoder. Shape: (batch_size, seq_len, hidden_size).""" + + prompt_embeds_mask: torch.Tensor | None = None + """Attention mask for prompt_embeds. Shape: (batch_size, seq_len). 1 for valid, 0 for padding.""" + + def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None): + self.prompt_embeds = self.prompt_embeds.to(device=device, dtype=dtype) + if self.prompt_embeds_mask is not None: + self.prompt_embeds_mask = self.prompt_embeds_mask.to(device=device) + return self + + +@dataclass +class AnimaConditioningInfo: + """Anima text conditioning information from Qwen3 0.6B encoder + T5-XXL tokenizer. + + Anima uses a dual-conditioning scheme where Qwen3 hidden states are combined + with T5-XXL token IDs inside the LLM Adapter (part of the transformer). + """ + + qwen3_embeds: torch.Tensor + """Qwen3 0.6B hidden states. Shape: (seq_len, hidden_size) where hidden_size=1024.""" + + t5xxl_ids: torch.Tensor + """T5-XXL token IDs. Shape: (seq_len,).""" + + t5xxl_weights: Optional[torch.Tensor] = None + """Per-token weights for prompt weighting. Shape: (seq_len,). None means uniform weight.""" + + def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None): + self.qwen3_embeds = self.qwen3_embeds.to(device=device, dtype=dtype) + self.t5xxl_ids = self.t5xxl_ids.to(device=device) + if self.t5xxl_weights is not None: + self.t5xxl_weights = self.t5xxl_weights.to(device=device, dtype=dtype) + return self + + @dataclass class ConditioningFieldData: # If you change this class, adding more types, you _must_ update the instantiation of ObjectSerializerDisk in @@ -100,6 +142,8 @@ class ConditioningFieldData: | List[SD3ConditioningInfo] | List[CogView4ConditioningInfo] | List[ZImageConditioningInfo] + | List[QwenImageConditioningInfo] + | List[AnimaConditioningInfo] ) diff --git a/invokeai/frontend/web/openapi.json b/invokeai/frontend/web/openapi.json index af8476528d..19e5a3a68e 100644 --- a/invokeai/frontend/web/openapi.json +++ b/invokeai/frontend/web/openapi.json @@ -6463,6 +6463,23 @@ "title": "Has Been Opened" }, "description": "Whether to include/exclude recent workflows" + }, + { + "name": "is_public", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Is Public" + }, + "description": "Filter by public/shared status" } ], "responses": { @@ -6655,6 +6672,23 @@ "title": "Categories" }, "description": "The categories to include" + }, + { + "name": "is_public", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Is Public" + }, + "description": "Filter by public/shared status" } ], "responses": { @@ -6744,6 +6778,23 @@ "title": "Has Been Opened" }, "description": "Whether to include/exclude recent workflows" + }, + { + "name": "is_public", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Is Public" + }, + "description": "Filter by public/shared status" } ], "responses": { @@ -6812,6 +6863,23 @@ "title": "Has Been Opened" }, "description": "Whether to include/exclude recent workflows" + }, + { + "name": "is_public", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Is Public" + }, + "description": "Filter by public/shared status" } ], "responses": { @@ -7352,6 +7420,67 @@ } } } + }, + "/api/v1/workflows/i/{workflow_id}/is_public": { + "patch": { + "tags": ["workflows"], + "summary": "Update Workflow Is Public", + "description": "Updates whether a workflow is shared publicly", + "operationId": "update_workflow_is_public", + "parameters": [ + { + "name": "workflow_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workflow Id" + }, + "description": "The workflow to update" + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "properties": { + "is_public": { + "type": "boolean", + "title": "Is Public", + "description": "Whether the workflow should be shared publicly" + } + }, + "type": "object", + "required": ["is_public"], + "title": "Body_update_workflow_is_public" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/WorkflowRecordDTO" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } } }, "components": { @@ -59137,10 +59266,20 @@ "workflow": { "$ref": "#/components/schemas/Workflow", "description": "The workflow." + }, + "user_id": { + "type": "string", + "title": "User Id", + "description": "The id of the user who owns this workflow." + }, + "is_public": { + "type": "boolean", + "title": "Is Public", + "description": "Whether this workflow is shared with all users." } }, "type": "object", - "required": ["workflow_id", "name", "created_at", "updated_at", "workflow"], + "required": ["workflow_id", "name", "created_at", "updated_at", "workflow", "user_id", "is_public"], "title": "WorkflowRecordDTO" }, "WorkflowRecordListItemWithThumbnailDTO": { @@ -59222,15 +59361,35 @@ ], "title": "Thumbnail Url", "description": "The URL of the workflow thumbnail." + }, + "user_id": { + "type": "string", + "title": "User Id", + "description": "The id of the user who owns this workflow." + }, + "is_public": { + "type": "boolean", + "title": "Is Public", + "description": "Whether this workflow is shared with all users." } }, "type": "object", - "required": ["workflow_id", "name", "created_at", "updated_at", "description", "category", "tags"], + "required": [ + "workflow_id", + "name", + "created_at", + "updated_at", + "description", + "category", + "tags", + "user_id", + "is_public" + ], "title": "WorkflowRecordListItemWithThumbnailDTO" }, "WorkflowRecordOrderBy": { "type": "string", - "enum": ["created_at", "updated_at", "opened_at", "name"], + "enum": ["created_at", "updated_at", "opened_at", "name", "is_public"], "title": "WorkflowRecordOrderBy", "description": "The order by options for workflow records" }, @@ -59303,10 +59462,20 @@ ], "title": "Thumbnail Url", "description": "The URL of the workflow thumbnail." + }, + "user_id": { + "type": "string", + "title": "User Id", + "description": "The id of the user who owns this workflow." + }, + "is_public": { + "type": "boolean", + "title": "Is Public", + "description": "Whether this workflow is shared with all users." } }, "type": "object", - "required": ["workflow_id", "name", "created_at", "updated_at", "workflow"], + "required": ["workflow_id", "name", "created_at", "updated_at", "workflow", "user_id", "is_public"], "title": "WorkflowRecordWithThumbnailDTO" }, "WorkflowWithoutID": { diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index da4e31142f..e537362801 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -21,7 +21,7 @@ "scripts": { "dev": "vite dev", "dev:host": "vite dev --host", - "build": "pnpm run lint && vite build", + "build": "pnpm run lint && vitest run && vite build", "typegen": "node scripts/typegen.js", "preview": "vite preview", "lint:knip": "knip --tags=-knipignore", @@ -35,6 +35,7 @@ "storybook": "storybook dev -p 6006", "build-storybook": "storybook build", "test": "vitest", + "test:run": "vitest run", "test:ui": "vitest --coverage --ui", "test:no-watch": "vitest --no-watch" }, @@ -65,6 +66,7 @@ "i18next-http-backend": "^3.0.2", "idb-keyval": "6.2.1", "jsondiffpatch": "^0.7.3", + "jszip": "^3.10.1", "konva": "^9.3.22", "linkify-react": "^4.3.1", "linkifyjs": "^4.3.1", diff --git a/invokeai/frontend/web/pnpm-lock.yaml b/invokeai/frontend/web/pnpm-lock.yaml index 3f94ba7d69..6a2ed95ab0 100644 --- a/invokeai/frontend/web/pnpm-lock.yaml +++ b/invokeai/frontend/web/pnpm-lock.yaml @@ -86,6 +86,9 @@ importers: jsondiffpatch: specifier: ^0.7.3 version: 0.7.3 + jszip: + specifier: ^3.10.1 + version: 3.10.1 konva: specifier: ^9.3.22 version: 9.3.22 @@ -2003,6 +2006,9 @@ packages: copy-to-clipboard@3.3.3: resolution: {integrity: sha512-2KV8NhB5JqC3ky0r9PMCAZKbUHSwtEo4CwCs0KXgruG43gX5PMqDEBbVU4OUzw2MuAWUfsuFmWvEKG5QRfSnJA==} + core-util-is@1.0.3: + resolution: {integrity: sha512-ZQBvi1DcpJ4GDqanjucZ2Hj3wEO5pZDS89BWbkcrvdxksJorwUDDZamX9ldFkp9aw2lmBDLgkObEA4DWNJ9FYQ==} + cosmiconfig@7.1.0: resolution: {integrity: sha512-AdmX6xUzdNASswsFtmwSt7Vj8po9IuqXm0UXz7QKPuEUmPB4XyjGfaAr2PSuELMwkRMVH1EpIkX5bTZGRB3eCA==} engines: {node: '>=10'} @@ -2672,6 +2678,9 @@ packages: resolution: {integrity: sha512-Hs59xBNfUIunMFgWAbGX5cq6893IbWg4KnrjbYwX3tx0ztorVgTDA6B2sxf8ejHJ4wz8BqGUMYlnzNBer5NvGg==} engines: {node: '>= 4'} + immediate@3.0.6: + resolution: {integrity: sha512-XXOFtyqDjNDAQxVfYxuF7g9Il/IbWmmlQg2MYKOH8ExIT1qg6xc4zyS3HaEEATgs1btfzxq15ciUiY7gjSXRGQ==} + immer@10.1.1: resolution: {integrity: sha512-s2MPrmjovJcoMaHtx6K11Ra7oD05NT97w1IC5zpMkT6Atjr7H8LjaDd81iIxUYpMKSRRNMJE703M1Fhr/TctHw==} @@ -2825,6 +2834,9 @@ packages: resolution: {integrity: sha512-fKzAra0rGJUUBwGBgNkHZuToZcn+TtXHpeCgmkMJMMYx1sQDYaCSyjJBSCa2nH1DGm7s3n1oBnohoVTBaN7Lww==} engines: {node: '>=8'} + isarray@1.0.0: + resolution: {integrity: sha512-VLghIWNM6ELQzo7zwmcg0NmTVyWKYjvIeM83yjp0wRDTmUnrM678fQbcKBo6n2CJEF0szoG//ytg+TKla89ALQ==} + isarray@2.0.5: resolution: {integrity: sha512-xHjhDr3cNBK0BzdUJSPXZntQUx/mwMS5Rw4A7lPJ90XGAO6ISP/ePDNuo0vhqOZU+UD5JoodwCAAoZQd3FeAKw==} @@ -2916,6 +2928,9 @@ packages: resolution: {integrity: sha512-ZZow9HBI5O6EPgSJLUb8n2NKgmVWTwCvHGwFuJlMjvLFqlGG6pjirPhtdsseaLZjSibD8eegzmYpUZwoIlj2cQ==} engines: {node: '>=4.0'} + jszip@3.10.1: + resolution: {integrity: sha512-xXDvecyTpGLrqFrvkrUSoxxfJI5AH7U8zxxtVclpsUtMCq4JQ290LY8AW5c7Ggnr/Y/oK+bQMbqK2qmtk3pN4g==} + keyv@4.5.4: resolution: {integrity: sha512-oxVHkHR/EJf2CNXnWxRLW6mg7JyCCUcG0DtEGmL2ctUo1PNTin1PUil+r/+4r5MpVgC/fn1kjsx7mjSujKqIpw==} @@ -2934,6 +2949,9 @@ packages: resolution: {integrity: sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ==} engines: {node: '>= 0.8.0'} + lie@3.3.0: + resolution: {integrity: sha512-UaiMJzeWRlEujzAuw5LokY1L5ecNQYZKfmyZ9L7wDHb/p5etKaxXhohBcrw0EYby+G/NA52vRSN4N39dxHAIwQ==} + lines-and-columns@1.2.4: resolution: {integrity: sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==} @@ -3210,6 +3228,9 @@ packages: package-json-from-dist@1.0.1: resolution: {integrity: sha512-UEZIS3/by4OC8vL3P2dTXRETpebLI2NiI5vIrjaD/5UtrkFX/tNbwjTSRAGC/+7CAo2pIcBaRgWmcBBHcsaCIw==} + pako@1.0.11: + resolution: {integrity: sha512-4hLB8Py4zZce5s4yd9XzopqwVv/yGNhV1Bl8NTmCq1763HeK2+EwVTv+leGeL13Dnh2wfbqowVPXCIO0z4taYw==} + pako@2.1.0: resolution: {integrity: sha512-w+eufiZ1WuJYgPXbV/PO3NCMEc3xqylkKHzp8bxp1uW4qaSNQUkwmLLEc3kKsfz8lpV1F8Ht3U1Cm+9Srog2ug==} @@ -3298,6 +3319,9 @@ packages: resolution: {integrity: sha512-Qb1gy5OrP5+zDf2Bvnzdl3jsTf1qXVMazbvCoKhtKqVs4/YK4ozX4gKQJJVyNe+cajNPn0KoC0MC3FUmaHWEmQ==} engines: {node: ^10.13.0 || ^12.13.0 || ^14.15.0 || >=15.0.0} + process-nextick-args@2.0.1: + resolution: {integrity: sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==} + prop-types@15.8.1: resolution: {integrity: sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg==} @@ -3539,6 +3563,9 @@ packages: resolution: {integrity: sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==} engines: {node: '>=0.10.0'} + readable-stream@2.3.8: + resolution: {integrity: sha512-8p0AUk4XODgIewSi0l8Epjs+EVnWiK7NoDIEGU0HhE7+ZyY8D1IMY7odu5lRrFXGg71L15KG8QrPmum45RTtdA==} + readable-stream@3.6.2: resolution: {integrity: sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==} engines: {node: '>= 6'} @@ -3661,6 +3688,9 @@ packages: resolution: {integrity: sha512-AURm5f0jYEOydBj7VQlVvDrjeFgthDdEF5H1dP+6mNpoXOMo1quQqJ4wvJDyRZ9+pO3kGWoOdmV08cSv2aJV6Q==} engines: {node: '>=0.4'} + safe-buffer@5.1.2: + resolution: {integrity: sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==} + safe-buffer@5.2.1: resolution: {integrity: sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==} @@ -3718,6 +3748,9 @@ packages: resolution: {integrity: sha512-RJRdvCo6IAnPdsvP/7m6bsQqNnn1FCBX5ZNtFL98MmFF/4xAIJTIg1YbHW5DC2W5SKZanrC6i4HsJqlajw/dZw==} engines: {node: '>= 0.4'} + setimmediate@1.0.5: + resolution: {integrity: sha512-MATJdZp8sLqDl/68LfQmbP8zKPLQNV6BIZoIgrscFDQ+RsvK/BxeDQOgyxKKoh0y/8h3BqVFnCqQ/gd+reiIXA==} + shebang-command@2.0.0: resolution: {integrity: sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==} engines: {node: '>=8'} @@ -3857,6 +3890,9 @@ packages: resolution: {integrity: sha512-UXSH262CSZY1tfu3G3Secr6uGLCFVPMhIqHjlgCUtCCcgihYc/xKs9djMTMUOb2j1mVSeU8EU6NWc/iQKU6Gfg==} engines: {node: '>= 0.4'} + string_decoder@1.1.1: + resolution: {integrity: sha512-n/ShnvDi6FHbbVfviro+WojiFzv+s8MPMHBczVePfUpDJLwoLT0ht1l4YwBCbi8pJAveEEdnkHyPyTP/mzRfwg==} + string_decoder@1.3.0: resolution: {integrity: sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==} @@ -6153,6 +6189,8 @@ snapshots: dependencies: toggle-selection: 1.0.6 + core-util-is@1.0.3: {} + cosmiconfig@7.1.0: dependencies: '@types/parse-json': 4.0.2 @@ -6957,6 +6995,8 @@ snapshots: ignore@7.0.5: {} + immediate@3.0.6: {} + immer@10.1.1: {} import-fresh@3.3.1: @@ -7103,6 +7143,8 @@ snapshots: dependencies: is-docker: 2.2.1 + isarray@1.0.0: {} + isarray@2.0.5: {} isexe@2.0.0: {} @@ -7192,6 +7234,13 @@ snapshots: object.assign: 4.1.7 object.values: 1.2.1 + jszip@3.10.1: + dependencies: + lie: 3.3.0 + pako: 1.0.11 + readable-stream: 2.3.8 + setimmediate: 1.0.5 + keyv@4.5.4: dependencies: json-buffer: 3.0.1 @@ -7221,6 +7270,10 @@ snapshots: prelude-ls: 1.2.1 type-check: 0.4.0 + lie@3.3.0: + dependencies: + immediate: 3.0.6 + lines-and-columns@1.2.4: {} linkify-react@4.3.1(linkifyjs@4.3.1)(react@18.3.1): @@ -7510,6 +7563,8 @@ snapshots: package-json-from-dist@1.0.1: {} + pako@1.0.11: {} + pako@2.1.0: {} parent-module@1.0.1: @@ -7578,6 +7633,8 @@ snapshots: ansi-styles: 5.2.0 react-is: 17.0.2 + process-nextick-args@2.0.1: {} + prop-types@15.8.1: dependencies: loose-envify: 1.4.0 @@ -7843,6 +7900,16 @@ snapshots: dependencies: loose-envify: 1.4.0 + readable-stream@2.3.8: + dependencies: + core-util-is: 1.0.3 + inherits: 2.0.4 + isarray: 1.0.0 + process-nextick-args: 2.0.1 + safe-buffer: 5.1.2 + string_decoder: 1.1.1 + util-deprecate: 1.0.2 + readable-stream@3.6.2: dependencies: inherits: 2.0.4 @@ -7994,6 +8061,8 @@ snapshots: has-symbols: 1.1.0 isarray: 2.0.5 + safe-buffer@5.1.2: {} + safe-buffer@5.2.1: {} safe-push-apply@1.0.0: @@ -8051,6 +8120,8 @@ snapshots: es-errors: 1.3.0 es-object-atoms: 1.1.1 + setimmediate@1.0.5: {} + shebang-command@2.0.0: dependencies: shebang-regex: 3.0.0 @@ -8236,6 +8307,10 @@ snapshots: define-properties: 1.2.1 es-object-atoms: 1.1.1 + string_decoder@1.1.1: + dependencies: + safe-buffer: 5.1.2 + string_decoder@1.3.0: dependencies: safe-buffer: 5.2.1 diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index c961b8e8fc..a0bf23843e 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -25,7 +25,8 @@ "rememberMe": "Remember me for 7 days", "signIn": "Sign In", "signingIn": "Signing in...", - "loginFailed": "Login failed. Please check your credentials." + "loginFailed": "Login failed. Please check your credentials.", + "sessionExpired": "Your credentials have expired. Please log in again to resume." }, "setup": { "title": "Welcome to InvokeAI", @@ -160,7 +161,17 @@ "imagesWithCount_other": "{{count}} images", "assetsWithCount_one": "{{count}} asset", "assetsWithCount_other": "{{count}} assets", - "updateBoardError": "Error updating board" + "updateBoardError": "Error updating board", + "setBoardVisibility": "Set Board Visibility", + "setVisibilityPrivate": "Set Private", + "setVisibilityShared": "Set Shared", + "setVisibilityPublic": "Set Public", + "visibilityPrivate": "Private", + "visibilityShared": "Shared", + "visibilityPublic": "Public", + "visibilityBadgeShared": "Shared board", + "visibilityBadgePublic": "Public board", + "updateBoardVisibilityError": "Error updating board visibility" }, "accordions": { "generation": { @@ -201,6 +212,7 @@ "copy": "Copy", "copyError": "$t(gallery.copy) Error", "clipboard": "Clipboard", + "collapseAll": "Collapse All", "crop": "Crop", "on": "On", "off": "Off", @@ -223,57 +235,76 @@ "discordLabel": "Discord", "dontAskMeAgain": "Don't ask me again", "dontShowMeThese": "Don't show me these", + "editName": "Edit name", "editor": "Editor", "error": "Error", "error_withCount_one": "{{count}} error", "error_withCount_other": "{{count}} errors", + "expandAll": "Expand All", "model_withCount_one": "{{count}} model", "model_withCount_other": "{{count}} models", "file": "File", + "fitView": "Fit View", "folder": "Folder", "format": "format", "githubLabel": "Github", "goTo": "Go to", "hotkeysLabel": "Hotkeys", - "loadingImage": "Loading Image", - "loadingModel": "Loading Model", + "hex": "Hex", "imageFailedToLoad": "Unable to Load Image", "img2img": "Image To Image", "inpaint": "inpaint", "input": "Input", "installed": "Installed", + "json": "JSON", "languagePickerLabel": "Language", "linear": "Linear", "load": "Load", "loading": "Loading", + "loadingImage": "Loading Image", + "loadingModel": "Loading Model", "localSystem": "Local System", + "minimize": "Minimize", + "next": "Next", + "noMatchingItems": "No matching items", + "notifications": "Notifications", "learnMore": "Learn More", "modelManager": "Model Manager", "noMatches": "No matches", "noOptions": "No options", "nodes": "Workflows", "notInstalled": "Not $t(common.installed)", + "openSlider": "Open slider", "openInNewTab": "Open in New Tab", "openInViewer": "Open in Viewer", "orderBy": "Order By", "outpaint": "outpaint", "outputs": "Outputs", "postprocessing": "Post Processing", + "previous": "Previous", "random": "Random", + "removeFromCollection": "Remove from Collection", "reportBugLabel": "Report Bug", + "resetView": "Reset View", "safetensors": "Safetensors", "save": "Save", "saveAs": "Save As", "saveChanges": "Save Changes", + "saveToAssets": "Save to Assets", + "settings": "Settings", "settingsLabel": "Settings", "simple": "Simple", "somethingWentWrong": "Something went wrong", "statusDisconnected": "Disconnected", "template": "Template", + "toggleRgbHex": "Toggle RGB/HEX", "toResolve": "To resolve", "txt2img": "Text To Image", "unknown": "Unknown", + "unpin": "Unpin", "upload": "Upload", + "zoomIn": "Zoom In", + "zoomOut": "Zoom Out", "updated": "Updated", "created": "Created", "prevPage": "Previous Page", @@ -343,13 +374,19 @@ "discard": "Discard", "noPromptHistory": "No prompt history recorded.", "noMatchingPrompts": "No matching prompts in history.", - "toSwitchBetweenPrompts": "to switch between prompts." + "toSwitchBetweenPrompts": "to switch between prompts.", + "promptHistory": "Prompt History", + "clearHistory": "Clear History", + "usePrompt": "Use prompt", + "searchPrompts": "Search..." }, "queue": { "queue": "Queue", "queueFront": "Add to Front of Queue", "queueBack": "Add to Queue", + "queueActionsMenu": "Queue Actions Menu", "queueEmpty": "Queue Empty", + "queueItem": "Queue Item", "enqueueing": "Queueing Batch", "resume": "Resume", "resumeTooltip": "Resume Processor", @@ -492,7 +529,10 @@ "imagesSettings": "Gallery Images Settings", "jump": "Jump", "loading": "Loading", + "loadingGallery": "Loading gallery...", + "loadingMetadata": "Loading metadata...", "newestFirst": "Newest First", + "noImagesFound": "No images found", "oldestFirst": "Oldest First", "sortDirection": "Sort Direction", "showStarredImagesFirst": "Show Starred Images First", @@ -504,6 +544,8 @@ "unableToLoad": "Unable to load Gallery", "deleteSelection": "Delete Selection", "downloadSelection": "Download Selection", + "bulkDownloadReady": "Download ready", + "clickToDownload": "Click here to download", "bulkDownloadRequested": "Preparing Download", "bulkDownloadRequestedDesc": "Your download request is being prepared. This may take a few moments.", "bulkDownloadRequestFailed": "Problem Preparing Download", @@ -675,6 +717,10 @@ "title": "Rect Tool", "desc": "Select the rect tool." }, + "selectLassoTool": { + "title": "Lasso Tool", + "desc": "Select the lasso tool." + }, "selectViewTool": { "title": "View Tool", "desc": "Select the view tool." @@ -935,7 +981,8 @@ } }, "lora": { - "weight": "Weight" + "weight": "Weight", + "removeLoRA": "Remove LoRA" }, "metadata": { "allPrompts": "All Prompts", @@ -970,6 +1017,7 @@ "seedVarianceEnabled": "Seed Variance Enabled", "seedVarianceStrength": "Seed Variance Strength", "seedVarianceRandomizePercent": "Seed Variance Randomize %", + "zImageShift": "Z-Image Shift", "seed": "Seed", "steps": "Steps", "strength": "Image to image strength", @@ -982,6 +1030,14 @@ "modelManager": { "active": "active", "actions": "Bulk Actions", + "deleteModelsConfirm": "Are you sure you want to delete {{count}} model(s)? This action cannot be undone.", + "deleteWarning": "Models in your Invoke models directory will be permanently deleted from disk.", + "modelsDeleted": "Successfully deleted {{count}} model(s)", + "modelsDeleteFailed": "Failed to delete models", + "someModelsFailedToDelete": "{{count}} model(s) could not be deleted", + "modelsDeletedPartial": "Partially completed", + "someModelsDeleted": "{{deleted}} deleted, {{failed}} failed", + "modelsDeleteError": "Error deleting models", "pause": "Pause", "pauseAll": "Pause All", "pauseAllTooltip": "Pause all active downloads", @@ -1012,6 +1068,15 @@ "reidentifySuccess": "Model reidentified successfully", "reidentifyUnknown": "Unable to identify model", "reidentifyError": "Error reidentifying model", + "reidentifyModels": "Reidentify Models", + "reidentifyModelsConfirm": "Are you sure you want to reidentify {{count}} model(s)? This will re-probe their weights files to determine the correct format and settings.", + "reidentifyWarning": "This will reset any custom settings you may have applied to these models.", + "modelsReidentified": "Successfully reidentified {{count}} model(s)", + "modelsReidentifyFailed": "Failed to reidentify models", + "someModelsFailedToReidentify": "{{count}} model(s) could not be reidentified", + "modelsReidentifiedPartial": "Partially completed", + "someModelsReidentified": "{{succeeded}} reidentified, {{failed}} failed", + "modelsReidentifyError": "Error reidentifying models", "updatePath": "Update Path", "updatePathTooltip": "Update the file path for this model if you have moved the model files to a new location.", "updatePathDescription": "Enter the new path to the model file or directory. Use this if you have manually moved the model files on disk.", @@ -1150,7 +1215,9 @@ "numImages": "Num Images", "modelPickerFallbackNoModelsInstalled": "No models installed.", "modelPickerFallbackNoModelsInstalled2": "Visit the Model Manager to install models.", + "modelPickerFallbackNoModelsInstalledNonAdmin": "No models installed. Ask your InvokeAI administrator () to install some models.", "noModelsInstalledDesc1": "Install models with the", + "noModelsInstalledAskAdmin": "Ask your administrator to install some.", "noModelSelected": "No Model Selected", "noMatchingModels": "No matching models", "noModelsInstalled": "No models installed", @@ -1224,11 +1291,18 @@ "triggerPhrases": "Trigger Phrases", "loraTriggerPhrases": "LoRA Trigger Phrases", "mainModelTriggerPhrases": "Main Model Trigger Phrases", + "queueEmpty": "The install queue is empty.", "selectAll": "Select All", "selectModelToView": "Select a model to view its details", "typePhraseHere": "Type phrase here", "t5Encoder": "T5 Encoder", "qwen3Encoder": "Qwen3 Encoder", + "animaVae": "VAE", + "animaVaePlaceholder": "Select Anima-compatible VAE", + "animaQwen3Encoder": "Qwen3 0.6B Encoder", + "animaQwen3EncoderPlaceholder": "Select Qwen3 0.6B encoder", + "animaT5Encoder": "T5-XXL Encoder", + "animaT5EncoderPlaceholder": "Select T5-XXL encoder", "zImageVae": "VAE (optional)", "zImageVaePlaceholder": "From VAE source model", "zImageQwen3Encoder": "Qwen3 Encoder (optional)", @@ -1239,6 +1313,12 @@ "flux2KleinVaePlaceholder": "From main model", "flux2KleinQwen3Encoder": "Qwen3 Encoder (optional)", "flux2KleinQwen3EncoderPlaceholder": "From main model", + "qwenImageComponentSource": "VAE/Encoder Source (Diffusers)", + "qwenImageComponentSourcePlaceholder": "Required for GGUF models", + "qwenImageQuantization": "Encoder Quantization", + "qwenImageQuantizationNone": "None (bf16)", + "qwenImageQuantizationInt8": "8-bit (int8)", + "qwenImageQuantizationNf4": "4-bit (nf4)", "upcastAttention": "Upcast Attention", "uploadImage": "Upload Image", "urlOrLocalPath": "URL or Local Path", @@ -1337,6 +1417,8 @@ "fullyContainNodesHelp": "Nodes must be fully inside the selection box to be selected", "showEdgeLabels": "Show Edge Labels", "showEdgeLabelsHelp": "Show labels on edges, indicating the connected nodes", + "groupNodesByCategory": "Group Nodes by Category", + "groupNodesByCategoryHelp": "Group nodes by category in the add node dialog", "hideLegendNodes": "Hide Field Type Legend", "hideMinimapnodes": "Hide MiniMap", "inputMayOnlyHaveOneConnection": "Input may only have one connection", @@ -1347,6 +1429,7 @@ "noWorkflows": "No Workflows", "noMatchingWorkflows": "No Matching Workflows", "noWorkflow": "No Workflow", + "noWorkflowToSave": "No workflow to save", "unableToUpdateNode": "Node update failed: node {{node}} of type {{type}} (may require deleting and recreating)", "mismatchedVersion": "Invalid node: node {{node}} of type {{type}} has mismatched version (try updating?)", "missingTemplate": "Invalid node: node {{node}} of type {{type}} missing template (not installed?)", @@ -1370,9 +1453,12 @@ "nodeOpacity": "Node Opacity", "nodeVersion": "Node Version", "noOutputRecorded": "No outputs recorded", + "nodeData": "Node Data", "notes": "Notes", "description": "Description", "notesDescription": "Add notes about your workflow", + "addConnector": "Add Connector", + "deleteConnector": "Delete Connector", "problemSettingTitle": "Problem Setting Title", "resetToDefaultValue": "Reset to default value", "reloadNodeTemplates": "Reload Node Templates", @@ -1490,6 +1576,7 @@ "copyImage": "Copy Image", "denoisingStrength": "Denoising Strength", "disabledNoRasterContent": "Disabled (No Raster Content)", + "disabledNotSupported": "Not supported by model", "downloadImage": "Download Image", "general": "General", "guidance": "Guidance", @@ -1503,6 +1590,7 @@ "info": "Info", "invoke": { "addingImagesTo": "Adding images to", + "boardNotWritable": "You do not have write access to board \"{{boardName}}\". Select a board you own or switch to Uncategorized.", "modelDisabledForTrial": "Generating with {{modelName}} is not available on trial accounts. Visit your account settings to upgrade.", "invoke": "Invoke", "missingFieldTemplate": "Missing field template", @@ -1529,8 +1617,12 @@ "noFLUXVAEModelSelected": "No VAE model selected for FLUX generation", "noCLIPEmbedModelSelected": "No CLIP Embed model selected for FLUX generation", "noQwen3EncoderModelSelected": "No Qwen3 Encoder model selected for FLUX2 Klein generation", + "noQwenImageComponentSourceSelected": "GGUF Qwen Image models require a Diffusers Component Source for VAE/encoder", "noZImageVaeSourceSelected": "No VAE source: Select VAE (FLUX) or Qwen3 Source model", "noZImageQwen3EncoderSourceSelected": "No Qwen3 Encoder source: Select Qwen3 Encoder or Qwen3 Source model", + "noAnimaVaeModelSelected": "No Anima VAE model selected", + "noAnimaQwen3EncoderModelSelected": "No Anima Qwen3 Encoder model selected", + "noAnimaT5EncoderModelSelected": "No Anima T5 Encoder model selected", "fluxModelIncompatibleBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), bbox width is {{width}}", "fluxModelIncompatibleBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), bbox height is {{height}}", "fluxModelIncompatibleScaledBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), scaled bbox width is {{width}}", @@ -1579,6 +1671,7 @@ "sendToCanvas": "Send To Canvas", "sendToUpscale": "Send To Upscale", "showOptionsPanel": "Show Side Panel (O or T)", + "shift": "Shift", "shuffle": "Shuffle Seed", "steps": "Steps", "strength": "Strength", @@ -1601,6 +1694,7 @@ "boxBlur": "Box Blur", "staged": "Staged", "resolution": "Resolution", + "imageSize": "Image Size", "modelDisabledForTrial": "Generating with {{modelName}} is not available on trial accounts. Visit your account settings to upgrade." }, "dynamicPrompts": { @@ -1616,6 +1710,7 @@ "perPromptDesc": "Use a different seed for each image" }, "loading": "Generating Dynamic Prompts...", + "problemGeneratingPrompts": "Problem generating prompts", "promptsToGenerate": "Prompts to Generate" }, "sdxl": { @@ -1654,6 +1749,8 @@ "enableNSFWChecker": "Enable NSFW Checker", "general": "General", "generation": "Generation", + "maxQueueHistory": "Max Queue History", + "maxQueueHistorySaveFailed": "Failed to save Max Queue History", "models": "Models", "preferAttentionStyleNumeric": "Prefer Numeric Attention Style", "prompt": "Prompt", @@ -2253,6 +2350,8 @@ "tags": "Tags", "yourWorkflows": "Your Workflows", "recentlyOpened": "Recently Opened", + "sharedWorkflows": "Shared Workflows", + "shareWorkflow": "Shared workflow", "noRecentWorkflows": "No Recent Workflows", "private": "Private", "shared": "Shared", @@ -2403,6 +2502,27 @@ "pullBboxIntoReferenceImageError": "Problem Pulling BBox Into ReferenceImage", "addAdjustments": "Add Adjustments", "removeAdjustments": "Remove Adjustments", + "workflowIntegration": { + "title": "Run Workflow on Canvas", + "description": "Select a workflow with a Canvas Output node and an image parameter to run on the current canvas layer. You can adjust parameters before executing. The result will be added back to the canvas.", + "execute": "Execute Workflow", + "executing": "Executing...", + "runWorkflow": "Run Workflow", + "filteringWorkflows": "Filtering workflows...", + "loadingWorkflows": "Loading workflows...", + "noWorkflowsFound": "No workflows found.", + "noWorkflowsWithImageField": "No compatible workflows found. A workflow needs a Form Builder with an image input field and a Canvas Output node.", + "selectWorkflow": "Select Workflow", + "selectPlaceholder": "Choose a workflow...", + "unnamedWorkflow": "Unnamed Workflow", + "loadingParameters": "Loading workflow parameters...", + "noFormBuilderError": "This workflow has no form builder and cannot be used. Please select a different workflow.", + "imageFieldSelected": "This field will receive the canvas image", + "imageFieldNotSelected": "Click to use this field for canvas image", + "executionStarted": "Workflow execution started", + "executionStartedDescription": "The result will appear in the staging area when complete.", + "executionFailed": "Failed to execute workflow" + }, "compositeOperation": { "label": "Blend Mode", "add": "Add Blend Mode", @@ -2475,6 +2595,8 @@ "disableAutoNegative": "Disable Auto Negative", "deletePrompt": "Delete Prompt", "deleteReferenceImage": "Delete Reference Image", + "disableReferenceImage": "Disable Reference Image", + "enableReferenceImage": "Enable Reference Image", "showHUD": "Show HUD", "rectangle": "Rectangle", "maskFill": "Mask Fill", @@ -2495,6 +2617,7 @@ "controlLayer": "Control Layer", "inpaintMask": "Inpaint Mask", "invertMask": "Invert Mask", + "invertRegion": "Invert Region", "regionalGuidance": "Regional Guidance", "referenceImageRegional": "Reference Image (Regional)", "referenceImageGlobal": "Reference Image (Global)", @@ -2502,7 +2625,10 @@ "asRasterLayerResize": "As $t(controlLayers.rasterLayer) (Resize)", "asControlLayer": "As $t(controlLayers.controlLayer)", "asControlLayerResize": "As $t(controlLayers.controlLayer) (Resize)", + "invalidReferenceImage": "Invalid Reference Image:", "referenceImage": "Reference Image", + "removeImageFromCollection": "Remove Image from Collection", + "selectRefImage": "Select Ref Image", "maxRefImages": "Max Ref Images", "useAsReferenceImage": "Use as Reference Image", "regionalReferenceImage": "Regional Reference Image", @@ -2523,7 +2649,11 @@ "alignLeft": "Align Left", "alignCenter": "Align Center", "alignRight": "Align Right", - "px": "px" + "px": "px", + "lineHeight": "Spacing", + "lineHeightDense": "Dense", + "lineHeightNormal": "Normal", + "lineHeightSpacious": "Spacious" }, "newCanvasFromImage": "New Canvas from Image", "newImg2ImgCanvasFromImage": "New Img2Img from Image", @@ -2670,17 +2800,24 @@ "crosshatch": "Crosshatch", "vertical": "Vertical", "horizontal": "Horizontal", - "diagonal": "Diagonal" + "diagonal": "Diagonal", + "switchColors": "Switch FG/BG (X)" }, "gradient": { "linear": "Linear", "radial": "Radial", "clip": "Clip Gradient" }, + "lasso": { + "freehand": "Freehand", + "polygon": "Polygon", + "polygonHint": "Click to add points, click the first point to close." + }, "tool": { "brush": "Brush", "eraser": "Eraser", "rectangle": "Rectangle", + "lasso": "Lasso", "gradient": "Gradient", "bbox": "Bbox", "move": "Move", @@ -2920,6 +3057,19 @@ "copyCanvasToClipboard": "Copy Canvas to Clipboard", "copyBboxToClipboard": "Copy Bbox to Clipboard" }, + "canvasProject": { + "project": "Project", + "saveProject": "Save Canvas Project", + "loadProject": "Load Canvas Project", + "saveSuccess": "Project Saved", + "saveSuccessDesc": "Saved project with {{count}} images", + "saveError": "Failed to Save Project", + "loadSuccess": "Project Loaded", + "loadSuccessDesc": "Canvas state restored from project file", + "loadError": "Failed to Load Project", + "loadWarning": "Loading a project will replace your current canvas, including all layers, masks, reference images, and generation parameters. This action cannot be undone.", + "projectName": "Project Name" + }, "stagingArea": { "accept": "Accept", "discardAll": "Discard All", @@ -2927,13 +3077,18 @@ "previous": "Previous", "next": "Next", "saveToGallery": "Save To Gallery", + "hideThumbnails": "Hide Thumbnails", + "showThumbnails": "Show Thumbnails", "showResultsOn": "Showing Results", "showResultsOff": "Hiding Results" }, "autoSwitch": { "off": "Off", + "doNotAutoSwitch": "Do not auto-switch", "switchOnStart": "On Start", - "switchOnFinish": "On Finish" + "switchOnStartDesc": "Switch on start", + "switchOnFinish": "On Finish", + "switchOnFinishDesc": "Switch on finish" } }, "upscaling": { @@ -2950,6 +3105,7 @@ "tileOverlap": "Tile Overlap", "postProcessingMissingModelWarning": "Visit the Model Manager to install a post-processing (image to image) model.", "missingModelsWarning": "Visit the Model Manager to install the required models:", + "missingModelsWarningNonAdmin": "Ask your InvokeAI administrator () to install the required models:", "mainModelDesc": "Main model (SD1.5 or SDXL architecture)", "tileControlNetModelDesc": "Tile ControlNet model for the chosen main model architecture", "upscaleModelDesc": "Upscale (image to image) model", @@ -3058,6 +3214,7 @@ }, "workflows": { "description": "Workflows are reusable templates that automate image generation tasks, allowing you to quickly perform complex operations and get consistent results.", + "descriptionMultiuser": "Workflows are reusable templates that automate image generation tasks, allowing you to quickly perform complex operations and get consistent results. You may share your workflows with other users of the system by selecting 'Shared workflow' when you create or edit it.", "learnMoreLink": "Learn more about creating workflows", "browseTemplates": { "title": "Browse Workflow Templates", @@ -3136,25 +3293,39 @@ "toGetStartedLocal": "To get started, make sure to download or import models needed to run Invoke. Then, enter a prompt in the box and click Invoke to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the Gallery or edit them to the Canvas.", "toGetStarted": "To get started, enter a prompt in the box and click Invoke to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the Gallery or edit them to the Canvas.", "toGetStartedWorkflow": "To get started, fill in the fields on the left and press Invoke to generate your image. Want to explore more workflows? Click the folder icon next to the workflow title to see a list of other templates you can try.", + "toGetStartedNonAdmin": "To get started, ask your InvokeAI administrator () to install the AI models needed to run Invoke. Then, enter a prompt in the box and click Invoke to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the Gallery or edit them to the Canvas.", "gettingStartedSeries": "Want more guidance? Check out our Getting Started Series for tips on unlocking the full potential of the Invoke Studio.", "lowVRAMMode": "For best performance, follow our Low VRAM guide.", - "noModelsInstalled": "It looks like you don't have any models installed! You can download a starter model bundle or import models." + "noModelsInstalled": "It looks like you don't have any models installed! You can download a starter model bundle or import models.", + "noModelsInstalledAskAdmin": "Ask your administrator to install some." }, "whatsNew": { "whatsNewInInvoke": "What's New in Invoke", "items": [ - "FLUX.2 Klein Support: InvokeAI now supports the new FLUX.2 Klein models (4B and 9B variants) with GGUF, FP8, and Diffusers formats. Features include txt2img, img2img, inpainting, and outpainting. See 'Starter Models' to get started.", - "DyPE support for FLUX models improves high-resolution (>1536 px up to 4K) images. Go to the 'Advanced Options' section to activate.", - "Z-Image Turbo diversity: Active 'Seed Variance Enhancer' under 'Advanced Options' to add diversitiy to your ZiT gens." + "Multi-user mode supports multiple isolated users on the same server.", + "Enhanced support for Z-Image and FLUX.2 Models.", + "Multiple user interface enhancements and new canvas features." ], "takeUserSurvey": "📣 Let us know how you like InvokeAI. Take our User Experience Survey!", "readReleaseNotes": "Read Release Notes", "watchRecentReleaseVideos": "Watch Recent Release Videos", "watchUiUpdatesOverview": "Watch UI Updates Overview" }, + "cropper": { + "cropImage": "Crop Image", + "aspectRatio": "Aspect Ratio", + "free": "Free", + "mouseWheelZoom": "Mouse wheel: Zoom", + "spaceDragPan": "Space + Drag: Pan", + "dragCropBoxToAdjust": "Drag crop box or handles to adjust" + }, "supportVideos": { "supportVideos": "Support Videos", "gettingStarted": "Getting Started", + "gettingStartedPlaylist": "Getting Started playlist", + "studioSessionsPlaylist": "Studio Sessions playlist", + "discord": "Discord", + "github": "GitHub", "watch": "Watch", "studioSessionsDesc": "Join our to participate in the live sessions and ask questions. Sessions are uploaded to the playlist the following week.", "videos": { diff --git a/invokeai/frontend/web/public/locales/es.json b/invokeai/frontend/web/public/locales/es.json index 4c58ea87f5..8f68ea585c 100644 --- a/invokeai/frontend/web/public/locales/es.json +++ b/invokeai/frontend/web/public/locales/es.json @@ -92,7 +92,9 @@ "toResolve": "Para resolver", "outpaint": "outpaint", "simple": "Sencillo", - "close": "Cerrar" + "close": "Cerrar", + "board": "Tablero", + "crop": "Cortar" }, "gallery": { "galleryImageSize": "Tamaño de la imagen", @@ -327,7 +329,7 @@ "movingImagesToBoard_one": "Moviendo {{count}} imagen al panel:", "movingImagesToBoard_many": "Moviendo {{count}} imágenes al panel:", "movingImagesToBoard_other": "Moviendo {{count}} imágenes al panel:", - "bottomMessage": "Al eliminar este panel y las imágenes que contiene, se restablecerán las funciones que los estén utilizando actualmente.", + "bottomMessage": "Al eliminarlas imágenes, se restablecerán las funcionalidades que actualmente las estén utilizando.", "deleteBoardAndImages": "Borrar el panel y las imágenes", "loading": "Cargando...", "deletedBoardsCannotbeRestored": "Los paneles eliminados no se pueden restaurar. Al Seleccionar 'Borrar solo el panel' transferirá las imágenes a un estado sin categorizar.", @@ -354,9 +356,21 @@ "unarchiveBoard": "Desarchivar el panel", "noBoards": "No hay paneles {{boardType}}", "shared": "Paneles compartidos", - "deletedPrivateBoardsCannotbeRestored": "Los paneles eliminados no se pueden restaurar. Al elegir \"Eliminar solo el panel\", las imágenes se colocan en un estado privado y sin categoría para el creador de la imagen.", + "deletedPrivateBoardsCannotbeRestored": "Los paneles eliminados no se pueden restaurar. Al elegir \"Eliminar solo el panel\", las imágenes se colocarán en un estado privado y sin categoría para el creador de la imagen.", "private": "Paneles privados", - "updateBoardError": "No se pudo actualizar el panel" + "updateBoardError": "No se pudo actualizar el panel", + "pause": "Pausa", + "resume": "Reanudar", + "restartFailed": "Reinicio fallido", + "restartFile": "Reiniciar archivo", + "restartRequired": "Reinicio requerido", + "resumeRefused": "Reanudación rechazada por el servidor. Reinicio requerido.", + "uncategorizedImages": "Imágenes sin categoría", + "deleteAllUncategorizedImages": "Eliminar todas las imágenes sin categoría", + "deletedImagesCannotBeRestored": "Las imágenes eliminadas no pueden ser restauradas.", + "hideBoards": "Ocultar tableros", + "locateInGalery": "Ubicar en galeria", + "viewBoards": "Ver paneles" }, "accordions": { "compositing": { @@ -867,5 +881,100 @@ "noModelsInstalled": "Parece que no tienes ningún modelo instalado", "gettingStartedSeries": "¿Desea más orientación? Consulte nuestra Serie de introducción para obtener consejos sobre cómo aprovechar todo el potencial de Invoke Studio.", "toGetStartedLocal": "Para empezar, asegúrate de descargar o importar los modelos necesarios para ejecutar Invoke. A continuación, introduzca un mensaje en el cuadro y haga clic en Invocar para generar su primera imagen. Seleccione una plantilla para mejorar los resultados. Puede elegir guardar sus imágenes directamente en Galería o editarlas en el Lienzo." + }, + "auth": { + "login": { + "title": "Iniciar sesión en InvokeAI", + "email": "Email", + "emailPlaceholder": "Email", + "password": "Contraseña", + "passwordPlaceholder": "Contraseña", + "rememberMe": "Recordarme por 7 días", + "signIn": "Iniciar sesión", + "signingIn": "Iniciando sesión...", + "loginFailed": "Inicio de sesión fallido. Por favor revise sus credenciales." + }, + "setup": { + "title": "Bienvenido a InvokeAI", + "subtitle": "Configure su cuenta de administrador para empezar", + "email": "Email", + "emailPlaceholder": "admin@example.com", + "emailHelper": "Este será su nombre de usuario para iniciar sesión", + "displayName": "Nombre para mostrar", + "displayNamePlaceholder": "Administrador", + "displayNameHelper": "Su nombre como se mostrará en la aplicación", + "password": "Contraseña", + "passwordPlaceholder": "Contraseña", + "passwordHelper": "Debe tener al menos 8 caracteres con mayúsculas, minúsculas y números", + "passwordTooShort": "La contraseña debe tener al menos 8 caracteres", + "passwordMissingRequirements": "La contraseña debe contener mayúsculas, minúsculas y numeros", + "confirmPassword": "Confirmar contraseña", + "confirmPasswordPlaceholder": "Confirmar contraseña", + "passwordsDoNotMatch": "Las contraseñas no coinciden", + "createAccount": "Crear cuenta de administrador", + "creatingAccount": "Configurando...", + "setupFailed": "Configuración fallida. Por favor intente nuevamente.", + "passwordHelperRelaxed": "Ingrese una contraseña (se mostrará la fortaleza)" + }, + "userMenu": "Menu de usuario", + "admin": "Administrador", + "logout": "Cerrar Sesión", + "adminOnlyFeature": "Esta funcionalidad solo esta disponible para administradores.", + "profile": { + "menuItem": "Mi perfil", + "title": "Mi perfil", + "email": "Email", + "emailReadOnly": "La dirección de email no puede ser cambiada", + "displayName": "Nombre para mostrar", + "displayNamePlaceholder": "Su nombre", + "changePassword": "Cambiar contraseña", + "currentPassword": "Contraseña Actual", + "currentPasswordPlaceholder": "Contraseña Actual", + "newPassword": "Nueva contraseña", + "newPasswordPlaceholder": "Nueva contraseña", + "confirmPassword": "Confirmar nueva contraseña", + "confirmPasswordPlaceholder": "Confirmar nueva contraseña", + "passwordsDoNotMatch": "Las contraseñas no coinciden", + "saveSuccess": "Perfil actualizado correctamente", + "saveFailed": "Falló el guardado del perfil. Por favor intente nuevamente." + }, + "userManagement": { + "menuItem": "Administración de usuario", + "title": "Administración de usuario", + "email": "Email", + "emailPlaceholder": "user@example.com", + "displayName": "Nombre para mostrar", + "displayNamePlaceholder": "Nombre para mostrar", + "password": "Contraseña", + "passwordPlaceholder": "Contraseña", + "newPassword": "Nueva contraseña", + "newPasswordPlaceholder": "Deje en blanco para conservar la contraseña actual", + "role": "Rol", + "status": "Estado", + "actions": "Acciones", + "isAdmin": "Administrador", + "user": "Usuario", + "you": "Tu", + "createUser": "Crear usuario", + "editUser": "Editar usuario", + "deleteUser": "Eliminar usuario", + "deleteConfirm": "Esta seguro que desea eliminar {{name}}? Esta accion no se podrá revertir.", + "generatePassword": "Generar contraseña robusta", + "showPassword": "Mostrar contraseña", + "hidePassword": "Ocultar contraseña", + "activate": "Activar", + "deactivate": "Desactivar", + "saveFailed": "Fallo al guardar usuario. Por favor intente nuevamente.", + "deleteFailed": "Fallo al borrar usuario. Por favor intente nuevamente.", + "loadFailed": "Fallo al cargar usuarios.", + "back": "Atras", + "cannotDeleteSelf": "Usted no puede eliminar su propia cuenta", + "cannotDeactivateSelf": "Usted no puede desactivar su propia cuenta" + }, + "passwordStrength": { + "weak": "Contraseña debil", + "moderate": "Contraseña moderada", + "strong": "Contraseña fuerte" + } } } diff --git a/invokeai/frontend/web/public/locales/fi.json b/invokeai/frontend/web/public/locales/fi.json index f03c6f1aa1..54e5a66660 100644 --- a/invokeai/frontend/web/public/locales/fi.json +++ b/invokeai/frontend/web/public/locales/fi.json @@ -4,7 +4,8 @@ "uploadImage": "Lataa kuva", "invokeProgressBar": "Invoken edistymispalkki", "nextImage": "Seuraava kuva", - "previousImage": "Edellinen kuva" + "previousImage": "Edellinen kuva", + "uploadImages": "Lähetä Kuva(t)" }, "common": { "languagePickerLabel": "Kielen valinta", @@ -29,5 +30,28 @@ "galleryImageSize": "Kuvan koko", "gallerySettings": "Gallerian asetukset", "autoSwitchNewImages": "Vaihda uusiin kuviin automaattisesti" + }, + "modelManager": { + "t5Encoder": "T5-kooderi", + "qwen3Encoder": "Qwen3-kooderi", + "zImageVae": "VAE (valinnainen)", + "zImageQwen3Encoder": "Qwen3-kooderi (valinnainen)", + "zImageQwen3SourcePlaceholder": "Pakollinen, jos VAE/Enkooderi on tyhjä", + "flux2KleinVae": "VAE (valinnainen)", + "flux2KleinQwen3Encoder": "Qwen3-kooderi (valinnainen)" + }, + "auth": { + "login": { + "title": "Kirjaudu sisään InvokeAI:hin", + "password": "Salasana", + "passwordPlaceholder": "Salasana", + "signIn": "Kirjaudu sisään", + "signingIn": "Kirjaudutaan sisään...", + "loginFailed": "Kirjautuminen epäonnistui. Tarkista käyttäjätunnuksesi tiedot." + }, + "setup": { + "title": "Tervetuloa InvokeAI:hin", + "subtitle": "Määritä ensimmäiseksi järjestelmänvalvojan tili" + } } } diff --git a/invokeai/frontend/web/public/locales/it.json b/invokeai/frontend/web/public/locales/it.json index d17d36d5c0..9fa7c5a894 100644 --- a/invokeai/frontend/web/public/locales/it.json +++ b/invokeai/frontend/web/public/locales/it.json @@ -132,7 +132,21 @@ "prevPage": "Pagina precedente", "nextPage": "Pagina successiva", "resetToDefaults": "Ripristina impostazioni predefinite", - "crop": "Ritaglia" + "crop": "Ritaglia", + "editName": "Modifica nome", + "fitView": "Adatta la vista", + "minimize": "Minimizza", + "next": "Prossimo", + "noMatchingItems": "Nessun articolo corrispondente", + "notifications": "Notifiche", + "previous": "Precedente", + "removeFromCollection": "Rimuovi dalla raccolta", + "resetView": "Ripristina la vista", + "saveToAssets": "Salva nelle risorse", + "settings": "Impostazioni", + "toggleRgbHex": "Attiva/disattiva RGB/HEX", + "unpin": "Sblocca", + "openSlider": "Apri il cursore" }, "gallery": { "galleryImageSize": "Dimensione dell'immagine", @@ -203,7 +217,12 @@ "selectAnImageToCompare": "Seleziona un'immagine da confrontare", "openViewer": "Apri Visualizzatore", "closeViewer": "Chiudi Visualizzatore", - "usePagedGalleryView": "Utilizza la visualizzazione Galleria a pagine" + "usePagedGalleryView": "Utilizza la visualizzazione Galleria a pagine", + "loadingGallery": "Caricamento galleria in corso...", + "loadingMetadata": "Caricamento dei metadati in corso...", + "noImagesFound": "Nessuna immagine trovata", + "bulkDownloadReady": "Download pronto", + "clickToDownload": "Clicca qui per scaricare" }, "hotkeys": { "searchHotkeys": "Cerca tasti di scelta rapida", @@ -844,7 +863,32 @@ "settingsImportedPartial": "Impostazioni del modello parzialmente importate. Le impostazioni incompatibili sono state ignorate: {{fields}}", "settingsImportFailed": "Impossibile importare le impostazioni del modello", "settingsImportIncompatible": "Il file delle impostazioni non contiene impostazioni compatibili per questo tipo di modello", - "settingsImportInvalidFile": "File di impostazioni non valido" + "settingsImportInvalidFile": "File di impostazioni non valido", + "reidentifyModels": "Re-identificare i modelli", + "reidentifyModelsConfirm": "Sei sicuro di voler re-identificare {{count}} modello(i)? Questa operazione eseguirà una nuova scansione dei relativi file dei pesi per determinarne il formato e le impostazioni corrette.", + "reidentifyWarning": "Questa operazione ripristinerà tutte le impostazioni personalizzate che potresti aver applicato a questi modelli.", + "modelsReidentified": "{{count}} modello(i) re-identificato(i) con successo", + "modelsReidentifyFailed": "Impossibile re-identificare i modelli", + "someModelsFailedToReidentify": "Non è stato possibile re-identificare {{count}} modello(i)", + "modelsReidentifiedPartial": "Completato parzialmente", + "someModelsReidentified": "{{succeeded}} re-identificato(i), {{failed}} fallito(i)", + "modelsReidentifyError": "Errore nella re-identificazione dei modelli", + "deleteModelsConfirm": "Sei sicuro di voler eliminare {{count}} modello(i)? Questa azione non può essere annullata.", + "deleteWarning": "I modelli presenti nella cartella dei modelli di Invoke verranno eliminati definitivamente dal disco.", + "modelsDeleted": "{{count}} modello(i) eliminato(i) con successo", + "modelsDeleteFailed": "Impossibile eliminare i modelli", + "someModelsFailedToDelete": "Non è stato possibile eliminare {{count}} modello(i)", + "modelsDeletedPartial": "Parzialmente completato", + "someModelsDeleted": "{{deleted}} eliminato(i), {{failed}} fallito(i)", + "modelsDeleteError": "Errore durante l'eliminazione dei modelli", + "queueEmpty": "La coda di installazione è vuota.", + "animaVaePlaceholder": "Seleziona VAE compatibile con Anima", + "animaQwen3EncoderPlaceholder": "Seleziona l'encoder Qwen3 0.6B", + "animaT5EncoderPlaceholder": "Seleziona l'encoder T5-XXL", + "qwenImageComponentSourcePlaceholder": "Necessario per i modelli GGUF", + "qwenImageComponentSource": "VAE/Sorgente Encoder (Diffusori)", + "qwenImageQuantization": "Quantizzazione dell'encoder", + "qwenImageQuantizationNone": "Nessuna (bf16)" }, "parameters": { "images": "Immagini", @@ -935,7 +979,11 @@ "fluxModelIncompatibleScaledBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), l'altezza ridimensionata del riquadro è {{height}}", "noZImageQwen3EncoderSourceSelected": "Nessuna sorgente Qwen3 Encoder: seleziona il modello Qwen3 Encoder o Qwen3 Source", "noZImageVaeSourceSelected": "Nessuna sorgente VAE: selezionare il modello di sorgente VAE (FLUX) o Qwen3", - "noQwen3EncoderModelSelected": "Nessun modello di encoder Qwen3 selezionato per la generazione Klein di FLUX2" + "noQwen3EncoderModelSelected": "Nessun modello di encoder Qwen3 selezionato per la generazione Klein di FLUX2", + "noAnimaVaeModelSelected": "Nessun modello VAE Anima selezionato", + "noAnimaQwen3EncoderModelSelected": "Nessun modello di encoder Anima Qwen3 selezionato", + "noAnimaT5EncoderModelSelected": "Nessun modello di encoder Anima T5 selezionato", + "noQwenImageComponentSourceSelected": "I modelli GGUF Qwen Image richiedono una sorgente componente diffusori per VAE/encoder" }, "useCpuNoise": "Usa la CPU per generare rumore", "iterations": "Iterazioni", @@ -1336,7 +1384,9 @@ "versionUnknown": " Versione sconosciuta", "generateValues": "Genera valori", "floatRangeGenerator": "Generatore di intervallo di numeri decimali", - "integerRangeGenerator": "Generatore di intervallo di numeri interi" + "integerRangeGenerator": "Generatore di intervallo di numeri interi", + "noWorkflowToSave": "Nessun flusso di lavoro da salvare", + "nodeData": "Dati del nodo" }, "boards": { "autoAddBoard": "Aggiungi automaticamente bacheca", @@ -1489,7 +1539,9 @@ "clearFailedAccessDenied": "Problema durante la cancellazione della coda: accesso negato", "user": "Utente", "cannotViewDetails": "Non hai l'autorizzazione per visualizzare i dettagli di questo elemento della coda", - "fieldValuesHidden": "" + "fieldValuesHidden": "", + "queueActionsMenu": "Menu azioni in coda", + "queueItem": "Elemento della coda" }, "models": { "noMatchingModels": "Nessun modello corrispondente", @@ -1535,7 +1587,8 @@ "promptsPreview": "Anteprima dei prompt", "showDynamicPrompts": "Mostra prompt dinamici", "loading": "Generazione prompt dinamici...", - "promptsToGenerate": "Prompt da generare" + "promptsToGenerate": "Prompt da generare", + "problemGeneratingPrompts": "Problema nella generazione dei prompt" }, "popovers": { "paramScheduler": { @@ -2269,7 +2322,11 @@ "insert": "Inserisci", "noPromptHistory": "Nessuna cronologia di prompt registrata.", "noMatchingPrompts": "Nessun prompt corrispondente nella cronologia.", - "toSwitchBetweenPrompts": "per passare da un prompt all'altro." + "toSwitchBetweenPrompts": "per passare da un prompt all'altro.", + "promptHistory": "Cronologia dei prompt", + "clearHistory": "Cancella cronologia", + "usePrompt": "Utilizza il prompt", + "searchPrompts": "Ricerca..." }, "controlLayers": { "addLayer": "Aggiungi Livello", @@ -2524,7 +2581,8 @@ "horizontal": "Orizzontale", "diagonal": "Diagonale", "bgFillColor": "Colore di sfondo", - "fgFillColor": "Colore di primo piano" + "fgFillColor": "Colore di primo piano", + "switchColors": "Commuta FG/BG (X)" }, "locked": "Bloccato", "hidingType": "Nascondere {{type}}", @@ -2585,7 +2643,9 @@ "saveToGallery": "Salva nella Galleria", "previous": "Precedente", "showResultsOn": "Visualizzare i risultati", - "showResultsOff": "Nascondere i risultati" + "showResultsOff": "Nascondere i risultati", + "hideThumbnails": "Nascondi le miniature", + "showThumbnails": "Mostra miniature" }, "HUD": { "bbox": "Riquadro di delimitazione", @@ -2639,7 +2699,7 @@ "desc": "Seleziona un singolo oggetto di destinazione. Una volta completata la selezione, fai clic su Applica per eliminare tutto ciò che si trova al di fuori dell'area selezionata, oppure salva la selezione come nuovo livello.", "visualModeDesc": "La modalità visiva utilizza input di tipo riquadro e punto per selezionare un oggetto.", "visualMode1": "Fai clic e trascina per disegnare un riquadro attorno all'oggetto che desideri selezionare. Puoi ottenere risultati migliori disegnando il riquadro un po' più grande o più piccolo dell'oggetto.", - "visualMode2": "Fare clic per aggiungere un punto di iinclusionei verde oppure fare clic tenendo premuto Maiusc per aggiungere un punto di iesclusionei rosso per indicare al modello cosa includere o escludere.", + "visualMode2": "Fai clic per aggiungere un punto verde includi oppure fai clic tenendo premuto il tasto Maiusc per aggiungere un punto rosso escludi per indicare al modello cosa includere o escludere.", "visualMode3": "I punti possono essere utilizzati per perfezionare una selezione di caselle oppure in modo indipendente.", "promptModeDesc": "La modalità Prompt utilizza l'input di testo per selezionare un oggetto.", "promptMode1": "Digitare una breve descrizione dell'oggetto che si desidera selezionare.", @@ -2719,7 +2779,10 @@ "autoSwitch": { "off": "Spento", "switchOnStart": "All'inizio", - "switchOnFinish": "Alla fine" + "switchOnFinish": "Alla fine", + "doNotAutoSwitch": "Non commutare automaticamente", + "switchOnStartDesc": "Attiva all'avvio", + "switchOnFinishDesc": "Attiva al termine" }, "invertMask": "Inverti maschera", "fitBboxToMasks": "Adatta il riquadro di delimitazione alle maschere", @@ -2825,8 +2888,39 @@ "strikethrough": "Barrato", "alignLeft": "Allinea a sinistra", "alignCenter": "Allinea al centro", - "alignRight": "Allinea a destra" - } + "alignRight": "Allinea a destra", + "lineHeight": "Spaziatura", + "lineHeightDense": "Densa", + "lineHeightNormal": "Normale", + "lineHeightSpacious": "Spaziosa" + }, + "workflowIntegration": { + "title": "Eseguire il flusso di lavoro sula Tela", + "description": "Seleziona un flusso di lavoro con un nodo Output su tela e un parametro immagine da eseguire sul livello corrente della tela. Puoi regolare i parametri prima dell'esecuzione. Il risultato verrà aggiunto nuovamente alla tela.", + "execute": "Eseguire il flusso di lavoro", + "executing": "Esecuzione in corso...", + "runWorkflow": "Avvia il flusso di lavoro", + "filteringWorkflows": "Filtraggio dei flussi di lavoro...", + "loadingWorkflows": "Caricamento dei flussi di lavoro...", + "noWorkflowsFound": "Nessun flusso di lavoro trovato.", + "noWorkflowsWithImageField": "Nessun flusso di lavoro compatibile trovato. Un flusso di lavoro richiede un Generatore Modello con un campo di input immagine e un nodo Output su tela.", + "selectWorkflow": "Seleziona il flusso di lavoro", + "selectPlaceholder": "Scegli un flusso di lavoro...", + "unnamedWorkflow": "Flusso di lavoro senza nome", + "loadingParameters": "Caricamento dei parametri del flusso di lavoro in corso...", + "noFormBuilderError": "Questo flusso di lavoro non dispone di un generatore di moduli e non può essere utilizzato. Selezionare un flusso di lavoro diverso.", + "imageFieldSelected": "Questo campo riceverà l'immagine della tela", + "imageFieldNotSelected": "Fai clic su questo campo per usarlo per l'immagine sulla tela", + "executionStarted": "L'esecuzione del flusso di lavoro è stata avviata", + "executionStartedDescription": "Il risultato apparirà nell'area di lavoro una volta completata l'operazione.", + "executionFailed": "Impossibile eseguire il flusso di lavoro" + }, + "disableReferenceImage": "Disabilita l'immagine di riferimento", + "enableReferenceImage": "Abilita l'immagine di riferimento", + "invertRegion": "Inverti la regione", + "invalidReferenceImage": "Immagine di riferimento non valida:", + "removeImageFromCollection": "Rimuovi l'immagine dalla raccolta", + "selectRefImage": "Seleziona l'immagine di riferimento" }, "ui": { "tabs": { @@ -3009,9 +3103,9 @@ "readReleaseNotes": "Leggi le note di rilascio", "watchRecentReleaseVideos": "Guarda i video su questa versione", "items": [ - "Supporto FLUX.2 Klein: InvokeAI ora supporta i nuovi modelli FLUX.2 Klein (varianti 4B e 9B) con formati GGUF, FP8 e Diffusers. Le funzionalità includono txt2img, img2img, inpainting e outpainting. Consultare la sezione \"Modelli di partenza\" per iniziare.", - "Il supporto DyPE per i modelli FLUX migliora le immagini ad alta risoluzione (da >1536 px fino a 4K). Vai alla sezione \"Opzioni avanzate\" per attivarlo.", - "Diversità Z-Image Turbo: attiva 'Seed Variance Enhancer' in 'Opzioni avanzate' per aggiungere diversità alle tue generazioni ZiT." + "La modalità multiutente supporta più utenti isolati sullo stesso server.", + "Supporto migliorato per i modelli Z-Image e FLUX.2.", + "Numerosi miglioramenti dell'interfaccia utente e nuove funzionalità Tela." ], "watchUiUpdatesOverview": "Guarda la panoramica degli aggiornamenti dell'interfaccia utente", "takeUserSurvey": "📣 Facci sapere cosa ne pensi di InvokeAI. Partecipa al nostro sondaggio sull'esperienza utente!" @@ -3056,7 +3150,9 @@ "title": "Sessioni in studio", "description": "Sessioni approfondite che esplorano le funzionalità avanzate di Invoke, i flussi di lavoro creativi e le discussioni della community." } - } + }, + "gettingStartedPlaylist": "Playlist per iniziare", + "studioSessionsPlaylist": "Playlist delle sessioni in studio" }, "modelCache": { "clear": "Cancella la cache del modello", @@ -3064,7 +3160,8 @@ "clearFailed": "Problema durante la cancellazione della cache del modello" }, "lora": { - "weight": "Peso" + "weight": "Peso", + "removeLoRA": "Rimuovi LoRA" }, "auth": { "login": { @@ -3072,7 +3169,8 @@ "rememberMe": "Ricordami per 7 giorni", "signIn": "Accedi", "signingIn": "Accesso in corso...", - "loginFailed": "Accesso non riuscito. Controlla le tue credenziali." + "loginFailed": "Accesso non riuscito. Controlla le tue credenziali.", + "sessionExpired": "Le tue credenziali sono scadute. Effettua nuovamente l'accesso per continuare." }, "setup": { "title": "Benvenuti a InvokeAI", @@ -3089,7 +3187,8 @@ "passwordsDoNotMatch": "Le password non corrispondono", "createAccount": "Crea un account amministratore", "creatingAccount": "Impostazione in corso...", - "setupFailed": "Installazione non riuscita. Riprova." + "setupFailed": "Installazione non riuscita. Riprova.", + "passwordHelperRelaxed": "Inserisci una password qualsiasi (verrà visualizzata la sua robustezza)" }, "userMenu": "Menu utente", "logout": "Esci", @@ -3139,6 +3238,19 @@ "back": "Indietro", "cannotDeleteSelf": "Non puoi eliminare il tuo account", "cannotDeactivateSelf": "Non puoi disattivare il tuo account" + }, + "passwordStrength": { + "weak": "Password debole", + "moderate": "Password moderata", + "strong": "Password forte" } + }, + "cropper": { + "cropImage": "Ritaglia l'immagine", + "aspectRatio": "Rapporto d'aspetto", + "free": "Libera", + "mouseWheelZoom": "Rotellina del mouse: Zoom", + "spaceDragPan": "Spazio + trascina: Panoramica", + "dragCropBoxToAdjust": "Trascina il riquadro di ritaglio o le maniglie per regolare" } } diff --git a/invokeai/frontend/web/public/locales/ja.json b/invokeai/frontend/web/public/locales/ja.json index 291b34cafa..ed8e438693 100644 --- a/invokeai/frontend/web/public/locales/ja.json +++ b/invokeai/frontend/web/public/locales/ja.json @@ -8,7 +8,7 @@ "back": "戻る", "statusDisconnected": "切断済", "cancel": "キャンセル", - "accept": "同意", + "accept": "確定", "img2img": "img2img", "loading": "ロード中", "githubLabel": "Github", @@ -33,9 +33,9 @@ "batch": "バッチマネージャー", "advanced": "高度", "created": "作成済", - "green": "緑", - "blue": "青", - "alpha": "アルファ", + "green": "G", + "blue": "B", + "alpha": "α", "outpaint": "outpaint", "unknown": "不明", "updated": "更新済", @@ -44,7 +44,7 @@ "copyError": "$t(gallery.copy) エラー", "data": "データ", "template": "テンプレート", - "red": "赤", + "red": "R", "or": "または", "checkpoint": "Checkpoint", "direction": "方向", @@ -157,7 +157,7 @@ "noImageSelected": "画像が選択されていません", "deleteSelection": "選択中のものを削除", "downloadSelection": "選択中のものをダウンロード", - "starImage": "スターをつける", + "starImage": "スター", "viewerImage": "閲覧画像", "compareImage": "比較画像", "openInViewer": "ビューアで開く", @@ -190,15 +190,16 @@ "selectAllOnPage": "ページ上のすべてを選択", "images": "画像", "assetsTab": "プロジェクトで使用するためにアップロードされたファイル。", - "imagesTab": "Invoke内で作成および保存された画像。", + "imagesTab": "Invoke内であなたが作成および保存した画像。", "assets": "アセット", "useForPromptGeneration": "プロンプト生成に使用する", "jump": "ジャンプ", - "noImagesInGallery": "ディスプレイに画像がありません", + "noImagesInGallery": "表示する画像がありません", "unableToLoad": "ギャラリーを読み込めません", "selectAnImageToCompare": "比較する画像を選択", "openViewer": "ビューアーを開く", - "closeViewer": "ビューアーを閉じる" + "closeViewer": "ビューアーを閉じる", + "usePagedGalleryView": "ページ型ギャラリービューを使う" }, "hotkeys": { "searchHotkeys": "ホットキーを検索", @@ -211,7 +212,7 @@ }, "useSize": { "title": "サイズを使用", - "desc": "現画像のサイズをbboxサイズとして使用する." + "desc": "現画像のサイズをバウンディングボックスのサイズとして使用する." }, "recallPrompts": { "title": "プロンプトを再使用", @@ -366,8 +367,8 @@ "desc": "矩形ツールを選択します。" }, "settings": { - "behavior": "行動", - "display": "ディスプレイ", + "behavior": "挙動", + "display": "表示", "grid": "グリッド", "debug": "デバッグ" }, @@ -388,25 +389,25 @@ "desc": "選択したインペイント マスクを反転し、反対の透明度を持つ新しいマスクを作成します。" }, "fitBboxToLayers": { - "title": "Bboxをレイヤーに合わせる", - "desc": "表示レイヤーに合わせて生成境界ボックスを自動的に調整します" + "title": "バウンディングボックスをレイヤー群に合わせる", + "desc": "表示されているレイヤーに合わせて生成バウンディングボックスを自動的に調整します" }, "fitBboxToMasks": { - "title": "Bboxをマスクにフィットさせる", - "desc": "目に見えるインペイントマスクに合わせて生成境界ボックスを自動的に調整します" + "title": "バウンディングボックスをマスクにフィットさせる", + "desc": "可視のインペイントマスクに合わせて生成バウンディングボックスを自動的に調整します" }, "toggleBbox": { - "title": "Bboxの表示/非表示を切り替える", - "desc": "生成境界ボックスを非表示または表示する" + "title": "バウンディングボックスの表示/非表示を切り替える", + "desc": "生成バウンディングボックスを非表示または表示する" }, "applySegmentAnything": { - "title": "何でもセグメント化を適用する", - "desc": "現在の「何でもセグメント」マスクを適用します。", + "title": "Segment Anythingを適用する", + "desc": "現在のSegment Anythingマスクを適用します。", "key": "入力" }, "cancelSegmentAnything": { "title": "セグメントをキャンセル", - "desc": "現在の「何でもセグメント」操作をキャンセルします。", + "desc": "現在のSegment Anything操作をキャンセルします。", "key": "エスケープ" } }, @@ -468,8 +469,8 @@ "title": "キャンバスタブを選択" }, "selectUpscalingTab": { - "desc": "アップスケーリングタブを選択します。", - "title": "アップスケーリングタブを選択" + "desc": "アップスケールタブを選択します。", + "title": "アップスケールタブを選択" }, "toggleRightPanel": { "desc": "右パネルを表示または非表示。", @@ -504,7 +505,7 @@ "desc": "カーソルをポジティブプロンプト欄に移動します。" }, "promptHistoryPrev": { - "title": "履歴の前のプロンプト", + "title": "ヒストリーの以前のプロンプト", "desc": "プロンプトにフォーカスがある場合は、履歴内の前の(古い)プロンプトに移動します。" }, "promptHistoryNext": { @@ -515,6 +516,9 @@ "title": "生成タブを選択", "desc": "生成タブを選択。", "key": "1" + }, + "promptWeightUp": { + "title": "選択したプロンプトの重みを増加" } }, "hotkeys": "ホットキー", @@ -568,7 +572,30 @@ "title": "画像にスターを付ける/スターを外す", "desc": "選択した画像にスターを付けたり、スターを外したりします。" } - } + }, + "editMode": "編集モード", + "viewMode": "ビューモード", + "editHotkey": "ホットキーの編集", + "addHotkey": "ホットキーの追加", + "resetToDefault": "デフォルトにリセット", + "resetAll": "全てをデフォルトにリセット", + "resetAllConfirmation": "すべてのホットキーをデフォルトに戻してよろしいですか?この操作は取り消せません。", + "enterHotkeys": "カンマ区切りでホットキーを入力してください", + "save": "保存", + "cancel": "キャンセル", + "modifiers": "モディファイア", + "syntaxHelp": "構文のヘルプ", + "multipleHotkeys": "カンマで区切られた複数のホットキー", + "help": "ヘルプ", + "noHotkeysRecorded": "まだホットキーが記録されていません", + "pressKeys": "キーを押してください...", + "setHotkey": "セット", + "setAnother": "他をセット", + "removeLastHotkey": "最後のホットキーを削除", + "clearAll": "全てをクリア", + "duplicateWarning": "このホットキーはすでに記録済みです", + "conflictWarning": "はすでに \"{{hotkeyTitle}}\" で使われています", + "thisHotkey": "このホットキー" }, "modelManager": { "modelManager": "モデルマネージャ", @@ -636,9 +663,9 @@ "controlLora": "コントロールLoRA", "triggerPhrases": "トリガーフレーズ", "t5Encoder": "T5エンコーダー", - "textualInversions": "テキスト反転", + "textualInversions": "Textual Inversions", "fluxRedux": "FLUX リダックス", - "installQueue": "キューをインストール", + "installQueue": "インストール進捗状況", "noMatchingModels": "マッチするモデルがありません", "noDefaultSettings": "このモデルには構成されたデフォルト設定がありません.デフォルト設定を追加するためにモデルマネージャーにアクセスしてください.", "usingDefaultSettings": "モデルのデフォルト設定を使用する", @@ -651,7 +678,7 @@ "main": "メイン", "defaultSettings": "デフォルト設定", "deleteModelImage": "モデル画像を削除", - "hfTokenInvalid": "ハギングフェイストークンが無効または見つかりません", + "hfTokenInvalid": "HuggingFaceトークンが無効または見つかりません", "hfForbiddenErrorMessage": "リポジトリにアクセスすることを勧めます.所有者はダウンロードにあたり利用規約への同意を要求する場合があります.", "noModelsInstalled": "インストールされているモデルがありません", "pathToConfig": "設定へのパス", @@ -665,8 +692,8 @@ "installRepo": "リポジトリをインストール", "localOnly": "ローカルのみ", "huggingFaceHelper": "いくつかのモデルがこのリポジトリで見つかった場合,1つを選択してインストールするように求められます.", - "hfTokenInvalidErrorMessage": "ハギングフェイストークンが無効または見つかりません.", - "hfTokenRequired": "有効なハギングフェイストークンが必要なモデルをダウンロードしようとしています.", + "hfTokenInvalidErrorMessage": "HuggingFaceトークンが無効または見つかりません。", + "hfTokenRequired": "有効なHuggingFaceトークンが必要なモデルをダウンロードしようとしています。", "hfTokenInvalidErrorMessage2": "更新してください ", "modelImageDeleted": "モデル画像削除", "repoVariant": "リポジトリバリアント", @@ -679,17 +706,17 @@ "urlOrLocalPath": "URLかローカルパス", "clipLEmbed": "クリップ-L 埋め込み", "defaultSettingsSaved": "デフォルト設定を保存しました", - "hfTokenUnableToVerify": "ハギングフェイストークンを確認できません", - "hfForbidden": "このハギングフェイスモデルにアクセスできません", - "hfTokenLabel": "ハギングフェイストークン(いくつかのモデルに必要)", + "hfTokenUnableToVerify": "HuggingFaceトークンを確認できません", + "hfForbidden": "このHuggingFaceモデルにアクセスできません", + "hfTokenLabel": "HuggingFaceトークン(いくつかのモデルに必要)", "noModelSelected": "モデルが選択されていません", "prune": "除去", - "hfTokenHelperText": "いくつかのモデルにハギングフェイストークンが必要です.ここをクリックしてあなたのトークンを作成してください.", + "hfTokenHelperText": "いくつかのモデルにHuggingFaceトークンが必要です。ここをクリックしてあなたのトークンを作成してください。", "starterBundleHelpText": "メインモデル,コントロールネット,IPアダプターなど,ベースモデルから始めるのに必要なすべてのモデルを簡単にインストールできます.バンドルを選択すると,すでにインストールされているモデルはスキップされます.", "inplaceInstallDesc": "ファイルを移動せずにモデルをインストールします.このモデルを使ったとき、元の場所からロードされます.利用できない場合、モデルファイルはInvoke管理モデルディレクトリにインストールしている間に移動されます。", - "hfTokenUnableToVerifyErrorMessage": "ハギングフェイストークンを確認できません.ネットワークによるエラーの可能性があります.後ほどトライしてください.", + "hfTokenUnableToVerifyErrorMessage": "HuggingFaceトークンを確認できません。ネットワークによるエラーの可能性があります。後ほどトライしてください。", "restoreDefaultSettings": "クリックするとモデルのデフォルト設定が使用されます.", - "hfTokenSaved": "ハギングフェイストークンを保存しました", + "hfTokenSaved": "HuggingFaceトークンを保存しました", "imageEncoderModelId": "画像エンコーダーモデルID", "includesNModels": "{{n}}個のモデルとこれらの依存関係を含みます。", "learnMoreAboutSupportedModels": "私たちのサポートしているモデルについて更に学ぶ", @@ -711,7 +738,7 @@ "modelPickerFallbackNoModelsInstalled2": "モデルマネージャー にアクセスしてモデルをインストールしてください.", "modelPickerFallbackNoModelsInstalled": "モデルがインストールされていません.", "manageModels": "モデル管理", - "hfTokenReset": "ハギングフェイストークンリセット", + "hfTokenReset": "HuggingFaceトークンをリセット", "relatedModels": "関連のあるモデル", "installedModelsCount": "{{total}} モデルのうち {{installed}} 個がインストールされています。", "allNModelsInstalled": "{{count}} 個のモデルがすべてインストールされています", @@ -719,7 +746,7 @@ "nAlreadyInstalled": "{{count}} 個すでにインストールされています", "bundleAlreadyInstalled": "バンドルがすでにインストールされています", "bundleAlreadyInstalledDesc": "{{bundleName}} バンドル内のすべてのモデルはすでにインストールされています。", - "launchpadTab": "ランチパッド", + "launchpadTab": "ローンチパッド", "launchpad": { "welcome": "モデルマネジメントへようこそ", "description": "Invoke プラットフォームのほとんどの機能を利用するには、モデルのインストールが必要です。手動インストールオプションから選択するか、厳選されたスターターモデルをご覧ください。", @@ -742,7 +769,10 @@ "installBundleMsg2": "このバンドルでは、次の {{count}} モデルがインストールされます:", "ipAdapters": "IPアダプター", "showOnlyRelatedModels": "関連している", - "starterModelsInModelManager": "スターターモデルはモデルマネージャーにあります" + "starterModelsInModelManager": "スターターモデルはモデルマネージャーにあります", + "actions": "一括操作", + "selectAll": "全て選択", + "deselectAll": "全て選択解除" }, "parameters": { "images": "画像", @@ -752,7 +782,7 @@ "seed": "シード値", "shuffle": "シャッフル", "strength": "強度", - "upscaling": "アップスケーリング", + "upscaling": "アップスケール", "scale": "スケール", "scaleBeforeProcessing": "処理前のスケール", "scaledWidth": "幅のスケール", @@ -794,10 +824,10 @@ "systemDisconnected": "システムが切断されました", "canvasIsTransforming": "キャンバスがビジー状態(変換)", "canvasIsRasterizing": "キャンバスがビジー状態(ラスタライズ)", - "modelIncompatibleBboxHeight": "Bboxの高さは{{height}}ですが,{{model}}は{{multiple}}の倍数が必要です", - "modelIncompatibleScaledBboxHeight": "bboxの高さは{{height}}ですが,{{model}}は{{multiple}}の倍数を必要です", - "modelIncompatibleBboxWidth": "Bboxの幅は{{width}}ですが, {{model}}は{{multiple}}の倍数が必要です", - "modelIncompatibleScaledBboxWidth": "bboxの幅は{{width}}ですが,{{model}}は{{multiple}}の倍数が必要です", + "modelIncompatibleBboxHeight": "バウンディングボックスの高さは{{height}}ですが,{{model}}は{{multiple}}の倍数が必要です", + "modelIncompatibleScaledBboxHeight": "バウンディングボックスの高さは{{height}}ですが,{{model}}は{{multiple}}の倍数を必要です", + "modelIncompatibleBboxWidth": "バウンディングボックスの幅は{{width}}ですが, {{model}}は{{multiple}}の倍数が必要です", + "modelIncompatibleScaledBboxWidth": "バウンディングボックスの幅は{{width}}ですが,{{model}}は{{multiple}}の倍数が必要です", "canvasIsSelectingObject": "キャンバスがビジー状態(オブジェクトの選択)", "noFLUXVAEModelSelected": "FLUX生成にVAEモデルが選択されていません", "noT5EncoderModelSelected": "FLUX生成にT5エンコーダモデルが選択されていません", @@ -806,10 +836,10 @@ "promptExpansionResultPending": "プロンプト拡張結果を受け入れるか破棄してください", "emptyBatches": "空のバッチ", "noStartingFrameImage": "開始フレーム画像がありません", - "fluxModelIncompatibleBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16)、bboxの幅は{{width}}です", - "fluxModelIncompatibleBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16)、bboxの高さは{{height}}です", - "fluxModelIncompatibleScaledBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16)、スケールされたbboxの幅は{{width}}です", - "fluxModelIncompatibleScaledBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16)、スケールされた bbox の高さは {{height}} です", + "fluxModelIncompatibleBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16)、バウンディングボックスの幅は{{width}}です", + "fluxModelIncompatibleBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16)、バウンディングボックスの高さは{{height}}です", + "fluxModelIncompatibleScaledBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16)、スケールされたバウンディングボックスの幅は{{width}}です", + "fluxModelIncompatibleScaledBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16)、スケールされたバウンディングボックスの高さは {{height}} です", "incompatibleLoRAs": "互換性のない LoRA が追加されました" }, "aspect": "縦横比", @@ -818,7 +848,7 @@ "sendToUpscale": "アップスケーラーに転送", "useSize": "サイズを使用", "postProcessing": "ポストプロセス (Shift + U)", - "denoisingStrength": "ノイズ除去強度", + "denoisingStrength": "除去ノイズ強度", "recallMetadata": "メタデータを再使用", "copyImage": "画像をコピー", "positivePromptPlaceholder": "ポジティブプロンプト", @@ -834,7 +864,7 @@ "imageFit": "初期画像を出力サイズに合わせる", "setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (おそらく大きすぎます)", "coherenceEdgeSize": "エッジサイズ", - "swapDimensions": "スワップ次元", + "swapDimensions": "縦横サイズを入れ替え", "controlNetControlMode": "制御モード", "infillColorValue": "塗りつぶし色", "coherenceMinDenoise": "最小ノイズ除去", @@ -845,7 +875,7 @@ "infillMethod": "充填法", "patchmatchDownScaleSize": "ダウンスケール", "boxBlur": "ボックスぼかし", - "remixImage": "リミックス画像", + "remixImage": "画像をリミックス", "processImage": "プロセス画像", "useCpuNoise": "CPUノイズの使用", "staged": "ステージ", @@ -997,8 +1027,8 @@ "noVisibleMasksDesc": "少なくとも1つのインペイントマスクを作成または有効にして反転します", "noInpaintMaskSelected": "インペイントマスクが選択されていません", "noInpaintMaskSelectedDesc": "反転するインペイントマスクを選択", - "invalidBbox": "無効な境界ボックス", - "invalidBboxDesc": "境界ボックスに有効な寸法がありません" + "invalidBbox": "無効なバウンディングボックス", + "invalidBboxDesc": "バウンディングボックスの寸法が有効ではありません" }, "accessibility": { "invokeProgressBar": "進捗バー", @@ -1080,7 +1110,7 @@ "batchQueuedDesc_other": "{{count}} セッションをキューの{{direction}}に追加しました", "graphQueued": "グラフをキューに追加しました", "batch": "バッチ", - "clearQueueAlertDialog": "キューをクリアすると、処理中の項目は直ちにキャンセルされ、キューは完全にクリアされます。保留中のフィルターもキャンセルされます。", + "clearQueueAlertDialog": "キューをクリアすると、処理中の項目は直ちにキャンセルされ、キューは完全にクリアされます。保留中のフィルターもキャンセルされ、ステージングエリアもリセットされます。", "pending": "保留中", "resumeFailed": "処理の再開に問題があります", "clear": "クリア", @@ -1132,7 +1162,13 @@ "sortColumn": "列の並べ替え", "sortBy": "{{column}}で並べ替え", "sortOrderAscending": "昇順", - "sortOrderDescending": "降順" + "sortOrderDescending": "降順", + "cancelFailedAccessDenied": "アイテムのキャンセル中に問題が発生しました:アクセスが拒否されました", + "clearFailedAccessDenied": "キューのクリア中に問題が発生しました:アクセスが拒否されました", + "paused": "一時停止中", + "user": "ユーザー", + "fieldValuesHidden": "<非表示>", + "cannotViewDetails": "このキューアイテムを閲覧する権限がありません" }, "models": { "noMatchingModels": "一致するモデルがありません", @@ -1179,7 +1215,7 @@ "cannotConnectInputToInput": "入力から入力には接続できません", "cannotConnectOutputToOutput": "出力から出力には接続できません", "cannotConnectToSelf": "自身のノードには接続できません", - "colorCodeEdges": "カラーコードエッジ", + "colorCodeEdges": "エッジのカラー化", "loadingNodes": "ノードを読み込み中...", "scheduler": "スケジューラー", "version": "バージョン", @@ -1197,7 +1233,7 @@ "enum": "Enum", "arithmeticSequence": "等差数列", "linearDistribution": "線形分布", - "animatedEdges": "アニメーションエッジ", + "animatedEdges": "エッジのアニメーション", "uniformRandomDistribution": "一様ランダム分布", "noBatchGroup": "グループなし", "parseString": "文字列の解析", @@ -1232,7 +1268,7 @@ "unableToUpdateNode": "ノードアップロード失敗:ノード {{node}} のタイプ {{type}} (削除か再生成が必要かもしれません)", "deletedInvalidEdge": "無効なエッジを削除しました{{source}} -> {{target}}", "collectionFieldType": "{{name}} (コレクション)", - "colorCodeEdgesHelp": "接続されたフィールドによるカラーコードエッジ", + "colorCodeEdgesHelp": "接続されたフィールド種ごとにエッジをカラー化", "showEdgeLabelsHelp": "エッジのラベルを表示,接続されているノードを示す", "sourceNodeFieldDoesNotExist": "無効なエッジ:ソース/アウトプットフィールド{{node}}.{{field}}が存在しません", "deletedMissingNodeFieldFormElement": "不足しているフォームフィールドを削除しました: ノード {{nodeId}} フィールド {{fieldName}}", @@ -1378,7 +1414,13 @@ "deletedImagesCannotBeRestored": "削除された画像は復元できません。", "hideBoards": "ボードを隠す", "locateInGalery": "ギャラリーで検索", - "viewBoards": "ボードを表示" + "viewBoards": "ボードを表示", + "pause": "一時停止", + "resume": "再開", + "restartFailed": "再起動に失敗しました", + "restartFile": "ファイルを再起動", + "restartRequired": "再起動が必要です", + "resumeRefused": "サーバーで再開が拒否されました。再起動が必要です。" }, "invocationCache": { "invocationCache": "呼び出しキャッシュ", @@ -1602,13 +1644,13 @@ "compositingMaskAdjustments": { "heading": "マスク調整", "paragraphs": [ - "マスクを調整する." + "マスクを調整する" ] }, "compositingCoherenceMinDenoise": { "paragraphs": [ - "コヒーレンスモードの最小ノイズ除去強度", - "インペインティングまたはアウトペインティング時のコヒーレンス領域の最小ノイズ除去強度" + "コヒーレンスモードの最小除去ノイズ強度", + "インペイント・アウトペイント時のコヒーレンス領域の最小除去ノイズ強度" ], "heading": "最小ノイズ除去" }, @@ -1691,7 +1733,7 @@ "たとえば, プロンプトが 5 つある場合, 各画像は同じシードを使用します.", "「画像ごと」では, 画像ごとに固有のシード値が使用されます. これにより、より多くのバリエーションが得られます." ], - "heading": "シード行動" + "heading": "シードの挙動" }, "imageFit": { "paragraphs": [ @@ -1730,7 +1772,7 @@ "optimizedDenoising": { "heading": "イメージtoイメージの最適化", "paragraphs": [ - "「イメージtoイメージを最適化」を有効にすると、Fluxモデルを用いた画像間変換およびインペインティング変換において、より段階的なノイズ除去強度スケールが適用されます。この設定により、画像に適用される変化量を制御する能力が向上しますが、標準のノイズ除去強度スケールを使用したい場合はオフにすることができます。この設定は現在調整中で、ベータ版です。" + "「イメージtoイメージを最適化」を有効にすると、Fluxモデルを用いた画像間変換およびインペイント変換において、より段階的な除去ノイズ強度スケールが適用されます。この設定により、画像に適用される変化量を制御する能力が向上しますが、標準の除去ノイズ強度スケールを使用したい場合はオフにすることができます。この設定は現在調整中で、ベータ版です。" ] }, "refinerPositiveAestheticScore": { @@ -1756,8 +1798,8 @@ "refinerModel": { "heading": "リファイナーモデル", "paragraphs": [ - "生成プロセスの精製部分で使用されるモデル。", - "世代モデルに似ています。" + "生成プロセスのリファイナー部分で使用されるモデル。", + "生成モデルに似ています。" ] }, "refinerCfgScale": { @@ -1833,7 +1875,7 @@ "tileOverlap": { "heading": "タイルオーバーラップ", "paragraphs": [ - "アップスケーリング時の隣接するタイルの重なり具合を制御します。重なり具合の値を大きくするとタイル間の継ぎ目が見えにくくなりますが、メモリ使用量は増加します。", + "アップスケール時の隣接するタイルの重なり具合を制御します。重なり具合の値を大きくするとタイル間の継ぎ目が見えにくくなりますが、メモリ使用量は増加します。", "デフォルト値の 128 はほとんどの場合に適していますが、特定のニーズやメモリの制約に基づいて調整できます。" ] } @@ -1881,8 +1923,8 @@ "resultTitle": "プロンプト拡張完了", "resultSubtitle": "拡張プロンプトの処理方法を選択します:", "insert": "挿入", - "noPromptHistory": "プロンプト履歴が記録されていません。", - "noMatchingPrompts": "履歴にマッチするプロンプトがありません。", + "noPromptHistory": "プロンプトヒストリーが記録されていません。", + "noMatchingPrompts": "マッチするプロンプトがヒストリーにありません。", "toSwitchBetweenPrompts": "プロンプトを切り替えます。" }, "ui": { @@ -1894,7 +1936,7 @@ "gallery": "ギャラリー", "workflowsTab": "$t(ui.tabs.workflows) $t(common.tab)", "modelsTab": "$t(ui.tabs.models) $t(common.tab)", - "upscaling": "アップスケーリング", + "upscaling": "アップスケール", "upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)", "generate": "生成" }, @@ -1904,7 +1946,7 @@ "scale": "スケール", "helpText": { "promptAdvice": "アップスケールする際は、媒体とスタイルを説明するプロンプトを使用してください。画像内の具体的なコンテンツの詳細を説明することは避けてください。", - "styleAdvice": "アップスケーリングは、画像の全体的なスタイルに最適です。" + "styleAdvice": "アップスケールは、画像の全体的なスタイルに最適です。" }, "uploadImage": { "title": "アップスケール用の画像をアップロードする", @@ -1957,19 +1999,19 @@ "browseAndLoadWorkflows": "既存のワークフローを参照して読み込む", "addStyleRef": { "title": "スタイル参照を追加する", - "description": "画像を追加して外観を転送します。" + "description": "外観を参照するための画像を追加しましょう。" }, "editImage": { "title": "画像を編集", - "description": "絞り込むために画像を追加します。" + "description": "リファインする画像を追加しましょう。" }, "generateFromText": { "title": "テキストから生成", - "description": "プロンプトを入力して呼び出します。" + "description": "プロンプトを入力して生成しましょう。" }, "useALayoutImage": { "title": "レイアウト画像を使用", - "description": "構成を制御するために画像を追加します。" + "description": "構図を制御するための画像を追加しましょう。" }, "generate": { "canvasCalloutTitle": "画像をさらに細かく制御、編集、反復したいですか?", @@ -1997,13 +2039,13 @@ "canvasGroup": "キャンバス", "saveToGalleryGroup": "ギャラリーに保存", "saveCanvasToGallery": "キャンバスをギャラリーに保存", - "saveBboxToGallery": "Bボックスをギャラリーに保存", + "saveBboxToGallery": "バウンディングボックスをギャラリーに保存", "newControlLayer": "新規コントロールレイヤー", "newRasterLayer": "新規ラスターレイヤー", "newInpaintMask": "新規インペイントマスク", "copyToClipboard": "クリップボードにコピー", "copyCanvasToClipboard": "キャンバスをクリップボードにコピー", - "copyBboxToClipboard": "Bボックスをクリップボードにコピー", + "copyBboxToClipboard": "バウンディングボックスをクリップボードにコピー", "newResizedControlLayer": "新しくサイズ変更されたコントロールレイヤー" }, "regionalGuidance": "領域ガイダンス", @@ -2030,7 +2072,7 @@ "rectangle": "矩形", "move": "移動", "eraser": "消しゴム", - "bbox": "Bbox", + "bbox": "バウンディングボックス", "view": "ビュー" }, "saveCanvasToGallery": "キャンバスをギャラリーに保存", @@ -2064,7 +2106,7 @@ "label": "グリッドにスナップ" }, "preserveMask": { - "label": "マスクされた領域を保持", + "label": "マスクされた領域を保護", "alert": "マスクされた領域の保存" }, "isolatedStagingPreview": "分離されたステージングプレビュー", @@ -2072,10 +2114,10 @@ "isolatedLayerPreview": "分離されたレイヤーのプレビュー", "isolatedLayerPreviewDesc": "フィルタリングや変換などの操作を実行するときに、このレイヤーのみを表示するかどうか。", "invertBrushSizeScrollDirection": "ブラシサイズのスクロール反転", - "pressureSensitivity": "圧力感度", + "pressureSensitivity": "筆圧検知", "saveAllImagesToGallery": { - "label": "ギャラリーに新しい生成を送る", - "alert": "キャンバスを経由せず、ギャラリーに新しい生成を送り込む" + "label": "ギャラリーに新しい生成画像を送る", + "alert": "キャンバスを経由せず、ギャラリーに新しい生成を送る" } }, "filter": { @@ -2093,14 +2135,14 @@ "cancel": "キャンセル", "filters": "フィルター", "filterType": "フィルタータイプ", - "autoProcess": "オートプロセス", + "autoProcess": "自動で実行", "process": "プロセス", - "advanced": "アドバンスド", + "advanced": "詳細設定", "processingLayerWith": "{{type}} フィルターを使用した処理レイヤー。", "forMoreControl": "さらに細かく制御するには、以下の「詳細設定」をクリックしてください。", "canny_edge_detection": { - "label": "キャニーエッジ検出", - "description": "Canny エッジ検出アルゴリズムを使用して、選択したレイヤーからエッジ マップを生成します。", + "label": "エッジ検出(Canny)", + "description": "Canny エッジ検出アルゴリズムを使用して、選択したレイヤーから線画を生成します。", "low_threshold": "低閾値", "high_threshold": "高閾値" }, @@ -2115,8 +2157,8 @@ "scale_factor": "スケール係数" }, "depth_anything_depth_estimation": { - "label": "デプスエニシング", - "description": "デプスエニシングモデルを使用して、選択したレイヤーから深度マップを生成します。", + "label": "深度抽出(Depth Anything)", + "description": "Depth Anthingモデルを使用して、選択したレイヤーから深度マップを生成します。", "model_size": "モデルサイズ", "model_size_small": "スモール", "model_size_small_v2": "スモールv2", @@ -2124,50 +2166,50 @@ "model_size_large": "ラージ" }, "dw_openpose_detection": { - "label": "DW オープンポーズ検出", + "label": "ポーズ検出(DW Openpose)", "description": "DW Openpose モデルを使用して、選択したレイヤー内の人間のポーズを検出します。", "draw_hands": "手を描く", "draw_face": "顔を描く", "draw_body": "体を描く" }, "hed_edge_detection": { - "label": "HEDエッジ検出", - "description": "HED エッジ検出モデルを使用して、選択したレイヤーからエッジ マップを生成します。", + "label": "エッジ検出(HED)", + "description": "HED エッジ検出モデルを使用して、選択したレイヤーから線画を生成します。", "scribble": "落書き" }, "lineart_anime_edge_detection": { - "label": "線画アニメのエッジ検出", - "description": "線画アニメエッジ検出モデルを使用して、選択したレイヤーからエッジ マップを生成します。" + "label": "エッジ検出(Lineart Anime)", + "description": "Lineart Animeエッジ検出モデルを使用して、選択したレイヤーから線画を生成します。" }, "lineart_edge_detection": { - "label": "線画エッジ検出", - "description": "線画エッジ検出モデルを使用して、選択したレイヤーからエッジ マップを生成します。", - "coarse": "粗い" + "label": "エッジ検出(Lineart)", + "description": "Linartエッジ検出モデルを使用して、選択したレイヤーから線画を生成します。", + "coarse": "粗く" }, "mediapipe_face_detection": { - "label": "メディアパイプ顔検出", - "description": "メディアパイプ顔検出モデルを使用して、選択したレイヤー内の顔を検出します。", - "max_faces": "マックスフェイス", + "label": "顔検出(MediaPipe)", + "description": "MediaPipe顔検出モデルを使用して、選択したレイヤー内の顔を検出します。", + "max_faces": "最大顔数", "min_confidence": "最小信頼度" }, "mlsd_detection": { - "label": "線分検出", - "description": "MLSD 線分検出モデルを使用して、選択したレイヤーから線分マップを生成します。", + "label": "直線検出(MLSD)", + "description": "MLSD 線分検出モデルを使用して、選択したレイヤーから直線部分を抽出します。", "score_threshold": "スコア閾値", "distance_threshold": "距離閾値" }, "normal_map": { - "label": "ノーマルマップ", + "label": "ノーマルマップ推定", "description": "選択したレイヤーからノーマルマップを生成します。" }, "pidi_edge_detection": { - "label": "PiDiNetエッジ検出", - "description": "PiDiNet エッジ検出モデルを使用して、選択したレイヤーからエッジ マップを生成します。", + "label": "エッジ検出(PiDiNet)", + "description": "PiDiNet エッジ検出モデルを使用して、選択したレイヤーから線画を生成します。", "scribble": "落書き", "quantize_edges": "エッジを量子化する" }, "img_blur": { - "label": "画像をぼかす", + "label": "ぼかし", "description": "選択したレイヤーをぼかします。", "blur_type": "ぼかしの種類", "blur_radius": "半径", @@ -2175,7 +2217,7 @@ "box_type": "ボックス" }, "img_noise": { - "label": "ノイズ画像", + "label": "ノイズ", "description": "選択したレイヤーにノイズを追加します。", "noise_type": "ノイズの種類", "noise_amount": "総計", @@ -2219,26 +2261,26 @@ "newGlobalReferenceImageError": "グローバル参照イメージの作成中に問題が発生しました", "newRegionalReferenceImageOk": "地域参照画像の作成", "newRegionalReferenceImageError": "地域参照画像の作成中に問題が発生しました", - "newControlLayerOk": "制御レイヤーの作成", + "newControlLayerOk": "作成されたコントロールレイヤー", "newControlLayerError": "制御層の作成中に問題が発生しました", "newRasterLayerOk": "ラスターレイヤーを作成しました", "newRasterLayerError": "ラスターレイヤーの作成中に問題が発生しました", - "pullBboxIntoLayerOk": "Bbox をレイヤーにプル", - "pullBboxIntoLayerError": "BBox をレイヤーにプルする際に問題が発生しました", - "pullBboxIntoReferenceImageOk": "Bbox が ReferenceImage にプルされました", - "pullBboxIntoReferenceImageError": "BBox を ReferenceImage にプルする際に問題が発生しました", + "pullBboxIntoLayerOk": "バウンディングボックスをレイヤーに", + "pullBboxIntoLayerError": "バウンディングボックスをレイヤーにする際に問題が発生しました", + "pullBboxIntoReferenceImageOk": "バウンディングボックスが参照画像にされました", + "pullBboxIntoReferenceImageError": "バウンディングボックスを参照画像にする際に問題が発生しました", "regionIsEmpty": "選択した領域は空です", "mergeVisible": "マージを可視化", "mergeVisibleOk": "マージされたレイヤー", "mergeVisibleError": "レイヤーの結合エラー", "mergingLayers": "レイヤーのマージ", "clearHistory": "履歴をクリア", - "bboxOverlay": "Bboxオーバーレイを表示", + "bboxOverlay": "バウンディングボックスのオーバーレイを表示", "ruleOfThirds": "三分割法を表示", "newSession": "新しいセッション", "clearCaches": "キャッシュをクリア", "recalculateRects": "長方形を再計算する", - "clipToBbox": "ストロークをBboxにクリップ", + "clipToBbox": "ストロークをバウンディングボックス内に制限", "outputOnlyMaskedRegions": "生成された領域のみを出力する", "width": "幅", "autoNegative": "オートネガティブ", @@ -2284,13 +2326,13 @@ "pasteTo": "貼り付け先", "pasteToAssets": "アセット", "pasteToAssetsDesc": "アセットに貼り付け", - "pasteToBbox": "Bボックス", - "pasteToBboxDesc": "新しいレイヤー(Bbox内)", + "pasteToBbox": "バウンディングボックス", + "pasteToBboxDesc": "新しいレイヤー(バウンディングボックス内)", "pasteToCanvas": "キャンバス", "pasteToCanvasDesc": "新しいレイヤー(キャンバス内)", - "transparency": "透明性", - "enableTransparencyEffect": "透明効果を有効にする", - "disableTransparencyEffect": "透明効果を無効にする", + "transparency": "透過表示", + "enableTransparencyEffect": "透過表示を有効にする", + "disableTransparencyEffect": "透過表示を無効にする", "hidingType": "{{type}} を非表示", "showingType": "{{type}}を表示", "showNonRasterLayers": "非ラスターレイヤーを表示 (Shift+H)", @@ -2301,24 +2343,24 @@ "unlocked": "ロック解除", "deleteSelected": "選択項目を削除", "replaceLayer": "レイヤーの置き換え", - "pullBboxIntoLayer": "Bboxをレイヤーに引き込む", - "pullBboxIntoReferenceImage": "Bboxを参照画像に取り込む", + "pullBboxIntoLayer": "バウンディングボックスをレイヤーに", + "pullBboxIntoReferenceImage": "バウンディングボックスを参照画像に", "showProgressOnCanvas": "キャンバスに進捗状況を表示", "useImage": "画像を使う", "negativePrompt": "ネガティブプロンプト", "beginEndStepPercentShort": "開始/終了 %", - "resetCanvasLayers": "キャンバスレイヤーをリセット", + "resetCanvasLayers": "キャンバスとレイヤーをリセット", "resetGenerationSettings": "生成設定をリセット", - "controlLayerEmptyState": "画像をアップロード、ギャラリーからこのレイヤーに画像をドラッグ、境界ボックスをこのレイヤーにプル、またはキャンバスに描画して開始します。", - "referenceImageEmptyStateWithCanvasOptions": "開始するには、画像をアップロードするか、ギャラリーからこの参照画像に画像をドラッグするか、境界ボックスをこの参照画像に引き込みます。", + "controlLayerEmptyState": "画像をアップロード、ギャラリーからこのレイヤーに画像をドラッグ、バウンディングボックスをこのレイヤーにする、またはキャンバスに描画して開始します。", + "referenceImageEmptyStateWithCanvasOptions": "開始するには、画像をアップロードするか、ギャラリーからこの参照画像に画像をドラッグするか、バウンディングボックスをこの参照画像にします。", "referenceImageEmptyState": "開始するには、画像をアップロードするか、ギャラリーからこの参照画像に画像をドラッグします。", "imageNoise": "画像ノイズ", "denoiseLimit": "ノイズ除去制限", "warnings": { "problemsFound": "問題が見つかりました", "unsupportedModel": "選択したベースモデルではレイヤーがサポートされていません", - "controlAdapterNoModelSelected": "制御レイヤーモデルが選択されていません", - "controlAdapterIncompatibleBaseModel": "互換性のない制御レイヤーベースモデル", + "controlAdapterNoModelSelected": "コントロールレイヤーのモデルが選択されていません", + "controlAdapterIncompatibleBaseModel": "コントロールレイヤーのベースモデルに互換性がありません", "controlAdapterNoControl": "コントロールが選択/描画されていません", "ipAdapterNoModelSelected": "参照画像モデルが選択されていません", "ipAdapterIncompatibleBaseModel": "互換性のない参照画像ベースモデル", @@ -2329,7 +2371,7 @@ "rgAutoNegativeNotSupported": "選択したベースモデルでは自動否定はサポートされていません", "rgNoRegion": "領域が描画されていません", "fluxFillIncompatibleWithControlLoRA": "コントロールLoRAはFLUX Fillと互換性がありません", - "bboxHidden": "境界ボックスは非表示です(Shift+O で切り替えます)" + "bboxHidden": "バウンディングボックスは非表示です(Shift+O で切り替え)" }, "errors": { "unableToFindImage": "画像が見つかりません", @@ -2370,7 +2412,7 @@ }, "selectObject": { "selectObject": "オブジェクトを選択", - "pointType": "ポイントタイプ", + "pointType": "点タイプ", "invertSelection": "選択範囲を反転", "include": "含む", "exclude": "除外", @@ -2384,7 +2426,7 @@ "dragToMove": "ポイントをドラッグして移動します", "clickToRemove": "ポイントをクリックして削除します", "desc": "対象オブジェクトを1つ選択します。選択が完了したら、適用 をクリックして選択範囲外のすべてを削除するか、選択範囲を新しいレイヤーとして保存します。", - "visualModeDesc": "ビジュアル モードでは、ボックスとポイントの入力を使用してオブジェクトを選択します。", + "visualModeDesc": "ビジュアル モードでは、ボックスと点の入力を使用してオブジェクトを選択します。", "visualMode1": "クリック&ドラッグして、選択したいオブジェクトの周囲にボックスを描きます。オブジェクトより少し大きいか小さいボックスを描くと、より良い結果が得られる場合があります。", "visualMode2": "クリックして緑の include ポイントを追加するか、Shift キーを押しながらクリックして赤の exclude ポイントを追加し、モデルに含める内容と除外する内容を指示します。", "visualMode3": "ポイントは、ボックスの選択を絞り込むために使用することも、独立して使用することもできます。", @@ -2392,13 +2434,13 @@ "promptMode1": "選択するオブジェクトの簡単な説明を入力します。", "promptMode2": "複雑な説明や複数のオブジェクトを避け、簡単な言葉を使用してください。", "model": "モデル", - "segmentAnything1": "何でもセグメント1", - "segmentAnything2": "何でもセグメント2", + "segmentAnything1": "Segment Anything 1", + "segmentAnything2": "Segment Anything 2", "prompt": "プロンプト選択" }, "HUD": { - "bbox": "Bボックス", - "scaledBbox": "スケールされたBボックス", + "bbox": "バウンディングボックス", + "scaledBbox": "スケールされたバウンディングボックス", "entityStatus": { "isFiltering": "{{title}} はフィルタリング中です", "isTransforming": "{{title}}は変化しています", @@ -2418,20 +2460,20 @@ "showResultsOn": "結果を表示", "showResultsOff": "結果を隠す" }, - "fitBboxToMasks": "Bboxをマスクにフィットさせる", + "fitBboxToMasks": "バウンディングボックスをマスクにフィットさせる", "addAdjustments": "調整を追加", "removeAdjustments": "調整を削除", "adjustments": { "simple": "シンプル", - "curves": "曲線", + "curves": "カーブ", "heading": "調整", "expand": "調整を拡張", "collapse": "折りたたみ調整", "brightness": "輝度", "contrast": "コントラスト", - "saturation": "飽和", - "temperature": "温度", - "tint": "色合い", + "saturation": "彩度", + "temperature": "色温度", + "tint": "色相", "sharpness": "シャープネス", "finish": "終了", "reset": "リセット", @@ -2475,7 +2517,8 @@ "off": "オフ", "switchOnStart": "開始時", "switchOnFinish": "終了時" - } + }, + "extractRegion": "領域を抽出" }, "stylePresets": { "clearTemplateSelection": "選択したテンプレートをクリア", @@ -2541,18 +2584,18 @@ "missingUpscaleInitialImage": "アップスケール用の初期画像がありません", "missingUpscaleModel": "アップスケールモデルがありません", "missingTileControlNetModel": "有効なタイル コントロールネットモデルがインストールされていません", - "incompatibleBaseModel": "アップスケーリングにサポートされていないメインモデルアーキテクチャです", - "incompatibleBaseModelDesc": "アップスケーリングはSD1.5およびSDXLアーキテクチャモデルでのみサポートされています。アップスケーリングを有効にするには、メインモデルを変更してください。", + "incompatibleBaseModel": "アップスケールにサポートされていないメインモデルアーキテクチャです", + "incompatibleBaseModelDesc": "アップスケールはSD1.5およびSDXLアーキテクチャモデルでのみサポートされています。アップスケールを有効にするには、メインモデルを変更してください。", "tileControl": "タイルコントロール", "tileSize": "タイルサイズ", "tileOverlap": "タイルオーバーラップ" }, "sdxl": { - "denoisingStrength": "ノイズ除去強度", + "denoisingStrength": "除去ノイズ強度", "scheduler": "スケジューラー", "loading": "ロード中...", "steps": "ステップ", - "refiner": "Refiner", + "refiner": "リファイナー", "noModelsAvailable": "利用できるモデルがありません", "cfgScale": "CFGスケール", "posAestheticScore": "ポジティブ美的スコア", @@ -2594,7 +2637,7 @@ "builder": "フォームビルダー", "text": "テキスト", "row": "行", - "multiLine": "マルチライン", + "multiLine": "テキスト(複数行)", "resetAllNodeFields": "すべてのノードフィールドをリセット", "slider": "スライダー", "layout": "レイアウト", @@ -2604,7 +2647,7 @@ "component": "コンポーネント", "textPlaceholder": "空のテキスト", "addOption": "オプションを追加", - "singleLine": "単線", + "singleLine": "テキスト", "numberInput": "数値入力", "column": "列", "container": "コンテナ", @@ -2682,7 +2725,7 @@ "delete": "削除", "loadMore": "もっと読み込む", "saveWorkflowToProject": "ワークフローをプロジェクトに保存", - "created": "作成されました", + "created": "作成順", "workflowEditorMenu": "ワークフローエディターメニュー", "recentlyOpened": "最近開いた", "opened": "オープン", @@ -2736,9 +2779,9 @@ "seedBehaviour": { "label": "シードの挙動", "perPromptLabel": "画像ごとのシード", - "perIterationLabel": "いてレーションごとのシード", + "perIterationLabel": "イテレーションごとのシード", "perPromptDesc": "それぞれの画像に足して別のシードを使う", - "perIterationDesc": "それぞれのいてレーションに別のシードを使う" + "perIterationDesc": "それぞれのイテレーションに別のシードを使う" }, "showDynamicPrompts": "ダイナミックプロンプトを表示する", "dynamicPrompts": "ダイナミックプロンプト", @@ -2758,7 +2801,7 @@ "whatsNewInInvoke": "Invokeの新機能", "items": [ "オブジェクトの選択 v2: ポイントおよびボックス入力またはテキスト プロンプトによるオブジェクト選択が改善されました。", - "ラスター レイヤーの調整: レイヤーの明るさ、コントラスト、彩度、曲線などを簡単に調整できます。" + "ラスター レイヤーの調整: レイヤーの明度、コントラスト、彩度、カーブなどを簡単に調整できます。" ], "readReleaseNotes": "リリースノートを読む", "watchRecentReleaseVideos": "最近のリリースビデオを見る", @@ -2782,5 +2825,100 @@ }, "lora": { "weight": "重み" + }, + "auth": { + "login": { + "title": "Invokeにサインイン", + "email": "Eメール", + "emailPlaceholder": "Eメール", + "password": "パスワード", + "passwordPlaceholder": "パスワード", + "rememberMe": "7日間は記憶", + "signIn": "サインイン", + "signingIn": "サインイン中...", + "loginFailed": "ログインに失敗しました。正しい内容かを確認してください。" + }, + "setup": { + "title": "Invokeへようこそ", + "subtitle": "管理者アカウントをセットアップします", + "email": "Eメール", + "emailPlaceholder": "hoge@example.com", + "emailHelper": "これはサインインに使うユーザー名になります", + "displayName": "表示名", + "displayNamePlaceholder": "管理者", + "displayNameHelper": "アプリケーションの中で表示される名前です", + "password": "パスワード", + "passwordPlaceholder": "パスワード", + "passwordHelper": "大文字、小文字、数字を組み合わせた8文字以上", + "passwordTooShort": "パスワードは8文字以上である必要があります", + "passwordMissingRequirements": "パスワードは小文字、大文字、数字を含まなければなりません", + "confirmPassword": "パスワードの確認", + "confirmPasswordPlaceholder": "パスワードの確認", + "passwordsDoNotMatch": "パスワードが一致しません", + "createAccount": "管理者アカウントを作る", + "creatingAccount": "設定中...", + "setupFailed": "セットアップに失敗しました。もう一度試してください。", + "passwordHelperRelaxed": "パスワードを入力してください(強度が表示されます)" + }, + "userMenu": "ユーザーメニュー", + "admin": "管理", + "logout": "ログアウト", + "adminOnlyFeature": "この機能は管理者のみ使用できます。", + "profile": { + "menuItem": "プロフィール", + "title": "プロフィール", + "email": "Eメール", + "emailReadOnly": "Eメールアドレスは変更できません", + "displayName": "表示名", + "displayNamePlaceholder": "あなたの名前", + "changePassword": "パスワードの変更", + "currentPassword": "現在のパスワード", + "currentPasswordPlaceholder": "現在のパスワード", + "newPassword": "新しいパスワード", + "newPasswordPlaceholder": "新しいパスワード", + "confirmPassword": "新しいパスワードの確認", + "confirmPasswordPlaceholder": "新しいパスワードの確認", + "passwordsDoNotMatch": "パスワードが一致しません", + "saveSuccess": "プロフィールのアップデートに成功しました", + "saveFailed": "プロフィールの保存に失敗しました。もう一度試してください。" + }, + "userManagement": { + "menuItem": "ユーザー管理", + "title": "ユーザー管理", + "email": "Eメール", + "emailPlaceholder": "hoge@example.com", + "displayName": "表示名", + "displayNamePlaceholder": "表示名", + "password": "パスワード", + "passwordPlaceholder": "パスワード", + "newPassword": "新しいパスワード", + "newPasswordPlaceholder": "現在のパスワードを維持するには空白にしておいてください", + "role": "ロール", + "status": "ステータス", + "actions": "アクション", + "isAdmin": "管理者", + "user": "ユーザー", + "you": "あなた", + "createUser": "ユーザーの作成", + "editUser": "ユーザーの編集", + "deleteUser": "ユーザーの削除", + "deleteConfirm": "本当に \"{{name}}\" を削除しますか?このアクションは取り消せません。", + "generatePassword": "強力なパスワードを生成", + "showPassword": "パスワードの表示", + "hidePassword": "パスワードを隠す", + "activate": "有効化", + "deactivate": "非有効化", + "saveFailed": "ユーザーの保存に失敗しました。もう一度実行してください。", + "deleteFailed": "ユーザーの削除に失敗しました。もう一度実行してください。", + "loadFailed": "ユーザーのロードに失敗しました。", + "back": "戻る", + "cannotDeleteSelf": "あなた自身のアカウントを削除することはできません", + "cannotDeactivateSelf": "あなた自身のアカウントを非有効化することはできません" + }, + "passwordStrength": { + "weak": "弱いパスワード", + "moderate": "適切なパスワード", + "strong": "強力なパスワード" + } } } diff --git a/invokeai/frontend/web/scripts/typegen.js b/invokeai/frontend/web/scripts/typegen.js index 7b3a747285..87c00a2883 100644 --- a/invokeai/frontend/web/scripts/typegen.js +++ b/invokeai/frontend/web/scripts/typegen.js @@ -34,7 +34,37 @@ async function generateTypes(schema) { }, defaultNonNullable: false, }); - fs.writeFileSync(OUTPUT_FILE, astToString(types)); + let output = astToString(types); + + // Post-process: openapi-typescript sometimes computes enum types from `const` + // usage in discriminated unions rather than from the enum definition itself, + // dropping values that only appear in some union members. Patch the generated + // output to match the OpenAPI schema's actual enum definitions. + // + // The `schema` parameter is a parsed JSON object when piped from stdin, or + // a URL/Buffer when passed as an argument. We only patch in the JSON case. + if (schema && typeof schema === 'object' && !Buffer.isBuffer(schema)) { + const schemas = schema.components?.schemas; + if (schemas) { + // Collect all string enum types and their expected values from the OpenAPI schema + for (const [typeName, typeDef] of Object.entries(schemas)) { + if (typeDef && typeDef.type === 'string' && Array.isArray(typeDef.enum)) { + const expectedUnion = typeDef.enum.map((v) => `"${v}"`).join(' | '); + // Match the type definition line. These appear as: + // `TypeName: "val1" | "val2" | ...;` + // Use word boundary to avoid matching types that contain this + // type name as a substring (e.g. ModelType vs BaseModelType). + const regex = new RegExp(`(\\b${typeName}: )"[^;]+(;)`); + const match = output.match(regex); + if (match) { + output = output.replace(regex, `$1${expectedUnion}$2`); + } + } + } + } + } + + fs.writeFileSync(OUTPUT_FILE, output); process.stdout.write(`\nOK!\r\n`); } diff --git a/invokeai/frontend/web/src/app/components/GlobalModalIsolator.tsx b/invokeai/frontend/web/src/app/components/GlobalModalIsolator.tsx index 5c1446662e..e5ec5ccc56 100644 --- a/invokeai/frontend/web/src/app/components/GlobalModalIsolator.tsx +++ b/invokeai/frontend/web/src/app/components/GlobalModalIsolator.tsx @@ -1,6 +1,9 @@ import { GlobalImageHotkeys } from 'app/components/GlobalImageHotkeys'; import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal'; import { CanvasPasteModal } from 'features/controlLayers/components/CanvasPasteModal'; +import { CanvasWorkflowIntegrationModal } from 'features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationModal'; +import { LoadCanvasProjectConfirmationAlertDialog } from 'features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog'; +import { SaveCanvasProjectDialog } from 'features/controlLayers/components/SaveCanvasProjectDialog'; import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate'; import { CropImageModal } from 'features/cropper/components/CropImageModal'; import { DeleteImageModal } from 'features/deleteImageModal/components/DeleteImageModal'; @@ -51,7 +54,10 @@ export const GlobalModalIsolator = memo(() => { + + + diff --git a/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx b/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx index 437cecd492..62b7114288 100644 --- a/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx +++ b/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx @@ -2,6 +2,7 @@ import '@fontsource-variable/inter'; import 'overlayscrollbars/overlayscrollbars.css'; import '@xyflow/react/dist/base.css'; import 'common/components/OverlayScrollbars/overlayscrollbars.css'; +import 'app/components/touchDevice.css'; import { ChakraProvider, DarkMode, extendTheme, theme as baseTheme, TOAST_OPTIONS } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; diff --git a/invokeai/frontend/web/src/app/components/touchDevice.css b/invokeai/frontend/web/src/app/components/touchDevice.css new file mode 100644 index 0000000000..4753e2a9a8 --- /dev/null +++ b/invokeai/frontend/web/src/app/components/touchDevice.css @@ -0,0 +1,7 @@ +/* Hide tooltips on touch devices where hover gets "stuck" */ +@media (hover: none) { + [role='tooltip'] { + visibility: hidden !important; + opacity: 0 !important; + } +} diff --git a/invokeai/frontend/web/src/app/logging/logger.ts b/invokeai/frontend/web/src/app/logging/logger.ts index 6c843068df..d20ef77090 100644 --- a/invokeai/frontend/web/src/app/logging/logger.ts +++ b/invokeai/frontend/web/src/app/logging/logger.ts @@ -16,6 +16,7 @@ const $logger = atom(Roarr.child(BASE_CONTEXT)); export const zLogNamespace = z.enum([ 'canvas', + 'canvas-workflow-integration', 'config', 'dnd', 'events', diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx index e0e72d12ff..fa4c29b8f4 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx @@ -12,10 +12,14 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) = effect: (action) => { log.debug(action.payload, 'Bulk download requested'); - // If we have an item name, we are processing the bulk download locally and should use it as the toast id to - // prevent multiple toasts for the same item. + // Use a "preparing:" prefix so this toast cannot collide with the + // "ready to download" toast that arrives via the bulk_download_complete + // socket event. The background task can complete in under 20ms, so the + // socket event may arrive *before* this Redux middleware runs — without + // distinct IDs the "preparing" toast would overwrite the "ready" toast. + const itemName = action.payload.bulk_download_item_name; toast({ - id: action.payload.bulk_download_item_name ?? undefined, + id: itemName ? `preparing:${itemName}` : undefined, title: t('gallery.bulkDownloadRequested'), status: 'success', // Show the response message if it exists, otherwise show the default message diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts index ed2c67d529..1c7941106b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts @@ -4,9 +4,15 @@ import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/c import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice'; import { loraIsEnabledChanged } from 'features/controlLayers/store/lorasSlice'; import { + animaQwen3EncoderModelSelected, + animaT5EncoderModelSelected, + animaVaeModelSelected, + aspectRatioIdChanged, kleinQwen3EncoderModelSelected, kleinVaeModelSelected, modelChanged, + qwenImageComponentSourceSelected, + resolutionPresetSelected, setZImageScheduler, syncedToOptimalDimension, vaeSelected, @@ -24,12 +30,18 @@ import { selectBboxModelBase, selectCanvasSlice, } from 'features/controlLayers/store/selectors'; -import { getEntityIdentifier, isFlux2ReferenceImageConfig } from 'features/controlLayers/store/types'; +import { + getEntityIdentifier, + isAspectRatioID, + isFlux2ReferenceImageConfig, + isQwenImageReferenceImageConfig, +} from 'features/controlLayers/store/types'; import { initialFlux2ReferenceImage, initialFluxKontextReferenceImage, initialFLUXRedux, initialIPAdapter, + initialQwenImageReferenceImage, } from 'features/controlLayers/store/util'; import { SUPPORTS_REF_IMAGES_BASE_MODELS } from 'features/modelManagerV2/models'; import { zModelIdentifierField } from 'features/nodes/types/common'; @@ -39,14 +51,18 @@ import { toast } from 'features/toast/toast'; import { t } from 'i18next'; import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models'; import { + selectAnimaQwen3EncoderModels, + selectAnimaVAEModels, selectFluxVAEModels, selectGlobalRefImageModels, selectQwen3EncoderModels, + selectQwenImageDiffusersModels, selectRegionalRefImageModels, + selectT5EncoderModels, selectZImageDiffusersModels, } from 'services/api/hooks/modelsByType'; import type { FLUXKontextModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types'; -import { isFluxKontextModelConfig, isFluxReduxModelConfig } from 'services/api/types'; +import { isExternalApiModelConfig, isFluxKontextModelConfig, isFluxReduxModelConfig } from 'services/api/types'; const log = logger('models'); @@ -155,6 +171,68 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) = } } + // handle incompatible Anima models - clear if switching away from anima + const { animaVaeModel, animaQwen3EncoderModel, animaT5EncoderModel } = state.params; + if (newBase !== 'anima') { + if (animaVaeModel) { + dispatch(animaVaeModelSelected(null)); + modelsUpdatedDisabledOrCleared += 1; + } + if (animaQwen3EncoderModel) { + dispatch(animaQwen3EncoderModelSelected(null)); + modelsUpdatedDisabledOrCleared += 1; + } + if (animaT5EncoderModel) { + dispatch(animaT5EncoderModelSelected(null)); + modelsUpdatedDisabledOrCleared += 1; + } + } else { + // Switching to Anima - set defaults if no valid configuration exists + const hasValidConfig = animaVaeModel && animaQwen3EncoderModel && animaT5EncoderModel; + + if (!hasValidConfig) { + const availableQwen3Encoders = selectAnimaQwen3EncoderModels(state); + const availableAnimaVAEs = selectAnimaVAEModels(state); + const availableT5Encoders = selectT5EncoderModels(state); + + if (availableQwen3Encoders.length > 0 && availableAnimaVAEs.length > 0) { + const qwen3Encoder = availableQwen3Encoders[0]; + const fluxVAE = availableAnimaVAEs[0]; + + if (qwen3Encoder && !animaQwen3EncoderModel) { + dispatch( + animaQwen3EncoderModelSelected({ + key: qwen3Encoder.key, + name: qwen3Encoder.name, + base: qwen3Encoder.base, + }) + ); + } + if (fluxVAE && !animaVaeModel) { + dispatch( + animaVaeModelSelected({ + key: fluxVAE.key, + hash: fluxVAE.hash, + name: fluxVAE.name, + base: fluxVAE.base, + type: fluxVAE.type, + }) + ); + } + const t5Encoder = availableT5Encoders[0]; + if (t5Encoder && !animaT5EncoderModel) { + dispatch( + animaT5EncoderModelSelected({ + key: t5Encoder.key, + name: t5Encoder.name, + base: t5Encoder.base, + }) + ); + } + } + } + } + // handle incompatible FLUX.2 Klein models - clear if switching away from flux2 const { kleinVaeModel, kleinQwen3EncoderModel } = state.params; if (newBase !== 'flux2') { @@ -168,6 +246,44 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) = } } + // handle incompatible Qwen Image Edit component source - clear if switching away + const { qwenImageComponentSource } = state.params; + if (newBase !== 'qwen-image') { + if (qwenImageComponentSource) { + dispatch(qwenImageComponentSourceSelected(null)); + modelsUpdatedDisabledOrCleared += 1; + } + } else { + // Switching to Qwen Image - auto-default component source to a matching diffusers model + if (!qwenImageComponentSource) { + const availableQwenImageDiffusers = selectQwenImageDiffusersModels(state); + + // Look up the new model's variant to match generate vs edit + const modelConfigsResult = selectModelConfigsQuery(state); + let selectedVariant: string | null = null; + if (modelConfigsResult.data) { + const newModelConfig = modelConfigsAdapterSelectors.selectById(modelConfigsResult.data, newModel.key); + if (newModelConfig && 'variant' in newModelConfig && typeof newModelConfig.variant === 'string') { + selectedVariant = newModelConfig.variant; + } + } + + // Find a diffusers model matching the variant; if no variant on denoiser, prefer "generate" then "edit" + const variantToMatch = selectedVariant ?? 'generate'; + const matchingModel = availableQwenImageDiffusers.find( + (m) => 'variant' in m && m.variant === variantToMatch + ); + const fallbackModel = availableQwenImageDiffusers.find( + (m) => 'variant' in m && m.variant !== variantToMatch + ); + const diffusersModel = matchingModel ?? fallbackModel ?? availableQwenImageDiffusers[0]; + + if (diffusersModel) { + dispatch(qwenImageComponentSourceSelected(zModelIdentifierField.parse(diffusersModel))); + } + } + } + if (newModel.base !== 'external' && SUPPORTS_REF_IMAGES_BASE_MODELS.includes(newModel.base)) { // Handle incompatible reference image models - switch to first compatible model, with some smart logic // to choose the best available model based on the new main model. @@ -210,6 +326,20 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) = continue; } + if (newBase === 'qwen-image') { + // Switching TO Qwen Image Edit - convert any non-qwen configs to qwen_image_reference_image + if (!isQwenImageReferenceImageConfig(entity.config)) { + dispatch( + refImageConfigChanged({ + id: entity.id, + config: { ...initialQwenImageReferenceImage }, + }) + ); + modelsUpdatedDisabledOrCleared += 1; + } + continue; + } + if (isFlux2ReferenceImageConfig(entity.config)) { // Switching AWAY from FLUX.2 - convert flux2_reference_image to the appropriate config type let newConfig; @@ -234,6 +364,30 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) = continue; } + if (isQwenImageReferenceImageConfig(entity.config)) { + // Switching AWAY from Qwen Image Edit - convert to the appropriate config type + let newConfig; + if (newGlobalRefImageModel) { + const parsedModel = zModelIdentifierField.parse(newGlobalRefImageModel); + if (newModel.base === 'flux' && newModel.name.toLowerCase().includes('kontext')) { + newConfig = { ...initialFluxKontextReferenceImage, model: parsedModel }; + } else if (newGlobalRefImageModel.type === 'flux_redux') { + newConfig = { ...initialFLUXRedux, model: parsedModel }; + } else { + newConfig = { ...initialIPAdapter, model: parsedModel }; + if (parsedModel.base === 'flux') { + newConfig.clipVisionModel = 'ViT-L'; + } + } + } else { + // No compatible model found - fall back to an empty IP adapter config + newConfig = { ...initialIPAdapter }; + } + dispatch(refImageConfigChanged({ id: entity.id, config: newConfig })); + modelsUpdatedDisabledOrCleared += 1; + continue; + } + // Standard handling for non-flux2 configs const shouldUpdateModel = (entity.config.model && entity.config.model.base !== newBase) || @@ -321,6 +475,32 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) = } } + // Handle Qwen Image model changes within the same base (variant may change between generate/edit) + // Auto-update the component source diffusers model to match the new variant + if ( + newBase === 'qwen-image' && + state.params.model?.base === 'qwen-image' && + newModel.key !== state.params.model?.key + ) { + const modelConfigsResult = selectModelConfigsQuery(state); + if (modelConfigsResult.data) { + const newModelConfig = modelConfigsAdapterSelectors.selectById(modelConfigsResult.data, newModel.key); + const newVariant = + newModelConfig && 'variant' in newModelConfig && typeof newModelConfig.variant === 'string' + ? newModelConfig.variant + : 'generate'; + + const availableQwenImageDiffusers = selectQwenImageDiffusersModels(state); + const matchingModel = availableQwenImageDiffusers.find((m) => 'variant' in m && m.variant === newVariant); + const fallbackModel = availableQwenImageDiffusers.find((m) => 'variant' in m && m.variant !== newVariant); + const diffusersModel = matchingModel ?? fallbackModel ?? availableQwenImageDiffusers[0]; + + if (diffusersModel) { + dispatch(qwenImageComponentSourceSelected(zModelIdentifierField.parse(diffusersModel))); + } + } + } + // Handle Z-Image scheduler when switching to Z-Image Base (zbase) model // LCM is not supported for undistilled models, so reset to euler if (newBase === 'z-image' && state.params.zImageScheduler === 'lcm') { @@ -352,6 +532,34 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) = dispatch(bboxSyncedToOptimalDimension()); } } + + // When switching to an external model, sync bbox to the model's first preset dimensions + if (newBase === 'external') { + const modelConfigsResult = selectModelConfigsQuery(getState()); + if (modelConfigsResult.data) { + const newModelConfig = modelConfigsAdapterSelectors.selectById(modelConfigsResult.data, newModel.key); + if (newModelConfig && isExternalApiModelConfig(newModelConfig)) { + const { aspect_ratio_sizes, resolution_presets } = newModelConfig.capabilities; + if (resolution_presets && resolution_presets.length > 0) { + const firstPreset = resolution_presets[0]!; + dispatch( + resolutionPresetSelected({ + imageSize: firstPreset.image_size, + aspectRatio: firstPreset.aspect_ratio, + width: firstPreset.width, + height: firstPreset.height, + }) + ); + } else if (aspect_ratio_sizes) { + const firstRatio = Object.keys(aspect_ratio_sizes)[0]; + const firstSize = firstRatio ? aspect_ratio_sizes[firstRatio] : undefined; + if (firstRatio && firstSize && isAspectRatioID(firstRatio)) { + dispatch(aspectRatioIdChanged({ id: firstRatio, fixedSize: firstSize })); + } + } + } + } + } }, }); }; diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index 8f077baaea..f24d2d0105 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -25,6 +25,7 @@ import { canvasSettingsSliceConfig } from 'features/controlLayers/store/canvasSe import { canvasSliceConfig } from 'features/controlLayers/store/canvasSlice'; import { canvasSessionSliceConfig } from 'features/controlLayers/store/canvasStagingAreaSlice'; import { canvasTextSliceConfig } from 'features/controlLayers/store/canvasTextSlice'; +import { canvasWorkflowIntegrationSliceConfig } from 'features/controlLayers/store/canvasWorkflowIntegrationSlice'; import { lorasSliceConfig } from 'features/controlLayers/store/lorasSlice'; import { paramsSliceConfig } from 'features/controlLayers/store/paramsSlice'; import { refImagesSliceConfig } from 'features/controlLayers/store/refImagesSlice'; @@ -67,6 +68,7 @@ const SLICE_CONFIGS = { [canvasSettingsSliceConfig.slice.reducerPath]: canvasSettingsSliceConfig, [canvasTextSliceConfig.slice.reducerPath]: canvasTextSliceConfig, [canvasSliceConfig.slice.reducerPath]: canvasSliceConfig, + [canvasWorkflowIntegrationSliceConfig.slice.reducerPath]: canvasWorkflowIntegrationSliceConfig, [changeBoardModalSliceConfig.slice.reducerPath]: changeBoardModalSliceConfig, [dynamicPromptsSliceConfig.slice.reducerPath]: dynamicPromptsSliceConfig, [gallerySliceConfig.slice.reducerPath]: gallerySliceConfig, @@ -98,6 +100,7 @@ const ALL_REDUCERS = { canvasSliceConfig.slice.reducer, canvasSliceConfig.undoableConfig?.reduxUndoOptions ), + [canvasWorkflowIntegrationSliceConfig.slice.reducerPath]: canvasWorkflowIntegrationSliceConfig.slice.reducer, [changeBoardModalSliceConfig.slice.reducerPath]: changeBoardModalSliceConfig.slice.reducer, [dynamicPromptsSliceConfig.slice.reducerPath]: dynamicPromptsSliceConfig.slice.reducer, [gallerySliceConfig.slice.reducerPath]: gallerySliceConfig.slice.reducer, diff --git a/invokeai/frontend/web/src/common/components/ColorPicker/RgbaColorPicker.tsx b/invokeai/frontend/web/src/common/components/ColorPicker/RgbaColorPicker.tsx index 6bcf546fd3..38fcb2696e 100644 --- a/invokeai/frontend/web/src/common/components/ColorPicker/RgbaColorPicker.tsx +++ b/invokeai/frontend/web/src/common/components/ColorPicker/RgbaColorPicker.tsx @@ -85,7 +85,7 @@ const RgbaColorPicker = (props: Props) => { h={10} whiteSpace="nowrap" onClick={onToggleMode} - aria-label="Toggle RGB/HEX" + aria-label={t('common.toggleRgbHex')} > RGB @@ -144,12 +144,12 @@ const RgbaColorPicker = (props: Props) => { h={10} whiteSpace="nowrap" onClick={onToggleMode} - aria-label="Toggle RGB/HEX" + aria-label={t('common.toggleRgbHex')} > HEX - {t('common.hex', { defaultValue: 'Hex' })} + {t('common.hex')} diff --git a/invokeai/frontend/web/src/common/components/Picker/Picker.tsx b/invokeai/frontend/web/src/common/components/Picker/Picker.tsx index ffd0b30242..b70e44dd64 100644 --- a/invokeai/frontend/web/src/common/components/Picker/Picker.tsx +++ b/invokeai/frontend/web/src/common/components/Picker/Picker.tsx @@ -867,7 +867,7 @@ const GroupToggleButtons = typedMemo(() => { } return ( - + {groups.map((group) => ( ))} @@ -927,6 +927,7 @@ const GroupToggleButton = typedMemo(({ group }: { group: Group size="xs" variant="solid" userSelect="none" + flexShrink={0} bg={bg} color={color} borderColor={groupColor} diff --git a/invokeai/frontend/web/src/common/hooks/focus.test.ts b/invokeai/frontend/web/src/common/hooks/focus.test.ts new file mode 100644 index 0000000000..c106fe1cec --- /dev/null +++ b/invokeai/frontend/web/src/common/hooks/focus.test.ts @@ -0,0 +1,13 @@ +import { describe, expect, it } from 'vitest'; + +import { getFocusedRegion, setFocusedRegion } from './focus'; + +describe('focus regions', () => { + it('supports the workflows region', () => { + setFocusedRegion('workflows'); + expect(getFocusedRegion()).toBe('workflows'); + + setFocusedRegion(null); + expect(getFocusedRegion()).toBe(null); + }); +}); diff --git a/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx b/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx index 445f58c9a6..fc173de979 100644 --- a/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx +++ b/invokeai/frontend/web/src/common/hooks/useImageUploadButton.tsx @@ -187,11 +187,12 @@ export const UploadImageIconButton = memo( onUpload?: (imageDTO: ImageDTO) => void; isError?: boolean; } & SetOptional) => { + const { t } = useTranslation(); const uploadApi = useImageUploadButton({ isDisabled, allowMultiple: false, onUpload }); return ( <> { + const { t } = useTranslation(); const { children, isDisabled = false, onUpload, isError = false, ...rest } = props; const uploadApi = useImageUploadButton({ isDisabled, allowMultiple: false, onUpload }); return ( <> + + + + + + + ); +}); + +CanvasWorkflowIntegrationModal.displayName = 'CanvasWorkflowIntegrationModal'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationParameterPanel.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationParameterPanel.tsx new file mode 100644 index 0000000000..f59a6c45ed --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationParameterPanel.tsx @@ -0,0 +1,13 @@ +import { Box } from '@invoke-ai/ui-library'; +import { WorkflowFormPreview } from 'features/controlLayers/components/CanvasWorkflowIntegration/WorkflowFormPreview'; +import { memo } from 'react'; + +export const CanvasWorkflowIntegrationParameterPanel = memo(() => { + return ( + + + + ); +}); + +CanvasWorkflowIntegrationParameterPanel.displayName = 'CanvasWorkflowIntegrationParameterPanel'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationWorkflowSelector.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationWorkflowSelector.tsx new file mode 100644 index 0000000000..30bc60605c --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationWorkflowSelector.tsx @@ -0,0 +1,92 @@ +import { Flex, FormControl, FormLabel, Select, Spinner, Text } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { + canvasWorkflowIntegrationWorkflowSelected, + selectCanvasWorkflowIntegrationSelectedWorkflowId, +} from 'features/controlLayers/store/canvasWorkflowIntegrationSlice'; +import type { ChangeEvent } from 'react'; +import { memo, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useListWorkflowsInfiniteInfiniteQuery } from 'services/api/endpoints/workflows'; + +import { useFilteredWorkflows } from './useFilteredWorkflows'; + +export const CanvasWorkflowIntegrationWorkflowSelector = memo(() => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + + const selectedWorkflowId = useAppSelector(selectCanvasWorkflowIntegrationSelectedWorkflowId); + const { data: workflowsData, isLoading } = useListWorkflowsInfiniteInfiniteQuery( + { + per_page: 100, // Get a reasonable number of workflows + page: 0, + }, + { + selectFromResult: ({ data, isLoading }) => ({ + data, + isLoading, + }), + } + ); + + const workflows = useMemo(() => { + if (!workflowsData) { + return []; + } + // Flatten all pages into a single list + return workflowsData.pages.flatMap((page) => page.items); + }, [workflowsData]); + + // Filter workflows to only show those with ImageFields + const { filteredWorkflows, isFiltering } = useFilteredWorkflows(workflows); + + const onChange = useCallback( + (e: ChangeEvent) => { + const workflowId = e.target.value || null; + dispatch(canvasWorkflowIntegrationWorkflowSelected({ workflowId })); + }, + [dispatch] + ); + + if (isLoading || isFiltering) { + return ( + + + + {isFiltering + ? t('controlLayers.workflowIntegration.filteringWorkflows') + : t('controlLayers.workflowIntegration.loadingWorkflows')} + + + ); + } + + if (filteredWorkflows.length === 0) { + return ( + + {workflows.length === 0 + ? t('controlLayers.workflowIntegration.noWorkflowsFound') + : t('controlLayers.workflowIntegration.noWorkflowsWithImageField')} + + ); + } + + return ( + + {t('controlLayers.workflowIntegration.selectWorkflow')} + + + ); +}); + +CanvasWorkflowIntegrationWorkflowSelector.displayName = 'CanvasWorkflowIntegrationWorkflowSelector'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/WorkflowFieldRenderer.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/WorkflowFieldRenderer.tsx new file mode 100644 index 0000000000..2d91be13bf --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasWorkflowIntegration/WorkflowFieldRenderer.tsx @@ -0,0 +1,548 @@ +import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; +import { + Combobox, + Flex, + FormControl, + FormLabel, + IconButton, + Input, + Radio, + Select, + Switch, + Text, + Textarea, +} from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; +import { skipToken } from '@reduxjs/toolkit/query'; +import { logger } from 'app/logging/logger'; +import { EMPTY_ARRAY } from 'app/store/constants'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { UploadImageIconButton } from 'common/hooks/useImageUploadButton'; +import { + canvasWorkflowIntegrationFieldValueChanged, + canvasWorkflowIntegrationImageFieldSelected, + selectCanvasWorkflowIntegrationFieldValues, + selectCanvasWorkflowIntegrationSelectedImageFieldKey, + selectCanvasWorkflowIntegrationSelectedWorkflowId, +} from 'features/controlLayers/store/canvasWorkflowIntegrationSlice'; +import { DndImage } from 'features/dnd/DndImage'; +import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox'; +import { $templates } from 'features/nodes/store/nodesSlice'; +import type { NodeFieldElement } from 'features/nodes/types/workflow'; +import { SCHEDULER_OPTIONS } from 'features/parameters/types/constants'; +import { isParameterScheduler } from 'features/parameters/types/parameterSchemas'; +import type { ChangeEvent } from 'react'; +import { memo, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiTrashSimpleBold } from 'react-icons/pi'; +import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; +import { useGetImageDTOQuery } from 'services/api/endpoints/images'; +import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models'; +import { useGetWorkflowQuery } from 'services/api/endpoints/workflows'; +import type { AnyModelConfig, ImageDTO } from 'services/api/types'; + +const log = logger('canvas-workflow-integration'); + +interface WorkflowFieldRendererProps { + el: NodeFieldElement; +} + +export const WorkflowFieldRenderer = memo(({ el }: WorkflowFieldRendererProps) => { + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + const selectedWorkflowId = useAppSelector(selectCanvasWorkflowIntegrationSelectedWorkflowId); + const fieldValues = useAppSelector(selectCanvasWorkflowIntegrationFieldValues); + const selectedImageFieldKey = useAppSelector(selectCanvasWorkflowIntegrationSelectedImageFieldKey); + const templates = useStore($templates); + + const { data: workflow } = useGetWorkflowQuery(selectedWorkflowId!, { + skip: !selectedWorkflowId, + }); + + // Load boards and models for BoardField and ModelIdentifierField + const { data: boardsData } = useListAllBoardsQuery({ include_archived: true }); + const { data: modelsData, isLoading: isLoadingModels } = useGetModelConfigsQuery(); + + const { fieldIdentifier } = el.data; + const fieldKey = `${fieldIdentifier.nodeId}.${fieldIdentifier.fieldName}`; + + log.debug({ fieldIdentifier, fieldKey }, 'Rendering workflow field'); + + // Get the node, field instance, and field template + const { field, fieldTemplate } = useMemo(() => { + if (!workflow?.workflow.nodes) { + log.warn('No workflow nodes found'); + return { field: null, fieldTemplate: null }; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const foundNode = workflow.workflow.nodes.find((n: any) => n.data.id === fieldIdentifier.nodeId); + if (!foundNode) { + log.warn({ nodeId: fieldIdentifier.nodeId }, 'Node not found'); + return { field: null, fieldTemplate: null }; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const foundField = (foundNode.data as any).inputs[fieldIdentifier.fieldName]; + if (!foundField) { + log.warn({ nodeId: fieldIdentifier.nodeId, fieldName: fieldIdentifier.fieldName }, 'Field not found in node'); + return { field: null, fieldTemplate: null }; + } + + // Get the field template from the invocation templates + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const nodeType = (foundNode.data as any).type; + const template = templates[nodeType]; + if (!template) { + log.warn({ nodeType }, 'No template found for node type'); + return { field: foundField, fieldTemplate: null }; + } + + const foundFieldTemplate = template.inputs[fieldIdentifier.fieldName]; + if (!foundFieldTemplate) { + log.warn({ nodeType, fieldName: fieldIdentifier.fieldName }, 'Field template not found'); + return { field: foundField, fieldTemplate: null }; + } + + return { field: foundField, fieldTemplate: foundFieldTemplate }; + }, [workflow, fieldIdentifier, templates]); + + // Get the current value from Redux or fallback to field default + const currentValue = useMemo(() => { + if (fieldValues && fieldKey in fieldValues) { + return fieldValues[fieldKey]; + } + + return field?.value ?? fieldTemplate?.default ?? ''; + }, [fieldValues, fieldKey, field, fieldTemplate]); + + // Get field type from the template + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const fieldType = fieldTemplate ? (fieldTemplate as any).type?.name : null; + + const handleChange = useCallback( + (value: unknown) => { + dispatch(canvasWorkflowIntegrationFieldValueChanged({ fieldName: fieldKey, value })); + }, + [dispatch, fieldKey] + ); + + const handleStringChange = useCallback( + (e: ChangeEvent) => { + handleChange(e.target.value); + }, + [handleChange] + ); + + const handleNumberChange = useCallback( + (e: ChangeEvent) => { + const val = fieldType === 'IntegerField' ? parseInt(e.target.value, 10) : parseFloat(e.target.value); + handleChange(isNaN(val) ? 0 : val); + }, + [handleChange, fieldType] + ); + + const handleBooleanChange = useCallback( + (e: ChangeEvent) => { + handleChange(e.target.checked); + }, + [handleChange] + ); + + const handleSelectChange = useCallback( + (e: ChangeEvent) => { + handleChange(e.target.value); + }, + [handleChange] + ); + + // SchedulerField handlers + const handleSchedulerChange = useCallback( + (v) => { + if (!isParameterScheduler(v?.value)) { + return; + } + handleChange(v.value); + }, + [handleChange] + ); + + const schedulerValue = useMemo(() => SCHEDULER_OPTIONS.find((o) => o.value === currentValue), [currentValue]); + + // BoardField handlers + const handleBoardChange = useCallback( + (v) => { + if (!v) { + return; + } + const value = v.value === 'auto' || v.value === 'none' ? v.value : { board_id: v.value }; + handleChange(value); + }, + [handleChange] + ); + + const boardOptions = useMemo(() => { + const _options: ComboboxOption[] = [ + { label: t('common.auto'), value: 'auto' }, + { label: `${t('common.none')} (${t('boards.uncategorized')})`, value: 'none' }, + ]; + if (boardsData) { + for (const board of boardsData) { + _options.push({ + label: board.board_name, + value: board.board_id, + }); + } + } + return _options; + }, [boardsData, t]); + + const boardValue = useMemo(() => { + const _value = currentValue; + const autoOption = boardOptions[0]; + const noneOption = boardOptions[1]; + if (!_value || _value === 'auto') { + return autoOption; + } + if (_value === 'none') { + return noneOption; + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const boardId = typeof _value === 'object' ? (_value as any).board_id : _value; + const boardOption = boardOptions.find((o) => o.value === boardId); + return boardOption ?? autoOption; + }, [currentValue, boardOptions]); + + const noOptionsMessage = useCallback(() => t('boards.noMatching'), [t]); + + // ModelIdentifierField handlers + const handleModelChange = useCallback( + (value: AnyModelConfig | null) => { + if (!value) { + return; + } + handleChange(value); + }, + [handleChange] + ); + + const modelConfigs = useMemo(() => { + if (!modelsData) { + return EMPTY_ARRAY; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const ui_model_base = fieldTemplate ? (fieldTemplate as any)?.ui_model_base : null; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const ui_model_type = fieldTemplate ? (fieldTemplate as any)?.ui_model_type : null; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const ui_model_variant = fieldTemplate ? (fieldTemplate as any)?.ui_model_variant : null; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const ui_model_format = fieldTemplate ? (fieldTemplate as any)?.ui_model_format : null; + + if (!ui_model_base && !ui_model_type) { + return modelConfigsAdapterSelectors.selectAll(modelsData); + } + + return modelConfigsAdapterSelectors.selectAll(modelsData).filter((config) => { + if (ui_model_base && !ui_model_base.includes(config.base)) { + return false; + } + if (ui_model_type && !ui_model_type.includes(config.type)) { + return false; + } + if (ui_model_variant && 'variant' in config && config.variant && !ui_model_variant.includes(config.variant)) { + return false; + } + if (ui_model_format && !ui_model_format.includes(config.format)) { + return false; + } + return true; + }); + }, [modelsData, fieldTemplate]); + + // ImageField handler + const handleImageFieldSelect = useCallback(() => { + dispatch(canvasWorkflowIntegrationImageFieldSelected({ fieldKey })); + }, [dispatch, fieldKey]); + + if (!field || !fieldTemplate) { + log.warn({ fieldIdentifier }, 'Field or template is null - not rendering'); + return null; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const label = (field as any)?.label || (fieldTemplate as any)?.title || fieldIdentifier.fieldName; + + // Log the entire field structure to understand its shape + log.debug( + { fieldType, label, currentValue, fieldStructure: field, fieldTemplateStructure: fieldTemplate }, + 'Field info' + ); + + // ImageField - allow user to select which one receives the canvas image + if (fieldType === 'ImageField') { + return ( + + ); + } + + // Render different input types based on field type + if (fieldType === 'StringField') { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const isTextarea = (fieldTemplate as any)?.ui_component === 'textarea'; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const isRequired = (fieldTemplate as any)?.required ?? false; + + if (isTextarea) { + return ( + + {label} +