Compare commits

...

89 Commits

Author SHA1 Message Date
psychedelicious
8a8f4c593f wip 2025-04-02 06:42:01 +10:00
psychedelicious
29c78f0e5e wip 2025-04-01 15:45:59 +10:00
psychedelicious
501534e2e1 chore(ui): typegen 2025-04-01 08:49:28 +10:00
psychedelicious
50c7318004 feat(app): add is_published to workflow models 2025-04-01 08:48:06 +10:00
psychedelicious
7f14597012 refactor(app): clean up compose_mode_from_fields util 2025-04-01 08:46:27 +10:00
psychedelicious
dbe68b364f feat(ui): publish toast links to project dashboard 2025-04-01 08:22:48 +10:00
psychedelicious
0c7aa85a5c feat(ui): add badge to queue indicating if run is validation run 2025-04-01 08:22:48 +10:00
psychedelicious
703e1c8001 feat(ui): publish toasts do not auto-close 2025-04-01 08:22:48 +10:00
psychedelicious
b056c93ea3 feat(ui): disable invoke button during publish operation 2025-04-01 08:22:48 +10:00
psychedelicious
4289241943 feat(ui): "isInDeployFlow" -> "isInPublishFlow" 2025-04-01 08:22:48 +10:00
psychedelicious
51f5abf5f9 feat(ui): wip publish flow 2025-04-01 08:22:48 +10:00
psychedelicious
e59fa59ad7 feat(ui): wip publish flow 2025-04-01 08:22:48 +10:00
psychedelicious
2407cb64b3 feat(app): truncate invalid model config warning to 64 chars
Previously it logged the whole config and flooded the terminal output.
2025-04-01 08:22:48 +10:00
psychedelicious
70f704ab44 feat(ui): publish button works 2025-04-01 08:22:48 +10:00
psychedelicious
b786032b89 feat(ui): make validation run logic conditional 2025-04-01 08:22:48 +10:00
psychedelicious
e8cc06cc92 feat(ui): disable all workflow editor interaction while in deploy flow 2025-04-01 08:22:48 +10:00
psychedelicious
8e6c56c93d wip 2025-04-01 08:22:48 +10:00
psychedelicious
69d4ee7f93 chore(ui): bump @xyflow/react to latest 2025-04-01 08:22:48 +10:00
psychedelicious
567fd3e0da refactor(ui): standardize more workflow editor hooks to use Safe and OrThrow suffixes for clarity 2025-04-01 08:22:47 +10:00
psychedelicious
0b8f88e554 wip 2025-04-01 08:22:47 +10:00
psychedelicious
60f0c4bf99 refactor(ui): standardize more workflow editor hooks to use Safe and OrThrow suffixes for clarity 2025-04-01 08:22:47 +10:00
psychedelicious
900ec92ef1 tidy(ui): remove extraneous scrollable container 2025-04-01 08:22:47 +10:00
psychedelicious
2594768479 revert(ui): remove api_fields from zod workflow schema 2025-04-01 08:22:47 +10:00
psychedelicious
91ab81eca9 chore(ui): typegen 2025-04-01 08:22:47 +10:00
psychedelicious
b20c745c6e revert(app): remove api_fields from workflow pydantic model 2025-04-01 08:22:47 +10:00
psychedelicious
e41a37bca0 refactor(ui): generalize node field dnd to drag node fields vs node field form elements 2025-04-01 08:22:47 +10:00
psychedelicious
9ca44f27a5 feat(ui): rough out state mgmt for workflow api fields 2025-04-01 08:22:47 +10:00
psychedelicious
b9ddf67853 refactor(ui): rejiggle enqueue actions to support api validation runs 2025-04-01 08:22:47 +10:00
psychedelicious
afe088045f chore(ui): rename type BatchConfig -> EnqueueBatchArg 2025-04-01 08:22:47 +10:00
psychedelicious
09ca61a962 chore(ui): typegen 2025-04-01 08:22:47 +10:00
psychedelicious
dd69a96c03 feat(queue): move session count calculation in to Batch class, cache it, add pydantic validator for validation runs 2025-04-01 08:22:46 +10:00
psychedelicious
4a54e594d0 tests(ui): update test for workflow types 2025-04-01 08:22:46 +10:00
psychedelicious
936ed1960a feat(ui): add api_fields to zod schemas 2025-04-01 08:22:46 +10:00
psychedelicious
9fac7986c7 chore(ui): typegen 2025-04-01 08:22:46 +10:00
psychedelicious
e4b603f44e feat(app): add api_fields to workflow pydantic schema 2025-04-01 08:22:46 +10:00
psychedelicious
7edfe6edcf chore(ui): bump tsafe dep 2025-04-01 08:22:46 +10:00
jazzhaiku
bfb117d0e0 Port LoRA to new classification API (#7849)
## Summary

- Port LoRA to new classification API
- Add 2 additional tests cases (ControlLora and Flux Diffusers LoRA)
- Moved `ModelOnDisk` to its own module

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-04-01 08:05:48 +11:00
jazzhaiku
b31c1022c3 Merge branch 'main' into lora-classification 2025-04-01 07:58:36 +11:00
Mary Hipp
a5851ca31c fix from leftover testing 2025-03-31 12:45:53 -04:00
Mary Hipp
77bf5c15bb GET presigned URLs directly instead of trying to use redirects 2025-03-31 12:45:53 -04:00
psychedelicious
595133463e feat(nodes): add methods to invalidate invocation typeadapters 2025-03-31 19:15:59 +11:00
psychedelicious
6155f9ff9e feat(nodes): move invocation/output registration to separate class 2025-03-31 19:15:59 +11:00
psychedelicious
7be87c8048 refactor(nodes): simpler logic for baseinvocation typeadapter handling 2025-03-31 19:15:59 +11:00
jazzhaiku
9868c3bfe3 Merge branch 'main' into lora-classification 2025-03-31 16:43:26 +11:00
psychedelicious
8b299d0bac chore: prep for v5.9.1 2025-03-31 13:40:07 +11:00
psychedelicious
a44bfb4658 fix(mm): handle FLUX models w/ diff in_channels keys
Before FLUX Fill was merged, we didn't do any checks for the model variant. We always returned "normal".

To determine if a model is a FLUX Fill model, we need to check the state dict for a specific key. Initially, this logic was too strict and rejected quantized FLUX models. This issue was resolved, but it turns out there is another failure mode - some fine-tunes use a different key.

This change further reduces the strictness, handling the alternate key and also falling back to "normal" if we don't see either key. This effectively restores the previous probing behaviour for all FLUX models.

Closes #7856
Closes #7859
2025-03-31 12:32:55 +11:00
psychedelicious
96fb5f6881 feat(ui): disable denoising strength when selected models flux fill 2025-03-31 11:31:02 +11:00
psychedelicious
4109ea5324 fix(nodes): expanded masks not 100% transparent outside the fade out region
The polynomial fit isn't perfect and we end up with alpha values of 1 instead of 0 when applying the mask. This in turn causes issues on canvas where outputs aren't 100% transparent and individual layer bbox calculations are incorrect.
2025-03-31 11:17:00 +11:00
jazzhaiku
f6c2ee5040 Merge branch 'main' into lora-classification 2025-03-31 09:01:16 +11:00
Billy
965753bf8b Ruff formatting 2025-03-31 08:18:00 +11:00
Billy
40c53ab95c Guard 2025-03-29 09:58:02 +11:00
psychedelicious
aaa6211625 chore(backend): ruff C420 2025-03-28 18:28:32 -04:00
psychedelicious
f6d770eac9 ci: add python 3.12 to test matrix 2025-03-28 18:28:32 -04:00
psychedelicious
47cb61cd62 ci: remove python 3.10 from test matrix 2025-03-28 18:28:32 -04:00
psychedelicious
b0fdc8ae1c ci: bump linux-cpu test runner to ubuntu 24.04 2025-03-28 18:28:32 -04:00
psychedelicious
ed9b30efda ci: bump uv to 0.6.10 2025-03-28 18:28:32 -04:00
psychedelicious
168e5eeff0 ci: use uv in typegen-checks
ci: use uv in typegen-checks to generate types

experiment: simulate typegen-checks failure

Revert "experiment: simulate typegen-checks failure"

This reverts commit f53c6876fe8311de236d974194abce93ed84930c.
2025-03-28 18:28:32 -04:00
psychedelicious
7acaa86bdf ci: get ci working with uv instead of pip
Lots of squashed experimentation heh:

ci: manually specify python version in tests

ci: whoops typo in ruff cmds

ci: specify python versions for uv python install

ci: install python verbosely

ci: try forcing python preference?

ci: try forcing python preference a different way?

ci: try in a venv?

ci: it works, but try without venv

ci: oh maybe we need --preview?

ci: poking it with a stick

ci: it works, add summary to pytest output

ci: fix pytest output

experiment: simulate test failure

Revert "experiment: simulate test failure"

This reverts commit b99ca512f6e61a2a04a1c0636d44018c11019954.

ci: just use default pytest output

cI: attempt again to use uv to install python

cI: attempt again again to use uv to install python

Revert "cI: attempt again again to use uv to install python"

This reverts commit 3cba861c90738081caeeb3eca97b60656ab63929.

Revert "cI: attempt again to use uv to install python"

This reverts commit b30f2277041dc999ed514f6c594c6d6a78f5c810.
2025-03-28 18:28:32 -04:00
psychedelicious
96c0393fe7 ci: bump ruff to 0.11.2
Need to bump both CI and pyproject.toml at the same time
2025-03-28 18:28:32 -04:00
psychedelicious
403f795c5e ci: remove linux-cuda-11_7 & linux-rocm-5_2 from test matrix
We only have CPU runners, so these tests are not doing anything useful.
2025-03-28 18:28:32 -04:00
psychedelicious
c0f88a083e ci: use uv for python-tests 2025-03-28 18:28:32 -04:00
psychedelicious
542b182899 ci: use uv for python-checks 2025-03-28 18:28:32 -04:00
Mary Hipp
3f58c68c09 fix tag invalidation 2025-03-28 10:52:27 -04:00
Mary Hipp
e50c7e5947 restore multiple key 2025-03-28 10:52:27 -04:00
Mary Hipp
4a83700fe4 if clientSideUploading is enabled, handle bulk uploads using that flow 2025-03-28 10:52:27 -04:00
jazzhaiku
c25f6d1f84 Merge branch 'main' into lora-classification 2025-03-28 12:32:22 +11:00
jazzhaiku
a53e1ccf08 Small improvements (#7842)
## Summary

- Extend `ModelOnDisk` with caching, type hints, default args
- Fail early if there is an error classifying a config

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-03-28 12:21:41 +11:00
jazzhaiku
1af9930951 Merge branch 'main' into small-improvements 2025-03-28 12:11:09 +11:00
Billy
c276c1cbee Comment 2025-03-28 10:57:46 +11:00
Billy
c619348f29 Extract ModelOnDisk to its own module 2025-03-28 10:35:13 +11:00
psychedelicious
c6f96613fc chore(ui): typegen 2025-03-28 08:14:06 +11:00
psychedelicious
258bf736da fix(nodes): handle zero fade size (e.g. mask blur 0)
Closes #7850
2025-03-28 08:14:06 +11:00
Billy
0d75c99476 Caching 2025-03-27 17:55:09 +11:00
Billy
323d409fb6 Make ruff happy 2025-03-27 17:47:57 +11:00
Billy
f251722f56 LoRA classification API 2025-03-27 17:47:01 +11:00
jazzhaiku
c9dc27afbb Merge branch 'main' into small-improvements 2025-03-27 08:14:48 +11:00
Billy
efd14ec0e4 Make ruff happy 2025-03-27 08:11:39 +11:00
Billy
21ee2b6251 Merge branch 'small-improvements' of github.com:invoke-ai/InvokeAI into small-improvements 2025-03-27 08:10:38 +11:00
Billy
82dd2d508f Deprecate checkpoint as file, diffusers as directory terminology 2025-03-27 08:10:12 +11:00
jazzhaiku
5a59f6e3b8 Merge branch 'main' into small-improvements 2025-03-27 07:38:13 +11:00
Billy
60b5aef16a Log error -> warning 2025-03-27 06:56:22 +11:00
Billy
0e8b5484d5 Error handling 2025-03-26 19:31:57 +11:00
Billy
454506c83e Type hints 2025-03-26 19:12:49 +11:00
Billy
8f6ab67376 Logs 2025-03-26 16:34:32 +11:00
Billy
5afcc7778f Redundant 2025-03-26 16:32:19 +11:00
Billy
325e07d330 Error handling 2025-03-26 16:30:45 +11:00
Billy
a016bdc159 Add todo 2025-03-26 16:17:26 +11:00
Billy
a14f0b2864 Fail early on invalid config 2025-03-26 16:10:32 +11:00
Billy
721483318a Extend ModelOnDisk 2025-03-26 16:10:00 +11:00
149 changed files with 2836 additions and 807 deletions

View File

@@ -34,6 +34,9 @@ on:
jobs:
python-checks:
env:
# uv requires a venv by default - but for this, we can simply use the system python
UV_SYSTEM_PYTHON: 1
runs-on: ubuntu-latest
timeout-minutes: 5 # expected run time: <1 min
steps:
@@ -57,25 +60,19 @@ jobs:
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: setup python
- name: setup uv
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
uses: actions/setup-python@v5
uses: astral-sh/setup-uv@v5
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
- name: install ruff
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: pip install ruff==0.9.9
shell: bash
version: '0.6.10'
enable-cache: true
- name: ruff check
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: ruff check --output-format=github .
run: uv tool run ruff@0.11.2 check --output-format=github .
shell: bash
- name: ruff format
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
run: ruff format --check .
run: uv tool run ruff@0.11.2 format --check .
shell: bash

View File

@@ -39,24 +39,15 @@ jobs:
strategy:
matrix:
python-version:
- '3.10'
- '3.11'
- '3.12'
platform:
- linux-cuda-11_7
- linux-rocm-5_2
- linux-cpu
- macos-default
- windows-cpu
include:
- platform: linux-cuda-11_7
os: ubuntu-22.04
github-env: $GITHUB_ENV
- platform: linux-rocm-5_2
os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
github-env: $GITHUB_ENV
- platform: linux-cpu
os: ubuntu-22.04
os: ubuntu-24.04
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- platform: macos-default
@@ -70,6 +61,8 @@ jobs:
timeout-minutes: 15 # expected run time: 2-6 min, depending on platform
env:
PIP_USE_PEP517: '1'
UV_SYSTEM_PYTHON: 1
steps:
- name: checkout
# https://github.com/nschloe/action-cached-lfs-checkout
@@ -92,20 +85,25 @@ jobs:
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: setup uv
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
uses: astral-sh/setup-uv@v5
with:
version: '0.6.10'
enable-cache: true
python-version: ${{ matrix.python-version }}
- name: setup python
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: pyproject.toml
- name: install dependencies
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
env:
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
run: >
pip3 install --editable=".[test]"
UV_INDEX: ${{ matrix.extra-index-url }}
run: uv pip install --editable ".[test]"
- name: run pytest
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}

View File

@@ -54,17 +54,25 @@ jobs:
- 'pyproject.toml'
- 'invokeai/**'
- name: setup uv
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
uses: astral-sh/setup-uv@v5
with:
version: '0.6.10'
enable-cache: true
python-version: '3.11'
- name: setup python
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
python-version: '3.11'
- name: install python dependencies
- name: install dependencies
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
run: pip3 install --use-pep517 --editable="."
env:
UV_INDEX: ${{ matrix.extra-index-url }}
run: uv pip install --editable .
- name: install frontend dependencies
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
@@ -77,7 +85,7 @@ jobs:
- name: generate schema
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
run: make frontend-typegen
run: cd invokeai/frontend/web && uv run ../../../scripts/generate_openapi_schema.py | pnpm typegen
shell: bash
- name: compare files

View File

@@ -96,6 +96,22 @@ async def upload_image(
raise HTTPException(status_code=500, detail="Failed to create image")
class ImageUploadEntry(BaseModel):
image_dto: ImageDTO = Body(description="The image DTO")
presigned_url: str = Body(description="The URL to get the presigned URL for the image upload")
@images_router.post("/", operation_id="create_image_upload_entry")
async def create_image_upload_entry(
width: int = Body(description="The width of the image"),
height: int = Body(description="The height of the image"),
board_id: Optional[str] = Body(default=None, description="The board to add this image to, if any"),
) -> ImageUploadEntry:
"""Uploads an image from a URL, not implemented"""
raise HTTPException(status_code=501, detail="Not implemented")
@images_router.delete("/i/{image_name}", operation_id="delete_image")
async def delete_image(
image_name: str = Path(description="The name of the image to delete"),

View File

@@ -1,10 +1,13 @@
from typing import Optional
import json
from typing import Any, Optional
from fastapi import Body, Path, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.invocations.fields import BoardField
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
@@ -15,6 +18,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
CancelByDestinationResult,
ClearResult,
EnqueueBatchResult,
FieldIdentifier,
PruneResult,
RetryItemsResult,
SessionQueueCountsByDestination,
@@ -22,6 +26,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItemDTO,
SessionQueueStatus,
)
from invokeai.app.services.shared.compose_pydantic_model import compose_model_from_fields
from invokeai.app.services.shared.pagination import CursorPaginatedResults
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
@@ -34,6 +39,17 @@ class SessionQueueAndProcessorStatus(BaseModel):
processor: SessionProcessorStatus
class SimpleModelIdentifer(BaseModel):
id: str = Field(description="The model id")
model_field_overrides = {ModelIdentifierField: (SimpleModelIdentifer, Field(description="The model identifier"))}
def model_field_filter(field_type: type[Any]) -> bool:
return field_type not in {BoardField, Optional[BoardField]}
@session_queue_router.post(
"/{queue_id}/enqueue_batch",
operation_id="enqueue_batch",
@@ -45,9 +61,52 @@ async def enqueue_batch(
queue_id: str = Path(description="The queue id to perform this operation on"),
batch: Batch = Body(description="Batch to process"),
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
is_api_validation_run: bool = Body(
default=False,
description="Whether or not this is a validation run.",
),
api_input_fields: Optional[list[FieldIdentifier]] = Body(
default=None, description="The fields that were used as input to the API"
),
api_output_fields: Optional[list[FieldIdentifier]] = Body(
default=None, description="The fields that were used as output from the API"
),
) -> EnqueueBatchResult:
"""Processes a batch and enqueues the output graphs for execution."""
if is_api_validation_run:
session_count = batch.get_session_count()
assert session_count == 1, "API validation run only supports single session batches"
if api_input_fields:
composed_model = compose_model_from_fields(
g=batch.graph,
field_identifiers=api_input_fields,
composed_model_class_name="APIInputModel",
model_field_overrides=model_field_overrides,
model_field_filter=model_field_filter,
)
json_schema = composed_model.model_json_schema(mode="validation")
print("API Input Model")
print(json.dumps(json_schema))
if api_output_fields:
composed_model = compose_model_from_fields(
g=batch.graph,
field_identifiers=api_output_fields,
composed_model_class_name="APIOutputModel",
)
json_schema = composed_model.model_json_schema(mode="validation")
print("API Output Model")
print(json.dumps(json_schema))
print("graph")
print(batch.graph.model_dump_json())
if batch.workflow is not None:
print("workflow")
print(batch.workflow.model_dump_json())
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
queue_id=queue_id, batch=batch, prepend=prepend
)

View File

@@ -1,4 +1,5 @@
import io
import random
import traceback
from typing import Optional
@@ -24,6 +25,37 @@ from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_common import
IMAGE_MAX_AGE = 31536000
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
ids = {
"6614752a-0420-4d81-98fc-e110069d4f38": random.choice([True, False]),
"default_5e8b008d-c697-45d0-8883-085a954c6ace": random.choice([True, False]),
"4b2b297a-0d47-4f43-8113-ebbf3f403089": random.choice([True, False]),
"d0ce602a-049e-4368-97ae-977b49eed042": random.choice([True, False]),
"f170a187-fd74-40b8-ba9c-00de173ea4b9": random.choice([True, False]),
"default_f96e794f-eb3e-4d01-a960-9b4e43402bcf": random.choice([True, False]),
"default_cbf0e034-7b54-4b2c-b670-3b1e2e4b4a88": random.choice([True, False]),
"default_dec5a2e9-f59c-40d9-8869-a056751d79b8": random.choice([True, False]),
"default_dbe46d95-22aa-43fb-9c16-94400d0ce2fd": random.choice([True, False]),
"default_d7a1c60f-ca2f-4f90-9e33-75a826ca6d8f": random.choice([True, False]),
"default_e71d153c-2089-43c7-bd2c-f61f37d4c1c1": random.choice([True, False]),
"default_7dde3e36-d78f-4152-9eea-00ef9c8124ed": random.choice([True, False]),
"default_444fe292-896b-44fd-bfc6-c0b5d220fffc": random.choice([True, False]),
"default_2d05e719-a6b9-4e64-9310-b875d3b2f9d2": random.choice([True, False]),
"acae7e87-070b-4999-9074-c5b593c86618": random.choice([True, False]),
"3008fc77-1521-49c7-ba95-94c5a4508d1d": random.choice([True, False]),
"default_686bb1d0-d086-4c70-9fa3-2f600b922023": random.choice([True, False]),
"36905c46-e768-4dc3-8ecd-e55fe69bf03c": random.choice([True, False]),
"7c3e4951-183b-40ef-a890-28eef4d50097": random.choice([True, False]),
"7a053b2f-64e4-4152-80e9-296006e77131": random.choice([True, False]),
"27d4f1be-4156-46e9-8d22-d0508cd72d4f": random.choice([True, False]),
"e881dc06-70d2-438f-b007-6f3e0c3c0e78": random.choice([True, False]),
"265d2244-a1d7-495c-a2eb-88217f5eae37": random.choice([True, False]),
"caebcbc7-2bf0-41c4-b553-106b585fddda": random.choice([True, False]),
"a7998705-474e-417d-bd37-a2a9480beedf": random.choice([True, False]),
"554d94b5-94b3-4d8e-8aed-51ebfc9deea5": random.choice([True, False]),
"e6898540-c1bc-408b-b944-c1e242cddbcd": random.choice([True, False]),
"363b0960-ab2c-4902-8df3-f592d6194bb3": random.choice([True, False]),
}
@workflows_router.get(
"/i/{workflow_id}",
@@ -39,6 +71,8 @@ async def get_workflow(
try:
thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id)
workflow = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
workflow.is_published = ids.get(workflow_id, False)
workflow.workflow.is_published = ids.get(workflow_id, False)
return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump())
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
@@ -106,10 +140,11 @@ async def list_workflows(
tags: Optional[list[str]] = Query(default=None, description="The tags of workflow to get"),
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
is_published: Optional[bool] = Query(default=None, description="Whether to include/exclude published workflows"),
) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]:
"""Gets a page of workflows"""
workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = []
workflows = ApiDependencies.invoker.services.workflow_records.get_many(
workflow_record_list_items = ApiDependencies.invoker.services.workflow_records.get_many(
order_by=order_by,
direction=direction,
page=page,
@@ -118,20 +153,23 @@ async def list_workflows(
categories=categories,
tags=tags,
has_been_opened=has_been_opened,
is_published=is_published,
)
for workflow in workflows.items:
for item in workflow_record_list_items.items:
data = item.model_dump()
data["is_published"] = ids.get(item.workflow_id, False)
workflows_with_thumbnails.append(
WorkflowRecordListItemWithThumbnailDTO(
thumbnail_url=ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow.workflow_id),
**workflow.model_dump(),
thumbnail_url=ApiDependencies.invoker.services.workflow_thumbnails.get_url(item.workflow_id),
**data,
)
)
return PaginatedResults[WorkflowRecordListItemWithThumbnailDTO](
items=workflows_with_thumbnails,
total=workflows.total,
page=workflows.page,
pages=workflows.pages,
per_page=workflows.per_page,
total=workflow_record_list_items.total,
page=workflow_record_list_items.page,
pages=workflow_record_list_items.pages,
per_page=workflow_record_list_items.per_page,
)

View File

@@ -8,6 +8,7 @@ import sys
import warnings
from abc import ABC, abstractmethod
from enum import Enum
from functools import lru_cache
from inspect import signature
from typing import (
TYPE_CHECKING,
@@ -27,7 +28,6 @@ import semver
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from typing_extensions import TypeAliasType
from invokeai.app.invocations.fields import (
FieldKind,
@@ -100,37 +100,6 @@ class BaseInvocationOutput(BaseModel):
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
"""
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod
def register_output(cls, output: BaseInvocationOutput) -> None:
"""Registers an invocation output."""
cls._output_classes.add(output)
cls._typeadapter_needs_update = True
@classmethod
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
"""Gets all invocation outputs."""
return cls._output_classes
@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
if not cls._typeadapter or cls._typeadapter_needs_update:
AnyInvocationOutput = TypeAliasType(
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
cls._typeadapter_needs_update = False
return cls._typeadapter
@classmethod
def get_output_types(cls) -> Iterable[str]:
"""Gets all invocation output types."""
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
@@ -173,76 +142,16 @@ class BaseInvocation(ABC, BaseModel):
All invocations must use the `@invocation` decorator to provide their unique type.
"""
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod
def get_type(cls) -> str:
"""Gets the invocation's type, as provided by the `@invocation` decorator."""
return cls.model_fields["type"].default
@classmethod
def register_invocation(cls, invocation: BaseInvocation) -> None:
"""Registers an invocation."""
cls._invocation_classes.add(invocation)
cls._typeadapter_needs_update = True
@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
if not cls._typeadapter or cls._typeadapter_needs_update:
AnyInvocation = TypeAliasType(
"AnyInvocation", Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(AnyInvocation)
cls._typeadapter_needs_update = False
return cls._typeadapter
@classmethod
def invalidate_typeadapter(cls) -> None:
"""Invalidates the typeadapter, forcing it to be rebuilt on next access. If the invocation allowlist or
denylist is changed, this should be called to ensure the typeadapter is updated and validation respects
the updated allowlist and denylist."""
cls._typeadapter_needs_update = True
@classmethod
def get_invocations(cls) -> Iterable[BaseInvocation]:
"""Gets all invocations, respecting the allowlist and denylist."""
app_config = get_config()
allowed_invocations: set[BaseInvocation] = set()
for sc in cls._invocation_classes:
invocation_type = sc.get_type()
is_in_allowlist = (
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
)
is_in_denylist = (
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
)
if is_in_allowlist and not is_in_denylist:
allowed_invocations.add(sc)
return allowed_invocations
@classmethod
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
"""Gets a map of all invocation types to their invocation classes."""
return {i.get_type(): i for i in BaseInvocation.get_invocations()}
@classmethod
def get_invocation_types(cls) -> Iterable[str]:
"""Gets all invocation types."""
return (i.get_type() for i in BaseInvocation.get_invocations())
@classmethod
def get_output_annotation(cls) -> BaseInvocationOutput:
"""Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method)."""
return signature(cls.invoke).return_annotation
@classmethod
def get_invocation_for_type(cls, invocation_type: str) -> BaseInvocation | None:
"""Gets the invocation class for a given invocation type."""
return cls.get_invocations_map().get(invocation_type)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
@@ -340,6 +249,105 @@ class BaseInvocation(ABC, BaseModel):
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
class InvocationRegistry:
_invocation_classes: ClassVar[set[type[BaseInvocation]]] = set()
_output_classes: ClassVar[set[type[BaseInvocationOutput]]] = set()
@classmethod
def register_invocation(cls, invocation: type[BaseInvocation]) -> None:
"""Registers an invocation."""
cls._invocation_classes.add(invocation)
cls.invalidate_invocation_typeadapter()
@classmethod
@lru_cache(maxsize=1)
def get_invocation_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantic TypeAdapter for the union of all invocation types.
This is used to parse serialized invocations into the correct invocation class.
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
the updated allowlist and denylist.
@see https://docs.pydantic.dev/latest/concepts/type_adapter/
"""
return TypeAdapter(Annotated[Union[tuple(cls.get_invocation_classes())], Field(discriminator="type")])
@classmethod
def invalidate_invocation_typeadapter(cls) -> None:
"""Invalidates the cached invocation type adapter."""
cls.get_invocation_typeadapter.cache_clear()
@classmethod
def get_invocation_classes(cls) -> Iterable[type[BaseInvocation]]:
"""Gets all invocations, respecting the allowlist and denylist."""
app_config = get_config()
allowed_invocations: set[type[BaseInvocation]] = set()
for sc in cls._invocation_classes:
invocation_type = sc.get_type()
is_in_allowlist = (
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
)
is_in_denylist = (
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
)
if is_in_allowlist and not is_in_denylist:
allowed_invocations.add(sc)
return allowed_invocations
@classmethod
def get_invocations_map(cls) -> dict[str, type[BaseInvocation]]:
"""Gets a map of all invocation types to their invocation classes."""
return {i.get_type(): i for i in cls.get_invocation_classes()}
@classmethod
def get_invocation_types(cls) -> Iterable[str]:
"""Gets all invocation types."""
return (i.get_type() for i in cls.get_invocation_classes())
@classmethod
def get_invocation_for_type(cls, invocation_type: str) -> type[BaseInvocation] | None:
"""Gets the invocation class for a given invocation type."""
return cls.get_invocations_map().get(invocation_type)
@classmethod
def register_output(cls, output: "type[TBaseInvocationOutput]") -> None:
"""Registers an invocation output."""
cls._output_classes.add(output)
cls.invalidate_output_typeadapter()
@classmethod
def get_output_classes(cls) -> Iterable[type[BaseInvocationOutput]]:
"""Gets all invocation outputs."""
return cls._output_classes
@classmethod
@lru_cache(maxsize=1)
def get_output_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantic TypeAdapter for the union of all invocation output types.
This is used to parse serialized invocation outputs into the correct invocation output class.
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
the updated allowlist and denylist.
@see https://docs.pydantic.dev/latest/concepts/type_adapter/
"""
return TypeAdapter(Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")])
@classmethod
def invalidate_output_typeadapter(cls) -> None:
"""Invalidates the cached invocation output type adapter."""
cls.get_output_typeadapter.cache_clear()
@classmethod
def get_output_types(cls) -> Iterable[str]:
"""Gets all invocation output types."""
return (i.get_type() for i in cls.get_output_classes())
RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
"id",
"is_intermediate",
@@ -453,8 +461,8 @@ def invocation(
node_pack = cls.__module__.split(".")[0]
# Handle the case where an existing node is being clobbered by the one we are registering
if invocation_type in BaseInvocation.get_invocation_types():
clobbered_invocation = BaseInvocation.get_invocation_for_type(invocation_type)
if invocation_type in InvocationRegistry.get_invocation_types():
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
# This should always be true - we just checked if the invocation type was in the set
assert clobbered_invocation is not None
@@ -539,8 +547,7 @@ def invocation(
)
cls.__doc__ = docstring
# TODO: how to type this correctly? it's typed as ModelMetaclass, a private class in pydantic
BaseInvocation.register_invocation(cls) # type: ignore
InvocationRegistry.register_invocation(cls)
return cls
@@ -565,7 +572,7 @@ def invocation_output(
if re.compile(r"^\S+$").match(output_type) is None:
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
if output_type in BaseInvocationOutput.get_output_types():
if output_type in InvocationRegistry.get_output_types():
raise ValueError(f'Invocation type "{output_type}" already exists')
validate_fields(cls.model_fields, output_type)
@@ -586,7 +593,7 @@ def invocation_output(
)
cls.__doc__ = docstring
BaseInvocationOutput.register_output(cls) # type: ignore # TODO: how to type this correctly?
InvocationRegistry.register_output(cls)
return cls

View File

@@ -1089,12 +1089,13 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
@invocation(
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.0"
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.1"
)
class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
The mask is thresholded to create a binary mask, and then a distance transform is applied to create a fade effect.
The fade size is specified in pixels, and the mask is expanded by that amount. The result is a mask with a smooth transition from black to white.
If the fade size is 0, the mask is returned as-is.
"""
mask: ImageField = InputField(description="The mask to expand")
@@ -1104,6 +1105,11 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_mask = context.images.get_pil(self.mask.image_name, mode="L")
if self.fade_size_px == 0:
# If the fade size is 0, just return the mask as-is.
image_dto = context.images.save(image=pil_mask, image_category=ImageCategory.MASK)
return ImageOutput.build(image_dto)
np_mask = numpy.array(pil_mask)
# Threshold the mask to create a binary mask - 0 for black, 255 for white
@@ -1141,8 +1147,21 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
coeffs = numpy.polyfit(x_control, y_control, 3)
poly = numpy.poly1d(coeffs)
# Evaluate and clip the smooth mapping
feather = numpy.clip(poly(d_norm), 0, 1)
# Evaluate the polynomial
feather = poly(d_norm)
# The polynomial fit isn't perfect. Points beyond the fade distance are likely to be slightly less than 1.0,
# even though the control points indicate that they should be exactly 1.0. This is due to the nature of the
# polynomial fit, which is a best approximation of the control points but not an exact match.
# When this occurs, the area outside the mask and fade-out will not be 100% transparent. For example, it may
# have an alpha value of 1 instead of 0. So we must force pixels at or beyond the fade distance to exactly 1.0.
# Force pixels at or beyond the fade distance to exactly 1.0
feather = numpy.where(d_norm >= 1.0, 1.0, feather)
# Clip any other values to ensure they're in the valid range [0,1]
feather = numpy.clip(feather, 0, 1)
# Build final image.
np_result = numpy.where(black_mask == 1, 0, (feather * 255).astype(numpy.uint8))

View File

@@ -302,7 +302,10 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
# We catch this error so that the app can still run if there are invalid model configs in the database.
# One reason that an invalid model config might be in the database is if someone had to rollback from a
# newer version of the app that added a new model type.
self._logger.warning(f"Found an invalid model config in the database. Ignoring this model. ({row[0]})")
row_data = f"{row[0][:64]}..." if len(row[0]) > 64 else row[0]
self._logger.warning(
f"Found an invalid model config in the database. Ignoring this model. ({row_data})"
)
else:
results.append(model_config)

View File

@@ -33,7 +33,12 @@ class SessionQueueBase(ABC):
pass
@abstractmethod
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> Coroutine[Any, Any, EnqueueBatchResult]:
def enqueue_batch(
self,
queue_id: str,
batch: Batch,
prepend: bool,
) -> Coroutine[Any, Any, EnqueueBatchResult]:
"""Enqueues all permutations of a batch for execution."""
pass

View File

@@ -157,6 +157,28 @@ class Batch(BaseModel):
v.validate_self()
return v
def get_session_count(self) -> int:
"""
Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually
creating them, as is done in `create_session_nfv_tuples()`.
The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how
many were _actually_ created (which may be less due to the maximum number of sessions).
If the session count has already been calculated, return the cached value.
"""
if not self.data:
return self.runs
data = []
for batch_datum_list in self.data:
to_zip = []
for batch_datum in batch_datum_list:
batch_data_items = range(len(batch_datum.items))
to_zip.append(batch_data_items)
data.append(list(zip(*to_zip, strict=True)))
data_product = list(product(*data))
return len(data_product) * self.runs
model_config = ConfigDict(
json_schema_extra={
"required": [
@@ -201,6 +223,12 @@ def get_workflow(queue_item_dict: dict) -> Optional[WorkflowWithoutID]:
return None
class FieldIdentifier(BaseModel):
kind: Literal["input", "output"] = Field(description="The kind of field")
node_id: str = Field(description="The ID of the node")
field_name: str = Field(description="The name of the field")
class SessionQueueItemWithoutGraph(BaseModel):
"""Session queue item without the full graph. Used for serialization."""
@@ -237,6 +265,16 @@ class SessionQueueItemWithoutGraph(BaseModel):
retried_from_item_id: Optional[int] = Field(
default=None, description="The item_id of the queue item that this item was retried from"
)
is_api_validation_run: bool = Field(
default=False,
description="Whether this queue item is an API validation run.",
)
api_input_fields: Optional[list[FieldIdentifier]] = Field(
default=None, description="The fields that were used as input to the API"
)
api_output_fields: Optional[list[FieldIdentifier]] = Field(
default=None, description="The nodes that were used as output from the API"
)
@classmethod
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
@@ -536,28 +574,6 @@ def create_session_nfv_tuples(batch: Batch, maximum: int) -> Generator[tuple[str
count += 1
def calc_session_count(batch: Batch) -> int:
"""
Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually
creating them, as is done in `create_session_nfv_tuples()`.
The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how
many were _actually_ created (which may be less due to the maximum number of sessions).
"""
# TODO: Should this be a class method on Batch?
if not batch.data:
return batch.runs
data = []
for batch_datum_list in batch.data:
to_zip = []
for batch_datum in batch_datum_list:
batch_data_items = range(len(batch_datum.items))
to_zip.append(batch_data_items)
data.append(list(zip(*to_zip, strict=True)))
data_product = list(product(*data))
return len(data_product) * batch.runs
ValueToInsertTuple: TypeAlias = tuple[
str, # queue_id
str, # session (as stringified JSON)

View File

@@ -28,7 +28,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItemNotFoundError,
SessionQueueStatus,
ValueToInsertTuple,
calc_session_count,
prepare_values_to_insert,
)
from invokeai.app.services.shared.graph import GraphExecutionState
@@ -118,7 +117,8 @@ class SqliteSessionQueue(SessionQueueBase):
if prepend:
priority = self._get_highest_priority(queue_id) + 1
requested_count = calc_session_count(batch)
requested_count = batch.get_session_count()
values_to_insert = prepare_values_to_insert(
queue_id=queue_id,
batch=batch,

View File

@@ -0,0 +1,204 @@
from copy import deepcopy
from typing import Any, Callable, TypeAlias, get_args
from pydantic import BaseModel, ConfigDict, create_model
from pydantic.fields import FieldInfo
from invokeai.app.services.session_queue.session_queue_common import FieldIdentifier
from invokeai.app.services.shared.graph import Graph
DictOfFieldsMetadata: TypeAlias = dict[str, tuple[type[Any], FieldInfo]]
class ComposedFieldMetadata(BaseModel):
node_id: str
field_name: str
field_type_class_name: str
def dedupe_field_name(field_metadata: DictOfFieldsMetadata, field_name: str) -> str:
"""Given a field name, return a name that is not already in the field metadata.
If the field name is not in the field metadata, return the field name.
If the field name is in the field metadata, generate a new name by appending an underscore and integer to the field name, starting with 2.
"""
if field_name not in field_metadata:
return field_name
i = 2
while True:
new_field_name = f"{field_name}_{i}"
if new_field_name not in field_metadata:
return new_field_name
i += 1
def compose_model_from_fields(
g: Graph,
field_identifiers: list[FieldIdentifier],
composed_model_class_name: str = "ComposedModel",
model_field_overrides: dict[type[Any], tuple[type[Any], FieldInfo]] | None = None,
model_field_filter: Callable[[type[Any]], bool] | None = None,
) -> type[BaseModel]:
"""Given a graph and a list of field identifiers, create a new pydantic model composed of the fields of the nodes in the graph.
The resultant model can be used to validate a JSON payload that contains the fields of the nodes in the graph, or generate an
OpenAPI schema for the model.
Args:
g: The graph containing the nodes whose fields will be composed into the new model.
field_identifiers: A list of FieldIdentifier instances, each representing a field on a node in the graph.
model_name: The name of the composed model.
kind: The kind of model to create. Must be "input" or "output". Defaults to "input".
model_field_overrides: A dictionary mapping type annotations to tuples of (new_type_annotation, new_field_info).
This can be used to override the type annotation and field info of a field in the composed model. For example,
if `ModelIdentifierField` should be replaced by a string, the dictionary would look like this:
```python
{ModelIdentifierField: (str, Field(description="The model id."))}
```
model_field_filter: A function that takes a type annotation and returns True if the field should be included in the composed model.
If None, all fields will be included. For example, to omit `BoardField` fields, the filter would look like this:
```python
def model_field_filter(field_type: type[Any]) -> bool:
return field_type not in {BoardField}
```
Optional fields - or any other complex field types like unions - must be explicitly included in the filter. For example,
to omit `BoardField` _and_ `Optional[BoardField]`:
```python
def model_field_filter(field_type: type[Any]) -> bool:
return field_type not in {BoardField, Optional[BoardField]}
```
Note that the filter is applied to the type annotation of the field, not the field itself.
Example usage:
```python
# Create some nodes.
add_node = AddInvocation()
sub_node = SubtractInvocation()
color_node = ColorInvocation()
# Create a graph with the nodes.
g = Graph(
nodes={
add_node.id: add_node,
sub_node.id: sub_node,
color_node.id: color_node,
}
)
# Select the fields to compose.
fields_to_compose = [
FieldIdentifier(node_id=add_node.id, field_name="a"),
FieldIdentifier(node_id=sub_node.id, field_name="a"), # this will be deduped to "a_2"
FieldIdentifier(node_id=add_node.id, field_name="b"),
FieldIdentifier(node_id=color_node.id, field_name="color"),
]
# Compose the model from the fields.
composed_model = compose_model_from_fields(g, fields_to_compose, model_name="ComposedModel")
# Generate the OpenAPI schema for the model.
json_schema = composed_model.model_json_schema(mode="validation")
```
"""
# Temp storage for the composed fields. Pydantic needs a type annotation and instance of FieldInfo to create a model.
field_metadata: DictOfFieldsMetadata = {}
model_field_overrides = model_field_overrides or {}
# The list of required fields. This is used to ensure the composed model's fields retain their required state.
required: list[str] = []
for field_identifier in field_identifiers:
node_id = field_identifier.node_id
field_name = field_identifier.field_name
# Pull the node instance from the graph so we can introspect it.
node_instance = g.nodes[node_id]
if field_identifier.kind == "input":
# Get the class of the node. This will be a BaseInvocation subclass, e.g. AddInvocation, DenoiseLatentsInvocation, etc.
pydantic_model = type(node_instance)
else:
# Otherwise the the type of the node's output class. This will be a BaseInvocationOutput subclass, e.g. IntegerOutput, ImageOutput, etc.
pydantic_model = type(node_instance).get_output_annotation()
# Get the FieldInfo instance for the field. For example:
# a: int = Field(..., description="The first number to add.")
# ^^^^^ The return value of this Field call is the FieldInfo instance (Field is a function).
og_field_info = pydantic_model.model_fields[field_name]
# Get the type annotation of the field. For example:
# a: int = Field(..., description="The first number to add.")
# ^^^ this is the type annotation
og_field_type = og_field_info.annotation
# Apparently pydantic allows fields without type annotations. We don't support that.
assert og_field_type is not None, (
f"{field_identifier.kind.capitalize()} field {field_name} on node {node_id} has no type annotation."
)
# Now that we have the type annotation, we can apply the filter to see if we should include the field in the composed model.
if model_field_filter and not model_field_filter(og_field_type):
continue
# Ok, we want this type of field. Retrieve any overrides for the field type. This is a dictionary mapping
# type annotations to tuples of (override_type_annotation, override_field_info).
(override_field_type, override_field_info) = model_field_overrides.get(og_field_type, (None, None))
# The override tuple's first element is the new type annotation, if it exists.
composed_field_type = override_field_type if override_field_type is not None else og_field_type
# Create a deep copy of the FieldInfo instance (or override it if it exists) so we can modify it without
# affecting the original. This is important because we are going to modify the FieldInfo instance and
# don't want to affect the original model's schema.
composed_field_info = deepcopy(override_field_info if override_field_info is not None else og_field_info)
json_schema_extra = og_field_info.json_schema_extra if isinstance(og_field_info.json_schema_extra, dict) else {}
# The field's original required state is stored in the json_schema_extra dict. For more information about why,
# see the definition of `InputField` in invokeai/app/invocations/fields.py.
#
# Add the field to the required list if it is required, which we will use when creating the composed model.
if json_schema_extra.get("orig_required", False):
required.append(field_name)
# Invocation fields have some extra metadata, used by the UI to render the field in the frontend. This data is
# included in the OpenAPI schema for each field. For example, we add a "ui_order" field, which the UI uses to
# sort fields when rendering them.
#
# The composed model's OpenAPI schema should not have this information. It should only have a standard OpenAPI
# schema for the field. We need to strip out the UI-specific metadata from the FieldInfo instance before adding
# it to the composed model.
#
# We will replace this metadata with some custom metadata:
# - node_id: The id of the node that this field belongs to.
# - field_name: The name of the field on the node.
# - original_data_type: The original data type of the field.
field_type_class = get_args(og_field_type)[0] if hasattr(og_field_type, "__args__") else og_field_type
field_type_class_name = field_type_class.__name__
composed_field_metadata = ComposedFieldMetadata(
node_id=node_id,
field_name=field_name,
field_type_class_name=field_type_class_name,
)
composed_field_info.json_schema_extra = {
"composed_field_extra": composed_field_metadata.model_dump(),
}
# Override the name, title and description if overrides are provided. Dedupe the field name if necessary.
final_field_name = dedupe_field_name(field_metadata, field_name)
# Store the field metadata.
field_metadata.update({final_field_name: (composed_field_type, composed_field_info)})
# Splat in the composed fields to create the new model. There are type errors here because create_model's kwargs are not typed,
# and for some reason pydantic's ConfigDict doesn't like lists in `json_schema_extra`. Anyways, the inputs here are correct.
return create_model(
composed_model_class_name,
**field_metadata,
__config__=ConfigDict(json_schema_extra={"required": required}),
)

View File

@@ -21,6 +21,7 @@ from invokeai.app.invocations import * # noqa: F401 F403
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationRegistry,
invocation,
invocation_output,
)
@@ -283,7 +284,7 @@ class AnyInvocation(BaseInvocation):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
def validate_invocation(v: Any) -> "AnyInvocation":
return BaseInvocation.get_typeadapter().validate_python(v)
return InvocationRegistry.get_invocation_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation)
@@ -294,7 +295,7 @@ class AnyInvocation(BaseInvocation):
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocation.get_invocations()]
names = [i.__name__ for i in InvocationRegistry.get_invocation_classes()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}
@@ -304,7 +305,7 @@ class AnyInvocationOutput(BaseInvocationOutput):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
return BaseInvocationOutput.get_typeadapter().validate_python(v)
return InvocationRegistry.get_output_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation_output)
@@ -316,7 +317,7 @@ class AnyInvocationOutput(BaseInvocationOutput):
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocationOutput.get_outputs()]
names = [i.__name__ for i in InvocationRegistry.get_output_classes()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}

View File

@@ -47,6 +47,7 @@ class WorkflowRecordsStorageBase(ABC):
query: Optional[str],
tags: Optional[list[str]],
has_been_opened: Optional[bool],
is_published: Optional[bool],
) -> PaginatedResults[WorkflowRecordListItemDTO]:
"""Gets many workflows."""
pass
@@ -56,6 +57,7 @@ class WorkflowRecordsStorageBase(ABC):
self,
categories: list[WorkflowCategory],
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> dict[str, int]:
"""Gets a dictionary of counts for each of the provided categories."""
pass
@@ -66,6 +68,7 @@ class WorkflowRecordsStorageBase(ABC):
tags: list[str],
categories: Optional[list[WorkflowCategory]] = None,
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> dict[str, int]:
"""Gets a dictionary of counts for each of the provided tags."""
pass

View File

@@ -67,6 +67,7 @@ class WorkflowWithoutID(BaseModel):
# This is typed as optional to prevent errors when pulling workflows from the DB. The frontend adds a default form if
# it is None.
form: dict[str, JsonValue] | None = Field(default=None, description="The form of the workflow.")
is_published: bool | None = Field(default=None, description="Whether the workflow is published or not.")
model_config = ConfigDict(extra="ignore")
@@ -101,6 +102,7 @@ class WorkflowRecordDTOBase(BaseModel):
opened_at: Optional[Union[datetime.datetime, str]] = Field(
default=None, description="The opened timestamp of the workflow."
)
is_published: bool | None = Field(default=None, description="Whether the workflow is published or not.")
class WorkflowRecordDTO(WorkflowRecordDTOBase):

View File

@@ -119,6 +119,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
query: Optional[str] = None,
tags: Optional[list[str]] = None,
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
# sanitize!
assert order_by in WorkflowRecordOrderBy
@@ -241,6 +242,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
tags: list[str],
categories: Optional[list[WorkflowCategory]] = None,
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> dict[str, int]:
if not tags:
return {}
@@ -292,6 +294,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
self,
categories: list[WorkflowCategory],
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> dict[str, int]:
cursor = self._conn.cursor()
result: dict[str, int] = {}

View File

@@ -4,7 +4,10 @@ from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from pydantic.json_schema import models_json_schema
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase
from invokeai.app.invocations.baseinvocation import (
InvocationRegistry,
UIConfigBase,
)
from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.events.events_common import EventBase
@@ -56,14 +59,14 @@ def get_openapi_func(
invocation_output_map_required: list[str] = []
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
for output in BaseInvocationOutput.get_outputs():
for output in InvocationRegistry.get_output_classes():
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
move_defs_to_top_level(openapi_schema, json_schema)
openapi_schema["components"]["schemas"][output.__name__] = json_schema
# Technically, invocations are added to the schema by pydantic, but we still need to manually set their output
# property, so we'll just do it all manually.
for invocation in BaseInvocation.get_invocations():
for invocation in InvocationRegistry.get_invocation_classes():
json_schema = invocation.model_json_schema(
mode="serialization", ref_template="#/components/schemas/{model}"
)

View File

@@ -0,0 +1,23 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from invokeai.backend.model_manager.legacy_probe import CkptType
def get_flux_in_channels_from_state_dict(state_dict: "CkptType") -> int | None:
"""Gets the in channels from the state dict."""
# "Standard" FLUX models use "img_in.weight", but some community fine tunes use
# "model.diffusion_model.img_in.weight". Known models that use the latter key:
# - https://civitai.com/models/885098?modelVersionId=990775
# - https://civitai.com/models/1018060?modelVersionId=1596255
# - https://civitai.com/models/978314/ultrareal-fine-tune?modelVersionId=1413133
keys = {"img_in.weight", "model.diffusion_model.img_in.weight"}
for key in keys:
val = state_dict.get(key)
if val is not None:
return val.shape[1]
return None

View File

@@ -30,19 +30,18 @@ from inspect import isabstract
from pathlib import Path
from typing import ClassVar, Literal, Optional, TypeAlias, Union
import safetensors.torch
import torch
from picklescan.scanner import scan_file_path
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.taxonomy import (
AnyVariant,
BaseModelType,
ClipVariantType,
FluxLoRAFormat,
ModelFormat,
ModelRepoVariant,
ModelSourceType,
@@ -51,9 +50,8 @@ from invokeai.backend.model_manager.taxonomy import (
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.silence_warnings import SilenceWarnings
logger = logging.getLogger(__name__)
@@ -97,68 +95,6 @@ class ControlAdapterDefaultSettings(BaseModel):
model_config = ConfigDict(extra="forbid")
class ModelOnDisk:
"""A utility class representing a model stored on disk."""
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
self.path = path
self.format_type = ModelFormat.Diffusers if path.is_dir() else ModelFormat.Checkpoint
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
self.name = path.stem
else:
self.name = path.name
self.hash_algo = hash_algo
def hash(self):
return ModelHash(algorithm=self.hash_algo).hash(self.path)
def size(self):
if self.format_type == ModelFormat.Checkpoint:
return self.path.stat().st_size
return sum(file.stat().st_size for file in self.path.rglob("*"))
def component_paths(self):
if self.format_type == ModelFormat.Checkpoint:
return {self.path}
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
return {f for f in self.path.rglob("*") if f.suffix in extensions}
def repo_variant(self):
if self.format_type == ModelFormat.Checkpoint:
return None
weight_files = list(self.path.glob("**/*.safetensors"))
weight_files.extend(list(self.path.glob("**/*.bin")))
for x in weight_files:
if ".fp16" in x.suffixes:
return ModelRepoVariant.FP16
if "openvino_model" in x.name:
return ModelRepoVariant.OpenVINO
if "flax_model" in x.name:
return ModelRepoVariant.Flax
if x.suffix == ".onnx":
return ModelRepoVariant.ONNX
return ModelRepoVariant.Default
@staticmethod
def load_state_dict(path: Path):
with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
scan_result = scan_file_path(path)
if scan_result.infected_files != 0 or scan_result.scan_err:
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
checkpoint = torch.load(path, map_location="cpu")
elif path.suffix.endswith(".gguf"):
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
elif path.suffix.endswith(".safetensors"):
checkpoint = safetensors.torch.load_file(path)
else:
raise ValueError(f"Unrecognized model extension: {path.suffix}")
state_dict = checkpoint.get("state_dict", checkpoint)
return state_dict
class MatchSpeed(int, Enum):
"""Represents the estimated runtime speed of a config's 'matches' method."""
@@ -233,16 +169,18 @@ class ModelConfigBase(ABC, BaseModel):
Created to deprecate ModelProbe.probe
"""
candidates = ModelConfigBase._USING_CLASSIFY_API
sorted_by_match_speed = sorted(candidates, key=lambda cls: cls._MATCH_SPEED)
sorted_by_match_speed = sorted(candidates, key=lambda cls: (cls._MATCH_SPEED, cls.__name__))
mod = ModelOnDisk(model_path, hash_algo)
for config_cls in sorted_by_match_speed:
try:
return config_cls.from_model_on_disk(mod, **overrides)
except InvalidModelConfigException:
logger.debug(f"ModelConfig '{config_cls.__name__}' failed to parse '{mod.path}', trying next config")
if not config_cls.matches(mod):
continue
except Exception as e:
logger.error(f"Unexpected exception while parsing '{config_cls.__name__}': {e}, trying next config")
logger.warning(f"Unexpected exception while matching {mod.name} to '{config_cls.__name__}': {e}")
continue
else:
return config_cls.from_model_on_disk(mod, **overrides)
raise InvalidModelConfigException("No valid config found")
@@ -282,12 +220,12 @@ class ModelConfigBase(ABC, BaseModel):
if "source_type" in overrides:
overrides["source_type"] = ModelSourceType(overrides["source_type"])
if "variant" in overrides:
overrides["variant"] = ModelVariantType(overrides["variant"])
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
"""Creates an instance of this config or raises InvalidModelConfigException."""
if not cls.matches(mod):
raise InvalidModelConfigException(f"Path {mod.path} does not match {cls.__name__} format")
fields = cls.parse(mod)
cls.cast_overrides(overrides)
fields.update(overrides)
@@ -344,6 +282,38 @@ class LoRAConfigBase(ABC, BaseModel):
type: Literal[ModelType.LoRA] = ModelType.LoRA
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
@classmethod
def flux_lora_format(cls, mod: ModelOnDisk):
key = "FLUX_LORA_FORMAT"
if key in mod.cache:
return mod.cache[key]
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
sd = mod.load_state_dict(mod.path)
value = flux_format_from_state_dict(sd)
mod.cache[key] = value
return value
@classmethod
def base_model(cls, mod: ModelOnDisk) -> BaseModelType:
if cls.flux_lora_format(mod):
return BaseModelType.Flux
state_dict = mod.load_state_dict()
# If we've gotten here, we assume that the model is a Stable Diffusion model
token_vector_length = lora_token_vector_length(state_dict)
if token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif token_vector_length == 1024:
return BaseModelType.StableDiffusion2
elif token_vector_length == 1280:
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException("Unknown LoRA type")
class T5EncoderConfigBase(ABC, BaseModel):
"""Base class for diffusers-style models."""
@@ -359,11 +329,40 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin,
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
class LoRALyCORISConfig(LoRAConfigBase, LegacyProbeMixin, ModelConfigBase):
class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.path.is_dir():
return False
# Avoid false positive match against ControlLoRA and Diffusers
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
return False
state_dict = mod.load_state_dict()
for key in state_dict.keys():
if type(key) is int:
continue
if key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")):
return True
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
if key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
return True
return False
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {
"base": cls.base_model(mod),
}
class ControlAdapterConfigBase(ABC, BaseModel):
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
@@ -387,11 +386,26 @@ class ControlLoRADiffusersConfig(ControlAdapterConfigBase, LegacyProbeMixin, Mod
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class LoRADiffusersConfig(LoRAConfigBase, LegacyProbeMixin, ModelConfigBase):
class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
"""Model config for LoRA/Diffusers models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.path.is_file():
return cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers
suffixes = ["bin", "safetensors"]
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
return any(wf.exists() for wf in weight_files)
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {
"base": cls.base_model(mod),
}
class VAECheckpointConfig(CheckpointConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for standalone VAE models."""
@@ -563,7 +577,7 @@ class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.format_type == ModelFormat.Checkpoint:
if mod.path.is_file():
return False
config_path = mod.path / "config.json"

View File

@@ -14,6 +14,7 @@ from invokeai.backend.flux.controlnet.state_dict_utils import (
is_state_dict_instantx_controlnet,
is_state_dict_xlabs_controlnet,
)
from invokeai.backend.flux.flux_state_dict_utils import get_flux_in_channels_from_state_dict
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
@@ -564,7 +565,14 @@ class CheckpointProbeBase(ProbeBase):
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
if base_type == BaseModelType.Flux:
in_channels = state_dict["img_in.weight"].shape[1]
in_channels = get_flux_in_channels_from_state_dict(state_dict)
if in_channels is None:
# If we cannot find the in_channels, we assume that this is a normal variant. Log a warning.
logger.warning(
f"{self.model_path} does not have img_in.weight or model.diffusion_model.img_in.weight key. Assuming normal variant."
)
return ModelVariantType.Normal
# FLUX Model variant types are distinguished by input channels:
# - Unquantized Dev and Schnell have in_channels=64

View File

@@ -0,0 +1,96 @@
from pathlib import Path
from typing import Any, Optional, TypeAlias
import safetensors.torch
import torch
from picklescan.scanner import scan_file_path
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.util.silence_warnings import SilenceWarnings
StateDict: TypeAlias = dict[str | int, Any] # When are the keys int?
class ModelOnDisk:
"""A utility class representing a model stored on disk."""
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
self.path = path
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
self.name = path.stem
else:
self.name = path.name
self.hash_algo = hash_algo
# Having a cache helps users of ModelOnDisk (i.e. configs) to save state
# This prevents redundant computations during matching and parsing
self.cache = {"_CACHED_STATE_DICTS": {}}
def hash(self) -> str:
return ModelHash(algorithm=self.hash_algo).hash(self.path)
def size(self) -> int:
if self.path.is_file():
return self.path.stat().st_size
return sum(file.stat().st_size for file in self.path.rglob("*"))
def component_paths(self) -> set[Path]:
if self.path.is_file():
return {self.path}
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
return {f for f in self.path.rglob("*") if f.suffix in extensions}
def repo_variant(self) -> Optional[ModelRepoVariant]:
if self.path.is_file():
return None
weight_files = list(self.path.glob("**/*.safetensors"))
weight_files.extend(list(self.path.glob("**/*.bin")))
for x in weight_files:
if ".fp16" in x.suffixes:
return ModelRepoVariant.FP16
if "openvino_model" in x.name:
return ModelRepoVariant.OpenVINO
if "flax_model" in x.name:
return ModelRepoVariant.Flax
if x.suffix == ".onnx":
return ModelRepoVariant.ONNX
return ModelRepoVariant.Default
def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
sd_cache = self.cache["_CACHED_STATE_DICTS"]
if path in sd_cache:
return sd_cache[path]
if not path:
components = list(self.component_paths())
match components:
case []:
raise ValueError("No weight files found for this model")
case [p]:
path = p
case ps if len(ps) >= 2:
raise ValueError(
f"Multiple weight files found for this model: {ps}. "
f"Please specify the intended file using the 'path' argument"
)
with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
scan_result = scan_file_path(path)
if scan_result.infected_files != 0 or scan_result.scan_err:
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
checkpoint = torch.load(path, map_location="cpu")
assert isinstance(checkpoint, dict)
elif path.suffix.endswith(".gguf"):
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
elif path.suffix.endswith(".safetensors"):
checkpoint = safetensors.torch.load_file(path)
else:
raise ValueError(f"Unrecognized model extension: {path.suffix}")
state_dict = checkpoint.get("state_dict", checkpoint)
sd_cache[path] = state_dict
return state_dict

View File

@@ -126,4 +126,13 @@ class ModelSourceType(str, Enum):
HFRepoID = "hf_repo_id"
class FluxLoRAFormat(str, Enum):
"""Flux LoRA formats."""
Diffusers = "flux.diffusers"
Kohya = "flux.kohya"
OneTrainer = "flux.onetrainer"
Control = "flux.control"
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]

View File

@@ -0,0 +1,24 @@
from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
)
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
is_state_dict_likely_in_flux_kohya_format,
)
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
is_state_dict_likely_in_flux_onetrainer_format,
)
def flux_format_from_state_dict(state_dict):
if is_state_dict_likely_in_flux_kohya_format(state_dict):
return FluxLoRAFormat.Kohya
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict):
return FluxLoRAFormat.OneTrainer
elif is_state_dict_likely_in_flux_diffusers_format(state_dict):
return FluxLoRAFormat.Diffusers
elif is_state_dict_likely_flux_control(state_dict):
return FluxLoRAFormat.Control
else:
return None

View File

@@ -62,7 +62,7 @@
"@nanostores/react": "^0.7.3",
"@reduxjs/toolkit": "2.6.1",
"@roarr/browser-log-writer": "^1.3.0",
"@xyflow/react": "^12.4.2",
"@xyflow/react": "^12.5.1",
"async-mutex": "^0.5.0",
"chakra-react-select": "^4.9.2",
"cmdk": "^1.0.0",
@@ -150,7 +150,7 @@
"prettier": "^3.3.3",
"rollup-plugin-visualizer": "^5.12.0",
"storybook": "^8.3.4",
"tsafe": "^1.7.5",
"tsafe": "^1.8.5",
"type-fest": "^4.26.1",
"typescript": "^5.6.2",
"vite": "^6.1.0",

View File

@@ -36,8 +36,8 @@ dependencies:
specifier: ^1.3.0
version: 1.3.0
'@xyflow/react':
specifier: ^12.4.2
version: 12.4.2(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1)
specifier: ^12.5.1
version: 12.5.1(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1)
async-mutex:
specifier: ^0.5.0
version: 0.5.0
@@ -284,8 +284,8 @@ devDependencies:
specifier: ^8.3.4
version: 8.3.4
tsafe:
specifier: ^1.7.5
version: 1.7.5
specifier: ^1.8.5
version: 1.8.5
type-fest:
specifier: ^4.26.1
version: 4.26.1
@@ -3323,7 +3323,7 @@ packages:
/@types/d3-drag@3.0.7:
resolution: {integrity: sha512-HE3jVKlzU9AaMazNufooRJ5ZpWmLIoc90A37WU2JMmeq28w1FQqCZswHZ3xR+SuxYftzHq6WU6KJHvqxKzTxxQ==}
dependencies:
'@types/d3-selection': 3.0.10
'@types/d3-selection': 3.0.11
dev: false
/@types/d3-interpolate@3.0.4:
@@ -3332,21 +3332,21 @@ packages:
'@types/d3-color': 3.1.3
dev: false
/@types/d3-selection@3.0.10:
resolution: {integrity: sha512-cuHoUgS/V3hLdjJOLTT691+G2QoqAjCVLmr4kJXR4ha56w1Zdu8UUQ5TxLRqudgNjwXeQxKMq4j+lyf9sWuslg==}
/@types/d3-selection@3.0.11:
resolution: {integrity: sha512-bhAXu23DJWsrI45xafYpkQ4NtcKMwWnAC/vKrd2l+nxMFuvOT3XMYTIj2opv8vq8AO5Yh7Qac/nSeP/3zjTK0w==}
dev: false
/@types/d3-transition@3.0.8:
resolution: {integrity: sha512-ew63aJfQ/ms7QQ4X7pk5NxQ9fZH/z+i24ZfJ6tJSfqxJMrYLiK01EAs2/Rtw/JreGUsS3pLPNV644qXFGnoZNQ==}
/@types/d3-transition@3.0.9:
resolution: {integrity: sha512-uZS5shfxzO3rGlu0cC3bjmMFKsXv+SmZZcgp0KD22ts4uGXp5EVYGzu/0YdwZeKmddhcAccYtREJKkPfXkZuCg==}
dependencies:
'@types/d3-selection': 3.0.10
'@types/d3-selection': 3.0.11
dev: false
/@types/d3-zoom@3.0.8:
resolution: {integrity: sha512-iqMC4/YlFCSlO8+2Ii1GGGliCAY4XdeG748w5vQUbevlbDu0zSjH/+jojorQVBK/se0j6DUFNPBGSqD3YWYnDw==}
dependencies:
'@types/d3-interpolate': 3.0.4
'@types/d3-selection': 3.0.10
'@types/d3-selection': 3.0.11
dev: false
/@types/diff-match-patch@1.0.36:
@@ -3951,28 +3951,28 @@ packages:
resolution: {integrity: sha512-N8tkAACJx2ww8vFMneJmaAgmjAG1tnVBZJRLRcx061tmsLRZHSEZSLuGWnwPtunsSLvSqXQ2wfp7Mgqg1I+2dQ==}
dev: false
/@xyflow/react@12.4.2(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1):
resolution: {integrity: sha512-AFJKVc/fCPtgSOnRst3xdYJwiEcUN9lDY7EO/YiRvFHYCJGgfzg+jpvZjkTOnBLGyrMJre9378pRxAc3fsR06A==}
/@xyflow/react@12.5.1(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1):
resolution: {integrity: sha512-jMKQVqGwCz0x6pUyvxTIuCMbyehfua7CfEEWDj29zQSHigQpCy0/5d8aOmZrqK4cwur/pVHLQomT6Rm10gXfHg==}
peerDependencies:
react: '>=17'
react-dom: '>=17'
dependencies:
'@xyflow/system': 0.0.50
'@xyflow/system': 0.0.53
classcat: 5.0.5
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
zustand: 4.5.5(@types/react@18.3.11)(react@18.3.1)
zustand: 4.5.6(@types/react@18.3.11)(react@18.3.1)
transitivePeerDependencies:
- '@types/react'
- immer
dev: false
/@xyflow/system@0.0.50:
resolution: {integrity: sha512-HVUZd4LlY88XAaldFh2nwVxDOcdIBxGpQ5txzwfJPf+CAjj2BfYug1fHs2p4yS7YO8H6A3EFJQovBE8YuHkAdg==}
/@xyflow/system@0.0.53:
resolution: {integrity: sha512-QTWieiTtvNYyQAz1fxpzgtUGXNpnhfh6vvZa7dFWpWS2KOz6bEHODo/DTK3s07lDu0Bq0Db5lx/5M5mNjb9VDQ==}
dependencies:
'@types/d3-drag': 3.0.7
'@types/d3-selection': 3.0.10
'@types/d3-transition': 3.0.8
'@types/d3-selection': 3.0.11
'@types/d3-transition': 3.0.9
'@types/d3-zoom': 3.0.8
d3-drag: 3.0.0
d3-selection: 3.0.0
@@ -8791,8 +8791,8 @@ packages:
resolution: {integrity: sha512-tLJxacIQUM82IR7JO1UUkKlYuUTmoY9HBJAmNWFzheSlDS5SPMcNIepejHJa4BpPQLAcbRhRf3GDJzyj6rbKvA==}
dev: false
/tsafe@1.7.5:
resolution: {integrity: sha512-tbNyyBSbwfbilFfiuXkSOj82a6++ovgANwcoqBAcO9/REPoZMEQoE8kWPeO0dy5A2D/2Lajr8Ohue5T0ifIvLQ==}
/tsafe@1.8.5:
resolution: {integrity: sha512-LFWTWQrW6rwSY+IBNFl2ridGfUzVsPwrZ26T4KUJww/py8rzaQ/SY+MIz6YROozpUCaRcuISqagmlwub9YT9kw==}
dev: true
/tsconfck@3.1.5(typescript@5.6.2):
@@ -9123,6 +9123,14 @@ packages:
react: 18.3.1
dev: false
/use-sync-external-store@1.4.0(react@18.3.1):
resolution: {integrity: sha512-9WXSPC5fMv61vaupRkCKCxsPxBocVnwakBEkMIHHpkTTg6icbJtg6jzgtLDm4bl3cSHAca52rYWih0k4K3PfHw==}
peerDependencies:
react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
dependencies:
react: 18.3.1
dev: false
/util-deprecate@1.0.2:
resolution: {integrity: sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==}
dev: true
@@ -9567,8 +9575,8 @@ packages:
/zod@3.23.8:
resolution: {integrity: sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g==}
/zustand@4.5.5(@types/react@18.3.11)(react@18.3.1):
resolution: {integrity: sha512-+0PALYNJNgK6hldkgDq2vLrw5f6g/jCInz52n9RTpropGgeAf/ioFUCdtsjCqu4gNhW9D01rUQBROoRjdzyn2Q==}
/zustand@4.5.6(@types/react@18.3.11)(react@18.3.1):
resolution: {integrity: sha512-ibr/n1hBzLLj5Y+yUcU7dYw8p6WnIVzdJbnX+1YpaScvZVF2ziugqHs+LAmHw4lWO9c/zRj+K1ncgWDQuthEdQ==}
engines: {node: '>=12.7.0'}
peerDependencies:
'@types/react': '>=16.8'
@@ -9584,5 +9592,5 @@ packages:
dependencies:
'@types/react': 18.3.11
react: 18.3.1
use-sync-external-store: 1.2.2(react@18.3.1)
use-sync-external-store: 1.4.0(react@18.3.1)
dev: false

View File

@@ -1783,7 +1783,37 @@
"textPlaceholder": "Empty Text",
"workflowBuilderAlphaWarning": "The workflow builder is currently in alpha. There may be breaking changes before the stable release.",
"minimum": "Minimum",
"maximum": "Maximum"
"maximum": "Maximum",
"publish": "Publish",
"published": "Published",
"unpublish": "Unpublish",
"workflowLocked": "Workflow Locked",
"workflowLockedPublished": "Published workflows are locked for editing.\nYou can unpublish the workflow to edit it, or make a copy of it.",
"workflowLockedDuringPublishing": "Workflow is locked while configuring for publishing.",
"selectOutputNode": "Select Output Node",
"changeOutputNode": "Change Output Node",
"publishedWorkflowOutputs": "Outputs",
"publishedWorkflowInputs": "Inputs",
"unpublishableInputs": "These unpublishable inputs will be omitted",
"noPublishableInputs": "No publishable inputs",
"noOutputNodeSelected": "No output node selected",
"cannotPublish": "Cannot publish workflow",
"publishWarnings": "Warnings",
"errorWorkflowHasUnsavedChanges": "Workflow has unsaved changes",
"errorWorkflowHasBatchOrGeneratorNodes": "Workflow has batch and/or generator nodes",
"errorWorkflowHasInvalidGraph": "Workflow graph invalid (hover Invoke button for details)",
"errorWorkflowHasNoOutputNode": "No output node selected",
"warningWorkflowHasNoPublishableInputFields": "No publishable input fields selected - published workflow will run with only default values",
"warningWorkflowHasUnpublishableInputFields": "Workflow has some unpublishable inputs - these will be omitted from the published workflow",
"publishFailed": "Publish failed",
"publishFailedDesc": "There was a problem publishing the workflow. Please try again.",
"publishSuccess": "Your workflow is being published",
"publishSuccessDesc": "Check your <LinkComponent>Project Dashboard</LinkComponent> to see its progress.",
"publishInProgress": "Publishing in progress",
"publishedWorkflowIsLocked": "Published workflow is locked",
"publishingValidationRun": "Publishing Validation Run",
"publishingValidationRunInProgress": "Publishing validation run in progress.",
"publishedWorkflowsLocked": "Published workflows are locked and cannot be edited or run. Either unpublish the workflow or save a copy to edit or run this workflow."
}
},
"controlLayers": {

View File

@@ -1,7 +0,0 @@
import { createAction } from '@reduxjs/toolkit';
import type { TabName } from 'features/ui/store/uiTypes';
export const enqueueRequested = createAction<{
tabName: TabName;
prepend: boolean;
}>('app/enqueueRequested');

View File

@@ -10,7 +10,6 @@ import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/l
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
import { addGalleryOffsetChangedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryOffsetChanged';
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
@@ -63,7 +62,6 @@ addGalleryImageClickedListener(startAppListening);
addGalleryOffsetChangedListener(startAppListening);
// User Invoked
addEnqueueRequestedNodes(startAppListening);
addEnqueueRequestedLinear(startAppListening);
addEnqueueRequestedUpscale(startAppListening);
addAnyEnqueuedListener(startAppListening);

View File

@@ -5,7 +5,7 @@ import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAd
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig, ImageDTO } from 'services/api/types';
import type { EnqueueBatchArg, ImageDTO } from 'services/api/types';
import type { JsonObject } from 'type-fest';
const log = logger('queue');
@@ -19,7 +19,7 @@ export const addAdHocPostProcessingRequestedListener = (startAppListening: AppSt
const { imageDTO } = action.payload;
const state = getState();
const enqueueBatchArg: BatchConfig = {
const enqueueBatchArg: EnqueueBatchArg = {
prepend: true,
batch: {
graph: await buildAdHocPostProcessingGraph({

View File

@@ -1,5 +1,5 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import { withResult, withResultAsync } from 'common/util/result';
@@ -17,10 +17,11 @@ import { assert, AssertionError } from 'tsafe';
const log = logger('generation');
export const enqueueRequestedCanvas = createAction<{ prepend: boolean }>('app/enqueueRequestedCanvas');
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
startAppListening({
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
enqueueRequested.match(action) && action.payload.tabName === 'canvas',
actionCreator: enqueueRequestedCanvas,
effect: async (action, { getState, dispatch }) => {
log.debug('Enqueue requested');
const state = getState();

View File

@@ -1,5 +1,5 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
@@ -9,10 +9,11 @@ import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endp
const log = logger('generation');
export const enqueueRequestedUpscaling = createAction<{ prepend: boolean }>('app/enqueueRequestedUpscaling');
export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening) => {
startAppListening({
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
enqueueRequested.match(action) && action.payload.tabName === 'upscaling',
actionCreator: enqueueRequestedUpscaling,
effect: async (action, { getState, dispatch }) => {
const state = getState();
const { prepend } = action.payload;

View File

@@ -1,6 +1,8 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { RootState } from 'app/store/store';
import { imageUploadedClientSide } from 'features/gallery/store/actions';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged } from 'features/gallery/store/gallerySlice';
import { toast } from 'features/toast/toast';
@@ -8,7 +10,8 @@ import { t } from 'i18next';
import { omit } from 'lodash-es';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { getCategories, getListImagesUrl } from 'services/api/util';
const log = logger('gallery');
/**
@@ -34,19 +37,56 @@ let lastUploadedToastTimeout: number | null = null;
export const addImageUploadedFulfilledListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.uploadImage.matchFulfilled,
matcher: isAnyOf(imagesApi.endpoints.uploadImage.matchFulfilled, imageUploadedClientSide),
effect: (action, { dispatch, getState }) => {
const imageDTO = action.payload;
let imageDTO: ImageDTO;
let silent;
let isFirstUploadOfBatch = true;
if (imageUploadedClientSide.match(action)) {
imageDTO = action.payload.imageDTO;
silent = action.payload.silent;
isFirstUploadOfBatch = action.payload.isFirstUploadOfBatch;
} else if (imagesApi.endpoints.uploadImage.matchFulfilled(action)) {
imageDTO = action.payload;
silent = action.meta.arg.originalArgs.silent;
isFirstUploadOfBatch = action.meta.arg.originalArgs.isFirstUploadOfBatch ?? true;
} else {
return;
}
if (silent || imageDTO.is_intermediate) {
// If the image is silent or intermediate, we don't want to show a toast
return;
}
if (imageUploadedClientSide.match(action)) {
const categories = getCategories(imageDTO);
const boardId = imageDTO.board_id ?? 'none';
dispatch(
imagesApi.util.invalidateTags([
{
type: 'ImageList',
id: getListImagesUrl({
board_id: boardId,
categories,
}),
},
{
type: 'Board',
id: boardId,
},
{
type: 'BoardImagesTotal',
id: boardId,
},
])
);
}
const state = getState();
log.debug({ imageDTO }, 'Image uploaded');
if (action.meta.arg.originalArgs.silent || imageDTO.is_intermediate) {
// When a "silent" upload is requested, or the image is intermediate, we can skip all post-upload actions,
// like toasts and switching the gallery view
return;
}
const boardId = imageDTO.board_id ?? 'none';
const DEFAULT_UPLOADED_TOAST = {
@@ -80,7 +120,7 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
*
* Default to true to not require _all_ image upload handlers to set this value
*/
const isFirstUploadOfBatch = action.meta.arg.originalArgs.isFirstUploadOfBatch ?? true;
if (isFirstUploadOfBatch) {
dispatch(boardIdSelected({ boardId }));
dispatch(galleryViewChanged('assets'));

View File

@@ -3,6 +3,7 @@ import { autoBatchEnhancer, combineReducers, configureStore } from '@reduxjs/too
import { logger } from 'app/logging/logger';
import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver';
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
import { getDebugLoggerMiddleware } from 'app/store/middleware/debugLoggerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { changeBoardModalSlice } from 'features/changeBoardModal/store/slice';
import { canvasSettingsPersistConfig, canvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
@@ -175,6 +176,7 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
.concat(api.middleware)
.concat(dynamicMiddlewares)
.concat(authToastMiddleware)
.concat(getDebugLoggerMiddleware())
.prepend(listenerMiddleware.middleware),
enhancers: (getDefaultEnhancers) => {
const _enhancers = getDefaultEnhancers().concat(autoBatchEnhancer());

View File

@@ -28,7 +28,8 @@ export type AppFeature =
| 'starterModels'
| 'hfToken'
| 'retryQueueItem'
| 'cancelAndClearAll';
| 'cancelAndClearAll'
| 'deployWorkflow';
/**
* A disable-able Stable Diffusion feature
*/
@@ -73,6 +74,7 @@ export type AppConfig = {
maxUpscaleDimension?: number;
allowPrivateBoards: boolean;
allowPrivateStylePresets: boolean;
allowClientSideUpload: boolean;
disabledTabs: TabName[];
disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[];
@@ -81,7 +83,6 @@ export type AppConfig = {
metadataFetchDebounce?: number;
workflowFetchDebounce?: number;
isLocal?: boolean;
maxImageUploadCount?: number;
sd: {
defaultModel?: string;
disabledControlNetModels: string[];

View File

@@ -0,0 +1,121 @@
import { useStore } from '@nanostores/react';
import { $authToken } from 'app/store/nanostores/authToken';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { imageUploadedClientSide } from 'features/gallery/store/actions';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { useCallback } from 'react';
import { useCreateImageUploadEntryMutation } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
type PresignedUrlResponse = {
fullUrl: string;
thumbnailUrl: string;
};
const isPresignedUrlResponse = (response: unknown): response is PresignedUrlResponse => {
return typeof response === 'object' && response !== null && 'fullUrl' in response && 'thumbnailUrl' in response;
};
export const useClientSideUpload = () => {
const dispatch = useAppDispatch();
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const authToken = useStore($authToken);
const [createImageUploadEntry] = useCreateImageUploadEntryMutation();
const clientSideUpload = useCallback(
async (file: File, i: number): Promise<ImageDTO> => {
const image = new Image();
const objectURL = URL.createObjectURL(file);
image.src = objectURL;
let width = 0;
let height = 0;
let thumbnail: Blob | undefined;
await new Promise<void>((resolve) => {
image.onload = () => {
width = image.naturalWidth;
height = image.naturalHeight;
// Calculate thumbnail dimensions maintaining aspect ratio
let thumbWidth = width;
let thumbHeight = height;
if (width > height && width > 256) {
thumbWidth = 256;
thumbHeight = Math.round((height * 256) / width);
} else if (height > 256) {
thumbHeight = 256;
thumbWidth = Math.round((width * 256) / height);
}
const canvas = document.createElement('canvas');
canvas.width = thumbWidth;
canvas.height = thumbHeight;
const ctx = canvas.getContext('2d');
ctx?.drawImage(image, 0, 0, thumbWidth, thumbHeight);
canvas.toBlob(
(blob) => {
if (blob) {
thumbnail = blob;
// Clean up resources
URL.revokeObjectURL(objectURL);
image.src = ''; // Clear image source
image.remove(); // Remove the image element
canvas.width = 0; // Clear canvas
canvas.height = 0;
resolve();
}
},
'image/webp',
0.8
);
};
// Handle load errors
image.onerror = () => {
URL.revokeObjectURL(objectURL);
image.remove();
resolve();
};
});
const { presigned_url, image_dto } = await createImageUploadEntry({
width,
height,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
}).unwrap();
const response = await fetch(presigned_url, {
method: 'GET',
...(authToken && {
headers: {
Authorization: `Bearer ${authToken}`,
},
}),
}).then((res) => res.json());
if (!isPresignedUrlResponse(response)) {
throw new Error('Invalid response');
}
const fullUrl = response.fullUrl;
const thumbnailUrl = response.thumbnailUrl;
await fetch(fullUrl, {
method: 'PUT',
body: file,
});
await fetch(thumbnailUrl, {
method: 'PUT',
body: thumbnail,
});
dispatch(imageUploadedClientSide({ imageDTO: image_dto, silent: false, isFirstUploadOfBatch: i === 0 }));
return image_dto;
},
[autoAddBoardId, authToken, createImageUploadEntry, dispatch]
);
return clientSideUpload;
};

View File

@@ -14,7 +14,7 @@ export const useGlobalHotkeys = () => {
useRegisteredHotkeys({
id: 'invoke',
category: 'app',
callback: queue.queueBack,
callback: queue.enqueueBack,
options: {
enabled: !queue.isDisabled && !queue.isLoading,
preventDefault: true,
@@ -26,7 +26,7 @@ export const useGlobalHotkeys = () => {
useRegisteredHotkeys({
id: 'invokeFront',
category: 'app',
callback: queue.queueFront,
callback: queue.enqueueFront,
options: {
enabled: !queue.isDisabled && !queue.isLoading,
preventDefault: true,

View File

@@ -3,7 +3,7 @@ import { IconButton } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import { useAppSelector } from 'app/store/storeHooks';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { selectIsClientSideUploadEnabled } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import type { FileRejection } from 'react-dropzone';
@@ -15,6 +15,7 @@ import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
import type { SetOptional } from 'type-fest';
import { useClientSideUpload } from './useClientSideUpload';
type UseImageUploadButtonArgs =
| {
isDisabled?: boolean;
@@ -50,51 +51,65 @@ const log = logger('gallery');
*/
export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: UseImageUploadButtonArgs) => {
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const isClientSideUploadEnabled = useAppSelector(selectIsClientSideUploadEnabled);
const [uploadImage, request] = useUploadImageMutation();
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const clientSideUpload = useClientSideUpload();
const { t } = useTranslation();
const onDropAccepted = useCallback(
async (files: File[]) => {
if (!allowMultiple) {
if (files.length > 1) {
log.warn('Multiple files dropped but only one allowed');
return;
}
if (files.length === 0) {
// Should never happen
log.warn('No files dropped');
return;
}
const file = files[0];
assert(file !== undefined); // should never happen
const imageDTO = await uploadImage({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
silent: true,
}).unwrap();
if (onUpload) {
onUpload(imageDTO);
}
} else {
const imageDTOs = await uploadImages(
files.map((file, i) => ({
try {
if (!allowMultiple) {
if (files.length > 1) {
log.warn('Multiple files dropped but only one allowed');
return;
}
if (files.length === 0) {
// Should never happen
log.warn('No files dropped');
return;
}
const file = files[0];
assert(file !== undefined); // should never happen
const imageDTO = await uploadImage({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
silent: false,
isFirstUploadOfBatch: i === 0,
}))
);
if (onUpload) {
onUpload(imageDTOs);
silent: true,
}).unwrap();
if (onUpload) {
onUpload(imageDTO);
}
} else {
let imageDTOs: ImageDTO[] = [];
if (isClientSideUploadEnabled) {
imageDTOs = await Promise.all(files.map((file, i) => clientSideUpload(file, i)));
} else {
imageDTOs = await uploadImages(
files.map((file, i) => ({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
silent: false,
isFirstUploadOfBatch: i === 0,
}))
);
}
if (onUpload) {
onUpload(imageDTOs);
}
}
} catch (error) {
toast({
id: 'UPLOAD_FAILED',
title: t('toast.imageUploadFailed'),
status: 'error',
});
}
},
[allowMultiple, autoAddBoardId, onUpload, uploadImage]
[allowMultiple, autoAddBoardId, onUpload, uploadImage, isClientSideUploadEnabled, clientSideUpload, t]
);
const onDropRejected = useCallback(
@@ -105,10 +120,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
file: rejection.file.path,
}));
log.error({ errors }, 'Invalid upload');
const description =
maxImageUploadCount === undefined
? t('toast.uploadFailedInvalidUploadDesc')
: t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount });
const description = t('toast.uploadFailedInvalidUploadDesc');
toast({
id: 'UPLOAD_FAILED',
@@ -120,7 +132,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
return;
}
},
[maxImageUploadCount, t]
[t]
);
const {
@@ -137,8 +149,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
onDropRejected,
disabled: isDisabled,
noDrag: true,
multiple: allowMultiple && (maxImageUploadCount === undefined || maxImageUploadCount > 1),
maxFiles: maxImageUploadCount,
multiple: allowMultiple,
});
return { getUploadButtonProps, getUploadInputProps, openUploader, request };

View File

@@ -14,8 +14,9 @@ import WavyLine from 'common/components/WavyLine';
import { selectImg2imgStrength, setImg2imgStrength } from 'features/controlLayers/store/paramsSlice';
import { selectActiveRasterLayerEntities } from 'features/controlLayers/store/selectors';
import { selectImg2imgStrengthConfig } from 'features/system/store/configSlice';
import { memo, useCallback } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useSelectedModelConfig } from 'services/api/hooks/useSelectedModelConfig';
const selectHasRasterLayersWithContent = createSelector(
selectActiveRasterLayerEntities,
@@ -26,6 +27,7 @@ export const ParamDenoisingStrength = memo(() => {
const img2imgStrength = useAppSelector(selectImg2imgStrength);
const dispatch = useAppDispatch();
const hasRasterLayersWithContent = useAppSelector(selectHasRasterLayersWithContent);
const selectedModelConfig = useSelectedModelConfig();
const onChange = useCallback(
(v: number) => {
@@ -39,8 +41,24 @@ export const ParamDenoisingStrength = memo(() => {
const [invokeBlue300] = useToken('colors', ['invokeBlue.300']);
const isDisabled = useMemo(() => {
if (!hasRasterLayersWithContent) {
// Denoising strength does nothing if there are no raster layers w/ content
return true;
}
if (
selectedModelConfig?.type === 'main' &&
selectedModelConfig?.base === 'flux' &&
selectedModelConfig.variant === 'inpaint'
) {
// Denoising strength is ignored by FLUX Fill, which is indicated by the variant being 'inpaint'
return true;
}
return false;
}, [hasRasterLayersWithContent, selectedModelConfig]);
return (
<FormControl isDisabled={!hasRasterLayersWithContent} p={1} justifyContent="space-between" h={8}>
<FormControl isDisabled={isDisabled} p={1} justifyContent="space-between" h={8}>
<Flex gap={3} alignItems="center">
<InformationalPopover feature="paramDenoisingStrength">
<FormLabel mr={0}>{`${t('parameters.denoisingStrength')}`}</FormLabel>
@@ -49,7 +67,7 @@ export const ParamDenoisingStrength = memo(() => {
<WavyLine amplitude={img2imgStrength * 10} stroke={invokeBlue300} strokeWidth={1} width={40} height={14} />
)}
</Flex>
{hasRasterLayersWithContent ? (
{!isDisabled ? (
<>
<CompositeSlider
step={config.coarseStep}

View File

@@ -54,7 +54,7 @@ import { atom, computed } from 'nanostores';
import type { Logger } from 'roarr';
import { getImageDTO } from 'services/api/endpoints/images';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig, ImageDTO, S } from 'services/api/types';
import type { EnqueueBatchArg, ImageDTO, S } from 'services/api/types';
import { QueueError } from 'services/events/errors';
import type { Param0 } from 'tsafe';
import { assert } from 'tsafe';
@@ -291,7 +291,7 @@ export class CanvasStateApiModule extends CanvasModuleBase {
*/
const origin = getPrefixedId(graph.id);
const batch: BatchConfig = {
const batch: EnqueueBatchArg = {
prepend,
batch: {
graph: graph.getGraph(),

View File

@@ -8,12 +8,13 @@ import { useStore } from '@nanostores/react';
import { getStore } from 'app/store/nanostores/store';
import { useAppSelector } from 'app/store/storeHooks';
import { $focusedRegion } from 'common/hooks/focus';
import { useClientSideUpload } from 'common/hooks/useClientSideUpload';
import { setFileToPaste } from 'features/controlLayers/components/CanvasPasteModal';
import { DndDropOverlay } from 'features/dnd/DndDropOverlay';
import type { DndTargetState } from 'features/dnd/types';
import { $imageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { selectIsClientSideUploadEnabled } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useEffect, useRef, useState } from 'react';
@@ -53,13 +54,6 @@ const zUploadFile = z
(file) => ({ message: `File extension .${file.name.split('.').at(-1)} is not supported` })
);
const getFilesSchema = (max?: number) => {
if (max === undefined) {
return z.array(zUploadFile);
}
return z.array(zUploadFile).max(max);
};
const sx = {
position: 'absolute',
top: 2,
@@ -74,22 +68,19 @@ const sx = {
export const FullscreenDropzone = memo(() => {
const { t } = useTranslation();
const ref = useRef<HTMLDivElement>(null);
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const [dndState, setDndState] = useState<DndTargetState>('idle');
const activeTab = useAppSelector(selectActiveTab);
const isImageViewerOpen = useStore($imageViewer);
const isClientSideUploadEnabled = useAppSelector(selectIsClientSideUploadEnabled);
const clientSideUpload = useClientSideUpload();
const validateAndUploadFiles = useCallback(
(files: File[]) => {
async (files: File[]) => {
const { getState } = getStore();
const uploadFilesSchema = getFilesSchema(maxImageUploadCount);
const parseResult = uploadFilesSchema.safeParse(files);
const parseResult = z.array(zUploadFile).safeParse(files);
if (!parseResult.success) {
const description =
maxImageUploadCount === undefined
? t('toast.uploadFailedInvalidUploadDesc')
: t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount });
const description = t('toast.uploadFailedInvalidUploadDesc');
toast({
id: 'UPLOAD_FAILED',
@@ -118,17 +109,23 @@ export const FullscreenDropzone = memo(() => {
const autoAddBoardId = selectAutoAddBoardId(getState());
const uploadArgs: UploadImageArg[] = files.map((file, i) => ({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
isFirstUploadOfBatch: i === 0,
}));
if (isClientSideUploadEnabled) {
for (const [i, file] of files.entries()) {
await clientSideUpload(file, i);
}
} else {
const uploadArgs: UploadImageArg[] = files.map((file, i) => ({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
isFirstUploadOfBatch: i === 0,
}));
uploadImages(uploadArgs);
uploadImages(uploadArgs);
}
},
[activeTab, isImageViewerOpen, maxImageUploadCount, t]
[activeTab, isImageViewerOpen, t, isClientSideUploadEnabled, clientSideUpload]
);
const onPaste = useCallback(

View File

@@ -1,31 +1,18 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { t } from 'i18next';
import { useMemo } from 'react';
import { PiUploadBold } from 'react-icons/pi';
export const GalleryUploadButton = () => {
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const uploadOptions = useMemo(() => ({ allowMultiple: maxImageUploadCount !== 1 }), [maxImageUploadCount]);
const uploadApi = useImageUploadButton(uploadOptions);
const uploadApi = useImageUploadButton({ allowMultiple: true });
return (
<>
<IconButton
size="sm"
alignSelf="stretch"
variant="link"
aria-label={
maxImageUploadCount === undefined || maxImageUploadCount > 1
? t('accessibility.uploadImages')
: t('accessibility.uploadImage')
}
tooltip={
maxImageUploadCount === undefined || maxImageUploadCount > 1
? t('accessibility.uploadImages')
: t('accessibility.uploadImage')
}
aria-label={t('accessibility.uploadImages')}
tooltip={t('accessibility.uploadImages')}
icon={<PiUploadBold />}
{...uploadApi.getUploadButtonProps()}
/>

View File

@@ -1,4 +1,5 @@
import { createAction } from '@reduxjs/toolkit';
import type { ImageDTO } from 'services/api/types';
export const sentImageToCanvas = createAction('gallery/sentImageToCanvas');
@@ -7,3 +8,9 @@ export const imageDownloaded = createAction('gallery/imageDownloaded');
export const imageCopiedToClipboard = createAction('gallery/imageCopiedToClipboard');
export const imageOpenedInNewTab = createAction('gallery/imageOpenedInNewTab');
export const imageUploadedClientSide = createAction<{
imageDTO: ImageDTO;
silent: boolean;
isFirstUploadOfBatch: boolean;
}>('gallery/imageUploadedClientSide');

View File

@@ -2,7 +2,9 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { FocusRegionWrapper } from 'common/components/FocusRegionWrapper';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { AddNodeCmdk } from 'features/nodes/components/flow/AddNodeCmdk/AddNodeCmdk';
import TopPanel from 'features/nodes/components/flow/panels/TopPanel/TopPanel';
import { TopCenterPanel } from 'features/nodes/components/flow/panels/TopPanel/TopCenterPanel';
import { TopLeftPanel } from 'features/nodes/components/flow/panels/TopPanel/TopLeftPanel';
import { TopRightPanel } from 'features/nodes/components/flow/panels/TopPanel/TopRightPanel';
import WorkflowEditorSettings from 'features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -32,7 +34,9 @@ const NodeEditor = () => {
<>
<Flow />
<AddNodeCmdk />
<TopPanel />
<TopLeftPanel />
<TopCenterPanel />
<TopRightPanel />
<BottomLeftPanel />
<MinimapPanel />
</>

View File

@@ -18,6 +18,7 @@ import { CommandEmpty, CommandItem, CommandList, CommandRoot } from 'cmdk';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
import {
$addNodeCmdk,
$cursorPos,
@@ -146,6 +147,7 @@ export const AddNodeCmdk = memo(() => {
const [searchTerm, setSearchTerm] = useState('');
const addNode = useAddNode();
const tab = useAppSelector(selectActiveTab);
const isLocked = useIsWorkflowEditorLocked();
// Filtering the list is expensive - debounce the search term to avoid stutters
const [debouncedSearchTerm] = useDebounce(searchTerm, 300);
const isOpen = useStore($addNodeCmdk);
@@ -160,8 +162,8 @@ export const AddNodeCmdk = memo(() => {
id: 'addNode',
category: 'workflows',
callback: open,
options: { enabled: tab === 'workflows', preventDefault: true },
dependencies: [open, tab],
options: { enabled: tab === 'workflows' && !isLocked, preventDefault: true },
dependencies: [open, tab, isLocked],
});
const onChange = useCallback((e: ChangeEvent<HTMLInputElement>) => {

View File

@@ -4,6 +4,7 @@ import type {
EdgeChange,
HandleType,
NodeChange,
NodeMouseHandler,
OnEdgesChange,
OnInit,
OnMoveEnd,
@@ -16,8 +17,10 @@ import type {
import { Background, ReactFlow, useStore as useReactFlowStore, useUpdateNodeInternals } from '@xyflow/react';
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
import { $isSelectingOutputNode, $outputNodeId } from 'features/nodes/components/sidePanel/workflow/publish';
import { useConnection } from 'features/nodes/hooks/useConnection';
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
import { useNodeCopyPaste } from 'features/nodes/hooks/useNodeCopyPaste';
import { useSyncExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
import {
@@ -44,7 +47,7 @@ import {
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import { selectSelectionMode, selectShouldSnapToGrid } from 'features/nodes/store/workflowSettingsSlice';
import { NO_DRAG_CLASS, NO_PAN_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';
import { type AnyEdge, type AnyNode, isInvocationNode } from 'features/nodes/types/invocation';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import type { CSSProperties, MouseEvent } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
@@ -92,6 +95,8 @@ export const Flow = memo(() => {
const updateNodeInternals = useUpdateNodeInternals();
const store = useAppStore();
const isWorkflowsFocused = useIsRegionFocused('workflows');
const isLocked = useIsWorkflowEditorLocked();
useFocusRegion('workflows', flowWrapper);
useSyncExecutionState();
@@ -215,7 +220,7 @@ export const Flow = memo(() => {
id: 'copySelection',
category: 'workflows',
callback: copySelection,
options: { preventDefault: true },
options: { enabled: isWorkflowsFocused && !isLocked, preventDefault: true },
dependencies: [copySelection],
});
@@ -244,24 +249,24 @@ export const Flow = memo(() => {
id: 'selectAll',
category: 'workflows',
callback: selectAll,
options: { enabled: isWorkflowsFocused, preventDefault: true },
dependencies: [selectAll, isWorkflowsFocused],
options: { enabled: isWorkflowsFocused && !isLocked, preventDefault: true },
dependencies: [selectAll, isWorkflowsFocused, isLocked],
});
useRegisteredHotkeys({
id: 'pasteSelection',
category: 'workflows',
callback: pasteSelection,
options: { enabled: isWorkflowsFocused, preventDefault: true },
dependencies: [pasteSelection],
options: { enabled: isWorkflowsFocused && !isLocked, preventDefault: true },
dependencies: [pasteSelection, isLocked, isWorkflowsFocused],
});
useRegisteredHotkeys({
id: 'pasteSelectionWithEdges',
category: 'workflows',
callback: pasteSelectionWithEdges,
options: { enabled: isWorkflowsFocused, preventDefault: true },
dependencies: [pasteSelectionWithEdges],
options: { enabled: isWorkflowsFocused && !isLocked, preventDefault: true },
dependencies: [pasteSelectionWithEdges, isLocked, isWorkflowsFocused],
});
useRegisteredHotkeys({
@@ -270,8 +275,8 @@ export const Flow = memo(() => {
callback: () => {
dispatch(undo());
},
options: { enabled: isWorkflowsFocused && mayUndo, preventDefault: true },
dependencies: [mayUndo],
options: { enabled: isWorkflowsFocused && !isLocked && mayUndo, preventDefault: true },
dependencies: [mayUndo, isLocked, isWorkflowsFocused],
});
useRegisteredHotkeys({
@@ -280,8 +285,8 @@ export const Flow = memo(() => {
callback: () => {
dispatch(redo());
},
options: { enabled: isWorkflowsFocused && mayRedo, preventDefault: true },
dependencies: [mayRedo],
options: { enabled: isWorkflowsFocused && !isLocked && mayRedo, preventDefault: true },
dependencies: [mayRedo, isLocked, isWorkflowsFocused],
});
const onEscapeHotkey = useCallback(() => {
@@ -318,10 +323,22 @@ export const Flow = memo(() => {
id: 'deleteSelection',
category: 'workflows',
callback: deleteSelection,
options: { preventDefault: true, enabled: isWorkflowsFocused },
dependencies: [deleteSelection, isWorkflowsFocused],
options: { preventDefault: true, enabled: isWorkflowsFocused && !isLocked },
dependencies: [deleteSelection, isWorkflowsFocused, isLocked],
});
const onNodeClick = useCallback<NodeMouseHandler<AnyNode>>((e, node) => {
if (!$isSelectingOutputNode.get()) {
return;
}
if (!isInvocationNode(node)) {
return;
}
const { id } = node.data;
$outputNodeId.set(id);
$isSelectingOutputNode.set(false);
}, []);
return (
<ReactFlow<AnyNode, AnyEdge>
id="workflow-editor"
@@ -332,6 +349,7 @@ export const Flow = memo(() => {
nodes={nodes}
edges={edges}
onInit={onInit}
onNodeClick={onNodeClick}
onMouseMove={onMouseMove}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
@@ -344,6 +362,12 @@ export const Flow = memo(() => {
onMoveEnd={handleMoveEnd}
connectionLineComponent={CustomConnectionLine}
isValidConnection={isValidConnection}
edgesFocusable={!isLocked}
edgesReconnectable={!isLocked}
nodesDraggable={!isLocked}
nodesConnectable={!isLocked}
nodesFocusable={!isLocked}
elementsSelectable={!isLocked}
minZoom={0.1}
snapToGrid={shouldSnapToGrid}
snapGrid={snapGrid}

View File

@@ -1,5 +1,5 @@
import { Handle, Position } from '@xyflow/react';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { useNodeTemplateOrThrow } from 'features/nodes/hooks/useNodeTemplateOrThrow';
import { map } from 'lodash-es';
import type { CSSProperties } from 'react';
import { memo } from 'react';
@@ -19,7 +19,7 @@ const collapsedHandleStyles: CSSProperties = {
};
const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => {
const template = useNodeTemplate(nodeId);
const template = useNodeTemplateOrThrow(nodeId);
if (!template) {
return null;

View File

@@ -1,9 +1,9 @@
import { Flex, Icon, Text, Tooltip } from '@invoke-ai/ui-library';
import { compare } from 'compare-versions';
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
import { useInvocationNodeNotes } from 'features/nodes/hooks/useNodeNotes';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { useNodeTemplateOrThrow } from 'features/nodes/hooks/useNodeTemplateOrThrow';
import { useNodeUserTitleSafe } from 'features/nodes/hooks/useNodeUserTitleSafe';
import { useNodeVersion } from 'features/nodes/hooks/useNodeVersion';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -27,9 +27,9 @@ InvocationNodeInfoIcon.displayName = 'InvocationNodeInfoIcon';
const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
const notes = useInvocationNodeNotes(nodeId);
const label = useNodeLabel(nodeId);
const label = useNodeUserTitleSafe(nodeId);
const version = useNodeVersion(nodeId);
const nodeTemplate = useNodeTemplate(nodeId);
const nodeTemplate = useNodeTemplateOrThrow(nodeId);
const { t } = useTranslation();
const title = useMemo(() => {

View File

@@ -8,7 +8,7 @@ import {
Textarea,
} from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useInputFieldDescriptionSafe } from 'features/nodes/hooks/useInputFieldDescriptionSafe';
import { useInputFieldUserDescriptionSafe } from 'features/nodes/hooks/useInputFieldUserDescriptionSafe';
import { fieldDescriptionChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_PAN_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { ChangeEvent } from 'react';
@@ -48,7 +48,7 @@ InputFieldDescriptionPopover.displayName = 'InputFieldDescriptionPopover';
const Content = memo(({ nodeId, fieldName }: Props) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const description = useInputFieldDescriptionSafe(nodeId, fieldName);
const description = useInputFieldUserDescriptionSafe(nodeId, fieldName);
const onChange = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(fieldDescriptionChanged({ nodeId, fieldName, val: e.target.value }));

View File

@@ -7,7 +7,7 @@ import { InputFieldResetToDefaultValueIconButton } from 'features/nodes/componen
import { useNodeFieldDnd } from 'features/nodes/components/sidePanel/builder/dnd-hooks';
import { useInputFieldIsConnected } from 'features/nodes/hooks/useInputFieldIsConnected';
import { useInputFieldIsInvalid } from 'features/nodes/hooks/useInputFieldIsInvalid';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
import { NO_DRAG_CLASS } from 'features/nodes/types/constants';
import type { FieldInputTemplate } from 'features/nodes/types/field';
import { memo, useRef } from 'react';
@@ -100,7 +100,7 @@ const DirectField = memo(({ nodeId, fieldName, isInvalid, isConnected, fieldTemp
const draggableRef = useRef<HTMLDivElement>(null);
const dragHandleRef = useRef<HTMLDivElement>(null);
const isDragging = useNodeFieldDnd({ nodeId, fieldName }, fieldTemplate, draggableRef, dragHandleRef);
const isDragging = useNodeFieldDnd(nodeId, fieldName, fieldTemplate, draggableRef, dragHandleRef);
return (
<InputFieldWrapper>

View File

@@ -7,7 +7,8 @@ import {
useIsConnectionInProgress,
useIsConnectionStartField,
} from 'features/nodes/hooks/useFieldConnectionState';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import type { FieldInputTemplate } from 'features/nodes/types/field';
@@ -105,9 +106,16 @@ type HandleCommonProps = {
};
const IdleHandle = memo(({ fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => {
const isLocked = useIsWorkflowEditorLocked();
return (
<Tooltip label={fieldTypeName} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
<Handle type="target" id={fieldTemplate.name} position={Position.Left} style={handleStyles}>
<Handle
type="target"
id={fieldTemplate.name}
position={Position.Left}
style={handleStyles}
isConnectable={!isLocked}
>
<Box
sx={sx}
data-cardinality={fieldTemplate.type.cardinality}
@@ -130,6 +138,7 @@ const ConnectionInProgressHandle = memo(
const { t } = useTranslation();
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
const connectionError = useConnectionErrorTKey(nodeId, fieldName, 'target');
const isLocked = useIsWorkflowEditorLocked();
const tooltip = useMemo(() => {
if (connectionError !== null) {
@@ -140,7 +149,13 @@ const ConnectionInProgressHandle = memo(
return (
<Tooltip label={tooltip} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
<Handle type="target" id={fieldTemplate.name} position={Position.Left} style={handleStyles}>
<Handle
type="target"
id={fieldTemplate.name}
position={Position.Left}
style={handleStyles}
isConnectable={!isLocked}
>
<Box
sx={sx}
data-cardinality={fieldTemplate.type.cardinality}

View File

@@ -17,7 +17,7 @@ import { StringFieldDropdown } from 'features/nodes/components/flow/nodes/Invoca
import { StringFieldInput } from 'features/nodes/components/flow/nodes/Invocation/fields/StringField/StringFieldInput';
import { StringFieldTextarea } from 'features/nodes/components/flow/nodes/Invocation/fields/StringField/StringFieldTextarea';
import { useInputFieldInstance } from 'features/nodes/hooks/useInputFieldInstance';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
import {
isBoardFieldInputInstance,
isBoardFieldInputTemplate,

View File

@@ -9,8 +9,8 @@ import {
useIsConnectionStartField,
} from 'features/nodes/hooks/useFieldConnectionState';
import { useInputFieldIsConnected } from 'features/nodes/hooks/useInputFieldIsConnected';
import { useInputFieldLabelSafe } from 'features/nodes/hooks/useInputFieldLabelSafe';
import { useInputFieldTemplateTitle } from 'features/nodes/hooks/useInputFieldTemplateTitle';
import { useInputFieldTemplateTitleOrThrow } from 'features/nodes/hooks/useInputFieldTemplateTitleOrThrow';
import { useInputFieldUserTitleSafe } from 'features/nodes/hooks/useInputFieldUserTitleSafe';
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
import { HANDLE_TOOLTIP_OPEN_DELAY, NO_FIT_ON_DOUBLE_CLICK_CLASS } from 'features/nodes/types/constants';
import type { MouseEvent } from 'react';
@@ -43,8 +43,8 @@ interface Props {
export const InputFieldTitle = memo((props: Props) => {
const { nodeId, fieldName, isInvalid, isDragging } = props;
const inputRef = useRef<HTMLInputElement>(null);
const label = useInputFieldLabelSafe(nodeId, fieldName);
const fieldTemplateTitle = useInputFieldTemplateTitle(nodeId, fieldName);
const label = useInputFieldUserTitleSafe(nodeId, fieldName);
const fieldTemplateTitle = useInputFieldTemplateTitleOrThrow(nodeId, fieldName);
const { t } = useTranslation();
const isConnected = useInputFieldIsConnected(nodeId, fieldName);
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');

View File

@@ -1,7 +1,7 @@
import { Flex, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library';
import { useInputFieldErrors } from 'features/nodes/hooks/useInputFieldErrors';
import { useInputFieldInstance } from 'features/nodes/hooks/useInputFieldInstance';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import { startCase } from 'lodash-es';
import { memo, useMemo } from 'react';

View File

@@ -7,6 +7,7 @@ import {
useIsConnectionInProgress,
useIsConnectionStartField,
} from 'features/nodes/hooks/useFieldConnectionState';
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
import { useOutputFieldTemplate } from 'features/nodes/hooks/useOutputFieldTemplate';
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
@@ -105,9 +106,17 @@ type HandleCommonProps = {
};
const IdleHandle = memo(({ fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => {
const isLocked = useIsWorkflowEditorLocked();
return (
<Tooltip label={fieldTypeName} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
<Handle type="source" id={fieldTemplate.name} position={Position.Right} style={handleStyles}>
<Handle
type="source"
id={fieldTemplate.name}
position={Position.Right}
style={handleStyles}
isConnectable={!isLocked}
>
<Box
sx={sx}
data-cardinality={fieldTemplate.type.cardinality}
@@ -130,6 +139,7 @@ const ConnectionInProgressHandle = memo(
const { t } = useTranslation();
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
const connectionErrorTKey = useConnectionErrorTKey(nodeId, fieldName, 'target');
const isLocked = useIsWorkflowEditorLocked();
const tooltip = useMemo(() => {
if (connectionErrorTKey !== null) {
@@ -140,7 +150,13 @@ const ConnectionInProgressHandle = memo(
return (
<Tooltip label={tooltip} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
<Handle type="source" id={fieldTemplate.name} position={Position.Right} style={handleStyles}>
<Handle
type="source"
id={fieldTemplate.name}
position={Position.Right}
style={handleStyles}
isConnectable={!isLocked}
>
<Box
sx={sx}
data-cardinality={fieldTemplate.type.cardinality}

View File

@@ -3,8 +3,8 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { useEditable } from 'common/hooks/useEditable';
import { useBatchGroupColorToken } from 'features/nodes/hooks/useBatchGroupColorToken';
import { useBatchGroupId } from 'features/nodes/hooks/useBatchGroupId';
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle';
import { useNodeTemplateTitleSafe } from 'features/nodes/hooks/useNodeTemplateTitleSafe';
import { useNodeUserTitleSafe } from 'features/nodes/hooks/useNodeUserTitleSafe';
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
import { NO_FIT_ON_DOUBLE_CLICK_CLASS } from 'features/nodes/types/constants';
import { memo, useCallback, useMemo, useRef } from 'react';
@@ -17,10 +17,10 @@ type Props = {
const NodeTitle = ({ nodeId, title }: Props) => {
const dispatch = useAppDispatch();
const label = useNodeLabel(nodeId);
const label = useNodeUserTitleSafe(nodeId);
const batchGroupId = useBatchGroupId(nodeId);
const batchGroupColorToken = useBatchGroupColorToken(batchGroupId);
const templateTitle = useNodeTemplateTitle(nodeId);
const templateTitle = useNodeTemplateTitleSafe(nodeId);
const { t } = useTranslation();
const inputRef = useRef<HTMLInputElement>(null);

View File

@@ -1,6 +1,7 @@
import type { ChakraProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, useGlobalMenuClose } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
import { useMouseOverFormField, useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { useNodeExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
import { useZoomToNode } from 'features/nodes/hooks/useZoomToNode';
@@ -62,6 +63,12 @@ const containerSx: SystemStyleObject = {
display: 'block',
shadow: '0 0 0 3px var(--invoke-colors-blue-300)',
},
'&[data-is-editor-locked="true"]': {
'& *': {
cursor: 'not-allowed',
pointerEvents: 'none',
},
},
};
const shadowsSx: SystemStyleObject = {
@@ -98,7 +105,8 @@ const NodeWrapper = (props: NodeWrapperProps) => {
const { nodeId, width, children, selected } = props;
const mouseOverNode = useMouseOverNode(nodeId);
const mouseOverFormField = useMouseOverFormField(nodeId);
const zoomToNode = useZoomToNode();
const zoomToNode = useZoomToNode(nodeId);
const isLocked = useIsWorkflowEditorLocked();
const executionState = useNodeExecutionState(nodeId);
const isInProgress = executionState?.status === zNodeStatus.enum.IN_PROGRESS;
@@ -126,9 +134,9 @@ const NodeWrapper = (props: NodeWrapperProps) => {
// This target is marked as not fitting the view on double click
return;
}
zoomToNode(nodeId);
zoomToNode();
},
[nodeId, zoomToNode]
[zoomToNode]
);
return (
@@ -141,6 +149,7 @@ const NodeWrapper = (props: NodeWrapperProps) => {
sx={containerSx}
width={width || NODE_WIDTH}
opacity={opacity}
data-is-editor-locked={isLocked}
data-is-selected={selected}
data-is-mouse-over-form-field={mouseOverFormField.isMouseOverFormField}
>

View File

@@ -0,0 +1,15 @@
import { Flex } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { WorkflowName } from 'features/nodes/components/sidePanel/WorkflowName';
import { selectWorkflowName } from 'features/nodes/store/workflowSlice';
import { memo } from 'react';
export const TopCenterPanel = memo(() => {
const name = useAppSelector(selectWorkflowName);
return (
<Flex gap={2} top={2} left="50%" transform="translateX(-50%)" position="absolute" pointerEvents="none">
{!!name.length && <WorkflowName />}
</Flex>
);
});
TopCenterPanel.displayName = 'TopCenterPanel';

View File

@@ -0,0 +1,54 @@
import { Alert, AlertDescription, AlertIcon, AlertTitle, Box, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import AddNodeButton from 'features/nodes/components/flow/panels/TopPanel/AddNodeButton';
import UpdateNodesButton from 'features/nodes/components/flow/panels/TopPanel/UpdateNodesButton';
import { $isInPublishFlow, useIsValidationRunInProgress } from 'features/nodes/components/sidePanel/workflow/publish';
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
import { selectWorkflowIsPublished } from 'features/nodes/store/workflowSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
export const TopLeftPanel = memo(() => {
const isLocked = useIsWorkflowEditorLocked();
const isInPublishFlow = useStore($isInPublishFlow);
const isPublished = useAppSelector(selectWorkflowIsPublished);
const isValidationRunInProgress = useIsValidationRunInProgress();
const { t } = useTranslation();
return (
<Flex gap={2} top={2} left={2} position="absolute" alignItems="flex-start" pointerEvents="none">
{!isLocked && (
<Flex gap="2">
<AddNodeButton />
<UpdateNodesButton />
</Flex>
)}
{isLocked && (
<Alert status="info" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
<AlertIcon />
<Box>
<AlertTitle>{t('workflows.builder.workflowLocked')}</AlertTitle>
{isValidationRunInProgress && (
<AlertDescription whiteSpace="pre-wrap">
{t('workflows.builder.publishingValidationRunInProgress')}
</AlertDescription>
)}
{isInPublishFlow && !isValidationRunInProgress && (
<AlertDescription whiteSpace="pre-wrap">
{t('workflows.builder.workflowLockedDuringPublishing')}
</AlertDescription>
)}
{isPublished && (
<AlertDescription whiteSpace="pre-wrap">
{t('workflows.builder.workflowLockedPublished')}
</AlertDescription>
)}
</Box>
</Alert>
)}
</Flex>
);
});
TopLeftPanel.displayName = 'TopLeftPanel';

View File

@@ -1,40 +0,0 @@
import { Flex, IconButton, Spacer } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import AddNodeButton from 'features/nodes/components/flow/panels/TopPanel/AddNodeButton';
import ClearFlowButton from 'features/nodes/components/flow/panels/TopPanel/ClearFlowButton';
import SaveWorkflowButton from 'features/nodes/components/flow/panels/TopPanel/SaveWorkflowButton';
import UpdateNodesButton from 'features/nodes/components/flow/panels/TopPanel/UpdateNodesButton';
import { useWorkflowEditorSettingsModal } from 'features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings';
import { WorkflowName } from 'features/nodes/components/sidePanel/WorkflowName';
import { selectWorkflowName } from 'features/nodes/store/workflowSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiGearSixFill } from 'react-icons/pi';
const TopCenterPanel = () => {
const name = useAppSelector(selectWorkflowName);
const modal = useWorkflowEditorSettingsModal();
const { t } = useTranslation();
return (
<Flex gap={2} top={2} left={2} right={2} position="absolute" alignItems="flex-start" pointerEvents="none">
<Flex gap="2">
<AddNodeButton />
<UpdateNodesButton />
</Flex>
<Spacer />
{!!name.length && <WorkflowName />}
<Spacer />
<ClearFlowButton />
<SaveWorkflowButton />
<IconButton
pointerEvents="auto"
aria-label={t('workflows.workflowEditorMenu')}
icon={<PiGearSixFill />}
onClick={modal.setTrue}
/>
</Flex>
);
};
export default memo(TopCenterPanel);

View File

@@ -0,0 +1,34 @@
import { Flex, IconButton } from '@invoke-ai/ui-library';
import ClearFlowButton from 'features/nodes/components/flow/panels/TopPanel/ClearFlowButton';
import SaveWorkflowButton from 'features/nodes/components/flow/panels/TopPanel/SaveWorkflowButton';
import { useWorkflowEditorSettingsModal } from 'features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings';
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiGearSixFill } from 'react-icons/pi';
export const TopRightPanel = memo(() => {
const modal = useWorkflowEditorSettingsModal();
const isLocked = useIsWorkflowEditorLocked();
const { t } = useTranslation();
if (isLocked) {
return null;
}
return (
<Flex gap={2} top={2} right={2} position="absolute" alignItems="flex-end" pointerEvents="none">
<ClearFlowButton />
<SaveWorkflowButton />
<IconButton
pointerEvents="auto"
aria-label={t('workflows.workflowEditorMenu')}
icon={<PiGearSixFill />}
onClick={modal.setTrue}
/>
</Flex>
);
});
TopRightPanel.displayName = 'TopRightPanel';

View File

@@ -1,5 +1,4 @@
import { Box } from '@invoke-ai/ui-library';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { HorizontalResizeHandle } from 'features/ui/components/tabs/ResizeHandle';
import type { CSSProperties } from 'react';
import { memo, useCallback, useRef } from 'react';
@@ -23,23 +22,21 @@ export const EditModeLeftPanelContent = memo(() => {
return (
<Box position="relative" w="full" h="full">
<ScrollableContent>
<PanelGroup
ref={panelGroupRef}
id="workflow-panel-group"
autoSaveId="workflow-panel-group"
direction="vertical"
style={panelGroupStyles}
>
<Panel id="workflow" collapsible minSize={25}>
<WorkflowFieldsLinearViewPanel />
</Panel>
<HorizontalResizeHandle onDoubleClick={handleDoubleClickHandle} />
<Panel id="inspector" collapsible minSize={25}>
<WorkflowNodeInspectorPanel />
</Panel>
</PanelGroup>
</ScrollableContent>
<PanelGroup
ref={panelGroupRef}
id="workflow-panel-group"
autoSaveId="workflow-panel-group"
direction="vertical"
style={panelGroupStyles}
>
<Panel id="workflow" collapsible minSize={25}>
<WorkflowFieldsLinearViewPanel />
</Panel>
<HorizontalResizeHandle onDoubleClick={handleDoubleClickHandle} />
<Panel id="inspector" collapsible minSize={25}>
<WorkflowNodeInspectorPanel />
</Panel>
</PanelGroup>
</Box>
);
});

View File

@@ -0,0 +1,25 @@
import { Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
import { useSaveOrSaveAsWorkflow } from 'features/workflowLibrary/hooks/useSaveOrSaveAsWorkflow';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCopyBold, PiLockOpenBold } from 'react-icons/pi';
export const PublishedWorkflowPanelContent = memo(() => {
const { t } = useTranslation();
const saveAs = useSaveOrSaveAsWorkflow();
return (
<Flex flexDir="column" w="full" h="full" gap={2} alignItems="center">
<Heading size="md" pt={32}>
{t('workflows.builder.workflowLocked')}
</Heading>
<Text fontSize="md">{t('workflows.builder.publishedWorkflowsLocked')}</Text>
<Button size="md" onClick={saveAs} variant="ghost" leftIcon={<PiCopyBold />}>
{t('common.saveAs')}
</Button>
<Button size="md" onClick={undefined} variant="ghost" leftIcon={<PiLockOpenBold />}>
{t('workflows.builder.unpublish')}
</Button>
</Flex>
);
});
PublishedWorkflowPanelContent.displayName = 'PublishedWorkflowPanelContent';

View File

@@ -2,7 +2,7 @@ import { Flex, Spacer } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { WorkflowListMenuTrigger } from 'features/nodes/components/sidePanel/WorkflowListMenu/WorkflowListMenuTrigger';
import { WorkflowViewEditToggleButton } from 'features/nodes/components/sidePanel/WorkflowViewEditToggleButton';
import { selectWorkflowMode } from 'features/nodes/store/workflowSlice';
import { selectWorkflowIsPublished, selectWorkflowMode } from 'features/nodes/store/workflowSlice';
import { WorkflowLibraryMenu } from 'features/workflowLibrary/components/WorkflowLibraryMenu/WorkflowLibraryMenu';
import { memo } from 'react';
@@ -10,12 +10,13 @@ import SaveWorkflowButton from './SaveWorkflowButton';
export const ActiveWorkflowNameAndActions = memo(() => {
const mode = useAppSelector(selectWorkflowMode);
const isPublished = useAppSelector(selectWorkflowIsPublished);
return (
<Flex w="full" alignItems="center" gap={1} minW={0}>
<WorkflowListMenuTrigger />
<Spacer />
{mode === 'edit' && <SaveWorkflowButton />}
{mode === 'edit' && !isPublished && <SaveWorkflowButton />}
<WorkflowViewEditToggleButton />
<WorkflowLibraryMenu />
</Flex>

View File

@@ -1,22 +1,30 @@
import { Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { EditModeLeftPanelContent } from 'features/nodes/components/sidePanel/EditModeLeftPanelContent';
import { PublishedWorkflowPanelContent } from 'features/nodes/components/sidePanel/PublishedWorkflowPanelContent';
import { $isInPublishFlow } from 'features/nodes/components/sidePanel/workflow/publish';
import { PublishWorkflowPanelContent } from 'features/nodes/components/sidePanel/workflow/PublishWorkflowPanelContent';
import { ActiveWorkflowDescription } from 'features/nodes/components/sidePanel/WorkflowListMenu/ActiveWorkflowDescription';
import { ActiveWorkflowNameAndActions } from 'features/nodes/components/sidePanel/WorkflowListMenu/ActiveWorkflowNameAndActions';
import { selectWorkflowMode } from 'features/nodes/store/workflowSlice';
import { selectWorkflowIsPublished, selectWorkflowMode } from 'features/nodes/store/workflowSlice';
import { memo } from 'react';
import { ViewModeLeftPanelContent } from './viewMode/ViewModeLeftPanelContent';
const WorkflowsTabLeftPanel = () => {
const mode = useAppSelector(selectWorkflowMode);
const isPublished = useAppSelector(selectWorkflowIsPublished);
const isInPublishFlow = useStore($isInPublishFlow);
return (
<Flex w="full" h="full" gap={2} flexDir="column">
<ActiveWorkflowNameAndActions />
{mode === 'view' && <ActiveWorkflowDescription />}
{mode === 'view' && <ViewModeLeftPanelContent />}
{mode === 'edit' && <EditModeLeftPanelContent />}
{isInPublishFlow && <PublishWorkflowPanelContent />}
{!isInPublishFlow && <ActiveWorkflowNameAndActions />}
{!isInPublishFlow && !isPublished && mode === 'view' && <ActiveWorkflowDescription />}
{!isInPublishFlow && !isPublished && mode === 'view' && <ViewModeLeftPanelContent />}
{!isInPublishFlow && !isPublished && mode === 'edit' && <EditModeLeftPanelContent />}
{isPublished && <PublishedWorkflowPanelContent />}
</Flex>
);
};

View File

@@ -67,11 +67,8 @@ FormElementEditModeHeader.displayName = 'FormElementEditModeHeader';
const ZoomToNodeButton = memo(({ element }: { element: NodeFieldElement }) => {
const { t } = useTranslation();
const { nodeId } = element.data.fieldIdentifier;
const zoomToNode = useZoomToNode();
const zoomToNode = useZoomToNode(nodeId);
const mouseOverFormField = useMouseOverFormField(nodeId);
const onClick = useCallback(() => {
zoomToNode(nodeId);
}, [nodeId, zoomToNode]);
return (
<IconButton
@@ -79,7 +76,7 @@ const ZoomToNodeButton = memo(({ element }: { element: NodeFieldElement }) => {
onMouseOut={mouseOverFormField.handleMouseOut}
tooltip={t('workflows.builder.zoomToNode')}
aria-label={t('workflows.builder.zoomToNode')}
onClick={onClick}
onClick={zoomToNode}
icon={<PiGpsFixBold />}
variant="link"
size="sm"

View File

@@ -2,8 +2,8 @@ import { FormHelperText, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { linkifyOptions, linkifySx } from 'common/components/linkify';
import { useEditable } from 'common/hooks/useEditable';
import { useInputFieldDescriptionSafe } from 'features/nodes/hooks/useInputFieldDescriptionSafe';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
import { useInputFieldUserDescriptionSafe } from 'features/nodes/hooks/useInputFieldUserDescriptionSafe';
import { fieldDescriptionChanged } from 'features/nodes/store/nodesSlice';
import type { NodeFieldElement } from 'features/nodes/types/workflow';
import Linkify from 'linkify-react';
@@ -13,7 +13,7 @@ export const NodeFieldElementDescriptionEditable = memo(({ el }: { el: NodeField
const { data } = el;
const { fieldIdentifier } = data;
const dispatch = useAppDispatch();
const description = useInputFieldDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const description = useInputFieldUserDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const inputRef = useRef<HTMLTextAreaElement>(null);

View File

@@ -39,7 +39,7 @@ export const NodeFieldElementEditMode = memo(({ el }: { el: NodeFieldElement })
return (
<Flex ref={draggableRef} id={id} className={NODE_FIELD_CLASS_NAME} sx={sx} data-parent-layout={containerCtx.layout}>
<NodeFieldElementEditModeContent dragHandleRef={dragHandleRef} el={el} isDragging={isDragging} />
<NodeFieldElementOverlay element={el} />
<NodeFieldElementOverlay nodeId={el.data.fieldIdentifier.nodeId} />
<DndListDropIndicator activeDropRegion={activeDropRegion} gap="var(--invoke-space-4)" />
</Flex>
);
@@ -105,9 +105,9 @@ const nodeFieldOverlaySx: SystemStyleObject = {
},
};
const NodeFieldElementOverlay = memo(({ element }: { element: NodeFieldElement }) => {
const mouseOverNode = useMouseOverNode(element.data.fieldIdentifier.nodeId);
const mouseOverFormField = useMouseOverFormField(element.data.fieldIdentifier.nodeId);
export const NodeFieldElementOverlay = memo(({ nodeId }: { nodeId: string }) => {
const mouseOverNode = useMouseOverNode(nodeId);
const mouseOverFormField = useMouseOverFormField(nodeId);
return (
<Box

View File

@@ -1,14 +1,14 @@
import { Flex, FormLabel, Spacer } from '@invoke-ai/ui-library';
import { NodeFieldElementResetToInitialValueIconButton } from 'features/nodes/components/flow/nodes/Invocation/fields/NodeFieldElementResetToInitialValueIconButton';
import { useInputFieldLabelSafe } from 'features/nodes/hooks/useInputFieldLabelSafe';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
import { useInputFieldUserTitleSafe } from 'features/nodes/hooks/useInputFieldUserTitleSafe';
import type { NodeFieldElement } from 'features/nodes/types/workflow';
import { memo, useMemo } from 'react';
export const NodeFieldElementLabel = memo(({ el }: { el: NodeFieldElement }) => {
const { data } = el;
const { fieldIdentifier } = data;
const label = useInputFieldLabelSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const label = useInputFieldUserTitleSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const _label = useMemo(() => label || fieldTemplate.title, [label, fieldTemplate.title]);

View File

@@ -2,8 +2,8 @@ import { Flex, FormLabel, Input, Spacer } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useEditable } from 'common/hooks/useEditable';
import { NodeFieldElementResetToInitialValueIconButton } from 'features/nodes/components/flow/nodes/Invocation/fields/NodeFieldElementResetToInitialValueIconButton';
import { useInputFieldLabelSafe } from 'features/nodes/hooks/useInputFieldLabelSafe';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
import { useInputFieldUserTitleSafe } from 'features/nodes/hooks/useInputFieldUserTitleSafe';
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
import type { NodeFieldElement } from 'features/nodes/types/workflow';
import { memo, useCallback, useRef } from 'react';
@@ -12,7 +12,7 @@ export const NodeFieldElementLabelEditable = memo(({ el }: { el: NodeFieldElemen
const { data } = el;
const { fieldIdentifier } = data;
const dispatch = useAppDispatch();
const label = useInputFieldLabelSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const label = useInputFieldUserTitleSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const inputRef = useRef<HTMLInputElement>(null);

View File

@@ -15,7 +15,7 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { NodeFieldElementFloatSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementFloatSettings';
import { NodeFieldElementIntegerSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementIntegerSettings';
import { NodeFieldElementStringSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementStringSettings';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
import { formElementNodeFieldDataChanged } from 'features/nodes/store/workflowSlice';
import {
isFloatFieldInputTemplate,

View File

@@ -5,8 +5,9 @@ import { InputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/
import { InputFieldRenderer } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer';
import { useContainerContext } from 'features/nodes/components/sidePanel/builder/contexts';
import { NodeFieldElementLabel } from 'features/nodes/components/sidePanel/builder/NodeFieldElementLabel';
import { useInputFieldDescriptionSafe } from 'features/nodes/hooks/useInputFieldDescriptionSafe';
import { useInputFieldTemplateOrThrow, useInputFieldTemplateSafe } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
import { useInputFieldTemplateSafe } from 'features/nodes/hooks/useInputFieldTemplateSafe';
import { useInputFieldUserDescriptionSafe } from 'features/nodes/hooks/useInputFieldUserDescriptionSafe';
import type { NodeFieldElement } from 'features/nodes/types/workflow';
import { NODE_FIELD_CLASS_NAME } from 'features/nodes/types/workflow';
import Linkify from 'linkify-react';
@@ -36,7 +37,7 @@ const useFormatFallbackLabel = () => {
export const NodeFieldElementViewMode = memo(({ el }: { el: NodeFieldElement }) => {
const { id, data } = el;
const { fieldIdentifier, showDescription } = data;
const description = useInputFieldDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const description = useInputFieldUserDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplateSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const containerCtx = useContainerContext();
const formatFallbackLabel = useFormatFallbackLabel();
@@ -69,7 +70,7 @@ NodeFieldElementViewMode.displayName = 'NodeFieldElementViewMode';
const NodeFieldElementViewModeContent = memo(({ el }: { el: NodeFieldElement }) => {
const { data } = el;
const { fieldIdentifier, showDescription } = data;
const description = useInputFieldDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const description = useInputFieldUserDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const _description = useMemo(

View File

@@ -1,4 +1,6 @@
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
import type { DropTargetRecord } from '@atlaskit/pragmatic-drag-and-drop/dist/types/internal-types';
import type { ElementDragPayload } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
import {
draggable,
dropTargetForElements,
@@ -33,7 +35,7 @@ import {
selectFormRootElementId,
selectWorkflowSlice,
} from 'features/nodes/store/workflowSlice';
import type { FieldIdentifier, FieldInputTemplate, StatefulFieldValue } from 'features/nodes/types/field';
import type { FieldInputTemplate, StatefulFieldValue } from 'features/nodes/types/field';
import type { ElementId, FormElement } from 'features/nodes/types/workflow';
import { buildNodeFieldElement, isContainerElement } from 'features/nodes/types/workflow';
import type { RefObject } from 'react';
@@ -58,6 +60,27 @@ const isFormElementDndData = (data: Record<string | symbol, unknown>): data is F
return uniqueFormElementDndKey in data;
};
const uniqueNodeFieldDndKey = Symbol('node-field');
type NodeFieldDndData = {
[uniqueNodeFieldDndKey]: true;
nodeId: string;
fieldName: string;
fieldTemplate: FieldInputTemplate;
};
export const buildNodeFieldDndData = (
nodeId: string,
fieldName: string,
fieldTemplate: FieldInputTemplate
): NodeFieldDndData => ({
[uniqueNodeFieldDndKey]: true,
nodeId,
fieldName,
fieldTemplate,
});
const isNodeFieldDndData = (data: Record<string | symbol, unknown>): data is NodeFieldDndData => {
return uniqueNodeFieldDndKey in data;
};
/**
* Flashes an element by changing its background color. Used to indicate that an element has been moved.
* @param elementId The id of the element to flash
@@ -133,6 +156,27 @@ const useGetInitialValue = () => {
return _getInitialValue;
};
const getSourceElement = (source: ElementDragPayload) => {
if (isNodeFieldDndData(source.data)) {
const { nodeId, fieldName, fieldTemplate } = source.data;
return buildNodeFieldElement(nodeId, fieldName, fieldTemplate.type);
}
if (isFormElementDndData(source.data)) {
return source.data.element;
}
return null;
};
const getTargetElement = (target: DropTargetRecord) => {
if (isFormElementDndData(target.data)) {
return target.data.element;
}
return null;
};
/**
* Singleton hook that monitors for builder dnd events and dispatches actions accordingly.
*/
@@ -156,20 +200,20 @@ export const useBuilderDndMonitor = () => {
useEffect(() => {
return monitorForElements({
canMonitor: ({ source }) => isFormElementDndData(source.data),
canMonitor: ({ source }) => isFormElementDndData(source.data) || isNodeFieldDndData(source.data),
onDrop: ({ location, source }) => {
const target = location.current.dropTargets[0];
if (!target) {
return;
}
if (!isFormElementDndData(source.data) || !isFormElementDndData(target.data)) {
const sourceElement = getSourceElement(source);
const targetElement = getTargetElement(target);
if (!sourceElement || !targetElement) {
return;
}
const sourceElement = source.data.element;
const targetElement = target.data.element;
if (sourceElement.id === targetElement.id) {
// Dropping on self is a no-op
return;
@@ -359,8 +403,15 @@ export const useFormElementDnd = (
element: draggableElement,
// TODO(psyche): This causes a kinda jittery behaviour - need a better heuristic to determine stickiness
getIsSticky: () => false,
canDrop: ({ source }) =>
isFormElementDndData(source.data) && source.data.element.id !== getElement(elementId).parentId,
canDrop: ({ source }) => {
if (isNodeFieldDndData(source.data)) {
return true;
}
if (isFormElementDndData(source.data)) {
return source.data.element.id !== getElement(elementId).parentId;
}
return false;
},
getData: ({ input }) => {
const element = getElement(elementId);
@@ -423,8 +474,16 @@ export const useRootElementDropTarget = (droppableRef: RefObject<HTMLDivElement>
dropTargetForElements({
element: droppableElement,
getIsSticky: () => false,
canDrop: ({ source }) =>
getElement(rootElementId, isContainerElement).data.children.length === 0 && isFormElementDndData(source.data),
canDrop: ({ source }) => {
const rootElement = getElement(rootElementId, isContainerElement);
if (rootElement.data.children.length !== 0) {
return false;
}
if (isNodeFieldDndData(source.data) || isFormElementDndData(source.data)) {
return true;
}
return false;
},
getData: ({ input }) => {
const element = getElement(rootElementId, isContainerElement);
@@ -455,7 +514,8 @@ export const useRootElementDropTarget = (droppableRef: RefObject<HTMLDivElement>
/**
* Hook that provides dnd functionality for node fields.
*
* @param fieldIdentifier The identifier of the node field
* @param nodeId: The id of the node
* @param fieldName: The name of the field
* @param fieldTemplate The template of the node field, required to build the form element
* @param draggableRef The ref of the draggable HTML element
* @param dragHandleRef The ref of the drag handle HTML element
@@ -463,7 +523,8 @@ export const useRootElementDropTarget = (droppableRef: RefObject<HTMLDivElement>
* @returns Whether the node field is currently being dragged
*/
export const useNodeFieldDnd = (
fieldIdentifier: FieldIdentifier,
nodeId: string,
fieldName: string,
fieldTemplate: FieldInputTemplate,
draggableRef: RefObject<HTMLElement>,
dragHandleRef: RefObject<HTMLElement>
@@ -481,12 +542,7 @@ export const useNodeFieldDnd = (
draggable({
element: draggableElement,
dragHandle: dragHandleElement,
getInitialData: () => {
const { nodeId, fieldName } = fieldIdentifier;
const { type } = fieldTemplate;
const element = buildNodeFieldElement(nodeId, fieldName, type);
return buildFormElementDndData(element);
},
getInitialData: () => buildNodeFieldDndData(nodeId, fieldName, fieldTemplate),
onDragStart: () => {
setIsDragging(true);
},
@@ -495,7 +551,7 @@ export const useNodeFieldDnd = (
},
})
);
}, [dragHandleRef, draggableRef, fieldIdentifier, fieldTemplate]);
}, [dragHandleRef, draggableRef, fieldName, fieldTemplate, nodeId]);
return isDragging;
};

View File

@@ -1,6 +1,6 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useInputFieldInstance } from 'features/nodes/hooks/useInputFieldInstance';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
import { formElementAdded, selectFormRootElementId } from 'features/nodes/store/workflowSlice';
import { buildNodeFieldElement } from 'features/nodes/types/workflow';
import { useCallback } from 'react';

View File

@@ -5,7 +5,7 @@ import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableCon
import { InvocationNodeNotesTextarea } from 'features/nodes/components/flow/nodes/Invocation/InvocationNodeNotesTextarea';
import { TemplateGate } from 'features/nodes/components/sidePanel/inspector/NodeTemplateGate';
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { useNodeTemplateOrThrow } from 'features/nodes/hooks/useNodeTemplateOrThrow';
import { useNodeVersion } from 'features/nodes/hooks/useNodeVersion';
import { selectLastSelectedNodeId } from 'features/nodes/store/selectors';
import { memo } from 'react';
@@ -36,7 +36,7 @@ export default memo(InspectorDetailsTab);
const Content = memo(({ nodeId }: { nodeId: string }) => {
const { t } = useTranslation();
const version = useNodeVersion(nodeId);
const template = useNodeTemplate(nodeId);
const template = useNodeTemplateOrThrow(nodeId);
const needsUpdate = useNodeNeedsUpdate(nodeId);
return (

View File

@@ -5,7 +5,7 @@ import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableCon
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { TemplateGate } from 'features/nodes/components/sidePanel/inspector/NodeTemplateGate';
import { useNodeExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { useNodeTemplateOrThrow } from 'features/nodes/hooks/useNodeTemplateOrThrow';
import { selectLastSelectedNodeId } from 'features/nodes/store/selectors';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -37,7 +37,7 @@ const getKey = (result: AnyInvocationOutput, i: number) => `${result.type}-${i}`
const Content = memo(({ nodeId }: { nodeId: string }) => {
const { t } = useTranslation();
const template = useNodeTemplate(nodeId);
const template = useNodeTemplateOrThrow(nodeId);
const nes = useNodeExecutionState(nodeId);
if (!nes || nes.outputs.length === 0) {

View File

@@ -1,8 +1,8 @@
import { Flex, Input, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useEditable } from 'common/hooks/useEditable';
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle';
import { useNodeTemplateTitleSafe } from 'features/nodes/hooks/useNodeTemplateTitleSafe';
import { useNodeUserTitleSafe } from 'features/nodes/hooks/useNodeUserTitleSafe';
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
import { memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
@@ -14,8 +14,8 @@ type Props = {
const InspectorTabEditableNodeTitle = ({ nodeId, title }: Props) => {
const dispatch = useAppDispatch();
const label = useNodeLabel(nodeId);
const templateTitle = useNodeTemplateTitle(nodeId);
const label = useNodeUserTitleSafe(nodeId);
const templateTitle = useNodeTemplateTitleSafe(nodeId);
const { t } = useTranslation();
const inputRef = useRef<HTMLInputElement>(null);
const onChange = useCallback(

View File

@@ -2,7 +2,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { TemplateGate } from 'features/nodes/components/sidePanel/inspector/NodeTemplateGate';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { useNodeTemplateOrThrow } from 'features/nodes/hooks/useNodeTemplateOrThrow';
import { selectLastSelectedNodeId } from 'features/nodes/store/selectors';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -29,7 +29,7 @@ export default memo(NodeTemplateInspector);
const Content = memo(({ nodeId }: { nodeId: string }) => {
const { t } = useTranslation();
const template = useNodeTemplate(nodeId);
const template = useNodeTemplateOrThrow(nodeId);
return <DataViewer data={template} label={t('nodes.nodeTemplate')} bg="base.850" color="base.200" />;
});

View File

@@ -1,4 +1,4 @@
import { useNodeTemplateSafe } from 'features/nodes/hooks/useNodeTemplate';
import { useNodeTemplateSafe } from 'features/nodes/hooks/useNodeTemplateSafe';
import type { PropsWithChildren, ReactNode } from 'react';
import { memo } from 'react';

View File

@@ -0,0 +1,429 @@
import {
Button,
ButtonGroup,
Divider,
Flex,
ListItem,
Spacer,
Text,
Tooltip,
UnorderedList,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { logger } from 'app/logging/logger';
import { $projectUrl } from 'app/store/nanostores/projectId';
import { useAppSelector } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { withResultAsync } from 'common/util/result';
import { parseify } from 'common/util/serialize';
import { ExternalLink } from 'features/gallery/components/ImageViewer/NoContentForViewer';
import { NodeFieldElementOverlay } from 'features/nodes/components/sidePanel/builder/NodeFieldElementEditMode';
import {
$isInPublishFlow,
$isReadyToDoValidationRun,
$isSelectingOutputNode,
$outputNodeId,
$validationRunBatchId,
usePublishInputs,
} from 'features/nodes/components/sidePanel/workflow/publish';
import { useInputFieldTemplateTitleOrThrow } from 'features/nodes/hooks/useInputFieldTemplateTitleOrThrow';
import { useInputFieldUserTitleOrThrow } from 'features/nodes/hooks/useInputFieldUserTitleOrThrow';
import { useMouseOverFormField } from 'features/nodes/hooks/useMouseOverNode';
import { useNodeTemplateTitleOrThrow } from 'features/nodes/hooks/useNodeTemplateTitleOrThrow';
import { useNodeUserTitleOrThrow } from 'features/nodes/hooks/useNodeUserTitleOrThrow';
import { useOutputFieldNames } from 'features/nodes/hooks/useOutputFieldNames';
import { useOutputFieldTemplate } from 'features/nodes/hooks/useOutputFieldTemplate';
import { useZoomToNode } from 'features/nodes/hooks/useZoomToNode';
import { selectHasBatchOrGeneratorNodes } from 'features/nodes/store/selectors';
import { selectIsWorkflowSaved } from 'features/nodes/store/workflowSlice';
import { useEnqueueWorkflows } from 'features/queue/hooks/useEnqueueWorkflows';
import { $isReadyToEnqueue } from 'features/queue/store/readiness';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { toast } from 'features/toast/toast';
import type { PropsWithChildren } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import { PiLightningFill, PiSignOutBold, PiXBold } from 'react-icons/pi';
import { serializeError } from 'serialize-error';
import { assert } from 'tsafe';
const log = logger('generation');
export const PublishWorkflowPanelContent = memo(() => {
return (
<Flex flexDir="column" gap={2} h="full">
<ButtonGroup isAttached={false} size="sm" variant="ghost">
<SelectOutputNodeButton />
<Spacer />
<CancelPublishButton />
<PublishWorkflowButton />
</ButtonGroup>
<ScrollableContent>
<Flex flexDir="column" gap={2} w="full" h="full">
<OutputFields />
<PublishableInputFields />
<UnpublishableInputFields />
</Flex>
</ScrollableContent>
</Flex>
);
});
PublishWorkflowPanelContent.displayName = 'DeployWorkflowPanelContent';
const OutputFields = memo(() => {
const { t } = useTranslation();
const outputNodeId = useStore($outputNodeId);
if (!outputNodeId) {
return (
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
<Text fontWeight="semibold" color="error.300">
{t('workflows.builder.noOutputNodeSelected')}
</Text>
</Flex>
);
}
return <OutputFieldsContent outputNodeId={outputNodeId} />;
});
OutputFields.displayName = 'OutputFields';
const OutputFieldsContent = memo(({ outputNodeId }: { outputNodeId: string }) => {
const { t } = useTranslation();
const outputFieldNames = useOutputFieldNames(outputNodeId);
return (
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
<Text fontWeight="semibold">{t('workflows.builder.publishedWorkflowOutputs')}</Text>
<Divider />
{outputFieldNames.map((fieldName) => (
<NodeOutputFieldPreview key={`${outputNodeId}-${fieldName}`} nodeId={outputNodeId} fieldName={fieldName} />
))}
</Flex>
);
});
OutputFieldsContent.displayName = 'OutputFieldsContent';
const PublishableInputFields = memo(() => {
const { t } = useTranslation();
const inputs = usePublishInputs();
if (inputs.publishable.length === 0) {
return (
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
<Text fontWeight="semibold" color="warning.300">
{t('workflows.builder.noPublishableInputs')}
</Text>
</Flex>
);
}
return (
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
<Text fontWeight="semibold">{t('workflows.builder.publishedWorkflowInputs')}</Text>
<Divider />
{inputs.publishable.map(({ nodeId, fieldName }) => {
return <NodeInputFieldPreview key={`${nodeId}-${fieldName}`} nodeId={nodeId} fieldName={fieldName} />;
})}
</Flex>
);
});
PublishableInputFields.displayName = 'PublishableInputFields';
const UnpublishableInputFields = memo(() => {
const { t } = useTranslation();
const inputs = usePublishInputs();
if (inputs.unpublishable.length === 0) {
return null;
}
return (
<Flex flexDir="column" borderWidth={1} borderRadius="base" gap={2} p={2}>
<Text fontWeight="semibold" color="warning.300">
{t('workflows.builder.unpublishableInputs')}
</Text>
<Divider />
{inputs.unpublishable.map(({ nodeId, fieldName }) => {
return <NodeInputFieldPreview key={`${nodeId}-${fieldName}`} nodeId={nodeId} fieldName={fieldName} />;
})}
</Flex>
);
});
UnpublishableInputFields.displayName = 'UnpublishableInputFields';
const SelectOutputNodeButton = memo(() => {
const { t } = useTranslation();
const outputNodeId = useStore($outputNodeId);
const isSelectingOutputNode = useStore($isSelectingOutputNode);
const onClick = useCallback(() => {
$outputNodeId.set(null);
$isSelectingOutputNode.set(true);
}, []);
return (
<Button leftIcon={<PiSignOutBold />} isDisabled={isSelectingOutputNode} onClick={onClick}>
{outputNodeId ? t('workflows.builder.changeOutputNode') : t('workflows.builder.selectOutputNode')}
</Button>
);
});
SelectOutputNodeButton.displayName = 'SelectOutputNodeButton';
const CancelPublishButton = memo(() => {
const { t } = useTranslation();
const onClick = useCallback(() => {
$isInPublishFlow.set(false);
$isSelectingOutputNode.set(false);
$outputNodeId.set(null);
}, []);
return (
<Button leftIcon={<PiXBold />} onClick={onClick}>
{t('common.cancel')}
</Button>
);
});
CancelPublishButton.displayName = 'CancelDeployButton';
const PublishWorkflowButton = memo(() => {
const { t } = useTranslation();
const isReadyToDoValidationRun = useStore($isReadyToDoValidationRun);
const isReadyToEnqueue = useStore($isReadyToEnqueue);
const isWorkflowSaved = useAppSelector(selectIsWorkflowSaved);
const hasBatchOrGeneratorNodes = useAppSelector(selectHasBatchOrGeneratorNodes);
const outputNodeId = useStore($outputNodeId);
const isSelectingOutputNode = useStore($isSelectingOutputNode);
const inputs = usePublishInputs();
const projectUrl = useStore($projectUrl);
const enqueue = useEnqueueWorkflows();
const onClick = useCallback(async () => {
const result = await withResultAsync(() => enqueue(true, true));
if (result.isErr()) {
toast({
id: 'TOAST_PUBLISH_FAILED',
status: 'error',
title: t('workflows.builder.publishFailed'),
description: t('workflows.builder.publishFailedDesc'),
duration: null,
});
log.error({ error: serializeError(result.error) }, 'Failed to enqueue batch');
} else {
toast({
id: 'TOAST_PUBLISH_SUCCESSFUL',
status: 'success',
title: t('workflows.builder.publishSuccess'),
description: (
<Trans
i18nKey="workflows.builder.publishSuccessDesc"
components={{
LinkComponent: <ExternalLink href={projectUrl ?? ''} />,
}}
/>
),
duration: null,
});
assert(result.value.enqueueResult.batch.batch_id);
$validationRunBatchId.set(result.value.enqueueResult.batch.batch_id);
log.debug(parseify(result.value), 'Enqueued batch');
}
}, [enqueue, projectUrl, t]);
return (
<PublishTooltip
isWorkflowSaved={isWorkflowSaved}
hasBatchOrGeneratorNodes={hasBatchOrGeneratorNodes}
isReadyToEnqueue={isReadyToEnqueue}
hasOutputNode={outputNodeId !== null && !isSelectingOutputNode}
hasPublishableInputs={inputs.publishable.length > 0}
hasUnpublishableInputs={inputs.unpublishable.length > 0}
>
<Button
leftIcon={<PiLightningFill />}
isDisabled={
!isReadyToDoValidationRun ||
!isReadyToEnqueue ||
hasBatchOrGeneratorNodes ||
!(outputNodeId !== null && !isSelectingOutputNode)
}
onClick={onClick}
>
{t('workflows.builder.publish')}
</Button>
</PublishTooltip>
);
});
PublishWorkflowButton.displayName = 'DoValidationRunButton';
const NodeInputFieldPreview = memo(({ nodeId, fieldName }: { nodeId: string; fieldName: string }) => {
const mouseOverFormField = useMouseOverFormField(nodeId);
const nodeUserTitle = useNodeUserTitleOrThrow(nodeId);
const nodeTemplateTitle = useNodeTemplateTitleOrThrow(nodeId);
const fieldUserTitle = useInputFieldUserTitleOrThrow(nodeId, fieldName);
const fieldTemplateTitle = useInputFieldTemplateTitleOrThrow(nodeId, fieldName);
const zoomToNode = useZoomToNode(nodeId);
return (
<Flex
flexDir="column"
position="relative"
p={2}
borderRadius="base"
onMouseOver={mouseOverFormField.handleMouseOver}
onMouseOut={mouseOverFormField.handleMouseOut}
onClick={zoomToNode}
>
<Text fontWeight="semibold">{`${nodeUserTitle || nodeTemplateTitle} -> ${fieldUserTitle || fieldTemplateTitle}`}</Text>
<Text variant="subtext">{`${nodeId} -> ${fieldName}`}</Text>
<NodeFieldElementOverlay nodeId={nodeId} />
</Flex>
);
});
NodeInputFieldPreview.displayName = 'NodeInputFieldPreview';
const NodeOutputFieldPreview = memo(({ nodeId, fieldName }: { nodeId: string; fieldName: string }) => {
const mouseOverFormField = useMouseOverFormField(nodeId);
const nodeUserTitle = useNodeUserTitleOrThrow(nodeId);
const nodeTemplateTitle = useNodeTemplateTitleOrThrow(nodeId);
const fieldTemplate = useOutputFieldTemplate(nodeId, fieldName);
const zoomToNode = useZoomToNode(nodeId);
return (
<Flex
flexDir="column"
position="relative"
p={2}
borderRadius="base"
onMouseOver={mouseOverFormField.handleMouseOver}
onMouseOut={mouseOverFormField.handleMouseOut}
onClick={zoomToNode}
>
<Text fontWeight="semibold">{`${nodeUserTitle || nodeTemplateTitle} -> ${fieldTemplate.title}`}</Text>
<Text variant="subtext">{`${nodeId} -> ${fieldName}`}</Text>
<NodeFieldElementOverlay nodeId={nodeId} />
</Flex>
);
});
NodeOutputFieldPreview.displayName = 'NodeOutputFieldPreview';
export const StartPublishFlowButton = memo(() => {
const { t } = useTranslation();
const deployWorkflowIsEnabled = useFeatureStatus('deployWorkflow');
const isReadyToEnqueue = useStore($isReadyToEnqueue);
const isWorkflowSaved = useAppSelector(selectIsWorkflowSaved);
const hasBatchOrGeneratorNodes = useAppSelector(selectHasBatchOrGeneratorNodes);
const inputs = usePublishInputs();
const onClick = useCallback(() => {
$isInPublishFlow.set(true);
}, []);
return (
<PublishTooltip
isWorkflowSaved={isWorkflowSaved}
hasBatchOrGeneratorNodes={hasBatchOrGeneratorNodes}
isReadyToEnqueue={isReadyToEnqueue}
hasOutputNode={true}
hasPublishableInputs={inputs.publishable.length > 0}
hasUnpublishableInputs={inputs.unpublishable.length > 0}
>
<Button
onClick={onClick}
leftIcon={<PiLightningFill />}
variant="ghost"
size="sm"
isDisabled={!deployWorkflowIsEnabled || !isWorkflowSaved || hasBatchOrGeneratorNodes}
>
{t('workflows.builder.publish')}
</Button>
</PublishTooltip>
);
});
StartPublishFlowButton.displayName = 'StartPublishFlowButton';
const PublishTooltip = memo(
({
isWorkflowSaved,
hasBatchOrGeneratorNodes,
isReadyToEnqueue,
hasOutputNode,
hasPublishableInputs,
hasUnpublishableInputs,
children,
}: PropsWithChildren<{
isWorkflowSaved: boolean;
hasBatchOrGeneratorNodes: boolean;
isReadyToEnqueue: boolean;
hasOutputNode: boolean;
hasPublishableInputs: boolean;
hasUnpublishableInputs: boolean;
}>) => {
const { t } = useTranslation();
const warnings = useMemo(() => {
const _warnings: string[] = [];
if (!hasPublishableInputs) {
_warnings.push(t('workflows.builder.warningWorkflowHasNoPublishableInputFields'));
}
if (hasUnpublishableInputs) {
_warnings.push(t('workflows.builder.warningWorkflowHasUnpublishableInputFields'));
}
return _warnings;
}, [hasPublishableInputs, hasUnpublishableInputs, t]);
const errors = useMemo(() => {
const _errors: string[] = [];
if (!isWorkflowSaved) {
_errors.push(t('workflows.builder.errorWorkflowHasUnsavedChanges'));
}
if (hasBatchOrGeneratorNodes) {
_errors.push(t('workflows.builder.errorWorkflowHasBatchOrGeneratorNodes'));
}
if (!isReadyToEnqueue) {
_errors.push(t('workflows.builder.errorWorkflowHasInvalidGraph'));
}
if (!hasOutputNode) {
_errors.push(t('workflows.builder.errorWorkflowHasNoOutputNode'));
}
return _errors;
}, [hasBatchOrGeneratorNodes, hasOutputNode, isReadyToEnqueue, isWorkflowSaved, t]);
if (errors.length === 0 && warnings.length === 0) {
return children;
}
return (
<Tooltip
label={
<Flex flexDir="column">
{errors.length > 0 && (
<>
<Text color="error.700" fontWeight="semibold">
{t('workflows.builder.cannotPublish')}:
</Text>
<UnorderedList>
{errors.map((problem, index) => (
<ListItem key={index}>{problem}</ListItem>
))}
</UnorderedList>
</>
)}
{warnings.length > 0 && (
<>
<Text color="warning.700" fontWeight="semibold">
{t('workflows.builder.publishWarnings')}:
</Text>
<UnorderedList>
{warnings.map((problem, index) => (
<ListItem key={index}>{problem}</ListItem>
))}
</UnorderedList>
</>
)}
</Flex>
}
>
{children}
</Tooltip>
);
}
);
PublishTooltip.displayName = 'PublishTooltip';

View File

@@ -0,0 +1,23 @@
import { IconButton, Tooltip } from '@invoke-ai/ui-library';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiLockBold } from 'react-icons/pi';
export const LockedWorkflowIcon = memo(() => {
const { t } = useTranslation();
return (
<Tooltip label={t('workflows.builder.publishedWorkflowsLocked')} closeOnScroll>
<IconButton
size="sm"
cursor='not-allowed'
variant="link"
alignSelf="stretch"
aria-label={t('workflows.builder.publishedWorkflowsLocked')}
icon={<PiLockBold />}
/>
</Tooltip>
);
});
LockedWorkflowIcon.displayName = 'LockedWorkflowIcon';

View File

@@ -26,6 +26,7 @@ import {
workflowLibraryTagToggled,
workflowLibraryViewChanged,
} from 'features/nodes/store/workflowLibrarySlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { NewWorkflowButton } from 'features/workflowLibrary/components/NewWorkflowButton';
import { UploadWorkflowButton } from 'features/workflowLibrary/components/UploadWorkflowButton';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
@@ -39,13 +40,12 @@ export const WorkflowLibrarySideNav = () => {
const { t } = useTranslation();
const categoryOptions = useStore($workflowLibraryCategoriesOptions);
const view = useAppSelector(selectWorkflowLibraryView);
const deployWorkflow = useFeatureStatus('deployWorkflow');
return (
<Flex h="full" minH={0} overflow="hidden" flexDir="column" w={64} gap={0}>
<Flex flexDir="column" w="full" pb={2}>
<Flex flexDir="column" w="full" pb={2} gap={2}>
<WorkflowLibraryViewButton view="recent">{t('workflows.recentlyOpened')}</WorkflowLibraryViewButton>
</Flex>
<Flex flexDir="column" w="full" pb={2}>
<WorkflowLibraryViewButton view="yours">{t('workflows.yourWorkflows')}</WorkflowLibraryViewButton>
{categoryOptions.includes('project') && (
<Collapse in={view === 'yours' || view === 'shared' || view === 'private'}>
@@ -60,6 +60,9 @@ export const WorkflowLibrarySideNav = () => {
</Flex>
</Collapse>
)}
{deployWorkflow && (
<WorkflowLibraryViewButton view="published">{t('workflows.publishedWorkflows')}</WorkflowLibraryViewButton>
)}
</Flex>
<Flex h="full" minH={0} overflow="hidden" flexDir="column">
<BrowseWorkflowsButton />

View File

@@ -36,6 +36,8 @@ const getCategories = (view: WorkflowLibraryView): WorkflowCategory[] => {
return ['user'];
case 'shared':
return ['project'];
case 'published':
return ['user', 'project', 'default'];
default:
assert<Equals<typeof view, never>>(false);
}
@@ -66,6 +68,7 @@ const useInfiniteQueryAry = () => {
query: debouncedSearchTerm,
tags: view === 'defaults' ? selectedTags : [],
has_been_opened: getHasBeenOpened(view),
is_published: view === 'published' ? true : undefined,
} satisfies Parameters<typeof useListWorkflowsInfiniteInfiniteQuery>[0];
}, [orderBy, direction, view, debouncedSearchTerm, selectedTags]);

View File

@@ -1,6 +1,7 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Badge, Flex, Icon, Image, Spacer, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { LockedWorkflowIcon } from 'features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibraryListItemActions/LockedWorkflowIcon';
import { ShareWorkflowButton } from 'features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibraryListItemActions/ShareWorkflow';
import { selectWorkflowId, workflowModeChanged } from 'features/nodes/store/workflowSlice';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
@@ -54,7 +55,6 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
position="relative"
role="button"
onClick={handleClickLoad}
cursor="pointer"
bg="base.750"
borderRadius="base"
w="full"
@@ -81,7 +81,7 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
<Flex gap={2} alignItems="flex-start" justifyContent="space-between" w="full">
<Text noOfLines={2}>{workflow.name}</Text>
<Flex gap={2} alignItems="center">
{isActive && (
{isActive && !workflow.is_published && (
<Badge
color="invokeBlue.400"
borderColor="invokeBlue.700"
@@ -93,6 +93,18 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
{t('workflows.opened')}
</Badge>
)}
{workflow.is_published && (
<Badge
color="invokeGreen.400"
borderColor="invokeGreen.700"
borderWidth={1}
bg="transparent"
flexShrink={0}
variant="subtle"
>
{t('workflows.builder.published')}
</Badge>
)}
{workflow.category === 'project' && <Icon as={PiUsersBold} color="base.200" />}
{workflow.category === 'default' && (
<Image
@@ -119,8 +131,10 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
</Text>
)}
<Spacer />
{workflow.category === 'default' && <ViewWorkflow workflowId={workflow.workflow_id} />}
{workflow.category !== 'default' && (
{workflow.category === 'default' && !workflow.is_published && (
<ViewWorkflow workflowId={workflow.workflow_id} />
)}
{workflow.category !== 'default' && !workflow.is_published && (
<>
<EditWorkflow workflowId={workflow.workflow_id} />
<DownloadWorkflow workflowId={workflow.workflow_id} />
@@ -128,6 +142,7 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
</>
)}
{workflow.category === 'project' && <ShareWorkflowButton workflow={workflow} />}
{workflow.is_published && <LockedWorkflowIcon />}
</Flex>
</Flex>
</Flex>

View File

@@ -1,5 +1,7 @@
import { Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import { Spacer, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import { WorkflowBuilder } from 'features/nodes/components/sidePanel/builder/WorkflowBuilder';
import { StartPublishFlowButton } from 'features/nodes/components/sidePanel/workflow/PublishWorkflowPanelContent';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -8,12 +10,15 @@ import WorkflowJSONTab from './WorkflowJSONTab';
const WorkflowFieldsLinearViewPanel = () => {
const { t } = useTranslation();
const deployWorkflowIsEnabled = useFeatureStatus('deployWorkflow');
return (
<Tabs variant="enclosed" display="flex" w="full" h="full" flexDir="column">
<TabList>
<Tab>{t('workflows.builder.builder')}</Tab>
<Tab>{t('common.details')}</Tab>
<Tab>JSON</Tab>
<Spacer />
{deployWorkflowIsEnabled && <StartPublishFlowButton />}
</TabList>
<TabPanels h="full" pt={2}>

View File

@@ -0,0 +1,90 @@
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { $templates } from 'features/nodes/store/nodesSlice';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import type { Templates } from 'features/nodes/store/types';
import { selectWorkflowFormNodeFieldFieldIdentifiersDeduped } from 'features/nodes/store/workflowSlice';
import type { FieldIdentifier } from 'features/nodes/types/field';
import { isBoardFieldType } from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { atom, computed } from 'nanostores';
import { useMemo } from 'react';
import { useGetBatchStatusQuery } from 'services/api/endpoints/queue';
import { assert } from 'tsafe';
export const $isInPublishFlow = atom(false);
export const $outputNodeId = atom<string | null>(null);
export const $isSelectingOutputNode = atom(false);
export const $isReadyToDoValidationRun = computed(
[$isInPublishFlow, $outputNodeId, $isSelectingOutputNode],
(isInPublishFlow, outputNodeId, isSelectingOutputNode) => {
return isInPublishFlow && outputNodeId !== null && !isSelectingOutputNode;
}
);
export const $validationRunBatchId = atom<string | null>(null);
export const useIsValidationRunInProgress = () => {
const validationRunBatchId = useStore($validationRunBatchId);
const { isValidationRunInProgress } = useGetBatchStatusQuery(
validationRunBatchId ? { batch_id: validationRunBatchId } : skipToken,
{
selectFromResult: ({ currentData }) => {
if (!currentData) {
return { isValidationRunInProgress: false };
}
if (currentData && currentData.in_progress > 0) {
return { isValidationRunInProgress: true };
}
return { isValidationRunInProgress: false };
},
}
);
return validationRunBatchId !== null || isValidationRunInProgress;
};
export const selectFieldIdentifiersWithInvocationTypes = createSelector(
selectWorkflowFormNodeFieldFieldIdentifiersDeduped,
selectNodesSlice,
(fieldIdentifiers, nodes) => {
const result: { nodeId: string; fieldName: string; type: string }[] = [];
for (const fieldIdentifier of fieldIdentifiers) {
const node = nodes.nodes.find((node) => node.id === fieldIdentifier.nodeId);
assert(isInvocationNode(node), `Node ${fieldIdentifier.nodeId} not found`);
result.push({ nodeId: fieldIdentifier.nodeId, fieldName: fieldIdentifier.fieldName, type: node.data.type });
}
return result;
}
);
export const getPublishInputs = (fieldIdentifiers: (FieldIdentifier & { type: string })[], templates: Templates) => {
// Certain field types are not allowed to be input fields on a published workflow
const publishable: FieldIdentifier[] = [];
const unpublishable: FieldIdentifier[] = [];
for (const fieldIdentifier of fieldIdentifiers) {
const fieldTemplate = templates[fieldIdentifier.type]?.inputs[fieldIdentifier.fieldName];
if (!fieldTemplate) {
unpublishable.push(fieldIdentifier);
continue;
}
if (isBoardFieldType(fieldTemplate.type)) {
unpublishable.push(fieldIdentifier);
continue;
}
publishable.push(fieldIdentifier);
}
return { publishable, unpublishable };
};
export const usePublishInputs = () => {
const templates = useStore($templates);
const fieldIdentifiersWithInvocationTypes = useAppSelector(selectFieldIdentifiersWithInvocationTypes);
const fieldIdentifiers = useMemo(
() => getPublishInputs(fieldIdentifiersWithInvocationTypes, templates),
[fieldIdentifiersWithInvocationTypes, templates]
);
return fieldIdentifiers;
};

View File

@@ -1,6 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplateOrThrow';
import { fieldValueReset } from 'features/nodes/store/nodesSlice';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import { isInvocationNode } from 'features/nodes/types/invocation';

View File

@@ -1,10 +1,11 @@
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import type { FieldInputTemplate } from 'features/nodes/types/field';
import { isSingleOrCollection } from 'features/nodes/types/field';
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
import { useMemo } from 'react';
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
const isConnectionInputField = (field: FieldInputTemplate) => {
return (
(field.input === 'connection' && !isSingleOrCollection(field.type)) || !(field.type.name in TEMPLATE_BUILDER_MAP)
@@ -19,7 +20,7 @@ const isAnyOrDirectInputField = (field: FieldInputTemplate) => {
};
export const useInputFieldNamesMissing = (nodeId: string) => {
const template = useNodeTemplate(nodeId);
const template = useNodeTemplateOrThrow(nodeId);
const node = useNodeData(nodeId);
const fieldNames = useMemo(() => {
const instanceFields = new Set(Object.keys(node.inputs));
@@ -30,7 +31,7 @@ export const useInputFieldNamesMissing = (nodeId: string) => {
};
export const useInputFieldNamesAnyOrDirect = (nodeId: string) => {
const template = useNodeTemplate(nodeId);
const template = useNodeTemplateOrThrow(nodeId);
const fieldNames = useMemo(() => {
const anyOrDirectFields: string[] = [];
for (const [fieldName, fieldTemplate] of Object.entries(template.inputs)) {
@@ -44,7 +45,7 @@ export const useInputFieldNamesAnyOrDirect = (nodeId: string) => {
};
export const useInputFieldNamesConnection = (nodeId: string) => {
const template = useNodeTemplate(nodeId);
const template = useNodeTemplateOrThrow(nodeId);
const fieldNames = useMemo(() => {
const connectionFields: string[] = [];
for (const [fieldName, fieldTemplate] of Object.entries(template.inputs)) {

View File

@@ -1,8 +1,9 @@
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import type { FieldInputTemplate } from 'features/nodes/types/field';
import { useMemo } from 'react';
import { assert } from 'tsafe';
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
/**
* Returns the template for a specific input field of a node.
*
@@ -13,7 +14,7 @@ import { assert } from 'tsafe';
* @throws Will throw an error if the template for the input field is not found.
*/
export const useInputFieldTemplateOrThrow = (nodeId: string, fieldName: string): FieldInputTemplate => {
const template = useNodeTemplate(nodeId);
const template = useNodeTemplateOrThrow(nodeId);
const fieldTemplate = useMemo(() => {
const _fieldTemplate = template.inputs[fieldName];
assert(_fieldTemplate, `Template for input field ${fieldName} not found.`);
@@ -21,17 +22,3 @@ export const useInputFieldTemplateOrThrow = (nodeId: string, fieldName: string):
}, [fieldName, template.inputs]);
return fieldTemplate;
};
/**
* Returns the template for a specific input field of a node.
*
* **Note:** This function is a safe version of `useInputFieldTemplate` and will not throw an error if the template is not found.
*
* @param nodeId - The ID of the node.
* @param fieldName - The name of the input field.
*/
export const useInputFieldTemplateSafe = (nodeId: string, fieldName: string): FieldInputTemplate | null => {
const template = useNodeTemplate(nodeId);
const fieldTemplate = useMemo(() => template.inputs[fieldName] ?? null, [fieldName, template.inputs]);
return fieldTemplate;
};

View File

@@ -0,0 +1,17 @@
import { useNodeTemplateSafe } from 'features/nodes/hooks/useNodeTemplateSafe';
import type { FieldInputTemplate } from 'features/nodes/types/field';
import { useMemo } from 'react';
/**
* Returns the template for a specific input field of a node.
*
* **Note:** This function is a safe version of `useInputFieldTemplate` and will not throw an error if the template is not found.
*
* @param nodeId - The ID of the node.
* @param fieldName - The name of the input field.
*/
export const useInputFieldTemplateSafe = (nodeId: string, fieldName: string): FieldInputTemplate | null => {
const template = useNodeTemplateSafe(nodeId);
const fieldTemplate = useMemo(() => template?.inputs[fieldName] ?? null, [fieldName, template?.inputs]);
return fieldTemplate;
};

View File

@@ -1,9 +1,10 @@
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { useMemo } from 'react';
import { assert } from 'tsafe';
export const useInputFieldTemplateTitle = (nodeId: string, fieldName: string): string => {
const template = useNodeTemplate(nodeId);
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
export const useInputFieldTemplateTitleOrThrow = (nodeId: string, fieldName: string): string => {
const template = useNodeTemplateOrThrow(nodeId);
const title = useMemo(() => {
const fieldTemplate = template.inputs[fieldName];

View File

@@ -0,0 +1,9 @@
import { useMemo } from 'react';
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
export const useInputFieldTemplateTitleSafe = (nodeId: string, fieldName: string): string => {
const template = useNodeTemplateOrThrow(nodeId);
const title = useMemo(() => template.inputs[fieldName]?.title ?? '', [fieldName, template.inputs]);
return title;
};

View File

@@ -0,0 +1,22 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
import { useMemo } from 'react';
/**
* Gets the user-defined description of an input field for a given node.
*
* If the node doesn't exist or is not an invocation node, an error is thrown.
*
* @param nodeId The ID of the node
* @param fieldName The name of the field
*/
export const useInputFieldUserDescriptionOrThrow = (nodeId: string, fieldName: string) => {
const selector = useMemo(
() => createSelector(selectNodesSlice, (nodes) => selectFieldInputInstance(nodes, nodeId, fieldName).description),
[fieldName, nodeId]
);
const description = useAppSelector(selector);
return description;
};

View File

@@ -11,7 +11,7 @@ import { useMemo } from 'react';
* @param nodeId The ID of the node
* @param fieldName The name of the field
*/
export const useInputFieldDescriptionSafe = (nodeId: string, fieldName: string) => {
export const useInputFieldUserDescriptionSafe = (nodeId: string, fieldName: string) => {
const selector = useMemo(
() =>
createSelector(

View File

@@ -0,0 +1,23 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
import { useMemo } from 'react';
/**
* Gets the user-defined title of an input field for a given node.
*
* If the node doesn't exist or is not an invocation node, an error is thrown.
*
* @param nodeId The ID of the node
* @param fieldName The name of the field
*/
export const useInputFieldUserTitleOrThrow = (nodeId: string, fieldName: string): string => {
const selector = useMemo(
() => createSelector(selectNodesSlice, (nodes) => selectFieldInputInstance(nodes, nodeId, fieldName).label),
[fieldName, nodeId]
);
const title = useAppSelector(selector);
return title;
};

View File

@@ -4,21 +4,21 @@ import { selectFieldInputInstanceSafe, selectNodesSlice } from 'features/nodes/s
import { useMemo } from 'react';
/**
* Gets the user-defined label of an input field for a given node.
* Gets the user-defined title of an input field for a given node.
*
* If the node doesn't exist or is not an invocation node, an empty string is returned.
*
* @param nodeId The ID of the node
* @param fieldName The name of the field
*/
export const useInputFieldLabelSafe = (nodeId: string, fieldName: string): string => {
export const useInputFieldUserTitleSafe = (nodeId: string, fieldName: string): string => {
const selector = useMemo(
() =>
createSelector(selectNodesSlice, (nodes) => selectFieldInputInstanceSafe(nodes, nodeId, fieldName)?.label ?? ''),
[fieldName, nodeId]
);
const label = useAppSelector(selector);
const title = useAppSelector(selector);
return label;
return title;
};

View File

@@ -1,9 +1,10 @@
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { isBatchNodeType, isGeneratorNodeType } from 'features/nodes/types/invocation';
import { useMemo } from 'react';
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
export const useIsExecutableNode = (nodeId: string) => {
const template = useNodeTemplate(nodeId);
const template = useNodeTemplateOrThrow(nodeId);
const isExecutableNode = useMemo(
() => !isBatchNodeType(template.type) && !isGeneratorNodeType(template.type),
[template]

View File

@@ -0,0 +1,13 @@
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { $isInPublishFlow, useIsValidationRunInProgress } from 'features/nodes/components/sidePanel/workflow/publish';
import { selectWorkflowIsPublished } from 'features/nodes/store/workflowSlice';
export const useIsWorkflowEditorLocked = () => {
const isInPublishFlow = useStore($isInPublishFlow);
const isPublished = useAppSelector(selectWorkflowIsPublished);
const isValidationRunInProgress = useIsValidationRunInProgress();
const isLocked = isInPublishFlow || isPublished || isValidationRunInProgress;
return isLocked;
};

View File

@@ -1,9 +1,10 @@
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import type { Classification } from 'features/nodes/types/common';
import { useMemo } from 'react';
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
export const useNodeClassification = (nodeId: string): Classification => {
const template = useNodeTemplate(nodeId);
const template = useNodeTemplateOrThrow(nodeId);
const classification = useMemo(() => template.classification, [template]);
return classification;
};

View File

@@ -1,9 +1,10 @@
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { some } from 'lodash-es';
import { useMemo } from 'react';
import { useNodeTemplateOrThrow } from './useNodeTemplateOrThrow';
export const useNodeHasImageOutput = (nodeId: string): boolean => {
const template = useNodeTemplate(nodeId);
const template = useNodeTemplateOrThrow(nodeId);
const hasImageOutput = useMemo(
() =>
some(

Some files were not shown because too many files have changed in this diff Show More