mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-20 04:17:59 -05:00
Compare commits
1 Commits
v6.1.0rc2
...
psychedeli
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c1a4376b75 |
26
.github/ISSUE_TEMPLATE/BUG_REPORT.yml
vendored
26
.github/ISSUE_TEMPLATE/BUG_REPORT.yml
vendored
@@ -21,20 +21,6 @@ body:
|
||||
- label: I have searched the existing issues
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: install_method
|
||||
attributes:
|
||||
label: Install method
|
||||
description: How did you install Invoke?
|
||||
multiple: false
|
||||
options:
|
||||
- "Invoke's Launcher"
|
||||
- 'Stability Matrix'
|
||||
- 'Pinokio'
|
||||
- 'Manual'
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: __Describe your environment__
|
||||
@@ -90,8 +76,8 @@ body:
|
||||
attributes:
|
||||
label: Version number
|
||||
description: |
|
||||
The version of Invoke you have installed. If it is not the [latest version](https://github.com/invoke-ai/InvokeAI/releases/latest), please update and try again to confirm the issue still exists. If you are testing main, please include the commit hash instead.
|
||||
placeholder: ex. v6.0.2
|
||||
The version of Invoke you have installed. If it is not the latest version, please update and try again to confirm the issue still exists. If you are testing main, please include the commit hash instead.
|
||||
placeholder: ex. 3.6.1
|
||||
validations:
|
||||
required: true
|
||||
|
||||
@@ -99,17 +85,17 @@ body:
|
||||
id: browser-version
|
||||
attributes:
|
||||
label: Browser
|
||||
description: Your web browser and version, if you do not use the Launcher's provided GUI.
|
||||
description: Your web browser and version.
|
||||
placeholder: ex. Firefox 123.0b3
|
||||
validations:
|
||||
required: false
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: python-deps
|
||||
attributes:
|
||||
label: System Information
|
||||
label: Python dependencies
|
||||
description: |
|
||||
Click the gear icon at the bottom left corner, then click "About". Click the copy button and then paste here.
|
||||
If the problem occurred during image generation, click the gear icon at the bottom left corner, click "About", click the copy button and then paste here.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -190,5 +190,3 @@ installer/update.bat
|
||||
installer/update.sh
|
||||
installer/InvokeAI-Installer/
|
||||
.aider*
|
||||
|
||||
.claude/
|
||||
|
||||
@@ -5,7 +5,8 @@
|
||||
FROM docker.io/node:22-slim AS web-builder
|
||||
ENV PNPM_HOME="/pnpm"
|
||||
ENV PATH="$PNPM_HOME:$PATH"
|
||||
RUN corepack use pnpm@10.x && corepack enable
|
||||
RUN corepack use pnpm@8.x
|
||||
RUN corepack enable
|
||||
|
||||
WORKDIR /build
|
||||
COPY invokeai/frontend/web/ ./
|
||||
|
||||
@@ -41,7 +41,7 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
|
||||
With the modifications made, the install command should look something like this:
|
||||
|
||||
```sh
|
||||
uv pip install -e ".[dev,test,docs,xformers]" --python 3.12 --python-preference only-managed --index=https://download.pytorch.org/whl/cu128 --reinstall
|
||||
uv pip install -e ".[dev,test,docs,xformers]" --python 3.12 --python-preference only-managed --index=https://download.pytorch.org/whl/cu126 --reinstall
|
||||
```
|
||||
|
||||
6. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.
|
||||
@@ -50,11 +50,11 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
|
||||
|
||||
If you only want to edit the docs, you can stop here and skip to the **Documentation** section below.
|
||||
|
||||
7. Install the frontend dev toolchain, paying attention to versions:
|
||||
7. Install the frontend dev toolchain:
|
||||
|
||||
- [`nodejs`](https://nodejs.org/) (tested on LTS, v22)
|
||||
- [`nodejs`](https://nodejs.org/) (v20+)
|
||||
|
||||
- [`pnpm`](https://pnpm.io/installation) (tested on v10)
|
||||
- [`pnpm`](https://pnpm.io/8.x/installation) (must be v8 - not v9!)
|
||||
|
||||
8. Do a production build of the frontend:
|
||||
|
||||
|
||||
@@ -297,7 +297,7 @@ Migration logic is in [migrations.ts].
|
||||
<!-- links -->
|
||||
|
||||
[pydantic]: https://github.com/pydantic/pydantic 'pydantic'
|
||||
[zod]: https://github.com/colinhacks/zod 'zod'
|
||||
[zod]: https://github.com/colinhacks/zod 'zod/v4'
|
||||
[openapi-types]: https://github.com/kogosoftwarellc/open-api/tree/main/packages/openapi-types 'openapi-types'
|
||||
[reactflow]: https://github.com/xyflow/xyflow 'reactflow'
|
||||
[reactflow-concepts]: https://reactflow.dev/learn/concepts/terms-and-definitions
|
||||
|
||||
@@ -430,15 +430,6 @@ class FluxConditioningOutput(BaseInvocationOutput):
|
||||
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("flux_conditioning_collection_output")
|
||||
class FluxConditioningCollectionOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a collection of conditioning tensors"""
|
||||
|
||||
collection: list[FluxConditioningField] = OutputField(
|
||||
description="The output conditioning tensors",
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("sd3_conditioning_output")
|
||||
class SD3ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single SD3 conditioning tensor"""
|
||||
|
||||
@@ -14,14 +14,15 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._conn = db.conn
|
||||
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO board_images (board_id, image_name)
|
||||
@@ -30,12 +31,17 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
""",
|
||||
(board_id, image_name, board_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM board_images
|
||||
@@ -43,6 +49,10 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
def get_images_for_board(
|
||||
self,
|
||||
@@ -50,26 +60,27 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
ORDER BY board_images.updated_at DESC;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
# TODO: this isn't paginated yet?
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
ORDER BY board_images.updated_at DESC;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images WHERE 1=1;
|
||||
"""
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images WHERE 1=1;
|
||||
"""
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||
|
||||
@@ -79,55 +90,56 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
categories: list[ImageCategory] | None,
|
||||
is_intermediate: bool | None,
|
||||
) -> list[str]:
|
||||
with self._db.transaction() as cursor:
|
||||
params: list[str | bool] = []
|
||||
params: list[str | bool] = []
|
||||
|
||||
# Base query is a join between images and board_images
|
||||
stmt = """
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
# Base query is a join between images and board_images
|
||||
stmt = """
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
# Handle board_id filter
|
||||
if board_id == "none":
|
||||
stmt += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
else:
|
||||
stmt += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
params.append(board_id)
|
||||
# Handle board_id filter
|
||||
if board_id == "none":
|
||||
stmt += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
else:
|
||||
stmt += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
params.append(board_id)
|
||||
|
||||
# Add the category filter
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
stmt += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
# Add the category filter
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
stmt += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
params.append(c)
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
params.append(c)
|
||||
|
||||
# Add the is_intermediate filter
|
||||
if is_intermediate is not None:
|
||||
stmt += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
params.append(is_intermediate)
|
||||
# Add the is_intermediate filter
|
||||
if is_intermediate is not None:
|
||||
stmt += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
params.append(is_intermediate)
|
||||
|
||||
# Put a ring on it
|
||||
stmt += ";"
|
||||
# Put a ring on it
|
||||
stmt += ";"
|
||||
|
||||
cursor.execute(stmt, params)
|
||||
# Execute the query
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(stmt, params)
|
||||
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [r[0] for r in result]
|
||||
return image_names
|
||||
|
||||
@@ -135,31 +147,31 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Optional[str]:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT board_id
|
||||
FROM board_images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT board_id
|
||||
FROM board_images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result is None:
|
||||
return None
|
||||
return cast(str, result[0])
|
||||
|
||||
def get_image_count_for_board(self, board_id: str) -> int:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE images.is_intermediate = FALSE
|
||||
AND board_images.board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE images.is_intermediate = FALSE
|
||||
AND board_images.board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
return count
|
||||
|
||||
@@ -20,57 +20,61 @@ from invokeai.app.util.misc import uuid_string
|
||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._conn = db.conn
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
except Exception as e:
|
||||
raise BoardRecordDeleteException from e
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordDeleteException from e
|
||||
|
||||
def save(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardRecord:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
board_id = uuid_string()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO boards (board_id, board_name)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
(board_id, board_name),
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise BoardRecordSaveException from e
|
||||
try:
|
||||
board_id = uuid_string()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO boards (board_id, board_name)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
(board_id, board_name),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordSaveException from e
|
||||
return self.get(board_id)
|
||||
|
||||
def get(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> BoardRecord:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
raise BoardRecordNotFoundException from e
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
raise BoardRecordNotFoundException from e
|
||||
if result is None:
|
||||
raise BoardRecordNotFoundException
|
||||
return BoardRecord(**dict(result))
|
||||
@@ -80,43 +84,45 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardRecord:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
# Change the name of a board
|
||||
if changes.board_name is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET board_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.board_name, board_id),
|
||||
)
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
# Change the name of a board
|
||||
if changes.board_name is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET board_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.board_name, board_id),
|
||||
)
|
||||
|
||||
# Change the cover image of a board
|
||||
if changes.cover_image_name is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET cover_image_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.cover_image_name, board_id),
|
||||
)
|
||||
# Change the cover image of a board
|
||||
if changes.cover_image_name is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET cover_image_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.cover_image_name, board_id),
|
||||
)
|
||||
|
||||
# Change the archived status of a board
|
||||
if changes.archived is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET archived = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.archived, board_id),
|
||||
)
|
||||
# Change the archived status of a board
|
||||
if changes.archived is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET archived = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.archived, board_id),
|
||||
)
|
||||
|
||||
except sqlite3.Error as e:
|
||||
raise BoardRecordSaveException from e
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordSaveException from e
|
||||
return self.get(board_id)
|
||||
|
||||
def get_many(
|
||||
@@ -127,77 +133,78 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
limit: int = 10,
|
||||
include_archived: bool = False,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
with self._db.transaction() as cursor:
|
||||
# Build base query
|
||||
base_query = """
|
||||
SELECT *
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
# Build base query
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
LIMIT ? OFFSET ?;
|
||||
"""
|
||||
|
||||
# Determine archived filter condition
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
# Execute query to fetch boards
|
||||
cursor.execute(final_query, (limit, offset))
|
||||
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
# Determine count query
|
||||
if include_archived:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards;
|
||||
"""
|
||||
else:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
LIMIT ? OFFSET ?;
|
||||
WHERE archived = 0;
|
||||
"""
|
||||
|
||||
# Determine archived filter condition
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
# Execute count query
|
||||
cursor.execute(count_query)
|
||||
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
# Execute query to fetch boards
|
||||
cursor.execute(final_query, (limit, offset))
|
||||
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
# Determine count query
|
||||
if include_archived:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards;
|
||||
"""
|
||||
else:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards
|
||||
WHERE archived = 0;
|
||||
"""
|
||||
|
||||
# Execute count query
|
||||
cursor.execute(count_query)
|
||||
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
|
||||
|
||||
def get_all(
|
||||
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
|
||||
) -> list[BoardRecord]:
|
||||
with self._db.transaction() as cursor:
|
||||
if order_by == BoardRecordOrderBy.Name:
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY LOWER(board_name) {direction}
|
||||
"""
|
||||
else:
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
"""
|
||||
cursor = self._conn.cursor()
|
||||
if order_by == BoardRecordOrderBy.Name:
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY LOWER(board_name) {direction}
|
||||
"""
|
||||
else:
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
"""
|
||||
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
cursor.execute(final_query)
|
||||
cursor.execute(final_query)
|
||||
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
return boards
|
||||
|
||||
@@ -8,7 +8,6 @@ import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from queue import Empty, PriorityQueue
|
||||
from shutil import disk_usage
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
|
||||
|
||||
import requests
|
||||
@@ -336,14 +335,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
|
||||
assert job.download_path
|
||||
|
||||
free_space = disk_usage(job.download_path.parent).free
|
||||
GB = 2**30
|
||||
self._logger.debug(f"Download is {job.total_bytes / GB:.2f} GB of {free_space / GB:.2f} GB free.")
|
||||
if free_space < job.total_bytes:
|
||||
raise RuntimeError(
|
||||
f"Free disk space {free_space / GB:.2f} GB is not enough for download of {job.total_bytes / GB:.2f} GB."
|
||||
)
|
||||
|
||||
# Don't clobber an existing file. See commit 82c2c85202f88c6d24ff84710f297cfc6ae174af
|
||||
# for code that instead resumes an interrupted download.
|
||||
if job.download_path.exists():
|
||||
|
||||
@@ -24,22 +24,22 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._conn = db.conn
|
||||
|
||||
def get(self, image_name: str) -> ImageRecord:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS} FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS} FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordNotFoundException from e
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordNotFoundException from e
|
||||
|
||||
if not result:
|
||||
raise ImageRecordNotFoundException
|
||||
@@ -47,20 +47,17 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
return deserialize_image_record(dict(result))
|
||||
|
||||
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordNotFoundException from e
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
|
||||
if not result:
|
||||
raise ImageRecordNotFoundException
|
||||
@@ -68,60 +65,64 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
as_dict = dict(result)
|
||||
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
|
||||
return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordNotFoundException from e
|
||||
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
# Change the category of the image
|
||||
if changes.image_category is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET image_category = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.image_category, image_name),
|
||||
)
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
# Change the category of the image
|
||||
if changes.image_category is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET image_category = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.image_category, image_name),
|
||||
)
|
||||
|
||||
# Change the session associated with the image
|
||||
if changes.session_id is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET session_id = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.session_id, image_name),
|
||||
)
|
||||
# Change the session associated with the image
|
||||
if changes.session_id is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET session_id = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.session_id, image_name),
|
||||
)
|
||||
|
||||
# Change the image's `is_intermediate`` flag
|
||||
if changes.is_intermediate is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET is_intermediate = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.is_intermediate, image_name),
|
||||
)
|
||||
# Change the image's `is_intermediate`` flag
|
||||
if changes.is_intermediate is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET is_intermediate = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.is_intermediate, image_name),
|
||||
)
|
||||
|
||||
# Change the image's `starred`` state
|
||||
if changes.starred is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET starred = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.starred, image_name),
|
||||
)
|
||||
# Change the image's `starred`` state
|
||||
if changes.starred is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET starred = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.starred, image_name),
|
||||
)
|
||||
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordSaveException from e
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordSaveException from e
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
@@ -135,162 +136,170 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
with self._db.transaction() as cursor:
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
count_query = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
count_query = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
images_query = f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS}
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
|
||||
images_query = f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS}
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
# board_id of "none" is reserved for images without a board
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
# Search term condition
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
if starred_first:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
else:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
# Final images query with pagination
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
# Add all the parameters
|
||||
images_params = query_params.copy()
|
||||
# Add the pagination parameters
|
||||
images_params.extend([limit, offset])
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
# Build the list of images, deserializing each row
|
||||
cursor.execute(images_query, images_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
# board_id of "none" is reserved for images without a board
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
# Search term condition
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
if starred_first:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
else:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
# Final images query with pagination
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
# Add all the parameters
|
||||
images_params = query_params.copy()
|
||||
# Add the pagination parameters
|
||||
images_params.extend([limit, offset])
|
||||
|
||||
# Build the list of images, deserializing each row
|
||||
cursor.execute(images_query, images_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
|
||||
# Set up and execute the count query, without pagination
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
cursor.execute(count_query, count_params)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
# Set up and execute the count query, without pagination
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
cursor.execute(count_query, count_params)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||
|
||||
def delete(self, image_name: str) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordDeleteException from e
|
||||
|
||||
def delete_many(self, image_names: list[str]) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
placeholders = ",".join("?" for _ in image_names)
|
||||
|
||||
# Construct the SQLite query with the placeholders
|
||||
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
|
||||
|
||||
# Execute the query with the list of IDs as parameters
|
||||
cursor.execute(query, image_names)
|
||||
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordDeleteException from e
|
||||
|
||||
def get_intermediates_count(self) -> int:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
DELETE FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
|
||||
def delete_many(self, image_names: list[str]) -> None:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
placeholders = ",".join("?" for _ in image_names)
|
||||
|
||||
# Construct the SQLite query with the placeholders
|
||||
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
|
||||
|
||||
# Execute the query with the list of IDs as parameters
|
||||
cursor.execute(query, image_names)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
|
||||
def get_intermediates_count(self) -> int:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
self._conn.commit()
|
||||
return count
|
||||
|
||||
def delete_intermediates(self) -> list[str]:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT image_name FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [r[0] for r in result]
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordDeleteException from e
|
||||
return image_names
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT image_name FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [r[0] for r in result]
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
self._conn.commit()
|
||||
return image_names
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
|
||||
def save(
|
||||
self,
|
||||
@@ -306,71 +315,73 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
node_id: Optional[str] = None,
|
||||
metadata: Optional[str] = None,
|
||||
) -> datetime:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO images (
|
||||
image_name,
|
||||
image_origin,
|
||||
image_category,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
image_origin.value,
|
||||
image_category.value,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow,
|
||||
),
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT created_at
|
||||
FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
created_at = datetime.fromisoformat(cursor.fetchone()[0])
|
||||
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordSaveException from e
|
||||
return created_at
|
||||
|
||||
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM images
|
||||
JOIN board_images ON images.image_name = board_images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
AND images.is_intermediate = FALSE
|
||||
ORDER BY images.starred DESC, images.created_at DESC
|
||||
LIMIT 1;
|
||||
INSERT OR IGNORE INTO images (
|
||||
image_name,
|
||||
image_origin,
|
||||
image_category,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(board_id,),
|
||||
(
|
||||
image_name,
|
||||
image_origin.value,
|
||||
image_category.value,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT created_at
|
||||
FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
created_at = datetime.fromisoformat(cursor.fetchone()[0])
|
||||
|
||||
return created_at
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordSaveException from e
|
||||
|
||||
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM images
|
||||
JOIN board_images ON images.image_name = board_images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
AND images.is_intermediate = FALSE
|
||||
ORDER BY images.starred DESC, images.created_at DESC
|
||||
LIMIT 1;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
@@ -387,84 +398,85 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> ImageNamesResult:
|
||||
with self._db.transaction() as cursor:
|
||||
# Build query conditions (reused for both starred count and image names queries)
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
# Build query conditions (reused for both starred count and image names queries)
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
|
||||
if categories is not None:
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
query_params.append(is_intermediate)
|
||||
if categories is not None:
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
# Get starred count if starred_first is enabled
|
||||
starred_count = 0
|
||||
if starred_first:
|
||||
starred_count_query = f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE images.starred = TRUE AND (1=1{query_conditions})
|
||||
"""
|
||||
cursor.execute(starred_count_query, query_params)
|
||||
starred_count = cast(int, cursor.fetchone()[0])
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
# Get all image names with proper ordering
|
||||
if starred_first:
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value}
|
||||
"""
|
||||
else:
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.created_at {order_dir.value}
|
||||
"""
|
||||
# Get starred count if starred_first is enabled
|
||||
starred_count = 0
|
||||
if starred_first:
|
||||
starred_count_query = f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE images.starred = TRUE AND (1=1{query_conditions})
|
||||
"""
|
||||
cursor.execute(starred_count_query, query_params)
|
||||
starred_count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
cursor.execute(names_query, query_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
# Get all image names with proper ordering
|
||||
if starred_first:
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value}
|
||||
"""
|
||||
else:
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.created_at {order_dir.value}
|
||||
"""
|
||||
|
||||
cursor.execute(names_query, query_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [row[0] for row in result]
|
||||
|
||||
return ImageNamesResult(image_names=image_names, starred_count=starred_count, total_count=len(image_names))
|
||||
|
||||
@@ -78,6 +78,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db = db
|
||||
self._logger = logger
|
||||
|
||||
@property
|
||||
def db(self) -> SqliteDatabase:
|
||||
"""Return the underlying database."""
|
||||
return self._db
|
||||
|
||||
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
@@ -88,33 +93,38 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||
"""
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO models (
|
||||
id,
|
||||
config
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
config.key,
|
||||
config.model_dump_json(),
|
||||
),
|
||||
)
|
||||
try:
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO models (
|
||||
id,
|
||||
config
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
config.key,
|
||||
config.model_dump_json(),
|
||||
),
|
||||
)
|
||||
self._db.conn.commit()
|
||||
|
||||
except sqlite3.IntegrityError as e:
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "models.path" in str(e):
|
||||
msg = f"A model with path '{config.path}' is already installed"
|
||||
elif "models.name" in str(e):
|
||||
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
||||
else:
|
||||
msg = f"A model with key '{config.key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._db.conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "models.path" in str(e):
|
||||
msg = f"A model with path '{config.path}' is already installed"
|
||||
elif "models.name" in str(e):
|
||||
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
||||
else:
|
||||
raise e
|
||||
msg = f"A model with key '{config.key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
else:
|
||||
raise e
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(config.key)
|
||||
|
||||
@@ -126,7 +136,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise an UnknownModelException
|
||||
"""
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM models
|
||||
@@ -136,17 +147,22 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
)
|
||||
if cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
with self._db.transaction() as cursor:
|
||||
record = self.get_model(key)
|
||||
record = self.get_model(key)
|
||||
|
||||
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
|
||||
for field_name in changes.model_fields_set:
|
||||
setattr(record, field_name, getattr(changes, field_name))
|
||||
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
|
||||
for field_name in changes.model_fields_set:
|
||||
setattr(record, field_name, getattr(changes, field_name))
|
||||
|
||||
json_serialized = record.model_dump_json()
|
||||
json_serialized = record.model_dump_json()
|
||||
|
||||
try:
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE models
|
||||
@@ -158,6 +174,10 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
)
|
||||
if cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(key)
|
||||
|
||||
@@ -169,30 +189,30 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
"""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
rows = cursor.fetchone()
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
rows = cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
rows = cursor.fetchone()
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
rows = cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
@@ -204,15 +224,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
:param key: Unique key for the model to be deleted
|
||||
"""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
return count > 0
|
||||
|
||||
def search_by_attr(
|
||||
@@ -235,42 +255,43 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
"""
|
||||
with self._db.transaction() as cursor:
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
|
||||
where_clause: list[str] = []
|
||||
bindings: list[str] = []
|
||||
if model_name:
|
||||
where_clause.append("name=?")
|
||||
bindings.append(model_name)
|
||||
if base_model:
|
||||
where_clause.append("base=?")
|
||||
bindings.append(base_model)
|
||||
if model_type:
|
||||
where_clause.append("type=?")
|
||||
bindings.append(model_type)
|
||||
if model_format:
|
||||
where_clause.append("format=?")
|
||||
bindings.append(model_format)
|
||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config, strftime('%s',updated_at)
|
||||
FROM models
|
||||
{where}
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
result = cursor.fetchall()
|
||||
where_clause: list[str] = []
|
||||
bindings: list[str] = []
|
||||
if model_name:
|
||||
where_clause.append("name=?")
|
||||
bindings.append(model_name)
|
||||
if base_model:
|
||||
where_clause.append("base=?")
|
||||
bindings.append(base_model)
|
||||
if model_type:
|
||||
where_clause.append("type=?")
|
||||
bindings.append(model_type)
|
||||
if model_format:
|
||||
where_clause.append("format=?")
|
||||
bindings.append(model_format)
|
||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config, strftime('%s',updated_at)
|
||||
FROM models
|
||||
{where}
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
result = cursor.fetchall()
|
||||
|
||||
# Parse the model configs.
|
||||
results: list[AnyModelConfig] = []
|
||||
@@ -292,68 +313,69 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated path."""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
return results
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated hash."""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
return results
|
||||
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Return a paginated summary listing of each model in the database."""
|
||||
with self._db.transaction() as cursor:
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
# query1: get the total number of model configs
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from models;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(cursor.fetchone()[0])
|
||||
cursor = self._db.conn.cursor()
|
||||
|
||||
# query2: fetch key fields
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config
|
||||
FROM models
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
""",
|
||||
(
|
||||
per_page,
|
||||
page * per_page,
|
||||
),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
# query1: get the total number of model configs
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from models;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(cursor.fetchone()[0])
|
||||
|
||||
# query2: fetch key fields
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config
|
||||
FROM models
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
""",
|
||||
(
|
||||
per_page,
|
||||
page * per_page,
|
||||
),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
items = [ModelSummary.model_validate(dict(x)) for x in rows]
|
||||
return PaginatedResults(page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.model_relationship_records.model_relationship_records_base import (
|
||||
ModelRelationshipRecordStorageBase,
|
||||
)
|
||||
@@ -7,49 +9,58 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._conn = db.conn
|
||||
|
||||
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
if model_key_1 == model_key_2:
|
||||
raise ValueError("Cannot relate a model to itself.")
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
if model_key_1 == model_key_2:
|
||||
raise ValueError("Cannot relate a model to itself.")
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"INSERT OR IGNORE INTO model_relationships (model_key_1, model_key_2) VALUES (?, ?)",
|
||||
(a, b),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"DELETE FROM model_relationships WHERE model_key_1 = ? AND model_key_2 = ?",
|
||||
(a, b),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
def get_related_model_keys(self, model_key: str) -> list[str]:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
|
||||
UNION
|
||||
SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
|
||||
""",
|
||||
(model_key, model_key),
|
||||
)
|
||||
result = [row[0] for row in cursor.fetchall()]
|
||||
return result
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
|
||||
UNION
|
||||
SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
|
||||
""",
|
||||
(model_key, model_key),
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
|
||||
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
|
||||
with self._db.transaction() as cursor:
|
||||
key_list = ",".join("?" for _ in model_keys)
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
|
||||
UNION
|
||||
SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
|
||||
""",
|
||||
model_keys + model_keys,
|
||||
)
|
||||
result = [row[0] for row in cursor.fetchall()]
|
||||
return result
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
key_list = ",".join("?" for _ in model_keys)
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
|
||||
UNION
|
||||
SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
|
||||
""",
|
||||
model_keys + model_keys,
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
|
||||
@@ -50,14 +50,15 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._conn = db.conn
|
||||
|
||||
def _set_in_progress_to_canceled(self) -> None:
|
||||
"""
|
||||
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
|
||||
This is necessary because the invoker may have been killed while processing a queue item.
|
||||
"""
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
@@ -65,79 +66,87 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
WHERE status = 'in_progress';
|
||||
"""
|
||||
)
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
def _get_current_queue_size(self, queue_id: str) -> int:
|
||||
"""Gets the current number of pending queue items"""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
return count
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
return cast(int, cursor.fetchone()[0])
|
||||
|
||||
def _get_highest_priority(self, queue_id: str) -> int:
|
||||
"""Gets the highest priority value in the queue"""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT MAX(priority)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
priority = cast(Union[int, None], cursor.fetchone()[0]) or 0
|
||||
return priority
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT MAX(priority)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
return cast(Union[int, None], cursor.fetchone()[0]) or 0
|
||||
|
||||
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||
current_queue_size = self._get_current_queue_size(queue_id)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
max_new_queue_items = max_queue_size - current_queue_size
|
||||
try:
|
||||
# TODO: how does this work in a multi-user scenario?
|
||||
current_queue_size = self._get_current_queue_size(queue_id)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
max_new_queue_items = max_queue_size - current_queue_size
|
||||
|
||||
priority = 0
|
||||
if prepend:
|
||||
priority = self._get_highest_priority(queue_id) + 1
|
||||
priority = 0
|
||||
if prepend:
|
||||
priority = self._get_highest_priority(queue_id) + 1
|
||||
|
||||
requested_count = await asyncio.to_thread(
|
||||
calc_session_count,
|
||||
batch=batch,
|
||||
)
|
||||
values_to_insert = await asyncio.to_thread(
|
||||
prepare_values_to_insert,
|
||||
queue_id=queue_id,
|
||||
batch=batch,
|
||||
priority=priority,
|
||||
max_new_queue_items=max_new_queue_items,
|
||||
)
|
||||
enqueued_count = len(values_to_insert)
|
||||
requested_count = await asyncio.to_thread(
|
||||
calc_session_count,
|
||||
batch=batch,
|
||||
)
|
||||
values_to_insert = await asyncio.to_thread(
|
||||
prepare_values_to_insert,
|
||||
queue_id=queue_id,
|
||||
batch=batch,
|
||||
priority=priority,
|
||||
max_new_queue_items=max_new_queue_items,
|
||||
)
|
||||
enqueued_count = len(values_to_insert)
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.executemany(
|
||||
"""--sql
|
||||
with self._conn:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
values_to_insert,
|
||||
)
|
||||
with self._conn:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT item_id
|
||||
FROM session_queue
|
||||
WHERE batch_id = ?
|
||||
ORDER BY item_id DESC;
|
||||
""",
|
||||
(batch.batch_id,),
|
||||
)
|
||||
item_ids = [row[0] for row in cursor.fetchall()]
|
||||
(batch.batch_id,),
|
||||
)
|
||||
item_ids = [row[0] for row in cursor.fetchall()]
|
||||
except Exception:
|
||||
raise
|
||||
enqueue_result = EnqueueBatchResult(
|
||||
queue_id=queue_id,
|
||||
requested=requested_count,
|
||||
@@ -150,19 +159,19 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return enqueue_result
|
||||
|
||||
def dequeue(self) -> Optional[SessionQueueItem]:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
return None
|
||||
queue_item = SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
@@ -170,40 +179,40 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return queue_item
|
||||
|
||||
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
created_at ASC
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
created_at ASC
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
return None
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
|
||||
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'in_progress'
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'in_progress'
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
return None
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
@@ -216,7 +225,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
error_message: Optional[str] = None,
|
||||
error_traceback: Optional[str] = None,
|
||||
) -> SessionQueueItem:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status FROM session_queue WHERE item_id = ?
|
||||
@@ -224,15 +234,12 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
(item_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
current_status = row[0]
|
||||
|
||||
# Only update if not already finished (completed, failed or canceled)
|
||||
if current_status in ("completed", "failed", "canceled"):
|
||||
return self.get_queue_item(item_id)
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
if row is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
current_status = row[0]
|
||||
# Only update if not already finished (completed, failed or canceled)
|
||||
if current_status in ("completed", "failed", "canceled"):
|
||||
return self.get_queue_item(item_id)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
@@ -241,7 +248,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(status, error_type, error_message, error_traceback, item_id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
|
||||
@@ -249,34 +259,35 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return queue_item
|
||||
|
||||
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
is_empty = cast(int, cursor.fetchone()[0]) == 0
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
is_empty = cast(int, cursor.fetchone()[0]) == 0
|
||||
return IsEmptyResult(is_empty=is_empty)
|
||||
|
||||
def is_full(self, queue_id: str) -> IsFullResult:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
|
||||
return IsFullResult(is_full=is_full)
|
||||
|
||||
def clear(self, queue_id: str) -> ClearResult:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
@@ -294,19 +305,24 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
self.__invoker.services.events.emit_queue_cleared(queue_id)
|
||||
return ClearResult(deleted=count)
|
||||
|
||||
def prune(self, queue_id: str) -> PruneResult:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND (
|
||||
queue_id = ?
|
||||
AND (
|
||||
status = 'completed'
|
||||
OR status = 'failed'
|
||||
OR status = 'canceled'
|
||||
)
|
||||
)
|
||||
"""
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
@@ -325,6 +341,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return PruneResult(deleted=count)
|
||||
|
||||
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
@@ -337,7 +357,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.cancel_queue_item(item_id)
|
||||
except SessionQueueItemNotFoundError:
|
||||
pass
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE
|
||||
@@ -346,6 +367,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
|
||||
@@ -368,7 +393,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return queue_item
|
||||
|
||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
placeholders = ", ".join(["?" for _ in batch_ids])
|
||||
where = f"""--sql
|
||||
@@ -399,14 +425,17 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
|
||||
self._conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return CancelByBatchIDsResult(canceled=count)
|
||||
|
||||
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
where = """--sql
|
||||
WHERE
|
||||
@@ -436,12 +465,17 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
params,
|
||||
)
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
self._conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return CancelByDestinationResult(canceled=count)
|
||||
|
||||
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
self.cancel_queue_item(current_queue_item.item_id)
|
||||
@@ -467,10 +501,15 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
params,
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return DeleteByDestinationResult(deleted=count)
|
||||
|
||||
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
@@ -493,10 +532,15 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return DeleteAllExceptCurrentResult(deleted=count)
|
||||
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
where = """--sql
|
||||
WHERE
|
||||
@@ -525,13 +569,18 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return CancelByQueueIDResult(canceled=count)
|
||||
|
||||
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
@@ -554,25 +603,30 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return CancelAllExceptCurrentResult(canceled=count)
|
||||
|
||||
def get_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT * FROM session_queue
|
||||
WHERE
|
||||
item_id = ?
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT * FROM session_queue
|
||||
WHERE
|
||||
item_id = ?
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
|
||||
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
|
||||
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
|
||||
# during execution.
|
||||
@@ -585,6 +639,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(session_json, item_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return self.get_queue_item(item_id)
|
||||
|
||||
def list_queue_items(
|
||||
@@ -596,42 +654,42 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||
destination: Optional[str] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItem]:
|
||||
with self._db.transaction() as cursor_:
|
||||
item_id = cursor
|
||||
query = """--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if status is not None:
|
||||
query += """--sql
|
||||
AND status = ?
|
||||
"""
|
||||
params.append(status)
|
||||
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
"""
|
||||
params.append(destination)
|
||||
|
||||
if item_id is not None:
|
||||
query += """--sql
|
||||
AND (priority < ?) OR (priority = ? AND item_id > ?)
|
||||
"""
|
||||
params.extend([priority, priority, item_id])
|
||||
cursor_ = self._conn.cursor()
|
||||
item_id = cursor
|
||||
query = """--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if status is not None:
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT ?
|
||||
AND status = ?
|
||||
"""
|
||||
params.append(limit + 1)
|
||||
cursor_.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor_.fetchall())
|
||||
params.append(status)
|
||||
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
"""
|
||||
params.append(destination)
|
||||
|
||||
if item_id is not None:
|
||||
query += """--sql
|
||||
AND (priority < ?) OR (priority = ? AND item_id > ?)
|
||||
"""
|
||||
params.extend([priority, priority, item_id])
|
||||
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit + 1)
|
||||
cursor_.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor_.fetchall())
|
||||
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
|
||||
has_more = False
|
||||
if len(items) > limit:
|
||||
@@ -646,43 +704,43 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
destination: Optional[str] = None,
|
||||
) -> list[SessionQueueItem]:
|
||||
"""Gets all queue items that match the given parameters"""
|
||||
with self._db.transaction() as cursor:
|
||||
query = """--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
cursor_ = self._conn.cursor()
|
||||
query = """--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
params.append(destination)
|
||||
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
"""
|
||||
params.append(destination)
|
||||
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
;
|
||||
"""
|
||||
cursor.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
;
|
||||
"""
|
||||
cursor_.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor_.fetchall())
|
||||
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
|
||||
return items
|
||||
|
||||
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
current_item = self.get_current(queue_id=queue_id)
|
||||
total = sum(row[1] or 0 for row in counts_result)
|
||||
@@ -701,19 +759,19 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
|
||||
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*), origin, destination
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND batch_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, batch_id),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*), origin, destination
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND batch_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, batch_id),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
total = sum(row[1] or 0 for row in result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||
origin = result[0]["origin"] if result else None
|
||||
@@ -733,18 +791,18 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
|
||||
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
AND destination = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, destination),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
AND destination = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, destination),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
total = sum(row[1] or 0 for row in counts_result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
|
||||
@@ -762,7 +820,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
|
||||
"""Retries the given queue items"""
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
values_to_insert: list[ValueToInsertTuple] = []
|
||||
retried_item_ids: list[int] = []
|
||||
|
||||
@@ -813,6 +872,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
values_to_insert,
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
retry_result = RetryItemsResult(
|
||||
queue_id=queue_id,
|
||||
retried_item_ids=retried_item_ids,
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
@@ -29,65 +26,46 @@ class SqliteDatabase:
|
||||
|
||||
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
|
||||
"""Initializes the database. This is used internally by the class constructor."""
|
||||
self._logger = logger
|
||||
self._db_path = db_path
|
||||
self._verbose = verbose
|
||||
self._lock = threading.RLock()
|
||||
self.logger = logger
|
||||
self.db_path = db_path
|
||||
self.verbose = verbose
|
||||
|
||||
if not self._db_path:
|
||||
if not self.db_path:
|
||||
logger.info("Initializing in-memory database")
|
||||
else:
|
||||
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._logger.info(f"Initializing database at {self._db_path}")
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.logger.info(f"Initializing database at {self.db_path}")
|
||||
|
||||
self._conn = sqlite3.connect(database=self._db_path or sqlite_memory, check_same_thread=False)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
|
||||
if self._verbose:
|
||||
self._conn.set_trace_callback(self._logger.debug)
|
||||
if self.verbose:
|
||||
self.conn.set_trace_callback(self.logger.debug)
|
||||
|
||||
# Enable foreign key constraints
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self.conn.execute("PRAGMA foreign_keys = ON;")
|
||||
|
||||
# Enable Write-Ahead Logging (WAL) mode for better concurrency
|
||||
self._conn.execute("PRAGMA journal_mode = WAL;")
|
||||
self.conn.execute("PRAGMA journal_mode = WAL;")
|
||||
|
||||
# Set a busy timeout to prevent database lockups during writes
|
||||
self._conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
|
||||
self.conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
|
||||
|
||||
def clean(self) -> None:
|
||||
"""
|
||||
Cleans the database by running the VACUUM command, reporting on the freed space.
|
||||
"""
|
||||
# No need to clean in-memory database
|
||||
if not self._db_path:
|
||||
if not self.db_path:
|
||||
return
|
||||
try:
|
||||
with self._conn as conn:
|
||||
initial_db_size = Path(self._db_path).stat().st_size
|
||||
conn.execute("VACUUM;")
|
||||
conn.commit()
|
||||
final_db_size = Path(self._db_path).stat().st_size
|
||||
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
||||
if freed_space_in_mb > 0:
|
||||
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
||||
initial_db_size = Path(self.db_path).stat().st_size
|
||||
self.conn.execute("VACUUM;")
|
||||
self.conn.commit()
|
||||
final_db_size = Path(self.db_path).stat().st_size
|
||||
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
||||
if freed_space_in_mb > 0:
|
||||
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error cleaning database: {e}")
|
||||
self.logger.error(f"Error cleaning database: {e}")
|
||||
raise
|
||||
|
||||
@contextmanager
|
||||
def transaction(self) -> Generator[sqlite3.Cursor, None, None]:
|
||||
"""
|
||||
Thread-safe context manager for DB work.
|
||||
Acquires the RLock, yields a Cursor, then commits or rolls back.
|
||||
"""
|
||||
with self._lock:
|
||||
cursor = self._conn.cursor()
|
||||
try:
|
||||
yield cursor
|
||||
self._conn.commit()
|
||||
except:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
@@ -32,7 +32,7 @@ class SqliteMigrator:
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
self._db = db
|
||||
self._logger = db._logger
|
||||
self._logger = db.logger
|
||||
self._migration_set = MigrationSet()
|
||||
self._backup_path: Optional[Path] = None
|
||||
|
||||
@@ -45,7 +45,7 @@ class SqliteMigrator:
|
||||
"""Migrates the database to the latest version."""
|
||||
# This throws if there is a problem.
|
||||
self._migration_set.validate_migration_chain()
|
||||
cursor = self._db._conn.cursor()
|
||||
cursor = self._db.conn.cursor()
|
||||
self._create_migrations_table(cursor=cursor)
|
||||
|
||||
if self._migration_set.count == 0:
|
||||
@@ -59,13 +59,13 @@ class SqliteMigrator:
|
||||
self._logger.info("Database update needed")
|
||||
|
||||
# Make a backup of the db if it needs to be updated and is a file db
|
||||
if self._db._db_path is not None:
|
||||
if self._db.db_path is not None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
self._backup_path = self._db._db_path.parent / f"{self._db._db_path.stem}_backup_{timestamp}.db"
|
||||
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
|
||||
self._logger.info(f"Backing up database to {str(self._backup_path)}")
|
||||
# Use SQLite to do the backup
|
||||
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
|
||||
self._db._conn.backup(backup_conn)
|
||||
self._db.conn.backup(backup_conn)
|
||||
else:
|
||||
self._logger.info("Using in-memory database, no backup needed")
|
||||
|
||||
@@ -81,7 +81,7 @@ class SqliteMigrator:
|
||||
try:
|
||||
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
|
||||
# exception is raised.
|
||||
with self._db._conn as conn:
|
||||
with self._db.conn as conn:
|
||||
cursor = conn.cursor()
|
||||
if self._get_current_version(cursor) != migration.from_version:
|
||||
raise MigrationError(
|
||||
|
||||
@@ -17,7 +17,7 @@ from invokeai.app.util.misc import uuid_string
|
||||
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._conn = db.conn
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
@@ -25,23 +25,24 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
|
||||
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
|
||||
"""Gets a style preset by ID."""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM style_presets
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM style_presets
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
|
||||
return StylePresetRecordDTO.from_dict(dict(row))
|
||||
|
||||
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
|
||||
style_preset_id = uuid_string()
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO style_presets (
|
||||
@@ -59,11 +60,16 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
style_preset.type,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
|
||||
style_preset_ids = []
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
for style_preset in style_presets:
|
||||
style_preset_id = uuid_string()
|
||||
style_preset_ids.append(style_preset_id)
|
||||
@@ -84,11 +90,16 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
style_preset.type,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
return None
|
||||
|
||||
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
# Change the name of a style preset
|
||||
if changes.name is not None:
|
||||
cursor.execute(
|
||||
@@ -111,10 +122,15 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
(changes.preset_data.model_dump_json(), style_preset_id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE from style_presets
|
||||
@@ -122,41 +138,51 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return None
|
||||
|
||||
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
|
||||
with self._db.transaction() as cursor:
|
||||
main_query = """
|
||||
SELECT
|
||||
*
|
||||
FROM style_presets
|
||||
"""
|
||||
main_query = """
|
||||
SELECT
|
||||
*
|
||||
FROM style_presets
|
||||
"""
|
||||
|
||||
if type is not None:
|
||||
main_query += "WHERE type = ? "
|
||||
if type is not None:
|
||||
main_query += "WHERE type = ? "
|
||||
|
||||
main_query += "ORDER BY LOWER(name) ASC"
|
||||
main_query += "ORDER BY LOWER(name) ASC"
|
||||
|
||||
if type is not None:
|
||||
cursor.execute(main_query, (type,))
|
||||
else:
|
||||
cursor.execute(main_query)
|
||||
cursor = self._conn.cursor()
|
||||
if type is not None:
|
||||
cursor.execute(main_query, (type,))
|
||||
else:
|
||||
cursor.execute(main_query)
|
||||
|
||||
rows = cursor.fetchall()
|
||||
rows = cursor.fetchall()
|
||||
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
|
||||
|
||||
return style_presets
|
||||
|
||||
def _sync_default_style_presets(self) -> None:
|
||||
"""Syncs default style presets to the database. Internal use only."""
|
||||
with self._db.transaction() as cursor:
|
||||
# First delete all existing default style presets
|
||||
|
||||
# First delete all existing default style presets
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM style_presets
|
||||
WHERE type = "default";
|
||||
"""
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
# Next, parse and create the default style presets
|
||||
with open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
|
||||
presets = json.load(file)
|
||||
|
||||
@@ -25,7 +25,7 @@ SQL_TIME_FORMAT = "%Y-%m-%d %H:%M:%f"
|
||||
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._conn = db.conn
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
@@ -33,16 +33,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
|
||||
def get(self, workflow_id: str) -> WorkflowRecordDTO:
|
||||
"""Gets a workflow by ID. Updates the opened_at column."""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
|
||||
FROM workflow_library
|
||||
WHERE workflow_id = ?;
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
|
||||
FROM workflow_library
|
||||
WHERE workflow_id = ?;
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
|
||||
return WorkflowRecordDTO.from_dict(dict(row))
|
||||
@@ -51,8 +51,9 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
if workflow.meta.category is WorkflowCategory.Default:
|
||||
raise ValueError("Default workflows cannot be created via this method")
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO workflow_library (
|
||||
@@ -63,13 +64,18 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
""",
|
||||
(workflow_with_id.id, workflow_with_id.model_dump_json()),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return self.get(workflow_with_id.id)
|
||||
|
||||
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
|
||||
if workflow.meta.category is WorkflowCategory.Default:
|
||||
raise ValueError("Default workflows cannot be updated")
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE workflow_library
|
||||
@@ -78,13 +84,18 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
""",
|
||||
(workflow.model_dump_json(), workflow.id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return self.get(workflow.id)
|
||||
|
||||
def delete(self, workflow_id: str) -> None:
|
||||
if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default:
|
||||
raise ValueError("Default workflows cannot be deleted")
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE from workflow_library
|
||||
@@ -92,6 +103,10 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return None
|
||||
|
||||
def get_many(
|
||||
@@ -106,108 +121,108 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
with self._db.transaction() as cursor:
|
||||
# sanitize!
|
||||
assert order_by in WorkflowRecordOrderBy
|
||||
assert direction in SQLiteDirection
|
||||
# sanitize!
|
||||
assert order_by in WorkflowRecordOrderBy
|
||||
assert direction in SQLiteDirection
|
||||
|
||||
# We will construct the query dynamically based on the query params
|
||||
# We will construct the query dynamically based on the query params
|
||||
|
||||
# The main query to get the workflows / counts
|
||||
main_query = """
|
||||
SELECT
|
||||
workflow_id,
|
||||
category,
|
||||
name,
|
||||
description,
|
||||
created_at,
|
||||
updated_at,
|
||||
opened_at,
|
||||
tags
|
||||
FROM workflow_library
|
||||
"""
|
||||
count_query = "SELECT COUNT(*) FROM workflow_library"
|
||||
# The main query to get the workflows / counts
|
||||
main_query = """
|
||||
SELECT
|
||||
workflow_id,
|
||||
category,
|
||||
name,
|
||||
description,
|
||||
created_at,
|
||||
updated_at,
|
||||
opened_at,
|
||||
tags
|
||||
FROM workflow_library
|
||||
"""
|
||||
count_query = "SELECT COUNT(*) FROM workflow_library"
|
||||
|
||||
# Start with an empty list of conditions and params
|
||||
conditions: list[str] = []
|
||||
params: list[str | int] = []
|
||||
# Start with an empty list of conditions and params
|
||||
conditions: list[str] = []
|
||||
params: list[str | int] = []
|
||||
|
||||
if categories:
|
||||
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
|
||||
if categories:
|
||||
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
|
||||
|
||||
# Ensure all categories are valid (is this necessary?)
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
# Ensure all categories are valid (is this necessary?)
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
|
||||
# Construct a placeholder string for the number of categories
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
# Construct a placeholder string for the number of categories
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
|
||||
# Construct the condition string & params
|
||||
category_condition = f"category IN ({placeholders})"
|
||||
category_params = [category.value for category in categories]
|
||||
# Construct the condition string & params
|
||||
category_condition = f"category IN ({placeholders})"
|
||||
category_params = [category.value for category in categories]
|
||||
|
||||
conditions.append(category_condition)
|
||||
params.extend(category_params)
|
||||
conditions.append(category_condition)
|
||||
params.extend(category_params)
|
||||
|
||||
if tags:
|
||||
# Tags is a list of strings, and a single string in the DB
|
||||
# The string in the DB has no guaranteed format
|
||||
if tags:
|
||||
# Tags is a list of strings, and a single string in the DB
|
||||
# The string in the DB has no guaranteed format
|
||||
|
||||
# Construct a list of conditions for each tag
|
||||
tags_conditions = ["tags LIKE ?" for _ in tags]
|
||||
tags_conditions_joined = " OR ".join(tags_conditions)
|
||||
tags_condition = f"({tags_conditions_joined})"
|
||||
# Construct a list of conditions for each tag
|
||||
tags_conditions = ["tags LIKE ?" for _ in tags]
|
||||
tags_conditions_joined = " OR ".join(tags_conditions)
|
||||
tags_condition = f"({tags_conditions_joined})"
|
||||
|
||||
# And the params for the tags, case-insensitive
|
||||
tags_params = [f"%{t.strip()}%" for t in tags]
|
||||
# And the params for the tags, case-insensitive
|
||||
tags_params = [f"%{t.strip()}%" for t in tags]
|
||||
|
||||
conditions.append(tags_condition)
|
||||
params.extend(tags_params)
|
||||
conditions.append(tags_condition)
|
||||
params.extend(tags_params)
|
||||
|
||||
if has_been_opened:
|
||||
conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
conditions.append("opened_at IS NULL")
|
||||
if has_been_opened:
|
||||
conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
conditions.append("opened_at IS NULL")
|
||||
|
||||
# Ignore whitespace in the query
|
||||
stripped_query = query.strip() if query else None
|
||||
if stripped_query:
|
||||
# Construct a wildcard query for the name, description, and tags
|
||||
wildcard_query = "%" + stripped_query + "%"
|
||||
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
|
||||
# Ignore whitespace in the query
|
||||
stripped_query = query.strip() if query else None
|
||||
if stripped_query:
|
||||
# Construct a wildcard query for the name, description, and tags
|
||||
wildcard_query = "%" + stripped_query + "%"
|
||||
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
|
||||
|
||||
conditions.append(query_condition)
|
||||
params.extend([wildcard_query, wildcard_query, wildcard_query])
|
||||
conditions.append(query_condition)
|
||||
params.extend([wildcard_query, wildcard_query, wildcard_query])
|
||||
|
||||
if conditions:
|
||||
# If there are conditions, add a WHERE clause and then join the conditions
|
||||
main_query += " WHERE "
|
||||
count_query += " WHERE "
|
||||
if conditions:
|
||||
# If there are conditions, add a WHERE clause and then join the conditions
|
||||
main_query += " WHERE "
|
||||
count_query += " WHERE "
|
||||
|
||||
all_conditions = " AND ".join(conditions)
|
||||
main_query += all_conditions
|
||||
count_query += all_conditions
|
||||
all_conditions = " AND ".join(conditions)
|
||||
main_query += all_conditions
|
||||
count_query += all_conditions
|
||||
|
||||
# After this point, the query and params differ for the main query and the count query
|
||||
main_params = params.copy()
|
||||
count_params = params.copy()
|
||||
# After this point, the query and params differ for the main query and the count query
|
||||
main_params = params.copy()
|
||||
count_params = params.copy()
|
||||
|
||||
# Main query also gets ORDER BY and LIMIT/OFFSET
|
||||
main_query += f" ORDER BY {order_by.value} {direction.value}"
|
||||
# Main query also gets ORDER BY and LIMIT/OFFSET
|
||||
main_query += f" ORDER BY {order_by.value} {direction.value}"
|
||||
|
||||
if per_page:
|
||||
main_query += " LIMIT ? OFFSET ?"
|
||||
main_params.extend([per_page, page * per_page])
|
||||
if per_page:
|
||||
main_query += " LIMIT ? OFFSET ?"
|
||||
main_params.extend([per_page, page * per_page])
|
||||
|
||||
# Put a ring on it
|
||||
main_query += ";"
|
||||
count_query += ";"
|
||||
# Put a ring on it
|
||||
main_query += ";"
|
||||
count_query += ";"
|
||||
|
||||
cursor.execute(main_query, main_params)
|
||||
rows = cursor.fetchall()
|
||||
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(main_query, main_params)
|
||||
rows = cursor.fetchall()
|
||||
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
|
||||
|
||||
cursor.execute(count_query, count_params)
|
||||
total = cursor.fetchone()[0]
|
||||
cursor.execute(count_query, count_params)
|
||||
total = cursor.fetchone()[0]
|
||||
|
||||
if per_page:
|
||||
pages = total // per_page + (total % per_page > 0)
|
||||
@@ -232,46 +247,46 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
if not tags:
|
||||
return {}
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
result: dict[str, int] = {}
|
||||
# Base conditions for categories and selected tags
|
||||
base_conditions: list[str] = []
|
||||
base_params: list[str | int] = []
|
||||
cursor = self._conn.cursor()
|
||||
result: dict[str, int] = {}
|
||||
# Base conditions for categories and selected tags
|
||||
base_conditions: list[str] = []
|
||||
base_params: list[str | int] = []
|
||||
|
||||
# Add category conditions
|
||||
if categories:
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
base_conditions.append(f"category IN ({placeholders})")
|
||||
base_params.extend([category.value for category in categories])
|
||||
# Add category conditions
|
||||
if categories:
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
base_conditions.append(f"category IN ({placeholders})")
|
||||
base_params.extend([category.value for category in categories])
|
||||
|
||||
if has_been_opened:
|
||||
base_conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
base_conditions.append("opened_at IS NULL")
|
||||
if has_been_opened:
|
||||
base_conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
base_conditions.append("opened_at IS NULL")
|
||||
|
||||
# For each tag to count, run a separate query
|
||||
for tag in tags:
|
||||
# Start with the base conditions
|
||||
conditions = base_conditions.copy()
|
||||
params = base_params.copy()
|
||||
# For each tag to count, run a separate query
|
||||
for tag in tags:
|
||||
# Start with the base conditions
|
||||
conditions = base_conditions.copy()
|
||||
params = base_params.copy()
|
||||
|
||||
# Add this specific tag condition
|
||||
conditions.append("tags LIKE ?")
|
||||
params.append(f"%{tag.strip()}%")
|
||||
# Add this specific tag condition
|
||||
conditions.append("tags LIKE ?")
|
||||
params.append(f"%{tag.strip()}%")
|
||||
|
||||
# Construct the full query
|
||||
stmt = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM workflow_library
|
||||
"""
|
||||
# Construct the full query
|
||||
stmt = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM workflow_library
|
||||
"""
|
||||
|
||||
if conditions:
|
||||
stmt += " WHERE " + " AND ".join(conditions)
|
||||
if conditions:
|
||||
stmt += " WHERE " + " AND ".join(conditions)
|
||||
|
||||
cursor.execute(stmt, params)
|
||||
count = cursor.fetchone()[0]
|
||||
result[tag] = count
|
||||
cursor.execute(stmt, params)
|
||||
count = cursor.fetchone()[0]
|
||||
result[tag] = count
|
||||
|
||||
return result
|
||||
|
||||
@@ -281,51 +296,52 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
with self._db.transaction() as cursor:
|
||||
result: dict[str, int] = {}
|
||||
# Base conditions for categories
|
||||
base_conditions: list[str] = []
|
||||
base_params: list[str | int] = []
|
||||
cursor = self._conn.cursor()
|
||||
result: dict[str, int] = {}
|
||||
# Base conditions for categories
|
||||
base_conditions: list[str] = []
|
||||
base_params: list[str | int] = []
|
||||
|
||||
# Add category conditions
|
||||
if categories:
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
base_conditions.append(f"category IN ({placeholders})")
|
||||
base_params.extend([category.value for category in categories])
|
||||
# Add category conditions
|
||||
if categories:
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
base_conditions.append(f"category IN ({placeholders})")
|
||||
base_params.extend([category.value for category in categories])
|
||||
|
||||
if has_been_opened:
|
||||
base_conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
base_conditions.append("opened_at IS NULL")
|
||||
if has_been_opened:
|
||||
base_conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
base_conditions.append("opened_at IS NULL")
|
||||
|
||||
# For each category to count, run a separate query
|
||||
for category in categories:
|
||||
# Start with the base conditions
|
||||
conditions = base_conditions.copy()
|
||||
params = base_params.copy()
|
||||
# For each category to count, run a separate query
|
||||
for category in categories:
|
||||
# Start with the base conditions
|
||||
conditions = base_conditions.copy()
|
||||
params = base_params.copy()
|
||||
|
||||
# Add this specific category condition
|
||||
conditions.append("category = ?")
|
||||
params.append(category.value)
|
||||
# Add this specific category condition
|
||||
conditions.append("category = ?")
|
||||
params.append(category.value)
|
||||
|
||||
# Construct the full query
|
||||
stmt = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM workflow_library
|
||||
"""
|
||||
# Construct the full query
|
||||
stmt = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM workflow_library
|
||||
"""
|
||||
|
||||
if conditions:
|
||||
stmt += " WHERE " + " AND ".join(conditions)
|
||||
if conditions:
|
||||
stmt += " WHERE " + " AND ".join(conditions)
|
||||
|
||||
cursor.execute(stmt, params)
|
||||
count = cursor.fetchone()[0]
|
||||
result[category.value] = count
|
||||
cursor.execute(stmt, params)
|
||||
count = cursor.fetchone()[0]
|
||||
result[category.value] = count
|
||||
|
||||
return result
|
||||
|
||||
def update_opened_at(self, workflow_id: str) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE workflow_library
|
||||
@@ -334,6 +350,10 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
def _sync_default_workflows(self) -> None:
|
||||
"""Syncs default workflows to the database. Internal use only."""
|
||||
@@ -348,7 +368,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
meaningless, as they are overwritten every time the server starts.
|
||||
"""
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
workflows_from_file: list[Workflow] = []
|
||||
workflows_to_update: list[Workflow] = []
|
||||
workflows_to_add: list[Workflow] = []
|
||||
@@ -428,3 +449,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
""",
|
||||
(w.model_dump_json(), w.id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
@@ -187,7 +187,7 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
else:
|
||||
return config_cls.from_model_on_disk(mod, **overrides)
|
||||
|
||||
raise InvalidModelConfigException("Unable to determine model type")
|
||||
raise InvalidModelConfigException("No valid config found")
|
||||
|
||||
@classmethod
|
||||
def get_tag(cls) -> Tag:
|
||||
|
||||
10
invokeai/frontend/web/.eslintignore
Normal file
10
invokeai/frontend/web/.eslintignore
Normal file
@@ -0,0 +1,10 @@
|
||||
dist/
|
||||
static/
|
||||
.husky/
|
||||
node_modules/
|
||||
patches/
|
||||
stats.html
|
||||
index.html
|
||||
.yarn/
|
||||
*.scss
|
||||
src/services/api/schema.ts
|
||||
88
invokeai/frontend/web/.eslintrc.js
Normal file
88
invokeai/frontend/web/.eslintrc.js
Normal file
@@ -0,0 +1,88 @@
|
||||
module.exports = {
|
||||
extends: ['@invoke-ai/eslint-config-react'],
|
||||
plugins: ['path', 'i18next'],
|
||||
rules: {
|
||||
// TODO(psyche): Enable this rule. Requires no default exports in components - many changes.
|
||||
'react-refresh/only-export-components': 'off',
|
||||
// TODO(psyche): Enable this rule. Requires a lot of eslint-disable-next-line comments.
|
||||
'@typescript-eslint/consistent-type-assertions': 'off',
|
||||
// https://github.com/qdanik/eslint-plugin-path
|
||||
'path/no-relative-imports': ['error', { maxDepth: 0 }],
|
||||
// https://github.com/edvardchen/eslint-plugin-i18next/blob/HEAD/docs/rules/no-literal-string.md
|
||||
// TODO: ENABLE THIS RULE BEFORE v6.0.0
|
||||
// 'i18next/no-literal-string': 'error',
|
||||
// https://eslint.org/docs/latest/rules/no-console
|
||||
'no-console': 'warn',
|
||||
// https://eslint.org/docs/latest/rules/no-promise-executor-return
|
||||
'no-promise-executor-return': 'error',
|
||||
// https://eslint.org/docs/latest/rules/require-await
|
||||
'require-await': 'error',
|
||||
// Restrict setActiveTab calls to only use-navigation-api.tsx
|
||||
'no-restricted-syntax': [
|
||||
'error',
|
||||
{
|
||||
selector: 'CallExpression[callee.name="setActiveTab"]',
|
||||
message:
|
||||
'setActiveTab() can only be called from use-navigation-api.tsx. Use navigationApi.switchToTab() instead.',
|
||||
},
|
||||
],
|
||||
// TODO: ENABLE THIS RULE BEFORE v6.0.0
|
||||
'react/display-name': 'off',
|
||||
'no-restricted-properties': [
|
||||
'error',
|
||||
{
|
||||
object: 'crypto',
|
||||
property: 'randomUUID',
|
||||
message: 'Use of crypto.randomUUID is not allowed as it is not available in all browsers.',
|
||||
},
|
||||
{
|
||||
object: 'navigator',
|
||||
property: 'clipboard',
|
||||
message:
|
||||
'The Clipboard API is not available by default in Firefox. Use the `useClipboard` hook instead, which wraps clipboard access to prevent errors.',
|
||||
},
|
||||
],
|
||||
'no-restricted-imports': [
|
||||
'error',
|
||||
{
|
||||
paths: [
|
||||
{
|
||||
name: 'lodash-es',
|
||||
importNames: ['isEqual'],
|
||||
message: 'Please use objectEquals from @observ33r/object-equals instead.',
|
||||
},
|
||||
{
|
||||
name: 'lodash-es',
|
||||
message: 'Please use es-toolkit instead.',
|
||||
},
|
||||
{
|
||||
name: 'es-toolkit',
|
||||
importNames: ['isEqual'],
|
||||
message: 'Please use objectEquals from @observ33r/object-equals instead.',
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
overrides: [
|
||||
/**
|
||||
* Allow setActiveTab calls only in use-navigation-api.tsx
|
||||
*/
|
||||
{
|
||||
files: ['**/use-navigation-api.tsx'],
|
||||
rules: {
|
||||
'no-restricted-syntax': 'off',
|
||||
},
|
||||
},
|
||||
/**
|
||||
* Overrides for stories
|
||||
*/
|
||||
{
|
||||
files: ['*.stories.tsx'],
|
||||
rules: {
|
||||
// We may not have i18n available in stories.
|
||||
'i18next/no-literal-string': 'off',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
@@ -14,4 +14,3 @@ static/
|
||||
src/theme/css/overlayscrollbars.css
|
||||
src/theme_/css/overlayscrollbars.css
|
||||
pnpm-lock.yaml
|
||||
.claude
|
||||
|
||||
11
invokeai/frontend/web/.prettierrc.js
Normal file
11
invokeai/frontend/web/.prettierrc.js
Normal file
@@ -0,0 +1,11 @@
|
||||
module.exports = {
|
||||
...require('@invoke-ai/prettier-config-react'),
|
||||
overrides: [
|
||||
{
|
||||
files: ['public/locales/*.json'],
|
||||
options: {
|
||||
tabWidth: 4,
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
@@ -1,17 +0,0 @@
|
||||
{
|
||||
"$schema": "http://json.schemastore.org/prettierrc",
|
||||
"trailingComma": "es5",
|
||||
"printWidth": 120,
|
||||
"tabWidth": 2,
|
||||
"semi": true,
|
||||
"singleQuote": true,
|
||||
"endOfLine": "auto",
|
||||
"overrides": [
|
||||
{
|
||||
"files": ["public/locales/*.json"],
|
||||
"options": {
|
||||
"tabWidth": 4
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,23 +1,21 @@
|
||||
import { useGlobalModifiersInit } from '@invoke-ai/ui-library';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo, useEffect } from 'react';
|
||||
|
||||
import { useAppDispatch } from '../src/app/store/storeHooks';
|
||||
import { PropsWithChildren, memo, useEffect } from 'react';
|
||||
import { modelChanged } from '../src/features/controlLayers/store/paramsSlice';
|
||||
import { useAppDispatch } from '../src/app/store/storeHooks';
|
||||
import { useGlobalModifiersInit } from '@invoke-ai/ui-library';
|
||||
/**
|
||||
* Initializes some state for storybook. Must be in a different component
|
||||
* so that it is run inside the redux context.
|
||||
*/
|
||||
export const ReduxInit = memo(({ children }: PropsWithChildren) => {
|
||||
export const ReduxInit = memo((props: PropsWithChildren) => {
|
||||
const dispatch = useAppDispatch();
|
||||
useGlobalModifiersInit();
|
||||
useEffect(() => {
|
||||
dispatch(
|
||||
modelChanged({ model: { key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' } })
|
||||
);
|
||||
}, [dispatch]);
|
||||
}, []);
|
||||
|
||||
return children;
|
||||
return props.children;
|
||||
});
|
||||
|
||||
ReduxInit.displayName = 'ReduxInit';
|
||||
|
||||
@@ -2,13 +2,19 @@ import type { StorybookConfig } from '@storybook/react-vite';
|
||||
|
||||
const config: StorybookConfig = {
|
||||
stories: ['../src/**/*.mdx', '../src/**/*.stories.@(js|jsx|mjs|ts|tsx)'],
|
||||
addons: ['@storybook/addon-links', '@storybook/addon-docs'],
|
||||
|
||||
addons: [
|
||||
'@storybook/addon-links',
|
||||
'@storybook/addon-essentials',
|
||||
'@storybook/addon-interactions',
|
||||
'@storybook/addon-storysource',
|
||||
],
|
||||
framework: {
|
||||
name: '@storybook/react-vite',
|
||||
options: {},
|
||||
},
|
||||
|
||||
docs: {
|
||||
autodocs: 'tag',
|
||||
},
|
||||
core: {
|
||||
disableTelemetry: true,
|
||||
},
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { addons } from 'storybook/manager-api';
|
||||
import { themes } from 'storybook/theming';
|
||||
import { addons } from '@storybook/manager-api';
|
||||
import { themes } from '@storybook/theming';
|
||||
|
||||
addons.setConfig({
|
||||
theme: themes.dark,
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
import type { Preview } from '@storybook/react-vite';
|
||||
import { themes } from 'storybook/theming';
|
||||
import { $store } from 'app/store/nanostores/store';
|
||||
import { Preview } from '@storybook/react';
|
||||
import { themes } from '@storybook/theming';
|
||||
import i18n from 'i18next';
|
||||
import { initReactI18next } from 'react-i18next';
|
||||
import { Provider } from 'react-redux';
|
||||
|
||||
import ThemeLocaleProvider from '../src/app/components/ThemeLocaleProvider';
|
||||
import { $baseUrl } from '../src/app/store/nanostores/baseUrl';
|
||||
import { createStore } from '../src/app/store/store';
|
||||
// TODO: Disabled for IDE performance issues with our translation JSON
|
||||
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
|
||||
// @ts-ignore
|
||||
import translationEN from '../public/locales/en.json';
|
||||
import ThemeLocaleProvider from '../src/app/components/ThemeLocaleProvider';
|
||||
import { $baseUrl } from '../src/app/store/nanostores/baseUrl';
|
||||
import { createStore } from '../src/app/store/store';
|
||||
import { ReduxInit } from './ReduxInit';
|
||||
import { $store } from 'app/store/nanostores/store';
|
||||
|
||||
i18n.use(initReactI18next).init({
|
||||
lng: 'en',
|
||||
@@ -47,7 +46,6 @@ const preview: Preview = {
|
||||
parameters: {
|
||||
docs: {
|
||||
theme: themes.dark,
|
||||
codePanel: true,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -1,242 +0,0 @@
|
||||
import js from '@eslint/js';
|
||||
import typescriptEslint from '@typescript-eslint/eslint-plugin';
|
||||
import typescriptParser from '@typescript-eslint/parser';
|
||||
import pluginI18Next from 'eslint-plugin-i18next';
|
||||
import pluginImport from 'eslint-plugin-import';
|
||||
import pluginPath from 'eslint-plugin-path';
|
||||
import pluginReact from 'eslint-plugin-react';
|
||||
import pluginReactHooks from 'eslint-plugin-react-hooks';
|
||||
import pluginReactRefresh from 'eslint-plugin-react-refresh';
|
||||
import pluginSimpleImportSort from 'eslint-plugin-simple-import-sort';
|
||||
import pluginStorybook from 'eslint-plugin-storybook';
|
||||
import pluginUnusedImports from 'eslint-plugin-unused-imports';
|
||||
import globals from 'globals';
|
||||
|
||||
export default [
|
||||
js.configs.recommended,
|
||||
|
||||
{
|
||||
languageOptions: {
|
||||
parser: typescriptParser,
|
||||
parserOptions: {
|
||||
ecmaFeatures: {
|
||||
jsx: true,
|
||||
},
|
||||
},
|
||||
globals: {
|
||||
...globals.browser,
|
||||
...globals.node,
|
||||
GlobalCompositeOperation: 'readonly',
|
||||
RequestInit: 'readonly',
|
||||
},
|
||||
},
|
||||
|
||||
files: ['**/*.ts', '**/*.tsx', '**/*.js', '**/*.jsx'],
|
||||
|
||||
plugins: {
|
||||
react: pluginReact,
|
||||
'@typescript-eslint': typescriptEslint,
|
||||
'react-hooks': pluginReactHooks,
|
||||
import: pluginImport,
|
||||
'unused-imports': pluginUnusedImports,
|
||||
'simple-import-sort': pluginSimpleImportSort,
|
||||
'react-refresh': pluginReactRefresh.configs.vite,
|
||||
path: pluginPath,
|
||||
i18next: pluginI18Next,
|
||||
storybook: pluginStorybook,
|
||||
},
|
||||
|
||||
rules: {
|
||||
...typescriptEslint.configs.recommended.rules,
|
||||
...pluginReact.configs.recommended.rules,
|
||||
...pluginReact.configs['jsx-runtime'].rules,
|
||||
...pluginReactHooks.configs.recommended.rules,
|
||||
...pluginStorybook.configs.recommended.rules,
|
||||
|
||||
'react/jsx-no-bind': [
|
||||
'error',
|
||||
{
|
||||
allowBind: true,
|
||||
},
|
||||
],
|
||||
|
||||
'react/jsx-curly-brace-presence': [
|
||||
'error',
|
||||
{
|
||||
props: 'never',
|
||||
children: 'never',
|
||||
},
|
||||
],
|
||||
|
||||
'react-hooks/exhaustive-deps': 'error',
|
||||
|
||||
curly: 'error',
|
||||
'no-var': 'error',
|
||||
'brace-style': 'error',
|
||||
'prefer-template': 'error',
|
||||
radix: 'error',
|
||||
'space-before-blocks': 'error',
|
||||
eqeqeq: 'error',
|
||||
'one-var': ['error', 'never'],
|
||||
'no-eval': 'error',
|
||||
'no-extend-native': 'error',
|
||||
'no-implied-eval': 'error',
|
||||
'no-label-var': 'error',
|
||||
'no-return-assign': 'error',
|
||||
'no-sequences': 'error',
|
||||
'no-template-curly-in-string': 'error',
|
||||
'no-throw-literal': 'error',
|
||||
'no-unmodified-loop-condition': 'error',
|
||||
'import/no-duplicates': 'error',
|
||||
'import/prefer-default-export': 'off',
|
||||
'unused-imports/no-unused-imports': 'error',
|
||||
|
||||
'unused-imports/no-unused-vars': [
|
||||
'error',
|
||||
{
|
||||
vars: 'all',
|
||||
varsIgnorePattern: '^_',
|
||||
args: 'after-used',
|
||||
argsIgnorePattern: '^_',
|
||||
},
|
||||
],
|
||||
|
||||
'simple-import-sort/imports': 'error',
|
||||
'simple-import-sort/exports': 'error',
|
||||
'@typescript-eslint/no-unused-vars': 'off',
|
||||
|
||||
'@typescript-eslint/ban-ts-comment': [
|
||||
'error',
|
||||
{
|
||||
'ts-expect-error': 'allow-with-description',
|
||||
'ts-ignore': true,
|
||||
'ts-nocheck': true,
|
||||
'ts-check': false,
|
||||
minimumDescriptionLength: 10,
|
||||
},
|
||||
],
|
||||
|
||||
'@typescript-eslint/no-empty-interface': [
|
||||
'error',
|
||||
{
|
||||
allowSingleExtends: true,
|
||||
},
|
||||
],
|
||||
|
||||
'@typescript-eslint/consistent-type-imports': [
|
||||
'error',
|
||||
{
|
||||
prefer: 'type-imports',
|
||||
fixStyle: 'separate-type-imports',
|
||||
disallowTypeAnnotations: true,
|
||||
},
|
||||
],
|
||||
|
||||
'@typescript-eslint/no-import-type-side-effects': 'error',
|
||||
|
||||
'@typescript-eslint/consistent-type-assertions': [
|
||||
'error',
|
||||
{
|
||||
assertionStyle: 'as',
|
||||
},
|
||||
],
|
||||
|
||||
'path/no-relative-imports': [
|
||||
'error',
|
||||
{
|
||||
maxDepth: 0,
|
||||
},
|
||||
],
|
||||
|
||||
'no-console': 'warn',
|
||||
'no-promise-executor-return': 'error',
|
||||
'require-await': 'error',
|
||||
|
||||
'no-restricted-syntax': [
|
||||
'error',
|
||||
{
|
||||
selector: 'CallExpression[callee.name="setActiveTab"]',
|
||||
message:
|
||||
'setActiveTab() can only be called from use-navigation-api.tsx. Use navigationApi.switchToTab() instead.',
|
||||
},
|
||||
],
|
||||
|
||||
'no-restricted-properties': [
|
||||
'error',
|
||||
{
|
||||
object: 'crypto',
|
||||
property: 'randomUUID',
|
||||
message: 'Use of crypto.randomUUID is not allowed as it is not available in all browsers.',
|
||||
},
|
||||
{
|
||||
object: 'navigator',
|
||||
property: 'clipboard',
|
||||
message:
|
||||
'The Clipboard API is not available by default in Firefox. Use the `useClipboard` hook instead, which wraps clipboard access to prevent errors.',
|
||||
},
|
||||
],
|
||||
|
||||
// Typescript handles this for us: https://eslint.org/docs/latest/rules/no-redeclare#handled_by_typescript
|
||||
'no-redeclare': 'off',
|
||||
|
||||
'no-restricted-imports': [
|
||||
'error',
|
||||
{
|
||||
paths: [
|
||||
{
|
||||
name: 'lodash-es',
|
||||
importNames: ['isEqual'],
|
||||
message: 'Please use objectEquals from @observ33r/object-equals instead.',
|
||||
},
|
||||
{
|
||||
name: 'lodash-es',
|
||||
message: 'Please use es-toolkit instead.',
|
||||
},
|
||||
{
|
||||
name: 'es-toolkit',
|
||||
importNames: ['isEqual'],
|
||||
message: 'Please use objectEquals from @observ33r/object-equals instead.',
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
settings: {
|
||||
react: {
|
||||
version: 'detect',
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
files: ['**/use-navigation-api.tsx'],
|
||||
rules: {
|
||||
'no-restricted-syntax': 'off',
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
files: ['**/*.stories.tsx'],
|
||||
rules: {
|
||||
'i18next/no-literal-string': 'off',
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
ignores: [
|
||||
'**/dist/',
|
||||
'**/static/',
|
||||
'**/.husky/',
|
||||
'**/node_modules/',
|
||||
'**/patches/',
|
||||
'**/stats.html',
|
||||
'**/index.html',
|
||||
'**/.yarn/',
|
||||
'**/*.scss',
|
||||
'src/services/api/schema.ts',
|
||||
'.prettierrc.js',
|
||||
'.storybook',
|
||||
],
|
||||
},
|
||||
];
|
||||
@@ -14,7 +14,6 @@ const config: KnipConfig = {
|
||||
'src/features/controlLayers/konva/util.ts',
|
||||
// Will be using this
|
||||
'src/common/hooks/useAsyncState.ts',
|
||||
'src/app/store/use-debounced-app-selector.ts',
|
||||
],
|
||||
ignoreBinaries: ['only-allow'],
|
||||
paths: {
|
||||
|
||||
@@ -47,25 +47,25 @@
|
||||
"@fontsource-variable/inter": "^5.2.6",
|
||||
"@invoke-ai/ui-library": "^0.0.46",
|
||||
"@nanostores/react": "^1.0.0",
|
||||
"@observ33r/object-equals": "^1.1.5",
|
||||
"@observ33r/object-equals": "^1.1.4",
|
||||
"@reduxjs/toolkit": "2.8.2",
|
||||
"@roarr/browser-log-writer": "^1.3.0",
|
||||
"@xyflow/react": "^12.8.2",
|
||||
"ag-psd": "^28.2.2",
|
||||
"@xyflow/react": "^12.7.1",
|
||||
"ag-psd": "^28.2.1",
|
||||
"async-mutex": "^0.5.0",
|
||||
"chakra-react-select": "^4.9.2",
|
||||
"cmdk": "^1.1.1",
|
||||
"compare-versions": "^6.1.1",
|
||||
"dockview": "^4.4.1",
|
||||
"es-toolkit": "^1.39.7",
|
||||
"dockview": "^4.4.0",
|
||||
"es-toolkit": "^1.39.5",
|
||||
"filesize": "^10.1.6",
|
||||
"fracturedjsonjs": "^4.1.0",
|
||||
"framer-motion": "^11.10.0",
|
||||
"i18next": "^25.3.2",
|
||||
"i18next": "^25.2.1",
|
||||
"i18next-http-backend": "^3.0.2",
|
||||
"idb-keyval": "6.2.2",
|
||||
"idb-keyval": "^6.2.2",
|
||||
"jsondiffpatch": "^0.7.3",
|
||||
"konva": "^9.3.22",
|
||||
"konva": "^9.3.20",
|
||||
"linkify-react": "^4.3.1",
|
||||
"linkifyjs": "^4.3.1",
|
||||
"lru-cache": "^11.1.0",
|
||||
@@ -83,7 +83,7 @@
|
||||
"react-dom": "^18.3.1",
|
||||
"react-dropzone": "^14.3.8",
|
||||
"react-error-boundary": "^5.0.0",
|
||||
"react-hook-form": "^7.60.0",
|
||||
"react-hook-form": "^7.58.1",
|
||||
"react-hotkeys-hook": "4.5.0",
|
||||
"react-i18next": "^15.5.3",
|
||||
"react-icons": "^5.5.0",
|
||||
@@ -103,7 +103,7 @@
|
||||
"use-debounce": "^10.0.5",
|
||||
"use-device-pixel-ratio": "^1.1.2",
|
||||
"uuid": "^11.1.0",
|
||||
"zod": "^4.0.5",
|
||||
"zod": "^3.25.67",
|
||||
"zod-validation-error": "^3.5.2"
|
||||
},
|
||||
"peerDependencies": {
|
||||
@@ -111,43 +111,39 @@
|
||||
"react-dom": "^18.2.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.31.0",
|
||||
"@storybook/addon-docs": "^9.0.17",
|
||||
"@storybook/addon-links": "^9.0.17",
|
||||
"@storybook/react-vite": "^9.0.17",
|
||||
"@invoke-ai/eslint-config-react": "^0.0.14",
|
||||
"@invoke-ai/prettier-config-react": "^0.0.7",
|
||||
"@storybook/addon-essentials": "^8.6.12",
|
||||
"@storybook/addon-interactions": "^8.6.12",
|
||||
"@storybook/addon-links": "^8.6.12",
|
||||
"@storybook/addon-storysource": "^8.6.12",
|
||||
"@storybook/manager-api": "^8.6.12",
|
||||
"@storybook/react": "^8.6.12",
|
||||
"@storybook/react-vite": "^8.6.12",
|
||||
"@storybook/theming": "^8.6.12",
|
||||
"@types/node": "^22.15.1",
|
||||
"@types/react": "^18.3.11",
|
||||
"@types/react-dom": "^18.3.0",
|
||||
"@types/uuid": "^10.0.0",
|
||||
"@typescript-eslint/eslint-plugin": "^8.37.0",
|
||||
"@typescript-eslint/parser": "^8.37.0",
|
||||
"@vitejs/plugin-react-swc": "^3.9.0",
|
||||
"@vitest/coverage-v8": "^3.1.2",
|
||||
"@vitest/ui": "^3.1.2",
|
||||
"concurrently": "^9.1.2",
|
||||
"csstype": "^3.1.3",
|
||||
"dpdm": "^3.14.0",
|
||||
"eslint": "^9.31.0",
|
||||
"eslint-plugin-i18next": "^6.1.2",
|
||||
"eslint-plugin-import": "^2.29.1",
|
||||
"eslint-plugin-path": "^2.0.3",
|
||||
"eslint-plugin-react": "^7.33.2",
|
||||
"eslint-plugin-react-hooks": "^5.2.0",
|
||||
"eslint-plugin-react-refresh": "^0.4.5",
|
||||
"eslint-plugin-simple-import-sort": "^12.0.0",
|
||||
"eslint-plugin-storybook": "^9.0.17",
|
||||
"eslint-plugin-unused-imports": "^4.1.4",
|
||||
"globals": "^16.3.0",
|
||||
"eslint": "^8.57.1",
|
||||
"eslint-plugin-i18next": "^6.1.1",
|
||||
"eslint-plugin-path": "^1.3.0",
|
||||
"knip": "^5.61.3",
|
||||
"openapi-types": "^12.1.3",
|
||||
"openapi-typescript": "^7.6.1",
|
||||
"prettier": "^3.5.3",
|
||||
"rollup-plugin-visualizer": "^6.0.3",
|
||||
"storybook": "^9.0.17",
|
||||
"rollup-plugin-visualizer": "^5.14.0",
|
||||
"storybook": "^8.6.12",
|
||||
"tsafe": "^1.8.5",
|
||||
"type-fest": "^4.40.0",
|
||||
"typescript": "^5.8.3",
|
||||
"vite": "^7.0.5",
|
||||
"vite": "^7.0.2",
|
||||
"vite-plugin-css-injected-by-js": "^3.5.2",
|
||||
"vite-plugin-dts": "^4.5.3",
|
||||
"vite-plugin-eslint": "^1.8.1",
|
||||
|
||||
2152
invokeai/frontend/web/pnpm-lock.yaml
generated
2152
invokeai/frontend/web/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -1775,20 +1775,6 @@
|
||||
"Structure controls how closely the output image will keep to the layout of the original. Low structure allows major changes, while high structure strictly maintains the original composition and layout."
|
||||
]
|
||||
},
|
||||
"tileSize": {
|
||||
"heading": "Tile Size",
|
||||
"paragraphs": [
|
||||
"Controls the size of tiles used during the upscaling process. Larger tiles use more memory but may produce better results.",
|
||||
"SD1.5 models default to 768, while SDXL models default to 1024. Reduce tile size if you encounter memory issues."
|
||||
]
|
||||
},
|
||||
"tileOverlap": {
|
||||
"heading": "Tile Overlap",
|
||||
"paragraphs": [
|
||||
"Controls the overlap between adjacent tiles during upscaling. Higher overlap values help reduce visible seams between tiles but use more memory.",
|
||||
"The default value of 128 works well for most cases, but you can adjust based on your specific needs and memory constraints."
|
||||
]
|
||||
},
|
||||
"fluxDevLicense": {
|
||||
"heading": "Non-Commercial License",
|
||||
"paragraphs": [
|
||||
@@ -1976,6 +1962,7 @@
|
||||
"recalculateRects": "Recalculate Rects",
|
||||
"clipToBbox": "Clip Strokes to Bbox",
|
||||
"outputOnlyMaskedRegions": "Output Only Generated Regions",
|
||||
"saveAllImagesToGallery": "Save All Images to Gallery",
|
||||
"addLayer": "Add Layer",
|
||||
"duplicate": "Duplicate",
|
||||
"moveToFront": "Move to Front",
|
||||
@@ -2100,9 +2087,9 @@
|
||||
"resetCanvasLayers": "Reset Canvas Layers",
|
||||
"resetGenerationSettings": "Reset Generation Settings",
|
||||
"replaceCurrent": "Replace Current",
|
||||
"controlLayerEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the gallery onto this layer, <PullBboxButton>pull the bounding box into this layer</PullBboxButton>, or draw on the canvas to get started.",
|
||||
"referenceImageEmptyStateWithCanvasOptions": "<UploadButton>Upload an image</UploadButton>, drag an image from the gallery onto this Reference Image or <PullBboxButton>pull the bounding box into this Reference Image</PullBboxButton> to get started.",
|
||||
"referenceImageEmptyState": "<UploadButton>Upload an image</UploadButton> or drag an image from the gallery onto this Reference Image to get started.",
|
||||
"controlLayerEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer, <PullBboxButton>pull the bounding box into this layer</PullBboxButton>, or draw on the canvas to get started.",
|
||||
"referenceImageEmptyStateWithCanvasOptions": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this Reference Image or <PullBboxButton>pull the bounding box into this Reference Image</PullBboxButton> to get started.",
|
||||
"referenceImageEmptyState": "<UploadButton>Upload an image</UploadButton> or drag an image from the <GalleryButton>gallery</GalleryButton> onto this Reference Image to get started.",
|
||||
"uploadOrDragAnImage": "Drag an image from the gallery or <UploadButton>upload an image</UploadButton>.",
|
||||
"imageNoise": "Image Noise",
|
||||
"denoiseLimit": "Denoise Limit",
|
||||
@@ -2345,8 +2332,7 @@
|
||||
"alert": "Preserving Masked Region"
|
||||
},
|
||||
"saveAllImagesToGallery": {
|
||||
"label": "Send New Generations to Gallery",
|
||||
"alert": "Sending new generations to Gallery, bypassing Canvas"
|
||||
"alert": "Saving All Images to Gallery"
|
||||
},
|
||||
"isolatedStagingPreview": "Isolated Staging Preview",
|
||||
"isolatedPreview": "Isolated Preview",
|
||||
@@ -2410,9 +2396,6 @@
|
||||
"upscaleModel": "Upscale Model",
|
||||
"postProcessingModel": "Post-Processing Model",
|
||||
"scale": "Scale",
|
||||
"tileControl": "Tile Control",
|
||||
"tileSize": "Tile Size",
|
||||
"tileOverlap": "Tile Overlap",
|
||||
"postProcessingMissingModelWarning": "Visit the <LinkComponent>Model Manager</LinkComponent> to install a post-processing (image to image) model.",
|
||||
"missingModelsWarning": "Visit the <LinkComponent>Model Manager</LinkComponent> to install the required models:",
|
||||
"mainModelDesc": "Main model (SD1.5 or SDXL architecture)",
|
||||
|
||||
@@ -30,16 +30,16 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
|
||||
}, [clearStorage]);
|
||||
|
||||
return (
|
||||
<ThemeLocaleProvider>
|
||||
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
|
||||
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
|
||||
<ThemeLocaleProvider>
|
||||
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
|
||||
<AppContent />
|
||||
{!didStudioInit && <Loading />}
|
||||
</Box>
|
||||
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
|
||||
<GlobalModalIsolator />
|
||||
</ErrorBoundary>
|
||||
</ThemeLocaleProvider>
|
||||
</ThemeLocaleProvider>
|
||||
</ErrorBoundary>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import { useGlobalModifiersInit } from '@invoke-ai/ui-library';
|
||||
import { setupListeners } from '@reduxjs/toolkit/query';
|
||||
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
|
||||
import { useStudioInitAction } from 'app/hooks/useStudioInitAction';
|
||||
import { useSyncLangDirection } from 'app/hooks/useSyncLangDirection';
|
||||
import { useSyncQueueStatus } from 'app/hooks/useSyncQueueStatus';
|
||||
import { useLogger } from 'app/logging/useLogger';
|
||||
import { useSyncLoggingConfig } from 'app/logging/useSyncLoggingConfig';
|
||||
@@ -16,8 +15,6 @@ import { useDndMonitor } from 'features/dnd/useDndMonitor';
|
||||
import { useDynamicPromptsWatcher } from 'features/dynamicPrompts/hooks/useDynamicPromptsWatcher';
|
||||
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
|
||||
import { useWorkflowBuilderWatcher } from 'features/nodes/components/sidePanel/workflow/IsolatedWorkflowBuilderWatcher';
|
||||
import { useSyncExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
|
||||
import { useSyncNodeErrors } from 'features/nodes/store/util/fieldValidators';
|
||||
import { useReadinessWatcher } from 'features/queue/store/readiness';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { selectLanguage } from 'features/system/store/systemSelectors';
|
||||
@@ -50,13 +47,10 @@ export const GlobalHookIsolator = memo(
|
||||
useCloseChakraTooltipsOnDragFix();
|
||||
useNavigationApi();
|
||||
useDndMonitor();
|
||||
useSyncNodeErrors();
|
||||
useSyncLangDirection();
|
||||
|
||||
// Persistent subscription to the queue counts query - canvas relies on this to know if there are pending
|
||||
// and/or in progress canvas sessions.
|
||||
useGetQueueCountsByDestinationQuery(queueCountArg);
|
||||
useSyncExecutionState();
|
||||
|
||||
useEffect(() => {
|
||||
i18n.changeLanguage(language);
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import { GlobalImageHotkeys } from 'app/components/GlobalImageHotkeys';
|
||||
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
|
||||
import { CanvasPasteModal } from 'features/controlLayers/components/CanvasPasteModal';
|
||||
import {
|
||||
NewCanvasSessionDialog,
|
||||
NewGallerySessionDialog,
|
||||
} from 'features/controlLayers/components/NewSessionConfirmationAlertDialog';
|
||||
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { DeleteImageModal } from 'features/deleteImageModal/components/DeleteImageModal';
|
||||
import { FullscreenDropzone } from 'features/dnd/FullscreenDropzone';
|
||||
@@ -46,6 +50,8 @@ export const GlobalModalIsolator = memo(() => {
|
||||
<RefreshAfterResetModal />
|
||||
<DeleteBoardModal />
|
||||
<GlobalImageHotkeys />
|
||||
<NewGallerySessionDialog />
|
||||
<NewCanvasSessionDialog />
|
||||
<ImageContextMenu />
|
||||
<FullscreenDropzone />
|
||||
<VideosModal />
|
||||
|
||||
@@ -317,7 +317,7 @@ const InvokeAIUI = ({
|
||||
if (import.meta.env.MODE === 'development') {
|
||||
window.$store = $store;
|
||||
}
|
||||
return () => {
|
||||
() => {
|
||||
$store.set(undefined);
|
||||
if (import.meta.env.MODE === 'development') {
|
||||
window.$store = undefined;
|
||||
|
||||
@@ -3,39 +3,43 @@ import 'overlayscrollbars/overlayscrollbars.css';
|
||||
import '@xyflow/react/dist/base.css';
|
||||
import 'common/components/OverlayScrollbars/overlayscrollbars.css';
|
||||
|
||||
import { ChakraProvider, DarkMode, extendTheme, theme as baseTheme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $direction } from 'app/hooks/useSyncLangDirection';
|
||||
import { ChakraProvider, DarkMode, extendTheme, theme as _theme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
|
||||
import type { ReactNode } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { memo, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type ThemeLocaleProviderProps = {
|
||||
children: ReactNode;
|
||||
};
|
||||
|
||||
const buildTheme = (direction: 'ltr' | 'rtl') => {
|
||||
return extendTheme({
|
||||
...baseTheme,
|
||||
direction,
|
||||
shadows: {
|
||||
...baseTheme.shadows,
|
||||
selected:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
hoverSelected:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
hoverUnselected:
|
||||
'inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-800)',
|
||||
selectedForCompare:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
hoverSelectedForCompare:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
||||
const direction = useStore($direction);
|
||||
const theme = useMemo(() => buildTheme(direction), [direction]);
|
||||
const { i18n } = useTranslation();
|
||||
|
||||
const direction = i18n.dir();
|
||||
|
||||
const theme = useMemo(() => {
|
||||
return extendTheme({
|
||||
..._theme,
|
||||
direction,
|
||||
shadows: {
|
||||
..._theme.shadows,
|
||||
selected:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
hoverSelected:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
hoverUnselected:
|
||||
'inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-800)',
|
||||
selectedForCompare:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
hoverSelectedForCompare:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
},
|
||||
});
|
||||
}, [direction]);
|
||||
|
||||
useEffect(() => {
|
||||
document.body.dir = direction;
|
||||
}, [direction]);
|
||||
|
||||
return (
|
||||
<ChakraProvider theme={theme} toastOptions={TOAST_OPTIONS}>
|
||||
|
||||
@@ -21,6 +21,7 @@ import { $isStylePresetsMenuOpen, activeStylePresetIdChanged } from 'features/st
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { navigationApi } from 'features/ui/layouts/navigation-api';
|
||||
import { LAUNCHPAD_PANEL_ID, WORKSPACE_PANEL_ID } from 'features/ui/layouts/shared';
|
||||
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
|
||||
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
|
||||
import { atom } from 'nanostores';
|
||||
import { useCallback, useEffect } from 'react';
|
||||
@@ -164,6 +165,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
|
||||
// Go to the generate tab, open the launchpad
|
||||
await navigationApi.focusPanel('generate', LAUNCHPAD_PANEL_ID);
|
||||
store.dispatch(paramsReset());
|
||||
store.dispatch(activeTabCanvasRightPanelChanged('gallery'));
|
||||
break;
|
||||
case 'canvas':
|
||||
// Go to the canvas tab, open the launchpad
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import { atom } from 'nanostores';
|
||||
import { useEffect } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
/**
|
||||
* Global atom storing the language direction, to be consumed by the Chakra theme.
|
||||
*
|
||||
* Why do we need this? We have a kind of catch-22:
|
||||
* - The Chakra theme needs to know the language direction to apply the correct styles.
|
||||
* - The language direction is determined by i18n and the language selection.
|
||||
* - We want our error boundary to be themed.
|
||||
* - It's possible that i18n can throw if the language selection is invalid or not supported.
|
||||
*
|
||||
* Previously, we had the logic in this file in the theme provider, which wrapped the error boundary. The error
|
||||
* was properly themed. But then, if i18n threw in the theme provider, the error boundary does not catch the
|
||||
* error. The app would crash to a white screen.
|
||||
*
|
||||
* We tried swapping the component hierarchy so that the error boundary wraps the theme provider, but then the
|
||||
* error boundary isn't themed!
|
||||
*
|
||||
* The solution is to move this i18n direction logic out of the theme provider and into a hook that we can use
|
||||
* within the error boundary. The error boundary will be themed, _and_ catch any i18n errors.
|
||||
*/
|
||||
export const $direction = atom<'ltr' | 'rtl'>('ltr');
|
||||
|
||||
export const useSyncLangDirection = () => {
|
||||
useAssertSingleton('useSyncLangDirection');
|
||||
const { i18n, t } = useTranslation();
|
||||
|
||||
useEffect(() => {
|
||||
const direction = i18n.dir();
|
||||
$direction.set(direction);
|
||||
document.body.dir = direction;
|
||||
}, [i18n, t]);
|
||||
};
|
||||
@@ -2,7 +2,7 @@ import { createLogWriter } from '@roarr/browser-log-writer';
|
||||
import { atom } from 'nanostores';
|
||||
import type { Logger, MessageSerializer } from 'roarr';
|
||||
import { ROARR, Roarr } from 'roarr';
|
||||
import { z } from 'zod';
|
||||
import { z } from 'zod/v4';
|
||||
|
||||
const serializeMessage: MessageSerializer = (message) => {
|
||||
return JSON.stringify(message);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
|
||||
import { modelChanged, syncedToOptimalDimension, vaeSelected } from 'features/controlLayers/store/paramsSlice';
|
||||
import { refImageModelChanged, selectReferenceImageEntities } from 'features/controlLayers/store/refImagesSlice';
|
||||
@@ -152,8 +152,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
if (modelBase !== state.params.model?.base) {
|
||||
// Sync generate tab settings whenever the model base changes
|
||||
dispatch(syncedToOptimalDimension());
|
||||
const isStaging = buildSelectIsStaging(selectCanvasSessionId(state))(state);
|
||||
if (!isStaging) {
|
||||
if (!selectIsStaging(state)) {
|
||||
// Canvas tab only syncs if not staging
|
||||
dispatch(bboxSyncedToOptimalDimension());
|
||||
}
|
||||
|
||||
@@ -15,11 +15,7 @@ import { refImageModelChanged, selectRefImagesSlice } from 'features/controlLaye
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { getEntityIdentifier, isFLUXReduxConfig, isIPAdapterConfig } from 'features/controlLayers/store/types';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import {
|
||||
postProcessingModelChanged,
|
||||
tileControlnetModelChanged,
|
||||
upscaleModelChanged,
|
||||
} from 'features/parameters/store/upscaleSlice';
|
||||
import { postProcessingModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import {
|
||||
zParameterCLIPEmbedModel,
|
||||
zParameterSpandrelImageToImageModel,
|
||||
@@ -32,7 +28,6 @@ import type { AnyModelConfig } from 'services/api/types';
|
||||
import {
|
||||
isCLIPEmbedModelConfig,
|
||||
isControlLayerModelConfig,
|
||||
isControlNetModelConfig,
|
||||
isFluxReduxModelConfig,
|
||||
isFluxVAEModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
@@ -76,7 +71,6 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
|
||||
handleControlAdapterModels(models, state, dispatch, log);
|
||||
handlePostProcessingModel(models, state, dispatch, log);
|
||||
handleUpscaleModel(models, state, dispatch, log);
|
||||
handleTileControlNetModel(models, state, dispatch, log);
|
||||
handleIPAdapterModels(models, state, dispatch, log);
|
||||
handleT5EncoderModels(models, state, dispatch, log);
|
||||
handleCLIPEmbedModels(models, state, dispatch, log);
|
||||
@@ -351,46 +345,6 @@ const handleUpscaleModel: ModelHandler = (models, state, dispatch, log) => {
|
||||
}
|
||||
};
|
||||
|
||||
const handleTileControlNetModel: ModelHandler = (models, state, dispatch, log) => {
|
||||
const selectedTileControlNetModel = state.upscale.tileControlnetModel;
|
||||
const controlNetModels = models.filter(isControlNetModelConfig);
|
||||
|
||||
// If the currently selected model is available, we don't need to do anything
|
||||
if (selectedTileControlNetModel && controlNetModels.some((m) => m.key === selectedTileControlNetModel.key)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// The only way we have to identify a model as a tile model is by its name containing 'tile' :)
|
||||
const tileModel = controlNetModels.find((m) => m.name.toLowerCase().includes('tile'));
|
||||
|
||||
// If we have a tile model, select it
|
||||
if (tileModel) {
|
||||
log.debug(
|
||||
{ selectedTileControlNetModel, tileModel },
|
||||
'No selected tile ControlNet model or selected model is not available, selecting tile model'
|
||||
);
|
||||
dispatch(tileControlnetModelChanged(tileModel));
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, select the first available ControlNet model
|
||||
const firstModel = controlNetModels[0] || null;
|
||||
if (firstModel) {
|
||||
log.debug(
|
||||
{ selectedTileControlNetModel, firstModel },
|
||||
'No tile ControlNet model found, selecting first available ControlNet model'
|
||||
);
|
||||
dispatch(tileControlnetModelChanged(firstModel));
|
||||
return;
|
||||
}
|
||||
|
||||
// No available models, we should clear the selected model - but only if we have one selected
|
||||
if (selectedTileControlNetModel) {
|
||||
log.debug({ selectedTileControlNetModel }, 'Selected tile ControlNet model is not available, clearing');
|
||||
dispatch(tileControlnetModelChanged(null));
|
||||
}
|
||||
};
|
||||
|
||||
const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
|
||||
const selectedT5EncoderModel = state.params.t5EncoderModel;
|
||||
const t5EncoderModels = models.filter((m) => isT5EncoderModelConfig(m));
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { isNil } from 'es-toolkit';
|
||||
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { buildSelectIsStaging, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import {
|
||||
heightChanged,
|
||||
setCfgRescaleMultiplier,
|
||||
@@ -115,8 +115,7 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
|
||||
}
|
||||
const setSizeOptions = { updateAspectRatio: true, clamp: true };
|
||||
|
||||
const isStaging = buildSelectIsStaging(selectCanvasSessionId(state))(state);
|
||||
|
||||
const isStaging = selectIsStaging(getState());
|
||||
const activeTab = selectActiveTab(getState());
|
||||
if (activeTab === 'generate') {
|
||||
if (isParameterWidth(width)) {
|
||||
|
||||
@@ -67,8 +67,6 @@ export type Feature =
|
||||
| 'scale'
|
||||
| 'creativity'
|
||||
| 'structure'
|
||||
| 'tileSize'
|
||||
| 'tileOverlap'
|
||||
| 'optimizedDenoising'
|
||||
| 'fluxDevLicense';
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import { atom, computed } from 'nanostores';
|
||||
import type { RefObject } from 'react';
|
||||
import { useEffect } from 'react';
|
||||
import { objectKeys } from 'tsafe';
|
||||
import z from 'zod/v4';
|
||||
|
||||
/**
|
||||
* We need to manage focus regions to conditionally enable hotkeys:
|
||||
@@ -27,7 +28,10 @@ import { objectKeys } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
const REGION_NAMES = [
|
||||
/**
|
||||
* The names of the focus regions.
|
||||
*/
|
||||
const zFocusRegionName = z.enum([
|
||||
'launchpad',
|
||||
'viewer',
|
||||
'gallery',
|
||||
@@ -37,16 +41,13 @@ const REGION_NAMES = [
|
||||
'workflows',
|
||||
'progress',
|
||||
'settings',
|
||||
] as const;
|
||||
/**
|
||||
* The names of the focus regions.
|
||||
*/
|
||||
export type FocusRegionName = (typeof REGION_NAMES)[number];
|
||||
]);
|
||||
export type FocusRegionName = z.infer<typeof zFocusRegionName>;
|
||||
|
||||
/**
|
||||
* A map of focus regions to the elements that are part of that region.
|
||||
*/
|
||||
const REGION_TARGETS: Record<FocusRegionName, Set<HTMLElement>> = REGION_NAMES.reduce(
|
||||
const REGION_TARGETS: Record<FocusRegionName, Set<HTMLElement>> = zFocusRegionName.options.values().reduce(
|
||||
(acc, region) => {
|
||||
acc[region] = new Set<HTMLElement>();
|
||||
return acc;
|
||||
|
||||
@@ -6,7 +6,7 @@ import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import { selectIsClientSideUploadEnabled } from 'features/system/store/configSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import type { Accept, FileRejection } from 'react-dropzone';
|
||||
import type { FileRejection } from 'react-dropzone';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiUploadBold } from 'react-icons/pi';
|
||||
@@ -15,18 +15,6 @@ import type { ImageDTO } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
import type { SetOptional } from 'type-fest';
|
||||
|
||||
const addUpperCaseReducer = (acc: string[], ext: string) => {
|
||||
acc.push(ext);
|
||||
acc.push(ext.toUpperCase());
|
||||
return acc;
|
||||
};
|
||||
|
||||
export const dropzoneAccept: Accept = {
|
||||
'image/png': ['.png'].reduce(addUpperCaseReducer, [] as string[]),
|
||||
'image/jpeg': ['.jpg', '.jpeg', '.png'].reduce(addUpperCaseReducer, [] as string[]),
|
||||
'image/webp': ['.webp'].reduce(addUpperCaseReducer, [] as string[]),
|
||||
};
|
||||
|
||||
import { useClientSideUpload } from './useClientSideUpload';
|
||||
type UseImageUploadButtonArgs =
|
||||
| {
|
||||
@@ -176,7 +164,11 @@ export const useImageUploadButton = ({
|
||||
getInputProps: getUploadInputProps,
|
||||
open: openUploader,
|
||||
} = useDropzone({
|
||||
accept: dropzoneAccept,
|
||||
accept: {
|
||||
'image/png': ['.png'],
|
||||
'image/jpeg': ['.jpg', '.jpeg', '.png'],
|
||||
'image/webp': ['.webp'],
|
||||
},
|
||||
onDropAccepted,
|
||||
onDropRejected,
|
||||
disabled: isDisabled,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import type { MouseEvent } from 'react';
|
||||
|
||||
export const preventDefault = (e: MouseEvent) => {
|
||||
export const preventDefault = (e: React.MouseEvent) => {
|
||||
e.preventDefault();
|
||||
};
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import type React from 'react';
|
||||
import { memo } from 'react';
|
||||
|
||||
/**
|
||||
* A typed version of React.memo, useful for components that take generics.
|
||||
*/
|
||||
export const typedMemo: <T extends keyof React.JSX.IntrinsicElements | React.JSXElementConstructor<any>>(
|
||||
export const typedMemo: <T extends keyof JSX.IntrinsicElements | React.JSXElementConstructor<any>>(
|
||||
component: T,
|
||||
propsAreEqual?: (prevProps: React.ComponentProps<T>, nextProps: React.ComponentProps<T>) => boolean
|
||||
) => T & { displayName?: string } = memo;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import type { z } from 'zod';
|
||||
import type { z } from 'zod/v4';
|
||||
|
||||
/**
|
||||
* Helper to create a type guard from a zod schema. The type guard will infer the schema's TS type.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, ConfirmationAlertDialog, Flex, FormControl, Text } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import {
|
||||
@@ -13,7 +14,7 @@ import { useTranslation } from 'react-i18next';
|
||||
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
|
||||
import { useAddImagesToBoardMutation, useRemoveImagesFromBoardMutation } from 'services/api/endpoints/images';
|
||||
|
||||
const selectImagesToChange = createSelector(
|
||||
const selectImagesToChange = createMemoizedSelector(
|
||||
selectChangeBoardModalSlice,
|
||||
(changeBoardModal) => changeBoardModal.image_names
|
||||
);
|
||||
|
||||
@@ -13,7 +13,7 @@ export const CanvasAlertsSaveAllImagesToGallery = memo(() => {
|
||||
}
|
||||
|
||||
return (
|
||||
<Alert status="warning" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
|
||||
<Alert status="info" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
|
||||
<AlertIcon />
|
||||
<AlertTitle>{t('controlLayers.settings.saveAllImagesToGallery.alert')}</AlertTitle>
|
||||
</Alert>
|
||||
|
||||
@@ -57,21 +57,21 @@ const CanvasAlertsSelectedEntityStatusContent = memo(({ entityIdentifier, adapte
|
||||
const alert = useMemo<AlertData | null>(() => {
|
||||
if (isFiltering) {
|
||||
return {
|
||||
status: 'warning',
|
||||
status: 'info',
|
||||
title: t('controlLayers.HUD.entityStatus.isFiltering', { title }),
|
||||
};
|
||||
}
|
||||
|
||||
if (isTransforming) {
|
||||
return {
|
||||
status: 'warning',
|
||||
status: 'info',
|
||||
title: t('controlLayers.HUD.entityStatus.isTransforming', { title }),
|
||||
};
|
||||
}
|
||||
|
||||
if (isEmpty) {
|
||||
return {
|
||||
status: 'warning',
|
||||
status: 'info',
|
||||
title: t('controlLayers.HUD.entityStatus.isEmpty', { title }),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import { useEntityIdentifierContext } from 'features/controlLayers/contexts/Enti
|
||||
import { usePullBboxIntoLayer } from 'features/controlLayers/hooks/saveCanvasHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { replaceCanvasEntityObjectsWithImage } from 'features/imageActions/actions';
|
||||
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { Trans } from 'react-i18next';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
@@ -20,6 +21,9 @@ export const ControlLayerSettingsEmptyState = memo(() => {
|
||||
[dispatch, entityIdentifier, getState]
|
||||
);
|
||||
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
|
||||
const onClickGalleryButton = useCallback(() => {
|
||||
dispatch(activeTabCanvasRightPanelChanged('gallery'));
|
||||
}, [dispatch]);
|
||||
const pullBboxIntoLayer = usePullBboxIntoLayer(entityIdentifier);
|
||||
|
||||
const components = useMemo(
|
||||
@@ -27,11 +31,14 @@ export const ControlLayerSettingsEmptyState = memo(() => {
|
||||
UploadButton: (
|
||||
<Button isDisabled={isBusy} size="sm" variant="link" color="base.300" {...uploadApi.getUploadButtonProps()} />
|
||||
),
|
||||
GalleryButton: (
|
||||
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
|
||||
),
|
||||
PullBboxButton: (
|
||||
<Button onClick={pullBboxIntoLayer} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
|
||||
),
|
||||
}),
|
||||
[isBusy, pullBboxIntoLayer, uploadApi]
|
||||
[isBusy, onClickGalleryButton, pullBboxIntoLayer, uploadApi]
|
||||
);
|
||||
|
||||
return (
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
import { Checkbox, ConfirmationAlertDialog, Flex, FormControl, FormLabel, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import { buildUseBoolean } from 'common/hooks/useBoolean';
|
||||
import { canvasSessionReset, generateSessionReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import {
|
||||
selectSystemShouldConfirmOnNewSession,
|
||||
shouldConfirmOnNewSessionToggled,
|
||||
} from 'features/system/store/systemSlice';
|
||||
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const [useNewGallerySessionDialog] = buildUseBoolean(false);
|
||||
const [useNewCanvasSessionDialog] = buildUseBoolean(false);
|
||||
|
||||
const useNewGallerySession = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
|
||||
const newSessionDialog = useNewGallerySessionDialog();
|
||||
|
||||
const newGallerySessionImmediate = useCallback(() => {
|
||||
dispatch(generateSessionReset());
|
||||
dispatch(activeTabCanvasRightPanelChanged('gallery'));
|
||||
}, [dispatch]);
|
||||
|
||||
const newGallerySessionWithDialog = useCallback(() => {
|
||||
if (shouldConfirmOnNewSession) {
|
||||
newSessionDialog.setTrue();
|
||||
return;
|
||||
}
|
||||
newGallerySessionImmediate();
|
||||
}, [newGallerySessionImmediate, newSessionDialog, shouldConfirmOnNewSession]);
|
||||
|
||||
return { newGallerySessionImmediate, newGallerySessionWithDialog };
|
||||
};
|
||||
|
||||
const useNewCanvasSession = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
|
||||
const newSessionDialog = useNewCanvasSessionDialog();
|
||||
|
||||
const newCanvasSessionImmediate = useCallback(() => {
|
||||
dispatch(canvasSessionReset());
|
||||
dispatch(activeTabCanvasRightPanelChanged('layers'));
|
||||
}, [dispatch]);
|
||||
|
||||
const newCanvasSessionWithDialog = useCallback(() => {
|
||||
if (shouldConfirmOnNewSession) {
|
||||
newSessionDialog.setTrue();
|
||||
return;
|
||||
}
|
||||
|
||||
newCanvasSessionImmediate();
|
||||
}, [newCanvasSessionImmediate, newSessionDialog, shouldConfirmOnNewSession]);
|
||||
|
||||
return { newCanvasSessionImmediate, newCanvasSessionWithDialog };
|
||||
};
|
||||
|
||||
export const NewGallerySessionDialog = memo(() => {
|
||||
useAssertSingleton('NewGallerySessionDialog');
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const dialog = useNewGallerySessionDialog();
|
||||
const { newGallerySessionImmediate } = useNewGallerySession();
|
||||
|
||||
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
|
||||
const onToggleConfirm = useCallback(() => {
|
||||
dispatch(shouldConfirmOnNewSessionToggled());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<ConfirmationAlertDialog
|
||||
isOpen={dialog.isTrue}
|
||||
onClose={dialog.setFalse}
|
||||
title={t('controlLayers.newGallerySession')}
|
||||
acceptCallback={newGallerySessionImmediate}
|
||||
acceptButtonText={t('common.ok')}
|
||||
useInert={false}
|
||||
>
|
||||
<Flex direction="column" gap={3}>
|
||||
<Text>{t('controlLayers.newGallerySessionDesc')}</Text>
|
||||
<Text>{t('common.areYouSure')}</Text>
|
||||
<FormControl>
|
||||
<FormLabel>{t('common.dontAskMeAgain')}</FormLabel>
|
||||
<Checkbox isChecked={!shouldConfirmOnNewSession} onChange={onToggleConfirm} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</ConfirmationAlertDialog>
|
||||
);
|
||||
});
|
||||
|
||||
NewGallerySessionDialog.displayName = 'NewGallerySessionDialog';
|
||||
|
||||
export const NewCanvasSessionDialog = memo(() => {
|
||||
useAssertSingleton('NewCanvasSessionDialog');
|
||||
const { t } = useTranslation();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const dialog = useNewCanvasSessionDialog();
|
||||
const { newCanvasSessionImmediate } = useNewCanvasSession();
|
||||
|
||||
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
|
||||
const onToggleConfirm = useCallback(() => {
|
||||
dispatch(shouldConfirmOnNewSessionToggled());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<ConfirmationAlertDialog
|
||||
isOpen={dialog.isTrue}
|
||||
onClose={dialog.setFalse}
|
||||
title={t('controlLayers.newCanvasSession')}
|
||||
acceptCallback={newCanvasSessionImmediate}
|
||||
acceptButtonText={t('common.ok')}
|
||||
useInert={false}
|
||||
>
|
||||
<Flex direction="column" gap={3}>
|
||||
<Text>{t('controlLayers.newCanvasSessionDesc')}</Text>
|
||||
<Text>{t('common.areYouSure')}</Text>
|
||||
<FormControl>
|
||||
<FormLabel>{t('common.dontAskMeAgain')}</FormLabel>
|
||||
<Checkbox isChecked={!shouldConfirmOnNewSession} onChange={onToggleConfirm} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</ConfirmationAlertDialog>
|
||||
);
|
||||
});
|
||||
|
||||
NewCanvasSessionDialog.displayName = 'NewCanvasSessionDialog';
|
||||
@@ -126,7 +126,6 @@ const AddRefImageDropTargetAndButton = memo(() => {
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
AddRefImageDropTargetAndButton.displayName = 'AddRefImageDropTargetAndButton';
|
||||
|
||||
const BboxButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
@@ -146,4 +145,4 @@ const BboxButton = memo(() => {
|
||||
/>
|
||||
);
|
||||
});
|
||||
BboxButton.displayName = 'BboxButton';
|
||||
AddRefImageDropTargetAndButton.displayName = 'AddRefImageDropTargetAndButton';
|
||||
|
||||
@@ -6,6 +6,7 @@ import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd';
|
||||
import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { setGlobalReferenceImage } from 'features/imageActions/actions';
|
||||
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
@@ -21,6 +22,9 @@ export const RefImageNoImageState = memo(() => {
|
||||
[dispatch, id]
|
||||
);
|
||||
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
|
||||
const onClickGalleryButton = useCallback(() => {
|
||||
dispatch(activeTabCanvasRightPanelChanged('gallery'));
|
||||
}, [dispatch]);
|
||||
|
||||
const dndTargetData = useMemo<SetGlobalReferenceImageDndTargetData>(
|
||||
() => setGlobalReferenceImageDndTarget.getData({ id }),
|
||||
@@ -30,8 +34,9 @@ export const RefImageNoImageState = memo(() => {
|
||||
const components = useMemo(
|
||||
() => ({
|
||||
UploadButton: <Button size="sm" variant="link" color="base.300" {...uploadApi.getUploadButtonProps()} />,
|
||||
GalleryButton: <Button onClick={onClickGalleryButton} size="sm" variant="link" color="base.300" />,
|
||||
}),
|
||||
[uploadApi]
|
||||
[onClickGalleryButton, uploadApi]
|
||||
);
|
||||
|
||||
return (
|
||||
|
||||
@@ -8,6 +8,7 @@ import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd';
|
||||
import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { setGlobalReferenceImage } from 'features/imageActions/actions';
|
||||
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
@@ -24,6 +25,9 @@ export const RefImageNoImageStateWithCanvasOptions = memo(() => {
|
||||
[dispatch, id]
|
||||
);
|
||||
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
|
||||
const onClickGalleryButton = useCallback(() => {
|
||||
dispatch(activeTabCanvasRightPanelChanged('gallery'));
|
||||
}, [dispatch]);
|
||||
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(id);
|
||||
|
||||
const dndTargetData = useMemo<SetGlobalReferenceImageDndTargetData>(
|
||||
@@ -36,11 +40,14 @@ export const RefImageNoImageStateWithCanvasOptions = memo(() => {
|
||||
UploadButton: (
|
||||
<Button isDisabled={isBusy} size="sm" variant="link" color="base.300" {...uploadApi.getUploadButtonProps()} />
|
||||
),
|
||||
GalleryButton: (
|
||||
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
|
||||
),
|
||||
PullBboxButton: (
|
||||
<Button onClick={pullBboxIntoIPAdapter} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
|
||||
),
|
||||
}),
|
||||
[isBusy, pullBboxIntoIPAdapter, uploadApi]
|
||||
[isBusy, onClickGalleryButton, pullBboxIntoIPAdapter, uploadApi]
|
||||
);
|
||||
|
||||
return (
|
||||
|
||||
@@ -9,6 +9,7 @@ import type { SetRegionalGuidanceReferenceImageDndTargetData } from 'features/dn
|
||||
import { setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { setRegionalGuidanceReferenceImage } from 'features/imageActions/actions';
|
||||
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
@@ -30,6 +31,9 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag
|
||||
[dispatch, entityIdentifier, referenceImageId]
|
||||
);
|
||||
const uploadApi = useImageUploadButton({ onUpload, allowMultiple: false });
|
||||
const onClickGalleryButton = useCallback(() => {
|
||||
dispatch(activeTabCanvasRightPanelChanged('gallery'));
|
||||
}, [dispatch]);
|
||||
const onDeleteIPAdapter = useCallback(() => {
|
||||
dispatch(rgRefImageDeleted({ entityIdentifier, referenceImageId }));
|
||||
}, [dispatch, entityIdentifier, referenceImageId]);
|
||||
@@ -49,11 +53,14 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag
|
||||
UploadButton: (
|
||||
<Button isDisabled={isBusy} size="sm" variant="link" color="base.300" {...uploadApi.getUploadButtonProps()} />
|
||||
),
|
||||
GalleryButton: (
|
||||
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
|
||||
),
|
||||
PullBboxButton: (
|
||||
<Button onClick={pullBboxIntoIPAdapter} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
|
||||
),
|
||||
}),
|
||||
[isBusy, pullBboxIntoIPAdapter, uploadApi]
|
||||
[isBusy, onClickGalleryButton, pullBboxIntoIPAdapter, uploadApi]
|
||||
);
|
||||
|
||||
return (
|
||||
|
||||
@@ -16,7 +16,7 @@ export const CanvasSettingsSaveAllImagesToGalleryCheckbox = memo(() => {
|
||||
}, [dispatch]);
|
||||
return (
|
||||
<FormControl w="full">
|
||||
<FormLabel flexGrow={1}>{t('controlLayers.settings.saveAllImagesToGallery.label')}</FormLabel>
|
||||
<FormLabel flexGrow={1}>{t('controlLayers.saveAllImagesToGallery')}</FormLabel>
|
||||
<Checkbox isChecked={saveAllImagesToGallery} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { Alert, Button, Flex, Grid, Text } from '@invoke-ai/ui-library';
|
||||
import { InitialStateMainModelPicker } from 'features/controlLayers/components/SimpleSession/InitialStateMainModelPicker';
|
||||
import { LaunchpadAddStyleReference } from 'features/controlLayers/components/SimpleSession/LaunchpadAddStyleReference';
|
||||
import { navigationApi } from 'features/ui/layouts/navigation-api';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
import { InitialStateMainModelPicker } from './InitialStateMainModelPicker';
|
||||
import { LaunchpadAddStyleReference } from './LaunchpadAddStyleReference';
|
||||
import { LaunchpadContainer } from './LaunchpadContainer';
|
||||
import { LaunchpadGenerateFromTextButton } from './LaunchpadGenerateFromTextButton';
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import { Flex, Heading, Icon, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import { LaunchpadButton } from 'features/controlLayers/components/SimpleSession/LaunchpadButton';
|
||||
import { getDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { refImageAdded } from 'features/controlLayers/store/refImagesSlice';
|
||||
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
|
||||
import { addGlobalReferenceImageDndTarget, newCanvasFromImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { LaunchpadButton } from 'features/ui/layouts/LaunchpadButton';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { PiUploadBold, PiUserCircleGearBold } from 'react-icons/pi';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
@@ -1,10 +1,10 @@
|
||||
import { Flex, Heading, Icon, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import { LaunchpadButton } from 'features/controlLayers/components/SimpleSession/LaunchpadButton';
|
||||
import { newCanvasFromImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { newCanvasFromImage } from 'features/imageActions/actions';
|
||||
import { LaunchpadButton } from 'features/ui/layouts/LaunchpadButton';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { PiPencilBold, PiUploadBold } from 'react-icons/pi';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Flex, Heading, Icon, Text } from '@invoke-ai/ui-library';
|
||||
import { LaunchpadButton } from 'features/ui/layouts/LaunchpadButton';
|
||||
import { LaunchpadButton } from 'features/controlLayers/components/SimpleSession/LaunchpadButton';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { PiCursorTextBold, PiTextAaBold } from 'react-icons/pi';
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import type { CircularProgressProps, SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { CircularProgress, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { getProgressMessage } from 'features/controlLayers/components/StagingArea/shared';
|
||||
import { useCanvasSessionContext, useProgressData } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { getProgressMessage } from 'features/controlLayers/components/SimpleSession/shared';
|
||||
import { memo } from 'react';
|
||||
import type { S } from 'services/api/types';
|
||||
|
||||
import { useProgressDatum } from './context';
|
||||
|
||||
const circleStyles: SystemStyleObject = {
|
||||
circle: {
|
||||
transitionProperty: 'none',
|
||||
@@ -19,7 +18,8 @@ const circleStyles: SystemStyleObject = {
|
||||
type Props = { itemId: number; status: S['SessionQueueItem']['status'] } & CircularProgressProps;
|
||||
|
||||
export const QueueItemCircularProgress = memo(({ itemId, status, ...rest }: Props) => {
|
||||
const { progressEvent } = useProgressDatum(itemId);
|
||||
const { $progressData } = useCanvasSessionContext();
|
||||
const { progressEvent } = useProgressData($progressData, itemId);
|
||||
|
||||
if (status !== 'in_progress') {
|
||||
return null;
|
||||
@@ -1,9 +1,8 @@
|
||||
import type { TextProps } from '@invoke-ai/ui-library';
|
||||
import { Text } from '@invoke-ai/ui-library';
|
||||
import { DROP_SHADOW } from 'features/controlLayers/components/SimpleSession/shared';
|
||||
import { memo } from 'react';
|
||||
|
||||
import { DROP_SHADOW } from './shared';
|
||||
|
||||
export const QueueItemNumber = memo(({ number, ...rest }: { number: number } & TextProps) => {
|
||||
return <Text pointerEvents="none" userSelect="none" filter={DROP_SHADOW} {...rest}>{`#${number}`}</Text>;
|
||||
});
|
||||
@@ -1,23 +1,25 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { QueueItemCircularProgress } from 'features/controlLayers/components/StagingArea/QueueItemCircularProgress';
|
||||
import { QueueItemProgressImage } from 'features/controlLayers/components/StagingArea/QueueItemProgressImage';
|
||||
import { QueueItemStatusLabel } from 'features/controlLayers/components/StagingArea/QueueItemStatusLabel';
|
||||
import { getQueueItemElementId } from 'features/controlLayers/components/StagingArea/shared';
|
||||
import {
|
||||
useCanvasSessionContext,
|
||||
useOutputImageDTO,
|
||||
useProgressData,
|
||||
} from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { QueueItemCircularProgress } from 'features/controlLayers/components/SimpleSession/QueueItemCircularProgress';
|
||||
import { QueueItemNumber } from 'features/controlLayers/components/SimpleSession/QueueItemNumber';
|
||||
import { QueueItemProgressImage } from 'features/controlLayers/components/SimpleSession/QueueItemProgressImage';
|
||||
import { QueueItemStatusLabel } from 'features/controlLayers/components/SimpleSession/QueueItemStatusLabel';
|
||||
import { getQueueItemElementId } from 'features/controlLayers/components/SimpleSession/shared';
|
||||
import {
|
||||
selectStagingAreaAutoSwitch,
|
||||
settingsStagingAreaAutoSwitchChanged,
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { DndImage } from 'features/dnd/DndImage';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import type { S } from 'services/api/types';
|
||||
|
||||
import { useOutputImageDTO, useStagingAreaContext } from './context';
|
||||
import { QueueItemNumber } from './QueueItemNumber';
|
||||
|
||||
const sx = {
|
||||
cursor: 'pointer',
|
||||
userSelect: 'none',
|
||||
@@ -39,19 +41,19 @@ const sx = {
|
||||
type Props = {
|
||||
item: S['SessionQueueItem'];
|
||||
index: number;
|
||||
isSelected: boolean;
|
||||
};
|
||||
|
||||
export const QueueItemPreviewMini = memo(({ item, index }: Props) => {
|
||||
const ctx = useStagingAreaContext();
|
||||
export const QueueItemPreviewMini = memo(({ item, isSelected, index }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const $isSelected = useMemo(() => ctx.buildIsSelectedComputed(item.item_id), [ctx, item.item_id]);
|
||||
const isSelected = useStore($isSelected);
|
||||
const imageDTO = useOutputImageDTO(item.item_id);
|
||||
const ctx = useCanvasSessionContext();
|
||||
const { imageLoaded } = useProgressData(ctx.$progressData, item.item_id);
|
||||
const imageDTO = useOutputImageDTO(item);
|
||||
const autoSwitch = useAppSelector(selectStagingAreaAutoSwitch);
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
ctx.select(item.item_id);
|
||||
}, [ctx, item.item_id]);
|
||||
ctx.$selectedItemId.set(item.item_id);
|
||||
}, [ctx.$selectedItemId, item.item_id]);
|
||||
|
||||
const onDoubleClick = useCallback(() => {
|
||||
if (autoSwitch !== 'off') {
|
||||
@@ -62,6 +64,10 @@ export const QueueItemPreviewMini = memo(({ item, index }: Props) => {
|
||||
}
|
||||
}, [autoSwitch, dispatch]);
|
||||
|
||||
const onLoad = useCallback(() => {
|
||||
ctx.onImageLoad(item.item_id);
|
||||
}, [ctx, item.item_id]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
id={getQueueItemElementId(index)}
|
||||
@@ -71,8 +77,8 @@ export const QueueItemPreviewMini = memo(({ item, index }: Props) => {
|
||||
onDoubleClick={onDoubleClick}
|
||||
>
|
||||
<QueueItemStatusLabel item={item} position="absolute" margin="auto" />
|
||||
{imageDTO && <DndImage imageDTO={imageDTO} asThumbnail position="absolute" />}
|
||||
<QueueItemProgressImage itemId={item.item_id} position="absolute" />
|
||||
{imageDTO && <DndImage imageDTO={imageDTO} onLoad={onLoad} asThumbnail />}
|
||||
{!imageLoaded && <QueueItemProgressImage itemId={item.item_id} position="absolute" />}
|
||||
<QueueItemNumber number={index + 1} position="absolute" top={0} left={1} />
|
||||
<QueueItemCircularProgress itemId={item.item_id} status={item.status} position="absolute" top={1} right={2} />
|
||||
</Flex>
|
||||
@@ -1,13 +1,13 @@
|
||||
import type { ImageProps } from '@invoke-ai/ui-library';
|
||||
import { Image } from '@invoke-ai/ui-library';
|
||||
import { useCanvasSessionContext, useProgressData } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { memo } from 'react';
|
||||
|
||||
import { useProgressDatum } from './context';
|
||||
|
||||
type Props = { itemId: number } & ImageProps;
|
||||
|
||||
export const QueueItemProgressImage = memo(({ itemId, ...rest }: Props) => {
|
||||
const { progressImage } = useProgressDatum(itemId);
|
||||
const ctx = useCanvasSessionContext();
|
||||
const { progressImage } = useProgressData(ctx.$progressData, itemId);
|
||||
|
||||
if (!progressImage) {
|
||||
return null;
|
||||
@@ -1,16 +1,16 @@
|
||||
import type { TextProps } from '@invoke-ai/ui-library';
|
||||
import { Text } from '@invoke-ai/ui-library';
|
||||
import { useCanvasSessionContext, useProgressData } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { memo } from 'react';
|
||||
import type { S } from 'services/api/types';
|
||||
|
||||
import { useProgressDatum } from './context';
|
||||
|
||||
type Props = { item: S['SessionQueueItem'] } & TextProps;
|
||||
|
||||
export const QueueItemStatusLabel = memo(({ item, ...rest }: Props) => {
|
||||
const { progressImage } = useProgressDatum(item.item_id);
|
||||
const ctx = useCanvasSessionContext();
|
||||
const { progressImage, imageLoaded } = useProgressData(ctx.$progressData, item.item_id);
|
||||
|
||||
if (progressImage) {
|
||||
if (progressImage || imageLoaded) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
import { Box, Flex, forwardRef } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { QueueItemPreviewMini } from 'features/controlLayers/components/StagingArea/QueueItemPreviewMini';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { QueueItemPreviewMini } from 'features/controlLayers/components/SimpleSession/QueueItemPreviewMini';
|
||||
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useOverlayScrollbars } from 'overlayscrollbars-react';
|
||||
import type { CSSProperties, RefObject } from 'react';
|
||||
import { memo, useCallback, useEffect, useRef, useState } from 'react';
|
||||
import type { Components, ComputeItemKey, ItemContent, ListRange, VirtuosoHandle, VirtuosoProps } from 'react-virtuoso';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import type { Components, ItemContent, ListRange, VirtuosoHandle, VirtuosoProps } from 'react-virtuoso';
|
||||
import { Virtuoso } from 'react-virtuoso';
|
||||
import type { S } from 'services/api/types';
|
||||
|
||||
import { useStagingAreaContext } from './context';
|
||||
import { getQueueItemElementId } from './shared';
|
||||
|
||||
const log = logger('system');
|
||||
@@ -20,6 +20,8 @@ const virtuosoStyles = {
|
||||
height: '72px',
|
||||
} satisfies CSSProperties;
|
||||
|
||||
type VirtuosoContext = { selectedItemId: number | null };
|
||||
|
||||
/**
|
||||
* Scroll the item at the given index into view if it is not currently visible.
|
||||
*/
|
||||
@@ -92,7 +94,6 @@ const useScrollableStagingArea = (rootRef: RefObject<HTMLDivElement>) => {
|
||||
const { viewport } = osInstance.elements();
|
||||
viewport.style.overflowX = `var(--os-viewport-overflow-x)`;
|
||||
viewport.style.overflowY = `var(--os-viewport-overflow-y)`;
|
||||
viewport.style.textAlign = 'center';
|
||||
},
|
||||
},
|
||||
options: {
|
||||
@@ -130,26 +131,28 @@ const useScrollableStagingArea = (rootRef: RefObject<HTMLDivElement>) => {
|
||||
};
|
||||
|
||||
export const StagingAreaItemsList = memo(() => {
|
||||
const canvasManager = useCanvasManager();
|
||||
|
||||
const ctx = useStagingAreaContext();
|
||||
const canvasManager = useCanvasManagerSafe();
|
||||
const ctx = useCanvasSessionContext();
|
||||
const virtuosoRef = useRef<VirtuosoHandle>(null);
|
||||
const rangeRef = useRef<ListRange>({ startIndex: 0, endIndex: 0 });
|
||||
const rootRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const items = useStore(ctx.$items);
|
||||
const selectedItemId = useStore(ctx.$selectedItemId);
|
||||
|
||||
const context = useMemo(() => ({ selectedItemId }), [selectedItemId]);
|
||||
const scrollerRef = useScrollableStagingArea(rootRef);
|
||||
|
||||
useEffect(() => {
|
||||
return canvasManager.stagingArea.connectToSession(ctx.$items, ctx.$selectedItem);
|
||||
}, [canvasManager, ctx.$progressData, ctx.$items, ctx.$selectedItem]);
|
||||
if (!canvasManager) {
|
||||
return;
|
||||
}
|
||||
|
||||
return canvasManager.stagingArea.connectToSession(ctx.$selectedItemId, ctx.$progressData, ctx.$isPending);
|
||||
}, [canvasManager, ctx.$progressData, ctx.$selectedItemId, ctx.$isPending]);
|
||||
|
||||
useEffect(() => {
|
||||
return ctx.$selectedItemIndex.listen((selectedItemIndex) => {
|
||||
if (selectedItemIndex === null) {
|
||||
return;
|
||||
}
|
||||
return ctx.$selectedItemIndex.listen((index) => {
|
||||
if (!virtuosoRef.current) {
|
||||
return;
|
||||
}
|
||||
@@ -158,7 +161,11 @@ export const StagingAreaItemsList = memo(() => {
|
||||
return;
|
||||
}
|
||||
|
||||
scrollIntoView(selectedItemIndex, rootRef.current, virtuosoRef.current, rangeRef.current);
|
||||
if (index === null) {
|
||||
return;
|
||||
}
|
||||
|
||||
scrollIntoView(index, rootRef.current, virtuosoRef.current, rangeRef.current);
|
||||
});
|
||||
}, [ctx.$selectedItemIndex]);
|
||||
|
||||
@@ -168,46 +175,40 @@ export const StagingAreaItemsList = memo(() => {
|
||||
|
||||
return (
|
||||
<Box data-overlayscrollbars-initialize="" ref={rootRef} position="relative" w="full" h="full">
|
||||
<Virtuoso<S['SessionQueueItem']>
|
||||
<Virtuoso<S['SessionQueueItem'], VirtuosoContext>
|
||||
ref={virtuosoRef}
|
||||
context={context}
|
||||
data={items}
|
||||
horizontalDirection
|
||||
style={virtuosoStyles}
|
||||
computeItemKey={computeItemKey}
|
||||
increaseViewportBy={2048}
|
||||
itemContent={itemContent}
|
||||
components={components}
|
||||
rangeChanged={onRangeChanged}
|
||||
// Virtuoso expects the ref to be of HTMLElement | null | Window, but overlayscrollbars doesn't allow Window
|
||||
scrollerRef={scrollerRef as VirtuosoProps<S['SessionQueueItem'], void>['scrollerRef']}
|
||||
scrollerRef={scrollerRef as VirtuosoProps<S['SessionQueueItem'], VirtuosoContext>['scrollerRef']}
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
});
|
||||
StagingAreaItemsList.displayName = 'StagingAreaItemsList';
|
||||
|
||||
const computeItemKey: ComputeItemKey<S['SessionQueueItem'], void> = (_, item: S['SessionQueueItem']) => {
|
||||
return item.item_id;
|
||||
};
|
||||
|
||||
const itemContent: ItemContent<S['SessionQueueItem'], void> = (index, item) => (
|
||||
<QueueItemPreviewMini key={`${item.item_id}-mini`} item={item} index={index} />
|
||||
const itemContent: ItemContent<S['SessionQueueItem'], VirtuosoContext> = (index, item, { selectedItemId }) => (
|
||||
<QueueItemPreviewMini
|
||||
key={`${item.item_id}-mini`}
|
||||
item={item}
|
||||
index={index}
|
||||
isSelected={selectedItemId === item.item_id}
|
||||
/>
|
||||
);
|
||||
|
||||
const listSx = {
|
||||
'& > * + *': {
|
||||
pl: 2,
|
||||
},
|
||||
'&[data-disabled="true"]': {
|
||||
filter: 'grayscale(1) opacity(0.5)',
|
||||
},
|
||||
};
|
||||
|
||||
const components: Components<S['SessionQueueItem']> = {
|
||||
const components: Components<S['SessionQueueItem'], VirtuosoContext> = {
|
||||
List: forwardRef(({ context: _, ...rest }, ref) => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
|
||||
|
||||
return <Flex ref={ref} sx={listSx} data-disabled={!shouldShowStagedImage} {...rest} />;
|
||||
return <Flex ref={ref} sx={listSx} {...rest} />;
|
||||
}),
|
||||
};
|
||||
@@ -0,0 +1,585 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { getOutputImageName } from 'features/controlLayers/components/SimpleSession/shared';
|
||||
import { selectStagingAreaAutoSwitch } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import {
|
||||
buildSelectSessionQueueItems,
|
||||
canvasQueueItemDiscarded,
|
||||
canvasSessionReset,
|
||||
} from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import type { ProgressImage } from 'features/nodes/types/common';
|
||||
import type { Atom, MapStore, StoreValue, WritableAtom } from 'nanostores';
|
||||
import { atom, computed, effect, map, subscribeKeys } from 'nanostores';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { createContext, memo, useCallback, useContext, useEffect, useMemo, useState } from 'react';
|
||||
import { getImageDTOSafe } from 'services/api/endpoints/images';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { ImageDTO, S } from 'services/api/types';
|
||||
import { $socket } from 'services/events/stores';
|
||||
import { assert, objectEntries } from 'tsafe';
|
||||
|
||||
export type ProgressData = {
|
||||
itemId: number;
|
||||
progressEvent: S['InvocationProgressEvent'] | null;
|
||||
progressImage: ProgressImage | null;
|
||||
imageDTO: ImageDTO | null;
|
||||
imageLoaded: boolean;
|
||||
};
|
||||
|
||||
const getInitialProgressData = (itemId: number): ProgressData => ({
|
||||
itemId,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO: null,
|
||||
imageLoaded: false,
|
||||
});
|
||||
|
||||
export const useProgressData = ($progressData: ProgressDataMap, itemId: number): ProgressData => {
|
||||
const getInitialValue = useCallback(
|
||||
() => $progressData.get()[itemId] ?? getInitialProgressData(itemId),
|
||||
[$progressData, itemId]
|
||||
);
|
||||
const [value, setValue] = useState(getInitialValue);
|
||||
useEffect(() => {
|
||||
const unsub = subscribeKeys($progressData, [itemId], (data) => {
|
||||
const progressData = data[itemId];
|
||||
if (!progressData) {
|
||||
return;
|
||||
}
|
||||
setValue(progressData);
|
||||
});
|
||||
return () => {
|
||||
unsub();
|
||||
};
|
||||
}, [$progressData, itemId]);
|
||||
|
||||
return value;
|
||||
};
|
||||
|
||||
const setProgress = ($progressData: ProgressDataMap, data: S['InvocationProgressEvent']) => {
|
||||
const progressData = $progressData.get();
|
||||
const current = progressData[data.item_id];
|
||||
if (current) {
|
||||
const next = { ...current };
|
||||
next.progressEvent = data;
|
||||
if (data.image) {
|
||||
next.progressImage = data.image;
|
||||
}
|
||||
$progressData.set({
|
||||
...progressData,
|
||||
[data.item_id]: next,
|
||||
});
|
||||
} else {
|
||||
$progressData.set({
|
||||
...progressData,
|
||||
[data.item_id]: {
|
||||
itemId: data.item_id,
|
||||
progressEvent: data,
|
||||
progressImage: data.image ?? null,
|
||||
imageDTO: null,
|
||||
imageLoaded: false,
|
||||
},
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
export type ProgressDataMap = MapStore<Record<number, ProgressData | undefined>>;
|
||||
|
||||
type CanvasSessionContextValue = {
|
||||
session: { id: string; type: 'simple' | 'advanced' };
|
||||
$items: Atom<S['SessionQueueItem'][]>;
|
||||
$itemCount: Atom<number>;
|
||||
$hasItems: Atom<boolean>;
|
||||
$isPending: Atom<boolean>;
|
||||
$progressData: ProgressDataMap;
|
||||
$selectedItemId: WritableAtom<number | null>;
|
||||
$selectedItem: Atom<S['SessionQueueItem'] | null>;
|
||||
$selectedItemIndex: Atom<number | null>;
|
||||
$selectedItemOutputImageDTO: Atom<ImageDTO | null>;
|
||||
selectNext: () => void;
|
||||
selectPrev: () => void;
|
||||
selectFirst: () => void;
|
||||
selectLast: () => void;
|
||||
onImageLoad: (itemId: number) => void;
|
||||
discard: (itemId: number) => void;
|
||||
discardAll: () => void;
|
||||
};
|
||||
|
||||
const CanvasSessionContext = createContext<CanvasSessionContextValue | null>(null);
|
||||
|
||||
export const CanvasSessionContextProvider = memo(
|
||||
({ id, type, children }: PropsWithChildren<{ id: string; type: 'simple' | 'advanced' }>) => {
|
||||
/**
|
||||
* For best performance and interop with the Canvas, which is outside react but needs to interact with the react
|
||||
* app, all canvas session state is packaged as nanostores atoms. The trickiest part is syncing the queue items
|
||||
* with a nanostores atom.
|
||||
*/
|
||||
const session = useMemo(() => ({ type, id }), [type, id]);
|
||||
|
||||
/**
|
||||
* App store
|
||||
*/
|
||||
const store = useAppStore();
|
||||
|
||||
const socket = useStore($socket);
|
||||
|
||||
/**
|
||||
* Track the last completed item. Used to implement autoswitch.
|
||||
*/
|
||||
const $lastCompletedItemId = useState(() => atom<number | null>(null))[0];
|
||||
|
||||
/**
|
||||
* Track the last started item. Used to implement autoswitch.
|
||||
*/
|
||||
const $lastStartedItemId = useState(() => atom<number | null>(null))[0];
|
||||
|
||||
/**
|
||||
* Manually-synced atom containing queue items for the current session. This is populated from the RTK Query cache
|
||||
* and kept in sync with it via a redux subscription.
|
||||
*/
|
||||
const $items = useState(() => atom<S['SessionQueueItem'][]>([]))[0];
|
||||
|
||||
/**
|
||||
* An internal flag used to work around race conditions with auto-switch switching to queue items before their
|
||||
* output images have fully loaded.
|
||||
*/
|
||||
const $lastLoadedItemId = useState(() => atom<number | null>(null))[0];
|
||||
|
||||
/**
|
||||
* An ephemeral store of progress events and images for all items in the current session.
|
||||
*/
|
||||
const $progressData = useState(() => map<StoreValue<ProgressDataMap>>({}))[0];
|
||||
|
||||
/**
|
||||
* The currently selected queue item's ID, or null if one is not selected.
|
||||
*/
|
||||
const $selectedItemId = useState(() => atom<number | null>(null))[0];
|
||||
|
||||
/**
|
||||
* The number of items. Computed from the queue items array.
|
||||
*/
|
||||
const $itemCount = useState(() => computed([$items], (items) => items.length))[0];
|
||||
|
||||
/**
|
||||
* Whether there are any items. Computed from the queue items array.
|
||||
*/
|
||||
const $hasItems = useState(() => computed([$items], (items) => items.length > 0))[0];
|
||||
|
||||
/**
|
||||
* Whether there are any pending or in-progress items. Computed from the queue items array.
|
||||
*/
|
||||
const $isPending = useState(() =>
|
||||
computed([$items], (items) => items.some((item) => item.status === 'pending' || item.status === 'in_progress'))
|
||||
)[0];
|
||||
|
||||
/**
|
||||
* The currently selected queue item, or null if one is not selected.
|
||||
*/
|
||||
const $selectedItem = useState(() =>
|
||||
computed([$items, $selectedItemId], (items, selectedItemId) => {
|
||||
if (items.length === 0) {
|
||||
return null;
|
||||
}
|
||||
if (selectedItemId === null) {
|
||||
return null;
|
||||
}
|
||||
return items.find(({ item_id }) => item_id === selectedItemId) ?? null;
|
||||
})
|
||||
)[0];
|
||||
|
||||
/**
|
||||
* The currently selected queue item's index in the list of items, or null if one is not selected.
|
||||
*/
|
||||
const $selectedItemIndex = useState(() =>
|
||||
computed([$items, $selectedItemId], (items, selectedItemId) => {
|
||||
if (items.length === 0) {
|
||||
return null;
|
||||
}
|
||||
if (selectedItemId === null) {
|
||||
return null;
|
||||
}
|
||||
return items.findIndex(({ item_id }) => item_id === selectedItemId) ?? null;
|
||||
})
|
||||
)[0];
|
||||
|
||||
/**
|
||||
* The currently selected queue item's output image name, or null if one is not selected or there is no output
|
||||
* image recorded.
|
||||
*/
|
||||
const $selectedItemOutputImageDTO = useState(() =>
|
||||
computed([$selectedItemId, $progressData], (selectedItemId, progressData) => {
|
||||
if (selectedItemId === null) {
|
||||
return null;
|
||||
}
|
||||
const datum = progressData[selectedItemId];
|
||||
if (!datum) {
|
||||
return null;
|
||||
}
|
||||
return datum.imageDTO;
|
||||
})
|
||||
)[0];
|
||||
|
||||
/**
|
||||
* A redux selector to select all queue items from the RTK Query cache.
|
||||
*/
|
||||
const selectQueueItems = useMemo(() => buildSelectSessionQueueItems(session.id), [session.id]);
|
||||
|
||||
const discard = useCallback(
|
||||
(itemId: number) => {
|
||||
store.dispatch(canvasQueueItemDiscarded({ itemId }));
|
||||
},
|
||||
[store]
|
||||
);
|
||||
|
||||
const discardAll = useCallback(() => {
|
||||
store.dispatch(canvasSessionReset());
|
||||
}, [store]);
|
||||
|
||||
const selectNext = useCallback(() => {
|
||||
const selectedItemId = $selectedItemId.get();
|
||||
if (selectedItemId === null) {
|
||||
return;
|
||||
}
|
||||
const items = $items.get();
|
||||
const currentIndex = items.findIndex((item) => item.item_id === selectedItemId);
|
||||
const nextIndex = (currentIndex + 1) % items.length;
|
||||
const nextItem = items[nextIndex];
|
||||
if (!nextItem) {
|
||||
return;
|
||||
}
|
||||
$selectedItemId.set(nextItem.item_id);
|
||||
}, [$items, $selectedItemId]);
|
||||
|
||||
const selectPrev = useCallback(() => {
|
||||
const selectedItemId = $selectedItemId.get();
|
||||
if (selectedItemId === null) {
|
||||
return;
|
||||
}
|
||||
const items = $items.get();
|
||||
const currentIndex = items.findIndex((item) => item.item_id === selectedItemId);
|
||||
const prevIndex = (currentIndex - 1 + items.length) % items.length;
|
||||
const prevItem = items[prevIndex];
|
||||
if (!prevItem) {
|
||||
return;
|
||||
}
|
||||
$selectedItemId.set(prevItem.item_id);
|
||||
}, [$items, $selectedItemId]);
|
||||
|
||||
const selectFirst = useCallback(() => {
|
||||
const items = $items.get();
|
||||
const first = items.at(0);
|
||||
if (!first) {
|
||||
return;
|
||||
}
|
||||
$selectedItemId.set(first.item_id);
|
||||
}, [$items, $selectedItemId]);
|
||||
|
||||
const selectLast = useCallback(() => {
|
||||
const items = $items.get();
|
||||
const last = items.at(-1);
|
||||
if (!last) {
|
||||
return;
|
||||
}
|
||||
$selectedItemId.set(last.item_id);
|
||||
}, [$items, $selectedItemId]);
|
||||
|
||||
const onImageLoad = useCallback(
|
||||
(itemId: number) => {
|
||||
const progressData = $progressData.get();
|
||||
const current = progressData[itemId];
|
||||
if (current) {
|
||||
const next = { ...current, imageLoaded: true };
|
||||
$progressData.setKey(itemId, next);
|
||||
} else {
|
||||
$progressData.setKey(itemId, {
|
||||
...getInitialProgressData(itemId),
|
||||
imageLoaded: true,
|
||||
});
|
||||
}
|
||||
if (
|
||||
$lastCompletedItemId.get() === itemId &&
|
||||
selectStagingAreaAutoSwitch(store.getState()) === 'switch_on_finish'
|
||||
) {
|
||||
$selectedItemId.set(itemId);
|
||||
$lastCompletedItemId.set(null);
|
||||
}
|
||||
},
|
||||
[$lastCompletedItemId, $progressData, $selectedItemId, store]
|
||||
);
|
||||
|
||||
// Set up socket listeners
|
||||
useEffect(() => {
|
||||
if (!socket) {
|
||||
return;
|
||||
}
|
||||
|
||||
const onProgress = (data: S['InvocationProgressEvent']) => {
|
||||
if (data.destination !== session.id) {
|
||||
return;
|
||||
}
|
||||
setProgress($progressData, data);
|
||||
};
|
||||
|
||||
const onQueueItemStatusChanged = (data: S['QueueItemStatusChangedEvent']) => {
|
||||
if (data.destination !== session.id) {
|
||||
return;
|
||||
}
|
||||
if (data.status === 'completed') {
|
||||
$lastCompletedItemId.set(data.item_id);
|
||||
}
|
||||
if (data.status === 'in_progress') {
|
||||
$lastStartedItemId.set(data.item_id);
|
||||
}
|
||||
};
|
||||
|
||||
socket.on('invocation_progress', onProgress);
|
||||
socket.on('queue_item_status_changed', onQueueItemStatusChanged);
|
||||
|
||||
return () => {
|
||||
socket.off('invocation_progress', onProgress);
|
||||
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
|
||||
};
|
||||
}, [$lastCompletedItemId, $lastStartedItemId, $progressData, $selectedItemId, session.id, socket]);
|
||||
|
||||
// Set up state subscriptions and effects
|
||||
useEffect(() => {
|
||||
let _prevItems: readonly S['SessionQueueItem'][] = [];
|
||||
// Seed the $items atom with the initial query cache state
|
||||
$items.set(selectQueueItems(store.getState()));
|
||||
|
||||
// Manually keep the $items atom in sync as the query cache is updated
|
||||
const unsubReduxSyncToItemsAtom = store.subscribe(() => {
|
||||
const prevItems = $items.get();
|
||||
const items = selectQueueItems(store.getState());
|
||||
if (items !== prevItems) {
|
||||
_prevItems = prevItems;
|
||||
$items.set(items);
|
||||
}
|
||||
});
|
||||
|
||||
// Handle cases that could result in a nonexistent queue item being selected.
|
||||
const unsubEnsureSelectedItemIdExists = effect(
|
||||
[$items, $selectedItemId, $lastStartedItemId],
|
||||
(items, selectedItemId, lastStartedItemId) => {
|
||||
if (items.length === 0) {
|
||||
// If there are no items, cannot have a selected item.
|
||||
$selectedItemId.set(null);
|
||||
} else if (selectedItemId === null && items.length > 0) {
|
||||
// If there is no selected item but there are items, select the first one.
|
||||
$selectedItemId.set(items[0]?.item_id ?? null);
|
||||
return;
|
||||
} else if (
|
||||
selectStagingAreaAutoSwitch(store.getState()) === 'switch_on_start' &&
|
||||
items.findIndex(({ item_id }) => item_id === lastStartedItemId) !== -1
|
||||
) {
|
||||
$selectedItemId.set(lastStartedItemId);
|
||||
$lastStartedItemId.set(null);
|
||||
} else if (selectedItemId !== null && items.findIndex(({ item_id }) => item_id === selectedItemId) === -1) {
|
||||
// If an item is selected and it is not in the list of items, un-set it. This effect will run again and we'll
|
||||
// the above case, selecting the first item if there are any.
|
||||
let prevIndex = _prevItems.findIndex(({ item_id }) => item_id === selectedItemId);
|
||||
if (prevIndex >= items.length) {
|
||||
prevIndex = items.length - 1;
|
||||
}
|
||||
const nextItem = items[prevIndex];
|
||||
$selectedItemId.set(nextItem?.item_id ?? null);
|
||||
}
|
||||
|
||||
if (items !== _prevItems) {
|
||||
_prevItems = items;
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
// Clean up the progress data when a queue item is discarded.
|
||||
const unsubCleanUpProgressData = $items.subscribe(async (items) => {
|
||||
const progressData = $progressData.get();
|
||||
|
||||
const toDelete: number[] = [];
|
||||
const toUpdate: ProgressData[] = [];
|
||||
|
||||
for (const [id, datum] of objectEntries(progressData)) {
|
||||
if (!datum) {
|
||||
toDelete.push(id);
|
||||
continue;
|
||||
}
|
||||
const item = items.find(({ item_id }) => item_id === datum.itemId);
|
||||
if (!item) {
|
||||
toDelete.push(datum.itemId);
|
||||
} else if (item.status === 'canceled' || item.status === 'failed') {
|
||||
toUpdate.push({
|
||||
...datum,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO: null,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
for (const item of items) {
|
||||
const datum = progressData[item.item_id];
|
||||
|
||||
if (datum) {
|
||||
if (datum.imageDTO) {
|
||||
continue;
|
||||
}
|
||||
const outputImageName = getOutputImageName(item);
|
||||
if (!outputImageName) {
|
||||
continue;
|
||||
}
|
||||
const imageDTO = await getImageDTOSafe(outputImageName);
|
||||
if (!imageDTO) {
|
||||
continue;
|
||||
}
|
||||
toUpdate.push({
|
||||
...datum,
|
||||
imageDTO,
|
||||
});
|
||||
} else {
|
||||
const outputImageName = getOutputImageName(item);
|
||||
if (!outputImageName) {
|
||||
continue;
|
||||
}
|
||||
const imageDTO = await getImageDTOSafe(outputImageName);
|
||||
if (!imageDTO) {
|
||||
continue;
|
||||
}
|
||||
toUpdate.push({
|
||||
...getInitialProgressData(item.item_id),
|
||||
imageDTO,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
for (const itemId of toDelete) {
|
||||
$progressData.setKey(itemId, undefined);
|
||||
}
|
||||
|
||||
for (const datum of toUpdate) {
|
||||
$progressData.setKey(datum.itemId, datum);
|
||||
}
|
||||
});
|
||||
|
||||
// We only want to auto-switch to completed queue items once their images have fully loaded to prevent flashes
|
||||
// of fallback content and/or progress images. The only surefire way to determine when images have fully loaded
|
||||
// is via the image elements' `onLoad` callback. Images set `$lastLoadedItemId` to their queue item ID in their
|
||||
// `onLoad` handler, and we listen for that here. If auto-switch is enabled, we then switch the to the item.
|
||||
//
|
||||
// TODO: This isn't perfect... we set $lastLoadedItemId in the mini preview component, but the full view
|
||||
// component still needs to retrieve the image from the browser cache... can result in a flash of the progress
|
||||
// image...
|
||||
const unsubHandleAutoSwitch = $lastLoadedItemId.listen((lastLoadedItemId) => {
|
||||
if (lastLoadedItemId === null) {
|
||||
return;
|
||||
}
|
||||
if (selectStagingAreaAutoSwitch(store.getState()) === 'switch_on_finish') {
|
||||
$selectedItemId.set(lastLoadedItemId);
|
||||
}
|
||||
$lastLoadedItemId.set(null);
|
||||
});
|
||||
|
||||
// Create an RTK Query subscription. Without this, the query cache selector will never return anything bc RTK
|
||||
// doesn't know we care about it.
|
||||
const { unsubscribe: unsubQueueItemsQuery } = store.dispatch(
|
||||
queueApi.endpoints.listAllQueueItems.initiate({ destination: session.id })
|
||||
);
|
||||
|
||||
// const unsubListener = store.dispatch(
|
||||
// addAppListener({
|
||||
// matcher: queueApi.endpoints.cancelQueueItem.matchFulfilled,
|
||||
// effect: ({ payload }, { getState }) => {
|
||||
// const { item_id } = payload;
|
||||
|
||||
// const items = selectQueueItems(getState());
|
||||
// if (items.length === 0) {
|
||||
// $selectedItemId.set(null);
|
||||
// } else if ($selectedItemId.get() === null) {
|
||||
// $selectedItemId.set(items[0].item_id);
|
||||
// }
|
||||
// },
|
||||
// })
|
||||
// );
|
||||
|
||||
// Clean up all subscriptions and top-level (i.e. non-computed/derived state)
|
||||
return () => {
|
||||
unsubHandleAutoSwitch();
|
||||
unsubQueueItemsQuery();
|
||||
unsubReduxSyncToItemsAtom();
|
||||
unsubEnsureSelectedItemIdExists();
|
||||
unsubCleanUpProgressData();
|
||||
$items.set([]);
|
||||
$progressData.set({});
|
||||
$selectedItemId.set(null);
|
||||
};
|
||||
}, [
|
||||
$items,
|
||||
$lastLoadedItemId,
|
||||
$lastStartedItemId,
|
||||
$progressData,
|
||||
$selectedItemId,
|
||||
selectQueueItems,
|
||||
session.id,
|
||||
store,
|
||||
]);
|
||||
|
||||
const value = useMemo<CanvasSessionContextValue>(
|
||||
() => ({
|
||||
session,
|
||||
$items,
|
||||
$hasItems,
|
||||
$isPending,
|
||||
$progressData,
|
||||
$selectedItemId,
|
||||
$selectedItem,
|
||||
$selectedItemIndex,
|
||||
$selectedItemOutputImageDTO,
|
||||
$itemCount,
|
||||
selectNext,
|
||||
selectPrev,
|
||||
selectFirst,
|
||||
selectLast,
|
||||
onImageLoad,
|
||||
discard,
|
||||
discardAll,
|
||||
}),
|
||||
[
|
||||
$items,
|
||||
$hasItems,
|
||||
$isPending,
|
||||
$progressData,
|
||||
$selectedItem,
|
||||
$selectedItemId,
|
||||
$selectedItemIndex,
|
||||
session,
|
||||
$selectedItemOutputImageDTO,
|
||||
$itemCount,
|
||||
selectNext,
|
||||
selectPrev,
|
||||
selectFirst,
|
||||
selectLast,
|
||||
onImageLoad,
|
||||
discard,
|
||||
discardAll,
|
||||
]
|
||||
);
|
||||
|
||||
return <CanvasSessionContext.Provider value={value}>{children}</CanvasSessionContext.Provider>;
|
||||
}
|
||||
);
|
||||
CanvasSessionContextProvider.displayName = 'CanvasSessionContextProvider';
|
||||
|
||||
export const useCanvasSessionContext = () => {
|
||||
const ctx = useContext(CanvasSessionContext);
|
||||
assert(ctx !== null, "'useCanvasSessionContext' must be used within a CanvasSessionContextProvider");
|
||||
return ctx;
|
||||
};
|
||||
|
||||
export const useOutputImageDTO = (item: S['SessionQueueItem']) => {
|
||||
const ctx = useCanvasSessionContext();
|
||||
const $imageDTO = useState(() =>
|
||||
computed([ctx.$progressData], (progressData) => progressData[item.item_id]?.imageDTO ?? null)
|
||||
)[0];
|
||||
const imageDTO = useStore($imageDTO);
|
||||
|
||||
return imageDTO;
|
||||
};
|
||||
@@ -1,7 +1,5 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import {
|
||||
selectStagingAreaAutoSwitch,
|
||||
settingsStagingAreaAutoSwitchChanged,
|
||||
@@ -10,9 +8,6 @@ import { memo, useCallback } from 'react';
|
||||
import { PiCaretLineRightBold, PiCaretRightBold, PiMoonBold } from 'react-icons/pi';
|
||||
|
||||
export const StagingAreaAutoSwitchButtons = memo(() => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
|
||||
|
||||
const autoSwitch = useAppSelector(selectStagingAreaAutoSwitch);
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
@@ -34,7 +29,6 @@ export const StagingAreaAutoSwitchButtons = memo(() => {
|
||||
icon={<PiMoonBold />}
|
||||
colorScheme={autoSwitch === 'off' ? 'invokeBlue' : 'base'}
|
||||
onClick={onClickOff}
|
||||
isDisabled={!shouldShowStagedImage}
|
||||
/>
|
||||
<IconButton
|
||||
aria-label="Switch on start"
|
||||
@@ -42,7 +36,6 @@ export const StagingAreaAutoSwitchButtons = memo(() => {
|
||||
icon={<PiCaretRightBold />}
|
||||
colorScheme={autoSwitch === 'switch_on_start' ? 'invokeBlue' : 'base'}
|
||||
onClick={onClickSwitchOnStart}
|
||||
isDisabled={!shouldShowStagedImage}
|
||||
/>
|
||||
<IconButton
|
||||
aria-label="Switch on finish"
|
||||
@@ -50,7 +43,6 @@ export const StagingAreaAutoSwitchButtons = memo(() => {
|
||||
icon={<PiCaretLineRightBold />}
|
||||
colorScheme={autoSwitch === 'switch_on_finish' ? 'invokeBlue' : 'base'}
|
||||
onClick={onClickSwitchOnFinished}
|
||||
isDisabled={!shouldShowStagedImage}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { ButtonGroup, Flex } from '@invoke-ai/ui-library';
|
||||
import { useStagingAreaContext } from 'features/controlLayers/components/StagingArea/context';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { StagingAreaToolbarAcceptButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarAcceptButton';
|
||||
import { StagingAreaToolbarDiscardAllButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarDiscardAllButton';
|
||||
import { StagingAreaToolbarDiscardSelectedButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarDiscardSelectedButton';
|
||||
@@ -9,13 +10,17 @@ import { StagingAreaToolbarNextButton } from 'features/controlLayers/components/
|
||||
import { StagingAreaToolbarPrevButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarPrevButton';
|
||||
import { StagingAreaToolbarSaveSelectedToGalleryButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarSaveSelectedToGalleryButton';
|
||||
import { StagingAreaToolbarToggleShowResultsButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarToggleShowResultsButton';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { memo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
|
||||
import { StagingAreaAutoSwitchButtons } from './StagingAreaAutoSwitchButtons';
|
||||
|
||||
export const StagingAreaToolbar = memo(() => {
|
||||
const ctx = useStagingAreaContext();
|
||||
const canvasManager = useCanvasManager();
|
||||
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
|
||||
|
||||
const ctx = useCanvasSessionContext();
|
||||
|
||||
useHotkeys('meta+left', ctx.selectFirst, { preventDefault: true });
|
||||
useHotkeys('meta+right', ctx.selectLast, { preventDefault: true });
|
||||
@@ -23,22 +28,22 @@ export const StagingAreaToolbar = memo(() => {
|
||||
return (
|
||||
<Flex gap={2}>
|
||||
<ButtonGroup borderRadius="base" shadow="dark-lg">
|
||||
<StagingAreaToolbarPrevButton />
|
||||
<StagingAreaToolbarPrevButton isDisabled={!shouldShowStagedImage} />
|
||||
<StagingAreaToolbarImageCountButton />
|
||||
<StagingAreaToolbarNextButton />
|
||||
<StagingAreaToolbarNextButton isDisabled={!shouldShowStagedImage} />
|
||||
</ButtonGroup>
|
||||
<ButtonGroup borderRadius="base" shadow="dark-lg">
|
||||
<StagingAreaToolbarAcceptButton />
|
||||
<StagingAreaToolbarToggleShowResultsButton />
|
||||
<StagingAreaToolbarSaveSelectedToGalleryButton />
|
||||
<StagingAreaToolbarMenu />
|
||||
<StagingAreaToolbarDiscardSelectedButton />
|
||||
<StagingAreaToolbarDiscardSelectedButton isDisabled={!shouldShowStagedImage} />
|
||||
</ButtonGroup>
|
||||
<ButtonGroup borderRadius="base" shadow="dark-lg">
|
||||
<StagingAreaAutoSwitchButtons />
|
||||
</ButtonGroup>
|
||||
<ButtonGroup borderRadius="base" shadow="dark-lg">
|
||||
<StagingAreaToolbarDiscardAllButton />
|
||||
<StagingAreaToolbarDiscardAllButton isDisabled={!shouldShowStagedImage} />
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
);
|
||||
|
||||
@@ -1,32 +1,64 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useIsRegionFocused } from 'common/hooks/focus';
|
||||
import { useStagingAreaContext } from 'features/controlLayers/components/StagingArea/context';
|
||||
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
|
||||
import { canvasSessionReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { selectBboxRect, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
|
||||
import { imageNameToImageObject } from 'features/controlLayers/store/util';
|
||||
import { useCancelQueueItemsByDestination } from 'features/queue/hooks/useCancelQueueItemsByDestination';
|
||||
import { memo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCheckBold } from 'react-icons/pi';
|
||||
|
||||
export const StagingAreaToolbarAcceptButton = memo(() => {
|
||||
const ctx = useStagingAreaContext();
|
||||
const ctx = useCanvasSessionContext();
|
||||
const dispatch = useAppDispatch();
|
||||
const canvasManager = useCanvasManager();
|
||||
const bboxRect = useAppSelector(selectBboxRect);
|
||||
const selectedEntityIdentifier = useAppSelector(selectSelectedEntityIdentifier);
|
||||
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
|
||||
const isCanvasFocused = useIsRegionFocused('canvas');
|
||||
const selectedItemImageDTO = useStore(ctx.$selectedItemOutputImageDTO);
|
||||
const cancelQueueItemsByDestination = useCancelQueueItemsByDestination();
|
||||
const acceptSelectedIsEnabled = useStore(ctx.$acceptSelectedIsEnabled);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const acceptSelected = useCallback(() => {
|
||||
if (!selectedItemImageDTO) {
|
||||
return;
|
||||
}
|
||||
const { x, y, width, height } = bboxRect;
|
||||
const imageObject = imageNameToImageObject(selectedItemImageDTO.image_name, { width, height });
|
||||
const overrides: Partial<CanvasRasterLayerState> = {
|
||||
position: { x, y },
|
||||
objects: [imageObject],
|
||||
};
|
||||
|
||||
dispatch(rasterLayerAdded({ overrides, isSelected: selectedEntityIdentifier?.type === 'raster_layer' }));
|
||||
dispatch(canvasSessionReset());
|
||||
cancelQueueItemsByDestination.trigger(ctx.session.id, { withToast: false });
|
||||
}, [
|
||||
selectedItemImageDTO,
|
||||
bboxRect,
|
||||
dispatch,
|
||||
selectedEntityIdentifier?.type,
|
||||
cancelQueueItemsByDestination,
|
||||
ctx.session.id,
|
||||
]);
|
||||
|
||||
useHotkeys(
|
||||
['enter'],
|
||||
ctx.acceptSelected,
|
||||
acceptSelected,
|
||||
{
|
||||
preventDefault: true,
|
||||
enabled: isCanvasFocused && shouldShowStagedImage && acceptSelectedIsEnabled,
|
||||
enabled: isCanvasFocused && shouldShowStagedImage && selectedItemImageDTO !== null,
|
||||
},
|
||||
[ctx.acceptSelected, isCanvasFocused, shouldShowStagedImage, acceptSelectedIsEnabled]
|
||||
[isCanvasFocused, shouldShowStagedImage, selectedItemImageDTO]
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -34,9 +66,9 @@ export const StagingAreaToolbarAcceptButton = memo(() => {
|
||||
tooltip={`${t('common.accept')} (Enter)`}
|
||||
aria-label={`${t('common.accept')} (Enter)`}
|
||||
icon={<PiCheckBold />}
|
||||
onClick={ctx.acceptSelected}
|
||||
onClick={acceptSelected}
|
||||
colorScheme="invokeBlue"
|
||||
isDisabled={!acceptSelectedIsEnabled || !shouldShowStagedImage || cancelQueueItemsByDestination.isDisabled}
|
||||
isDisabled={!selectedItemImageDTO || !shouldShowStagedImage || cancelQueueItemsByDestination.isDisabled}
|
||||
isLoading={cancelQueueItemsByDestination.isLoading}
|
||||
/>
|
||||
);
|
||||
|
||||
@@ -1,28 +1,28 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useStagingAreaContext } from 'features/controlLayers/components/StagingArea/context';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { useCancelQueueItemsByDestination } from 'features/queue/hooks/useCancelQueueItemsByDestination';
|
||||
import { memo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
|
||||
export const StagingAreaToolbarDiscardAllButton = memo(() => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
|
||||
|
||||
const ctx = useStagingAreaContext();
|
||||
export const StagingAreaToolbarDiscardAllButton = memo(({ isDisabled }: { isDisabled?: boolean }) => {
|
||||
const ctx = useCanvasSessionContext();
|
||||
const { t } = useTranslation();
|
||||
const cancelQueueItemsByDestination = useCancelQueueItemsByDestination();
|
||||
|
||||
const discardAll = useCallback(() => {
|
||||
ctx.discardAll();
|
||||
cancelQueueItemsByDestination.trigger(ctx.session.id, { withToast: false });
|
||||
}, [cancelQueueItemsByDestination, ctx]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
tooltip={`${t('controlLayers.stagingArea.discardAll')} (Esc)`}
|
||||
aria-label={t('controlLayers.stagingArea.discardAll')}
|
||||
icon={<PiTrashSimpleBold />}
|
||||
onClick={ctx.discardAll}
|
||||
onClick={discardAll}
|
||||
colorScheme="error"
|
||||
isDisabled={cancelQueueItemsByDestination.isDisabled || !shouldShowStagedImage}
|
||||
isDisabled={isDisabled || cancelQueueItemsByDestination.isDisabled}
|
||||
isLoading={cancelQueueItemsByDestination.isLoading}
|
||||
/>
|
||||
);
|
||||
|
||||
@@ -1,30 +1,34 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useStagingAreaContext } from 'features/controlLayers/components/StagingArea/context';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { useCancelQueueItem } from 'features/queue/hooks/useCancelQueueItem';
|
||||
import { memo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
|
||||
export const StagingAreaToolbarDiscardSelectedButton = memo(() => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
|
||||
|
||||
const ctx = useStagingAreaContext();
|
||||
export const StagingAreaToolbarDiscardSelectedButton = memo(({ isDisabled }: { isDisabled?: boolean }) => {
|
||||
const ctx = useCanvasSessionContext();
|
||||
const cancelQueueItem = useCancelQueueItem();
|
||||
const discardSelectedIsEnabled = useStore(ctx.$discardSelectedIsEnabled);
|
||||
const selectedItemId = useStore(ctx.$selectedItemId);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const discardSelected = useCallback(async () => {
|
||||
if (selectedItemId === null) {
|
||||
return;
|
||||
}
|
||||
ctx.discard(selectedItemId);
|
||||
await cancelQueueItem.trigger(selectedItemId, { withToast: false });
|
||||
}, [selectedItemId, ctx, cancelQueueItem]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
tooltip={t('controlLayers.stagingArea.discard')}
|
||||
aria-label={t('controlLayers.stagingArea.discard')}
|
||||
icon={<PiXBold />}
|
||||
onClick={ctx.discardSelected}
|
||||
onClick={discardSelected}
|
||||
colorScheme="invokeBlue"
|
||||
isDisabled={!discardSelectedIsEnabled || cancelQueueItem.isDisabled || !shouldShowStagedImage}
|
||||
isDisabled={selectedItemId === null || cancelQueueItem.isDisabled || isDisabled}
|
||||
isLoading={cancelQueueItem.isLoading}
|
||||
/>
|
||||
);
|
||||
|
||||
@@ -1,27 +1,23 @@
|
||||
import { Button } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useStagingAreaContext } from 'features/controlLayers/components/StagingArea/context';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { memo, useMemo } from 'react';
|
||||
|
||||
export const StagingAreaToolbarImageCountButton = memo(() => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
|
||||
|
||||
const ctx = useStagingAreaContext();
|
||||
const selectedItem = useStore(ctx.$selectedItem);
|
||||
const ctx = useCanvasSessionContext();
|
||||
const selectItemIndex = useStore(ctx.$selectedItemIndex);
|
||||
const itemCount = useStore(ctx.$itemCount);
|
||||
|
||||
const counterText = useMemo(() => {
|
||||
if (itemCount > 0 && selectedItem !== null) {
|
||||
return `${selectedItem.index + 1} of ${itemCount}`;
|
||||
if (itemCount > 0 && selectItemIndex !== null) {
|
||||
return `${selectItemIndex + 1} of ${itemCount}`;
|
||||
} else {
|
||||
return `0 of 0`;
|
||||
}
|
||||
}, [itemCount, selectedItem]);
|
||||
}, [itemCount, selectItemIndex]);
|
||||
|
||||
return (
|
||||
<Button colorScheme="base" pointerEvents="none" minW={28} isDisabled={!shouldShowStagedImage}>
|
||||
<Button colorScheme="base" pointerEvents="none" minW={28}>
|
||||
{counterText}
|
||||
</Button>
|
||||
);
|
||||
|
||||
@@ -1,23 +1,12 @@
|
||||
import { IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { StagingAreaToolbarNewLayerFromImageMenuItems } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarMenuNewLayerFromImage';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { memo } from 'react';
|
||||
import { PiDotsThreeVerticalBold } from 'react-icons/pi';
|
||||
|
||||
export const StagingAreaToolbarMenu = memo(() => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
|
||||
|
||||
return (
|
||||
<Menu>
|
||||
<MenuButton
|
||||
tooltip="Image Actions"
|
||||
as={IconButton}
|
||||
icon={<PiDotsThreeVerticalBold />}
|
||||
colorScheme="invokeBlue"
|
||||
isDisabled={!shouldShowStagedImage}
|
||||
/>
|
||||
<MenuButton as={IconButton} icon={<PiDotsThreeVerticalBold />} colorScheme="invokeBlue" />
|
||||
<MenuList>
|
||||
<StagingAreaToolbarNewLayerFromImageMenuItems />
|
||||
</MenuList>
|
||||
|
||||
@@ -2,7 +2,7 @@ import { MenuGroup, MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { NewLayerIcon } from 'features/controlLayers/components/common/icons';
|
||||
import { useStagingAreaContext } from 'features/controlLayers/components/StagingArea/context';
|
||||
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { createNewCanvasEntityFromImage } from 'features/imageActions/actions';
|
||||
import { toast } from 'features/toast/toast';
|
||||
@@ -15,8 +15,8 @@ const uploadImageArg = { image_category: 'general', is_intermediate: true, silen
|
||||
export const StagingAreaToolbarNewLayerFromImageMenuItems = memo(() => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const { t } = useTranslation();
|
||||
const ctx = useStagingAreaContext();
|
||||
const selectedItemImageDTO = useStore(ctx.$selectedItemImageDTO);
|
||||
const ctx = useCanvasSessionContext();
|
||||
const selectedItemOutputImageDTO = useStore(ctx.$selectedItemOutputImageDTO);
|
||||
const store = useAppStore();
|
||||
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
|
||||
|
||||
@@ -29,11 +29,11 @@ export const StagingAreaToolbarNewLayerFromImageMenuItems = memo(() => {
|
||||
}, [t]);
|
||||
|
||||
const onClickNewRasterLayerFromImage = useCallback(async () => {
|
||||
if (!selectedItemImageDTO) {
|
||||
if (!selectedItemOutputImageDTO) {
|
||||
return;
|
||||
}
|
||||
const { dispatch, getState } = store;
|
||||
const imageDTO = await copyImage(selectedItemImageDTO.image_name, uploadImageArg);
|
||||
const imageDTO = await copyImage(selectedItemOutputImageDTO.image_name, uploadImageArg);
|
||||
createNewCanvasEntityFromImage({
|
||||
imageDTO,
|
||||
type: 'raster_layer',
|
||||
@@ -42,14 +42,14 @@ export const StagingAreaToolbarNewLayerFromImageMenuItems = memo(() => {
|
||||
overrides: { isEnabled: false }, // We are adding the layer while staging, it should be disabled by default
|
||||
});
|
||||
toastSentToCanvas();
|
||||
}, [selectedItemImageDTO, store, toastSentToCanvas]);
|
||||
}, [selectedItemOutputImageDTO, store, toastSentToCanvas]);
|
||||
|
||||
const onClickNewControlLayerFromImage = useCallback(async () => {
|
||||
if (!selectedItemImageDTO) {
|
||||
if (!selectedItemOutputImageDTO) {
|
||||
return;
|
||||
}
|
||||
const { dispatch, getState } = store;
|
||||
const imageDTO = await copyImage(selectedItemImageDTO.image_name, uploadImageArg);
|
||||
const imageDTO = await copyImage(selectedItemOutputImageDTO.image_name, uploadImageArg);
|
||||
|
||||
createNewCanvasEntityFromImage({
|
||||
imageDTO,
|
||||
@@ -59,14 +59,14 @@ export const StagingAreaToolbarNewLayerFromImageMenuItems = memo(() => {
|
||||
overrides: { isEnabled: false }, // We are adding the layer while staging, it should be disabled by default
|
||||
});
|
||||
toastSentToCanvas();
|
||||
}, [selectedItemImageDTO, store, toastSentToCanvas]);
|
||||
}, [selectedItemOutputImageDTO, store, toastSentToCanvas]);
|
||||
|
||||
const onClickNewInpaintMaskFromImage = useCallback(async () => {
|
||||
if (!selectedItemImageDTO) {
|
||||
if (!selectedItemOutputImageDTO) {
|
||||
return;
|
||||
}
|
||||
const { dispatch, getState } = store;
|
||||
const imageDTO = await copyImage(selectedItemImageDTO.image_name, uploadImageArg);
|
||||
const imageDTO = await copyImage(selectedItemOutputImageDTO.image_name, uploadImageArg);
|
||||
|
||||
createNewCanvasEntityFromImage({
|
||||
imageDTO,
|
||||
@@ -76,14 +76,14 @@ export const StagingAreaToolbarNewLayerFromImageMenuItems = memo(() => {
|
||||
overrides: { isEnabled: false }, // We are adding the layer while staging, it should be disabled by default
|
||||
});
|
||||
toastSentToCanvas();
|
||||
}, [selectedItemImageDTO, store, toastSentToCanvas]);
|
||||
}, [selectedItemOutputImageDTO, store, toastSentToCanvas]);
|
||||
|
||||
const onClickNewRegionalGuidanceFromImage = useCallback(async () => {
|
||||
if (!selectedItemImageDTO) {
|
||||
if (!selectedItemOutputImageDTO) {
|
||||
return;
|
||||
}
|
||||
const { dispatch, getState } = store;
|
||||
const imageDTO = await copyImage(selectedItemImageDTO.image_name, uploadImageArg);
|
||||
const imageDTO = await copyImage(selectedItemOutputImageDTO.image_name, uploadImageArg);
|
||||
|
||||
createNewCanvasEntityFromImage({
|
||||
imageDTO,
|
||||
@@ -93,35 +93,35 @@ export const StagingAreaToolbarNewLayerFromImageMenuItems = memo(() => {
|
||||
overrides: { isEnabled: false }, // We are adding the layer while staging, it should be disabled by default
|
||||
});
|
||||
toastSentToCanvas();
|
||||
}, [selectedItemImageDTO, store, toastSentToCanvas]);
|
||||
}, [selectedItemOutputImageDTO, store, toastSentToCanvas]);
|
||||
|
||||
return (
|
||||
<MenuGroup title="New Layer From Image">
|
||||
<MenuItem
|
||||
icon={<NewLayerIcon />}
|
||||
onClickCapture={onClickNewInpaintMaskFromImage}
|
||||
isDisabled={!selectedItemImageDTO || !shouldShowStagedImage}
|
||||
isDisabled={!selectedItemOutputImageDTO || !shouldShowStagedImage}
|
||||
>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<NewLayerIcon />}
|
||||
onClickCapture={onClickNewRegionalGuidanceFromImage}
|
||||
isDisabled={!selectedItemImageDTO || !shouldShowStagedImage}
|
||||
isDisabled={!selectedItemOutputImageDTO || !shouldShowStagedImage}
|
||||
>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<NewLayerIcon />}
|
||||
onClickCapture={onClickNewControlLayerFromImage}
|
||||
isDisabled={!selectedItemImageDTO || !shouldShowStagedImage}
|
||||
isDisabled={!selectedItemOutputImageDTO || !shouldShowStagedImage}
|
||||
>
|
||||
{t('controlLayers.controlLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
icon={<NewLayerIcon />}
|
||||
onClickCapture={onClickNewRasterLayerFromImage}
|
||||
isDisabled={!selectedItemImageDTO || !shouldShowStagedImage}
|
||||
isDisabled={!selectedItemOutputImageDTO || !shouldShowStagedImage}
|
||||
>
|
||||
{t('controlLayers.rasterLayer')}
|
||||
</MenuItem>
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useIsRegionFocused } from 'common/hooks/focus';
|
||||
import { useStagingAreaContext } from 'features/controlLayers/components/StagingArea/context';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowRightBold } from 'react-icons/pi';
|
||||
|
||||
export const StagingAreaToolbarNextButton = memo(() => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
|
||||
|
||||
const ctx = useStagingAreaContext();
|
||||
export const StagingAreaToolbarNextButton = memo(({ isDisabled }: { isDisabled?: boolean }) => {
|
||||
const ctx = useCanvasSessionContext();
|
||||
const itemCount = useStore(ctx.$itemCount);
|
||||
const isCanvasFocused = useIsRegionFocused('canvas');
|
||||
|
||||
@@ -27,9 +23,9 @@ export const StagingAreaToolbarNextButton = memo(() => {
|
||||
ctx.selectNext,
|
||||
{
|
||||
preventDefault: true,
|
||||
enabled: isCanvasFocused && shouldShowStagedImage && itemCount > 1,
|
||||
enabled: isCanvasFocused && !isDisabled && itemCount > 1,
|
||||
},
|
||||
[isCanvasFocused, shouldShowStagedImage, itemCount, ctx.selectNext]
|
||||
[isCanvasFocused, isDisabled, itemCount, ctx.selectNext]
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -39,7 +35,7 @@ export const StagingAreaToolbarNextButton = memo(() => {
|
||||
icon={<PiArrowRightBold />}
|
||||
onClick={selectNext}
|
||||
colorScheme="invokeBlue"
|
||||
isDisabled={itemCount <= 1 || !shouldShowStagedImage}
|
||||
isDisabled={itemCount <= 1 || isDisabled}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useIsRegionFocused } from 'common/hooks/focus';
|
||||
import { useStagingAreaContext } from 'features/controlLayers/components/StagingArea/context';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowLeftBold } from 'react-icons/pi';
|
||||
|
||||
export const StagingAreaToolbarPrevButton = memo(() => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
|
||||
const ctx = useStagingAreaContext();
|
||||
export const StagingAreaToolbarPrevButton = memo(({ isDisabled }: { isDisabled?: boolean }) => {
|
||||
const ctx = useCanvasSessionContext();
|
||||
const itemCount = useStore(ctx.$itemCount);
|
||||
const isCanvasFocused = useIsRegionFocused('canvas');
|
||||
|
||||
@@ -26,9 +23,9 @@ export const StagingAreaToolbarPrevButton = memo(() => {
|
||||
ctx.selectPrev,
|
||||
{
|
||||
preventDefault: true,
|
||||
enabled: isCanvasFocused && shouldShowStagedImage && itemCount > 1,
|
||||
enabled: isCanvasFocused && !isDisabled && itemCount > 1,
|
||||
},
|
||||
[isCanvasFocused, shouldShowStagedImage, itemCount, ctx.selectPrev]
|
||||
[isCanvasFocused, isDisabled, itemCount, ctx.selectPrev]
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -38,7 +35,7 @@ export const StagingAreaToolbarPrevButton = memo(() => {
|
||||
icon={<PiArrowLeftBold />}
|
||||
onClick={selectPrev}
|
||||
colorScheme="invokeBlue"
|
||||
isDisabled={itemCount <= 1 || !shouldShowStagedImage}
|
||||
isDisabled={itemCount <= 1 || isDisabled}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -2,7 +2,7 @@ import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import { useStagingAreaContext } from 'features/controlLayers/components/StagingArea/context';
|
||||
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import { toast } from 'features/toast/toast';
|
||||
@@ -16,14 +16,14 @@ const TOAST_ID = 'SAVE_STAGING_AREA_IMAGE_TO_GALLERY';
|
||||
export const StagingAreaToolbarSaveSelectedToGalleryButton = memo(() => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
|
||||
const ctx = useStagingAreaContext();
|
||||
const selectedItemImageDTO = useStore(ctx.$selectedItemImageDTO);
|
||||
const ctx = useCanvasSessionContext();
|
||||
const selectedItemOutputImageDTO = useStore(ctx.$selectedItemOutputImageDTO);
|
||||
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const saveSelectedImageToGallery = useCallback(async () => {
|
||||
if (!selectedItemImageDTO) {
|
||||
if (!selectedItemOutputImageDTO) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ export const StagingAreaToolbarSaveSelectedToGalleryButton = memo(() => {
|
||||
// the gallery without borking the canvas, which may need this image to exist.
|
||||
const result = await withResultAsync(async () => {
|
||||
// Create a new file with the same name, which we will upload
|
||||
await copyImage(selectedItemImageDTO.image_name, {
|
||||
await copyImage(selectedItemOutputImageDTO.image_name, {
|
||||
// Image should show up in the Images tab
|
||||
image_category: 'general',
|
||||
is_intermediate: false,
|
||||
@@ -55,7 +55,7 @@ export const StagingAreaToolbarSaveSelectedToGalleryButton = memo(() => {
|
||||
status: 'error',
|
||||
});
|
||||
}
|
||||
}, [autoAddBoardId, selectedItemImageDTO, t]);
|
||||
}, [autoAddBoardId, selectedItemOutputImageDTO, t]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
@@ -64,7 +64,7 @@ export const StagingAreaToolbarSaveSelectedToGalleryButton = memo(() => {
|
||||
icon={<PiFloppyDiskBold />}
|
||||
onClick={saveSelectedImageToGallery}
|
||||
colorScheme="invokeBlue"
|
||||
isDisabled={!selectedItemImageDTO || !shouldShowStagedImage}
|
||||
isDisabled={!selectedItemOutputImageDTO || !shouldShowStagedImage}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,193 +0,0 @@
|
||||
import { merge } from 'es-toolkit';
|
||||
import type { StagingAreaAppApi } from 'features/controlLayers/components/StagingArea/state';
|
||||
import type { AutoSwitchMode } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import type { ImageDTO, S } from 'services/api/types';
|
||||
import type { PartialDeep } from 'type-fest';
|
||||
import { vi } from 'vitest';
|
||||
|
||||
export const createMockStagingAreaApp = (): StagingAreaAppApi & {
|
||||
// Additional methods for testing
|
||||
_triggerItemsChanged: (items: S['SessionQueueItem'][]) => void;
|
||||
_triggerQueueItemStatusChanged: (data: S['QueueItemStatusChangedEvent']) => void;
|
||||
_triggerInvocationProgress: (data: S['InvocationProgressEvent']) => void;
|
||||
_setAutoSwitchMode: (mode: AutoSwitchMode) => void;
|
||||
_setImageDTO: (imageName: string, imageDTO: ImageDTO | null) => void;
|
||||
_setLoadImageDelay: (delay: number) => void;
|
||||
} => {
|
||||
const itemsChangedHandlers = new Set<(items: S['SessionQueueItem'][]) => void>();
|
||||
const queueItemStatusChangedHandlers = new Set<(data: S['QueueItemStatusChangedEvent']) => void>();
|
||||
const invocationProgressHandlers = new Set<(data: S['InvocationProgressEvent']) => void>();
|
||||
|
||||
let autoSwitchMode: AutoSwitchMode = 'switch_on_start';
|
||||
const imageDTOs = new Map<string, ImageDTO | null>();
|
||||
let loadImageDelay = 0;
|
||||
|
||||
return {
|
||||
onDiscard: vi.fn(),
|
||||
onDiscardAll: vi.fn(),
|
||||
onAccept: vi.fn(),
|
||||
onSelect: vi.fn(),
|
||||
onSelectPrev: vi.fn(),
|
||||
onSelectNext: vi.fn(),
|
||||
onSelectFirst: vi.fn(),
|
||||
onSelectLast: vi.fn(),
|
||||
getAutoSwitch: vi.fn(() => autoSwitchMode),
|
||||
onAutoSwitchChange: vi.fn(),
|
||||
getImageDTO: vi.fn((imageName: string) => {
|
||||
return Promise.resolve(imageDTOs.get(imageName) || null);
|
||||
}),
|
||||
loadImage: vi.fn(async (imageName: string) => {
|
||||
if (loadImageDelay > 0) {
|
||||
await new Promise((resolve) => {
|
||||
setTimeout(resolve, loadImageDelay);
|
||||
});
|
||||
}
|
||||
// Mock HTMLImageElement for testing environment
|
||||
const mockImage = {
|
||||
src: imageName,
|
||||
width: 512,
|
||||
height: 512,
|
||||
onload: null,
|
||||
onerror: null,
|
||||
} as HTMLImageElement;
|
||||
return mockImage;
|
||||
}),
|
||||
onItemsChanged: vi.fn((handler) => {
|
||||
itemsChangedHandlers.add(handler);
|
||||
return () => itemsChangedHandlers.delete(handler);
|
||||
}),
|
||||
onQueueItemStatusChanged: vi.fn((handler) => {
|
||||
queueItemStatusChangedHandlers.add(handler);
|
||||
return () => queueItemStatusChangedHandlers.delete(handler);
|
||||
}),
|
||||
onInvocationProgress: vi.fn((handler) => {
|
||||
invocationProgressHandlers.add(handler);
|
||||
return () => invocationProgressHandlers.delete(handler);
|
||||
}),
|
||||
|
||||
// Testing helper methods
|
||||
_triggerItemsChanged: (items: S['SessionQueueItem'][]) => {
|
||||
itemsChangedHandlers.forEach((handler) => handler(items));
|
||||
},
|
||||
_triggerQueueItemStatusChanged: (data: S['QueueItemStatusChangedEvent']) => {
|
||||
queueItemStatusChangedHandlers.forEach((handler) => handler(data));
|
||||
},
|
||||
_triggerInvocationProgress: (data: S['InvocationProgressEvent']) => {
|
||||
invocationProgressHandlers.forEach((handler) => handler(data));
|
||||
},
|
||||
_setAutoSwitchMode: (mode: AutoSwitchMode) => {
|
||||
autoSwitchMode = mode;
|
||||
},
|
||||
_setImageDTO: (imageName: string, imageDTO: ImageDTO | null) => {
|
||||
imageDTOs.set(imageName, imageDTO);
|
||||
},
|
||||
_setLoadImageDelay: (delay: number) => {
|
||||
loadImageDelay = delay;
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
export const createMockQueueItem = (overrides: PartialDeep<S['SessionQueueItem']> = {}): S['SessionQueueItem'] =>
|
||||
merge(
|
||||
{
|
||||
item_id: 1,
|
||||
batch_id: 'test-batch-id',
|
||||
session_id: 'test-session',
|
||||
queue_id: 'test-queue-id',
|
||||
status: 'pending',
|
||||
priority: 0,
|
||||
origin: null,
|
||||
destination: 'test-session',
|
||||
error_type: null,
|
||||
error_message: null,
|
||||
error_traceback: null,
|
||||
created_at: '2024-01-01T00:00:00Z',
|
||||
updated_at: '2024-01-01T00:00:00Z',
|
||||
started_at: null,
|
||||
completed_at: null,
|
||||
field_values: null,
|
||||
retried_from_item_id: null,
|
||||
is_api_validation_run: false,
|
||||
published_workflow_id: null,
|
||||
session: {
|
||||
id: 'test-session',
|
||||
graph: {},
|
||||
execution_graph: {},
|
||||
executed: [],
|
||||
executed_history: [],
|
||||
results: {
|
||||
'test-node-id': {
|
||||
image: {
|
||||
image_name: 'test-image.png',
|
||||
},
|
||||
},
|
||||
},
|
||||
errors: {},
|
||||
prepared_source_mapping: {},
|
||||
source_prepared_mapping: {
|
||||
canvas_output: ['test-node-id'],
|
||||
},
|
||||
},
|
||||
workflow: null,
|
||||
},
|
||||
overrides
|
||||
) as S['SessionQueueItem'];
|
||||
|
||||
export const createMockImageDTO = (overrides: Partial<ImageDTO> = {}): ImageDTO => ({
|
||||
image_name: 'test-image.png',
|
||||
image_url: 'http://test.com/test-image.png',
|
||||
thumbnail_url: 'http://test.com/test-image-thumb.png',
|
||||
image_origin: 'internal',
|
||||
image_category: 'general',
|
||||
width: 512,
|
||||
height: 512,
|
||||
created_at: '2024-01-01T00:00:00Z',
|
||||
updated_at: '2024-01-01T00:00:00Z',
|
||||
deleted_at: null,
|
||||
is_intermediate: false,
|
||||
starred: false,
|
||||
has_workflow: false,
|
||||
session_id: 'test-session',
|
||||
node_id: 'test-node-id',
|
||||
board_id: null,
|
||||
...overrides,
|
||||
});
|
||||
|
||||
export const createMockProgressEvent = (
|
||||
overrides: PartialDeep<S['InvocationProgressEvent']> = {}
|
||||
): S['InvocationProgressEvent'] =>
|
||||
merge(
|
||||
{
|
||||
timestamp: Date.now(),
|
||||
queue_id: 'test-queue-id',
|
||||
item_id: 1,
|
||||
batch_id: 'test-batch-id',
|
||||
session_id: 'test-session',
|
||||
origin: null,
|
||||
destination: 'test-session',
|
||||
invocation: {},
|
||||
invocation_source_id: 'test-invocation-source-id',
|
||||
message: 'Processing...',
|
||||
percentage: 50,
|
||||
image: null,
|
||||
} as S['InvocationProgressEvent'],
|
||||
overrides
|
||||
);
|
||||
|
||||
export const createMockQueueItemStatusChangedEvent = (
|
||||
overrides: PartialDeep<S['QueueItemStatusChangedEvent']> = {}
|
||||
): S['QueueItemStatusChangedEvent'] =>
|
||||
merge(
|
||||
{
|
||||
timestamp: Date.now(),
|
||||
queue_id: 'test-queue-id',
|
||||
item_id: 1,
|
||||
batch_id: 'test-batch-id',
|
||||
origin: null,
|
||||
destination: 'test-session',
|
||||
status: 'completed',
|
||||
error_type: null,
|
||||
error_message: null,
|
||||
} as S['QueueItemStatusChangedEvent'],
|
||||
overrides
|
||||
);
|
||||
@@ -1,134 +0,0 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { loadImage } from 'features/controlLayers/konva/util';
|
||||
import {
|
||||
selectStagingAreaAutoSwitch,
|
||||
settingsStagingAreaAutoSwitchChanged,
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
|
||||
import {
|
||||
buildSelectCanvasQueueItems,
|
||||
canvasQueueItemDiscarded,
|
||||
canvasSessionReset,
|
||||
} from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { selectBboxRect, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
|
||||
import { imageNameToImageObject } from 'features/controlLayers/store/util';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { createContext, memo, useContext, useEffect, useMemo, useState } from 'react';
|
||||
import { getImageDTOSafe } from 'services/api/endpoints/images';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { S } from 'services/api/types';
|
||||
import { $socket } from 'services/events/stores';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import type { ProgressData, StagingAreaAppApi } from './state';
|
||||
import { getInitialProgressData, StagingAreaApi } from './state';
|
||||
|
||||
const StagingAreaContext = createContext<StagingAreaApi | null>(null);
|
||||
|
||||
export const StagingAreaContextProvider = memo(({ children, sessionId }: PropsWithChildren<{ sessionId: string }>) => {
|
||||
const store = useAppStore();
|
||||
const socket = useStore($socket);
|
||||
const stagingAreaAppApi = useMemo<StagingAreaAppApi>(() => {
|
||||
const selectQueueItems = buildSelectCanvasQueueItems(sessionId);
|
||||
|
||||
const _stagingAreaAppApi: StagingAreaAppApi = {
|
||||
getAutoSwitch: () => selectStagingAreaAutoSwitch(store.getState()),
|
||||
getImageDTO: (imageName: string) => getImageDTOSafe(imageName),
|
||||
loadImage: (imageUrl: string) => loadImage(imageUrl, true),
|
||||
onInvocationProgress: (handler) => {
|
||||
socket?.on('invocation_progress', handler);
|
||||
return () => {
|
||||
socket?.off('invocation_progress', handler);
|
||||
};
|
||||
},
|
||||
onQueueItemStatusChanged: (handler) => {
|
||||
socket?.on('queue_item_status_changed', handler);
|
||||
return () => {
|
||||
socket?.off('queue_item_status_changed', handler);
|
||||
};
|
||||
},
|
||||
onItemsChanged: (handler) => {
|
||||
let prev: S['SessionQueueItem'][] = [];
|
||||
return store.subscribe(() => {
|
||||
const next = selectQueueItems(store.getState());
|
||||
if (prev !== next) {
|
||||
prev = next;
|
||||
handler(next);
|
||||
}
|
||||
});
|
||||
},
|
||||
onDiscard: ({ item_id, status }) => {
|
||||
store.dispatch(canvasQueueItemDiscarded({ itemId: item_id }));
|
||||
if (status === 'in_progress' || status === 'pending') {
|
||||
store.dispatch(queueApi.endpoints.cancelQueueItem.initiate({ item_id }, { track: false }));
|
||||
}
|
||||
},
|
||||
onDiscardAll: () => {
|
||||
store.dispatch(canvasSessionReset());
|
||||
store.dispatch(
|
||||
queueApi.endpoints.cancelQueueItemsByDestination.initiate({ destination: sessionId }, { track: false })
|
||||
);
|
||||
},
|
||||
onAccept: (item, imageDTO) => {
|
||||
const bboxRect = selectBboxRect(store.getState());
|
||||
const { x, y, width, height } = bboxRect;
|
||||
const imageObject = imageNameToImageObject(imageDTO.image_name, { width, height });
|
||||
const selectedEntityIdentifier = selectSelectedEntityIdentifier(store.getState());
|
||||
const overrides: Partial<CanvasRasterLayerState> = {
|
||||
position: { x, y },
|
||||
objects: [imageObject],
|
||||
};
|
||||
|
||||
store.dispatch(rasterLayerAdded({ overrides, isSelected: selectedEntityIdentifier?.type === 'raster_layer' }));
|
||||
store.dispatch(canvasSessionReset());
|
||||
store.dispatch(
|
||||
queueApi.endpoints.cancelQueueItemsByDestination.initiate({ destination: sessionId }, { track: false })
|
||||
);
|
||||
},
|
||||
onAutoSwitchChange: (mode) => {
|
||||
store.dispatch(settingsStagingAreaAutoSwitchChanged(mode));
|
||||
},
|
||||
};
|
||||
|
||||
return _stagingAreaAppApi;
|
||||
}, [sessionId, socket, store]);
|
||||
|
||||
const [stagingAreaApi] = useState(() => new StagingAreaApi());
|
||||
|
||||
useEffect(() => {
|
||||
stagingAreaApi.connectToApp(sessionId, stagingAreaAppApi);
|
||||
|
||||
// We need to subscribe to the queue items query manually to ensure the staging area actually gets the items
|
||||
const { unsubscribe: unsubQueueItemsQuery } = store.dispatch(
|
||||
queueApi.endpoints.listAllQueueItems.initiate({ destination: sessionId })
|
||||
);
|
||||
|
||||
return () => {
|
||||
stagingAreaApi.cleanup();
|
||||
unsubQueueItemsQuery();
|
||||
};
|
||||
}, [sessionId, stagingAreaApi, stagingAreaAppApi, store]);
|
||||
|
||||
return <StagingAreaContext.Provider value={stagingAreaApi}>{children}</StagingAreaContext.Provider>;
|
||||
});
|
||||
StagingAreaContextProvider.displayName = 'StagingAreaContextProvider';
|
||||
|
||||
export const useStagingAreaContext = () => {
|
||||
const ctx = useContext(StagingAreaContext);
|
||||
assert(ctx !== null, "'useStagingAreaContext' must be used within a StagingAreaContextProvider");
|
||||
return ctx;
|
||||
};
|
||||
|
||||
export const useOutputImageDTO = (itemId: number) => {
|
||||
const ctx = useStagingAreaContext();
|
||||
const allProgressData = useStore(ctx.$progressData, { keys: [itemId] });
|
||||
return allProgressData[itemId]?.imageDTO ?? null;
|
||||
};
|
||||
|
||||
export const useProgressDatum = (itemId: number): ProgressData => {
|
||||
const ctx = useStagingAreaContext();
|
||||
const allProgressData = useStore(ctx.$progressData, { keys: [itemId] });
|
||||
return allProgressData[itemId] ?? getInitialProgressData(itemId);
|
||||
};
|
||||
@@ -1,205 +0,0 @@
|
||||
import type { S } from 'services/api/types';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { getOutputImageName, getProgressMessage, getQueueItemElementId } from './shared';
|
||||
|
||||
describe('StagingAreaApi Utility Functions', () => {
|
||||
describe('getProgressMessage', () => {
|
||||
it('should return default message when no data provided', () => {
|
||||
expect(getProgressMessage()).toBe('Generating');
|
||||
expect(getProgressMessage(null)).toBe('Generating');
|
||||
});
|
||||
|
||||
it('should format progress message when data is provided', () => {
|
||||
const progressEvent: S['InvocationProgressEvent'] = {
|
||||
item_id: 1,
|
||||
destination: 'test-session',
|
||||
node_id: 'test-node',
|
||||
source_node_id: 'test-source-node',
|
||||
progress: 0.5,
|
||||
message: 'Processing image...',
|
||||
image: null,
|
||||
} as unknown as S['InvocationProgressEvent'];
|
||||
|
||||
const result = getProgressMessage(progressEvent);
|
||||
expect(result).toBe('Processing image...');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getQueueItemElementId', () => {
|
||||
it('should generate correct element ID for queue item', () => {
|
||||
expect(getQueueItemElementId(0)).toBe('queue-item-preview-0');
|
||||
expect(getQueueItemElementId(5)).toBe('queue-item-preview-5');
|
||||
expect(getQueueItemElementId(99)).toBe('queue-item-preview-99');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getOutputImageName', () => {
|
||||
it('should extract image name from completed queue item', () => {
|
||||
const queueItem: S['SessionQueueItem'] = {
|
||||
item_id: 1,
|
||||
status: 'completed',
|
||||
priority: 0,
|
||||
destination: 'test-session',
|
||||
created_at: '2024-01-01T00:00:00Z',
|
||||
updated_at: '2024-01-01T00:00:00Z',
|
||||
started_at: '2024-01-01T00:00:01Z',
|
||||
completed_at: '2024-01-01T00:01:00Z',
|
||||
error: null,
|
||||
session: {
|
||||
id: 'test-session',
|
||||
source_prepared_mapping: {
|
||||
canvas_output: ['output-node-id'],
|
||||
},
|
||||
results: {
|
||||
'output-node-id': {
|
||||
image: {
|
||||
image_name: 'test-output.png',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
} as unknown as S['SessionQueueItem'];
|
||||
|
||||
expect(getOutputImageName(queueItem)).toBe('test-output.png');
|
||||
});
|
||||
|
||||
it('should return null when no canvas output node found', () => {
|
||||
const queueItem = {
|
||||
item_id: 1,
|
||||
status: 'completed',
|
||||
priority: 0,
|
||||
destination: 'test-session',
|
||||
created_at: '2024-01-01T00:00:00Z',
|
||||
updated_at: '2024-01-01T00:00:00Z',
|
||||
started_at: '2024-01-01T00:00:01Z',
|
||||
completed_at: '2024-01-01T00:01:00Z',
|
||||
error: null,
|
||||
session: {
|
||||
id: 'test-session',
|
||||
source_prepared_mapping: {
|
||||
some_other_node: ['other-node-id'],
|
||||
},
|
||||
results: {
|
||||
'other-node-id': {
|
||||
type: 'image_output',
|
||||
image: {
|
||||
image_name: 'test-output.png',
|
||||
},
|
||||
width: 512,
|
||||
height: 512,
|
||||
},
|
||||
},
|
||||
},
|
||||
} as unknown as S['SessionQueueItem'];
|
||||
|
||||
expect(getOutputImageName(queueItem)).toBe(null);
|
||||
});
|
||||
|
||||
it('should return null when output node has no results', () => {
|
||||
const queueItem: S['SessionQueueItem'] = {
|
||||
item_id: 1,
|
||||
status: 'completed',
|
||||
priority: 0,
|
||||
destination: 'test-session',
|
||||
created_at: '2024-01-01T00:00:00Z',
|
||||
updated_at: '2024-01-01T00:00:00Z',
|
||||
started_at: '2024-01-01T00:00:01Z',
|
||||
completed_at: '2024-01-01T00:01:00Z',
|
||||
error: null,
|
||||
session: {
|
||||
id: 'test-session',
|
||||
source_prepared_mapping: {
|
||||
canvas_output: ['output-node-id'],
|
||||
},
|
||||
results: {},
|
||||
},
|
||||
} as unknown as S['SessionQueueItem'];
|
||||
|
||||
expect(getOutputImageName(queueItem)).toBe(null);
|
||||
});
|
||||
|
||||
it('should return null when results contain no image fields', () => {
|
||||
const queueItem: S['SessionQueueItem'] = {
|
||||
item_id: 1,
|
||||
status: 'completed',
|
||||
priority: 0,
|
||||
destination: 'test-session',
|
||||
created_at: '2024-01-01T00:00:00Z',
|
||||
updated_at: '2024-01-01T00:00:00Z',
|
||||
started_at: '2024-01-01T00:00:01Z',
|
||||
completed_at: '2024-01-01T00:01:00Z',
|
||||
error: null,
|
||||
session: {
|
||||
id: 'test-session',
|
||||
source_prepared_mapping: {
|
||||
canvas_output: ['output-node-id'],
|
||||
},
|
||||
results: {
|
||||
'output-node-id': {
|
||||
text: 'some text output',
|
||||
number: 42,
|
||||
},
|
||||
},
|
||||
},
|
||||
} as unknown as S['SessionQueueItem'];
|
||||
|
||||
expect(getOutputImageName(queueItem)).toBe(null);
|
||||
});
|
||||
|
||||
it('should handle multiple outputs and return first image', () => {
|
||||
const queueItem: S['SessionQueueItem'] = {
|
||||
item_id: 1,
|
||||
status: 'completed',
|
||||
priority: 0,
|
||||
destination: 'test-session',
|
||||
created_at: '2024-01-01T00:00:00Z',
|
||||
updated_at: '2024-01-01T00:00:00Z',
|
||||
started_at: '2024-01-01T00:00:01Z',
|
||||
completed_at: '2024-01-01T00:01:00Z',
|
||||
error: null,
|
||||
session: {
|
||||
id: 'test-session',
|
||||
source_prepared_mapping: {
|
||||
canvas_output: ['output-node-id'],
|
||||
},
|
||||
results: {
|
||||
'output-node-id': {
|
||||
text: 'some text',
|
||||
first_image: {
|
||||
image_name: 'first-image.png',
|
||||
},
|
||||
second_image: {
|
||||
image_name: 'second-image.png',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
} as unknown as S['SessionQueueItem'];
|
||||
|
||||
const result = getOutputImageName(queueItem);
|
||||
expect(result).toBe('first-image.png');
|
||||
});
|
||||
|
||||
it('should handle empty session mapping', () => {
|
||||
const queueItem: S['SessionQueueItem'] = {
|
||||
item_id: 1,
|
||||
status: 'completed',
|
||||
priority: 0,
|
||||
destination: 'test-session',
|
||||
created_at: '2024-01-01T00:00:00Z',
|
||||
updated_at: '2024-01-01T00:00:00Z',
|
||||
started_at: '2024-01-01T00:00:01Z',
|
||||
completed_at: '2024-01-01T00:01:00Z',
|
||||
error: null,
|
||||
session: {
|
||||
id: 'test-session',
|
||||
source_prepared_mapping: {},
|
||||
results: {},
|
||||
},
|
||||
} as unknown as S['SessionQueueItem'];
|
||||
|
||||
expect(getOutputImageName(queueItem)).toBe(null);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,784 +0,0 @@
|
||||
import { afterEach, beforeEach, describe, expect, it } from 'vitest';
|
||||
|
||||
import {
|
||||
createMockImageDTO,
|
||||
createMockProgressEvent,
|
||||
createMockQueueItem,
|
||||
createMockQueueItemStatusChangedEvent,
|
||||
createMockStagingAreaApp,
|
||||
} from './__mocks__/mockStagingAreaApp';
|
||||
import { StagingAreaApi } from './state';
|
||||
|
||||
describe('StagingAreaApi', () => {
|
||||
let api: StagingAreaApi;
|
||||
let mockApp: ReturnType<typeof createMockStagingAreaApp>;
|
||||
const sessionId = 'test-session';
|
||||
|
||||
beforeEach(() => {
|
||||
mockApp = createMockStagingAreaApp();
|
||||
api = new StagingAreaApi();
|
||||
api.connectToApp(sessionId, mockApp);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
api.cleanup();
|
||||
});
|
||||
|
||||
describe('Constructor and Setup', () => {
|
||||
it('should initialize with correct session ID', () => {
|
||||
expect(api._sessionId).toBe(sessionId);
|
||||
});
|
||||
|
||||
it('should set up event subscriptions', () => {
|
||||
expect(mockApp.onItemsChanged).toHaveBeenCalledWith(expect.any(Function));
|
||||
expect(mockApp.onQueueItemStatusChanged).toHaveBeenCalledWith(expect.any(Function));
|
||||
expect(mockApp.onInvocationProgress).toHaveBeenCalledWith(expect.any(Function));
|
||||
});
|
||||
|
||||
it('should initialize atoms with default values', () => {
|
||||
expect(api.$lastStartedItemId.get()).toBe(null);
|
||||
expect(api.$lastCompletedItemId.get()).toBe(null);
|
||||
expect(api.$items.get()).toEqual([]);
|
||||
expect(api.$progressData.get()).toEqual({});
|
||||
expect(api.$selectedItemId.get()).toBe(null);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Computed Values', () => {
|
||||
it('should compute item count correctly', () => {
|
||||
expect(api.$itemCount.get()).toBe(0);
|
||||
|
||||
const items = [createMockQueueItem({ item_id: 1 })];
|
||||
api.$items.set(items);
|
||||
expect(api.$itemCount.get()).toBe(1);
|
||||
});
|
||||
|
||||
it('should compute hasItems correctly', () => {
|
||||
expect(api.$hasItems.get()).toBe(false);
|
||||
|
||||
const items = [createMockQueueItem({ item_id: 1 })];
|
||||
api.$items.set(items);
|
||||
expect(api.$hasItems.get()).toBe(true);
|
||||
});
|
||||
|
||||
it('should compute isPending correctly', () => {
|
||||
expect(api.$isPending.get()).toBe(false);
|
||||
|
||||
const items = [
|
||||
createMockQueueItem({ item_id: 1, status: 'pending' }),
|
||||
createMockQueueItem({ item_id: 2, status: 'completed' }),
|
||||
];
|
||||
api.$items.set(items);
|
||||
expect(api.$isPending.get()).toBe(true);
|
||||
});
|
||||
|
||||
it('should compute selectedItem correctly', () => {
|
||||
expect(api.$selectedItem.get()).toBe(null);
|
||||
|
||||
const items = [createMockQueueItem({ item_id: 1 }), createMockQueueItem({ item_id: 2 })];
|
||||
api.$items.set(items);
|
||||
api.$selectedItemId.set(1);
|
||||
|
||||
const selectedItem = api.$selectedItem.get();
|
||||
expect(selectedItem).not.toBe(null);
|
||||
expect(selectedItem?.item.item_id).toBe(1);
|
||||
expect(selectedItem?.index).toBe(0);
|
||||
});
|
||||
|
||||
it('should compute selectedItemImageDTO correctly', () => {
|
||||
const items = [createMockQueueItem({ item_id: 1 })];
|
||||
const imageDTO = createMockImageDTO();
|
||||
|
||||
api.$items.set(items);
|
||||
api.$selectedItemId.set(1);
|
||||
api.$progressData.setKey(1, {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO,
|
||||
});
|
||||
|
||||
expect(api.$selectedItemImageDTO.get()).toBe(imageDTO);
|
||||
});
|
||||
|
||||
it('should compute selectedItemIndex correctly', () => {
|
||||
const items = [createMockQueueItem({ item_id: 1 }), createMockQueueItem({ item_id: 2 })];
|
||||
api.$items.set(items);
|
||||
api.$selectedItemId.set(2);
|
||||
|
||||
expect(api.$selectedItemIndex.get()).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Selection Methods', () => {
|
||||
beforeEach(() => {
|
||||
const items = [
|
||||
createMockQueueItem({ item_id: 1 }),
|
||||
createMockQueueItem({ item_id: 2 }),
|
||||
createMockQueueItem({ item_id: 3 }),
|
||||
];
|
||||
api.$items.set(items);
|
||||
});
|
||||
|
||||
it('should select item by ID', () => {
|
||||
api.select(2);
|
||||
expect(api.$selectedItemId.get()).toBe(2);
|
||||
expect(mockApp.onSelect).toHaveBeenCalledWith(2);
|
||||
});
|
||||
|
||||
it('should select next item', () => {
|
||||
api.$selectedItemId.set(1);
|
||||
api.selectNext();
|
||||
expect(api.$selectedItemId.get()).toBe(2);
|
||||
expect(mockApp.onSelectNext).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should wrap to first item when selecting next from last', () => {
|
||||
api.$selectedItemId.set(3);
|
||||
api.selectNext();
|
||||
expect(api.$selectedItemId.get()).toBe(1);
|
||||
});
|
||||
|
||||
it('should select previous item', () => {
|
||||
api.$selectedItemId.set(2);
|
||||
api.selectPrev();
|
||||
expect(api.$selectedItemId.get()).toBe(1);
|
||||
expect(mockApp.onSelectPrev).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should wrap to last item when selecting previous from first', () => {
|
||||
api.$selectedItemId.set(1);
|
||||
api.selectPrev();
|
||||
expect(api.$selectedItemId.get()).toBe(3);
|
||||
});
|
||||
|
||||
it('should select first item', () => {
|
||||
api.selectFirst();
|
||||
expect(api.$selectedItemId.get()).toBe(1);
|
||||
expect(mockApp.onSelectFirst).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should select last item', () => {
|
||||
api.selectLast();
|
||||
expect(api.$selectedItemId.get()).toBe(3);
|
||||
expect(mockApp.onSelectLast).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should do nothing when no items exist', () => {
|
||||
api.$items.set([]);
|
||||
api.selectNext();
|
||||
api.selectPrev();
|
||||
api.selectFirst();
|
||||
api.selectLast();
|
||||
|
||||
expect(api.$selectedItemId.get()).toBe(null);
|
||||
});
|
||||
|
||||
it('should do nothing when no item is selected', () => {
|
||||
api.$selectedItemId.set(null);
|
||||
api.selectNext();
|
||||
api.selectPrev();
|
||||
|
||||
expect(api.$selectedItemId.get()).toBe(null);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Discard Methods', () => {
|
||||
beforeEach(() => {
|
||||
const items = [
|
||||
createMockQueueItem({ item_id: 1 }),
|
||||
createMockQueueItem({ item_id: 2 }),
|
||||
createMockQueueItem({ item_id: 3 }),
|
||||
];
|
||||
api.$items.set(items);
|
||||
});
|
||||
|
||||
it('should discard selected item and select next', () => {
|
||||
api.$selectedItemId.set(2);
|
||||
const selectedItem = api.$selectedItem.get();
|
||||
|
||||
api.discardSelected();
|
||||
|
||||
expect(api.$selectedItemId.get()).toBe(3);
|
||||
expect(mockApp.onDiscard).toHaveBeenCalledWith(selectedItem?.item);
|
||||
});
|
||||
|
||||
it('should discard selected item and clamp to last valid index', () => {
|
||||
api.$selectedItemId.set(3);
|
||||
const selectedItem = api.$selectedItem.get();
|
||||
|
||||
api.discardSelected();
|
||||
|
||||
// The logic clamps to the next index, so when discarding last item (index 2),
|
||||
// it tries to select index 3 which clamps to index 2 (item 3)
|
||||
expect(api.$selectedItemId.get()).toBe(3);
|
||||
expect(mockApp.onDiscard).toHaveBeenCalledWith(selectedItem?.item);
|
||||
});
|
||||
|
||||
it('should set selectedItemId to null when discarding last item', () => {
|
||||
api.$items.set([createMockQueueItem({ item_id: 1 })]);
|
||||
api.$selectedItemId.set(1);
|
||||
|
||||
api.discardSelected();
|
||||
|
||||
// When there's only one item, after clamping we get the same item, so it stays selected
|
||||
expect(api.$selectedItemId.get()).toBe(1);
|
||||
});
|
||||
|
||||
it('should do nothing when no item is selected', () => {
|
||||
api.$selectedItemId.set(null);
|
||||
api.discardSelected();
|
||||
|
||||
expect(mockApp.onDiscard).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should discard all items', () => {
|
||||
api.$selectedItemId.set(2);
|
||||
api.discardAll();
|
||||
|
||||
expect(api.$selectedItemId.get()).toBe(null);
|
||||
expect(mockApp.onDiscardAll).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should compute discardSelectedIsEnabled correctly', () => {
|
||||
expect(api.$discardSelectedIsEnabled.get()).toBe(false);
|
||||
|
||||
api.$selectedItemId.set(1);
|
||||
expect(api.$discardSelectedIsEnabled.get()).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Accept Methods', () => {
|
||||
beforeEach(() => {
|
||||
const items = [createMockQueueItem({ item_id: 1 })];
|
||||
api.$items.set(items);
|
||||
api.$selectedItemId.set(1);
|
||||
});
|
||||
|
||||
it('should accept selected item when image is available', () => {
|
||||
const imageDTO = createMockImageDTO();
|
||||
api.$progressData.setKey(1, {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO,
|
||||
});
|
||||
|
||||
const selectedItem = api.$selectedItem.get();
|
||||
api.acceptSelected();
|
||||
|
||||
expect(mockApp.onAccept).toHaveBeenCalledWith(selectedItem?.item, imageDTO);
|
||||
});
|
||||
|
||||
it('should do nothing when no image is available', () => {
|
||||
api.$progressData.setKey(1, {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO: null,
|
||||
});
|
||||
|
||||
api.acceptSelected();
|
||||
|
||||
expect(mockApp.onAccept).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should do nothing when no item is selected', () => {
|
||||
api.$selectedItemId.set(null);
|
||||
api.acceptSelected();
|
||||
|
||||
expect(mockApp.onAccept).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should compute acceptSelectedIsEnabled correctly', () => {
|
||||
expect(api.$acceptSelectedIsEnabled.get()).toBe(false);
|
||||
|
||||
const imageDTO = createMockImageDTO();
|
||||
api.$progressData.setKey(1, {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO,
|
||||
});
|
||||
|
||||
expect(api.$acceptSelectedIsEnabled.get()).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Progress Event Handling', () => {
|
||||
it('should handle invocation progress events', () => {
|
||||
const progressEvent = createMockProgressEvent({
|
||||
item_id: 1,
|
||||
destination: sessionId,
|
||||
});
|
||||
|
||||
api.onInvocationProgressEvent(progressEvent);
|
||||
|
||||
const progressData = api.$progressData.get();
|
||||
expect(progressData[1]?.progressEvent).toBe(progressEvent);
|
||||
});
|
||||
|
||||
it('should ignore events for different sessions', () => {
|
||||
const progressEvent = createMockProgressEvent({
|
||||
item_id: 1,
|
||||
destination: 'different-session',
|
||||
});
|
||||
|
||||
api.onInvocationProgressEvent(progressEvent);
|
||||
|
||||
const progressData = api.$progressData.get();
|
||||
expect(progressData[1]).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should update existing progress data', () => {
|
||||
api.$progressData.setKey(1, {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO: createMockImageDTO(),
|
||||
});
|
||||
|
||||
const progressEvent = createMockProgressEvent({
|
||||
item_id: 1,
|
||||
destination: sessionId,
|
||||
});
|
||||
|
||||
api.onInvocationProgressEvent(progressEvent);
|
||||
|
||||
const progressData = api.$progressData.get();
|
||||
expect(progressData[1]?.progressEvent).toBe(progressEvent);
|
||||
expect(progressData[1]?.imageDTO).toBeTruthy();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Queue Item Status Change Handling', () => {
|
||||
it('should handle completed status and set last completed item', () => {
|
||||
const statusEvent = createMockQueueItemStatusChangedEvent({
|
||||
item_id: 1,
|
||||
destination: sessionId,
|
||||
status: 'completed',
|
||||
});
|
||||
|
||||
api.onQueueItemStatusChangedEvent(statusEvent);
|
||||
|
||||
expect(api.$lastCompletedItemId.get()).toBe(1);
|
||||
});
|
||||
|
||||
it('should handle in_progress status with switch_on_start', () => {
|
||||
mockApp._setAutoSwitchMode('switch_on_start');
|
||||
|
||||
const statusEvent = createMockQueueItemStatusChangedEvent({
|
||||
item_id: 1,
|
||||
destination: sessionId,
|
||||
status: 'in_progress',
|
||||
});
|
||||
|
||||
api.onQueueItemStatusChangedEvent(statusEvent);
|
||||
|
||||
expect(api.$lastStartedItemId.get()).toBe(1);
|
||||
});
|
||||
|
||||
it('should ignore events for different sessions', () => {
|
||||
const statusEvent = createMockQueueItemStatusChangedEvent({
|
||||
item_id: 1,
|
||||
destination: 'different-session',
|
||||
status: 'completed',
|
||||
});
|
||||
|
||||
api.onQueueItemStatusChangedEvent(statusEvent);
|
||||
|
||||
expect(api.$lastCompletedItemId.get()).toBe(null);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Items Changed Event Handling', () => {
|
||||
it('should update items and auto-select first item', async () => {
|
||||
const items = [createMockQueueItem({ item_id: 1 }), createMockQueueItem({ item_id: 2 })];
|
||||
|
||||
await api.onItemsChangedEvent(items);
|
||||
|
||||
expect(api.$items.get()).toBe(items);
|
||||
expect(api.$selectedItemId.get()).toBe(1);
|
||||
});
|
||||
|
||||
it('should clear selection when no items', async () => {
|
||||
api.$selectedItemId.set(1);
|
||||
|
||||
await api.onItemsChangedEvent([]);
|
||||
|
||||
expect(api.$selectedItemId.get()).toBe(null);
|
||||
});
|
||||
|
||||
it('should not change selection if item already selected', async () => {
|
||||
api.$selectedItemId.set(2);
|
||||
|
||||
const items = [createMockQueueItem({ item_id: 1 }), createMockQueueItem({ item_id: 2 })];
|
||||
|
||||
await api.onItemsChangedEvent(items);
|
||||
|
||||
expect(api.$selectedItemId.get()).toBe(2);
|
||||
});
|
||||
|
||||
it('should load images for completed items', async () => {
|
||||
const imageDTO = createMockImageDTO({ image_name: 'test-image.png' });
|
||||
mockApp._setImageDTO('test-image.png', imageDTO);
|
||||
|
||||
const items = [
|
||||
createMockQueueItem({
|
||||
item_id: 1,
|
||||
status: 'completed',
|
||||
}),
|
||||
];
|
||||
|
||||
await api.onItemsChangedEvent(items);
|
||||
|
||||
const progressData = api.$progressData.get();
|
||||
expect(progressData[1]?.imageDTO).toBe(imageDTO);
|
||||
});
|
||||
|
||||
it('should handle auto-switch on completion', async () => {
|
||||
mockApp._setAutoSwitchMode('switch_on_finish');
|
||||
api.$lastCompletedItemId.set(1);
|
||||
|
||||
const imageDTO = createMockImageDTO({ image_name: 'test-image.png' });
|
||||
mockApp._setImageDTO('test-image.png', imageDTO);
|
||||
|
||||
const items = [
|
||||
createMockQueueItem({
|
||||
item_id: 1,
|
||||
status: 'completed',
|
||||
}),
|
||||
];
|
||||
|
||||
await api.onItemsChangedEvent(items);
|
||||
|
||||
// Wait for async image loading - the loadImage promise needs to complete
|
||||
await new Promise((resolve) => {
|
||||
setTimeout(resolve, 50);
|
||||
});
|
||||
|
||||
expect(api.$selectedItemId.get()).toBe(1);
|
||||
// The lastCompletedItemId should be reset after the loadImage promise resolves
|
||||
expect(api.$lastCompletedItemId.get()).toBe(null);
|
||||
});
|
||||
|
||||
it('should clean up progress data for removed items', async () => {
|
||||
api.$progressData.setKey(999, {
|
||||
itemId: 999,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO: null,
|
||||
});
|
||||
|
||||
const items = [createMockQueueItem({ item_id: 1 })];
|
||||
|
||||
await api.onItemsChangedEvent(items);
|
||||
|
||||
const progressData = api.$progressData.get();
|
||||
expect(progressData[999]).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should clear progress data for canceled/failed items', async () => {
|
||||
api.$progressData.setKey(1, {
|
||||
itemId: 1,
|
||||
progressEvent: createMockProgressEvent({ item_id: 1 }),
|
||||
progressImage: null,
|
||||
imageDTO: createMockImageDTO(),
|
||||
});
|
||||
|
||||
const items = [createMockQueueItem({ item_id: 1, status: 'canceled' })];
|
||||
|
||||
await api.onItemsChangedEvent(items);
|
||||
|
||||
const progressData = api.$progressData.get();
|
||||
expect(progressData[1]?.progressEvent).toBe(null);
|
||||
expect(progressData[1]?.progressImage).toBe(null);
|
||||
expect(progressData[1]?.imageDTO).toBe(null);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Auto Switch', () => {
|
||||
it('should set auto switch mode', () => {
|
||||
api.setAutoSwitch('switch_on_finish');
|
||||
expect(mockApp.onAutoSwitchChange).toHaveBeenCalledWith('switch_on_finish');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Utility Methods', () => {
|
||||
it('should build isSelected computed correctly', () => {
|
||||
const isSelected = api.buildIsSelectedComputed(1);
|
||||
expect(isSelected.get()).toBe(false);
|
||||
|
||||
api.$selectedItemId.set(1);
|
||||
expect(isSelected.get()).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Cleanup', () => {
|
||||
it('should reset all state on cleanup', () => {
|
||||
api.$selectedItemId.set(1);
|
||||
api.$items.set([createMockQueueItem({ item_id: 1 })]);
|
||||
api.$lastStartedItemId.set(1);
|
||||
api.$lastCompletedItemId.set(1);
|
||||
api.$progressData.setKey(1, {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO: null,
|
||||
});
|
||||
|
||||
api.cleanup();
|
||||
|
||||
expect(api.$selectedItemId.get()).toBe(null);
|
||||
expect(api.$items.get()).toEqual([]);
|
||||
expect(api.$lastStartedItemId.get()).toBe(null);
|
||||
expect(api.$lastCompletedItemId.get()).toBe(null);
|
||||
expect(api.$progressData.get()).toEqual({});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Edge Cases and Error Handling', () => {
|
||||
describe('Selection with Empty or Single Item Lists', () => {
|
||||
it('should handle selection operations with single item', () => {
|
||||
const items = [createMockQueueItem({ item_id: 1 })];
|
||||
api.$items.set(items);
|
||||
api.$selectedItemId.set(1);
|
||||
|
||||
// Navigation should wrap around to the same item
|
||||
api.selectNext();
|
||||
expect(api.$selectedItemId.get()).toBe(1);
|
||||
|
||||
api.selectPrev();
|
||||
expect(api.$selectedItemId.get()).toBe(1);
|
||||
});
|
||||
|
||||
it('should handle selection operations with empty list', () => {
|
||||
api.$items.set([]);
|
||||
|
||||
api.selectFirst();
|
||||
api.selectLast();
|
||||
api.selectNext();
|
||||
api.selectPrev();
|
||||
|
||||
expect(api.$selectedItemId.get()).toBe(null);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Progress Data Edge Cases', () => {
|
||||
it('should handle progress updates with image data', () => {
|
||||
const progressEvent = createMockProgressEvent({
|
||||
item_id: 1,
|
||||
destination: sessionId,
|
||||
image: { width: 512, height: 512, dataURL: 'foo' },
|
||||
});
|
||||
|
||||
api.onInvocationProgressEvent(progressEvent);
|
||||
|
||||
const progressData = api.$progressData.get();
|
||||
expect(progressData[1]?.progressImage).toBe(progressEvent.image);
|
||||
expect(progressData[1]?.progressEvent).toBe(progressEvent);
|
||||
});
|
||||
|
||||
it('should preserve imageDTO when updating progress', () => {
|
||||
const imageDTO = createMockImageDTO();
|
||||
api.$progressData.setKey(1, {
|
||||
itemId: 1,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO,
|
||||
});
|
||||
|
||||
const progressEvent = createMockProgressEvent({
|
||||
item_id: 1,
|
||||
destination: sessionId,
|
||||
});
|
||||
|
||||
api.onInvocationProgressEvent(progressEvent);
|
||||
|
||||
const progressData = api.$progressData.get();
|
||||
expect(progressData[1]?.imageDTO).toBe(imageDTO);
|
||||
expect(progressData[1]?.progressEvent).toBe(progressEvent);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Auto-Switch Edge Cases', () => {
|
||||
it('should handle auto-switch when item is not in current items list', async () => {
|
||||
mockApp._setAutoSwitchMode('switch_on_start');
|
||||
api.$lastStartedItemId.set(999); // Non-existent item
|
||||
|
||||
const items = [createMockQueueItem({ item_id: 1 })];
|
||||
await api.onItemsChangedEvent(items);
|
||||
|
||||
// Should not switch to non-existent item
|
||||
expect(api.$selectedItemId.get()).toBe(1);
|
||||
expect(api.$lastStartedItemId.get()).toBe(999);
|
||||
});
|
||||
|
||||
it('should handle auto-switch on finish when image loading fails', async () => {
|
||||
mockApp._setAutoSwitchMode('switch_on_finish');
|
||||
api.$lastCompletedItemId.set(1);
|
||||
|
||||
// Mock image loading failure
|
||||
mockApp._setImageDTO('test-image.png', null);
|
||||
|
||||
const items = [
|
||||
createMockQueueItem({
|
||||
item_id: 1,
|
||||
status: 'completed',
|
||||
session: {
|
||||
id: sessionId,
|
||||
source_prepared_mapping: { canvas_output: ['test-node-id'] },
|
||||
results: {
|
||||
'test-node-id': {
|
||||
type: 'image_output',
|
||||
image: { image_name: 'test-image.png' },
|
||||
width: 512,
|
||||
height: 512,
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
];
|
||||
|
||||
await api.onItemsChangedEvent(items);
|
||||
|
||||
// Should not switch when image loading fails
|
||||
expect(api.$selectedItemId.get()).toBe(1);
|
||||
expect(api.$lastCompletedItemId.get()).toBe(1);
|
||||
});
|
||||
|
||||
it('should handle auto-switch on finish with slow image loading', async () => {
|
||||
mockApp._setAutoSwitchMode('switch_on_finish');
|
||||
api.$lastCompletedItemId.set(1);
|
||||
|
||||
const imageDTO = createMockImageDTO({ image_name: 'test-image.png' });
|
||||
mockApp._setImageDTO('test-image.png', imageDTO);
|
||||
mockApp._setLoadImageDelay(50); // Add delay to image loading
|
||||
|
||||
const items = [
|
||||
createMockQueueItem({
|
||||
item_id: 1,
|
||||
status: 'completed',
|
||||
session: {
|
||||
id: sessionId,
|
||||
source_prepared_mapping: { canvas_output: ['test-node-id'] },
|
||||
results: { 'test-node-id': { image: { image_name: 'test-image.png' } } },
|
||||
},
|
||||
}),
|
||||
];
|
||||
|
||||
await api.onItemsChangedEvent(items);
|
||||
|
||||
// Should switch after image loads - wait for both the delay and promise resolution
|
||||
await new Promise((resolve) => {
|
||||
setTimeout(resolve, 150);
|
||||
});
|
||||
|
||||
expect(api.$selectedItemId.get()).toBe(1);
|
||||
// The lastCompletedItemId should be reset after the loadImage promise resolves
|
||||
expect(api.$lastCompletedItemId.get()).toBe(null);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Concurrent Operations', () => {
|
||||
it('should handle rapid item changes', async () => {
|
||||
const items1 = [createMockQueueItem({ item_id: 1 })];
|
||||
const items2 = [createMockQueueItem({ item_id: 2 })];
|
||||
|
||||
// Fire multiple events rapidly
|
||||
const promise1 = api.onItemsChangedEvent(items1);
|
||||
const promise2 = api.onItemsChangedEvent(items2);
|
||||
|
||||
await Promise.all([promise1, promise2]);
|
||||
|
||||
// Should end up with the last set of items
|
||||
expect(api.$items.get()).toBe(items2);
|
||||
// The selectedItemId retains the old value (1) but $selectedItem will be null
|
||||
// because item 1 is no longer in the items list
|
||||
expect(api.$selectedItemId.get()).toBe(1);
|
||||
expect(api.$selectedItem.get()).toBe(null);
|
||||
});
|
||||
|
||||
it('should handle multiple progress events for same item', () => {
|
||||
const event1 = createMockProgressEvent({
|
||||
item_id: 1,
|
||||
destination: sessionId,
|
||||
percentage: 0.3,
|
||||
});
|
||||
const event2 = createMockProgressEvent({
|
||||
item_id: 1,
|
||||
destination: sessionId,
|
||||
percentage: 0.7,
|
||||
});
|
||||
|
||||
api.onInvocationProgressEvent(event1);
|
||||
api.onInvocationProgressEvent(event2);
|
||||
|
||||
const progressData = api.$progressData.get();
|
||||
expect(progressData[1]?.progressEvent).toBe(event2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Memory Management', () => {
|
||||
it('should clean up progress data for large number of items', async () => {
|
||||
// Create progress data for many items
|
||||
for (let i = 1; i <= 1000; i++) {
|
||||
api.$progressData.setKey(i, {
|
||||
itemId: i,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO: null,
|
||||
});
|
||||
}
|
||||
|
||||
// Only keep a few items
|
||||
const items = [createMockQueueItem({ item_id: 1 }), createMockQueueItem({ item_id: 2 })];
|
||||
|
||||
await api.onItemsChangedEvent(items);
|
||||
|
||||
const progressData = api.$progressData.get();
|
||||
const progressDataKeys = Object.keys(progressData);
|
||||
|
||||
// Should only have progress data for current items
|
||||
expect(progressDataKeys.length).toBeLessThanOrEqual(2);
|
||||
expect(progressData[1]).toBeDefined();
|
||||
expect(progressData[2]).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Event Subscription Management', () => {
|
||||
it('should handle multiple subscriptions and unsubscriptions', () => {
|
||||
const api2 = new StagingAreaApi();
|
||||
api2.connectToApp(sessionId, mockApp);
|
||||
const api3 = new StagingAreaApi();
|
||||
api3.connectToApp(sessionId, mockApp);
|
||||
|
||||
// All should be subscribed
|
||||
expect(mockApp.onItemsChanged).toHaveBeenCalledTimes(3);
|
||||
|
||||
api2.cleanup();
|
||||
api3.cleanup();
|
||||
|
||||
// Should not affect original api
|
||||
expect(api.$items.get()).toBeDefined();
|
||||
});
|
||||
|
||||
it('should handle events after cleanup', () => {
|
||||
api.cleanup();
|
||||
|
||||
// These should not crash
|
||||
const progressEvent = createMockProgressEvent({
|
||||
item_id: 1,
|
||||
destination: sessionId,
|
||||
});
|
||||
|
||||
api.onInvocationProgressEvent(progressEvent);
|
||||
|
||||
// State should remain clean - but the event handler still works
|
||||
// so it will add progress data even after cleanup
|
||||
const progressData = api.$progressData.get();
|
||||
expect(progressData[1]).toBeDefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,426 +0,0 @@
|
||||
import { clamp } from 'es-toolkit';
|
||||
import type { AutoSwitchMode } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import type { ProgressImage } from 'features/nodes/types/common';
|
||||
import type { MapStore } from 'nanostores';
|
||||
import { atom, computed, map } from 'nanostores';
|
||||
import type { ImageDTO, S } from 'services/api/types';
|
||||
import { objectEntries } from 'tsafe';
|
||||
|
||||
import { getOutputImageName } from './shared';
|
||||
|
||||
/**
|
||||
* Interface for the app-level API that the StagingAreaApi depends on.
|
||||
* This provides the connection between the staging area and the rest of the application.
|
||||
*/
|
||||
export type StagingAreaAppApi = {
|
||||
onDiscard?: (item: S['SessionQueueItem']) => void;
|
||||
onDiscardAll?: () => void;
|
||||
onAccept?: (item: S['SessionQueueItem'], imageDTO: ImageDTO) => void;
|
||||
onSelect?: (itemId: number) => void;
|
||||
onSelectPrev?: () => void;
|
||||
onSelectNext?: () => void;
|
||||
onSelectFirst?: () => void;
|
||||
onSelectLast?: () => void;
|
||||
getAutoSwitch: () => AutoSwitchMode;
|
||||
onAutoSwitchChange?: (mode: AutoSwitchMode) => void;
|
||||
getImageDTO: (imageName: string) => Promise<ImageDTO | null>;
|
||||
loadImage: (imageName: string) => Promise<HTMLImageElement>;
|
||||
onItemsChanged: (handler: (data: S['SessionQueueItem'][]) => Promise<void> | void) => () => void;
|
||||
onQueueItemStatusChanged: (handler: (data: S['QueueItemStatusChangedEvent']) => Promise<void> | void) => () => void;
|
||||
onInvocationProgress: (handler: (data: S['InvocationProgressEvent']) => Promise<void> | void) => () => void;
|
||||
};
|
||||
|
||||
/** Progress data for a single queue item */
|
||||
export type ProgressData = {
|
||||
itemId: number;
|
||||
progressEvent: S['InvocationProgressEvent'] | null;
|
||||
progressImage: ProgressImage | null;
|
||||
imageDTO: ImageDTO | null;
|
||||
};
|
||||
|
||||
/** Combined data for the currently selected item */
|
||||
export type SelectedItemData = {
|
||||
item: S['SessionQueueItem'];
|
||||
index: number;
|
||||
progressData: ProgressData;
|
||||
};
|
||||
|
||||
/** Creates initial progress data for a queue item */
|
||||
export const getInitialProgressData = (itemId: number): ProgressData => ({
|
||||
itemId,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO: null,
|
||||
});
|
||||
type ProgressDataMap = Record<number, ProgressData | undefined>;
|
||||
|
||||
/**
|
||||
* API for managing the Canvas Staging Area - a view of the image generation queue.
|
||||
* Provides reactive state management for pending, in-progress, and completed images.
|
||||
* Users can accept images to place on canvas, discard them, navigate between items,
|
||||
* and configure auto-switching behavior.
|
||||
*/
|
||||
export class StagingAreaApi {
|
||||
/** The current session ID. */
|
||||
_sessionId: string | null = null;
|
||||
|
||||
/** The app API */
|
||||
_app: StagingAreaAppApi | null = null;
|
||||
|
||||
/** A set of subscriptions to be cleaned up when we are finished with a session */
|
||||
_subscriptions = new Set<() => void>();
|
||||
|
||||
/** Item ID of the last started item. Used for auto-switch on start. */
|
||||
$lastStartedItemId = atom<number | null>(null);
|
||||
|
||||
/** Item ID of the last completed item. Used for auto-switch on completion. */
|
||||
$lastCompletedItemId = atom<number | null>(null);
|
||||
|
||||
/** All queue items for the current session. */
|
||||
$items = atom<S['SessionQueueItem'][]>([]);
|
||||
|
||||
/** Progress data for all items including events, images, and ImageDTOs. */
|
||||
$progressData = map<ProgressDataMap>({});
|
||||
|
||||
/** ID of the currently selected queue item, or null if none selected. */
|
||||
$selectedItemId = atom<number | null>(null);
|
||||
|
||||
/** Total number of items in the queue. */
|
||||
$itemCount = computed([this.$items], (items) => items.length);
|
||||
|
||||
/** Whether there are any items in the queue. */
|
||||
$hasItems = computed([this.$items], (items) => items.length > 0);
|
||||
|
||||
/** Whether there are any pending or in-progress items. */
|
||||
$isPending = computed([this.$items], (items) =>
|
||||
items.some((item) => item.status === 'pending' || item.status === 'in_progress')
|
||||
);
|
||||
|
||||
/** The currently selected queue item with its index and progress data, or null if none selected. */
|
||||
$selectedItem = computed(
|
||||
[this.$items, this.$selectedItemId, this.$progressData],
|
||||
(items, selectedItemId, progressData) => {
|
||||
if (items.length === 0) {
|
||||
return null;
|
||||
}
|
||||
if (selectedItemId === null) {
|
||||
return null;
|
||||
}
|
||||
const item = items.find(({ item_id }) => item_id === selectedItemId);
|
||||
if (!item) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
item,
|
||||
index: items.findIndex(({ item_id }) => item_id === selectedItemId),
|
||||
progressData: progressData[selectedItemId] || getInitialProgressData(selectedItemId),
|
||||
};
|
||||
}
|
||||
);
|
||||
|
||||
/** The ImageDTO of the currently selected item, or null if none available. */
|
||||
$selectedItemImageDTO = computed([this.$selectedItem], (selectedItem) => {
|
||||
return selectedItem?.progressData.imageDTO ?? null;
|
||||
});
|
||||
|
||||
/** The index of the currently selected item, or null if none selected. */
|
||||
$selectedItemIndex = computed([this.$selectedItem], (selectedItem) => {
|
||||
return selectedItem?.index ?? null;
|
||||
});
|
||||
|
||||
/** Selects a queue item by ID. */
|
||||
select = (itemId: number) => {
|
||||
this.$selectedItemId.set(itemId);
|
||||
this._app?.onSelect?.(itemId);
|
||||
};
|
||||
|
||||
/** Selects the next item in the queue, wrapping to the first item if at the end. */
|
||||
selectNext = () => {
|
||||
const selectedItem = this.$selectedItem.get();
|
||||
if (selectedItem === null) {
|
||||
return;
|
||||
}
|
||||
const items = this.$items.get();
|
||||
const nextIndex = (selectedItem.index + 1) % items.length;
|
||||
const nextItem = items[nextIndex];
|
||||
if (!nextItem) {
|
||||
return;
|
||||
}
|
||||
this.$selectedItemId.set(nextItem.item_id);
|
||||
this._app?.onSelectNext?.();
|
||||
};
|
||||
|
||||
/** Selects the previous item in the queue, wrapping to the last item if at the beginning. */
|
||||
selectPrev = () => {
|
||||
const selectedItem = this.$selectedItem.get();
|
||||
if (selectedItem === null) {
|
||||
return;
|
||||
}
|
||||
const items = this.$items.get();
|
||||
const prevIndex = (selectedItem.index - 1 + items.length) % items.length;
|
||||
const prevItem = items[prevIndex];
|
||||
if (!prevItem) {
|
||||
return;
|
||||
}
|
||||
this.$selectedItemId.set(prevItem.item_id);
|
||||
this._app?.onSelectPrev?.();
|
||||
};
|
||||
|
||||
/** Selects the first item in the queue. */
|
||||
selectFirst = () => {
|
||||
const items = this.$items.get();
|
||||
const first = items.at(0);
|
||||
if (!first) {
|
||||
return;
|
||||
}
|
||||
this.$selectedItemId.set(first.item_id);
|
||||
this._app?.onSelectFirst?.();
|
||||
};
|
||||
|
||||
/** Selects the last item in the queue. */
|
||||
selectLast = () => {
|
||||
const items = this.$items.get();
|
||||
const last = items.at(-1);
|
||||
if (!last) {
|
||||
return;
|
||||
}
|
||||
this.$selectedItemId.set(last.item_id);
|
||||
this._app?.onSelectLast?.();
|
||||
};
|
||||
|
||||
/** Discards the currently selected item and selects the next available item. */
|
||||
discardSelected = () => {
|
||||
const selectedItem = this.$selectedItem.get();
|
||||
if (selectedItem === null) {
|
||||
return;
|
||||
}
|
||||
const items = this.$items.get();
|
||||
const nextIndex = clamp(selectedItem.index + 1, 0, items.length - 1);
|
||||
const nextItem = items[nextIndex];
|
||||
if (nextItem) {
|
||||
this.$selectedItemId.set(nextItem.item_id);
|
||||
} else {
|
||||
this.$selectedItemId.set(null);
|
||||
}
|
||||
this._app?.onDiscard?.(selectedItem.item);
|
||||
};
|
||||
|
||||
/** Whether the discard selected action is enabled. */
|
||||
$discardSelectedIsEnabled = computed([this.$selectedItem], (selectedItem) => {
|
||||
if (selectedItem === null) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
/** Connects to the app, registering listeners and such */
|
||||
connectToApp = (sessionId: string, app: StagingAreaAppApi) => {
|
||||
if (this._sessionId !== sessionId) {
|
||||
this.cleanup();
|
||||
this._sessionId = sessionId;
|
||||
}
|
||||
this._app = app;
|
||||
|
||||
this._subscriptions.add(this._app.onItemsChanged(this.onItemsChangedEvent));
|
||||
this._subscriptions.add(this._app.onQueueItemStatusChanged(this.onQueueItemStatusChangedEvent));
|
||||
this._subscriptions.add(this._app.onInvocationProgress(this.onInvocationProgressEvent));
|
||||
};
|
||||
|
||||
/** Discards all items in the queue. */
|
||||
discardAll = () => {
|
||||
this.$selectedItemId.set(null);
|
||||
this._app?.onDiscardAll?.();
|
||||
};
|
||||
|
||||
/** Accepts the currently selected item if an image is available. */
|
||||
acceptSelected = () => {
|
||||
const selectedItem = this.$selectedItem.get();
|
||||
if (selectedItem === null) {
|
||||
return;
|
||||
}
|
||||
const progressData = this.$progressData.get();
|
||||
const datum = progressData[selectedItem.item.item_id];
|
||||
if (!datum || !datum.imageDTO) {
|
||||
return;
|
||||
}
|
||||
this._app?.onAccept?.(selectedItem.item, datum.imageDTO);
|
||||
};
|
||||
|
||||
/** Whether the accept selected action is enabled. */
|
||||
$acceptSelectedIsEnabled = computed([this.$selectedItem, this.$progressData], (selectedItem, progressData) => {
|
||||
if (selectedItem === null) {
|
||||
return false;
|
||||
}
|
||||
const datum = progressData[selectedItem.item.item_id];
|
||||
return !!datum && !!datum.imageDTO;
|
||||
});
|
||||
|
||||
/** Sets the auto-switch mode. */
|
||||
setAutoSwitch = (mode: AutoSwitchMode) => {
|
||||
this._app?.onAutoSwitchChange?.(mode);
|
||||
};
|
||||
|
||||
/** Handles invocation progress events from the WebSocket. */
|
||||
onInvocationProgressEvent = (data: S['InvocationProgressEvent']) => {
|
||||
if (data.destination !== this._sessionId) {
|
||||
return;
|
||||
}
|
||||
setProgress(this.$progressData, data);
|
||||
};
|
||||
|
||||
/** Handles queue item status change events from the WebSocket. */
|
||||
onQueueItemStatusChangedEvent = (data: S['QueueItemStatusChangedEvent']) => {
|
||||
if (data.destination !== this._sessionId) {
|
||||
return;
|
||||
}
|
||||
if (data.status === 'completed') {
|
||||
/**
|
||||
* There is an unpleasant bit of indirection here. When an item is completed, and auto-switch is set to
|
||||
* switch_on_finish, we want to load the image and switch to it. In this socket handler, we don't have
|
||||
* access to the full queue item, which we need to get the output image and load it. We get the full
|
||||
* queue items as part of the list query, so it's rather inefficient to fetch it again here.
|
||||
*
|
||||
* To reduce the number of extra network requests, we instead store this item as the last completed item.
|
||||
* Then in the progress data sync effect, we process the queue item load its image.
|
||||
*/
|
||||
this.$lastCompletedItemId.set(data.item_id);
|
||||
}
|
||||
if (data.status === 'in_progress' && this._app?.getAutoSwitch() === 'switch_on_start') {
|
||||
this.$lastStartedItemId.set(data.item_id);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Handles queue items changed events. Updates items, manages progress data,
|
||||
* handles auto-selection, and implements auto-switch behavior.
|
||||
*/
|
||||
onItemsChangedEvent = async (items: S['SessionQueueItem'][]) => {
|
||||
const oldItems = this.$items.get();
|
||||
|
||||
if (items === oldItems) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (items.length === 0) {
|
||||
// If there are no items, cannot have a selected item.
|
||||
this.$selectedItemId.set(null);
|
||||
} else if (this.$selectedItemId.get() === null && items.length > 0) {
|
||||
// If there is no selected item but there are items, select the first one.
|
||||
this.$selectedItemId.set(items[0]?.item_id ?? null);
|
||||
}
|
||||
|
||||
const progressData = this.$progressData.get();
|
||||
|
||||
const toDelete: number[] = [];
|
||||
const toUpdate: ProgressData[] = [];
|
||||
|
||||
for (const [id, datum] of objectEntries(progressData)) {
|
||||
if (!datum) {
|
||||
toDelete.push(id);
|
||||
continue;
|
||||
}
|
||||
const item = items.find(({ item_id }) => item_id === datum.itemId);
|
||||
if (!item) {
|
||||
toDelete.push(datum.itemId);
|
||||
} else if (item.status === 'canceled' || item.status === 'failed') {
|
||||
toUpdate.push({
|
||||
...datum,
|
||||
progressEvent: null,
|
||||
progressImage: null,
|
||||
imageDTO: null,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
for (const item of items) {
|
||||
const datum = progressData[item.item_id];
|
||||
|
||||
if (this.$lastStartedItemId.get() === item.item_id && this._app?.getAutoSwitch() === 'switch_on_start') {
|
||||
this.$selectedItemId.set(item.item_id);
|
||||
this.$lastStartedItemId.set(null);
|
||||
}
|
||||
|
||||
if (datum?.imageDTO) {
|
||||
continue;
|
||||
}
|
||||
const outputImageName = getOutputImageName(item);
|
||||
if (!outputImageName) {
|
||||
continue;
|
||||
}
|
||||
const imageDTO = await this._app?.getImageDTO(outputImageName);
|
||||
if (!imageDTO) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// This is the load logic mentioned in the comment in the QueueItemStatusChangedEvent handler above.
|
||||
if (this.$lastCompletedItemId.get() === item.item_id && this._app?.getAutoSwitch() === 'switch_on_finish') {
|
||||
this._app.loadImage(imageDTO.image_url).then(() => {
|
||||
this.$selectedItemId.set(item.item_id);
|
||||
this.$lastCompletedItemId.set(null);
|
||||
});
|
||||
}
|
||||
|
||||
toUpdate.push({
|
||||
...getInitialProgressData(item.item_id),
|
||||
...datum,
|
||||
imageDTO,
|
||||
});
|
||||
}
|
||||
|
||||
for (const itemId of toDelete) {
|
||||
this.$progressData.setKey(itemId, undefined);
|
||||
}
|
||||
|
||||
for (const datum of toUpdate) {
|
||||
this.$progressData.setKey(datum.itemId, datum);
|
||||
}
|
||||
|
||||
this.$items.set(items);
|
||||
};
|
||||
|
||||
/** Creates a computed value that returns true if the given item ID is selected. */
|
||||
buildIsSelectedComputed = (itemId: number) => {
|
||||
return computed([this.$selectedItemId], (selectedItemId) => {
|
||||
return selectedItemId === itemId;
|
||||
});
|
||||
};
|
||||
|
||||
/** Cleans up all state and unsubscribes from all events. */
|
||||
cleanup = () => {
|
||||
this.$lastStartedItemId.set(null);
|
||||
this.$lastCompletedItemId.set(null);
|
||||
this.$items.set([]);
|
||||
this.$progressData.set({});
|
||||
this.$selectedItemId.set(null);
|
||||
this._subscriptions.forEach((unsubscribe) => unsubscribe());
|
||||
this._subscriptions.clear();
|
||||
};
|
||||
}
|
||||
|
||||
/** Updates progress data for a queue item with the latest progress event. */
|
||||
const setProgress = ($progressData: MapStore<ProgressDataMap>, data: S['InvocationProgressEvent']) => {
|
||||
const progressData = $progressData.get();
|
||||
const current = progressData[data.item_id];
|
||||
if (current) {
|
||||
const next = { ...current };
|
||||
next.progressEvent = data;
|
||||
if (data.image) {
|
||||
next.progressImage = data.image;
|
||||
}
|
||||
$progressData.set({
|
||||
...progressData,
|
||||
[data.item_id]: next,
|
||||
});
|
||||
} else {
|
||||
$progressData.set({
|
||||
...progressData,
|
||||
[data.item_id]: {
|
||||
itemId: data.item_id,
|
||||
progressEvent: data,
|
||||
progressImage: data.image ?? null,
|
||||
imageDTO: null,
|
||||
},
|
||||
});
|
||||
}
|
||||
};
|
||||
@@ -4,7 +4,6 @@ import { ToolBrushButton } from 'features/controlLayers/components/Tool/ToolBrus
|
||||
import { ToolColorPickerButton } from 'features/controlLayers/components/Tool/ToolColorPickerButton';
|
||||
import { ToolMoveButton } from 'features/controlLayers/components/Tool/ToolMoveButton';
|
||||
import { ToolRectButton } from 'features/controlLayers/components/Tool/ToolRectButton';
|
||||
import React from 'react';
|
||||
|
||||
import { ToolEraserButton } from './ToolEraserButton';
|
||||
import { ToolViewButton } from './ToolViewButton';
|
||||
|
||||
@@ -6,7 +6,7 @@ import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
import { z } from 'zod';
|
||||
import { z } from 'zod/v4';
|
||||
|
||||
const zMode = z.enum(['fill', 'contain', 'cover']);
|
||||
type Mode = z.infer<typeof zMode>;
|
||||
|
||||
@@ -7,7 +7,7 @@ import { useEntityAdapter } from 'features/controlLayers/contexts/EntityAdapterC
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
import { TRANSPARENCY_CHECKERBOARD_PATTERN_DARK_DATAURL } from 'features/controlLayers/konva/patterns/transparency-checkerboard-pattern';
|
||||
import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/selectors';
|
||||
import React, { memo, useEffect, useMemo, useRef } from 'react';
|
||||
import { memo, useEffect, useMemo, useRef } from 'react';
|
||||
import { useSelector } from 'react-redux';
|
||||
|
||||
const ChakraCanvas = chakra.canvas;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user