Compare commits

...

60 Commits

Author SHA1 Message Date
psychedelicious
fe7880dcd9 chore: bump version to v6.1.0rc1 2025-07-17 16:25:40 +10:00
psychedelicious
9e526d00c2 chore(ui): lint 2025-07-17 15:36:24 +10:00
psychedelicious
1a24396be8 feat(ui): styling when nodes have error 2025-07-17 15:36:24 +10:00
psychedelicious
d97e73a565 chore(ui): lint 2025-07-17 15:36:24 +10:00
psychedelicious
55b14c8aaf perf(ui): optimize redux selectors for workflow editor
- Build selectors for each node in a react context so components can
re-use the same selectors
- Cache the selectors in the context
2025-07-17 15:36:24 +10:00
psychedelicious
79f65e57eb fix(ui): remove unnecessary coalescing operator 2025-07-17 14:21:02 +10:00
Kent Keirsey
b4c8950278 address comments 2025-07-17 14:21:02 +10:00
Kent Keirsey
400b2e9a55 unlint. 2025-07-17 14:21:02 +10:00
Kent Keirsey
3a687c583a lint 2025-07-17 14:21:02 +10:00
Kent Keirsey
833950078d commit tile size controls 2025-07-17 14:21:02 +10:00
Kent Keirsey
e698dcb148 unlint. 2025-07-17 14:21:02 +10:00
Kent Keirsey
218386e077 lint 2025-07-17 14:21:02 +10:00
Kent Keirsey
4426be9e64 commit tile size controls 2025-07-17 14:21:02 +10:00
psychedelicious
86f4cf7857 feat(ui): related embedding styling/tidy 2025-07-17 14:12:29 +10:00
Kent Keirsey
49ae66d94a Added related model support 2025-07-17 14:12:29 +10:00
Cursor Agent
c10865c7ef Reorder embedding options in PromptTriggerSelect component
Co-authored-by: kent <kent@invoke.ai>
2025-07-17 14:12:29 +10:00
psychedelicious
f3478a189a fix(ui): able to drag empty space in tab bar and detach panels 2025-07-17 13:58:32 +10:00
psychedelicious
43db29176a chore(ui): lint 2025-07-17 13:52:24 +10:00
psychedelicious
f38922929c docs(ui): comments in modelsLoaded 2025-07-17 13:52:24 +10:00
psychedelicious
7d02c58f86 fix(ui): move <ParamTileControlNetModel /> to <UpscaleTabAdvancedSettingsAccordion /> 2025-07-17 13:52:24 +10:00
Kent Keirsey
6edce8be87 Add scaling in 2025-07-17 13:52:24 +10:00
Kent Keirsey
31f63e38bd lint 2025-07-17 13:52:24 +10:00
Kent Keirsey
78a68ac3a7 Updated 2025-07-17 13:52:24 +10:00
Kent Keirsey
8cd3bcd1c0 Updates 2025-07-17 13:52:24 +10:00
Cursor Agent
264cc5ef46 Add tile ControlNet model selection to upscale settings
Co-authored-by: kent <kent@invoke.ai>
2025-07-17 13:52:24 +10:00
JPPhoto
8bfbea5ed3 Updated __init__.py 2025-07-17 06:33:56 +10:00
JPPhoto
f06a66da07 Updated schema.ts 2025-07-17 06:33:56 +10:00
Jonathan
337cae9b22 Update __init__.py
Added FluxConditioningField, FluxConditioningCollectionOutput, and FluxConditioningCollectionOutput,
2025-07-17 06:33:56 +10:00
Jonathan
bf926bb7d5 Update primitives.py
Added FluxConditioningCollectionOutput
2025-07-17 06:33:56 +10:00
psychedelicious
18ad9a6af3 feat(ui): canvas/viewer panel tabs show progress 2025-07-17 06:20:05 +10:00
psychedelicious
b6ed31c222 feat(ui): clicking invoke switches to viewer tab instead of canvas when save all images to gallery is enabled 2025-07-17 06:20:05 +10:00
psychedelicious
200beb5af5 feat(ui): make save all images to gallery option also bypass canvas 2025-07-17 06:20:05 +10:00
psychedelicious
f82a948bdd refactor(ui): canvas autoswitch logic
Simplify the canvas auto-switch logic to not rely on the preview images
loading. This fixes an issue where offscreen preview images didn't get
auto-switched to. Images are now loaded directly.
2025-07-17 06:20:05 +10:00
psychedelicious
dd03e3ddcd refactor(ui): simplify canvas session logic 2025-07-17 06:20:05 +10:00
psychedelicious
7561b73e8f fix(ui): uppercase file extensions blocked for image upload
Closes #8284
2025-07-17 00:48:36 +10:00
psychedelicious
caa97608c7 fix(ui): aspect ratios out of order 2025-07-16 23:27:37 +10:00
Mary Hipp
72a6d1edc1 simplify descriptoin styling 2025-07-16 09:19:33 -04:00
Mary Hipp
b8bf89c2f1 add fallback image and make sure description text is legible for model picker noncompact 2025-07-16 09:19:33 -04:00
psychedelicious
a1ade2b8c0 feat(ui): export apis & actions from package 2025-07-16 08:21:03 -04:00
Eugene Brodsky
4bdcae1f8f fix(docker): switch to pnpm10.x 2025-07-15 13:03:15 -04:00
Jonathan
4b22c84407 Update dev-environment.md
Document the latest changes required to build Invoke 6.0.
2025-07-15 15:21:01 +10:00
Eugene Brodsky
c9daf1db30 (fix) remove timeout from image prompt expansion (#8281) 2025-07-14 11:19:20 -04:00
psychedelicious
06d3cfbe97 gh: update bug report template
- Add require drop down for install method
- Make browser version optional
- Link to latest release
- Update verbiage for sys info section
2025-07-14 12:18:52 +10:00
psychedelicious
71e4901313 fix(ui): ignore disalbed ref images in readiness checks 2025-07-14 10:51:51 +10:00
psychedelicious
82fb897b62 chore(ui): lint 2025-07-12 14:56:57 +10:00
psychedelicious
192b00d969 chore: bump version to v6.0.2 2025-07-12 14:56:57 +10:00
psychedelicious
7bb25ef1b4 fix(ui): gallery dnd 2025-07-12 14:56:57 +10:00
psychedelicious
62f52c74a8 fix(ui): linked negative style prompt not passed in
Closes #8256
2025-07-12 10:22:17 +10:00
psychedelicious
97439c1daa fix(ui): native context menu shown on right click on short fat images
Closes #8254
2025-07-12 10:22:17 +10:00
psychedelicious
b23bff1b53 fix(ui): center staging area images 2025-07-12 10:22:17 +10:00
psychedelicious
d9a1efbabf fix(ui): staging area images may be slightly too large 2025-07-12 10:22:17 +10:00
psychedelicious
d4e903ee2d chore: bump version to v6.0.1 2025-07-12 10:22:17 +10:00
Kevin Turner
bb3e5d16d8 feat(Model Manager): refuse to download a file when there's insufficient space 2025-07-12 10:14:25 +10:00
psychedelicious
e62d3f01a8 feat(app): better error message for failed model probe
- Old: No valid config found
- New: Unable to determine model type
2025-07-11 23:35:43 +10:00
psychedelicious
757ecdbf82 build(ui): downgrade idb-keyval
We have increased error rates after updating this package. Let's try
downgrading to see if that fixes the issue.
2025-07-11 15:00:10 +10:00
psychedelicious
694c85b041 fix(ui): language file filenames
Need to replace the underscores w/ dashes - this was missed in #8246.
2025-07-11 14:21:41 +10:00
psychedelicious
988d7ba24c chore: bump version to v6.0.1rc1 2025-07-11 09:05:24 +10:00
psychedelicious
ac981879ef fix(ui): runtime errors related to calling reduce on array iterator
Fix an issue in certain browsers/builds causing a runtime error.

A zod enum has a .options property, which is an array of all the options
for the enum. This is handy for when you need to derive something from a
zod schema.

In this case, we represented the possible focus regions in the zod enum,
then derived a mapping of region names to set of target HTML elements.
Why isn't important, but suffice to say, we were using the .options
property for this.

But actually, we were using .options.values(), then calling .reduce() on
that. An array's .values() method returns an _array iterator_. Array
iterators do not have .reduce() methods!

Except, apparently in some environments they do - it depends on the JS
engine and whether or not polyfills for iterator helpers were included
in the build.

Turns out my dev environment - and most user browsers - do provide
.reduce(), so we didn't catch this error. It took a large deployment and
error monitoring to catch it.

I've refactored the code to totally avoid deriving data from zod in this
way.
2025-07-11 08:25:47 +10:00
psychedelicious
fc71849c24 feat(app): expose a cursor, not a connection in db util 2025-07-11 08:20:06 +10:00
psychedelicious
a19aa3b032 feat(app): db abstraction to prevent threading conflicts
- Add a context manager to the SqliteDatabase class which abstracts away
creating a transaction, committing it on success and rolling back on
error.
- Use it everywhere. The context manager should be exited before
returning results. No business logic changes should be present.
2025-07-11 08:20:06 +10:00
202 changed files with 3508 additions and 3227 deletions

View File

@@ -21,6 +21,20 @@ body:
- label: I have searched the existing issues
required: true
- type: dropdown
id: install_method
attributes:
label: Install method
description: How did you install Invoke?
multiple: false
options:
- "Invoke's Launcher"
- 'Stability Matrix'
- 'Pinokio'
- 'Manual'
validations:
required: true
- type: markdown
attributes:
value: __Describe your environment__
@@ -76,8 +90,8 @@ body:
attributes:
label: Version number
description: |
The version of Invoke you have installed. If it is not the latest version, please update and try again to confirm the issue still exists. If you are testing main, please include the commit hash instead.
placeholder: ex. 3.6.1
The version of Invoke you have installed. If it is not the [latest version](https://github.com/invoke-ai/InvokeAI/releases/latest), please update and try again to confirm the issue still exists. If you are testing main, please include the commit hash instead.
placeholder: ex. v6.0.2
validations:
required: true
@@ -85,17 +99,17 @@ body:
id: browser-version
attributes:
label: Browser
description: Your web browser and version.
description: Your web browser and version, if you do not use the Launcher's provided GUI.
placeholder: ex. Firefox 123.0b3
validations:
required: true
required: false
- type: textarea
id: python-deps
attributes:
label: Python dependencies
label: System Information
description: |
If the problem occurred during image generation, click the gear icon at the bottom left corner, click "About", click the copy button and then paste here.
Click the gear icon at the bottom left corner, then click "About". Click the copy button and then paste here.
validations:
required: false

View File

@@ -5,8 +5,7 @@
FROM docker.io/node:22-slim AS web-builder
ENV PNPM_HOME="/pnpm"
ENV PATH="$PNPM_HOME:$PATH"
RUN corepack use pnpm@8.x
RUN corepack enable
RUN corepack use pnpm@10.x && corepack enable
WORKDIR /build
COPY invokeai/frontend/web/ ./

View File

@@ -41,7 +41,7 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
With the modifications made, the install command should look something like this:
```sh
uv pip install -e ".[dev,test,docs,xformers]" --python 3.12 --python-preference only-managed --index=https://download.pytorch.org/whl/cu126 --reinstall
uv pip install -e ".[dev,test,docs,xformers]" --python 3.12 --python-preference only-managed --index=https://download.pytorch.org/whl/cu128 --reinstall
```
6. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.
@@ -50,11 +50,11 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
If you only want to edit the docs, you can stop here and skip to the **Documentation** section below.
7. Install the frontend dev toolchain:
7. Install the frontend dev toolchain, paying attention to versions:
- [`nodejs`](https://nodejs.org/) (v20+)
- [`nodejs`](https://nodejs.org/) (tested on LTS, v22)
- [`pnpm`](https://pnpm.io/8.x/installation) (must be v8 - not v9!)
- [`pnpm`](https://pnpm.io/installation) (tested on v10)
8. Do a production build of the frontend:

View File

@@ -430,6 +430,15 @@ class FluxConditioningOutput(BaseInvocationOutput):
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
@invocation_output("flux_conditioning_collection_output")
class FluxConditioningCollectionOutput(BaseInvocationOutput):
"""Base class for nodes that output a collection of conditioning tensors"""
collection: list[FluxConditioningField] = OutputField(
description="The output conditioning tensors",
)
@invocation_output("sd3_conditioning_output")
class SD3ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single SD3 conditioning tensor"""

View File

@@ -14,15 +14,14 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
INSERT INTO board_images (board_id, image_name)
@@ -31,17 +30,12 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
""",
(board_id, image_name, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def remove_image_from_board(
self,
image_name: str,
) -> None:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE FROM board_images
@@ -49,10 +43,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
""",
(image_name,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def get_images_for_board(
self,
@@ -60,27 +50,26 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]:
# TODO: this isn't paginated yet?
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, cursor.fetchone()[0])
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
@@ -90,56 +79,55 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
categories: list[ImageCategory] | None,
is_intermediate: bool | None,
) -> list[str]:
params: list[str | bool] = []
with self._db.transaction() as cursor:
params: list[str | bool] = []
# Base query is a join between images and board_images
stmt = """
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
# Base query is a join between images and board_images
stmt = """
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
# Handle board_id filter
if board_id == "none":
stmt += """--sql
AND board_images.board_id IS NULL
"""
else:
stmt += """--sql
AND board_images.board_id = ?
"""
params.append(board_id)
# Handle board_id filter
if board_id == "none":
stmt += """--sql
AND board_images.board_id IS NULL
"""
else:
stmt += """--sql
AND board_images.board_id = ?
"""
params.append(board_id)
# Add the category filter
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
stmt += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Add the category filter
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
stmt += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Unpack the included categories into the query params
for c in category_strings:
params.append(c)
# Unpack the included categories into the query params
for c in category_strings:
params.append(c)
# Add the is_intermediate filter
if is_intermediate is not None:
stmt += """--sql
AND images.is_intermediate = ?
"""
params.append(is_intermediate)
# Add the is_intermediate filter
if is_intermediate is not None:
stmt += """--sql
AND images.is_intermediate = ?
"""
params.append(is_intermediate)
# Put a ring on it
stmt += ";"
# Put a ring on it
stmt += ";"
# Execute the query
cursor = self._conn.cursor()
cursor.execute(stmt, params)
cursor.execute(stmt, params)
result = cast(list[sqlite3.Row], cursor.fetchall())
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
return image_names
@@ -147,31 +135,31 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
self,
image_name: str,
) -> Optional[str]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT board_id
FROM board_images
WHERE image_name = ?;
""",
(image_name,),
)
result = cursor.fetchone()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT board_id
FROM board_images
WHERE image_name = ?;
""",
(image_name,),
)
result = cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
def get_image_count_for_board(self, board_id: str) -> int:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*)
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE images.is_intermediate = FALSE
AND board_images.board_id = ?;
""",
(board_id,),
)
count = cast(int, cursor.fetchone()[0])
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT COUNT(*)
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE images.is_intermediate = FALSE
AND board_images.board_id = ?;
""",
(board_id,),
)
count = cast(int, cursor.fetchone()[0])
return count

View File

@@ -20,61 +20,57 @@ from invokeai.app.util.misc import uuid_string
class SqliteBoardRecordStorage(BoardRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def delete(self, board_id: str) -> None:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
self._conn.commit()
except Exception as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
except Exception as e:
raise BoardRecordDeleteException from e
def save(
self,
board_name: str,
) -> BoardRecord:
try:
board_id = uuid_string()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
""",
(board_id, board_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
with self._db.transaction() as cursor:
try:
board_id = uuid_string()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
""",
(board_id, board_name),
)
except sqlite3.Error as e:
raise BoardRecordSaveException from e
return self.get(board_id)
def get(
self,
board_id: str,
) -> BoardRecord:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
except sqlite3.Error as e:
raise BoardRecordNotFoundException from e
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
except sqlite3.Error as e:
raise BoardRecordNotFoundException from e
if result is None:
raise BoardRecordNotFoundException
return BoardRecord(**dict(result))
@@ -84,45 +80,43 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
board_id: str,
changes: BoardChanges,
) -> BoardRecord:
try:
cursor = self._conn.cursor()
# Change the name of a board
if changes.board_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET board_name = ?
WHERE board_id = ?;
""",
(changes.board_name, board_id),
)
with self._db.transaction() as cursor:
try:
# Change the name of a board
if changes.board_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET board_name = ?
WHERE board_id = ?;
""",
(changes.board_name, board_id),
)
# Change the cover image of a board
if changes.cover_image_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET cover_image_name = ?
WHERE board_id = ?;
""",
(changes.cover_image_name, board_id),
)
# Change the cover image of a board
if changes.cover_image_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET cover_image_name = ?
WHERE board_id = ?;
""",
(changes.cover_image_name, board_id),
)
# Change the archived status of a board
if changes.archived is not None:
cursor.execute(
"""--sql
UPDATE boards
SET archived = ?
WHERE board_id = ?;
""",
(changes.archived, board_id),
)
# Change the archived status of a board
if changes.archived is not None:
cursor.execute(
"""--sql
UPDATE boards
SET archived = ?
WHERE board_id = ?;
""",
(changes.archived, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
except sqlite3.Error as e:
raise BoardRecordSaveException from e
return self.get(board_id)
def get_many(
@@ -133,78 +127,77 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardRecord]:
cursor = self._conn.cursor()
# Build base query
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
LIMIT ? OFFSET ?;
"""
# Determine archived filter condition
archived_filter = "" if include_archived else "WHERE archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
# Execute query to fetch boards
cursor.execute(final_query, (limit, offset))
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
"""
else:
count_query = """
SELECT COUNT(*)
with self._db.transaction() as cursor:
# Build base query
base_query = """
SELECT *
FROM boards
WHERE archived = 0;
{archived_filter}
ORDER BY {order_by} {direction}
LIMIT ? OFFSET ?;
"""
# Execute count query
cursor.execute(count_query)
# Determine archived filter condition
archived_filter = "" if include_archived else "WHERE archived = 0"
count = cast(int, cursor.fetchone()[0])
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
# Execute query to fetch boards
cursor.execute(final_query, (limit, offset))
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
"""
else:
count_query = """
SELECT COUNT(*)
FROM boards
WHERE archived = 0;
"""
# Execute count query
cursor.execute(count_query)
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
) -> list[BoardRecord]:
cursor = self._conn.cursor()
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY LOWER(board_name) {direction}
"""
else:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
"""
with self._db.transaction() as cursor:
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY LOWER(board_name) {direction}
"""
else:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
"""
archived_filter = "" if include_archived else "WHERE archived = 0"
archived_filter = "" if include_archived else "WHERE archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
cursor.execute(final_query)
cursor.execute(final_query)
result = cast(list[sqlite3.Row], cursor.fetchall())
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
return boards

View File

@@ -8,6 +8,7 @@ import time
import traceback
from pathlib import Path
from queue import Empty, PriorityQueue
from shutil import disk_usage
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
import requests
@@ -335,6 +336,14 @@ class DownloadQueueService(DownloadQueueServiceBase):
assert job.download_path
free_space = disk_usage(job.download_path.parent).free
GB = 2**30
self._logger.debug(f"Download is {job.total_bytes / GB:.2f} GB of {free_space / GB:.2f} GB free.")
if free_space < job.total_bytes:
raise RuntimeError(
f"Free disk space {free_space / GB:.2f} GB is not enough for download of {job.total_bytes / GB:.2f} GB."
)
# Don't clobber an existing file. See commit 82c2c85202f88c6d24ff84710f297cfc6ae174af
# for code that instead resumes an interrupted download.
if job.download_path.exists():

View File

@@ -24,22 +24,22 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteImageRecordStorage(ImageRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def get(self, image_name: str) -> ImageRecord:
try:
cursor = self._conn.cursor()
cursor.execute(
f"""--sql
SELECT {IMAGE_DTO_COLS} FROM images
WHERE image_name = ?;
""",
(image_name,),
)
with self._db.transaction() as cursor:
try:
cursor.execute(
f"""--sql
SELECT {IMAGE_DTO_COLS} FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
result = cast(Optional[sqlite3.Row], cursor.fetchone())
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
if not result:
raise ImageRecordNotFoundException
@@ -47,17 +47,20 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return deserialize_image_record(dict(result))
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT metadata FROM images
WHERE image_name = ?;
""",
(image_name,),
)
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
SELECT metadata FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
result = cast(Optional[sqlite3.Row], cursor.fetchone())
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
if not result:
raise ImageRecordNotFoundException
@@ -65,64 +68,60 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
as_dict = dict(result)
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
def update(
self,
image_name: str,
changes: ImageRecordChanges,
) -> None:
try:
cursor = self._conn.cursor()
# Change the category of the image
if changes.image_category is not None:
cursor.execute(
"""--sql
UPDATE images
SET image_category = ?
WHERE image_name = ?;
""",
(changes.image_category, image_name),
)
with self._db.transaction() as cursor:
try:
# Change the category of the image
if changes.image_category is not None:
cursor.execute(
"""--sql
UPDATE images
SET image_category = ?
WHERE image_name = ?;
""",
(changes.image_category, image_name),
)
# Change the session associated with the image
if changes.session_id is not None:
cursor.execute(
"""--sql
UPDATE images
SET session_id = ?
WHERE image_name = ?;
""",
(changes.session_id, image_name),
)
# Change the session associated with the image
if changes.session_id is not None:
cursor.execute(
"""--sql
UPDATE images
SET session_id = ?
WHERE image_name = ?;
""",
(changes.session_id, image_name),
)
# Change the image's `is_intermediate`` flag
if changes.is_intermediate is not None:
cursor.execute(
"""--sql
UPDATE images
SET is_intermediate = ?
WHERE image_name = ?;
""",
(changes.is_intermediate, image_name),
)
# Change the image's `is_intermediate`` flag
if changes.is_intermediate is not None:
cursor.execute(
"""--sql
UPDATE images
SET is_intermediate = ?
WHERE image_name = ?;
""",
(changes.is_intermediate, image_name),
)
# Change the image's `starred`` state
if changes.starred is not None:
cursor.execute(
"""--sql
UPDATE images
SET starred = ?
WHERE image_name = ?;
""",
(changes.starred, image_name),
)
# Change the image's `starred`` state
if changes.starred is not None:
cursor.execute(
"""--sql
UPDATE images
SET starred = ?
WHERE image_name = ?;
""",
(changes.starred, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
except sqlite3.Error as e:
raise ImageRecordSaveException from e
def get_many(
self,
@@ -136,170 +135,162 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
cursor = self._conn.cursor()
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
images_query = f"""--sql
SELECT {IMAGE_DTO_COLS}
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
with self._db.transaction() as cursor:
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
images_query = f"""--sql
SELECT {IMAGE_DTO_COLS}
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_params.append(is_intermediate)
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
# board_id of "none" is reserved for images without a board
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
# Search term condition
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
# Add the pagination parameters
images_params.extend([limit, offset])
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
# Build the list of images, deserializing each row
cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
cursor.execute(count_query, count_params)
count = cast(int, cursor.fetchone()[0])
query_params.append(is_intermediate)
# board_id of "none" is reserved for images without a board
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
# Search term condition
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
# Add the pagination parameters
images_params.extend([limit, offset])
# Build the list of images, deserializing each row
cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
cursor.execute(count_query, count_params)
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
def delete(self, image_name: str) -> None:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE FROM images
WHERE image_name = ?;
""",
(image_name,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
DELETE FROM images
WHERE image_name = ?;
""",
(image_name,),
)
except sqlite3.Error as e:
raise ImageRecordDeleteException from e
def delete_many(self, image_names: list[str]) -> None:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
try:
placeholders = ",".join("?" for _ in image_names)
placeholders = ",".join("?" for _ in image_names)
# Construct the SQLite query with the placeholders
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
# Construct the SQLite query with the placeholders
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
# Execute the query with the list of IDs as parameters
cursor.execute(query, image_names)
# Execute the query with the list of IDs as parameters
cursor.execute(query, image_names)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
except sqlite3.Error as e:
raise ImageRecordDeleteException from e
def get_intermediates_count(self) -> int:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images
WHERE is_intermediate = TRUE;
"""
)
count = cast(int, cursor.fetchone()[0])
self._conn.commit()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images
WHERE is_intermediate = TRUE;
"""
)
count = cast(int, cursor.fetchone()[0])
return count
def delete_intermediates(self) -> list[str]:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT image_name FROM images
WHERE is_intermediate = TRUE;
"""
)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
cursor.execute(
"""--sql
DELETE FROM images
WHERE is_intermediate = TRUE;
"""
)
self._conn.commit()
return image_names
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
SELECT image_name FROM images
WHERE is_intermediate = TRUE;
"""
)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
cursor.execute(
"""--sql
DELETE FROM images
WHERE is_intermediate = TRUE;
"""
)
except sqlite3.Error as e:
raise ImageRecordDeleteException from e
return image_names
def save(
self,
@@ -315,73 +306,71 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id: Optional[str] = None,
metadata: Optional[str] = None,
) -> datetime:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO images (
image_name,
image_origin,
image_category,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
image_origin.value,
image_category.value,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow,
),
)
self._conn.commit()
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
INSERT OR IGNORE INTO images (
image_name,
image_origin,
image_category,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
image_origin.value,
image_category.value,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow,
),
)
cursor.execute(
"""--sql
SELECT created_at
FROM images
WHERE image_name = ?;
""",
(image_name,),
)
cursor.execute(
"""--sql
SELECT created_at
FROM images
WHERE image_name = ?;
""",
(image_name,),
)
created_at = datetime.fromisoformat(cursor.fetchone()[0])
created_at = datetime.fromisoformat(cursor.fetchone()[0])
return created_at
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
except sqlite3.Error as e:
raise ImageRecordSaveException from e
return created_at
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
AND images.is_intermediate = FALSE
ORDER BY images.starred DESC, images.created_at DESC
LIMIT 1;
""",
(board_id,),
)
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
AND images.is_intermediate = FALSE
ORDER BY images.starred DESC, images.created_at DESC
LIMIT 1;
""",
(board_id,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
result = cast(Optional[sqlite3.Row], cursor.fetchone())
if result is None:
return None
@@ -398,85 +387,84 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
# Build query conditions (reused for both starred count and image names queries)
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
# Build query conditions (reused for both starred count and image names queries)
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
category_strings = [c.value for c in set(categories)]
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
for c in category_strings:
query_params.append(c)
if categories is not None:
category_strings = [c.value for c in set(categories)]
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_params.append(is_intermediate)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_params.append(is_intermediate)
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
# Get starred count if starred_first is enabled
starred_count = 0
if starred_first:
starred_count_query = f"""--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE images.starred = TRUE AND (1=1{query_conditions})
"""
cursor.execute(starred_count_query, query_params)
starred_count = cast(int, cursor.fetchone()[0])
# Get starred count if starred_first is enabled
starred_count = 0
if starred_first:
starred_count_query = f"""--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE images.starred = TRUE AND (1=1{query_conditions})
"""
cursor.execute(starred_count_query, query_params)
starred_count = cast(int, cursor.fetchone()[0])
# Get all image names with proper ordering
if starred_first:
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
ORDER BY images.starred DESC, images.created_at {order_dir.value}
"""
else:
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
ORDER BY images.created_at {order_dir.value}
"""
# Get all image names with proper ordering
if starred_first:
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
ORDER BY images.starred DESC, images.created_at {order_dir.value}
"""
else:
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
ORDER BY images.created_at {order_dir.value}
"""
cursor.execute(names_query, query_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
cursor.execute(names_query, query_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [row[0] for row in result]
return ImageNamesResult(image_names=image_names, starred_count=starred_count, total_count=len(image_names))

View File

@@ -78,11 +78,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db = db
self._logger = logger
@property
def db(self) -> SqliteDatabase:
"""Return the underlying database."""
return self._db
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
"""
Add a model to the database.
@@ -93,38 +88,33 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
INSERT INTO models (
id,
config
)
VALUES (?,?);
""",
(
config.key,
config.model_dump_json(),
),
)
self._db.conn.commit()
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
INSERT INTO models (
id,
config
)
VALUES (?,?);
""",
(
config.key,
config.model_dump_json(),
),
)
except sqlite3.IntegrityError as e:
self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e):
if "models.path" in str(e):
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in str(e):
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
except sqlite3.IntegrityError as e:
if "UNIQUE constraint failed" in str(e):
if "models.path" in str(e):
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in str(e):
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
else:
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
else:
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
else:
raise e
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
raise e
return self.get_model(config.key)
@@ -136,8 +126,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise an UnknownModelException
"""
try:
cursor = self._db.conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE FROM models
@@ -147,22 +136,17 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
record = self.get_model(key)
with self._db.transaction() as cursor:
record = self.get_model(key)
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
json_serialized = record.model_dump_json()
json_serialized = record.model_dump_json()
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
UPDATE models
@@ -174,10 +158,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(key)
@@ -189,30 +169,30 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Exceptions: UnknownModelException
"""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
)
rows = cursor.fetchone()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
)
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = cursor.fetchone()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
@@ -224,15 +204,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param key: Unique key for the model to be deleted
"""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
select count(*) FROM models
WHERE id=?;
""",
(key,),
)
count = cursor.fetchone()[0]
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
select count(*) FROM models
WHERE id=?;
""",
(key,),
)
count = cursor.fetchone()[0]
return count > 0
def search_by_attr(
@@ -255,43 +235,42 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
If none of the optional filters are passed, will return all
models in the database.
"""
with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
where_clause: list[str] = []
bindings: list[str] = []
if model_name:
where_clause.append("name=?")
bindings.append(model_name)
if base_model:
where_clause.append("base=?")
bindings.append(base_model)
if model_type:
where_clause.append("type=?")
bindings.append(model_type)
if model_format:
where_clause.append("format=?")
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
where_clause: list[str] = []
bindings: list[str] = []
if model_name:
where_clause.append("name=?")
bindings.append(model_name)
if base_model:
where_clause.append("base=?")
bindings.append(base_model)
if model_type:
where_clause.append("type=?")
bindings.append(model_type)
if model_format:
where_clause.append("format=?")
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
cursor = self._db.conn.cursor()
cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
""",
tuple(bindings),
)
result = cursor.fetchall()
cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
""",
tuple(bindings),
)
result = cursor.fetchall()
# Parse the model configs.
results: list[AnyModelConfig] = []
@@ -313,69 +292,68 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
"""Return models with the indicated path."""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
"""Return models with the indicated hash."""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
return results
def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database."""
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
cursor = self._db.conn.cursor()
# Lock so that the database isn't updated while we're doing the two queries.
# query1: get the total number of model configs
cursor.execute(
"""--sql
select count(*) from models;
""",
(),
)
total = int(cursor.fetchone()[0])
# Lock so that the database isn't updated while we're doing the two queries.
# query1: get the total number of model configs
cursor.execute(
"""--sql
select count(*) from models;
""",
(),
)
total = int(cursor.fetchone()[0])
# query2: fetch key fields
cursor.execute(
f"""--sql
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",
(
per_page,
page * per_page,
),
)
rows = cursor.fetchall()
# query2: fetch key fields
cursor.execute(
f"""--sql
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",
(
per_page,
page * per_page,
),
)
rows = cursor.fetchall()
items = [ModelSummary.model_validate(dict(x)) for x in rows]
return PaginatedResults(page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items)

View File

@@ -1,5 +1,3 @@
import sqlite3
from invokeai.app.services.model_relationship_records.model_relationship_records_base import (
ModelRelationshipRecordStorageBase,
)
@@ -9,58 +7,49 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
if model_key_1 == model_key_2:
raise ValueError("Cannot relate a model to itself.")
a, b = sorted([model_key_1, model_key_2])
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
if model_key_1 == model_key_2:
raise ValueError("Cannot relate a model to itself.")
a, b = sorted([model_key_1, model_key_2])
cursor.execute(
"INSERT OR IGNORE INTO model_relationships (model_key_1, model_key_2) VALUES (?, ?)",
(a, b),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
a, b = sorted([model_key_1, model_key_2])
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
a, b = sorted([model_key_1, model_key_2])
cursor.execute(
"DELETE FROM model_relationships WHERE model_key_1 = ? AND model_key_2 = ?",
(a, b),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def get_related_model_keys(self, model_key: str) -> list[str]:
cursor = self._conn.cursor()
cursor.execute(
"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
""",
(model_key, model_key),
)
return [row[0] for row in cursor.fetchall()]
with self._db.transaction() as cursor:
cursor.execute(
"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
""",
(model_key, model_key),
)
result = [row[0] for row in cursor.fetchall()]
return result
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
cursor = self._conn.cursor()
key_list = ",".join("?" for _ in model_keys)
cursor.execute(
f"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
""",
model_keys + model_keys,
)
return [row[0] for row in cursor.fetchall()]
with self._db.transaction() as cursor:
key_list = ",".join("?" for _ in model_keys)
cursor.execute(
f"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
""",
model_keys + model_keys,
)
result = [row[0] for row in cursor.fetchall()]
return result

View File

@@ -50,15 +50,14 @@ class SqliteSessionQueue(SessionQueueBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def _set_in_progress_to_canceled(self) -> None:
"""
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
This is necessary because the invoker may have been killed while processing a queue item.
"""
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
UPDATE session_queue
@@ -66,87 +65,79 @@ class SqliteSessionQueue(SessionQueueBase):
WHERE status = 'in_progress';
"""
)
except Exception:
self._conn.rollback()
raise
def _get_current_queue_size(self, queue_id: str) -> int:
"""Gets the current number of pending queue items"""
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
return cast(int, cursor.fetchone()[0])
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
count = cast(int, cursor.fetchone()[0])
return count
def _get_highest_priority(self, queue_id: str) -> int:
"""Gets the highest priority value in the queue"""
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT MAX(priority)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
return cast(Union[int, None], cursor.fetchone()[0]) or 0
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT MAX(priority)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
priority = cast(Union[int, None], cursor.fetchone()[0]) or 0
return priority
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
try:
# TODO: how does this work in a multi-user scenario?
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.max_queue_size
max_new_queue_items = max_queue_size - current_queue_size
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.max_queue_size
max_new_queue_items = max_queue_size - current_queue_size
priority = 0
if prepend:
priority = self._get_highest_priority(queue_id) + 1
priority = 0
if prepend:
priority = self._get_highest_priority(queue_id) + 1
requested_count = await asyncio.to_thread(
calc_session_count,
batch=batch,
)
values_to_insert = await asyncio.to_thread(
prepare_values_to_insert,
queue_id=queue_id,
batch=batch,
priority=priority,
max_new_queue_items=max_new_queue_items,
)
enqueued_count = len(values_to_insert)
requested_count = await asyncio.to_thread(
calc_session_count,
batch=batch,
)
values_to_insert = await asyncio.to_thread(
prepare_values_to_insert,
queue_id=queue_id,
batch=batch,
priority=priority,
max_new_queue_items=max_new_queue_items,
)
enqueued_count = len(values_to_insert)
with self._conn:
cursor = self._conn.cursor()
cursor.executemany(
"""--sql
with self._db.transaction() as cursor:
cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
with self._conn:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
values_to_insert,
)
cursor.execute(
"""--sql
SELECT item_id
FROM session_queue
WHERE batch_id = ?
ORDER BY item_id DESC;
""",
(batch.batch_id,),
)
item_ids = [row[0] for row in cursor.fetchall()]
except Exception:
raise
(batch.batch_id,),
)
item_ids = [row[0] for row in cursor.fetchall()]
enqueue_result = EnqueueBatchResult(
queue_id=queue_id,
requested=requested_count,
@@ -159,19 +150,19 @@ class SqliteSessionQueue(SessionQueueBase):
return enqueue_result
def dequeue(self) -> Optional[SessionQueueItem]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE status = 'pending'
ORDER BY
priority DESC,
item_id ASC
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE status = 'pending'
ORDER BY
priority DESC,
item_id ASC
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
queue_item = SessionQueueItem.queue_item_from_dict(dict(result))
@@ -179,40 +170,40 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
ORDER BY
priority DESC,
created_at ASC
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
ORDER BY
priority DESC,
created_at ASC
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'in_progress'
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'in_progress'
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
@@ -225,8 +216,7 @@ class SqliteSessionQueue(SessionQueueBase):
error_message: Optional[str] = None,
error_traceback: Optional[str] = None,
) -> SessionQueueItem:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status FROM session_queue WHERE item_id = ?
@@ -234,12 +224,15 @@ class SqliteSessionQueue(SessionQueueBase):
(item_id,),
)
row = cursor.fetchone()
if row is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
current_status = row[0]
# Only update if not already finished (completed, failed or canceled)
if current_status in ("completed", "failed", "canceled"):
return self.get_queue_item(item_id)
if row is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
current_status = row[0]
# Only update if not already finished (completed, failed or canceled)
if current_status in ("completed", "failed", "canceled"):
return self.get_queue_item(item_id)
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
UPDATE session_queue
@@ -248,10 +241,7 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(status, error_type, error_message, error_traceback, item_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
queue_item = self.get_queue_item(item_id)
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
@@ -259,35 +249,34 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
is_empty = cast(int, cursor.fetchone()[0]) == 0
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
is_empty = cast(int, cursor.fetchone()[0]) == 0
return IsEmptyResult(is_empty=is_empty)
def is_full(self, queue_id: str) -> IsFullResult:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
return IsFullResult(is_full=is_full)
def clear(self, queue_id: str) -> ClearResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT COUNT(*)
@@ -305,24 +294,19 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
self.__invoker.services.events.emit_queue_cleared(queue_id)
return ClearResult(deleted=count)
def prune(self, queue_id: str) -> PruneResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
where = """--sql
WHERE
queue_id = ?
AND (
queue_id = ?
AND (
status = 'completed'
OR status = 'failed'
OR status = 'canceled'
)
)
"""
cursor.execute(
f"""--sql
@@ -341,10 +325,6 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return PruneResult(deleted=count)
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
@@ -357,8 +337,7 @@ class SqliteSessionQueue(SessionQueueBase):
self.cancel_queue_item(item_id)
except SessionQueueItemNotFoundError:
pass
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE
@@ -367,10 +346,6 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(item_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
@@ -393,8 +368,7 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
placeholders = ", ".join(["?" for _ in batch_ids])
where = f"""--sql
@@ -425,17 +399,14 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
self._conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self._conn.rollback()
raise
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
return CancelByBatchIDsResult(canceled=count)
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
where = """--sql
WHERE
@@ -465,17 +436,12 @@ class SqliteSessionQueue(SessionQueueBase):
""",
params,
)
self._conn.commit()
if current_queue_item is not None and current_queue_item.destination == destination:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self._conn.rollback()
raise
if current_queue_item is not None and current_queue_item.destination == destination:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
return CancelByDestinationResult(canceled=count)
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
if current_queue_item is not None and current_queue_item.destination == destination:
self.cancel_queue_item(current_queue_item.item_id)
@@ -501,15 +467,10 @@ class SqliteSessionQueue(SessionQueueBase):
""",
params,
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return DeleteByDestinationResult(deleted=count)
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
where = """--sql
WHERE
queue_id == ?
@@ -532,15 +493,10 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return DeleteAllExceptCurrentResult(deleted=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
where = """--sql
WHERE
@@ -569,18 +525,13 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
self._conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self._conn.rollback()
raise
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
return CancelByQueueIDResult(canceled=count)
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
where = """--sql
WHERE
queue_id == ?
@@ -603,30 +554,25 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return CancelAllExceptCurrentResult(canceled=count)
def get_queue_item(self, item_id: int) -> SessionQueueItem:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
return SessionQueueItem.queue_item_from_dict(dict(result))
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
# during execution.
@@ -639,10 +585,6 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(session_json, item_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get_queue_item(item_id)
def list_queue_items(
@@ -654,42 +596,42 @@ class SqliteSessionQueue(SessionQueueBase):
status: Optional[QUEUE_ITEM_STATUS] = None,
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
cursor_ = self._conn.cursor()
item_id = cursor
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)
if destination is not None:
query += """---sql
AND destination = ?
with self._db.transaction() as cursor_:
item_id = cursor
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params.append(destination)
params: list[Union[str, int]] = [queue_id]
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)
if destination is not None:
query += """---sql
AND destination = ?
"""
params.extend([priority, priority, item_id])
params.append(destination)
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
"""
params.append(limit + 1)
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
"""
params.extend([priority, priority, item_id])
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
"""
params.append(limit + 1)
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
@@ -704,43 +646,43 @@ class SqliteSessionQueue(SessionQueueBase):
destination: Optional[str] = None,
) -> list[SessionQueueItem]:
"""Gets all queue items that match the given parameters"""
cursor_ = self._conn.cursor()
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if destination is not None:
query += """---sql
AND destination = ?
with self._db.transaction() as cursor:
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params.append(destination)
params: list[Union[str, int]] = [queue_id]
query += """--sql
ORDER BY
priority DESC,
item_id ASC
;
"""
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
if destination is not None:
query += """---sql
AND destination = ?
"""
params.append(destination)
query += """--sql
ORDER BY
priority DESC,
item_id ASC
;
"""
cursor.execute(query, params)
results = cast(list[sqlite3.Row], cursor.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
return items
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
current_item = self.get_current(queue_id=queue_id)
total = sum(row[1] or 0 for row in counts_result)
@@ -759,19 +701,19 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*), origin, destination
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*), origin, destination
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] or 0 for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None
@@ -791,18 +733,18 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
AND destination = ?
GROUP BY status
""",
(queue_id, destination),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
AND destination = ?
GROUP BY status
""",
(queue_id, destination),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] or 0 for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
@@ -820,8 +762,7 @@ class SqliteSessionQueue(SessionQueueBase):
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
"""Retries the given queue items"""
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
values_to_insert: list[ValueToInsertTuple] = []
retried_item_ids: list[int] = []
@@ -872,10 +813,6 @@ class SqliteSessionQueue(SessionQueueBase):
values_to_insert,
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
retry_result = RetryItemsResult(
queue_id=queue_id,
retried_item_ids=retried_item_ids,

View File

@@ -1,4 +1,7 @@
import sqlite3
import threading
from collections.abc import Generator
from contextlib import contextmanager
from logging import Logger
from pathlib import Path
@@ -26,46 +29,65 @@ class SqliteDatabase:
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
"""Initializes the database. This is used internally by the class constructor."""
self.logger = logger
self.db_path = db_path
self.verbose = verbose
self._logger = logger
self._db_path = db_path
self._verbose = verbose
self._lock = threading.RLock()
if not self.db_path:
if not self._db_path:
logger.info("Initializing in-memory database")
else:
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.logger.info(f"Initializing database at {self.db_path}")
self._db_path.parent.mkdir(parents=True, exist_ok=True)
self._logger.info(f"Initializing database at {self._db_path}")
self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
self.conn.row_factory = sqlite3.Row
self._conn = sqlite3.connect(database=self._db_path or sqlite_memory, check_same_thread=False)
self._conn.row_factory = sqlite3.Row
if self.verbose:
self.conn.set_trace_callback(self.logger.debug)
if self._verbose:
self._conn.set_trace_callback(self._logger.debug)
# Enable foreign key constraints
self.conn.execute("PRAGMA foreign_keys = ON;")
self._conn.execute("PRAGMA foreign_keys = ON;")
# Enable Write-Ahead Logging (WAL) mode for better concurrency
self.conn.execute("PRAGMA journal_mode = WAL;")
self._conn.execute("PRAGMA journal_mode = WAL;")
# Set a busy timeout to prevent database lockups during writes
self.conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
self._conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
def clean(self) -> None:
"""
Cleans the database by running the VACUUM command, reporting on the freed space.
"""
# No need to clean in-memory database
if not self.db_path:
if not self._db_path:
return
try:
initial_db_size = Path(self.db_path).stat().st_size
self.conn.execute("VACUUM;")
self.conn.commit()
final_db_size = Path(self.db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
with self._conn as conn:
initial_db_size = Path(self._db_path).stat().st_size
conn.execute("VACUUM;")
conn.commit()
final_db_size = Path(self._db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
except Exception as e:
self.logger.error(f"Error cleaning database: {e}")
self._logger.error(f"Error cleaning database: {e}")
raise
@contextmanager
def transaction(self) -> Generator[sqlite3.Cursor, None, None]:
"""
Thread-safe context manager for DB work.
Acquires the RLock, yields a Cursor, then commits or rolls back.
"""
with self._lock:
cursor = self._conn.cursor()
try:
yield cursor
self._conn.commit()
except:
self._conn.rollback()
raise
finally:
cursor.close()

View File

@@ -32,7 +32,7 @@ class SqliteMigrator:
def __init__(self, db: SqliteDatabase) -> None:
self._db = db
self._logger = db.logger
self._logger = db._logger
self._migration_set = MigrationSet()
self._backup_path: Optional[Path] = None
@@ -45,7 +45,7 @@ class SqliteMigrator:
"""Migrates the database to the latest version."""
# This throws if there is a problem.
self._migration_set.validate_migration_chain()
cursor = self._db.conn.cursor()
cursor = self._db._conn.cursor()
self._create_migrations_table(cursor=cursor)
if self._migration_set.count == 0:
@@ -59,13 +59,13 @@ class SqliteMigrator:
self._logger.info("Database update needed")
# Make a backup of the db if it needs to be updated and is a file db
if self._db.db_path is not None:
if self._db._db_path is not None:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
self._backup_path = self._db._db_path.parent / f"{self._db._db_path.stem}_backup_{timestamp}.db"
self._logger.info(f"Backing up database to {str(self._backup_path)}")
# Use SQLite to do the backup
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
self._db.conn.backup(backup_conn)
self._db._conn.backup(backup_conn)
else:
self._logger.info("Using in-memory database, no backup needed")
@@ -81,7 +81,7 @@ class SqliteMigrator:
try:
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
# exception is raised.
with self._db.conn as conn:
with self._db._conn as conn:
cursor = conn.cursor()
if self._get_current_version(cursor) != migration.from_version:
raise MigrationError(

View File

@@ -17,7 +17,7 @@ from invokeai.app.util.misc import uuid_string
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@@ -25,24 +25,23 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
"""Gets a style preset by ID."""
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = cursor.fetchone()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = cursor.fetchone()
if row is None:
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
return StylePresetRecordDTO.from_dict(dict(row))
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
style_preset_id = uuid_string()
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
@@ -60,16 +59,11 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(style_preset_id)
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
style_preset_ids = []
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
for style_preset in style_presets:
style_preset_id = uuid_string()
style_preset_ids.append(style_preset_id)
@@ -90,16 +84,11 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return None
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
# Change the name of a style preset
if changes.name is not None:
cursor.execute(
@@ -122,15 +111,10 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
(changes.preset_data.model_dump_json(), style_preset_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(style_preset_id)
def delete(self, style_preset_id: str) -> None:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE from style_presets
@@ -138,51 +122,41 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
""",
(style_preset_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return None
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
main_query = """
SELECT
*
FROM style_presets
"""
with self._db.transaction() as cursor:
main_query = """
SELECT
*
FROM style_presets
"""
if type is not None:
main_query += "WHERE type = ? "
if type is not None:
main_query += "WHERE type = ? "
main_query += "ORDER BY LOWER(name) ASC"
main_query += "ORDER BY LOWER(name) ASC"
cursor = self._conn.cursor()
if type is not None:
cursor.execute(main_query, (type,))
else:
cursor.execute(main_query)
if type is not None:
cursor.execute(main_query, (type,))
else:
cursor.execute(main_query)
rows = cursor.fetchall()
rows = cursor.fetchall()
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
return style_presets
def _sync_default_style_presets(self) -> None:
"""Syncs default style presets to the database. Internal use only."""
# First delete all existing default style presets
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
# First delete all existing default style presets
cursor.execute(
"""--sql
DELETE FROM style_presets
WHERE type = "default";
"""
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
# Next, parse and create the default style presets
with open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
presets = json.load(file)

View File

@@ -25,7 +25,7 @@ SQL_TIME_FORMAT = "%Y-%m-%d %H:%M:%f"
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@@ -33,16 +33,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def get(self, workflow_id: str) -> WorkflowRecordDTO:
"""Gets a workflow by ID. Updates the opened_at column."""
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
FROM workflow_library
WHERE workflow_id = ?;
""",
(workflow_id,),
)
row = cursor.fetchone()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
FROM workflow_library
WHERE workflow_id = ?;
""",
(workflow_id,),
)
row = cursor.fetchone()
if row is None:
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
return WorkflowRecordDTO.from_dict(dict(row))
@@ -51,9 +51,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be created via this method")
try:
with self._db.transaction() as cursor:
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO workflow_library (
@@ -64,18 +63,13 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow_with_id.id, workflow_with_id.model_dump_json()),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(workflow_with_id.id)
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be updated")
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
UPDATE workflow_library
@@ -84,18 +78,13 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow.model_dump_json(), workflow.id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(workflow.id)
def delete(self, workflow_id: str) -> None:
if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be deleted")
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE from workflow_library
@@ -103,10 +92,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return None
def get_many(
@@ -121,108 +106,108 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
with self._db.transaction() as cursor:
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
# We will construct the query dynamically based on the query params
# We will construct the query dynamically based on the query params
# The main query to get the workflows / counts
main_query = """
SELECT
workflow_id,
category,
name,
description,
created_at,
updated_at,
opened_at,
tags
FROM workflow_library
"""
count_query = "SELECT COUNT(*) FROM workflow_library"
# The main query to get the workflows / counts
main_query = """
SELECT
workflow_id,
category,
name,
description,
created_at,
updated_at,
opened_at,
tags
FROM workflow_library
"""
count_query = "SELECT COUNT(*) FROM workflow_library"
# Start with an empty list of conditions and params
conditions: list[str] = []
params: list[str | int] = []
# Start with an empty list of conditions and params
conditions: list[str] = []
params: list[str | int] = []
if categories:
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
if categories:
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
# Ensure all categories are valid (is this necessary?)
assert all(c in WorkflowCategory for c in categories)
# Ensure all categories are valid (is this necessary?)
assert all(c in WorkflowCategory for c in categories)
# Construct a placeholder string for the number of categories
placeholders = ", ".join("?" for _ in categories)
# Construct a placeholder string for the number of categories
placeholders = ", ".join("?" for _ in categories)
# Construct the condition string & params
category_condition = f"category IN ({placeholders})"
category_params = [category.value for category in categories]
# Construct the condition string & params
category_condition = f"category IN ({placeholders})"
category_params = [category.value for category in categories]
conditions.append(category_condition)
params.extend(category_params)
conditions.append(category_condition)
params.extend(category_params)
if tags:
# Tags is a list of strings, and a single string in the DB
# The string in the DB has no guaranteed format
if tags:
# Tags is a list of strings, and a single string in the DB
# The string in the DB has no guaranteed format
# Construct a list of conditions for each tag
tags_conditions = ["tags LIKE ?" for _ in tags]
tags_conditions_joined = " OR ".join(tags_conditions)
tags_condition = f"({tags_conditions_joined})"
# Construct a list of conditions for each tag
tags_conditions = ["tags LIKE ?" for _ in tags]
tags_conditions_joined = " OR ".join(tags_conditions)
tags_condition = f"({tags_conditions_joined})"
# And the params for the tags, case-insensitive
tags_params = [f"%{t.strip()}%" for t in tags]
# And the params for the tags, case-insensitive
tags_params = [f"%{t.strip()}%" for t in tags]
conditions.append(tags_condition)
params.extend(tags_params)
conditions.append(tags_condition)
params.extend(tags_params)
if has_been_opened:
conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
conditions.append("opened_at IS NULL")
if has_been_opened:
conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
conditions.append("opened_at IS NULL")
# Ignore whitespace in the query
stripped_query = query.strip() if query else None
if stripped_query:
# Construct a wildcard query for the name, description, and tags
wildcard_query = "%" + stripped_query + "%"
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
# Ignore whitespace in the query
stripped_query = query.strip() if query else None
if stripped_query:
# Construct a wildcard query for the name, description, and tags
wildcard_query = "%" + stripped_query + "%"
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
conditions.append(query_condition)
params.extend([wildcard_query, wildcard_query, wildcard_query])
conditions.append(query_condition)
params.extend([wildcard_query, wildcard_query, wildcard_query])
if conditions:
# If there are conditions, add a WHERE clause and then join the conditions
main_query += " WHERE "
count_query += " WHERE "
if conditions:
# If there are conditions, add a WHERE clause and then join the conditions
main_query += " WHERE "
count_query += " WHERE "
all_conditions = " AND ".join(conditions)
main_query += all_conditions
count_query += all_conditions
all_conditions = " AND ".join(conditions)
main_query += all_conditions
count_query += all_conditions
# After this point, the query and params differ for the main query and the count query
main_params = params.copy()
count_params = params.copy()
# After this point, the query and params differ for the main query and the count query
main_params = params.copy()
count_params = params.copy()
# Main query also gets ORDER BY and LIMIT/OFFSET
main_query += f" ORDER BY {order_by.value} {direction.value}"
# Main query also gets ORDER BY and LIMIT/OFFSET
main_query += f" ORDER BY {order_by.value} {direction.value}"
if per_page:
main_query += " LIMIT ? OFFSET ?"
main_params.extend([per_page, page * per_page])
if per_page:
main_query += " LIMIT ? OFFSET ?"
main_params.extend([per_page, page * per_page])
# Put a ring on it
main_query += ";"
count_query += ";"
# Put a ring on it
main_query += ";"
count_query += ";"
cursor = self._conn.cursor()
cursor.execute(main_query, main_params)
rows = cursor.fetchall()
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
cursor.execute(main_query, main_params)
rows = cursor.fetchall()
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
cursor.execute(count_query, count_params)
total = cursor.fetchone()[0]
cursor.execute(count_query, count_params)
total = cursor.fetchone()[0]
if per_page:
pages = total // per_page + (total % per_page > 0)
@@ -247,46 +232,46 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if not tags:
return {}
cursor = self._conn.cursor()
result: dict[str, int] = {}
# Base conditions for categories and selected tags
base_conditions: list[str] = []
base_params: list[str | int] = []
with self._db.transaction() as cursor:
result: dict[str, int] = {}
# Base conditions for categories and selected tags
base_conditions: list[str] = []
base_params: list[str | int] = []
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
# For each tag to count, run a separate query
for tag in tags:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# For each tag to count, run a separate query
for tag in tags:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# Add this specific tag condition
conditions.append("tags LIKE ?")
params.append(f"%{tag.strip()}%")
# Add this specific tag condition
conditions.append("tags LIKE ?")
params.append(f"%{tag.strip()}%")
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[tag] = count
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[tag] = count
return result
@@ -296,52 +281,51 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> dict[str, int]:
cursor = self._conn.cursor()
result: dict[str, int] = {}
# Base conditions for categories
base_conditions: list[str] = []
base_params: list[str | int] = []
with self._db.transaction() as cursor:
result: dict[str, int] = {}
# Base conditions for categories
base_conditions: list[str] = []
base_params: list[str | int] = []
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
# For each category to count, run a separate query
for category in categories:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# For each category to count, run a separate query
for category in categories:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# Add this specific category condition
conditions.append("category = ?")
params.append(category.value)
# Add this specific category condition
conditions.append("category = ?")
params.append(category.value)
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[category.value] = count
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[category.value] = count
return result
def update_opened_at(self, workflow_id: str) -> None:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
f"""--sql
UPDATE workflow_library
@@ -350,10 +334,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
def _sync_default_workflows(self) -> None:
"""Syncs default workflows to the database. Internal use only."""
@@ -368,8 +348,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
meaningless, as they are overwritten every time the server starts.
"""
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
workflows_from_file: list[Workflow] = []
workflows_to_update: list[Workflow] = []
workflows_to_add: list[Workflow] = []
@@ -449,8 +428,3 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(w.model_dump_json(), w.id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise

View File

@@ -187,7 +187,7 @@ class ModelConfigBase(ABC, BaseModel):
else:
return config_cls.from_model_on_disk(mod, **overrides)
raise InvalidModelConfigException("No valid config found")
raise InvalidModelConfigException("Unable to determine model type")
@classmethod
def get_tag(cls) -> Tag:

View File

@@ -14,6 +14,7 @@ const config: KnipConfig = {
'src/features/controlLayers/konva/util.ts',
// Will be using this
'src/common/hooks/useAsyncState.ts',
'src/app/store/use-debounced-app-selector.ts',
],
ignoreBinaries: ['only-allow'],
paths: {

View File

@@ -63,7 +63,7 @@
"framer-motion": "^11.10.0",
"i18next": "^25.2.1",
"i18next-http-backend": "^3.0.2",
"idb-keyval": "^6.2.2",
"idb-keyval": "6.2.1",
"jsondiffpatch": "^0.7.3",
"konva": "^9.3.20",
"linkify-react": "^4.3.1",

View File

@@ -81,8 +81,8 @@ importers:
specifier: ^3.0.2
version: 3.0.2
idb-keyval:
specifier: ^6.2.2
version: 6.2.2
specifier: 6.2.1
version: 6.2.1
jsondiffpatch:
specifier: ^0.7.3
version: 0.7.3
@@ -2927,8 +2927,8 @@ packages:
typescript:
optional: true
idb-keyval@6.2.2:
resolution: {integrity: sha512-yjD9nARJ/jb1g+CvD0tlhUHOrJ9Sy0P8T9MF3YaLlHnSRpwPfpTX0XIvpmw3gAJUmEu3FiICLBDPXVwyEvrleg==}
idb-keyval@6.2.1:
resolution: {integrity: sha512-8Sb3veuYCyrZL+VBt9LJfZjLUPWVvqn8tG28VqYNFCo43KHcKuq+b4EiXGeuaLAQWL2YmyDgMp2aSpH9JHsEQg==}
ieee754@1.2.1:
resolution: {integrity: sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==}
@@ -7720,7 +7720,7 @@ snapshots:
optionalDependencies:
typescript: 5.8.3
idb-keyval@6.2.2: {}
idb-keyval@6.2.1: {}
ieee754@1.2.1: {}

View File

@@ -1775,6 +1775,20 @@
"Structure controls how closely the output image will keep to the layout of the original. Low structure allows major changes, while high structure strictly maintains the original composition and layout."
]
},
"tileSize": {
"heading": "Tile Size",
"paragraphs": [
"Controls the size of tiles used during the upscaling process. Larger tiles use more memory but may produce better results.",
"SD1.5 models default to 768, while SDXL models default to 1024. Reduce tile size if you encounter memory issues."
]
},
"tileOverlap": {
"heading": "Tile Overlap",
"paragraphs": [
"Controls the overlap between adjacent tiles during upscaling. Higher overlap values help reduce visible seams between tiles but use more memory.",
"The default value of 128 works well for most cases, but you can adjust based on your specific needs and memory constraints."
]
},
"fluxDevLicense": {
"heading": "Non-Commercial License",
"paragraphs": [
@@ -1962,7 +1976,6 @@
"recalculateRects": "Recalculate Rects",
"clipToBbox": "Clip Strokes to Bbox",
"outputOnlyMaskedRegions": "Output Only Generated Regions",
"saveAllImagesToGallery": "Save All Images to Gallery",
"addLayer": "Add Layer",
"duplicate": "Duplicate",
"moveToFront": "Move to Front",
@@ -2087,9 +2100,9 @@
"resetCanvasLayers": "Reset Canvas Layers",
"resetGenerationSettings": "Reset Generation Settings",
"replaceCurrent": "Replace Current",
"controlLayerEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer, <PullBboxButton>pull the bounding box into this layer</PullBboxButton>, or draw on the canvas to get started.",
"referenceImageEmptyStateWithCanvasOptions": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this Reference Image or <PullBboxButton>pull the bounding box into this Reference Image</PullBboxButton> to get started.",
"referenceImageEmptyState": "<UploadButton>Upload an image</UploadButton> or drag an image from the <GalleryButton>gallery</GalleryButton> onto this Reference Image to get started.",
"controlLayerEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the gallery onto this layer, <PullBboxButton>pull the bounding box into this layer</PullBboxButton>, or draw on the canvas to get started.",
"referenceImageEmptyStateWithCanvasOptions": "<UploadButton>Upload an image</UploadButton>, drag an image from the gallery onto this Reference Image or <PullBboxButton>pull the bounding box into this Reference Image</PullBboxButton> to get started.",
"referenceImageEmptyState": "<UploadButton>Upload an image</UploadButton> or drag an image from the gallery onto this Reference Image to get started.",
"uploadOrDragAnImage": "Drag an image from the gallery or <UploadButton>upload an image</UploadButton>.",
"imageNoise": "Image Noise",
"denoiseLimit": "Denoise Limit",
@@ -2332,7 +2345,8 @@
"alert": "Preserving Masked Region"
},
"saveAllImagesToGallery": {
"alert": "Saving All Images to Gallery"
"label": "Send New Generations to Gallery",
"alert": "Sending new generations to Gallery, bypassing Canvas"
},
"isolatedStagingPreview": "Isolated Staging Preview",
"isolatedPreview": "Isolated Preview",
@@ -2396,6 +2410,9 @@
"upscaleModel": "Upscale Model",
"postProcessingModel": "Post-Processing Model",
"scale": "Scale",
"tileControl": "Tile Control",
"tileSize": "Tile Size",
"tileOverlap": "Tile Overlap",
"postProcessingMissingModelWarning": "Visit the <LinkComponent>Model Manager</LinkComponent> to install a post-processing (image to image) model.",
"missingModelsWarning": "Visit the <LinkComponent>Model Manager</LinkComponent> to install the required models:",
"mainModelDesc": "Main model (SD1.5 or SDXL architecture)",

View File

@@ -15,6 +15,8 @@ import { useDndMonitor } from 'features/dnd/useDndMonitor';
import { useDynamicPromptsWatcher } from 'features/dynamicPrompts/hooks/useDynamicPromptsWatcher';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { useWorkflowBuilderWatcher } from 'features/nodes/components/sidePanel/workflow/IsolatedWorkflowBuilderWatcher';
import { useSyncExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
import { useSyncNodeErrors } from 'features/nodes/store/util/fieldValidators';
import { useReadinessWatcher } from 'features/queue/store/readiness';
import { configChanged } from 'features/system/store/configSlice';
import { selectLanguage } from 'features/system/store/systemSelectors';
@@ -47,10 +49,12 @@ export const GlobalHookIsolator = memo(
useCloseChakraTooltipsOnDragFix();
useNavigationApi();
useDndMonitor();
useSyncNodeErrors();
// Persistent subscription to the queue counts query - canvas relies on this to know if there are pending
// and/or in progress canvas sessions.
useGetQueueCountsByDestinationQuery(queueCountArg);
useSyncExecutionState();
useEffect(() => {
i18n.changeLanguage(language);

View File

@@ -1,10 +1,6 @@
import { GlobalImageHotkeys } from 'app/components/GlobalImageHotkeys';
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
import { CanvasPasteModal } from 'features/controlLayers/components/CanvasPasteModal';
import {
NewCanvasSessionDialog,
NewGallerySessionDialog,
} from 'features/controlLayers/components/NewSessionConfirmationAlertDialog';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { DeleteImageModal } from 'features/deleteImageModal/components/DeleteImageModal';
import { FullscreenDropzone } from 'features/dnd/FullscreenDropzone';
@@ -50,8 +46,6 @@ export const GlobalModalIsolator = memo(() => {
<RefreshAfterResetModal />
<DeleteBoardModal />
<GlobalImageHotkeys />
<NewGallerySessionDialog />
<NewCanvasSessionDialog />
<ImageContextMenu />
<FullscreenDropzone />
<VideosModal />

View File

@@ -21,7 +21,6 @@ import { $isStylePresetsMenuOpen, activeStylePresetIdChanged } from 'features/st
import { toast } from 'features/toast/toast';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { LAUNCHPAD_PANEL_ID, WORKSPACE_PANEL_ID } from 'features/ui/layouts/shared';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import { atom } from 'nanostores';
import { useCallback, useEffect } from 'react';
@@ -165,7 +164,6 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
// Go to the generate tab, open the launchpad
await navigationApi.focusPanel('generate', LAUNCHPAD_PANEL_ID);
store.dispatch(paramsReset());
store.dispatch(activeTabCanvasRightPanelChanged('gallery'));
break;
case 'canvas':
// Go to the canvas tab, open the launchpad

View File

@@ -1,7 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import { modelChanged, syncedToOptimalDimension, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { refImageModelChanged, selectReferenceImageEntities } from 'features/controlLayers/store/refImagesSlice';
@@ -152,7 +152,8 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
if (modelBase !== state.params.model?.base) {
// Sync generate tab settings whenever the model base changes
dispatch(syncedToOptimalDimension());
if (!selectIsStaging(state)) {
const isStaging = buildSelectIsStaging(selectCanvasSessionId(state))(state);
if (!isStaging) {
// Canvas tab only syncs if not staging
dispatch(bboxSyncedToOptimalDimension());
}

View File

@@ -15,7 +15,11 @@ import { refImageModelChanged, selectRefImagesSlice } from 'features/controlLaye
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier, isFLUXReduxConfig, isIPAdapterConfig } from 'features/controlLayers/store/types';
import { modelSelected } from 'features/parameters/store/actions';
import { postProcessingModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
import {
postProcessingModelChanged,
tileControlnetModelChanged,
upscaleModelChanged,
} from 'features/parameters/store/upscaleSlice';
import {
zParameterCLIPEmbedModel,
zParameterSpandrelImageToImageModel,
@@ -28,6 +32,7 @@ import type { AnyModelConfig } from 'services/api/types';
import {
isCLIPEmbedModelConfig,
isControlLayerModelConfig,
isControlNetModelConfig,
isFluxReduxModelConfig,
isFluxVAEModelConfig,
isIPAdapterModelConfig,
@@ -71,6 +76,7 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
handleControlAdapterModels(models, state, dispatch, log);
handlePostProcessingModel(models, state, dispatch, log);
handleUpscaleModel(models, state, dispatch, log);
handleTileControlNetModel(models, state, dispatch, log);
handleIPAdapterModels(models, state, dispatch, log);
handleT5EncoderModels(models, state, dispatch, log);
handleCLIPEmbedModels(models, state, dispatch, log);
@@ -345,6 +351,46 @@ const handleUpscaleModel: ModelHandler = (models, state, dispatch, log) => {
}
};
const handleTileControlNetModel: ModelHandler = (models, state, dispatch, log) => {
const selectedTileControlNetModel = state.upscale.tileControlnetModel;
const controlNetModels = models.filter(isControlNetModelConfig);
// If the currently selected model is available, we don't need to do anything
if (selectedTileControlNetModel && controlNetModels.some((m) => m.key === selectedTileControlNetModel.key)) {
return;
}
// The only way we have to identify a model as a tile model is by its name containing 'tile' :)
const tileModel = controlNetModels.find((m) => m.name.toLowerCase().includes('tile'));
// If we have a tile model, select it
if (tileModel) {
log.debug(
{ selectedTileControlNetModel, tileModel },
'No selected tile ControlNet model or selected model is not available, selecting tile model'
);
dispatch(tileControlnetModelChanged(tileModel));
return;
}
// Otherwise, select the first available ControlNet model
const firstModel = controlNetModels[0] || null;
if (firstModel) {
log.debug(
{ selectedTileControlNetModel, firstModel },
'No tile ControlNet model found, selecting first available ControlNet model'
);
dispatch(tileControlnetModelChanged(firstModel));
return;
}
// No available models, we should clear the selected model - but only if we have one selected
if (selectedTileControlNetModel) {
log.debug({ selectedTileControlNetModel }, 'Selected tile ControlNet model is not available, clearing');
dispatch(tileControlnetModelChanged(null));
}
};
const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
const selectedT5EncoderModel = state.params.t5EncoderModel;
const t5EncoderModels = models.filter((m) => isT5EncoderModelConfig(m));

View File

@@ -1,7 +1,7 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { isNil } from 'es-toolkit';
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
heightChanged,
setCfgRescaleMultiplier,
@@ -115,7 +115,8 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
}
const setSizeOptions = { updateAspectRatio: true, clamp: true };
const isStaging = selectIsStaging(getState());
const isStaging = buildSelectIsStaging(selectCanvasSessionId(state))(state);
const activeTab = selectActiveTab(getState());
if (activeTab === 'generate') {
if (isParameterWidth(width)) {

View File

@@ -67,6 +67,8 @@ export type Feature =
| 'scale'
| 'creativity'
| 'structure'
| 'tileSize'
| 'tileOverlap'
| 'optimizedDenoising'
| 'fluxDevLicense';

View File

@@ -6,7 +6,6 @@ import { atom, computed } from 'nanostores';
import type { RefObject } from 'react';
import { useEffect } from 'react';
import { objectKeys } from 'tsafe';
import z from 'zod/v4';
/**
* We need to manage focus regions to conditionally enable hotkeys:
@@ -28,10 +27,7 @@ import z from 'zod/v4';
const log = logger('system');
/**
* The names of the focus regions.
*/
const zFocusRegionName = z.enum([
const REGION_NAMES = [
'launchpad',
'viewer',
'gallery',
@@ -41,13 +37,16 @@ const zFocusRegionName = z.enum([
'workflows',
'progress',
'settings',
]);
export type FocusRegionName = z.infer<typeof zFocusRegionName>;
] as const;
/**
* The names of the focus regions.
*/
export type FocusRegionName = (typeof REGION_NAMES)[number];
/**
* A map of focus regions to the elements that are part of that region.
*/
const REGION_TARGETS: Record<FocusRegionName, Set<HTMLElement>> = zFocusRegionName.options.values().reduce(
const REGION_TARGETS: Record<FocusRegionName, Set<HTMLElement>> = REGION_NAMES.reduce(
(acc, region) => {
acc[region] = new Set<HTMLElement>();
return acc;

View File

@@ -6,7 +6,7 @@ import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { selectIsClientSideUploadEnabled } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import { memo, useCallback } from 'react';
import type { FileRejection } from 'react-dropzone';
import type { Accept, FileRejection } from 'react-dropzone';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { PiUploadBold } from 'react-icons/pi';
@@ -15,6 +15,18 @@ import type { ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
import type { SetOptional } from 'type-fest';
const addUpperCaseReducer = (acc: string[], ext: string) => {
acc.push(ext);
acc.push(ext.toUpperCase());
return acc;
};
export const dropzoneAccept: Accept = {
'image/png': ['.png'].reduce(addUpperCaseReducer, [] as string[]),
'image/jpeg': ['.jpg', '.jpeg', '.png'].reduce(addUpperCaseReducer, [] as string[]),
'image/webp': ['.webp'].reduce(addUpperCaseReducer, [] as string[]),
};
import { useClientSideUpload } from './useClientSideUpload';
type UseImageUploadButtonArgs =
| {
@@ -164,11 +176,7 @@ export const useImageUploadButton = ({
getInputProps: getUploadInputProps,
open: openUploader,
} = useDropzone({
accept: {
'image/png': ['.png'],
'image/jpeg': ['.jpg', '.jpeg', '.png'],
'image/webp': ['.webp'],
},
accept: dropzoneAccept,
onDropAccepted,
onDropRejected,
disabled: isDisabled,

View File

@@ -1,7 +1,6 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, ConfirmationAlertDialog, Flex, FormControl, Text } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import {
@@ -14,7 +13,7 @@ import { useTranslation } from 'react-i18next';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import { useAddImagesToBoardMutation, useRemoveImagesFromBoardMutation } from 'services/api/endpoints/images';
const selectImagesToChange = createMemoizedSelector(
const selectImagesToChange = createSelector(
selectChangeBoardModalSlice,
(changeBoardModal) => changeBoardModal.image_names
);

View File

@@ -13,7 +13,7 @@ export const CanvasAlertsSaveAllImagesToGallery = memo(() => {
}
return (
<Alert status="info" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
<Alert status="warning" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
<AlertIcon />
<AlertTitle>{t('controlLayers.settings.saveAllImagesToGallery.alert')}</AlertTitle>
</Alert>

View File

@@ -57,21 +57,21 @@ const CanvasAlertsSelectedEntityStatusContent = memo(({ entityIdentifier, adapte
const alert = useMemo<AlertData | null>(() => {
if (isFiltering) {
return {
status: 'info',
status: 'warning',
title: t('controlLayers.HUD.entityStatus.isFiltering', { title }),
};
}
if (isTransforming) {
return {
status: 'info',
status: 'warning',
title: t('controlLayers.HUD.entityStatus.isTransforming', { title }),
};
}
if (isEmpty) {
return {
status: 'info',
status: 'warning',
title: t('controlLayers.HUD.entityStatus.isEmpty', { title }),
};
}

View File

@@ -5,7 +5,6 @@ import { useEntityIdentifierContext } from 'features/controlLayers/contexts/Enti
import { usePullBboxIntoLayer } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { replaceCanvasEntityObjectsWithImage } from 'features/imageActions/actions';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { memo, useCallback, useMemo } from 'react';
import { Trans } from 'react-i18next';
import type { ImageDTO } from 'services/api/types';
@@ -21,9 +20,6 @@ export const ControlLayerSettingsEmptyState = memo(() => {
[dispatch, entityIdentifier, getState]
);
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
const onClickGalleryButton = useCallback(() => {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch]);
const pullBboxIntoLayer = usePullBboxIntoLayer(entityIdentifier);
const components = useMemo(
@@ -31,14 +27,11 @@ export const ControlLayerSettingsEmptyState = memo(() => {
UploadButton: (
<Button isDisabled={isBusy} size="sm" variant="link" color="base.300" {...uploadApi.getUploadButtonProps()} />
),
GalleryButton: (
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
),
PullBboxButton: (
<Button onClick={pullBboxIntoLayer} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
),
}),
[isBusy, onClickGalleryButton, pullBboxIntoLayer, uploadApi]
[isBusy, pullBboxIntoLayer, uploadApi]
);
return (

View File

@@ -1,131 +0,0 @@
import { Checkbox, ConfirmationAlertDialog, Flex, FormControl, FormLabel, Text } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { buildUseBoolean } from 'common/hooks/useBoolean';
import { canvasSessionReset, generateSessionReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
selectSystemShouldConfirmOnNewSession,
shouldConfirmOnNewSessionToggled,
} from 'features/system/store/systemSlice';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const [useNewGallerySessionDialog] = buildUseBoolean(false);
const [useNewCanvasSessionDialog] = buildUseBoolean(false);
const useNewGallerySession = () => {
const dispatch = useAppDispatch();
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
const newSessionDialog = useNewGallerySessionDialog();
const newGallerySessionImmediate = useCallback(() => {
dispatch(generateSessionReset());
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch]);
const newGallerySessionWithDialog = useCallback(() => {
if (shouldConfirmOnNewSession) {
newSessionDialog.setTrue();
return;
}
newGallerySessionImmediate();
}, [newGallerySessionImmediate, newSessionDialog, shouldConfirmOnNewSession]);
return { newGallerySessionImmediate, newGallerySessionWithDialog };
};
const useNewCanvasSession = () => {
const dispatch = useAppDispatch();
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
const newSessionDialog = useNewCanvasSessionDialog();
const newCanvasSessionImmediate = useCallback(() => {
dispatch(canvasSessionReset());
dispatch(activeTabCanvasRightPanelChanged('layers'));
}, [dispatch]);
const newCanvasSessionWithDialog = useCallback(() => {
if (shouldConfirmOnNewSession) {
newSessionDialog.setTrue();
return;
}
newCanvasSessionImmediate();
}, [newCanvasSessionImmediate, newSessionDialog, shouldConfirmOnNewSession]);
return { newCanvasSessionImmediate, newCanvasSessionWithDialog };
};
export const NewGallerySessionDialog = memo(() => {
useAssertSingleton('NewGallerySessionDialog');
const { t } = useTranslation();
const dispatch = useAppDispatch();
const dialog = useNewGallerySessionDialog();
const { newGallerySessionImmediate } = useNewGallerySession();
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
const onToggleConfirm = useCallback(() => {
dispatch(shouldConfirmOnNewSessionToggled());
}, [dispatch]);
return (
<ConfirmationAlertDialog
isOpen={dialog.isTrue}
onClose={dialog.setFalse}
title={t('controlLayers.newGallerySession')}
acceptCallback={newGallerySessionImmediate}
acceptButtonText={t('common.ok')}
useInert={false}
>
<Flex direction="column" gap={3}>
<Text>{t('controlLayers.newGallerySessionDesc')}</Text>
<Text>{t('common.areYouSure')}</Text>
<FormControl>
<FormLabel>{t('common.dontAskMeAgain')}</FormLabel>
<Checkbox isChecked={!shouldConfirmOnNewSession} onChange={onToggleConfirm} />
</FormControl>
</Flex>
</ConfirmationAlertDialog>
);
});
NewGallerySessionDialog.displayName = 'NewGallerySessionDialog';
export const NewCanvasSessionDialog = memo(() => {
useAssertSingleton('NewCanvasSessionDialog');
const { t } = useTranslation();
const dispatch = useAppDispatch();
const dialog = useNewCanvasSessionDialog();
const { newCanvasSessionImmediate } = useNewCanvasSession();
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
const onToggleConfirm = useCallback(() => {
dispatch(shouldConfirmOnNewSessionToggled());
}, [dispatch]);
return (
<ConfirmationAlertDialog
isOpen={dialog.isTrue}
onClose={dialog.setFalse}
title={t('controlLayers.newCanvasSession')}
acceptCallback={newCanvasSessionImmediate}
acceptButtonText={t('common.ok')}
useInert={false}
>
<Flex direction="column" gap={3}>
<Text>{t('controlLayers.newCanvasSessionDesc')}</Text>
<Text>{t('common.areYouSure')}</Text>
<FormControl>
<FormLabel>{t('common.dontAskMeAgain')}</FormLabel>
<Checkbox isChecked={!shouldConfirmOnNewSession} onChange={onToggleConfirm} />
</FormControl>
</Flex>
</ConfirmationAlertDialog>
);
});
NewCanvasSessionDialog.displayName = 'NewCanvasSessionDialog';

View File

@@ -6,7 +6,6 @@ import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd';
import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { setGlobalReferenceImage } from 'features/imageActions/actions';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { memo, useCallback, useMemo } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import type { ImageDTO } from 'services/api/types';
@@ -22,9 +21,6 @@ export const RefImageNoImageState = memo(() => {
[dispatch, id]
);
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
const onClickGalleryButton = useCallback(() => {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch]);
const dndTargetData = useMemo<SetGlobalReferenceImageDndTargetData>(
() => setGlobalReferenceImageDndTarget.getData({ id }),
@@ -34,9 +30,8 @@ export const RefImageNoImageState = memo(() => {
const components = useMemo(
() => ({
UploadButton: <Button size="sm" variant="link" color="base.300" {...uploadApi.getUploadButtonProps()} />,
GalleryButton: <Button onClick={onClickGalleryButton} size="sm" variant="link" color="base.300" />,
}),
[onClickGalleryButton, uploadApi]
[uploadApi]
);
return (

View File

@@ -8,7 +8,6 @@ import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd';
import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { setGlobalReferenceImage } from 'features/imageActions/actions';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { memo, useCallback, useMemo } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import type { ImageDTO } from 'services/api/types';
@@ -25,9 +24,6 @@ export const RefImageNoImageStateWithCanvasOptions = memo(() => {
[dispatch, id]
);
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
const onClickGalleryButton = useCallback(() => {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch]);
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(id);
const dndTargetData = useMemo<SetGlobalReferenceImageDndTargetData>(
@@ -40,14 +36,11 @@ export const RefImageNoImageStateWithCanvasOptions = memo(() => {
UploadButton: (
<Button isDisabled={isBusy} size="sm" variant="link" color="base.300" {...uploadApi.getUploadButtonProps()} />
),
GalleryButton: (
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
),
PullBboxButton: (
<Button onClick={pullBboxIntoIPAdapter} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
),
}),
[isBusy, onClickGalleryButton, pullBboxIntoIPAdapter, uploadApi]
[isBusy, pullBboxIntoIPAdapter, uploadApi]
);
return (

View File

@@ -9,7 +9,6 @@ import type { SetRegionalGuidanceReferenceImageDndTargetData } from 'features/dn
import { setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { setRegionalGuidanceReferenceImage } from 'features/imageActions/actions';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { memo, useCallback, useMemo } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
@@ -31,9 +30,6 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag
[dispatch, entityIdentifier, referenceImageId]
);
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
const onClickGalleryButton = useCallback(() => {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch]);
const onDeleteIPAdapter = useCallback(() => {
dispatch(rgRefImageDeleted({ entityIdentifier, referenceImageId }));
}, [dispatch, entityIdentifier, referenceImageId]);
@@ -53,14 +49,11 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag
UploadButton: (
<Button isDisabled={isBusy} size="sm" variant="link" color="base.300" {...uploadApi.getUploadButtonProps()} />
),
GalleryButton: (
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
),
PullBboxButton: (
<Button onClick={pullBboxIntoIPAdapter} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
),
}),
[isBusy, onClickGalleryButton, pullBboxIntoIPAdapter, uploadApi]
[isBusy, pullBboxIntoIPAdapter, uploadApi]
);
return (

View File

@@ -16,7 +16,7 @@ export const CanvasSettingsSaveAllImagesToGalleryCheckbox = memo(() => {
}, [dispatch]);
return (
<FormControl w="full">
<FormLabel flexGrow={1}>{t('controlLayers.saveAllImagesToGallery')}</FormLabel>
<FormLabel flexGrow={1}>{t('controlLayers.settings.saveAllImagesToGallery.label')}</FormLabel>
<Checkbox isChecked={saveAllImagesToGallery} onChange={onChange} />
</FormControl>
);

View File

@@ -64,10 +64,6 @@ export const QueueItemPreviewMini = memo(({ item, isSelected, index }: Props) =>
}
}, [autoSwitch, dispatch]);
const onLoad = useCallback(() => {
ctx.onImageLoad(item.item_id);
}, [ctx, item.item_id]);
return (
<Flex
id={getQueueItemElementId(index)}
@@ -77,7 +73,7 @@ export const QueueItemPreviewMini = memo(({ item, isSelected, index }: Props) =>
onDoubleClick={onDoubleClick}
>
<QueueItemStatusLabel item={item} position="absolute" margin="auto" />
{imageDTO && <DndImage imageDTO={imageDTO} onLoad={onLoad} asThumbnail />}
{imageDTO && <DndImage imageDTO={imageDTO} asThumbnail position="absolute" />}
{!imageLoaded && <QueueItemProgressImage itemId={item.item_id} position="absolute" />}
<QueueItemNumber number={index + 1} position="absolute" top={0} left={1} />
<QueueItemCircularProgress itemId={item.item_id} status={item.status} position="absolute" top={1} right={2} />

View File

@@ -94,6 +94,7 @@ const useScrollableStagingArea = (rootRef: RefObject<HTMLDivElement>) => {
const { viewport } = osInstance.elements();
viewport.style.overflowX = `var(--os-viewport-overflow-x)`;
viewport.style.overflowY = `var(--os-viewport-overflow-y)`;
viewport.style.textAlign = 'center';
},
},
options: {
@@ -148,8 +149,8 @@ export const StagingAreaItemsList = memo(() => {
return;
}
return canvasManager.stagingArea.connectToSession(ctx.$selectedItemId, ctx.$progressData, ctx.$isPending);
}, [canvasManager, ctx.$progressData, ctx.$selectedItemId, ctx.$isPending]);
return canvasManager.stagingArea.connectToSession(ctx.$items, ctx.$selectedItemId, ctx.$progressData);
}, [canvasManager, ctx.$progressData, ctx.$selectedItemId, ctx.$items]);
useEffect(() => {
return ctx.$selectedItemIndex.listen((index) => {

View File

@@ -1,11 +1,13 @@
import { useStore } from '@nanostores/react';
import { useAppStore } from 'app/store/storeHooks';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { getOutputImageName } from 'features/controlLayers/components/SimpleSession/shared';
import { loadImage } from 'features/controlLayers/konva/util';
import { selectStagingAreaAutoSwitch } from 'features/controlLayers/store/canvasSettingsSlice';
import {
buildSelectSessionQueueItems,
buildSelectCanvasQueueItems,
canvasQueueItemDiscarded,
canvasSessionReset,
selectCanvasSessionId,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import type { ProgressImage } from 'features/nodes/types/common';
import type { Atom, MapStore, StoreValue, WritableAtom } from 'nanostores';
@@ -86,11 +88,9 @@ const setProgress = ($progressData: ProgressDataMap, data: S['InvocationProgress
export type ProgressDataMap = MapStore<Record<number, ProgressData | undefined>>;
type CanvasSessionContextValue = {
session: { id: string; type: 'simple' | 'advanced' };
$items: Atom<S['SessionQueueItem'][]>;
$itemCount: Atom<number>;
$hasItems: Atom<boolean>;
$isPending: Atom<boolean>;
$progressData: ProgressDataMap;
$selectedItemId: WritableAtom<number | null>;
$selectedItem: Atom<S['SessionQueueItem'] | null>;
@@ -100,472 +100,385 @@ type CanvasSessionContextValue = {
selectPrev: () => void;
selectFirst: () => void;
selectLast: () => void;
onImageLoad: (itemId: number) => void;
discard: (itemId: number) => void;
discardAll: () => void;
};
const CanvasSessionContext = createContext<CanvasSessionContextValue | null>(null);
export const CanvasSessionContextProvider = memo(
({ id, type, children }: PropsWithChildren<{ id: string; type: 'simple' | 'advanced' }>) => {
/**
* For best performance and interop with the Canvas, which is outside react but needs to interact with the react
* app, all canvas session state is packaged as nanostores atoms. The trickiest part is syncing the queue items
* with a nanostores atom.
*/
const session = useMemo(() => ({ type, id }), [type, id]);
export const CanvasSessionContextProvider = memo(({ children }: PropsWithChildren) => {
/**
* For best performance and interop with the Canvas, which is outside react but needs to interact with the react
* app, all canvas session state is packaged as nanostores atoms. The trickiest part is syncing the queue items
* with a nanostores atom.
*/
/**
* App store
*/
const store = useAppStore();
/**
* App store
*/
const store = useAppStore();
const socket = useStore($socket);
const sessionId = useAppSelector(selectCanvasSessionId);
/**
* Track the last completed item. Used to implement autoswitch.
*/
const $lastCompletedItemId = useState(() => atom<number | null>(null))[0];
const socket = useStore($socket);
/**
* Track the last started item. Used to implement autoswitch.
*/
const $lastStartedItemId = useState(() => atom<number | null>(null))[0];
/**
* Track the last completed item. Used to implement autoswitch.
*/
const $lastCompletedItemId = useState(() => atom<number | null>(null))[0];
/**
* Manually-synced atom containing queue items for the current session. This is populated from the RTK Query cache
* and kept in sync with it via a redux subscription.
*/
const $items = useState(() => atom<S['SessionQueueItem'][]>([]))[0];
/**
* Manually-synced atom containing queue items for the current session. This is populated from the RTK Query cache
* and kept in sync with it via a redux subscription.
*/
const $items = useState(() => atom<S['SessionQueueItem'][]>([]))[0];
/**
* An internal flag used to work around race conditions with auto-switch switching to queue items before their
* output images have fully loaded.
*/
const $lastLoadedItemId = useState(() => atom<number | null>(null))[0];
/**
* An ephemeral store of progress events and images for all items in the current session.
*/
const $progressData = useState(() => map<StoreValue<ProgressDataMap>>({}))[0];
/**
* An ephemeral store of progress events and images for all items in the current session.
*/
const $progressData = useState(() => map<StoreValue<ProgressDataMap>>({}))[0];
/**
* The currently selected queue item's ID, or null if one is not selected.
*/
const $selectedItemId = useState(() => atom<number | null>(null))[0];
/**
* The currently selected queue item's ID, or null if one is not selected.
*/
const $selectedItemId = useState(() => atom<number | null>(null))[0];
/**
* The number of items. Computed from the queue items array.
*/
const $itemCount = useState(() => computed([$items], (items) => items.length))[0];
/**
* The number of items. Computed from the queue items array.
*/
const $itemCount = useState(() => computed([$items], (items) => items.length))[0];
/**
* Whether there are any items. Computed from the queue items array.
*/
const $hasItems = useState(() => computed([$items], (items) => items.length > 0))[0];
/**
* Whether there are any items. Computed from the queue items array.
*/
const $hasItems = useState(() => computed([$items], (items) => items.length > 0))[0];
/**
* Whether there are any pending or in-progress items. Computed from the queue items array.
*/
const $isPending = useState(() =>
computed([$items], (items) => items.some((item) => item.status === 'pending' || item.status === 'in_progress'))
)[0];
/**
* Whether there are any pending or in-progress items. Computed from the queue items array.
*/
const $isPending = useState(() =>
computed([$items], (items) => items.some((item) => item.status === 'pending' || item.status === 'in_progress'))
)[0];
/**
* The currently selected queue item, or null if one is not selected.
*/
const $selectedItem = useState(() =>
computed([$items, $selectedItemId], (items, selectedItemId) => {
if (items.length === 0) {
return null;
}
if (selectedItemId === null) {
return null;
}
return items.find(({ item_id }) => item_id === selectedItemId) ?? null;
})
)[0];
/**
* The currently selected queue item, or null if one is not selected.
*/
const $selectedItem = useState(() =>
computed([$items, $selectedItemId], (items, selectedItemId) => {
if (items.length === 0) {
return null;
/**
* The currently selected queue item's index in the list of items, or null if one is not selected.
*/
const $selectedItemIndex = useState(() =>
computed([$items, $selectedItemId], (items, selectedItemId) => {
if (items.length === 0) {
return null;
}
if (selectedItemId === null) {
return null;
}
return items.findIndex(({ item_id }) => item_id === selectedItemId) ?? null;
})
)[0];
/**
* The currently selected queue item's output image name, or null if one is not selected or there is no output
* image recorded.
*/
const $selectedItemOutputImageDTO = useState(() =>
computed([$selectedItemId, $progressData], (selectedItemId, progressData) => {
if (selectedItemId === null) {
return null;
}
const datum = progressData[selectedItemId];
if (!datum) {
return null;
}
return datum.imageDTO;
})
)[0];
/**
* A redux selector to select all queue items from the RTK Query cache.
*/
const selectQueueItems = useMemo(() => buildSelectCanvasQueueItems(sessionId), [sessionId]);
const discard = useCallback(
(itemId: number) => {
store.dispatch(canvasQueueItemDiscarded({ itemId }));
},
[store]
);
const discardAll = useCallback(() => {
store.dispatch(canvasSessionReset());
}, [store]);
const selectNext = useCallback(() => {
const selectedItemId = $selectedItemId.get();
if (selectedItemId === null) {
return;
}
const items = $items.get();
const currentIndex = items.findIndex((item) => item.item_id === selectedItemId);
const nextIndex = (currentIndex + 1) % items.length;
const nextItem = items[nextIndex];
if (!nextItem) {
return;
}
$selectedItemId.set(nextItem.item_id);
}, [$items, $selectedItemId]);
const selectPrev = useCallback(() => {
const selectedItemId = $selectedItemId.get();
if (selectedItemId === null) {
return;
}
const items = $items.get();
const currentIndex = items.findIndex((item) => item.item_id === selectedItemId);
const prevIndex = (currentIndex - 1 + items.length) % items.length;
const prevItem = items[prevIndex];
if (!prevItem) {
return;
}
$selectedItemId.set(prevItem.item_id);
}, [$items, $selectedItemId]);
const selectFirst = useCallback(() => {
const items = $items.get();
const first = items.at(0);
if (!first) {
return;
}
$selectedItemId.set(first.item_id);
}, [$items, $selectedItemId]);
const selectLast = useCallback(() => {
const items = $items.get();
const last = items.at(-1);
if (!last) {
return;
}
$selectedItemId.set(last.item_id);
}, [$items, $selectedItemId]);
// Set up socket listeners
useEffect(() => {
if (!socket) {
return;
}
const onProgress = (data: S['InvocationProgressEvent']) => {
if (data.destination !== sessionId) {
return;
}
setProgress($progressData, data);
};
const onQueueItemStatusChanged = (data: S['QueueItemStatusChangedEvent']) => {
if (data.destination !== sessionId) {
return;
}
if (data.status === 'completed') {
/**
* There is an unpleasant bit of indirection here. When an item is completed, and auto-switch is set to
* switch_on_finish, we want to load the image and switch to it. In this socket handler, we don't have
* access to the full queue item, which we need to get the output image and load it. We get the full
* queue items as part of the list query, so it's rather inefficient to fetch it again here.
*
* To reduce the number of extra network requests, we instead store this item as the last completed item.
* Then in the progress data sync effect, we process the queue item load its image.
*/
$lastCompletedItemId.set(data.item_id);
}
if (data.status === 'in_progress' && selectStagingAreaAutoSwitch(store.getState()) === 'switch_on_start') {
$selectedItemId.set(data.item_id);
}
};
socket.on('invocation_progress', onProgress);
socket.on('queue_item_status_changed', onQueueItemStatusChanged);
return () => {
socket.off('invocation_progress', onProgress);
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
};
}, [$progressData, $selectedItemId, sessionId, socket, $lastCompletedItemId, store]);
// Set up state subscriptions and effects
useEffect(() => {
let _prevItems: readonly S['SessionQueueItem'][] = [];
// Seed the $items atom with the initial query cache state
$items.set(selectQueueItems(store.getState()));
// Manually keep the $items atom in sync as the query cache is updated
const unsubReduxSyncToItemsAtom = store.subscribe(() => {
const prevItems = $items.get();
const items = selectQueueItems(store.getState());
if (items !== prevItems) {
_prevItems = prevItems;
$items.set(items);
}
});
// Handle cases that could result in a nonexistent queue item being selected.
const unsubEnsureSelectedItemIdExists = effect([$items, $selectedItemId], (items, selectedItemId) => {
if (items.length === 0) {
// If there are no items, cannot have a selected item.
$selectedItemId.set(null);
} else if (selectedItemId === null && items.length > 0) {
// If there is no selected item but there are items, select the first one.
$selectedItemId.set(items[0]?.item_id ?? null);
return;
} else if (selectedItemId !== null && items.findIndex(({ item_id }) => item_id === selectedItemId) === -1) {
// If an item is selected and it is not in the list of items, un-set it. This effect will run again and we'll
// the above case, selecting the first item if there are any.
let prevIndex = _prevItems.findIndex(({ item_id }) => item_id === selectedItemId);
if (prevIndex >= items.length) {
prevIndex = items.length - 1;
}
if (selectedItemId === null) {
return null;
}
return items.find(({ item_id }) => item_id === selectedItemId) ?? null;
})
)[0];
const nextItem = items[prevIndex];
$selectedItemId.set(nextItem?.item_id ?? null);
}
/**
* The currently selected queue item's index in the list of items, or null if one is not selected.
*/
const $selectedItemIndex = useState(() =>
computed([$items, $selectedItemId], (items, selectedItemId) => {
if (items.length === 0) {
return null;
}
if (selectedItemId === null) {
return null;
}
return items.findIndex(({ item_id }) => item_id === selectedItemId) ?? null;
})
)[0];
if (items !== _prevItems) {
_prevItems = items;
}
});
/**
* The currently selected queue item's output image name, or null if one is not selected or there is no output
* image recorded.
*/
const $selectedItemOutputImageDTO = useState(() =>
computed([$selectedItemId, $progressData], (selectedItemId, progressData) => {
if (selectedItemId === null) {
return null;
}
const datum = progressData[selectedItemId];
// Sync progress data - remove canceled/failed items, update progress data with new images, and load images
const unsubSyncProgressData = $items.subscribe(async (items) => {
const progressData = $progressData.get();
const toDelete: number[] = [];
const toUpdate: ProgressData[] = [];
for (const [id, datum] of objectEntries(progressData)) {
if (!datum) {
return null;
toDelete.push(id);
continue;
}
return datum.imageDTO;
})
)[0];
/**
* A redux selector to select all queue items from the RTK Query cache.
*/
const selectQueueItems = useMemo(() => buildSelectSessionQueueItems(session.id), [session.id]);
const discard = useCallback(
(itemId: number) => {
store.dispatch(canvasQueueItemDiscarded({ itemId }));
},
[store]
);
const discardAll = useCallback(() => {
store.dispatch(canvasSessionReset());
}, [store]);
const selectNext = useCallback(() => {
const selectedItemId = $selectedItemId.get();
if (selectedItemId === null) {
return;
}
const items = $items.get();
const currentIndex = items.findIndex((item) => item.item_id === selectedItemId);
const nextIndex = (currentIndex + 1) % items.length;
const nextItem = items[nextIndex];
if (!nextItem) {
return;
}
$selectedItemId.set(nextItem.item_id);
}, [$items, $selectedItemId]);
const selectPrev = useCallback(() => {
const selectedItemId = $selectedItemId.get();
if (selectedItemId === null) {
return;
}
const items = $items.get();
const currentIndex = items.findIndex((item) => item.item_id === selectedItemId);
const prevIndex = (currentIndex - 1 + items.length) % items.length;
const prevItem = items[prevIndex];
if (!prevItem) {
return;
}
$selectedItemId.set(prevItem.item_id);
}, [$items, $selectedItemId]);
const selectFirst = useCallback(() => {
const items = $items.get();
const first = items.at(0);
if (!first) {
return;
}
$selectedItemId.set(first.item_id);
}, [$items, $selectedItemId]);
const selectLast = useCallback(() => {
const items = $items.get();
const last = items.at(-1);
if (!last) {
return;
}
$selectedItemId.set(last.item_id);
}, [$items, $selectedItemId]);
const onImageLoad = useCallback(
(itemId: number) => {
const progressData = $progressData.get();
const current = progressData[itemId];
if (current) {
const next = { ...current, imageLoaded: true };
$progressData.setKey(itemId, next);
} else {
$progressData.setKey(itemId, {
...getInitialProgressData(itemId),
imageLoaded: true,
const item = items.find(({ item_id }) => item_id === datum.itemId);
if (!item) {
toDelete.push(datum.itemId);
} else if (item.status === 'canceled' || item.status === 'failed') {
toUpdate.push({
...datum,
progressEvent: null,
progressImage: null,
imageDTO: null,
});
}
if (
$lastCompletedItemId.get() === itemId &&
selectStagingAreaAutoSwitch(store.getState()) === 'switch_on_finish'
) {
$selectedItemId.set(itemId);
$lastCompletedItemId.set(null);
}
},
[$lastCompletedItemId, $progressData, $selectedItemId, store]
);
// Set up socket listeners
useEffect(() => {
if (!socket) {
return;
}
const onProgress = (data: S['InvocationProgressEvent']) => {
if (data.destination !== session.id) {
return;
for (const item of items) {
const datum = progressData[item.item_id];
if (datum?.imageDTO) {
continue;
}
setProgress($progressData, data);
};
const onQueueItemStatusChanged = (data: S['QueueItemStatusChangedEvent']) => {
if (data.destination !== session.id) {
return;
const outputImageName = getOutputImageName(item);
if (!outputImageName) {
continue;
}
if (data.status === 'completed') {
$lastCompletedItemId.set(data.item_id);
}
if (data.status === 'in_progress') {
$lastStartedItemId.set(data.item_id);
}
};
socket.on('invocation_progress', onProgress);
socket.on('queue_item_status_changed', onQueueItemStatusChanged);
return () => {
socket.off('invocation_progress', onProgress);
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
};
}, [$lastCompletedItemId, $lastStartedItemId, $progressData, $selectedItemId, session.id, socket]);
// Set up state subscriptions and effects
useEffect(() => {
let _prevItems: readonly S['SessionQueueItem'][] = [];
// Seed the $items atom with the initial query cache state
$items.set(selectQueueItems(store.getState()));
// Manually keep the $items atom in sync as the query cache is updated
const unsubReduxSyncToItemsAtom = store.subscribe(() => {
const prevItems = $items.get();
const items = selectQueueItems(store.getState());
if (items !== prevItems) {
_prevItems = prevItems;
$items.set(items);
}
});
// Handle cases that could result in a nonexistent queue item being selected.
const unsubEnsureSelectedItemIdExists = effect(
[$items, $selectedItemId, $lastStartedItemId],
(items, selectedItemId, lastStartedItemId) => {
if (items.length === 0) {
// If there are no items, cannot have a selected item.
$selectedItemId.set(null);
} else if (selectedItemId === null && items.length > 0) {
// If there is no selected item but there are items, select the first one.
$selectedItemId.set(items[0]?.item_id ?? null);
return;
} else if (
selectStagingAreaAutoSwitch(store.getState()) === 'switch_on_start' &&
items.findIndex(({ item_id }) => item_id === lastStartedItemId) !== -1
) {
$selectedItemId.set(lastStartedItemId);
$lastStartedItemId.set(null);
} else if (selectedItemId !== null && items.findIndex(({ item_id }) => item_id === selectedItemId) === -1) {
// If an item is selected and it is not in the list of items, un-set it. This effect will run again and we'll
// the above case, selecting the first item if there are any.
let prevIndex = _prevItems.findIndex(({ item_id }) => item_id === selectedItemId);
if (prevIndex >= items.length) {
prevIndex = items.length - 1;
}
const nextItem = items[prevIndex];
$selectedItemId.set(nextItem?.item_id ?? null);
}
if (items !== _prevItems) {
_prevItems = items;
}
}
);
// Clean up the progress data when a queue item is discarded.
const unsubCleanUpProgressData = $items.subscribe(async (items) => {
const progressData = $progressData.get();
const toDelete: number[] = [];
const toUpdate: ProgressData[] = [];
for (const [id, datum] of objectEntries(progressData)) {
if (!datum) {
toDelete.push(id);
continue;
}
const item = items.find(({ item_id }) => item_id === datum.itemId);
if (!item) {
toDelete.push(datum.itemId);
} else if (item.status === 'canceled' || item.status === 'failed') {
toUpdate.push({
...datum,
progressEvent: null,
progressImage: null,
imageDTO: null,
});
}
const imageDTO = await getImageDTOSafe(outputImageName);
if (!imageDTO) {
continue;
}
for (const item of items) {
const datum = progressData[item.item_id];
if (datum) {
if (datum.imageDTO) {
continue;
}
const outputImageName = getOutputImageName(item);
if (!outputImageName) {
continue;
}
const imageDTO = await getImageDTOSafe(outputImageName);
if (!imageDTO) {
continue;
}
toUpdate.push({
...datum,
imageDTO,
});
} else {
const outputImageName = getOutputImageName(item);
if (!outputImageName) {
continue;
}
const imageDTO = await getImageDTOSafe(outputImageName);
if (!imageDTO) {
continue;
}
toUpdate.push({
...getInitialProgressData(item.item_id),
imageDTO,
});
}
// This is the load logic mentioned in the comment in the QueueItemStatusChangedEvent handler above.
if (
$lastCompletedItemId.get() === item.item_id &&
selectStagingAreaAutoSwitch(store.getState()) === 'switch_on_finish'
) {
loadImage(imageDTO.image_url, true).then(() => {
$selectedItemId.set(item.item_id);
$lastCompletedItemId.set(null);
});
}
for (const itemId of toDelete) {
$progressData.setKey(itemId, undefined);
}
toUpdate.push({
...getInitialProgressData(item.item_id),
...datum,
imageDTO,
});
}
for (const datum of toUpdate) {
$progressData.setKey(datum.itemId, datum);
}
});
for (const itemId of toDelete) {
$progressData.setKey(itemId, undefined);
}
// We only want to auto-switch to completed queue items once their images have fully loaded to prevent flashes
// of fallback content and/or progress images. The only surefire way to determine when images have fully loaded
// is via the image elements' `onLoad` callback. Images set `$lastLoadedItemId` to their queue item ID in their
// `onLoad` handler, and we listen for that here. If auto-switch is enabled, we then switch the to the item.
//
// TODO: This isn't perfect... we set $lastLoadedItemId in the mini preview component, but the full view
// component still needs to retrieve the image from the browser cache... can result in a flash of the progress
// image...
const unsubHandleAutoSwitch = $lastLoadedItemId.listen((lastLoadedItemId) => {
if (lastLoadedItemId === null) {
return;
}
if (selectStagingAreaAutoSwitch(store.getState()) === 'switch_on_finish') {
$selectedItemId.set(lastLoadedItemId);
}
$lastLoadedItemId.set(null);
});
for (const datum of toUpdate) {
$progressData.setKey(datum.itemId, datum);
}
});
// Create an RTK Query subscription. Without this, the query cache selector will never return anything bc RTK
// doesn't know we care about it.
const { unsubscribe: unsubQueueItemsQuery } = store.dispatch(
queueApi.endpoints.listAllQueueItems.initiate({ destination: session.id })
);
// const unsubListener = store.dispatch(
// addAppListener({
// matcher: queueApi.endpoints.cancelQueueItem.matchFulfilled,
// effect: ({ payload }, { getState }) => {
// const { item_id } = payload;
// const items = selectQueueItems(getState());
// if (items.length === 0) {
// $selectedItemId.set(null);
// } else if ($selectedItemId.get() === null) {
// $selectedItemId.set(items[0].item_id);
// }
// },
// })
// );
// Clean up all subscriptions and top-level (i.e. non-computed/derived state)
return () => {
unsubHandleAutoSwitch();
unsubQueueItemsQuery();
unsubReduxSyncToItemsAtom();
unsubEnsureSelectedItemIdExists();
unsubCleanUpProgressData();
$items.set([]);
$progressData.set({});
$selectedItemId.set(null);
};
}, [
$items,
$lastLoadedItemId,
$lastStartedItemId,
$progressData,
$selectedItemId,
selectQueueItems,
session.id,
store,
]);
const value = useMemo<CanvasSessionContextValue>(
() => ({
session,
$items,
$hasItems,
$isPending,
$progressData,
$selectedItemId,
$selectedItem,
$selectedItemIndex,
$selectedItemOutputImageDTO,
$itemCount,
selectNext,
selectPrev,
selectFirst,
selectLast,
onImageLoad,
discard,
discardAll,
}),
[
$items,
$hasItems,
$isPending,
$progressData,
$selectedItem,
$selectedItemId,
$selectedItemIndex,
session,
$selectedItemOutputImageDTO,
$itemCount,
selectNext,
selectPrev,
selectFirst,
selectLast,
onImageLoad,
discard,
discardAll,
]
// Create an RTK Query subscription. Without this, the query cache selector will never return anything bc RTK
// doesn't know we care about it.
const { unsubscribe: unsubQueueItemsQuery } = store.dispatch(
queueApi.endpoints.listAllQueueItems.initiate({ destination: sessionId })
);
return <CanvasSessionContext.Provider value={value}>{children}</CanvasSessionContext.Provider>;
}
);
// Clean up all subscriptions and top-level (i.e. non-computed/derived state)
return () => {
unsubQueueItemsQuery();
unsubReduxSyncToItemsAtom();
unsubEnsureSelectedItemIdExists();
unsubSyncProgressData();
$items.set([]);
$progressData.set({});
$selectedItemId.set(null);
};
}, [$items, $progressData, $selectedItemId, selectQueueItems, sessionId, store, $lastCompletedItemId]);
const value = useMemo<CanvasSessionContextValue>(
() => ({
$items,
$hasItems,
$isPending,
$progressData,
$selectedItemId,
$selectedItem,
$selectedItemIndex,
$selectedItemOutputImageDTO,
$itemCount,
selectNext,
selectPrev,
selectFirst,
selectLast,
discard,
discardAll,
}),
[
$items,
$hasItems,
$isPending,
$progressData,
$selectedItem,
$selectedItemId,
$selectedItemIndex,
$selectedItemOutputImageDTO,
$itemCount,
selectNext,
selectPrev,
selectFirst,
selectLast,
discard,
discardAll,
]
);
return <CanvasSessionContext.Provider value={value}>{children}</CanvasSessionContext.Provider>;
});
CanvasSessionContextProvider.displayName = 'CanvasSessionContextProvider';
export const useCanvasSessionContext = () => {

View File

@@ -5,7 +5,7 @@ import { useIsRegionFocused } from 'common/hooks/focus';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
import { canvasSessionReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { canvasSessionReset, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { selectBboxRect, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageNameToImageObject } from 'features/controlLayers/store/util';
@@ -19,6 +19,7 @@ export const StagingAreaToolbarAcceptButton = memo(() => {
const ctx = useCanvasSessionContext();
const dispatch = useAppDispatch();
const canvasManager = useCanvasManager();
const canvasSessionId = useAppSelector(selectCanvasSessionId);
const bboxRect = useAppSelector(selectBboxRect);
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
@@ -41,14 +42,14 @@ export const StagingAreaToolbarAcceptButton = memo(() => {
dispatch(rasterLayerAdded({ overrides, isSelected: selectedEntityIdentifier?.type === 'raster_layer' }));
dispatch(canvasSessionReset());
cancelQueueItemsByDestination.trigger(ctx.session.id, { withToast: false });
cancelQueueItemsByDestination.trigger(canvasSessionId, { withToast: false });
}, [
selectedItemImageDTO,
bboxRect,
dispatch,
selectedEntityIdentifier?.type,
cancelQueueItemsByDestination,
ctx.session.id,
canvasSessionId,
]);
useHotkeys(

View File

@@ -1,5 +1,7 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { useCancelQueueItemsByDestination } from 'features/queue/hooks/useCancelQueueItemsByDestination';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@@ -9,11 +11,12 @@ export const StagingAreaToolbarDiscardAllButton = memo(({ isDisabled }: { isDisa
const ctx = useCanvasSessionContext();
const { t } = useTranslation();
const cancelQueueItemsByDestination = useCancelQueueItemsByDestination();
const canvasSessionId = useAppSelector(selectCanvasSessionId);
const discardAll = useCallback(() => {
ctx.discardAll();
cancelQueueItemsByDestination.trigger(ctx.session.id, { withToast: false });
}, [cancelQueueItemsByDestination, ctx]);
cancelQueueItemsByDestination.trigger(canvasSessionId, { withToast: false });
}, [cancelQueueItemsByDestination, ctx, canvasSessionId]);
return (
<IconButton

View File

@@ -1,10 +1,10 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { getFocusedRegion } from 'common/hooks/focus';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { entityDeleted } from 'features/controlLayers/store/canvasSlice';
import { selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { selectActiveTabCanvasRightPanel } from 'features/ui/store/uiSelectors';
import { useCallback } from 'react';
export function useCanvasDeleteLayerHotkey() {
@@ -12,14 +12,13 @@ export function useCanvasDeleteLayerHotkey() {
const dispatch = useAppDispatch();
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
const isBusy = useCanvasIsBusy();
const canvasRightPanelTab = useAppSelector(selectActiveTabCanvasRightPanel);
const deleteSelectedLayer = useCallback(() => {
if (selectedEntityIdentifier === null || isBusy || canvasRightPanelTab !== 'layers') {
if (selectedEntityIdentifier === null || isBusy || getFocusedRegion() !== 'layers') {
return;
}
dispatch(entityDeleted({ entityIdentifier: selectedEntityIdentifier }));
}, [canvasRightPanelTab, dispatch, isBusy, selectedEntityIdentifier]);
}, [dispatch, isBusy, selectedEntityIdentifier]);
useRegisteredHotkeys({
id: 'deleteSelected',

View File

@@ -16,7 +16,6 @@ import {
selectIsolatedLayerPreview,
selectIsolatedStagingPreview,
} from 'features/controlLayers/store/canvasSettingsSlice';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
buildSelectIsSelected,
getSelectIsTypeHidden,
@@ -283,7 +282,7 @@ export abstract class CanvasEntityAdapterBase<T extends CanvasEntityState, U ext
this.subscriptions.add(
this.manager.stateApi.createStoreSubscription(selectIsolatedStagingPreview, this.syncVisibility)
);
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectIsStaging, this.syncVisibility));
this.subscriptions.add(this.manager.stagingArea.$isStaging.listen(this.syncVisibility));
this.subscriptions.add(this.manager.stateApi.$filteringAdapter.listen(this.syncVisibility));
this.subscriptions.add(this.manager.stateApi.$transformingAdapter.listen(this.syncVisibility));
this.subscriptions.add(this.manager.stateApi.$segmentingAdapter.listen(this.syncVisibility));
@@ -462,7 +461,7 @@ export abstract class CanvasEntityAdapterBase<T extends CanvasEntityState, U ext
* This allows the user to easily see how the new generation fits in with the rest of the canvas without the
* other layer types getting in the way.
*/
const isStaging = this.manager.stateApi.runSelector(selectIsStaging);
const isStaging = this.manager.stagingArea.$isStaging.get();
const isRasterLayer = isRasterLayerEntityIdentifier(this.entityIdentifier);
if (isStaging && !isRasterLayer) {
this.setVisibility(false);

View File

@@ -4,12 +4,12 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import type { CanvasImageState } from 'features/controlLayers/store/types';
import Konva from 'konva';
import type { Atom } from 'nanostores';
import { atom, effect } from 'nanostores';
import type { Logger } from 'roarr';
import type { S } from 'services/api/types';
// To get pixel sizes corresponding to our theme tokens, first find the theme token CSS var in browser dev tools.
// For example `var(--invoke-space-8)` is equivalent to using `8` as a space prop in a component.
@@ -121,14 +121,12 @@ export class CanvasStagingAreaModule extends CanvasModuleBase {
this.image = null;
/**
* When we change this flag, we need to re-render the staging area, which hides or shows the staged image.
*/
this.subscriptions.add(this.$shouldShowStagedImage.listen(this.render));
/**
* Rerender when the image source changes.
* Rerender when the anything important changes.
*/
this.subscriptions.add(this.$imageSrc.listen(this.render));
this.subscriptions.add(this.$shouldShowStagedImage.listen(this.render));
this.subscriptions.add(this.$isPending.listen(this.render));
this.subscriptions.add(this.$isStaging.listen(this.render));
/**
* Sync the $isStaging flag with the redux state. $isStaging is used by the manager to determine the global busy
@@ -138,8 +136,7 @@ export class CanvasStagingAreaModule extends CanvasModuleBase {
* even if the user disabled this in the last staging session.
*/
this.subscriptions.add(
this.manager.stateApi.createStoreSubscription(selectIsStaging, (isStaging, oldIsStaging) => {
this.$isStaging.set(isStaging);
this.$isStaging.listen((isStaging, oldIsStaging) => {
if (isStaging && !oldIsStaging) {
this.$shouldShowStagedImage.set(true);
}
@@ -150,15 +147,17 @@ export class CanvasStagingAreaModule extends CanvasModuleBase {
initialize = () => {
this.log.debug('Initializing module');
this.render();
this.$isStaging.set(this.manager.stateApi.runSelector(selectIsStaging));
};
connectToSession = (
$items: Atom<S['SessionQueueItem'][]>,
$selectedItemId: Atom<number | null>,
$progressData: ProgressDataMap,
$isPending: Atom<boolean>
$progressData: ProgressDataMap
) => {
const cb = (selectedItemId: number | null, progressData: Record<number, ProgressData | undefined>) => {
const imageSrcListener = (
selectedItemId: number | null,
progressData: Record<number, ProgressData | undefined>
) => {
if (!selectedItemId) {
this.$imageSrc.set(null);
return;
@@ -176,20 +175,30 @@ export class CanvasStagingAreaModule extends CanvasModuleBase {
this.$imageSrc.set(null);
}
};
const unsubImageSrc = effect([$selectedItemId, $progressData], imageSrcListener);
// Run the effect & forcibly render once to initialize
cb($selectedItemId.get(), $progressData.get());
const isPendingListener = (items: S['SessionQueueItem'][]) => {
this.$isPending.set(items.some((item) => item.status === 'pending' || item.status === 'in_progress'));
};
const unsubIsPending = effect([$items], isPendingListener);
const isStagingListener = (items: S['SessionQueueItem'][]) => {
this.$isStaging.set(items.length > 0);
};
const unsubIsStaging = effect([$items], isStagingListener);
// Run the effects & forcibly render once to initialize
isStagingListener($items.get());
isPendingListener($items.get());
imageSrcListener($selectedItemId.get(), $progressData.get());
this.render();
// Sync the $isPending flag with the computed
const unsubIsPending = effect([$isPending], (isPending) => {
this.$isPending.set(isPending);
});
const unsubImageSrc = effect([$selectedItemId, $progressData], cb);
return () => {
this.$isStaging.set(false);
unsubIsStaging();
this.$isPending.set(false);
unsubIsPending();
this.$imageSrc.set(null);
unsubImageSrc();
};
};

View File

@@ -1,19 +1,21 @@
import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
import type { PersistConfig, RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { deepClone } from 'common/util/deepClone';
import { canvasReset } from 'features/controlLayers/store/actions';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { useMemo } from 'react';
import { queueApi } from 'services/api/endpoints/queue';
type CanvasStagingAreaState = {
generateSessionId: string | null;
canvasSessionId: string | null;
_version: 1;
canvasSessionId: string;
canvasDiscardedQueueItems: number[];
};
const INITIAL_STATE: CanvasStagingAreaState = {
generateSessionId: null,
canvasSessionId: null,
_version: 1,
canvasSessionId: getPrefixedId('canvas'),
canvasDiscardedQueueItems: [],
};
@@ -23,46 +25,38 @@ export const canvasSessionSlice = createSlice({
name: 'canvasSession',
initialState: getInitialState(),
reducers: {
generateSessionIdChanged: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
state.generateSessionId = id;
},
generateSessionReset: (state) => {
state.generateSessionId = null;
},
canvasQueueItemDiscarded: (state, action: PayloadAction<{ itemId: number }>) => {
const { itemId } = action.payload;
if (!state.canvasDiscardedQueueItems.includes(itemId)) {
state.canvasDiscardedQueueItems.push(itemId);
}
},
canvasSessionIdChanged: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
state.canvasSessionId = id;
state.canvasDiscardedQueueItems = [];
canvasSessionReset: {
reducer: (state, action: PayloadAction<{ canvasSessionId: string }>) => {
const { canvasSessionId } = action.payload;
state.canvasSessionId = canvasSessionId;
state.canvasDiscardedQueueItems = [];
},
prepare: () => {
return {
payload: {
canvasSessionId: getPrefixedId('canvas'),
},
};
},
},
canvasSessionReset: (state) => {
state.canvasSessionId = null;
state.canvasDiscardedQueueItems = [];
},
},
extraReducers(builder) {
builder.addCase(canvasReset, (state) => {
state.canvasSessionId = null;
});
},
});
export const {
generateSessionIdChanged,
generateSessionReset,
canvasSessionIdChanged,
canvasSessionReset,
canvasQueueItemDiscarded,
} = canvasSessionSlice.actions;
export const { canvasSessionReset, canvasQueueItemDiscarded } = canvasSessionSlice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
state.canvasSessionId = state.canvasSessionId ?? getPrefixedId('canvas');
}
return state;
};
@@ -74,13 +68,14 @@ export const canvasStagingAreaPersistConfig: PersistConfig<CanvasStagingAreaStat
};
export const selectCanvasSessionSlice = (s: RootState) => s[canvasSessionSlice.name];
export const selectCanvasSessionId = createSelector(selectCanvasSessionSlice, ({ canvasSessionId }) => canvasSessionId);
export const selectGenerateSessionId = createSelector(
const selectDiscardedItems = createSelector(
selectCanvasSessionSlice,
({ generateSessionId }) => generateSessionId
({ canvasDiscardedQueueItems }) => canvasDiscardedQueueItems
);
export const buildSelectSessionQueueItems = (sessionId: string) =>
export const buildSelectCanvasQueueItems = (sessionId: string) =>
createSelector(
[queueApi.endpoints.listAllQueueItems.select({ destination: sessionId }), selectDiscardedItems],
({ data }, discardedItems) => {
@@ -93,21 +88,12 @@ export const buildSelectSessionQueueItems = (sessionId: string) =>
}
);
export const selectIsStaging = (state: RootState) => {
const sessionId = selectCanvasSessionId(state);
if (!sessionId) {
return false;
}
const { data } = queueApi.endpoints.listAllQueueItems.select({ destination: sessionId })(state);
if (!data) {
return false;
}
const discardedItems = selectDiscardedItems(state);
return data.some(
({ status, item_id }) => status !== 'canceled' && status !== 'failed' && !discardedItems.includes(item_id)
);
export const buildSelectIsStaging = (sessionId: string) =>
createSelector([buildSelectCanvasQueueItems(sessionId)], (queueItems) => {
return queueItems.length > 0;
});
export const useCanvasIsStaging = () => {
const sessionId = useAppSelector(selectCanvasSessionId);
const selector = useMemo(() => buildSelectIsStaging(sessionId), [sessionId]);
return useAppSelector(selector);
};
const selectDiscardedItems = createSelector(
selectCanvasSessionSlice,
({ canvasDiscardedQueueItems }) => canvasDiscardedQueueItems
);

View File

@@ -433,7 +433,7 @@ export type LoRA = {
export type EphemeralProgressImage = { sessionId: string; image: ProgressImage };
export const zAspectRatioID = z.enum(['Free', '21:9', '9:21', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
export const zAspectRatioID = z.enum(['Free', '21:9', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16', '9:21']);
export type AspectRatioID = z.infer<typeof zAspectRatioID>;
export const isAspectRatioID = (v: unknown): v is AspectRatioID => zAspectRatioID.safeParse(v).success;
export const ASPECT_RATIO_MAP: Record<Exclude<AspectRatioID, 'Free'>, { ratio: number; inverseID: AspectRatioID }> = {
@@ -469,7 +469,7 @@ export const CHATGPT_ASPECT_RATIOS: Record<ChatGPT4oAspectRatio, Dimensions> = {
'2:3': { width: 1024, height: 1536 },
} as const;
export const zFluxKontextAspectRatioID = z.enum(['21:9', '4:3', '1:1', '3:4', '9:21', '16:9', '9:16']);
export const zFluxKontextAspectRatioID = z.enum(['21:9', '16:9', '4:3', '1:1', '3:4', '9:16', '9:21']);
type FluxKontextAspectRatio = z.infer<typeof zFluxKontextAspectRatioID>;
export const isFluxKontextAspectRatioID = (v: unknown): v is z.infer<typeof zFluxKontextAspectRatioID> =>
zFluxKontextAspectRatioID.safeParse(v).success;

View File

@@ -41,13 +41,13 @@ const zUploadFile = z
// )
.refine(
(file) => {
return ACCEPTED_IMAGE_TYPES.includes(file.type);
return ACCEPTED_IMAGE_TYPES.includes(file.type.toLowerCase());
},
{ message: `File type is not supported` }
)
.refine(
(file) => {
return ACCEPTED_FILE_EXTENSIONS.some((ext) => file.name.endsWith(ext));
return ACCEPTED_FILE_EXTENSIONS.some((ext) => file.name.toLowerCase().endsWith(ext));
},
{ message: `File extension is not supported` }
);

View File

@@ -1,7 +1,7 @@
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
import { draggable, monitorForElements } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
import type { FlexProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Icon, Image } from '@invoke-ai/ui-library';
import { Flex, Icon, Image } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import type { AppDispatch, AppGetState } from 'app/store/store';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
@@ -23,13 +23,11 @@ import { imageToCompareChanged, selectGallerySlice, selectionChanged } from 'fea
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { VIEWER_PANEL_ID } from 'features/ui/layouts/shared';
import type { MouseEvent, MouseEventHandler } from 'react';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { PiImageBold } from 'react-icons/pi';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
const GALLERY_IMAGE_CLASS = 'gallery-image';
const galleryImageContainerSX = {
containerType: 'inline-size',
w: 'full',
@@ -42,45 +40,42 @@ const galleryImageContainerSX = {
'&[data-is-dragging=true]': {
opacity: 0.3,
},
[`.${GALLERY_IMAGE_CLASS}`]: {
touchAction: 'none',
userSelect: 'none',
webkitUserSelect: 'none',
position: 'relative',
justifyContent: 'center',
alignItems: 'center',
aspectRatio: '1/1',
'::before': {
content: '""',
display: 'inline-block',
position: 'absolute',
top: 0,
left: 0,
right: 0,
bottom: 0,
pointerEvents: 'none',
borderRadius: 'base',
},
'&[data-selected=true]::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
},
'&[data-selected-for-compare=true]::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
},
'&:hover::before': {
boxShadow:
'inset 0px 0px 0px 1px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-800)',
},
'&:hover[data-selected=true]::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
},
'&:hover[data-selected-for-compare=true]::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
},
userSelect: 'none',
webkitUserSelect: 'none',
position: 'relative',
justifyContent: 'center',
alignItems: 'center',
aspectRatio: '1/1',
'::before': {
content: '""',
display: 'inline-block',
position: 'absolute',
top: 0,
left: 0,
right: 0,
bottom: 0,
pointerEvents: 'none',
borderRadius: 'base',
},
'&[data-selected=true]::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
},
'&[data-selected-for-compare=true]::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
},
'&:hover::before': {
boxShadow:
'inset 0px 0px 0px 1px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-800)',
},
'&:hover[data-selected=true]::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
},
'&:hover[data-selected-for-compare=true]::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
},
} satisfies SystemStyleObject;
@@ -142,8 +137,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
const [dragPreviewState, setDragPreviewState] = useState<
DndDragPreviewSingleImageState | DndDragPreviewMultipleImageState | null
>(null);
// Must use callback ref - else chakra's Image fallback prop will break the ref & dnd
const [element, ref] = useState<HTMLImageElement | null>(null);
const ref = useRef<HTMLDivElement>(null);
const selectIsSelectedForCompare = useMemo(
() => createSelector(selectGallerySlice, (gallery) => gallery.imageToCompare === imageDTO.image_name),
[imageDTO.image_name]
@@ -156,6 +150,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
const isSelected = useAppSelector(selectIsSelected);
useEffect(() => {
const element = ref.current;
if (!element) {
return;
}
@@ -221,7 +216,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
},
})
);
}, [element, imageDTO, store]);
}, [imageDTO, store]);
const [isHovered, setIsHovered] = useState(false);
@@ -240,34 +235,35 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
navigationApi.focusPanelInActiveTab(VIEWER_PANEL_ID);
}, [store]);
useImageContextMenu(imageDTO, element);
useImageContextMenu(imageDTO, ref);
return (
<>
<Box sx={galleryImageContainerSX} data-is-dragging={isDragging} data-image-name={imageDTO.image_name}>
<Flex
role="button"
className={GALLERY_IMAGE_CLASS}
onMouseOver={onMouseOver}
onMouseOut={onMouseOut}
onClick={onClick}
onDoubleClick={onDoubleClick}
data-selected={isSelected}
data-selected-for-compare={isSelectedForCompare}
>
<Image
ref={ref}
src={imageDTO.thumbnail_url}
w={imageDTO.width}
fallback={<GalleryImagePlaceholder />}
objectFit="contain"
maxW="full"
maxH="full"
borderRadius="base"
/>
<GalleryImageHoverIcons imageDTO={imageDTO} isHovered={isHovered} />
</Flex>
</Box>
<Flex
ref={ref}
sx={galleryImageContainerSX}
data-is-dragging={isDragging}
data-image-name={imageDTO.image_name}
role="button"
onMouseOver={onMouseOver}
onMouseOut={onMouseOut}
onClick={onClick}
onDoubleClick={onDoubleClick}
data-selected={isSelected}
data-selected-for-compare={isSelectedForCompare}
>
<Image
pointerEvents="none"
src={imageDTO.thumbnail_url}
w={imageDTO.width}
fallback={<GalleryImagePlaceholder />}
objectFit="contain"
maxW="full"
maxH="full"
borderRadius="base"
/>
<GalleryImageHoverIcons imageDTO={imageDTO} isHovered={isHovered} />
</Flex>
{dragPreviewState?.type === 'multiple-image' ? createMultipleImageDragPreview(dragPreviewState) : null}
{dragPreviewState?.type === 'single-image' ? createSingleImageDragPreview(dragPreviewState) : null}
</>

View File

@@ -1,5 +1,5 @@
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { useCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { MetadataHandlers, MetadataUtils } from 'features/metadata/parsing';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { useCallback, useMemo } from 'react';
@@ -12,7 +12,7 @@ export const useRecallAll = (imageDTO: ImageDTO) => {
const store = useAppStore();
const tab = useAppSelector(selectActiveTab);
const { metadata, isLoading } = useDebouncedMetadata(imageDTO.image_name);
const isStaging = useAppSelector(selectIsStaging);
const isStaging = useCanvasIsStaging();
const clearStylePreset = useClearStylePresetWithToast();
const isEnabled = useMemo(() => {

View File

@@ -1,5 +1,5 @@
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { useCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { MetadataUtils } from 'features/metadata/parsing';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { useCallback, useMemo } from 'react';
@@ -8,7 +8,7 @@ import type { ImageDTO } from 'services/api/types';
export const useRecallDimensions = (imageDTO: ImageDTO) => {
const store = useAppStore();
const tab = useAppSelector(selectActiveTab);
const isStaging = useAppSelector(selectIsStaging);
const isStaging = useCanvasIsStaging();
const isEnabled = useMemo(() => {
if (tab !== 'canvas' && tab !== 'generate') {

View File

@@ -1,5 +1,5 @@
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { useCanvasIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { MetadataHandlers, MetadataUtils } from 'features/metadata/parsing';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { useCallback, useMemo } from 'react';
@@ -11,7 +11,7 @@ import { useClearStylePresetWithToast } from './useClearStylePresetWithToast';
export const useRecallRemix = (imageDTO: ImageDTO) => {
const store = useAppStore();
const tab = useAppSelector(selectActiveTab);
const isStaging = useAppSelector(selectIsStaging);
const isStaging = useCanvasIsStaging();
const clearStylePreset = useClearStylePresetWithToast();
const { metadata, isLoading } = useDebouncedMetadata(imageDTO.image_name);

View File

@@ -34,6 +34,17 @@ const ModelImage = ({ image_url }: Props) => {
minHeight={MODEL_IMAGE_THUMBNAIL_SIZE}
minWidth={MODEL_IMAGE_THUMBNAIL_SIZE}
borderRadius="base"
fallback={
<Flex
height={MODEL_IMAGE_THUMBNAIL_SIZE}
minWidth={MODEL_IMAGE_THUMBNAIL_SIZE}
borderRadius="base"
alignItems="center"
justifyContent="center"
>
<Icon color="base.500" as={PiImage} boxSize={FALLBACK_ICON_SIZE} />
</Flex>
}
/>
);
};

View File

@@ -1,4 +1,5 @@
import { Box, IconButton, Image } from '@invoke-ai/ui-library';
import { dropzoneAccept } from 'common/hooks/useImageUploadButton';
import { typedMemo } from 'common/util/typedMemo';
import { toast } from 'features/toast/toast';
import { useCallback, useState } from 'react';
@@ -72,11 +73,7 @@ const ModelImageUpload = ({ model_key, model_image }: Props) => {
}, [model_key, t, deleteModelImage]);
const { getInputProps, getRootProps } = useDropzone({
accept: {
'image/png': ['.png'],
'image/jpeg': ['.jpg', '.jpeg', '.png'],
'image/webp': ['.webp'],
},
accept: dropzoneAccept,
onDropAccepted,
noDrag: true,
multiple: false,

View File

@@ -22,7 +22,6 @@ import { useConnection } from 'features/nodes/hooks/useConnection';
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
import { useNodeCopyPaste } from 'features/nodes/hooks/useNodeCopyPaste';
import { useSyncExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
import {
$addNodeCmdk,
$cursorPos,
@@ -83,23 +82,16 @@ export const Flow = memo(() => {
const nodes = useAppSelector(selectNodes);
const edges = useAppSelector(selectEdges);
const viewport = useStore($viewport);
const needsFit = useStore($needsFit);
const mayUndo = useAppSelector(selectMayUndo);
const mayRedo = useAppSelector(selectMayRedo);
const shouldSnapToGrid = useAppSelector(selectShouldSnapToGrid);
const selectionMode = useAppSelector(selectSelectionMode);
const { onConnectStart, onConnect, onConnectEnd } = useConnection();
const flowWrapper = useRef<HTMLDivElement>(null);
const isValidConnection = useIsValidConnection();
const cancelConnection = useReactFlowStore(selectCancelConnection);
const updateNodeInternals = useUpdateNodeInternals();
const store = useAppStore();
const isWorkflowsFocused = useIsRegionFocused('workflows');
const isLocked = useIsWorkflowEditorLocked();
useFocusRegion('workflows', flowWrapper);
useSyncExecutionState();
const [borderRadius] = useToken('radii', ['base']);
const flowStyles = useMemo<CSSProperties>(() => ({ borderRadius }), [borderRadius]);
@@ -110,12 +102,12 @@ export const Flow = memo(() => {
if (!flow) {
return;
}
if (needsFit) {
if ($needsFit.get()) {
$needsFit.set(false);
flow.fitView();
}
},
[dispatch, needsFit]
[dispatch]
);
const onEdgesChange: OnEdgesChange<AnyEdge> = useCallback(
@@ -214,6 +206,83 @@ export const Flow = memo(() => {
// #endregion
const onNodeClick = useCallback<NodeMouseHandler<AnyNode>>((e, node) => {
if (!$isSelectingOutputNode.get()) {
return;
}
if (!isInvocationNode(node)) {
return;
}
const { id } = node.data;
$outputNodeId.set(id);
$isSelectingOutputNode.set(false);
}, []);
return (
<>
<ReactFlow<AnyNode, AnyEdge>
id="workflow-editor"
ref={flowWrapper}
defaultViewport={viewport}
nodeTypes={nodeTypes}
edgeTypes={edgeTypes}
nodes={nodes}
edges={edges}
onInit={onInit}
onNodeClick={onNodeClick}
onMouseMove={onMouseMove}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
onReconnect={onReconnect}
onReconnectStart={onReconnectStart}
onReconnectEnd={onReconnectEnd}
onConnectStart={onConnectStart}
onConnect={onConnect}
onConnectEnd={onConnectEnd}
onMoveEnd={handleMoveEnd}
connectionLineComponent={CustomConnectionLine}
isValidConnection={isValidConnection}
edgesFocusable={!isLocked}
edgesReconnectable={!isLocked}
nodesDraggable={!isLocked}
nodesConnectable={!isLocked}
nodesFocusable={!isLocked}
elementsSelectable={!isLocked}
minZoom={0.1}
snapToGrid={shouldSnapToGrid}
snapGrid={snapGrid}
connectionRadius={30}
proOptions={proOptions}
style={flowStyles}
onPaneClick={handlePaneClick}
deleteKeyCode={null}
selectionMode={selectionMode}
elevateEdgesOnSelect
nodeDragThreshold={1}
noDragClassName={NO_DRAG_CLASS}
noWheelClassName={NO_WHEEL_CLASS}
noPanClassName={NO_PAN_CLASS}
>
<Background />
</ReactFlow>
<HotkeyIsolator />
</>
);
});
Flow.displayName = 'Flow';
const HotkeyIsolator = memo(() => {
const isLocked = useIsWorkflowEditorLocked();
const mayUndo = useAppSelector(selectMayUndo);
const mayRedo = useAppSelector(selectMayRedo);
const cancelConnection = useReactFlowStore(selectCancelConnection);
const store = useAppStore();
const isWorkflowsFocused = useIsRegionFocused('workflows');
const { copySelection, pasteSelection, pasteSelectionWithEdges } = useNodeCopyPaste();
useRegisteredHotkeys({
@@ -239,12 +308,12 @@ export const Flow = memo(() => {
}
});
if (nodeChanges.length > 0) {
dispatch(nodesChanged(nodeChanges));
store.dispatch(nodesChanged(nodeChanges));
}
if (edgeChanges.length > 0) {
dispatch(edgesChanged(edgeChanges));
store.dispatch(edgesChanged(edgeChanges));
}
}, [dispatch, store]);
}, [store]);
useRegisteredHotkeys({
id: 'selectAll',
category: 'workflows',
@@ -273,20 +342,20 @@ export const Flow = memo(() => {
id: 'undo',
category: 'workflows',
callback: () => {
dispatch(undo());
store.dispatch(undo());
},
options: { enabled: isWorkflowsFocused && !isLocked && mayUndo, preventDefault: true },
dependencies: [mayUndo, isLocked, isWorkflowsFocused],
dependencies: [store, mayUndo, isLocked, isWorkflowsFocused],
});
useRegisteredHotkeys({
id: 'redo',
category: 'workflows',
callback: () => {
dispatch(redo());
store.dispatch(redo());
},
options: { enabled: isWorkflowsFocused && !isLocked && mayRedo, preventDefault: true },
dependencies: [mayRedo, isLocked, isWorkflowsFocused],
dependencies: [store, mayRedo, isLocked, isWorkflowsFocused],
});
const onEscapeHotkey = useCallback(() => {
@@ -313,12 +382,12 @@ export const Flow = memo(() => {
edgeChanges.push({ type: 'remove', id });
});
if (nodeChanges.length > 0) {
dispatch(nodesChanged(nodeChanges));
store.dispatch(nodesChanged(nodeChanges));
}
if (edgeChanges.length > 0) {
dispatch(edgesChanged(edgeChanges));
store.dispatch(edgesChanged(edgeChanges));
}
}, [dispatch, store]);
}, [store]);
useRegisteredHotkeys({
id: 'deleteSelection',
category: 'workflows',
@@ -327,65 +396,6 @@ export const Flow = memo(() => {
dependencies: [deleteSelection, isWorkflowsFocused, isLocked],
});
const onNodeClick = useCallback<NodeMouseHandler<AnyNode>>((e, node) => {
if (!$isSelectingOutputNode.get()) {
return;
}
if (!isInvocationNode(node)) {
return;
}
const { id } = node.data;
$outputNodeId.set(id);
$isSelectingOutputNode.set(false);
}, []);
return (
<ReactFlow<AnyNode, AnyEdge>
id="workflow-editor"
ref={flowWrapper}
defaultViewport={viewport}
nodeTypes={nodeTypes}
edgeTypes={edgeTypes}
nodes={nodes}
edges={edges}
onInit={onInit}
onNodeClick={onNodeClick}
onMouseMove={onMouseMove}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
onReconnect={onReconnect}
onReconnectStart={onReconnectStart}
onReconnectEnd={onReconnectEnd}
onConnectStart={onConnectStart}
onConnect={onConnect}
onConnectEnd={onConnectEnd}
onMoveEnd={handleMoveEnd}
connectionLineComponent={CustomConnectionLine}
isValidConnection={isValidConnection}
edgesFocusable={!isLocked}
edgesReconnectable={!isLocked}
nodesDraggable={!isLocked}
nodesConnectable={!isLocked}
nodesFocusable={!isLocked}
elementsSelectable={!isLocked}
minZoom={0.1}
snapToGrid={shouldSnapToGrid}
snapGrid={snapGrid}
connectionRadius={30}
proOptions={proOptions}
style={flowStyles}
onPaneClick={handlePaneClick}
deleteKeyCode={null}
selectionMode={selectionMode}
elevateEdgesOnSelect
nodeDragThreshold={1}
noDragClassName={NO_DRAG_CLASS}
noWheelClassName={NO_WHEEL_CLASS}
noPanClassName={NO_PAN_CLASS}
>
<Background />
</ReactFlow>
);
return null;
});
Flow.displayName = 'Flow';
HotkeyIsolator.displayName = 'HotkeyIsolator';

View File

@@ -1,6 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import { selectNodes } from 'features/nodes/store/selectors';
import type { Templates } from 'features/nodes/store/types';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { isInvocationNode } from 'features/nodes/types/invocation';
@@ -8,9 +8,9 @@ import { isInvocationNode } from 'features/nodes/types/invocation';
import { getFieldColor } from './getEdgeColor';
export const buildSelectAreConnectedNodesSelected = (source: string, target: string) =>
createSelector(selectNodesSlice, (nodes): boolean => {
const sourceNode = nodes.nodes.find((node) => node.id === source);
const targetNode = nodes.nodes.find((node) => node.id === target);
createSelector(selectNodes, (nodes): boolean => {
const sourceNode = nodes.find((node) => node.id === source);
const targetNode = nodes.find((node) => node.id === target);
return Boolean(sourceNode?.selected || targetNode?.selected);
});
@@ -22,10 +22,13 @@ export const buildSelectEdgeColor = (
target: string,
targetHandleId: string | null | undefined
) =>
createSelector(selectNodesSlice, selectWorkflowSettingsSlice, (nodes, workflowSettings): string => {
createSelector(selectNodes, selectWorkflowSettingsSlice, (nodes, workflowSettings): string => {
const { shouldColorEdges } = workflowSettings;
const sourceNode = nodes.nodes.find((node) => node.id === source);
const targetNode = nodes.nodes.find((node) => node.id === target);
if (!shouldColorEdges) {
return colorTokenToCssVar('base.500');
}
const sourceNode = nodes.find((node) => node.id === source);
const targetNode = nodes.find((node) => node.id === target);
if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) {
return colorTokenToCssVar('base.500');
@@ -37,7 +40,7 @@ export const buildSelectEdgeColor = (
const outputFieldTemplate = sourceNodeTemplate?.outputs[sourceHandleId];
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
return sourceType && shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
return sourceType ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
});
export const buildSelectEdgeLabel = (
@@ -47,9 +50,9 @@ export const buildSelectEdgeLabel = (
target: string,
targetHandleId: string | null | undefined
) =>
createSelector(selectNodesSlice, (nodes): string | null => {
const sourceNode = nodes.nodes.find((node) => node.id === source);
const targetNode = nodes.nodes.find((node) => node.id === target);
createSelector(selectNodes, (nodes): string | null => {
const sourceNode = nodes.find((node) => node.id === source);
const targetNode = nodes.find((node) => node.id === target);
if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) {
return null;

View File

@@ -37,7 +37,7 @@ const sx: SystemStyleObject = {
};
const InvocationNode = ({ nodeId, isOpen }: Props) => {
const withFooter = useWithFooter(nodeId);
const withFooter = useWithFooter();
return (
<>
@@ -64,7 +64,7 @@ const InvocationNode = ({ nodeId, isOpen }: Props) => {
export default memo(InvocationNode);
const ConnectionFields = memo(({ nodeId }: { nodeId: string }) => {
const fieldNames = useInputFieldNamesConnection(nodeId);
const fieldNames = useInputFieldNamesConnection();
return (
<>
{fieldNames.map((fieldName, i) => (
@@ -80,7 +80,7 @@ const ConnectionFields = memo(({ nodeId }: { nodeId: string }) => {
ConnectionFields.displayName = 'ConnectionFields';
const AnyOrDirectFields = memo(({ nodeId }: { nodeId: string }) => {
const fieldNames = useInputFieldNamesAnyOrDirect(nodeId);
const fieldNames = useInputFieldNamesAnyOrDirect();
return (
<>
{fieldNames.map((fieldName) => (
@@ -94,7 +94,7 @@ const AnyOrDirectFields = memo(({ nodeId }: { nodeId: string }) => {
AnyOrDirectFields.displayName = 'AnyOrDirectFields';
const MissingFields = memo(({ nodeId }: { nodeId: string }) => {
const fieldNames = useInputFieldNamesMissing(nodeId);
const fieldNames = useInputFieldNamesMissing();
return (
<>
{fieldNames.map((fieldName) => (
@@ -108,7 +108,7 @@ const MissingFields = memo(({ nodeId }: { nodeId: string }) => {
MissingFields.displayName = 'MissingFields';
const OutputFields = memo(({ nodeId }: { nodeId: string }) => {
const fieldNames = useOutputFieldNames(nodeId);
const fieldNames = useOutputFieldNames();
return (
<>
{fieldNames.map((fieldName, i) => (

View File

@@ -9,8 +9,8 @@ interface Props {
nodeId: string;
}
const InvocationNodeClassificationIcon = ({ nodeId }: Props) => {
const classification = useNodeClassification(nodeId);
const InvocationNodeClassificationIcon = (_: Props) => {
const classification = useNodeClassification();
if (!classification || classification === 'stable') {
return null;

View File

@@ -19,7 +19,7 @@ const collapsedHandleStyles: CSSProperties = {
};
const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => {
const template = useNodeTemplateOrThrow(nodeId);
const template = useNodeTemplateOrThrow();
if (!template) {
return null;

View File

@@ -16,8 +16,8 @@ type Props = {
const props: ChakraProps = { w: 'unset' };
const InvocationNodeFooter = ({ nodeId }: Props) => {
const hasImageOutput = useNodeHasImageOutput(nodeId);
const isExecutableNode = useIsExecutableNode(nodeId);
const hasImageOutput = useNodeHasImageOutput();
const isExecutableNode = useIsExecutableNode();
const isCacheEnabled = useFeatureStatus('invocationCache');
return (
<Flex

View File

@@ -3,7 +3,7 @@ import { Flex } from '@invoke-ai/ui-library';
import NodeCollapseButton from 'features/nodes/components/flow/nodes/common/NodeCollapseButton';
import NodeTitle from 'features/nodes/components/flow/nodes/common/NodeTitle';
import InvocationNodeClassificationIcon from 'features/nodes/components/flow/nodes/Invocation/InvocationNodeClassificationIcon';
import { useNodeIsInvalid } from 'features/nodes/hooks/useNodeIsInvalid';
import { useNodeHasErrors } from 'features/nodes/hooks/useNodeIsInvalid';
import { memo } from 'react';
import InvocationNodeCollapsedHandles from './InvocationNodeCollapsedHandles';
@@ -16,26 +16,23 @@ type Props = {
};
const sx: SystemStyleObject = {
bg: 'var(--header-bg-color)',
borderTopRadius: 'base',
alignItems: 'center',
justifyContent: 'space-between',
h: 8,
textAlign: 'center',
color: 'base.200',
borderBottomRadius: 'base',
'&[data-is-open="true"]': {
borderBottomRadius: 0,
},
'&[data-is-invalid="true"]': {
color: 'error.300',
},
};
const InvocationNodeHeader = ({ nodeId, isOpen }: Props) => {
const isInvalid = useNodeIsInvalid(nodeId);
const isInvalid = useNodeHasErrors();
return (
<Flex layerStyle="nodeHeader" sx={sx} data-is-open={isOpen} data-is-invalid={isInvalid}>
<Flex sx={sx} data-is-open={isOpen} data-is-invalid={isInvalid}>
<NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
<InvocationNodeClassificationIcon nodeId={nodeId} />
<NodeTitle nodeId={nodeId} />

View File

@@ -1,6 +1,5 @@
import { Flex, Icon, Text, Tooltip } from '@invoke-ai/ui-library';
import { compare } from 'compare-versions';
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
import { useInvocationNodeNotes } from 'features/nodes/hooks/useNodeNotes';
import { useNodeTemplateOrThrow } from 'features/nodes/hooks/useNodeTemplateOrThrow';
import { useNodeUserTitleSafe } from 'features/nodes/hooks/useNodeUserTitleSafe';
@@ -14,22 +13,20 @@ interface Props {
}
export const InvocationNodeInfoIcon = memo(({ nodeId }: Props) => {
const needsUpdate = useNodeNeedsUpdate(nodeId);
return (
<Tooltip label={<TooltipContent nodeId={nodeId} />} placement="top" shouldWrapChildren>
<Icon as={PiInfoBold} display="block" boxSize={4} w={8} color={needsUpdate ? 'error.400' : 'base.400'} />
<Icon as={PiInfoBold} display="block" boxSize={4} w={8} />
</Tooltip>
);
});
InvocationNodeInfoIcon.displayName = 'InvocationNodeInfoIcon';
const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
const notes = useInvocationNodeNotes(nodeId);
const label = useNodeUserTitleSafe(nodeId);
const version = useNodeVersion(nodeId);
const nodeTemplate = useNodeTemplateOrThrow(nodeId);
const TooltipContent = memo((_: { nodeId: string }) => {
const notes = useInvocationNodeNotes();
const label = useNodeUserTitleSafe();
const version = useNodeVersion();
const nodeTemplate = useNodeTemplateOrThrow();
const { t } = useTranslation();
const title = useMemo(() => {

View File

@@ -13,7 +13,7 @@ type Props = {
export const InvocationNodeNotesTextarea = memo(({ nodeId }: Props) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const notes = useInvocationNodeNotes(nodeId);
const notes = useInvocationNodeNotes();
const handleNotesChanged = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(nodeNotesChanged({ nodeId, notes: e.target.value }));

View File

@@ -14,7 +14,7 @@ type Props = {
const InvocationNodeUnknownFallback = ({ nodeId, isOpen, label, type }: Props) => {
const { t } = useTranslation();
const nodePack = useNodePack(nodeId);
const nodePack = useNodePack();
return (
<>
<Flex
@@ -26,9 +26,10 @@ const InvocationNodeUnknownFallback = ({ nodeId, isOpen, label, type }: Props) =
h={8}
fontWeight="semibold"
fontSize="sm"
bg="error.700"
>
<NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
<Text w="full" textAlign="center" pe={8} color="error.300">
<Text w="full" textAlign="center" pe={8}>
{label ? `${label} (${type})` : type}
</Text>
</Flex>

View File

@@ -9,6 +9,7 @@ import { selectNodes } from 'features/nodes/store/selectors';
import type { InvocationNodeData } from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react';
import { InvocationNodeContextProvider } from './context';
import InvocationNodeUnknownFallback from './InvocationNodeUnknownFallback';
const InvocationNodeWrapper = (props: NodeProps<Node<InvocationNodeData>>) => {
@@ -28,16 +29,20 @@ const InvocationNodeWrapper = (props: NodeProps<Node<InvocationNodeData>>) => {
if (!hasTemplate) {
return (
<NodeWrapper nodeId={nodeId} selected={selected}>
<InvocationNodeUnknownFallback nodeId={nodeId} isOpen={isOpen} label={label} type={type} />
</NodeWrapper>
<InvocationNodeContextProvider nodeId={nodeId}>
<NodeWrapper nodeId={nodeId} selected={selected} isMissingTemplate>
<InvocationNodeUnknownFallback nodeId={nodeId} isOpen={isOpen} label={label} type={type} />
</NodeWrapper>
</InvocationNodeContextProvider>
);
}
return (
<NodeWrapper nodeId={nodeId} selected={selected}>
<InvocationNode nodeId={nodeId} isOpen={isOpen} />
</NodeWrapper>
<InvocationNodeContextProvider nodeId={nodeId}>
<NodeWrapper nodeId={nodeId} selected={selected}>
<InvocationNode nodeId={nodeId} isOpen={isOpen} />
</NodeWrapper>
</InvocationNodeContextProvider>
);
};

View File

@@ -11,8 +11,8 @@ import { useTranslation } from 'react-i18next';
const SaveToGalleryCheckbox = ({ nodeId }: { nodeId: string }) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const hasImageOutput = useNodeHasImageOutput(nodeId);
const isIntermediate = useNodeIsIntermediate(nodeId);
const hasImageOutput = useNodeHasImageOutput();
const isIntermediate = useNodeIsIntermediate();
const handleChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(

View File

@@ -9,7 +9,7 @@ import { useTranslation } from 'react-i18next';
const UseCacheCheckbox = ({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const useCache = useUseCache(nodeId);
const useCache = useUseCache();
const handleChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(

View File

@@ -0,0 +1,230 @@
import { useStore } from '@nanostores/react';
import type { Selector } from '@reduxjs/toolkit';
import { createSelector } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import { $templates } from 'features/nodes/store/nodesSlice';
import { selectEdges, selectNodes } from 'features/nodes/store/selectors';
import type { InvocationNode, InvocationTemplate } from 'features/nodes/types/invocation';
import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate';
import type { PropsWithChildren } from 'react';
import { createContext, memo, useContext, useMemo } from 'react';
type InvocationNodeContextValue = {
nodeId: string;
selectNodeSafe: Selector<RootState, InvocationNode | null>;
selectNodeDataSafe: Selector<RootState, InvocationNode['data'] | null>;
selectNodeTypeSafe: Selector<RootState, string | null>;
selectNodeTemplateSafe: Selector<RootState, InvocationTemplate | null>;
selectNodeInputsSafe: Selector<RootState, InvocationNode['data']['inputs'] | null>;
buildSelectInputFieldSafe: (
fieldName: string
) => Selector<RootState, InvocationNode['data']['inputs'][string] | null>;
buildSelectInputFieldTemplateSafe: (
fieldName: string
) => Selector<RootState, InvocationTemplate['inputs'][string] | null>;
buildSelectOutputFieldTemplateSafe: (
fieldName: string
) => Selector<RootState, InvocationTemplate['outputs'][string] | null>;
selectNodeOrThrow: Selector<RootState, InvocationNode>;
selectNodeDataOrThrow: Selector<RootState, InvocationNode['data']>;
selectNodeTypeOrThrow: Selector<RootState, string>;
selectNodeTemplateOrThrow: Selector<RootState, InvocationTemplate>;
selectNodeInputsOrThrow: Selector<RootState, InvocationNode['data']['inputs']>;
buildSelectInputFieldOrThrow: (fieldName: string) => Selector<RootState, InvocationNode['data']['inputs'][string]>;
buildSelectInputFieldTemplateOrThrow: (
fieldName: string
) => Selector<RootState, InvocationTemplate['inputs'][string]>;
buildSelectOutputFieldTemplateOrThrow: (
fieldName: string
) => Selector<RootState, InvocationTemplate['outputs'][string]>;
buildSelectIsInputFieldConnected: (fieldName: string) => Selector<RootState, boolean>;
selectNodeNeedsUpdate: Selector<RootState, boolean>;
};
const InvocationNodeContext = createContext<InvocationNodeContextValue | null>(null);
const getSelectorFromCache = <T extends Selector>(cache: Map<string, Selector>, key: string, fallback: () => T): T => {
let selector = cache.get(key);
if (!selector) {
selector = fallback();
cache.set(key, selector);
}
return selector as T;
};
export const InvocationNodeContextProvider = memo(({ nodeId, children }: PropsWithChildren<{ nodeId: string }>) => {
const templates = useStore($templates);
const value = useMemo(() => {
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const cache: Map<string, Selector<RootState, any>> = new Map();
const selectNodeSafe = getSelectorFromCache(cache, 'selectNodeSafe', () =>
createSelector(selectNodes, (nodes) => {
return (nodes.find(({ id, type }) => type === 'invocation' && id === nodeId) ?? null) as InvocationNode | null;
})
);
const selectNodeDataSafe = getSelectorFromCache(cache, 'selectNodeDataSafe', () =>
createSelector(selectNodeSafe, (node) => {
return node?.data ?? null;
})
);
const selectNodeTypeSafe = getSelectorFromCache(cache, 'selectNodeTypeSafe', () =>
createSelector(selectNodeDataSafe, (data) => {
return data?.type ?? null;
})
);
const selectNodeTemplateSafe = getSelectorFromCache(cache, 'selectNodeTemplateSafe', () =>
createSelector(selectNodeTypeSafe, (type) => {
return type ? (templates[type] ?? null) : null;
})
);
const selectNodeInputsSafe = getSelectorFromCache(cache, 'selectNodeInputsSafe', () =>
createSelector(selectNodeDataSafe, (data) => {
return data?.inputs ?? null;
})
);
const buildSelectInputFieldSafe = (fieldName: string) =>
getSelectorFromCache(cache, `buildSelectInputFieldSafe-${fieldName}`, () =>
createSelector(selectNodeInputsSafe, (inputs) => {
return inputs?.[fieldName] ?? null;
})
);
const buildSelectInputFieldTemplateSafe = (fieldName: string) =>
getSelectorFromCache(cache, `buildSelectInputFieldTemplateSafe-${fieldName}`, () =>
createSelector(selectNodeTemplateSafe, (template) => {
return template?.inputs?.[fieldName] ?? null;
})
);
const buildSelectOutputFieldTemplateSafe = (fieldName: string) =>
getSelectorFromCache(cache, `buildSelectOutputFieldTemplateSafe-${fieldName}`, () =>
createSelector(selectNodeTemplateSafe, (template) => {
return template?.outputs?.[fieldName] ?? null;
})
);
const selectNodeOrThrow = getSelectorFromCache(cache, 'selectNodeOrThrow', () =>
createSelector(selectNodes, (nodes) => {
const node = nodes.find(({ id, type }) => type === 'invocation' && id === nodeId) as InvocationNode | undefined;
if (node === undefined) {
throw new Error(`Cannot find node with id ${nodeId}`);
}
return node;
})
);
const selectNodeDataOrThrow = getSelectorFromCache(cache, 'selectNodeDataOrThrow', () =>
createSelector(selectNodeOrThrow, (node) => {
return node.data;
})
);
const selectNodeTypeOrThrow = getSelectorFromCache(cache, 'selectNodeTypeOrThrow', () =>
createSelector(selectNodeDataOrThrow, (data) => {
return data.type;
})
);
const selectNodeTemplateOrThrow = getSelectorFromCache(cache, 'selectNodeTemplateOrThrow', () =>
createSelector(selectNodeTypeOrThrow, (type) => {
const template = templates[type];
if (template === undefined) {
throw new Error(`Cannot find template for node with id ${nodeId} with type ${type}`);
}
return template;
})
);
const selectNodeInputsOrThrow = getSelectorFromCache(cache, 'selectNodeInputsOrThrow', () =>
createSelector(selectNodeDataOrThrow, (data) => {
return data.inputs;
})
);
const buildSelectInputFieldOrThrow = (fieldName: string) =>
getSelectorFromCache(cache, `buildSelectInputFieldOrThrow-${fieldName}`, () =>
createSelector(selectNodeInputsOrThrow, (inputs) => {
const field = inputs[fieldName];
if (field === undefined) {
throw new Error(`Cannot find input field with name ${fieldName} in node ${nodeId}`);
}
return field;
})
);
const buildSelectInputFieldTemplateOrThrow = (fieldName: string) =>
getSelectorFromCache(cache, `buildSelectInputFieldTemplateOrThrow-${fieldName}`, () =>
createSelector(selectNodeTemplateOrThrow, (template) => {
const fieldTemplate = template.inputs[fieldName];
if (fieldTemplate === undefined) {
throw new Error(`Cannot find input field template with name ${fieldName} in node ${nodeId}`);
}
return fieldTemplate;
})
);
const buildSelectOutputFieldTemplateOrThrow = (fieldName: string) =>
getSelectorFromCache(cache, `buildSelectOutputFieldTemplateOrThrow-${fieldName}`, () =>
createSelector(selectNodeTemplateOrThrow, (template) => {
const fieldTemplate = template.outputs[fieldName];
if (fieldTemplate === undefined) {
throw new Error(`Cannot find output field template with name ${fieldName} in node ${nodeId}`);
}
return fieldTemplate;
})
);
const buildSelectIsInputFieldConnected = (fieldName: string) =>
getSelectorFromCache(cache, `buildSelectIsInputFieldConnected-${fieldName}`, () =>
createSelector(selectEdges, (edges) => {
return edges.some((edge) => {
return edge.target === nodeId && edge.targetHandle === fieldName;
});
})
);
const selectNodeNeedsUpdate = getSelectorFromCache(cache, 'selectNodeNeedsUpdate', () =>
createSelector([selectNodeDataSafe, selectNodeTemplateSafe], (data, template) => {
if (!data || !template) {
return false; // If there's no data or template, no update is possible
}
return getNeedsUpdate(data, template);
})
);
return {
nodeId,
selectNodeSafe,
selectNodeDataSafe,
selectNodeTypeSafe,
selectNodeTemplateSafe,
selectNodeInputsSafe,
buildSelectInputFieldSafe,
buildSelectInputFieldTemplateSafe,
buildSelectOutputFieldTemplateSafe,
selectNodeOrThrow,
selectNodeDataOrThrow,
selectNodeTypeOrThrow,
selectNodeTemplateOrThrow,
selectNodeInputsOrThrow,
buildSelectInputFieldOrThrow,
buildSelectInputFieldTemplateOrThrow,
buildSelectOutputFieldTemplateOrThrow,
buildSelectIsInputFieldConnected,
selectNodeNeedsUpdate,
} satisfies InvocationNodeContextValue;
}, [nodeId, templates]);
return <InvocationNodeContext.Provider value={value}>{children}</InvocationNodeContext.Provider>;
});
export const useInvocationNodeContext = () => {
const context = useContext(InvocationNodeContext);
if (!context) {
throw new Error('useInvocationNodeContext must be used within an InvocationNodeProvider');
}
return context;
};

View File

@@ -48,7 +48,7 @@ InputFieldDescriptionPopover.displayName = 'InputFieldDescriptionPopover';
const Content = memo(({ nodeId, fieldName }: Props) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const description = useInputFieldUserDescriptionSafe(nodeId, fieldName);
const description = useInputFieldUserDescriptionSafe(fieldName);
const onChange = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(fieldDescriptionChanged({ nodeId, fieldName, val: e.target.value }));

View File

@@ -22,9 +22,9 @@ interface Props {
}
export const InputFieldEditModeNodes = memo(({ nodeId, fieldName }: Props) => {
const fieldTemplate = useInputFieldTemplateOrThrow(nodeId, fieldName);
const isInvalid = useInputFieldIsInvalid(nodeId, fieldName);
const isConnected = useInputFieldIsConnected(nodeId, fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(fieldName);
const isInvalid = useInputFieldIsInvalid(fieldName);
const isConnected = useInputFieldIsConnected(fieldName);
if (fieldTemplate.input === 'connection' || isConnected) {
return (

View File

@@ -15,8 +15,8 @@ type Props = PropsWithChildren<{
}>;
export const InputFieldGate = memo(({ nodeId, fieldName, children, fallback, formatLabel }: Props) => {
const hasInstance = useInputFieldInstanceExists(nodeId, fieldName);
const hasTemplate = useInputFieldTemplateExists(nodeId, fieldName);
const hasInstance = useInputFieldInstanceExists(fieldName);
const hasTemplate = useInputFieldTemplateExists(fieldName);
if (!hasTemplate || !hasInstance) {
// fallback may be null, indicating we should render nothing at all - must check for undefined explicitly
@@ -41,7 +41,6 @@ InputFieldGate.displayName = 'InputFieldGate';
const Fallback = memo(
({
nodeId,
fieldName,
formatLabel,
hasTemplate,
@@ -54,7 +53,7 @@ const Fallback = memo(
hasInstance: boolean;
}) => {
const { t } = useTranslation();
const name = useInputFieldNameSafe(nodeId, fieldName);
const name = useInputFieldNameSafe(fieldName);
const label = useMemo(() => {
if (formatLabel) {
return formatLabel(name);
@@ -70,8 +69,8 @@ const Fallback = memo(
return (
<InputFieldWrapper>
<Flex w="full" px={1} py={1} justifyContent="center">
<Text fontWeight="semibold" color="error.300" whiteSpace="pre" textAlign="center">
<Flex w="full" px={1} py={1}>
<Text fontWeight="semibold" color="error.300" whiteSpace="pre">
{label}
</Text>
</Flex>

View File

@@ -63,7 +63,7 @@ const handleStyles = {
} satisfies CSSProperties;
export const InputFieldHandle = memo(({ nodeId, fieldName }: Props) => {
const fieldTemplate = useInputFieldTemplateOrThrow(nodeId, fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(fieldName);
const fieldTypeName = useFieldTypeName(fieldTemplate.type);
const fieldColor = useMemo(() => getFieldColor(fieldTemplate.type), [fieldTemplate.type]);
const isModelField = useMemo(() => isModelFieldType(fieldTemplate.type), [fieldTemplate.type]);

View File

@@ -150,8 +150,8 @@ type Props = {
};
export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props) => {
const field = useInputFieldInstance(nodeId, fieldName);
const template = useInputFieldTemplateOrThrow(nodeId, fieldName);
const field = useInputFieldInstance(fieldName);
const template = useInputFieldTemplateOrThrow(fieldName);
// When deciding which component to render, first we check the type of the template, which is more efficient than the
// instance type check. The instance type check uses zod and is slower.

View File

@@ -9,9 +9,9 @@ type Props = {
fieldName: string;
};
export const InputFieldResetToDefaultValueIconButton = memo(({ nodeId, fieldName }: Props) => {
export const InputFieldResetToDefaultValueIconButton = memo(({ fieldName }: Props) => {
const { t } = useTranslation();
const { isValueChanged, resetToDefaultValue } = useInputFieldDefaultValue(nodeId, fieldName);
const { isValueChanged, resetToDefaultValue } = useInputFieldDefaultValue(fieldName);
return (
<IconButton

View File

@@ -43,10 +43,10 @@ interface Props {
export const InputFieldTitle = memo((props: Props) => {
const { nodeId, fieldName, isInvalid, isDragging } = props;
const inputRef = useRef<HTMLInputElement>(null);
const label = useInputFieldUserTitleSafe(nodeId, fieldName);
const fieldTemplateTitle = useInputFieldTemplateTitleOrThrow(nodeId, fieldName);
const label = useInputFieldUserTitleSafe(fieldName);
const fieldTemplateTitle = useInputFieldTemplateTitleOrThrow(fieldName);
const { t } = useTranslation();
const isConnected = useInputFieldIsConnected(nodeId, fieldName);
const isConnected = useInputFieldIsConnected(fieldName);
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
const isConnectionInProgress = useIsConnectionInProgress();
const connectionError = useConnectionErrorTKey(nodeId, fieldName, 'target');

View File

@@ -12,13 +12,13 @@ interface Props {
fieldName: string;
}
export const InputFieldTooltipContent = memo(({ nodeId, fieldName }: Props) => {
export const InputFieldTooltipContent = memo(({ fieldName }: Props) => {
const { t } = useTranslation();
const fieldInstance = useInputFieldInstance(nodeId, fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(nodeId, fieldName);
const fieldInstance = useInputFieldInstance(fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(fieldName);
const fieldTypeName = useFieldTypeName(fieldTemplate.type);
const fieldErrors = useInputFieldErrors(nodeId, fieldName);
const fieldErrors = useInputFieldErrors(fieldName);
const fieldTitle = useMemo(() => {
if (fieldInstance.label && fieldTemplate.title) {

View File

@@ -12,7 +12,7 @@ type Props = PropsWithChildren<{
}>;
export const OutputFieldGate = memo(({ nodeId, fieldName, children }: Props) => {
const hasTemplate = useOutputFieldTemplateExists(nodeId, fieldName);
const hasTemplate = useOutputFieldTemplateExists(fieldName);
if (!hasTemplate) {
return <Fallback nodeId={nodeId} fieldName={fieldName} />;
@@ -23,9 +23,9 @@ export const OutputFieldGate = memo(({ nodeId, fieldName, children }: Props) =>
OutputFieldGate.displayName = 'OutputFieldGate';
const Fallback = memo(({ nodeId, fieldName }: Props) => {
const Fallback = memo(({ fieldName }: Props) => {
const { t } = useTranslation();
const name = useOutputFieldName(nodeId, fieldName);
const name = useOutputFieldName(fieldName);
return (
<OutputFieldWrapper>

View File

@@ -63,7 +63,7 @@ const handleStyles = {
} satisfies CSSProperties;
export const OutputFieldHandle = memo(({ nodeId, fieldName }: Props) => {
const fieldTemplate = useOutputFieldTemplate(nodeId, fieldName);
const fieldTemplate = useOutputFieldTemplate(fieldName);
const fieldTypeName = useFieldTypeName(fieldTemplate.type);
const fieldColor = useMemo(() => getFieldColor(fieldTemplate.type), [fieldTemplate.type]);
const isModelField = useMemo(() => isModelFieldType(fieldTemplate.type), [fieldTemplate.type]);

View File

@@ -27,8 +27,8 @@ type Props = {
};
export const OutputFieldTitle = memo(({ nodeId, fieldName }: Props) => {
const fieldTemplate = useOutputFieldTemplate(nodeId, fieldName);
const isConnected = useInputFieldIsConnected(nodeId, fieldName);
const fieldTemplate = useOutputFieldTemplate(fieldName);
const isConnected = useInputFieldIsConnected(fieldName);
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'source');
const isConnectionInProgress = useIsConnectionInProgress();
const connectionErrorTKey = useConnectionErrorTKey(nodeId, fieldName, 'source');

View File

@@ -9,8 +9,8 @@ interface Props {
fieldName: string;
}
export const OutputFieldTooltipContent = memo(({ nodeId, fieldName }: Props) => {
const fieldTemplate = useOutputFieldTemplate(nodeId, fieldName);
export const OutputFieldTooltipContent = memo(({ fieldName }: Props) => {
const fieldTemplate = useOutputFieldTemplate(fieldName);
const fieldTypeName = useFieldTypeName(fieldTemplate.type);
const { t } = useTranslation();

View File

@@ -40,7 +40,7 @@ export const FloatFieldCollectionInputComponent = memo(
const store = useAppStore();
const { t } = useTranslation();
const isInvalid = useInputFieldIsInvalid(nodeId, field.name);
const isInvalid = useInputFieldIsInvalid(field.name);
const onChangeValue = useCallback(
(value: FloatFieldCollectionInputInstance['value']) => {

View File

@@ -40,7 +40,7 @@ export const ImageFieldCollectionInputComponent = memo(
const { nodeId, field } = props;
const store = useAppStore();
const isInvalid = useInputFieldIsInvalid(nodeId, field.name);
const isInvalid = useInputFieldIsInvalid(field.name);
const dndTargetData = useMemo<AddImagesToNodeImageFieldCollection>(
() =>

View File

@@ -43,7 +43,7 @@ export const IntegerFieldCollectionInputComponent = memo(
const store = useAppStore();
const { t } = useTranslation();
const isInvalid = useInputFieldIsInvalid(nodeId, field.name);
const isInvalid = useInputFieldIsInvalid(field.name);
const onChangeValue = useCallback(
(value: IntegerFieldCollectionInputInstance['value']) => {

View File

@@ -33,7 +33,7 @@ export const StringFieldCollectionInputComponent = memo(
const { t } = useTranslation();
const store = useAppStore();
const isInvalid = useInputFieldIsInvalid(nodeId, field.name);
const isInvalid = useInputFieldIsInvalid(field.name);
const onRemoveString = useCallback(
(index: number) => {

View File

@@ -17,10 +17,10 @@ type Props = {
const NodeTitle = ({ nodeId, title }: Props) => {
const dispatch = useAppDispatch();
const label = useNodeUserTitleSafe(nodeId);
const label = useNodeUserTitleSafe();
const batchGroupId = useBatchGroupId(nodeId);
const batchGroupColorToken = useBatchGroupColorToken(batchGroupId);
const templateTitle = useNodeTemplateTitleSafe(nodeId);
const templateTitle = useNodeTemplateTitleSafe();
const { t } = useTranslation();
const inputRef = useRef<HTMLInputElement>(null);

View File

@@ -1,6 +1,7 @@
import type { ChakraProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, useGlobalMenuClose } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useInvocationNodeContext } from 'features/nodes/components/flow/nodes/Invocation/context';
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
import { useMouseOverFormField, useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { useNodeExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
@@ -15,6 +16,7 @@ type NodeWrapperProps = PropsWithChildren & {
nodeId: string;
selected: boolean;
width?: ChakraProps['w'];
isMissingTemplate?: boolean;
};
// Certain CSS transitions are disabled as a performance optimization - they can cause massive slowdowns in large
@@ -26,6 +28,19 @@ const containerSx: SystemStyleObject = {
borderRadius: 'base',
transitionProperty: 'none',
cursor: 'grab',
'--border-color': 'var(--invoke-colors-base-500)',
'--border-color-selected': 'var(--invoke-colors-blue-300)',
'--header-bg-color': 'var(--invoke-colors-base-900)',
'&[data-status="warning"]': {
'--border-color': 'var(--invoke-colors-warning-500)',
'--border-color-selected': 'var(--invoke-colors-warning-500)',
'--header-bg-color': 'var(--invoke-colors-warning-700)',
},
'&[data-status="error"]': {
'--border-color': 'var(--invoke-colors-error-500)',
'--border-color-selected': 'var(--invoke-colors-error-500)',
'--header-bg-color': 'var(--invoke-colors-error-700)',
},
// The action buttons are hidden by default and shown on hover
'& .node-selection-overlay': {
display: 'block',
@@ -37,7 +52,7 @@ const containerSx: SystemStyleObject = {
borderRadius: 'base',
transitionProperty: 'none',
pointerEvents: 'none',
shadow: '0 0 0 1px var(--invoke-colors-base-500)',
shadow: '0 0 0 1px var(--border-color)',
},
'&[data-is-mouse-over-node="true"] .node-selection-overlay': {
display: 'block',
@@ -49,16 +64,16 @@ const containerSx: SystemStyleObject = {
_hover: {
'& .node-selection-overlay': {
display: 'block',
shadow: '0 0 0 1px var(--invoke-colors-blue-300)',
shadow: '0 0 0 1px var(--border-color-selected)',
},
'&[data-is-selected="true"] .node-selection-overlay': {
display: 'block',
shadow: '0 0 0 2px var(--invoke-colors-blue-300)',
shadow: '0 0 0 2px var(--border-color-selected)',
},
},
'&[data-is-selected="true"] .node-selection-overlay': {
display: 'block',
shadow: '0 0 0 2px var(--invoke-colors-blue-300)',
shadow: '0 0 0 2px var(--border-color-selected)',
},
'&[data-is-editor-locked="true"]': {
'& *': {
@@ -99,7 +114,9 @@ const inProgressSx: SystemStyleObject = {
};
const NodeWrapper = (props: NodeWrapperProps) => {
const { nodeId, width, children, selected } = props;
const { nodeId, width, children, isMissingTemplate, selected } = props;
const ctx = useInvocationNodeContext();
const needsUpdate = useAppSelector(ctx.selectNodeNeedsUpdate);
const mouseOverNode = useMouseOverNode(nodeId);
const mouseOverFormField = useMouseOverFormField(nodeId);
const zoomToNode = useZoomToNode(nodeId);
@@ -149,6 +166,7 @@ const NodeWrapper = (props: NodeWrapperProps) => {
data-is-editor-locked={isLocked}
data-is-selected={selected}
data-is-mouse-over-form-field={mouseOverFormField.isMouseOverFormField}
data-status={isMissingTemplate ? 'error' : needsUpdate ? 'warning' : undefined}
>
<Box sx={shadowsSx} />
<Box sx={inProgressSx} data-is-in-progress={isInProgress} />

View File

@@ -2,6 +2,7 @@ import type { FlexProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, IconButton, Spacer, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { camelCase } from 'es-toolkit/compat';
import { InvocationNodeContextProvider } from 'features/nodes/components/flow/nodes/Invocation/context';
import { InputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldGate';
import { ContainerElementSettings } from 'features/nodes/components/sidePanel/builder/ContainerElementSettings';
import { useDepthContext } from 'features/nodes/components/sidePanel/builder/contexts';
@@ -49,14 +50,16 @@ export const FormElementEditModeHeader = memo(({ element, dragHandleRef, ...rest
<Spacer />
{isContainerElement(element) && <ContainerElementSettings element={element} />}
{isNodeFieldElement(element) && (
<InputFieldGate
nodeId={element.data.fieldIdentifier.nodeId}
fieldName={element.data.fieldIdentifier.fieldName}
fallback={null} // Do not render these buttons if the field is not found
>
<ZoomToNodeButton element={element} />
<NodeFieldElementSettings element={element} />
</InputFieldGate>
<InvocationNodeContextProvider nodeId={element.data.fieldIdentifier.nodeId}>
<InputFieldGate
nodeId={element.data.fieldIdentifier.nodeId}
fieldName={element.data.fieldIdentifier.fieldName}
fallback={null} // Do not render these buttons if the field is not found
>
<ZoomToNodeButton element={element} />
<NodeFieldElementSettings element={element} />
</InputFieldGate>
</InvocationNodeContextProvider>
)}
<RemoveElementButton element={element} />
</Flex>

View File

@@ -1,4 +1,5 @@
import { useAppSelector } from 'app/store/storeHooks';
import { InvocationNodeContextProvider } from 'features/nodes/components/flow/nodes/Invocation/context';
import { NodeFieldElementEditMode } from 'features/nodes/components/sidePanel/builder/NodeFieldElementEditMode';
import { NodeFieldElementViewMode } from 'features/nodes/components/sidePanel/builder/NodeFieldElementViewMode';
import { useElement } from 'features/nodes/components/sidePanel/builder/use-element';
@@ -15,11 +16,19 @@ export const NodeFieldElement = memo(({ id }: { id: string }) => {
}
if (mode === 'view') {
return <NodeFieldElementViewMode el={el} />;
return (
<InvocationNodeContextProvider nodeId={el.data.fieldIdentifier.nodeId}>
<NodeFieldElementViewMode el={el} />
</InvocationNodeContextProvider>
);
}
// mode === 'edit'
return <NodeFieldElementEditMode el={el} />;
return (
<InvocationNodeContextProvider nodeId={el.data.fieldIdentifier.nodeId}>
<NodeFieldElementEditMode el={el} />
</InvocationNodeContextProvider>
);
});
NodeFieldElement.displayName = 'NodeFieldElement';

View File

@@ -13,8 +13,8 @@ export const NodeFieldElementDescriptionEditable = memo(({ el }: { el: NodeField
const { data } = el;
const { fieldIdentifier } = data;
const dispatch = useAppDispatch();
const description = useInputFieldUserDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const description = useInputFieldUserDescriptionSafe(fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.fieldName);
const inputRef = useRef<HTMLTextAreaElement>(null);
const onChange = useCallback(

View File

@@ -1,5 +1,6 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Divider, Flex, FormControl } from '@invoke-ai/ui-library';
import { InvocationNodeContextProvider } from 'features/nodes/components/flow/nodes/Invocation/context';
import { InputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldGate';
import { InputFieldRenderer } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer';
import { useContainerContext } from 'features/nodes/components/sidePanel/builder/contexts';
@@ -63,25 +64,27 @@ const NodeFieldElementEditModeContent = memo(
<>
<FormElementEditModeHeader dragHandleRef={dragHandleRef} element={el} data-is-dragging={isDragging} />
<FormElementEditModeContent data-is-dragging={isDragging} p={4}>
<InputFieldGate nodeId={fieldIdentifier.nodeId} fieldName={fieldIdentifier.fieldName}>
<FormControl flex="1 1 0" orientation="vertical">
<NodeFieldElementLabelEditable el={el} />
<Flex w="full" gap={4}>
<InputFieldRenderer
nodeId={fieldIdentifier.nodeId}
fieldName={fieldIdentifier.fieldName}
settings={data.settings}
/>
</Flex>
{showDescription && <NodeFieldElementDescriptionEditable el={el} />}
{data.settings?.type === 'string-field-config' && data.settings.component === 'dropdown' && (
<>
<Divider />
<NodeFieldElementStringDropdownSettings id={id} settings={data.settings} />
</>
)}
</FormControl>
</InputFieldGate>
<InvocationNodeContextProvider nodeId={fieldIdentifier.nodeId}>
<InputFieldGate nodeId={fieldIdentifier.nodeId} fieldName={fieldIdentifier.fieldName}>
<FormControl flex="1 1 0" orientation="vertical">
<NodeFieldElementLabelEditable el={el} />
<Flex w="full" gap={4}>
<InputFieldRenderer
nodeId={fieldIdentifier.nodeId}
fieldName={fieldIdentifier.fieldName}
settings={data.settings}
/>
</Flex>
{showDescription && <NodeFieldElementDescriptionEditable el={el} />}
{data.settings?.type === 'string-field-config' && data.settings.component === 'dropdown' && (
<>
<Divider />
<NodeFieldElementStringDropdownSettings id={id} settings={data.settings} />
</>
)}
</FormControl>
</InputFieldGate>
</InvocationNodeContextProvider>
</FormElementEditModeContent>
</>
);

View File

@@ -67,7 +67,7 @@ SettingComponent.displayName = 'SettingComponent';
const SettingMin = memo(({ id, settings, nodeId, fieldName, fieldTemplate }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const field = useInputFieldInstance<FloatFieldInputInstance>(nodeId, fieldName);
const field = useInputFieldInstance<FloatFieldInputInstance>(fieldName);
const floatField = useFloatField(nodeId, fieldName, fieldTemplate);
@@ -129,7 +129,7 @@ SettingMin.displayName = 'SettingMin';
const SettingMax = memo(({ id, settings, nodeId, fieldName, fieldTemplate }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const field = useInputFieldInstance<FloatFieldInputInstance>(nodeId, fieldName);
const field = useInputFieldInstance<FloatFieldInputInstance>(fieldName);
const floatField = useFloatField(nodeId, fieldName, fieldTemplate);

View File

@@ -68,7 +68,7 @@ SettingComponent.displayName = 'SettingComponent';
const SettingMin = memo(({ id, settings, nodeId, fieldName, fieldTemplate }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const field = useInputFieldInstance<IntegerFieldInputInstance>(nodeId, fieldName);
const field = useInputFieldInstance<IntegerFieldInputInstance>(fieldName);
const integerField = useIntegerField(nodeId, fieldName, fieldTemplate);
@@ -131,7 +131,7 @@ SettingMin.displayName = 'SettingMin';
const SettingMax = memo(({ id, settings, nodeId, fieldName, fieldTemplate }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const field = useInputFieldInstance<IntegerFieldInputInstance>(nodeId, fieldName);
const field = useInputFieldInstance<IntegerFieldInputInstance>(fieldName);
const integerField = useIntegerField(nodeId, fieldName, fieldTemplate);

View File

@@ -8,8 +8,8 @@ import { memo, useMemo } from 'react';
export const NodeFieldElementLabel = memo(({ el }: { el: NodeFieldElement }) => {
const { data } = el;
const { fieldIdentifier } = data;
const label = useInputFieldUserTitleSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const label = useInputFieldUserTitleSafe(fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.fieldName);
const _label = useMemo(() => label || fieldTemplate.title, [label, fieldTemplate.title]);

View File

@@ -12,8 +12,8 @@ export const NodeFieldElementLabelEditable = memo(({ el }: { el: NodeFieldElemen
const { data } = el;
const { fieldIdentifier } = data;
const dispatch = useAppDispatch();
const label = useInputFieldUserTitleSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const label = useInputFieldUserTitleSafe(fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.fieldName);
const inputRef = useRef<HTMLInputElement>(null);
const onChange = useCallback(

View File

@@ -36,7 +36,7 @@ export const NodeFieldElementSettings = memo(({ element }: { element: NodeFieldE
const { id, data } = element;
const { showDescription, fieldIdentifier } = data;
const { nodeId, fieldName } = fieldIdentifier;
const fieldTemplate = useInputFieldTemplateOrThrow(nodeId, fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(fieldName);
const { t } = useTranslation();
const dispatch = useAppDispatch();

View File

@@ -37,8 +37,8 @@ const useFormatFallbackLabel = () => {
export const NodeFieldElementViewMode = memo(({ el }: { el: NodeFieldElement }) => {
const { id, data } = el;
const { fieldIdentifier, showDescription } = data;
const description = useInputFieldUserDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplateSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const description = useInputFieldUserDescriptionSafe(fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplateSafe(fieldIdentifier.fieldName);
const containerCtx = useContainerContext();
const formatFallbackLabel = useFormatFallbackLabel();
@@ -70,8 +70,8 @@ NodeFieldElementViewMode.displayName = 'NodeFieldElementViewMode';
const NodeFieldElementViewModeContent = memo(({ el }: { el: NodeFieldElement }) => {
const { data } = el;
const { fieldIdentifier, showDescription } = data;
const description = useInputFieldUserDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const description = useInputFieldUserDescriptionSafe(fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.fieldName);
const _description = useMemo(
() => description || fieldTemplate.description,

Some files were not shown because too many files have changed in this diff Show More