Compare commits

...

117 Commits

Author SHA1 Message Date
psychedelicious
82fb897b62 chore(ui): lint 2025-07-12 14:56:57 +10:00
psychedelicious
192b00d969 chore: bump version to v6.0.2 2025-07-12 14:56:57 +10:00
psychedelicious
7bb25ef1b4 fix(ui): gallery dnd 2025-07-12 14:56:57 +10:00
psychedelicious
62f52c74a8 fix(ui): linked negative style prompt not passed in
Closes #8256
2025-07-12 10:22:17 +10:00
psychedelicious
97439c1daa fix(ui): native context menu shown on right click on short fat images
Closes #8254
2025-07-12 10:22:17 +10:00
psychedelicious
b23bff1b53 fix(ui): center staging area images 2025-07-12 10:22:17 +10:00
psychedelicious
d9a1efbabf fix(ui): staging area images may be slightly too large 2025-07-12 10:22:17 +10:00
psychedelicious
d4e903ee2d chore: bump version to v6.0.1 2025-07-12 10:22:17 +10:00
Kevin Turner
bb3e5d16d8 feat(Model Manager): refuse to download a file when there's insufficient space 2025-07-12 10:14:25 +10:00
psychedelicious
e62d3f01a8 feat(app): better error message for failed model probe
- Old: No valid config found
- New: Unable to determine model type
2025-07-11 23:35:43 +10:00
psychedelicious
757ecdbf82 build(ui): downgrade idb-keyval
We have increased error rates after updating this package. Let's try
downgrading to see if that fixes the issue.
2025-07-11 15:00:10 +10:00
psychedelicious
694c85b041 fix(ui): language file filenames
Need to replace the underscores w/ dashes - this was missed in #8246.
2025-07-11 14:21:41 +10:00
psychedelicious
988d7ba24c chore: bump version to v6.0.1rc1 2025-07-11 09:05:24 +10:00
psychedelicious
ac981879ef fix(ui): runtime errors related to calling reduce on array iterator
Fix an issue in certain browsers/builds causing a runtime error.

A zod enum has a .options property, which is an array of all the options
for the enum. This is handy for when you need to derive something from a
zod schema.

In this case, we represented the possible focus regions in the zod enum,
then derived a mapping of region names to set of target HTML elements.
Why isn't important, but suffice to say, we were using the .options
property for this.

But actually, we were using .options.values(), then calling .reduce() on
that. An array's .values() method returns an _array iterator_. Array
iterators do not have .reduce() methods!

Except, apparently in some environments they do - it depends on the JS
engine and whether or not polyfills for iterator helpers were included
in the build.

Turns out my dev environment - and most user browsers - do provide
.reduce(), so we didn't catch this error. It took a large deployment and
error monitoring to catch it.

I've refactored the code to totally avoid deriving data from zod in this
way.
2025-07-11 08:25:47 +10:00
psychedelicious
fc71849c24 feat(app): expose a cursor, not a connection in db util 2025-07-11 08:20:06 +10:00
psychedelicious
a19aa3b032 feat(app): db abstraction to prevent threading conflicts
- Add a context manager to the SqliteDatabase class which abstracts away
creating a transaction, committing it on success and rolling back on
error.
- Use it everywhere. The context manager should be exited before
returning results. No business logic changes should be present.
2025-07-11 08:20:06 +10:00
psychedelicious
ef4d5d7377 feat(ui): virtualized list for staging area
Make the staging area a virtualized list so it doesn't choke when there
are a large number (i.e. more than a few hundred) of queue items.
2025-07-11 07:50:57 +10:00
Mary Hipp Rogers
6b0dfd8427 dont reset canvas if studio is loaded with canvas destination (#8252)
Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2025-07-10 09:36:41 -04:00
psychedelicious
471c010217 fix(ui): invalid language crashes app
- Apparently locales must use hyphens instead of underscores. This must
have been a fairly recent change that we didn't catch. It caused i18n to
throw for Brasilian Portuguese and both Simplified and Traditional
Mandarin. Change the locales to use the right strings.
- Move the theme + locale provider inside of the error boundary. This
allows errors with locals to be caught by the error boundary instead of
hard-crashing the app. The error screen is unstyled if this happens but
at least it has the reset button.
- Add a migration for the system slice to fix existing users' language
selections. For example, if the user had an incorrect language setting
of `zh_CN`, it will be changed to the correct `zh-CN`.
2025-07-10 14:27:36 +10:00
psychedelicious
b1193022f7 fix(ui): sometimes images added to gallery show as placeholder only
The range-based fetching logic had a subtle bug - it didn't keep track
of what the _current_ visible range is - only the ranges that the user
last scrolled to.

When an image was added to the gallery, the logic saw that the images
had changed, but thought it had already loaded everything it needed to,
so it didn't load the new image.

The updated logic tracks the current visible range separately from the
accumulated scroll ranges to address this issue.
2025-07-10 14:27:36 +10:00
psychedelicious
2152ca092c fix(ui): workaround for dockview bug that lets you drag tabs in certain ways 2025-07-10 14:27:36 +10:00
psychedelicious
ccc62ba56d perf(ui): revised range-based fetching strategy
When the user scrolls in the gallery, we are alerted of the new range of
visible images. Then we fetch those specific images.

Previously, each change of range triggered a throttled function to fetch
that range. The throttle timeout was 100ms.

Now, each change of range appends that range to a list of ranges and
triggers the throttled fetch. The timeout is increased to 500ms, but to
compensate, each fetch handles all ranges that had been accumulated
since the last fetch.

The result is far fewer network requests, but each of them gets more
images.
2025-07-10 14:27:36 +10:00
psychedelicious
9cf82de8c5 fix(ui): check for absolute value of scroll velocity to handle scrolling up 2025-07-10 14:27:36 +10:00
psychedelicious
aced349152 perf(ui): increase viewport in gallery
This allows us to prefetch more images and reduce how often placeholders
are shown as we fetch more images in the gallery.
2025-07-10 14:27:36 +10:00
psychedelicious
0d67ee6548 tests(ui): fix logging mock 2025-07-09 23:15:25 +10:00
psychedelicious
03c21d1607 fix(ui): gallery not updating when saving staging area image 2025-07-09 23:15:25 +10:00
psychedelicious
752e8db1f5 tidy(ui): demote logging in nav api to trace 2025-07-09 23:15:25 +10:00
psychedelicious
85fc861dd9 chore(ui): lint 2025-07-09 23:15:25 +10:00
psychedelicious
458cbfd874 fix(ui): selected model not highlighted 2025-07-09 23:15:25 +10:00
psychedelicious
04331c070a fix(ui): set denoise w/h when running flux fill 2025-07-09 23:15:25 +10:00
psychedelicious
632ddf0cb4 tests(ui): update tests for navigation api 2025-07-09 23:15:25 +10:00
psychedelicious
2b193ff416 fix(ui): delete stored state on error & save new state 2025-07-09 23:15:25 +10:00
psychedelicious
96ee394f9e refactor(ui): use dockview's own ser/de for persistence 2025-07-09 23:15:25 +10:00
psychedelicious
0badc80c0c fix(ui): ignore disabled ref images in readiness checks 2025-07-09 23:15:25 +10:00
psychedelicious
78e6cbf96e fix(ui): default tab is generate 2025-07-09 23:15:25 +10:00
psychedelicious
0b969a661b fix(ui): remove dep on focus from useDeleteImage 2025-07-09 23:15:25 +10:00
psychedelicious
6fe47ec9f8 feat(ui): improve ref image model autoswitch logic 2025-07-09 23:15:25 +10:00
Kent Keirsey
3850dd61f8 update comment 2025-07-09 23:15:25 +10:00
Kent Keirsey
75520eaf0f Match Chatgpt4o and kontext names exactly 2025-07-09 23:15:25 +10:00
Kent Keirsey
10e88c58c1 fix and lint 2025-07-09 23:15:25 +10:00
Kent Keirsey
30ed4dbd92 lint 2025-07-09 23:15:25 +10:00
Kent Keirsey
ed9c090f33 fixes 2025-07-09 23:15:25 +10:00
Kent Keirsey
d29f65ed22 lint fixes 2025-07-09 23:15:25 +10:00
Kent Keirsey
2062ec8ac0 Update invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts
Co-authored-by: Mary Hipp Rogers <maryhipp@gmail.com>
2025-07-09 23:15:25 +10:00
Cursor Agent
49e818338a Changes from background composer bc-abfadb27-a265-41a7-b0db-829879f4701e 2025-07-09 23:15:25 +10:00
Cursor Agent
1caab2b9c4 Implement automatic reference image model switching on base model change
Co-authored-by: kent <kent@invoke.ai>
2025-07-09 23:15:25 +10:00
psychedelicious
50079ea349 fix(ui): big red cancel button has diff behaviour than staging discard 2025-07-09 23:15:25 +10:00
psychedelicious
fffa1b24c4 fix(ui): isStaging selector could return wrong query cache 2025-07-09 23:15:25 +10:00
psychedelicious
a6d6170387 fix(ui): discarding 1 item when 2 items left in staging area discards both 2025-07-09 23:15:25 +10:00
psychedelicious
e5fceb0448 fix(ui): whole app scrolls while selecting staging area image 2025-07-09 23:15:25 +10:00
psychedelicious
059baf5b29 chore(ui): lint 2025-07-09 23:15:25 +10:00
psychedelicious
1be8a9a310 fix(ui): add metadata i18nKey to handler; fixes metadata toasts 2025-07-09 23:15:25 +10:00
psychedelicious
7adc33e04d refactor(ui): metadata recall buttons & hotkeys (WIP) 2025-07-09 23:15:25 +10:00
psychedelicious
7f2dd22d47 refactor(ui): metadata recall buttons & hotkeys (WIP) 2025-07-09 23:15:25 +10:00
psychedelicious
bb50f4b8a2 fix(ui): prevent panels from growing on init
This works but I think a better solution is to use dockview's provided
serialization API to store and restore layouts.
2025-07-09 23:15:25 +10:00
psychedelicious
a48958e0d4 chore(ui): lint 2025-07-09 23:15:25 +10:00
psychedelicious
e3a1e9af53 feat(ui): staging area updates
- Smaller staged image previews.
- Move autoswitch buttons to staging area toolbar, remove from settings
popover and the little three-dots menu. Use persisted autoswitch
setting, which is renamed from `defaultAutoSwitch` to
`stagingAreaAutoSwitch`.
- Fix issue with misaligned border radii in staging area preview images.
Required small changes to DndImage and its usage elsewhere.
- Fix issue where staging area toolbar could show up without any
previews in the list.
- Migrate canvas settings slice to use zod schema and inferred types for
its state.
2025-07-09 23:15:25 +10:00
psychedelicious
c6fe11c42f fix(ui): disable gallery hotkeys when in staging area 2025-07-09 23:15:25 +10:00
psychedelicious
4eb1bd67df fix(ui): hide staging area when there are no items 2025-07-09 23:15:25 +10:00
psychedelicious
c376f914d2 chore: bump version v6.0.0 2025-07-09 23:15:25 +10:00
Kent Keirsey
b5d1c47ef7 final link fix 2025-07-09 10:17:38 +10:00
Kent Keirsey
004a52ca65 fix to direct links 2025-07-09 10:17:38 +10:00
Kent Keirsey
b1d5a51ddf add-quantized-kontext-dev 2025-07-09 10:17:38 +10:00
Kent Keirsey
2b2498eaa1 fix prettier quirk 2025-07-08 14:54:29 -04:00
Kent Keirsey
10dda4440e Fix label 2025-07-08 14:54:29 -04:00
Cursor Agent
98f78abefa Add default auto-switch mode setting for canvas sessions
Co-authored-by: kent <kent@invoke.ai>
2025-07-08 14:54:29 -04:00
Mary Hipp Rogers
cc93fa270f update whats new for v6 (#8234)
Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-07-08 18:24:33 +00:00
Mary Hipp Rogers
014b27680f fix flux kontext error (#8235)
Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-07-08 13:42:48 -04:00
Mary Hipp Rogers
c3d8f875de if on generate tab, recall dimensions instead of bbox (#8233)
Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-07-08 13:09:21 -04:00
Mary Hipp Rogers
79f9dc6e4a fix(ui): dont show option to add new layer from if on generate tab (#8231)
* dont show option to add new layer from if on generate tab

* only disable width/height recall is staging AND canvas tab

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-07-08 11:46:54 -04:00
psychedelicious
6e1c0c1105 chore: bump version to v6.0.0rc5 2025-07-08 11:26:47 -04:00
Mary Hipp Rogers
0362524040 remove hard-coded flux kontext dev guidance (#8230)
Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-07-08 10:26:20 -04:00
psychedelicious
dc6656459b docs(ui): updated comments for navigation api 2025-07-08 07:30:36 -04:00
psychedelicious
3ea1b97f6f fix(ui): protect against getting stuck on tab loading screen 2025-07-08 07:30:36 -04:00
psychedelicious
a7c7405ccc feat(ui): style model picker selected item 2025-07-08 07:28:07 -04:00
psychedelicious
c391f1117a fix(ui): traverse groups when finding selected model in picker 2025-07-08 07:28:07 -04:00
psychedelicious
b1e2cb8401 fix(ui): queue tab list of queue items
Reverted incomplete change to how queue items are listed. In the future
I think we should redo it to work like the gallery. For now, it is back
the way it was in v5.
2025-07-08 07:22:51 -04:00
Emmanuel Ferdman
db6af134b7 fix: resolve FastAPI deprecation warning for example fields
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-07-08 20:54:08 +10:00
Emmanuel Ferdman
7e6cffb00c fix: resolve FastAPI deprecation warning for example fields
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-07-08 20:54:08 +10:00
psychedelicious
5b187bcb00 fix(ui): pull bbox into ref image component 2025-07-08 14:54:43 +10:00
psychedelicious
0843d609a3 feat(ui): add list of warnings in tooltip on ref image 2025-07-08 14:54:43 +10:00
Kent Keirsey
95bd9cef18 Lint 2025-07-08 14:54:43 +10:00
Kent Keirsey
931d6521f6 Adds bbox to ref image button 2025-07-08 14:54:43 +10:00
psychedelicious
e37665ff59 tests(ui): add wiggle room to timeout tests 2025-07-08 12:55:33 +10:00
psychedelicious
56857fbbe6 tests(ui): add tests for panel storage 2025-07-08 12:55:33 +10:00
psychedelicious
43cfb8a574 tests(ui): get tests passing
Still need tests for panel storage.
2025-07-08 12:55:33 +10:00
psychedelicious
05b1682d15 fix(ui): handle collapsed panels when rehydrating their state 2025-07-08 12:55:33 +10:00
psychedelicious
69a08ee7f2 feat(ui): panel state persistence (WIP) 2025-07-08 12:55:33 +10:00
psychedelicious
18212c7d8a feat(ui): clean up navigation API surface and add comments 2025-07-08 12:55:33 +10:00
psychedelicious
7de26f8e69 feat(ui): clean up auto layout context for panels 2025-07-08 12:55:33 +10:00
Kent Keirsey
0652b12a6f Address comments 2025-07-08 12:31:11 +10:00
Kent Keirsey
43a361a00f prettier 2025-07-08 12:31:11 +10:00
Kent Keirsey
cf68ad9cbc update links to playlist instead of video 2025-07-08 12:31:11 +10:00
Kent Keirsey
ec02a39325 fixes 2025-07-08 12:31:11 +10:00
Kent Keirsey
e52d7a05c2 Update support links. 2025-07-08 12:31:11 +10:00
Cursor Agent
c9d4e2b761 Refactor support videos modal to simplify video and playlist handling
Co-authored-by: kent <kent@invoke.ai>
2025-07-08 12:31:11 +10:00
Kent Keirsey
ac26aa9508 fix 2025-07-08 12:31:11 +10:00
Cursor Agent
9ff6ada15b Add support for video playlists in support videos modal
Co-authored-by: kent <kent@invoke.ai>
2025-07-08 12:31:11 +10:00
psychedelicious
e81a115169 chore(ui): lint 2025-07-08 12:23:57 +10:00
Kent Keirsey
52827807de remove ref image from upscale 2025-07-08 12:23:57 +10:00
Kent Keirsey
b631de4cb5 consistency 2025-07-08 12:20:08 +10:00
Kent Keirsey
099ebdbc37 fix 2025-07-08 12:20:08 +10:00
psychedelicious
4de6549be9 refactor(ui): track discarded items instead of using delete method 2025-07-08 12:12:55 +10:00
psychedelicious
368be34949 chore(ui): lint 2025-07-08 12:12:55 +10:00
psychedelicious
5baa4bd916 refactor(ui): use cancelation for staging area (mostly) 2025-07-08 12:12:55 +10:00
psychedelicious
4229377532 fix(app): ensure cancel events are emitted for current item when bulk canceling
There was a bug where bulk cancel operations would cancel the current
queue item in the DB but not emit the status changed events correctly.
2025-07-08 12:12:55 +10:00
psychedelicious
2610772ffd feat(ui): tighten up launchpad content to fit better 2025-07-08 08:57:44 +10:00
psychedelicious
193de6a8f2 feat(ui): add launchpad container component 2025-07-08 08:57:44 +10:00
psychedelicious
7ea343c787 tidy(ui): remove "staging" from the new settings verbiage 2025-07-08 07:10:55 +10:00
Kent Keirsey
12179dabba fix prettier 2025-07-08 07:10:55 +10:00
Cursor Agent
ef135f9923 Add option to save all staging images to gallery in canvas mode
Co-authored-by: kent <kent@invoke.ai>
2025-07-08 07:10:55 +10:00
Mary Hipp
e6c67cc00f update toast for prompt expansion failed 2025-07-08 06:42:00 +10:00
psychedelicious
179b988148 fix(ui): prompt concat derived state recall 2025-07-08 06:37:43 +10:00
psychedelicious
d913a3c85b fix(ui): reset selected ref image when replacing all
Fixes an unhandled error in a selector that can throw.
2025-07-08 06:37:43 +10:00
psychedelicious
e79525c40c docs(ui): update comments 2025-07-08 06:11:32 +10:00
psychedelicious
f409f913ac fix(ui): navigation api usage 2025-07-08 06:11:32 +10:00
Mary Hipp
7a79f61d4c add claude nodes to blacklist for publishing 2025-07-08 05:50:40 +10:00
151 changed files with 5277 additions and 3751 deletions

View File

@@ -72,7 +72,7 @@ async def upload_image(
resize_to: Optional[str] = Body(
default=None,
description=f"Dimensions to resize the image to, must be stringified tuple of 2 integers. Max total pixel count: {ResizeToDimensions.MAX_SIZE}",
example='"[1024,1024]"',
examples=['"[1024,1024]"'],
),
metadata: Optional[str] = Body(
default=None,

View File

@@ -292,7 +292,7 @@ async def get_hugging_face_models(
)
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
changes: Annotated[ModelRecordChanges, Body(description="Model config", examples=[example_model_input])],
) -> AnyModelConfig:
"""Update a model's config."""
logger = ApiDependencies.invoker.services.logger
@@ -450,7 +450,7 @@ async def install_model(
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
config: ModelRecordChanges = Body(
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
example={"name": "string", "description": "string"},
examples=[{"name": "string", "description": "string"}],
),
) -> ModelInstallJob:
"""Install a model using a string identifier.

View File

@@ -14,15 +14,14 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
INSERT INTO board_images (board_id, image_name)
@@ -31,17 +30,12 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
""",
(board_id, image_name, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def remove_image_from_board(
self,
image_name: str,
) -> None:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE FROM board_images
@@ -49,10 +43,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
""",
(image_name,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def get_images_for_board(
self,
@@ -60,27 +50,26 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]:
# TODO: this isn't paginated yet?
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, cursor.fetchone()[0])
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
@@ -90,56 +79,55 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
categories: list[ImageCategory] | None,
is_intermediate: bool | None,
) -> list[str]:
params: list[str | bool] = []
with self._db.transaction() as cursor:
params: list[str | bool] = []
# Base query is a join between images and board_images
stmt = """
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
# Base query is a join between images and board_images
stmt = """
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
# Handle board_id filter
if board_id == "none":
stmt += """--sql
AND board_images.board_id IS NULL
"""
else:
stmt += """--sql
AND board_images.board_id = ?
"""
params.append(board_id)
# Handle board_id filter
if board_id == "none":
stmt += """--sql
AND board_images.board_id IS NULL
"""
else:
stmt += """--sql
AND board_images.board_id = ?
"""
params.append(board_id)
# Add the category filter
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
stmt += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Add the category filter
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
stmt += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Unpack the included categories into the query params
for c in category_strings:
params.append(c)
# Unpack the included categories into the query params
for c in category_strings:
params.append(c)
# Add the is_intermediate filter
if is_intermediate is not None:
stmt += """--sql
AND images.is_intermediate = ?
"""
params.append(is_intermediate)
# Add the is_intermediate filter
if is_intermediate is not None:
stmt += """--sql
AND images.is_intermediate = ?
"""
params.append(is_intermediate)
# Put a ring on it
stmt += ";"
# Put a ring on it
stmt += ";"
# Execute the query
cursor = self._conn.cursor()
cursor.execute(stmt, params)
cursor.execute(stmt, params)
result = cast(list[sqlite3.Row], cursor.fetchall())
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
return image_names
@@ -147,31 +135,31 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
self,
image_name: str,
) -> Optional[str]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT board_id
FROM board_images
WHERE image_name = ?;
""",
(image_name,),
)
result = cursor.fetchone()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT board_id
FROM board_images
WHERE image_name = ?;
""",
(image_name,),
)
result = cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
def get_image_count_for_board(self, board_id: str) -> int:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*)
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE images.is_intermediate = FALSE
AND board_images.board_id = ?;
""",
(board_id,),
)
count = cast(int, cursor.fetchone()[0])
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT COUNT(*)
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE images.is_intermediate = FALSE
AND board_images.board_id = ?;
""",
(board_id,),
)
count = cast(int, cursor.fetchone()[0])
return count

View File

@@ -20,61 +20,57 @@ from invokeai.app.util.misc import uuid_string
class SqliteBoardRecordStorage(BoardRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def delete(self, board_id: str) -> None:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
self._conn.commit()
except Exception as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
except Exception as e:
raise BoardRecordDeleteException from e
def save(
self,
board_name: str,
) -> BoardRecord:
try:
board_id = uuid_string()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
""",
(board_id, board_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
with self._db.transaction() as cursor:
try:
board_id = uuid_string()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
""",
(board_id, board_name),
)
except sqlite3.Error as e:
raise BoardRecordSaveException from e
return self.get(board_id)
def get(
self,
board_id: str,
) -> BoardRecord:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
except sqlite3.Error as e:
raise BoardRecordNotFoundException from e
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
except sqlite3.Error as e:
raise BoardRecordNotFoundException from e
if result is None:
raise BoardRecordNotFoundException
return BoardRecord(**dict(result))
@@ -84,45 +80,43 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
board_id: str,
changes: BoardChanges,
) -> BoardRecord:
try:
cursor = self._conn.cursor()
# Change the name of a board
if changes.board_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET board_name = ?
WHERE board_id = ?;
""",
(changes.board_name, board_id),
)
with self._db.transaction() as cursor:
try:
# Change the name of a board
if changes.board_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET board_name = ?
WHERE board_id = ?;
""",
(changes.board_name, board_id),
)
# Change the cover image of a board
if changes.cover_image_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET cover_image_name = ?
WHERE board_id = ?;
""",
(changes.cover_image_name, board_id),
)
# Change the cover image of a board
if changes.cover_image_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET cover_image_name = ?
WHERE board_id = ?;
""",
(changes.cover_image_name, board_id),
)
# Change the archived status of a board
if changes.archived is not None:
cursor.execute(
"""--sql
UPDATE boards
SET archived = ?
WHERE board_id = ?;
""",
(changes.archived, board_id),
)
# Change the archived status of a board
if changes.archived is not None:
cursor.execute(
"""--sql
UPDATE boards
SET archived = ?
WHERE board_id = ?;
""",
(changes.archived, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
except sqlite3.Error as e:
raise BoardRecordSaveException from e
return self.get(board_id)
def get_many(
@@ -133,78 +127,77 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardRecord]:
cursor = self._conn.cursor()
# Build base query
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
LIMIT ? OFFSET ?;
"""
# Determine archived filter condition
archived_filter = "" if include_archived else "WHERE archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
# Execute query to fetch boards
cursor.execute(final_query, (limit, offset))
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
"""
else:
count_query = """
SELECT COUNT(*)
with self._db.transaction() as cursor:
# Build base query
base_query = """
SELECT *
FROM boards
WHERE archived = 0;
{archived_filter}
ORDER BY {order_by} {direction}
LIMIT ? OFFSET ?;
"""
# Execute count query
cursor.execute(count_query)
# Determine archived filter condition
archived_filter = "" if include_archived else "WHERE archived = 0"
count = cast(int, cursor.fetchone()[0])
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
# Execute query to fetch boards
cursor.execute(final_query, (limit, offset))
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
"""
else:
count_query = """
SELECT COUNT(*)
FROM boards
WHERE archived = 0;
"""
# Execute count query
cursor.execute(count_query)
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
) -> list[BoardRecord]:
cursor = self._conn.cursor()
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY LOWER(board_name) {direction}
"""
else:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
"""
with self._db.transaction() as cursor:
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY LOWER(board_name) {direction}
"""
else:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
"""
archived_filter = "" if include_archived else "WHERE archived = 0"
archived_filter = "" if include_archived else "WHERE archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
cursor.execute(final_query)
cursor.execute(final_query)
result = cast(list[sqlite3.Row], cursor.fetchall())
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
return boards

View File

@@ -8,6 +8,7 @@ import time
import traceback
from pathlib import Path
from queue import Empty, PriorityQueue
from shutil import disk_usage
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
import requests
@@ -335,6 +336,14 @@ class DownloadQueueService(DownloadQueueServiceBase):
assert job.download_path
free_space = disk_usage(job.download_path.parent).free
GB = 2**30
self._logger.debug(f"Download is {job.total_bytes / GB:.2f} GB of {free_space / GB:.2f} GB free.")
if free_space < job.total_bytes:
raise RuntimeError(
f"Free disk space {free_space / GB:.2f} GB is not enough for download of {job.total_bytes / GB:.2f} GB."
)
# Don't clobber an existing file. See commit 82c2c85202f88c6d24ff84710f297cfc6ae174af
# for code that instead resumes an interrupted download.
if job.download_path.exists():

View File

@@ -24,22 +24,22 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteImageRecordStorage(ImageRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def get(self, image_name: str) -> ImageRecord:
try:
cursor = self._conn.cursor()
cursor.execute(
f"""--sql
SELECT {IMAGE_DTO_COLS} FROM images
WHERE image_name = ?;
""",
(image_name,),
)
with self._db.transaction() as cursor:
try:
cursor.execute(
f"""--sql
SELECT {IMAGE_DTO_COLS} FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
result = cast(Optional[sqlite3.Row], cursor.fetchone())
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
if not result:
raise ImageRecordNotFoundException
@@ -47,17 +47,20 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return deserialize_image_record(dict(result))
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT metadata FROM images
WHERE image_name = ?;
""",
(image_name,),
)
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
SELECT metadata FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
result = cast(Optional[sqlite3.Row], cursor.fetchone())
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
if not result:
raise ImageRecordNotFoundException
@@ -65,64 +68,60 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
as_dict = dict(result)
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
def update(
self,
image_name: str,
changes: ImageRecordChanges,
) -> None:
try:
cursor = self._conn.cursor()
# Change the category of the image
if changes.image_category is not None:
cursor.execute(
"""--sql
UPDATE images
SET image_category = ?
WHERE image_name = ?;
""",
(changes.image_category, image_name),
)
with self._db.transaction() as cursor:
try:
# Change the category of the image
if changes.image_category is not None:
cursor.execute(
"""--sql
UPDATE images
SET image_category = ?
WHERE image_name = ?;
""",
(changes.image_category, image_name),
)
# Change the session associated with the image
if changes.session_id is not None:
cursor.execute(
"""--sql
UPDATE images
SET session_id = ?
WHERE image_name = ?;
""",
(changes.session_id, image_name),
)
# Change the session associated with the image
if changes.session_id is not None:
cursor.execute(
"""--sql
UPDATE images
SET session_id = ?
WHERE image_name = ?;
""",
(changes.session_id, image_name),
)
# Change the image's `is_intermediate`` flag
if changes.is_intermediate is not None:
cursor.execute(
"""--sql
UPDATE images
SET is_intermediate = ?
WHERE image_name = ?;
""",
(changes.is_intermediate, image_name),
)
# Change the image's `is_intermediate`` flag
if changes.is_intermediate is not None:
cursor.execute(
"""--sql
UPDATE images
SET is_intermediate = ?
WHERE image_name = ?;
""",
(changes.is_intermediate, image_name),
)
# Change the image's `starred`` state
if changes.starred is not None:
cursor.execute(
"""--sql
UPDATE images
SET starred = ?
WHERE image_name = ?;
""",
(changes.starred, image_name),
)
# Change the image's `starred`` state
if changes.starred is not None:
cursor.execute(
"""--sql
UPDATE images
SET starred = ?
WHERE image_name = ?;
""",
(changes.starred, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
except sqlite3.Error as e:
raise ImageRecordSaveException from e
def get_many(
self,
@@ -136,170 +135,162 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
cursor = self._conn.cursor()
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
images_query = f"""--sql
SELECT {IMAGE_DTO_COLS}
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
with self._db.transaction() as cursor:
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
images_query = f"""--sql
SELECT {IMAGE_DTO_COLS}
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_params.append(is_intermediate)
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
# board_id of "none" is reserved for images without a board
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
# Search term condition
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
# Add the pagination parameters
images_params.extend([limit, offset])
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
# Build the list of images, deserializing each row
cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
cursor.execute(count_query, count_params)
count = cast(int, cursor.fetchone()[0])
query_params.append(is_intermediate)
# board_id of "none" is reserved for images without a board
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
# Search term condition
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
# Add the pagination parameters
images_params.extend([limit, offset])
# Build the list of images, deserializing each row
cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
cursor.execute(count_query, count_params)
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
def delete(self, image_name: str) -> None:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE FROM images
WHERE image_name = ?;
""",
(image_name,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
DELETE FROM images
WHERE image_name = ?;
""",
(image_name,),
)
except sqlite3.Error as e:
raise ImageRecordDeleteException from e
def delete_many(self, image_names: list[str]) -> None:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
try:
placeholders = ",".join("?" for _ in image_names)
placeholders = ",".join("?" for _ in image_names)
# Construct the SQLite query with the placeholders
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
# Construct the SQLite query with the placeholders
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
# Execute the query with the list of IDs as parameters
cursor.execute(query, image_names)
# Execute the query with the list of IDs as parameters
cursor.execute(query, image_names)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
except sqlite3.Error as e:
raise ImageRecordDeleteException from e
def get_intermediates_count(self) -> int:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images
WHERE is_intermediate = TRUE;
"""
)
count = cast(int, cursor.fetchone()[0])
self._conn.commit()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images
WHERE is_intermediate = TRUE;
"""
)
count = cast(int, cursor.fetchone()[0])
return count
def delete_intermediates(self) -> list[str]:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT image_name FROM images
WHERE is_intermediate = TRUE;
"""
)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
cursor.execute(
"""--sql
DELETE FROM images
WHERE is_intermediate = TRUE;
"""
)
self._conn.commit()
return image_names
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
SELECT image_name FROM images
WHERE is_intermediate = TRUE;
"""
)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
cursor.execute(
"""--sql
DELETE FROM images
WHERE is_intermediate = TRUE;
"""
)
except sqlite3.Error as e:
raise ImageRecordDeleteException from e
return image_names
def save(
self,
@@ -315,73 +306,71 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id: Optional[str] = None,
metadata: Optional[str] = None,
) -> datetime:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO images (
image_name,
image_origin,
image_category,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
image_origin.value,
image_category.value,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow,
),
)
self._conn.commit()
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
INSERT OR IGNORE INTO images (
image_name,
image_origin,
image_category,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
image_origin.value,
image_category.value,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow,
),
)
cursor.execute(
"""--sql
SELECT created_at
FROM images
WHERE image_name = ?;
""",
(image_name,),
)
cursor.execute(
"""--sql
SELECT created_at
FROM images
WHERE image_name = ?;
""",
(image_name,),
)
created_at = datetime.fromisoformat(cursor.fetchone()[0])
created_at = datetime.fromisoformat(cursor.fetchone()[0])
return created_at
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
except sqlite3.Error as e:
raise ImageRecordSaveException from e
return created_at
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
AND images.is_intermediate = FALSE
ORDER BY images.starred DESC, images.created_at DESC
LIMIT 1;
""",
(board_id,),
)
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
AND images.is_intermediate = FALSE
ORDER BY images.starred DESC, images.created_at DESC
LIMIT 1;
""",
(board_id,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
result = cast(Optional[sqlite3.Row], cursor.fetchone())
if result is None:
return None
@@ -398,85 +387,84 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
# Build query conditions (reused for both starred count and image names queries)
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
# Build query conditions (reused for both starred count and image names queries)
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
category_strings = [c.value for c in set(categories)]
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
for c in category_strings:
query_params.append(c)
if categories is not None:
category_strings = [c.value for c in set(categories)]
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_params.append(is_intermediate)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_params.append(is_intermediate)
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
# Get starred count if starred_first is enabled
starred_count = 0
if starred_first:
starred_count_query = f"""--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE images.starred = TRUE AND (1=1{query_conditions})
"""
cursor.execute(starred_count_query, query_params)
starred_count = cast(int, cursor.fetchone()[0])
# Get starred count if starred_first is enabled
starred_count = 0
if starred_first:
starred_count_query = f"""--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE images.starred = TRUE AND (1=1{query_conditions})
"""
cursor.execute(starred_count_query, query_params)
starred_count = cast(int, cursor.fetchone()[0])
# Get all image names with proper ordering
if starred_first:
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
ORDER BY images.starred DESC, images.created_at {order_dir.value}
"""
else:
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
ORDER BY images.created_at {order_dir.value}
"""
# Get all image names with proper ordering
if starred_first:
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
ORDER BY images.starred DESC, images.created_at {order_dir.value}
"""
else:
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
ORDER BY images.created_at {order_dir.value}
"""
cursor.execute(names_query, query_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
cursor.execute(names_query, query_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [row[0] for row in result]
return ImageNamesResult(image_names=image_names, starred_count=starred_count, total_count=len(image_names))

View File

@@ -78,11 +78,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db = db
self._logger = logger
@property
def db(self) -> SqliteDatabase:
"""Return the underlying database."""
return self._db
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
"""
Add a model to the database.
@@ -93,38 +88,33 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
INSERT INTO models (
id,
config
)
VALUES (?,?);
""",
(
config.key,
config.model_dump_json(),
),
)
self._db.conn.commit()
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
INSERT INTO models (
id,
config
)
VALUES (?,?);
""",
(
config.key,
config.model_dump_json(),
),
)
except sqlite3.IntegrityError as e:
self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e):
if "models.path" in str(e):
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in str(e):
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
except sqlite3.IntegrityError as e:
if "UNIQUE constraint failed" in str(e):
if "models.path" in str(e):
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in str(e):
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
else:
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
else:
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
else:
raise e
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
raise e
return self.get_model(config.key)
@@ -136,8 +126,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise an UnknownModelException
"""
try:
cursor = self._db.conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE FROM models
@@ -147,22 +136,17 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
record = self.get_model(key)
with self._db.transaction() as cursor:
record = self.get_model(key)
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
json_serialized = record.model_dump_json()
json_serialized = record.model_dump_json()
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
UPDATE models
@@ -174,10 +158,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(key)
@@ -189,30 +169,30 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Exceptions: UnknownModelException
"""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
)
rows = cursor.fetchone()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
)
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = cursor.fetchone()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
@@ -224,15 +204,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param key: Unique key for the model to be deleted
"""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
select count(*) FROM models
WHERE id=?;
""",
(key,),
)
count = cursor.fetchone()[0]
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
select count(*) FROM models
WHERE id=?;
""",
(key,),
)
count = cursor.fetchone()[0]
return count > 0
def search_by_attr(
@@ -255,43 +235,42 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
If none of the optional filters are passed, will return all
models in the database.
"""
with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
where_clause: list[str] = []
bindings: list[str] = []
if model_name:
where_clause.append("name=?")
bindings.append(model_name)
if base_model:
where_clause.append("base=?")
bindings.append(base_model)
if model_type:
where_clause.append("type=?")
bindings.append(model_type)
if model_format:
where_clause.append("format=?")
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
where_clause: list[str] = []
bindings: list[str] = []
if model_name:
where_clause.append("name=?")
bindings.append(model_name)
if base_model:
where_clause.append("base=?")
bindings.append(base_model)
if model_type:
where_clause.append("type=?")
bindings.append(model_type)
if model_format:
where_clause.append("format=?")
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
cursor = self._db.conn.cursor()
cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
""",
tuple(bindings),
)
result = cursor.fetchall()
cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
""",
tuple(bindings),
)
result = cursor.fetchall()
# Parse the model configs.
results: list[AnyModelConfig] = []
@@ -313,69 +292,68 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
"""Return models with the indicated path."""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
"""Return models with the indicated hash."""
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
return results
def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database."""
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
cursor = self._db.conn.cursor()
# Lock so that the database isn't updated while we're doing the two queries.
# query1: get the total number of model configs
cursor.execute(
"""--sql
select count(*) from models;
""",
(),
)
total = int(cursor.fetchone()[0])
# Lock so that the database isn't updated while we're doing the two queries.
# query1: get the total number of model configs
cursor.execute(
"""--sql
select count(*) from models;
""",
(),
)
total = int(cursor.fetchone()[0])
# query2: fetch key fields
cursor.execute(
f"""--sql
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",
(
per_page,
page * per_page,
),
)
rows = cursor.fetchall()
# query2: fetch key fields
cursor.execute(
f"""--sql
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",
(
per_page,
page * per_page,
),
)
rows = cursor.fetchall()
items = [ModelSummary.model_validate(dict(x)) for x in rows]
return PaginatedResults(page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items)

View File

@@ -1,5 +1,3 @@
import sqlite3
from invokeai.app.services.model_relationship_records.model_relationship_records_base import (
ModelRelationshipRecordStorageBase,
)
@@ -9,58 +7,49 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
if model_key_1 == model_key_2:
raise ValueError("Cannot relate a model to itself.")
a, b = sorted([model_key_1, model_key_2])
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
if model_key_1 == model_key_2:
raise ValueError("Cannot relate a model to itself.")
a, b = sorted([model_key_1, model_key_2])
cursor.execute(
"INSERT OR IGNORE INTO model_relationships (model_key_1, model_key_2) VALUES (?, ?)",
(a, b),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
a, b = sorted([model_key_1, model_key_2])
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
a, b = sorted([model_key_1, model_key_2])
cursor.execute(
"DELETE FROM model_relationships WHERE model_key_1 = ? AND model_key_2 = ?",
(a, b),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def get_related_model_keys(self, model_key: str) -> list[str]:
cursor = self._conn.cursor()
cursor.execute(
"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
""",
(model_key, model_key),
)
return [row[0] for row in cursor.fetchall()]
with self._db.transaction() as cursor:
cursor.execute(
"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
""",
(model_key, model_key),
)
result = [row[0] for row in cursor.fetchall()]
return result
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
cursor = self._conn.cursor()
key_list = ",".join("?" for _ in model_keys)
cursor.execute(
f"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
""",
model_keys + model_keys,
)
return [row[0] for row in cursor.fetchall()]
with self._db.transaction() as cursor:
key_list = ",".join("?" for _ in model_keys)
cursor.execute(
f"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
""",
model_keys + model_keys,
)
result = [row[0] for row in cursor.fetchall()]
return result

View File

@@ -50,15 +50,14 @@ class SqliteSessionQueue(SessionQueueBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def _set_in_progress_to_canceled(self) -> None:
"""
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
This is necessary because the invoker may have been killed while processing a queue item.
"""
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
UPDATE session_queue
@@ -66,87 +65,79 @@ class SqliteSessionQueue(SessionQueueBase):
WHERE status = 'in_progress';
"""
)
except Exception:
self._conn.rollback()
raise
def _get_current_queue_size(self, queue_id: str) -> int:
"""Gets the current number of pending queue items"""
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
return cast(int, cursor.fetchone()[0])
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
count = cast(int, cursor.fetchone()[0])
return count
def _get_highest_priority(self, queue_id: str) -> int:
"""Gets the highest priority value in the queue"""
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT MAX(priority)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
return cast(Union[int, None], cursor.fetchone()[0]) or 0
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT MAX(priority)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
priority = cast(Union[int, None], cursor.fetchone()[0]) or 0
return priority
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
try:
# TODO: how does this work in a multi-user scenario?
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.max_queue_size
max_new_queue_items = max_queue_size - current_queue_size
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.max_queue_size
max_new_queue_items = max_queue_size - current_queue_size
priority = 0
if prepend:
priority = self._get_highest_priority(queue_id) + 1
priority = 0
if prepend:
priority = self._get_highest_priority(queue_id) + 1
requested_count = await asyncio.to_thread(
calc_session_count,
batch=batch,
)
values_to_insert = await asyncio.to_thread(
prepare_values_to_insert,
queue_id=queue_id,
batch=batch,
priority=priority,
max_new_queue_items=max_new_queue_items,
)
enqueued_count = len(values_to_insert)
requested_count = await asyncio.to_thread(
calc_session_count,
batch=batch,
)
values_to_insert = await asyncio.to_thread(
prepare_values_to_insert,
queue_id=queue_id,
batch=batch,
priority=priority,
max_new_queue_items=max_new_queue_items,
)
enqueued_count = len(values_to_insert)
with self._conn:
cursor = self._conn.cursor()
cursor.executemany(
"""--sql
with self._db.transaction() as cursor:
cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
with self._conn:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
values_to_insert,
)
cursor.execute(
"""--sql
SELECT item_id
FROM session_queue
WHERE batch_id = ?
ORDER BY item_id DESC;
""",
(batch.batch_id,),
)
item_ids = [row[0] for row in cursor.fetchall()]
except Exception:
raise
(batch.batch_id,),
)
item_ids = [row[0] for row in cursor.fetchall()]
enqueue_result = EnqueueBatchResult(
queue_id=queue_id,
requested=requested_count,
@@ -159,19 +150,19 @@ class SqliteSessionQueue(SessionQueueBase):
return enqueue_result
def dequeue(self) -> Optional[SessionQueueItem]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE status = 'pending'
ORDER BY
priority DESC,
item_id ASC
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE status = 'pending'
ORDER BY
priority DESC,
item_id ASC
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
queue_item = SessionQueueItem.queue_item_from_dict(dict(result))
@@ -179,40 +170,40 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
ORDER BY
priority DESC,
created_at ASC
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
ORDER BY
priority DESC,
created_at ASC
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'in_progress'
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'in_progress'
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
@@ -225,8 +216,7 @@ class SqliteSessionQueue(SessionQueueBase):
error_message: Optional[str] = None,
error_traceback: Optional[str] = None,
) -> SessionQueueItem:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status FROM session_queue WHERE item_id = ?
@@ -234,12 +224,15 @@ class SqliteSessionQueue(SessionQueueBase):
(item_id,),
)
row = cursor.fetchone()
if row is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
current_status = row[0]
# Only update if not already finished (completed, failed or canceled)
if current_status in ("completed", "failed", "canceled"):
return self.get_queue_item(item_id)
if row is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
current_status = row[0]
# Only update if not already finished (completed, failed or canceled)
if current_status in ("completed", "failed", "canceled"):
return self.get_queue_item(item_id)
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
UPDATE session_queue
@@ -248,10 +241,7 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(status, error_type, error_message, error_traceback, item_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
queue_item = self.get_queue_item(item_id)
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
@@ -259,35 +249,34 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
is_empty = cast(int, cursor.fetchone()[0]) == 0
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
is_empty = cast(int, cursor.fetchone()[0]) == 0
return IsEmptyResult(is_empty=is_empty)
def is_full(self, queue_id: str) -> IsFullResult:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
return IsFullResult(is_full=is_full)
def clear(self, queue_id: str) -> ClearResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT COUNT(*)
@@ -305,24 +294,19 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
self.__invoker.services.events.emit_queue_cleared(queue_id)
return ClearResult(deleted=count)
def prune(self, queue_id: str) -> PruneResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
where = """--sql
WHERE
queue_id = ?
AND (
queue_id = ?
AND (
status = 'completed'
OR status = 'failed'
OR status = 'canceled'
)
)
"""
cursor.execute(
f"""--sql
@@ -341,10 +325,6 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return PruneResult(deleted=count)
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
@@ -357,8 +337,7 @@ class SqliteSessionQueue(SessionQueueBase):
self.cancel_queue_item(item_id)
except SessionQueueItemNotFoundError:
pass
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE
@@ -367,10 +346,6 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(item_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
@@ -393,8 +368,7 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
placeholders = ", ".join(["?" for _ in batch_ids])
where = f"""--sql
@@ -404,6 +378,8 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
"""
params = [queue_id] + batch_ids
cursor.execute(
@@ -423,17 +399,14 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
self._conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self._conn.rollback()
raise
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
return CancelByBatchIDsResult(canceled=count)
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
where = """--sql
WHERE
@@ -442,6 +415,8 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
"""
params = (queue_id, destination)
cursor.execute(
@@ -461,17 +436,12 @@ class SqliteSessionQueue(SessionQueueBase):
""",
params,
)
self._conn.commit()
if current_queue_item is not None and current_queue_item.destination == destination:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self._conn.rollback()
raise
if current_queue_item is not None and current_queue_item.destination == destination:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
return CancelByDestinationResult(canceled=count)
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
if current_queue_item is not None and current_queue_item.destination == destination:
self.cancel_queue_item(current_queue_item.item_id)
@@ -497,15 +467,10 @@ class SqliteSessionQueue(SessionQueueBase):
""",
params,
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return DeleteByDestinationResult(deleted=count)
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
where = """--sql
WHERE
queue_id == ?
@@ -528,15 +493,10 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return DeleteAllExceptCurrentResult(deleted=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
where = """--sql
WHERE
@@ -544,6 +504,8 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
"""
params = [queue_id]
cursor.execute(
@@ -563,21 +525,13 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
self._conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
current_queue_item, batch_status, queue_status
)
except Exception:
self._conn.rollback()
raise
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
return CancelByQueueIDResult(canceled=count)
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
where = """--sql
WHERE
queue_id == ?
@@ -600,30 +554,25 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return CancelAllExceptCurrentResult(canceled=count)
def get_queue_item(self, item_id: int) -> SessionQueueItem:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
return SessionQueueItem.queue_item_from_dict(dict(result))
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
# during execution.
@@ -636,10 +585,6 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(session_json, item_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get_queue_item(item_id)
def list_queue_items(
@@ -651,42 +596,42 @@ class SqliteSessionQueue(SessionQueueBase):
status: Optional[QUEUE_ITEM_STATUS] = None,
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
cursor_ = self._conn.cursor()
item_id = cursor
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)
if destination is not None:
query += """---sql
AND destination = ?
with self._db.transaction() as cursor_:
item_id = cursor
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params.append(destination)
params: list[Union[str, int]] = [queue_id]
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)
if destination is not None:
query += """---sql
AND destination = ?
"""
params.extend([priority, priority, item_id])
params.append(destination)
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
"""
params.append(limit + 1)
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
"""
params.extend([priority, priority, item_id])
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
"""
params.append(limit + 1)
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
@@ -701,43 +646,43 @@ class SqliteSessionQueue(SessionQueueBase):
destination: Optional[str] = None,
) -> list[SessionQueueItem]:
"""Gets all queue items that match the given parameters"""
cursor_ = self._conn.cursor()
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if destination is not None:
query += """---sql
AND destination = ?
with self._db.transaction() as cursor:
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params.append(destination)
params: list[Union[str, int]] = [queue_id]
query += """--sql
ORDER BY
priority DESC,
item_id ASC
;
"""
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
if destination is not None:
query += """---sql
AND destination = ?
"""
params.append(destination)
query += """--sql
ORDER BY
priority DESC,
item_id ASC
;
"""
cursor.execute(query, params)
results = cast(list[sqlite3.Row], cursor.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
return items
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
current_item = self.get_current(queue_id=queue_id)
total = sum(row[1] or 0 for row in counts_result)
@@ -756,19 +701,19 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*), origin, destination
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*), origin, destination
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] or 0 for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None
@@ -788,18 +733,18 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
AND destination = ?
GROUP BY status
""",
(queue_id, destination),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
AND destination = ?
GROUP BY status
""",
(queue_id, destination),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] or 0 for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
@@ -817,8 +762,7 @@ class SqliteSessionQueue(SessionQueueBase):
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
"""Retries the given queue items"""
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
values_to_insert: list[ValueToInsertTuple] = []
retried_item_ids: list[int] = []
@@ -869,10 +813,6 @@ class SqliteSessionQueue(SessionQueueBase):
values_to_insert,
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
retry_result = RetryItemsResult(
queue_id=queue_id,
retried_item_ids=retried_item_ids,

View File

@@ -1,4 +1,7 @@
import sqlite3
import threading
from collections.abc import Generator
from contextlib import contextmanager
from logging import Logger
from pathlib import Path
@@ -26,46 +29,65 @@ class SqliteDatabase:
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
"""Initializes the database. This is used internally by the class constructor."""
self.logger = logger
self.db_path = db_path
self.verbose = verbose
self._logger = logger
self._db_path = db_path
self._verbose = verbose
self._lock = threading.RLock()
if not self.db_path:
if not self._db_path:
logger.info("Initializing in-memory database")
else:
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.logger.info(f"Initializing database at {self.db_path}")
self._db_path.parent.mkdir(parents=True, exist_ok=True)
self._logger.info(f"Initializing database at {self._db_path}")
self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
self.conn.row_factory = sqlite3.Row
self._conn = sqlite3.connect(database=self._db_path or sqlite_memory, check_same_thread=False)
self._conn.row_factory = sqlite3.Row
if self.verbose:
self.conn.set_trace_callback(self.logger.debug)
if self._verbose:
self._conn.set_trace_callback(self._logger.debug)
# Enable foreign key constraints
self.conn.execute("PRAGMA foreign_keys = ON;")
self._conn.execute("PRAGMA foreign_keys = ON;")
# Enable Write-Ahead Logging (WAL) mode for better concurrency
self.conn.execute("PRAGMA journal_mode = WAL;")
self._conn.execute("PRAGMA journal_mode = WAL;")
# Set a busy timeout to prevent database lockups during writes
self.conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
self._conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
def clean(self) -> None:
"""
Cleans the database by running the VACUUM command, reporting on the freed space.
"""
# No need to clean in-memory database
if not self.db_path:
if not self._db_path:
return
try:
initial_db_size = Path(self.db_path).stat().st_size
self.conn.execute("VACUUM;")
self.conn.commit()
final_db_size = Path(self.db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
with self._conn as conn:
initial_db_size = Path(self._db_path).stat().st_size
conn.execute("VACUUM;")
conn.commit()
final_db_size = Path(self._db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
except Exception as e:
self.logger.error(f"Error cleaning database: {e}")
self._logger.error(f"Error cleaning database: {e}")
raise
@contextmanager
def transaction(self) -> Generator[sqlite3.Cursor, None, None]:
"""
Thread-safe context manager for DB work.
Acquires the RLock, yields a Cursor, then commits or rolls back.
"""
with self._lock:
cursor = self._conn.cursor()
try:
yield cursor
self._conn.commit()
except:
self._conn.rollback()
raise
finally:
cursor.close()

View File

@@ -32,7 +32,7 @@ class SqliteMigrator:
def __init__(self, db: SqliteDatabase) -> None:
self._db = db
self._logger = db.logger
self._logger = db._logger
self._migration_set = MigrationSet()
self._backup_path: Optional[Path] = None
@@ -45,7 +45,7 @@ class SqliteMigrator:
"""Migrates the database to the latest version."""
# This throws if there is a problem.
self._migration_set.validate_migration_chain()
cursor = self._db.conn.cursor()
cursor = self._db._conn.cursor()
self._create_migrations_table(cursor=cursor)
if self._migration_set.count == 0:
@@ -59,13 +59,13 @@ class SqliteMigrator:
self._logger.info("Database update needed")
# Make a backup of the db if it needs to be updated and is a file db
if self._db.db_path is not None:
if self._db._db_path is not None:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
self._backup_path = self._db._db_path.parent / f"{self._db._db_path.stem}_backup_{timestamp}.db"
self._logger.info(f"Backing up database to {str(self._backup_path)}")
# Use SQLite to do the backup
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
self._db.conn.backup(backup_conn)
self._db._conn.backup(backup_conn)
else:
self._logger.info("Using in-memory database, no backup needed")
@@ -81,7 +81,7 @@ class SqliteMigrator:
try:
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
# exception is raised.
with self._db.conn as conn:
with self._db._conn as conn:
cursor = conn.cursor()
if self._get_current_version(cursor) != migration.from_version:
raise MigrationError(

View File

@@ -17,7 +17,7 @@ from invokeai.app.util.misc import uuid_string
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@@ -25,24 +25,23 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
"""Gets a style preset by ID."""
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = cursor.fetchone()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = cursor.fetchone()
if row is None:
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
return StylePresetRecordDTO.from_dict(dict(row))
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
style_preset_id = uuid_string()
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
@@ -60,16 +59,11 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(style_preset_id)
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
style_preset_ids = []
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
for style_preset in style_presets:
style_preset_id = uuid_string()
style_preset_ids.append(style_preset_id)
@@ -90,16 +84,11 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return None
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
# Change the name of a style preset
if changes.name is not None:
cursor.execute(
@@ -122,15 +111,10 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
(changes.preset_data.model_dump_json(), style_preset_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(style_preset_id)
def delete(self, style_preset_id: str) -> None:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE from style_presets
@@ -138,51 +122,41 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
""",
(style_preset_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return None
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
main_query = """
SELECT
*
FROM style_presets
"""
with self._db.transaction() as cursor:
main_query = """
SELECT
*
FROM style_presets
"""
if type is not None:
main_query += "WHERE type = ? "
if type is not None:
main_query += "WHERE type = ? "
main_query += "ORDER BY LOWER(name) ASC"
main_query += "ORDER BY LOWER(name) ASC"
cursor = self._conn.cursor()
if type is not None:
cursor.execute(main_query, (type,))
else:
cursor.execute(main_query)
if type is not None:
cursor.execute(main_query, (type,))
else:
cursor.execute(main_query)
rows = cursor.fetchall()
rows = cursor.fetchall()
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
return style_presets
def _sync_default_style_presets(self) -> None:
"""Syncs default style presets to the database. Internal use only."""
# First delete all existing default style presets
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
# First delete all existing default style presets
cursor.execute(
"""--sql
DELETE FROM style_presets
WHERE type = "default";
"""
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
# Next, parse and create the default style presets
with open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
presets = json.load(file)

View File

@@ -25,7 +25,7 @@ SQL_TIME_FORMAT = "%Y-%m-%d %H:%M:%f"
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
self._db = db
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@@ -33,16 +33,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def get(self, workflow_id: str) -> WorkflowRecordDTO:
"""Gets a workflow by ID. Updates the opened_at column."""
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
FROM workflow_library
WHERE workflow_id = ?;
""",
(workflow_id,),
)
row = cursor.fetchone()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
FROM workflow_library
WHERE workflow_id = ?;
""",
(workflow_id,),
)
row = cursor.fetchone()
if row is None:
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
return WorkflowRecordDTO.from_dict(dict(row))
@@ -51,9 +51,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be created via this method")
try:
with self._db.transaction() as cursor:
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO workflow_library (
@@ -64,18 +63,13 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow_with_id.id, workflow_with_id.model_dump_json()),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(workflow_with_id.id)
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be updated")
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
UPDATE workflow_library
@@ -84,18 +78,13 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow.model_dump_json(), workflow.id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(workflow.id)
def delete(self, workflow_id: str) -> None:
if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be deleted")
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE from workflow_library
@@ -103,10 +92,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return None
def get_many(
@@ -121,108 +106,108 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
with self._db.transaction() as cursor:
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
# We will construct the query dynamically based on the query params
# We will construct the query dynamically based on the query params
# The main query to get the workflows / counts
main_query = """
SELECT
workflow_id,
category,
name,
description,
created_at,
updated_at,
opened_at,
tags
FROM workflow_library
"""
count_query = "SELECT COUNT(*) FROM workflow_library"
# The main query to get the workflows / counts
main_query = """
SELECT
workflow_id,
category,
name,
description,
created_at,
updated_at,
opened_at,
tags
FROM workflow_library
"""
count_query = "SELECT COUNT(*) FROM workflow_library"
# Start with an empty list of conditions and params
conditions: list[str] = []
params: list[str | int] = []
# Start with an empty list of conditions and params
conditions: list[str] = []
params: list[str | int] = []
if categories:
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
if categories:
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
# Ensure all categories are valid (is this necessary?)
assert all(c in WorkflowCategory for c in categories)
# Ensure all categories are valid (is this necessary?)
assert all(c in WorkflowCategory for c in categories)
# Construct a placeholder string for the number of categories
placeholders = ", ".join("?" for _ in categories)
# Construct a placeholder string for the number of categories
placeholders = ", ".join("?" for _ in categories)
# Construct the condition string & params
category_condition = f"category IN ({placeholders})"
category_params = [category.value for category in categories]
# Construct the condition string & params
category_condition = f"category IN ({placeholders})"
category_params = [category.value for category in categories]
conditions.append(category_condition)
params.extend(category_params)
conditions.append(category_condition)
params.extend(category_params)
if tags:
# Tags is a list of strings, and a single string in the DB
# The string in the DB has no guaranteed format
if tags:
# Tags is a list of strings, and a single string in the DB
# The string in the DB has no guaranteed format
# Construct a list of conditions for each tag
tags_conditions = ["tags LIKE ?" for _ in tags]
tags_conditions_joined = " OR ".join(tags_conditions)
tags_condition = f"({tags_conditions_joined})"
# Construct a list of conditions for each tag
tags_conditions = ["tags LIKE ?" for _ in tags]
tags_conditions_joined = " OR ".join(tags_conditions)
tags_condition = f"({tags_conditions_joined})"
# And the params for the tags, case-insensitive
tags_params = [f"%{t.strip()}%" for t in tags]
# And the params for the tags, case-insensitive
tags_params = [f"%{t.strip()}%" for t in tags]
conditions.append(tags_condition)
params.extend(tags_params)
conditions.append(tags_condition)
params.extend(tags_params)
if has_been_opened:
conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
conditions.append("opened_at IS NULL")
if has_been_opened:
conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
conditions.append("opened_at IS NULL")
# Ignore whitespace in the query
stripped_query = query.strip() if query else None
if stripped_query:
# Construct a wildcard query for the name, description, and tags
wildcard_query = "%" + stripped_query + "%"
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
# Ignore whitespace in the query
stripped_query = query.strip() if query else None
if stripped_query:
# Construct a wildcard query for the name, description, and tags
wildcard_query = "%" + stripped_query + "%"
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
conditions.append(query_condition)
params.extend([wildcard_query, wildcard_query, wildcard_query])
conditions.append(query_condition)
params.extend([wildcard_query, wildcard_query, wildcard_query])
if conditions:
# If there are conditions, add a WHERE clause and then join the conditions
main_query += " WHERE "
count_query += " WHERE "
if conditions:
# If there are conditions, add a WHERE clause and then join the conditions
main_query += " WHERE "
count_query += " WHERE "
all_conditions = " AND ".join(conditions)
main_query += all_conditions
count_query += all_conditions
all_conditions = " AND ".join(conditions)
main_query += all_conditions
count_query += all_conditions
# After this point, the query and params differ for the main query and the count query
main_params = params.copy()
count_params = params.copy()
# After this point, the query and params differ for the main query and the count query
main_params = params.copy()
count_params = params.copy()
# Main query also gets ORDER BY and LIMIT/OFFSET
main_query += f" ORDER BY {order_by.value} {direction.value}"
# Main query also gets ORDER BY and LIMIT/OFFSET
main_query += f" ORDER BY {order_by.value} {direction.value}"
if per_page:
main_query += " LIMIT ? OFFSET ?"
main_params.extend([per_page, page * per_page])
if per_page:
main_query += " LIMIT ? OFFSET ?"
main_params.extend([per_page, page * per_page])
# Put a ring on it
main_query += ";"
count_query += ";"
# Put a ring on it
main_query += ";"
count_query += ";"
cursor = self._conn.cursor()
cursor.execute(main_query, main_params)
rows = cursor.fetchall()
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
cursor.execute(main_query, main_params)
rows = cursor.fetchall()
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
cursor.execute(count_query, count_params)
total = cursor.fetchone()[0]
cursor.execute(count_query, count_params)
total = cursor.fetchone()[0]
if per_page:
pages = total // per_page + (total % per_page > 0)
@@ -247,46 +232,46 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if not tags:
return {}
cursor = self._conn.cursor()
result: dict[str, int] = {}
# Base conditions for categories and selected tags
base_conditions: list[str] = []
base_params: list[str | int] = []
with self._db.transaction() as cursor:
result: dict[str, int] = {}
# Base conditions for categories and selected tags
base_conditions: list[str] = []
base_params: list[str | int] = []
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
# For each tag to count, run a separate query
for tag in tags:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# For each tag to count, run a separate query
for tag in tags:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# Add this specific tag condition
conditions.append("tags LIKE ?")
params.append(f"%{tag.strip()}%")
# Add this specific tag condition
conditions.append("tags LIKE ?")
params.append(f"%{tag.strip()}%")
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[tag] = count
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[tag] = count
return result
@@ -296,52 +281,51 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> dict[str, int]:
cursor = self._conn.cursor()
result: dict[str, int] = {}
# Base conditions for categories
base_conditions: list[str] = []
base_params: list[str | int] = []
with self._db.transaction() as cursor:
result: dict[str, int] = {}
# Base conditions for categories
base_conditions: list[str] = []
base_params: list[str | int] = []
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
# For each category to count, run a separate query
for category in categories:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# For each category to count, run a separate query
for category in categories:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# Add this specific category condition
conditions.append("category = ?")
params.append(category.value)
# Add this specific category condition
conditions.append("category = ?")
params.append(category.value)
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[category.value] = count
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[category.value] = count
return result
def update_opened_at(self, workflow_id: str) -> None:
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
f"""--sql
UPDATE workflow_library
@@ -350,10 +334,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
def _sync_default_workflows(self) -> None:
"""Syncs default workflows to the database. Internal use only."""
@@ -368,8 +348,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
meaningless, as they are overwritten every time the server starts.
"""
try:
cursor = self._conn.cursor()
with self._db.transaction() as cursor:
workflows_from_file: list[Workflow] = []
workflows_to_update: list[Workflow] = []
workflows_to_add: list[Workflow] = []
@@ -449,8 +428,3 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(w.model_dump_json(), w.id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise

View File

@@ -187,7 +187,7 @@ class ModelConfigBase(ABC, BaseModel):
else:
return config_cls.from_model_on_disk(mod, **overrides)
raise InvalidModelConfigException("No valid config found")
raise InvalidModelConfigException("Unable to determine model type")
@classmethod
def get_tag(cls) -> Tag:

View File

@@ -143,11 +143,19 @@ flux_dev = StarterModel(
flux_kontext = StarterModel(
name="FLUX.1 Kontext dev",
base=BaseModelType.Flux,
source="black-forest-labs/FLUX.1-Kontext-dev::flux1-kontext-dev.safetensors",
source="https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev/resolve/main/flux1-kontext-dev.safetensors",
description="FLUX.1 Kontext dev transformer in bfloat16. Total size with dependencies: ~33GB",
type=ModelType.Main,
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
)
flux_kontext_quantized = StarterModel(
name="FLUX.1 Kontext dev (Quantized)",
base=BaseModelType.Flux,
source="https://huggingface.co/unsloth/FLUX.1-Kontext-dev-GGUF/resolve/main/flux1-kontext-dev-Q4_K_M.gguf",
description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~14GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
sd35_medium = StarterModel(
name="SD3.5 Medium",
base=BaseModelType.StableDiffusion3,
@@ -664,7 +672,7 @@ flux_fill = StarterModel(
# List of starter models, displayed on the frontend.
# The order/sort of this list is not changed by the frontend - set it how you want it here.
STARTER_MODELS: list[StarterModel] = [
flux_kontext,
flux_kontext_quantized,
flux_schnell_quantized,
flux_dev_quantized,
flux_schnell,
@@ -785,7 +793,7 @@ flux_bundle: list[StarterModel] = [
flux_depth_control_lora,
flux_redux,
flux_fill,
flux_kontext,
flux_kontext_quantized,
]
STARTER_BUNDLES: dict[str, StarterModelBundle] = {

View File

@@ -12,6 +12,8 @@ const config: KnipConfig = {
'src/features/parameters/types/parameterSchemas.ts',
// TODO(psyche): maybe we can clean up these utils after canvas v2 release
'src/features/controlLayers/konva/util.ts',
// Will be using this
'src/common/hooks/useAsyncState.ts',
],
ignoreBinaries: ['only-allow'],
paths: {

View File

@@ -63,7 +63,7 @@
"framer-motion": "^11.10.0",
"i18next": "^25.2.1",
"i18next-http-backend": "^3.0.2",
"idb-keyval": "^6.2.2",
"idb-keyval": "6.2.1",
"jsondiffpatch": "^0.7.3",
"konva": "^9.3.20",
"linkify-react": "^4.3.1",

View File

@@ -81,8 +81,8 @@ importers:
specifier: ^3.0.2
version: 3.0.2
idb-keyval:
specifier: ^6.2.2
version: 6.2.2
specifier: 6.2.1
version: 6.2.1
jsondiffpatch:
specifier: ^0.7.3
version: 0.7.3
@@ -2927,8 +2927,8 @@ packages:
typescript:
optional: true
idb-keyval@6.2.2:
resolution: {integrity: sha512-yjD9nARJ/jb1g+CvD0tlhUHOrJ9Sy0P8T9MF3YaLlHnSRpwPfpTX0XIvpmw3gAJUmEu3FiICLBDPXVwyEvrleg==}
idb-keyval@6.2.1:
resolution: {integrity: sha512-8Sb3veuYCyrZL+VBt9LJfZjLUPWVvqn8tG28VqYNFCo43KHcKuq+b4EiXGeuaLAQWL2YmyDgMp2aSpH9JHsEQg==}
ieee754@1.2.1:
resolution: {integrity: sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==}
@@ -7720,7 +7720,7 @@ snapshots:
optionalDependencies:
typescript: 5.8.3
idb-keyval@6.2.2: {}
idb-keyval@6.2.1: {}
ieee754@1.2.1: {}

View File

@@ -1399,7 +1399,7 @@
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks.",
"imagenIncompatibleGenerationMode": "Google {{model}} supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"chatGPT4oIncompatibleGenerationMode": "ChatGPT 4o supports Text to Image and Image to Image only. Use other models Inpainting and Outpainting tasks.",
"fluxKontextIncompatibleGenerationMode": "FLUX Kontext supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"fluxKontextIncompatibleGenerationMode": "FLUX Kontext does not support generation from images placed on the canvas. Re-try using the Reference Image section and disable any Raster Layers.",
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
"workflowUnpublished": "Workflow Unpublished",
@@ -1407,7 +1407,7 @@
"sentToUpscale": "Sent to Upscale",
"promptGenerationStarted": "Prompt generation started",
"uploadAndPromptGenerationFailed": "Failed to upload image and generate prompt",
"promptExpansionFailed": "Prompt expansion failed"
"promptExpansionFailed": "We ran into an issue. Please try prompt expansion again."
},
"popovers": {
"clipSkip": {
@@ -1962,6 +1962,7 @@
"recalculateRects": "Recalculate Rects",
"clipToBbox": "Clip Strokes to Bbox",
"outputOnlyMaskedRegions": "Output Only Generated Regions",
"saveAllImagesToGallery": "Save All Images to Gallery",
"addLayer": "Add Layer",
"duplicate": "Duplicate",
"moveToFront": "Move to Front",
@@ -2330,6 +2331,9 @@
"label": "Preserve Masked Region",
"alert": "Preserving Masked Region"
},
"saveAllImagesToGallery": {
"alert": "Saving All Images to Gallery"
},
"isolatedStagingPreview": "Isolated Staging Preview",
"isolatedPreview": "Isolated Preview",
"isolatedLayerPreview": "Isolated Layer Preview",
@@ -2376,6 +2380,11 @@
"saveToGallery": "Save To Gallery",
"showResultsOn": "Showing Results",
"showResultsOff": "Hiding Results"
},
"autoSwitch": {
"off": "Off",
"switchOnStart": "On Start",
"switchOnFinish": "On Finish"
}
},
"upscaling": {
@@ -2551,8 +2560,9 @@
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"items": [
"Inpainting: Per-mask noise levels and denoise limits.",
"Canvas: Smarter aspect ratios for SDXL and improved scroll-to-zoom."
"Generate images faster with new Launchpads and a simplified Generate tab.",
"Edit with prompts using Flux Kontext Dev.",
"Export to PSD, bulk-hide overlays, organize models & images — all in a reimagined interface built for control."
],
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",
@@ -2561,62 +2571,16 @@
"supportVideos": {
"supportVideos": "Support Videos",
"gettingStarted": "Getting Started",
"controlCanvas": "Control Canvas",
"watch": "Watch",
"studioSessionsDesc1": "Check out the <StudioSessionsPlaylistLink /> for Invoke deep dives.",
"studioSessionsDesc2": "Join our <DiscordLink /> to participate in the live sessions and ask questions. Sessions are uploaded to the playlist the following week.",
"studioSessionsDesc": "Join our <DiscordLink /> to participate in the live sessions and ask questions. Sessions are uploaded to the playlist the following week.",
"videos": {
"creatingYourFirstImage": {
"title": "Creating Your First Image",
"description": "Introduction to creating an image from scratch using Invoke's tools."
"gettingStarted": {
"title": "Getting Started with Invoke",
"description": "Complete video series covering everything you need to know to get started with Invoke, from creating your first image to advanced techniques."
},
"usingControlLayersAndReferenceGuides": {
"title": "Using Control Layers and Reference Guides",
"description": "Learn how to guide your image creation with control layers and reference images."
},
"understandingImageToImageAndDenoising": {
"title": "Understanding Image-to-Image and Denoising",
"description": "Overview of image-to-image transformations and denoising in Invoke."
},
"exploringAIModelsAndConceptAdapters": {
"title": "Exploring AI Models and Concept Adapters",
"description": "Dive into AI models and how to use concept adapters for creative control."
},
"creatingAndComposingOnInvokesControlCanvas": {
"title": "Creating and Composing on Invoke's Control Canvas",
"description": "Learn to compose images using Invoke's control canvas."
},
"upscaling": {
"title": "Upscaling",
"description": "How to upscale images with Invoke's tools to enhance resolution."
},
"howDoIGenerateAndSaveToTheGallery": {
"title": "How Do I Generate and Save to the Gallery?",
"description": "Steps to generate and save images to the gallery."
},
"howDoIEditOnTheCanvas": {
"title": "How Do I Edit on the Canvas?",
"description": "Guide to editing images directly on the canvas."
},
"howDoIDoImageToImageTransformation": {
"title": "How Do I Do Image-to-Image Transformation?",
"description": "Tutorial on performing image-to-image transformations in Invoke."
},
"howDoIUseControlNetsAndControlLayers": {
"title": "How Do I Use Control Nets and Control Layers?",
"description": "Learn to apply control layers and controlnets to your images."
},
"howDoIUseGlobalIPAdaptersAndReferenceImages": {
"title": "How Do I Use Global IP Adapters and Reference Images?",
"description": "Introduction to adding reference images and global IP adapters."
},
"howDoIUseInpaintMasks": {
"title": "How Do I Use Inpaint Masks?",
"description": "How to apply inpaint masks for image correction and variation."
},
"howDoIOutpaint": {
"title": "How Do I Outpaint?",
"description": "Guide to outpainting beyond the original image borders."
"studioSessions": {
"title": "Studio Sessions",
"description": "Deep dive sessions exploring advanced Invoke features, creative workflows, and community discussions."
}
}
}

View File

@@ -11,6 +11,7 @@ import { memo, useCallback } from 'react';
import { ErrorBoundary } from 'react-error-boundary';
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
import ThemeLocaleProvider from './ThemeLocaleProvider';
const DEFAULT_CONFIG = {};
interface Props {
@@ -30,12 +31,14 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
return (
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
<AppContent />
{!didStudioInit && <Loading />}
</Box>
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
<GlobalModalIsolator />
<ThemeLocaleProvider>
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
<AppContent />
{!didStudioInit && <Loading />}
</Box>
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
<GlobalModalIsolator />
</ThemeLocaleProvider>
</ErrorBoundary>
);
};

View File

@@ -1,11 +1,14 @@
import { useAppSelector } from 'app/store/storeHooks';
import { useIsRegionFocused } from 'common/hooks/focus';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { useImageActions } from 'features/gallery/hooks/useImageActions';
import { useLoadWorkflow } from 'features/gallery/hooks/useLoadWorkflow';
import { useRecallAll } from 'features/gallery/hooks/useRecallAll';
import { useRecallDimensions } from 'features/gallery/hooks/useRecallDimensions';
import { useRecallPrompts } from 'features/gallery/hooks/useRecallPrompts';
import { useRecallRemix } from 'features/gallery/hooks/useRecallRemix';
import { useRecallSeed } from 'features/gallery/hooks/useRecallSeed';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useImageDTO } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
@@ -27,59 +30,64 @@ GlobalImageHotkeys.displayName = 'GlobalImageHotkeys';
const GlobalImageHotkeysInternal = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
const isGalleryFocused = useIsRegionFocused('gallery');
const isViewerFocused = useIsRegionFocused('viewer');
const imageActions = useImageActions(imageDTO);
const isStaging = useAppSelector(selectIsStaging);
const isUpscalingEnabled = useFeatureStatus('upscaling');
const isFocusOK = isGalleryFocused || isViewerFocused;
const recallAll = useRecallAll(imageDTO);
const recallRemix = useRecallRemix(imageDTO);
const recallPrompts = useRecallPrompts(imageDTO);
const recallSeed = useRecallSeed(imageDTO);
const recallDimensions = useRecallDimensions(imageDTO);
const loadWorkflow = useLoadWorkflow(imageDTO);
useRegisteredHotkeys({
id: 'loadWorkflow',
category: 'viewer',
callback: imageActions.loadWorkflow,
options: { enabled: isGalleryFocused || isViewerFocused },
dependencies: [imageActions.loadWorkflow, isGalleryFocused, isViewerFocused],
callback: loadWorkflow.load,
options: { enabled: loadWorkflow.isEnabled && isFocusOK },
dependencies: [loadWorkflow, isFocusOK],
});
useRegisteredHotkeys({
id: 'recallAll',
category: 'viewer',
callback: imageActions.recallAll,
options: { enabled: !isStaging && (isGalleryFocused || isViewerFocused) },
dependencies: [imageActions.recallAll, isStaging, isGalleryFocused, isViewerFocused],
callback: recallAll.recall,
options: { enabled: recallAll.isEnabled && isFocusOK },
dependencies: [recallAll, isFocusOK],
});
useRegisteredHotkeys({
id: 'recallSeed',
category: 'viewer',
callback: imageActions.recallSeed,
options: { enabled: isGalleryFocused || isViewerFocused },
dependencies: [imageActions.recallSeed, isGalleryFocused, isViewerFocused],
callback: recallSeed.recall,
options: { enabled: recallSeed.isEnabled && isFocusOK },
dependencies: [recallSeed, isFocusOK],
});
useRegisteredHotkeys({
id: 'recallPrompts',
category: 'viewer',
callback: imageActions.recallPrompts,
options: { enabled: isGalleryFocused || isViewerFocused },
dependencies: [imageActions.recallPrompts, isGalleryFocused, isViewerFocused],
callback: recallPrompts.recall,
options: { enabled: recallPrompts.isEnabled && isFocusOK },
dependencies: [recallPrompts, isFocusOK],
});
useRegisteredHotkeys({
id: 'remix',
category: 'viewer',
callback: imageActions.remix,
options: { enabled: isGalleryFocused || isViewerFocused },
dependencies: [imageActions.remix, isGalleryFocused, isViewerFocused],
callback: recallRemix.recall,
options: { enabled: recallRemix.isEnabled && isFocusOK },
dependencies: [recallRemix, isFocusOK],
});
useRegisteredHotkeys({
id: 'useSize',
category: 'viewer',
callback: imageActions.recallSize,
options: { enabled: !isStaging && (isGalleryFocused || isViewerFocused) },
dependencies: [imageActions.recallSize, isStaging, isGalleryFocused, isViewerFocused],
});
useRegisteredHotkeys({
id: 'runPostprocessing',
category: 'viewer',
callback: imageActions.upscale,
options: { enabled: isUpscalingEnabled && isViewerFocused },
dependencies: [isUpscalingEnabled, imageDTO, isViewerFocused],
callback: recallDimensions.recall,
options: { enabled: recallDimensions.isEnabled && isFocusOK },
dependencies: [recallDimensions, isFocusOK],
});
return null;
});

View File

@@ -42,7 +42,6 @@ import { $socketOptions } from 'services/events/stores';
import type { ManagerOptions, SocketOptions } from 'socket.io-client';
const App = lazy(() => import('./App'));
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
interface Props extends PropsWithChildren {
apiUrl?: string;
@@ -330,9 +329,7 @@ const InvokeAIUI = ({
<React.StrictMode>
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider>
<App config={config} studioInitAction={studioInitAction} />
</ThemeLocaleProvider>
<App config={config} studioInitAction={studioInitAction} />
</React.Suspense>
</Provider>
</React.StrictMode>

View File

@@ -20,6 +20,7 @@ import {
import { $isStylePresetsMenuOpen, activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { toast } from 'features/toast/toast';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { LAUNCHPAD_PANEL_ID, WORKSPACE_PANEL_ID } from 'features/ui/layouts/shared';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import { atom } from 'nanostores';
@@ -91,6 +92,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
const overrides: Partial<CanvasRasterLayerState> = {
objects: [imageObject],
};
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
store.dispatch(canvasReset());
store.dispatch(rasterLayerAdded({ overrides, isSelected: true }));
store.dispatch(sentImageToCanvas());
@@ -157,16 +159,17 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
);
const handleGoToDestination = useCallback(
(destination: StudioDestinationAction['data']['destination']) => {
async (destination: StudioDestinationAction['data']['destination']) => {
switch (destination) {
case 'generation':
// Go to the canvas tab, open the image viewer, and enable send-to-gallery mode
// Go to the generate tab, open the launchpad
await navigationApi.focusPanel('generate', LAUNCHPAD_PANEL_ID);
store.dispatch(paramsReset());
store.dispatch(activeTabCanvasRightPanelChanged('gallery'));
break;
case 'canvas':
// Go to the canvas tab, close the image viewer, and disable send-to-gallery mode
store.dispatch(canvasReset());
// Go to the canvas tab, open the launchpad
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
break;
case 'workflows':
// Go to the workflows tab

View File

@@ -1,14 +1,28 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { bboxSyncedToOptimalDimension } from 'features/controlLayers/store/canvasSlice';
import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import { modelChanged, syncedToOptimalDimension, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { selectBboxModelBase } from 'features/controlLayers/store/selectors';
import { refImageModelChanged, selectReferenceImageEntities } from 'features/controlLayers/store/refImagesSlice';
import {
selectAllEntitiesOfType,
selectBboxModelBase,
selectCanvasSlice,
} from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { modelSelected } from 'features/parameters/store/actions';
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { selectGlobalRefImageModels, selectRegionalRefImageModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig } from 'services/api/types';
import {
isChatGPT4oModelConfig,
isFluxKontextApiModelConfig,
isFluxKontextModelConfig,
isFluxReduxModelConfig,
} from 'services/api/types';
const log = logger('models');
@@ -25,9 +39,8 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
}
const newModel = result.data;
const newBaseModel = newModel.base;
const didBaseModelChange = state.params.model?.base !== newBaseModel;
const newBase = newModel.base;
const didBaseModelChange = state.params.model?.base !== newBase;
if (didBaseModelChange) {
// we may need to reset some incompatible submodels
@@ -35,7 +48,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
// handle incompatible loras
state.loras.loras.forEach((lora) => {
if (lora.model.base !== newBaseModel) {
if (lora.model.base !== newBase) {
dispatch(loraDeleted({ id: lora.id }));
modelsCleared += 1;
}
@@ -43,20 +56,82 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
// handle incompatible vae
const { vae } = state.params;
if (vae && vae.base !== newBaseModel) {
if (vae && vae.base !== newBase) {
dispatch(vaeSelected(null));
modelsCleared += 1;
}
// handle incompatible controlnets
// state.canvas.present.controlAdapters.entities.forEach((ca) => {
// if (ca.model?.base !== newBaseModel) {
// modelsCleared += 1;
// if (ca.isEnabled) {
// dispatch(entityIsEnabledToggled({ entityIdentifier: { id: ca.id, type: 'control_adapter' } }));
// }
// }
// });
// Handle incompatible reference image models - switch to first compatible model, with some smart logic
// to choose the best available model based on the new main model.
const allRefImageModels = selectGlobalRefImageModels(state).filter(({ base }) => base === newBase);
let newGlobalRefImageModel = null;
// Certain models require the ref image model to be the same as the main model - others just need a matching
// base. Helper to grab the first exact match or the first available model if no exact match is found.
const exactMatchOrFirst = <T extends AnyModelConfig>(candidates: T[]): T | null =>
candidates.find(({ key }) => key === newModel.key) ?? candidates[0] ?? null;
// The only way we can differentiate between FLUX and FLUX Kontext is to check for "kontext" in the name
if (newModel.base === 'flux' && newModel.name.toLowerCase().includes('kontext')) {
const fluxKontextDevModels = allRefImageModels.filter(isFluxKontextModelConfig);
newGlobalRefImageModel = exactMatchOrFirst(fluxKontextDevModels);
} else if (newModel.base === 'chatgpt-4o') {
const chatGPT4oModels = allRefImageModels.filter(isChatGPT4oModelConfig);
newGlobalRefImageModel = exactMatchOrFirst(chatGPT4oModels);
} else if (newModel.base === 'flux-kontext') {
const fluxKontextApiModels = allRefImageModels.filter(isFluxKontextApiModelConfig);
newGlobalRefImageModel = exactMatchOrFirst(fluxKontextApiModels);
} else if (newModel.base === 'flux') {
const fluxReduxModels = allRefImageModels.filter(isFluxReduxModelConfig);
newGlobalRefImageModel = fluxReduxModels[0] ?? null;
} else {
newGlobalRefImageModel = allRefImageModels[0] ?? null;
}
// All ref image entities are updated to use the same new model
const refImageEntities = selectReferenceImageEntities(state);
for (const entity of refImageEntities) {
const shouldUpdateModel =
(entity.config.model && entity.config.model.base !== newBase) ||
(!entity.config.model && newGlobalRefImageModel);
if (shouldUpdateModel) {
dispatch(
refImageModelChanged({
id: entity.id,
modelConfig: newGlobalRefImageModel,
})
);
modelsCleared += 1;
}
}
// For regional guidance, there is no smart logic - we just pick the first available model.
const newRegionalRefImageModel = selectRegionalRefImageModels(state)[0] ?? null;
// All regional guidance entities are updated to use the same new model.
const canvasState = selectCanvasSlice(state);
const canvasRegionalGuidanceEntities = selectAllEntitiesOfType(canvasState, 'regional_guidance');
for (const entity of canvasRegionalGuidanceEntities) {
for (const refImage of entity.referenceImages) {
// Only change the model if the current one is not compatible with the new base model.
const shouldUpdateModel =
(refImage.config.model && refImage.config.model.base !== newBase) ||
(!refImage.config.model && newRegionalRefImageModel);
if (shouldUpdateModel) {
dispatch(
rgRefImageModelChanged({
entityIdentifier: getEntityIdentifier(entity),
referenceImageId: refImage.id,
modelConfig: newRegionalRefImageModel,
})
);
modelsCleared += 1;
}
}
}
if (modelsCleared > 0) {
toast({

View File

@@ -3,6 +3,7 @@ import { isNil } from 'es-toolkit';
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
heightChanged,
setCfgRescaleMultiplier,
setCfgScale,
setGuidance,
@@ -10,6 +11,7 @@ import {
setSteps,
vaePrecisionChanged,
vaeSelected,
widthChanged,
} from 'features/controlLayers/store/paramsSlice';
import { setDefaultSettings } from 'features/parameters/store/actions';
import {
@@ -24,6 +26,7 @@ import {
zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import { toast } from 'features/toast/toast';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { t } from 'i18next';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import { isNonRefinerMainModelConfig } from 'services/api/types';
@@ -113,15 +116,24 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
const setSizeOptions = { updateAspectRatio: true, clamp: true };
const isStaging = selectIsStaging(getState());
if (!isStaging && width) {
const activeTab = selectActiveTab(getState());
if (activeTab === 'generate') {
if (isParameterWidth(width)) {
dispatch(bboxWidthChanged({ width, ...setSizeOptions }));
dispatch(widthChanged({ width, ...setSizeOptions }));
}
if (isParameterHeight(height)) {
dispatch(heightChanged({ height, ...setSizeOptions }));
}
}
if (!isStaging && height) {
if (isParameterHeight(height)) {
dispatch(bboxHeightChanged({ height, ...setSizeOptions }));
if (activeTab === 'canvas') {
if (!isStaging) {
if (isParameterWidth(width)) {
dispatch(bboxWidthChanged({ width, ...setSizeOptions }));
}
if (isParameterHeight(height)) {
dispatch(bboxHeightChanged({ height, ...setSizeOptions }));
}
}
}

View File

@@ -87,14 +87,10 @@ export const buildGroup = <T extends object>(group: Omit<Group<T>, typeof unique
[uniqueGroupKey]: true,
});
const isGroup = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGroup is Group<T> => {
export const isGroup = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGroup is Group<T> => {
return uniqueGroupKey in optionOrGroup && optionOrGroup[uniqueGroupKey] === true;
};
export const isOption = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGroup is T => {
return !(uniqueGroupKey in optionOrGroup);
};
const DefaultOptionComponent = typedMemo(<T extends object>({ option }: { option: T }) => {
const { getOptionId } = usePickerContext();
return <Text fontWeight="bold">{getOptionId(option)}</Text>;

View File

@@ -6,7 +6,6 @@ import { atom, computed } from 'nanostores';
import type { RefObject } from 'react';
import { useEffect } from 'react';
import { objectKeys } from 'tsafe';
import z from 'zod/v4';
/**
* We need to manage focus regions to conditionally enable hotkeys:
@@ -28,10 +27,7 @@ import z from 'zod/v4';
const log = logger('system');
/**
* The names of the focus regions.
*/
const zFocusRegionName = z.enum([
const REGION_NAMES = [
'launchpad',
'viewer',
'gallery',
@@ -41,13 +37,16 @@ const zFocusRegionName = z.enum([
'workflows',
'progress',
'settings',
]);
export type FocusRegionName = z.infer<typeof zFocusRegionName>;
] as const;
/**
* The names of the focus regions.
*/
export type FocusRegionName = (typeof REGION_NAMES)[number];
/**
* A map of focus regions to the elements that are part of that region.
*/
const REGION_TARGETS: Record<FocusRegionName, Set<HTMLElement>> = zFocusRegionName.options.values().reduce(
const REGION_TARGETS: Record<FocusRegionName, Set<HTMLElement>> = REGION_NAMES.reduce(
(acc, region) => {
acc[region] = new Set<HTMLElement>();
return acc;

View File

@@ -0,0 +1,115 @@
import { useStore } from '@nanostores/react';
import { WrappedError } from 'common/util/result';
import type { Atom } from 'nanostores';
import { atom } from 'nanostores';
import { useCallback, useEffect, useMemo, useState } from 'react';
type SuccessState<T> = {
status: 'success';
value: T;
error: null;
};
type ErrorState = {
status: 'error';
value: null;
error: Error;
};
type PendingState = {
status: 'pending';
value: null;
error: null;
};
type IdleState = {
status: 'idle';
value: null;
error: null;
};
export type State<T> = IdleState | PendingState | SuccessState<T> | ErrorState;
type UseAsyncStateOptions = {
immediate?: boolean;
};
type UseAsyncReturn<T> = {
$state: Atom<State<T>>;
trigger: () => Promise<void>;
reset: () => void;
};
export const useAsyncState = <T>(execute: () => Promise<T>, options?: UseAsyncStateOptions): UseAsyncReturn<T> => {
const $state = useState(() =>
atom<State<T>>({
status: 'idle',
value: null,
error: null,
})
)[0];
const trigger = useCallback(async () => {
$state.set({
status: 'pending',
value: null,
error: null,
});
try {
const value = await execute();
$state.set({
status: 'success',
value,
error: null,
});
} catch (error) {
$state.set({
status: 'error',
value: null,
error: WrappedError.wrap(error),
});
}
}, [$state, execute]);
const reset = useCallback(() => {
$state.set({
status: 'idle',
value: null,
error: null,
});
}, [$state]);
useEffect(() => {
if (options?.immediate) {
trigger();
}
}, [options?.immediate, trigger]);
const api = useMemo(
() =>
({
$state,
trigger,
reset,
}) satisfies UseAsyncReturn<T>,
[$state, trigger, reset]
);
return api;
};
type UseAsyncReturnReactive<T> = {
state: State<T>;
trigger: () => Promise<void>;
reset: () => void;
};
export const useAsyncStateReactive = <T>(
execute: () => Promise<T>,
options?: UseAsyncStateOptions
): UseAsyncReturnReactive<T> => {
const { $state, trigger, reset } = useAsyncState(execute, options);
const state = useStore($state);
return { state, trigger, reset };
};

View File

@@ -0,0 +1,23 @@
import { Alert, AlertIcon, AlertTitle } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { selectSaveAllImagesToGallery } from 'features/controlLayers/store/canvasSettingsSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasAlertsSaveAllImagesToGallery = memo(() => {
const { t } = useTranslation();
const saveAllImagesToGallery = useAppSelector(selectSaveAllImagesToGallery);
if (!saveAllImagesToGallery) {
return null;
}
return (
<Alert status="info" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
<AlertIcon />
<AlertTitle>{t('controlLayers.settings.saveAllImagesToGallery.alert')}</AlertTitle>
</Alert>
);
});
CanvasAlertsSaveAllImagesToGallery.displayName = 'CanvasAlertsSaveAllImagesToGallery';

View File

@@ -4,13 +4,17 @@ import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useRefImageEntity } from 'features/controlLayers/components/RefImage/useRefImageEntity';
import { useRefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
import {
refImageDeleted,
refImageIsEnabledToggled,
selectRefImageEntityIds,
} from 'features/controlLayers/store/refImagesSlice';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import { memo, useCallback, useMemo } from 'react';
import { PiCircleBold, PiCircleFill, PiTrashBold } from 'react-icons/pi';
import { PiCircleBold, PiCircleFill, PiTrashBold, PiWarningBold } from 'react-icons/pi';
import { RefImageWarningTooltipContent } from './RefImageWarningTooltipContent';
const textSx: SystemStyleObject = {
color: 'base.300',
@@ -28,6 +32,12 @@ export const RefImageHeader = memo(() => {
);
const refImageNumber = useAppSelector(selectRefImageNumber);
const entity = useRefImageEntity(id);
const mainModelConfig = useAppSelector(selectMainModelConfig);
const warnings = useMemo(() => {
return getGlobalReferenceImageWarnings(entity, mainModelConfig);
}, [entity, mainModelConfig]);
const deleteRefImage = useCallback(() => {
dispatch(refImageDeleted({ id }));
}, [dispatch, id]);
@@ -42,6 +52,18 @@ export const RefImageHeader = memo(() => {
Reference Image #{refImageNumber}
</Text>
<Flex alignItems="center" gap={1}>
{warnings.length > 0 && (
<IconButton
as="span"
size="sm"
variant="link"
alignSelf="stretch"
aria-label="warnings"
tooltip={<RefImageWarningTooltipContent warnings={warnings} />}
icon={<PiWarningBold />}
colorScheme="warning"
/>
)}
{!entity.isEnabled && (
<Text fontSize="xs" fontStyle="italic" color="base.400">
Disabled

View File

@@ -61,7 +61,7 @@ export const RefImageImage = memo(
)}
{imageDTO && (
<>
<DndImage imageDTO={imageDTO} borderWidth={1} borderStyle="solid" w="full" />
<DndImage imageDTO={imageDTO} borderRadius="base" borderWidth={1} borderStyle="solid" w="full" />
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
<DndImageIcon
onClick={handleResetControlImage}

View File

@@ -1,9 +1,12 @@
import { Button, Collapse, Divider, Flex } from '@invoke-ai/ui-library';
import { Button, Collapse, Divider, Flex, IconButton } from '@invoke-ai/ui-library';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { RefImagePreview } from 'features/controlLayers/components/RefImage/RefImagePreview';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { RefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
import { getDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerHooks';
import { useNewGlobalReferenceImageFromBbox } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusySafe } from 'features/controlLayers/hooks/useCanvasIsBusy';
import {
refImageAdded,
selectIsRefImagePanelOpen,
@@ -13,8 +16,10 @@ import {
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { addGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useMemo } from 'react';
import { PiUploadBold } from 'react-icons/pi';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold, PiUploadBold } from 'react-icons/pi';
import type { ImageDTO } from 'services/api/types';
import { RefImageHeader } from './RefImageHeader';
@@ -78,6 +83,7 @@ MaxRefImages.displayName = 'MaxRefImages';
const AddRefImageDropTargetAndButton = memo(() => {
const { dispatch, getState } = useAppStore();
const tab = useAppSelector(selectActiveTab);
const uploadOptions = useMemo(
() =>
@@ -95,7 +101,7 @@ const AddRefImageDropTargetAndButton = memo(() => {
const uploadApi = useImageUploadButton(uploadOptions);
return (
<>
<Flex gap={1} h="full" w="full">
<Button
position="relative"
size="sm"
@@ -112,7 +118,31 @@ const AddRefImageDropTargetAndButton = memo(() => {
<input {...uploadApi.getUploadInputProps()} />
<DndDropTarget label="Drop" dndTarget={addGlobalReferenceImageDndTarget} dndTargetData={dndTargetData} />
</Button>
</>
{tab === 'canvas' && (
<CanvasManagerProviderGate>
<BboxButton />
</CanvasManagerProviderGate>
)}
</Flex>
);
});
const BboxButton = memo(() => {
const { t } = useTranslation();
const isBusy = useCanvasIsBusySafe();
const newGlobalReferenceImageFromBbox = useNewGlobalReferenceImageFromBbox();
return (
<IconButton
size="lg"
variant="outline"
h="full"
icon={<PiBoundingBoxBold />}
onClick={newGlobalReferenceImageFromBbox}
isDisabled={isBusy}
aria-label={t('controlLayers.pullBboxIntoReferenceImage')}
tooltip={t('controlLayers.pullBboxIntoReferenceImage')}
/>
);
});
AddRefImageDropTargetAndButton.displayName = 'AddRefImageDropTargetAndButton';

View File

@@ -1,5 +1,5 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, Icon, IconButton, Image, Skeleton, Text } from '@invoke-ai/ui-library';
import { Flex, Icon, IconButton, Image, Skeleton, Text, Tooltip } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { round } from 'es-toolkit/compat';
@@ -17,6 +17,8 @@ import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { PiExclamationMarkBold, PiEyeSlashBold, PiImageBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { RefImageWarningTooltipContent } from './RefImageWarningTooltipContent';
const baseSx: SystemStyleObject = {
'&[data-is-open="true"]': {
borderColor: 'invokeBlue.300',
@@ -51,9 +53,6 @@ const getImageSxWithWeight = (weight: number): SystemStyleObject => {
return {
...baseSx,
'&[data-is-disabled="true"]': {
opacity: 0.4,
},
_after: {
content: '""',
position: 'absolute',
@@ -95,8 +94,8 @@ export const RefImagePreview = memo(() => {
};
}, [entity.config]);
const isInvalid = useMemo(() => {
return getGlobalReferenceImageWarnings(entity, mainModelConfig).length > 0;
const warnings = useMemo(() => {
return getGlobalReferenceImageWarnings(entity, mainModelConfig);
}, [entity, mainModelConfig]);
const onClick = useCallback(() => {
@@ -126,74 +125,76 @@ export const RefImagePreview = memo(() => {
);
}
return (
<Flex
position="relative"
borderWidth={1}
borderStyle="solid"
borderRadius="base"
aspectRatio="1/1"
maxW="full"
maxH="full"
flexShrink={0}
sx={sx}
data-is-open={selectedEntityId === id && isPanelOpen}
data-is-error={isInvalid}
data-is-disabled={!entity.isEnabled}
role="button"
onClick={onClick}
cursor="pointer"
>
<Image
src={imageDTO?.thumbnail_url}
objectFit="contain"
<Tooltip label={warnings.length > 0 ? <RefImageWarningTooltipContent warnings={warnings} /> : undefined}>
<Flex
position="relative"
borderWidth={1}
borderStyle="solid"
borderRadius="base"
aspectRatio="1/1"
height={imageDTO?.height}
fallback={<Skeleton h="full" aspectRatio="1/1" />}
maxW="full"
maxH="full"
borderRadius="base"
/>
{isIPAdapterConfig(entity.config) && (
<Flex
position="absolute"
inset={0}
fontWeight="semibold"
alignItems="center"
justifyContent="center"
zIndex={1}
data-visible={showWeightDisplay}
sx={weightDisplaySx}
>
<Text filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))">
{`${round(entity.config.weight * 100, 2)}%`}
</Text>
</Flex>
)}
{!entity.isEnabled && (
<Icon
position="absolute"
top="50%"
left="50%"
transform="translateX(-50%) translateY(-50%)"
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
color="base.300"
boxSize={8}
as={PiEyeSlashBold}
flexShrink={0}
sx={sx}
data-is-open={selectedEntityId === id && isPanelOpen}
data-is-error={warnings.length > 0}
data-is-disabled={!entity.isEnabled}
role="button"
onClick={onClick}
cursor="pointer"
overflow="hidden"
>
<Image
src={imageDTO?.thumbnail_url}
objectFit="contain"
aspectRatio="1/1"
height={imageDTO?.height}
fallback={<Skeleton h="full" aspectRatio="1/1" />}
maxW="full"
maxH="full"
/>
)}
{entity.isEnabled && isInvalid && (
<Icon
position="absolute"
top="50%"
left="50%"
transform="translateX(-50%) translateY(-50%)"
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
color="error.500"
boxSize={12}
as={PiExclamationMarkBold}
/>
)}
</Flex>
{isIPAdapterConfig(entity.config) && (
<Flex
position="absolute"
inset={0}
fontWeight="semibold"
alignItems="center"
justifyContent="center"
zIndex={1}
data-visible={showWeightDisplay}
sx={weightDisplaySx}
>
<Text filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))">
{`${round(entity.config.weight * 100, 2)}%`}
</Text>
</Flex>
)}
{!entity.isEnabled && (
<Icon
position="absolute"
top="50%"
left="50%"
transform="translateX(-50%) translateY(-50%)"
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
color="base.300"
boxSize={8}
as={PiEyeSlashBold}
/>
)}
{entity.isEnabled && warnings.length > 0 && (
<Icon
position="absolute"
top="50%"
left="50%"
transform="translateX(-50%) translateY(-50%)"
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
color="error.500"
boxSize={12}
as={PiExclamationMarkBold}
/>
)}
</Flex>
</Tooltip>
);
});
RefImagePreview.displayName = 'RefImagePreview';

View File

@@ -0,0 +1,18 @@
import { Flex, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library';
import { upperFirst } from 'es-toolkit/compat';
import { useTranslation } from 'react-i18next';
export const RefImageWarningTooltipContent = ({ warnings }: { warnings: string[] }) => {
const { t } = useTranslation();
return (
<Flex flexDir="column">
<Text fontWeight="semibold">Invalid Reference Image:</Text>
<UnorderedList>
{warnings.map((tKey) => (
<ListItem key={tKey}>{upperFirst(t(tKey))}</ListItem>
))}
</UnorderedList>
</Flex>
);
};

View File

@@ -26,6 +26,7 @@ import { CanvasSettingsPreserveMaskCheckbox } from 'features/controlLayers/compo
import { CanvasSettingsPressureSensitivityCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsPressureSensitivity';
import { CanvasSettingsRecalculateRectsButton } from 'features/controlLayers/components/Settings/CanvasSettingsRecalculateRectsButton';
import { CanvasSettingsRuleOfThirdsSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsRuleOfThirdsGuideSwitch';
import { CanvasSettingsSaveAllImagesToGalleryCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsSaveAllImagesToGalleryCheckbox';
import { CanvasSettingsShowHUDSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsShowHUDSwitch';
import { CanvasSettingsShowProgressOnCanvas } from 'features/controlLayers/components/Settings/CanvasSettingsShowProgressOnCanvasSwitch';
import { memo } from 'react';
@@ -61,6 +62,7 @@ export const CanvasSettingsPopover = memo(() => {
<CanvasSettingsPreserveMaskCheckbox />
<CanvasSettingsClipToBboxCheckbox />
<CanvasSettingsOutputOnlyMaskedRegionsCheckbox />
<CanvasSettingsSaveAllImagesToGalleryCheckbox />
</Flex>
<Divider />

View File

@@ -0,0 +1,25 @@
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
selectSaveAllImagesToGallery,
settingsSaveAllImagesToGalleryToggled,
} from 'features/controlLayers/store/canvasSettingsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasSettingsSaveAllImagesToGalleryCheckbox = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const saveAllImagesToGallery = useAppSelector(selectSaveAllImagesToGallery);
const onChange = useCallback(() => {
dispatch(settingsSaveAllImagesToGalleryToggled());
}, [dispatch]);
return (
<FormControl w="full">
<FormLabel flexGrow={1}>{t('controlLayers.saveAllImagesToGallery')}</FormLabel>
<Checkbox isChecked={saveAllImagesToGallery} onChange={onChange} />
</FormControl>
);
});
CanvasSettingsSaveAllImagesToGalleryCheckbox.displayName = 'CanvasSettingsSaveAllImagesToGalleryCheckbox';

View File

@@ -1,4 +1,4 @@
import { Button, Flex, Grid, Heading, Text } from '@invoke-ai/ui-library';
import { Button, Flex, Grid, Text } from '@invoke-ai/ui-library';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { WORKSPACE_PANEL_ID } from 'features/ui/layouts/shared';
import { memo, useCallback } from 'react';
@@ -6,6 +6,7 @@ import { useTranslation } from 'react-i18next';
import { InitialStateMainModelPicker } from './InitialStateMainModelPicker';
import { LaunchpadAddStyleReference } from './LaunchpadAddStyleReference';
import { LaunchpadContainer } from './LaunchpadContainer';
import { LaunchpadEditImageButton } from './LaunchpadEditImageButton';
import { LaunchpadGenerateFromTextButton } from './LaunchpadGenerateFromTextButton';
import { LaunchpadUseALayoutImageButton } from './LaunchpadUseALayoutImageButton';
@@ -16,35 +17,30 @@ export const CanvasLaunchpadPanel = memo(() => {
navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
}, []);
return (
<Flex flexDir="column" h="full" w="full" alignItems="center" gap={2}>
<Flex flexDir="column" w="full" gap={4} px={14} maxW={768} pt="20vh">
<Heading mb={4}>{t('ui.launchpad.canvasTitle')}</Heading>
<Flex flexDir="column" gap={8}>
<Grid gridTemplateColumns="1fr 1fr" gap={8}>
<InitialStateMainModelPicker />
<Flex flexDir="column" gap={2} justifyContent="center">
<Text>
{t('ui.launchpad.modelGuideText')}{' '}
<Button
as="a"
variant="link"
href="https://support.invoke.ai/support/solutions/articles/151000216086-model-guide"
target="_blank"
rel="noopener noreferrer"
size="sm"
>
{t('ui.launchpad.modelGuideLink')}
</Button>
</Text>
</Flex>
</Grid>
<LaunchpadGenerateFromTextButton extraAction={focusCanvas} />
<LaunchpadAddStyleReference extraAction={focusCanvas} />
<LaunchpadEditImageButton extraAction={focusCanvas} />
<LaunchpadUseALayoutImageButton extraAction={focusCanvas} />
<LaunchpadContainer heading={t('ui.launchpad.canvasTitle')}>
<Grid gridTemplateColumns="1fr 1fr" gap={8}>
<InitialStateMainModelPicker />
<Flex flexDir="column" gap={2} justifyContent="center">
<Text>
{t('ui.launchpad.modelGuideText')}{' '}
<Button
as="a"
variant="link"
href="https://support.invoke.ai/support/solutions/articles/151000216086-model-guide"
target="_blank"
rel="noopener noreferrer"
size="sm"
>
{t('ui.launchpad.modelGuideLink')}
</Button>
</Text>
</Flex>
</Flex>
</Flex>
</Grid>
<LaunchpadGenerateFromTextButton extraAction={focusCanvas} />
<LaunchpadAddStyleReference extraAction={focusCanvas} />
<LaunchpadEditImageButton extraAction={focusCanvas} />
<LaunchpadUseALayoutImageButton extraAction={focusCanvas} />
</LaunchpadContainer>
);
});
CanvasLaunchpadPanel.displayName = 'CanvasLaunchpadPanel';

View File

@@ -1,9 +1,10 @@
import { Alert, Button, Flex, Grid, Heading, Text } from '@invoke-ai/ui-library';
import { Alert, Button, Flex, Grid, Text } from '@invoke-ai/ui-library';
import { InitialStateMainModelPicker } from 'features/controlLayers/components/SimpleSession/InitialStateMainModelPicker';
import { LaunchpadAddStyleReference } from 'features/controlLayers/components/SimpleSession/LaunchpadAddStyleReference';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { memo, useCallback } from 'react';
import { LaunchpadContainer } from './LaunchpadContainer';
import { LaunchpadGenerateFromTextButton } from './LaunchpadGenerateFromTextButton';
export const GenerateLaunchpadPanel = memo(() => {
@@ -12,41 +13,36 @@ export const GenerateLaunchpadPanel = memo(() => {
}, []);
return (
<Flex flexDir="column" h="full" w="full" alignItems="center" gap={2}>
<Flex flexDir="column" w="full" gap={4} px={14} maxW={768} pt="20vh">
<Heading mb={4}>Generate images from text prompts.</Heading>
<Flex flexDir="column" gap={8}>
<Grid gridTemplateColumns="1fr 1fr" gap={8}>
<InitialStateMainModelPicker />
<Flex flexDir="column" gap={2} justifyContent="center">
<Text>
Want to learn what prompts work best for each model?{' '}
<Button
as="a"
variant="link"
href="https://support.invoke.ai/support/solutions/articles/151000216086-model-guide"
target="_blank"
rel="noopener noreferrer"
size="sm"
>
Check out our Model Guide.
</Button>
</Text>
</Flex>
</Grid>
<LaunchpadGenerateFromTextButton />
<LaunchpadAddStyleReference />
<Alert status="info" borderRadius="base" flexDir="column" gap={2} overflow="unset">
<Text fontSize="md" fontWeight="semibold">
Looking to get more control, edit, and iterate on your images?
</Text>
<Button variant="link" onClick={newCanvasSession}>
Navigate to Canvas for more capabilities.
<LaunchpadContainer heading="Generate images from text prompts.">
<Grid gridTemplateColumns="1fr 1fr" gap={8}>
<InitialStateMainModelPicker />
<Flex flexDir="column" gap={2} justifyContent="center">
<Text>
Want to learn what prompts work best for each model?{' '}
<Button
as="a"
variant="link"
href="https://support.invoke.ai/support/solutions/articles/151000216086-model-guide"
target="_blank"
rel="noopener noreferrer"
size="sm"
>
Check out our Model Guide.
</Button>
</Alert>
</Text>
</Flex>
</Flex>
</Flex>
</Grid>
<LaunchpadGenerateFromTextButton />
<LaunchpadAddStyleReference />
<Alert status="info" borderRadius="base" flexDir="column" gap={2} overflow="unset">
<Text fontSize="md" fontWeight="semibold">
Looking to get more control, edit, and iterate on your images?
</Text>
<Button variant="link" onClick={newCanvasSession}>
Navigate to Canvas for more capabilities.
</Button>
</Alert>
</LaunchpadContainer>
);
});
GenerateLaunchpadPanel.displayName = 'GenerateLaunchpad';

View File

@@ -0,0 +1,17 @@
import { Flex, Heading } from '@invoke-ai/ui-library';
import type { PropsWithChildren } from 'react';
import { memo } from 'react';
export const LaunchpadContainer = memo((props: PropsWithChildren<{ heading: string }>) => {
return (
<Flex flexDir="column" h="full" w="full" alignItems="center" justifyContent="center" gap={2}>
<Flex flexDir="column" w="full" gap={4} px={14} maxW={768}>
<Heading>{props.heading}</Heading>
<Flex flexDir="column" gap={4}>
{props.children}
</Flex>
</Flex>
</Flex>
);
});
LaunchpadContainer.displayName = 'LaunchpadContainer';

View File

@@ -1,5 +1,6 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
useCanvasSessionContext,
useOutputImageDTO,
@@ -10,6 +11,10 @@ import { QueueItemNumber } from 'features/controlLayers/components/SimpleSession
import { QueueItemProgressImage } from 'features/controlLayers/components/SimpleSession/QueueItemProgressImage';
import { QueueItemStatusLabel } from 'features/controlLayers/components/SimpleSession/QueueItemStatusLabel';
import { getQueueItemElementId } from 'features/controlLayers/components/SimpleSession/shared';
import {
selectStagingAreaAutoSwitch,
settingsStagingAreaAutoSwitchChanged,
} from 'features/controlLayers/store/canvasSettingsSlice';
import { DndImage } from 'features/dnd/DndImage';
import { toast } from 'features/toast/toast';
import { memo, useCallback } from 'react';
@@ -21,12 +26,13 @@ const sx = {
pos: 'relative',
alignItems: 'center',
justifyContent: 'center',
h: 108,
w: 108,
flexShrink: 0,
h: 'full',
aspectRatio: '1/1',
borderWidth: 2,
borderRadius: 'base',
bg: 'base.900',
overflow: 'hidden',
'&[data-selected="true"]': {
borderColor: 'invokeBlue.300',
},
@@ -34,28 +40,29 @@ const sx = {
type Props = {
item: S['SessionQueueItem'];
number: number;
index: number;
isSelected: boolean;
};
export const QueueItemPreviewMini = memo(({ item, isSelected, number }: Props) => {
export const QueueItemPreviewMini = memo(({ item, isSelected, index }: Props) => {
const dispatch = useAppDispatch();
const ctx = useCanvasSessionContext();
const { imageLoaded } = useProgressData(ctx.$progressData, item.item_id);
const imageDTO = useOutputImageDTO(item);
const autoSwitch = useAppSelector(selectStagingAreaAutoSwitch);
const onClick = useCallback(() => {
ctx.$selectedItemId.set(item.item_id);
}, [ctx.$selectedItemId, item.item_id]);
const onDoubleClick = useCallback(() => {
const autoSwitch = ctx.$autoSwitch.get();
if (autoSwitch !== 'off') {
ctx.$autoSwitch.set('off');
dispatch(settingsStagingAreaAutoSwitchChanged('off'));
toast({
title: 'Auto-Switch Disabled',
});
}
}, [ctx.$autoSwitch]);
}, [autoSwitch, dispatch]);
const onLoad = useCallback(() => {
ctx.onImageLoad(item.item_id);
@@ -63,16 +70,16 @@ export const QueueItemPreviewMini = memo(({ item, isSelected, number }: Props) =
return (
<Flex
id={getQueueItemElementId(item.item_id)}
id={getQueueItemElementId(index)}
sx={sx}
data-selected={isSelected}
onClick={onClick}
onDoubleClick={onDoubleClick}
>
<QueueItemStatusLabel item={item} position="absolute" margin="auto" />
{imageDTO && <DndImage imageDTO={imageDTO} onLoad={onLoad} asThumbnail />}
{imageDTO && <DndImage imageDTO={imageDTO} onLoad={onLoad} asThumbnail position="absolute" />}
{!imageLoaded && <QueueItemProgressImage itemId={item.item_id} position="absolute" />}
<QueueItemNumber number={number} position="absolute" top={0} left={1} />
<QueueItemNumber number={index + 1} position="absolute" top={0} left={1} />
<QueueItemCircularProgress itemId={item.item_id} status={item.status} position="absolute" top={1} right={2} />
</Flex>
);

View File

@@ -16,21 +16,21 @@ export const QueueItemStatusLabel = memo(({ item, ...rest }: Props) => {
if (item.status === 'pending') {
return (
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="base.300" {...rest}>
<Text fontSize="xs" pointerEvents="none" userSelect="none" fontWeight="semibold" color="base.300" {...rest}>
Pending
</Text>
);
}
if (item.status === 'canceled') {
return (
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="warning.300" {...rest}>
<Text fontSize="xs" pointerEvents="none" userSelect="none" fontWeight="semibold" color="warning.300" {...rest}>
Canceled
</Text>
);
}
if (item.status === 'failed') {
return (
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="error.300" {...rest}>
<Text fontSize="xs" pointerEvents="none" userSelect="none" fontWeight="semibold" color="error.300" {...rest}>
Failed
</Text>
);
@@ -38,7 +38,7 @@ export const QueueItemStatusLabel = memo(({ item, ...rest }: Props) => {
if (item.status === 'in_progress') {
return (
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="invokeBlue.300" {...rest}>
<Text fontSize="xs" pointerEvents="none" userSelect="none" fontWeight="semibold" color="invokeBlue.300" {...rest}>
In Progress
</Text>
);
@@ -46,7 +46,14 @@ export const QueueItemStatusLabel = memo(({ item, ...rest }: Props) => {
if (item.status === 'completed') {
return (
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="invokeGreen.300" {...rest}>
<Text
fontSize="xs"
pointerEvents="none"
userSelect="none"
fontWeight="semibold"
color="invokeGreen.300"
{...rest}
>
Completed
</Text>
);

View File

@@ -1,17 +1,149 @@
import { Flex } from '@invoke-ai/ui-library';
import { Box, Flex, forwardRef } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { logger } from 'app/logging/logger';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { QueueItemPreviewMini } from 'features/controlLayers/components/SimpleSession/QueueItemPreviewMini';
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { memo, useEffect } from 'react';
import { useOverlayScrollbars } from 'overlayscrollbars-react';
import type { CSSProperties, RefObject } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import type { Components, ItemContent, ListRange, VirtuosoHandle, VirtuosoProps } from 'react-virtuoso';
import { Virtuoso } from 'react-virtuoso';
import type { S } from 'services/api/types';
import { getQueueItemElementId } from './shared';
const log = logger('system');
const virtuosoStyles = {
width: '100%',
height: '72px',
} satisfies CSSProperties;
type VirtuosoContext = { selectedItemId: number | null };
/**
* Scroll the item at the given index into view if it is not currently visible.
*/
const scrollIntoView = (
targetIndex: number,
rootEl: HTMLDivElement,
virtuosoHandle: VirtuosoHandle,
range: ListRange
) => {
if (range.endIndex === 0) {
// No range is rendered; no need to scroll to anything.
return;
}
const targetItem = rootEl.querySelector(`#${getQueueItemElementId(targetIndex)}`);
if (!targetItem) {
if (targetIndex > range.endIndex) {
virtuosoHandle.scrollToIndex({
index: targetIndex,
behavior: 'auto',
align: 'end',
});
} else if (targetIndex < range.startIndex) {
virtuosoHandle.scrollToIndex({
index: targetIndex,
behavior: 'auto',
align: 'start',
});
} else {
log.debug(
`Unable to find queue item at index ${targetIndex} but it is in the rendered range ${range.startIndex}-${range.endIndex}`
);
}
return;
}
// We found the image in the DOM, but it might be in the overscan range - rendered but not in the visible viewport.
// Check if it is in the viewport and scroll if necessary.
const itemRect = targetItem.getBoundingClientRect();
const rootRect = rootEl.getBoundingClientRect();
if (itemRect.left < rootRect.left) {
virtuosoHandle.scrollToIndex({
index: targetIndex,
behavior: 'auto',
align: 'start',
});
} else if (itemRect.right > rootRect.right) {
virtuosoHandle.scrollToIndex({
index: targetIndex,
behavior: 'auto',
align: 'end',
});
} else {
// Image is already in view
}
return;
};
const useScrollableStagingArea = (rootRef: RefObject<HTMLDivElement>) => {
const [scroller, scrollerRef] = useState<HTMLElement | null>(null);
const [initialize, osInstance] = useOverlayScrollbars({
defer: true,
events: {
initialized(osInstance) {
// force overflow styles
const { viewport } = osInstance.elements();
viewport.style.overflowX = `var(--os-viewport-overflow-x)`;
viewport.style.overflowY = `var(--os-viewport-overflow-y)`;
viewport.style.textAlign = 'center';
},
},
options: {
scrollbars: {
visibility: 'auto',
autoHide: 'scroll',
autoHideDelay: 1300,
theme: 'os-theme-dark',
},
overflow: {
y: 'hidden',
x: 'scroll',
},
},
});
useEffect(() => {
const { current: root } = rootRef;
if (scroller && root) {
initialize({
target: root,
elements: {
viewport: scroller,
},
});
}
return () => {
osInstance()?.destroy();
};
}, [scroller, initialize, osInstance, rootRef]);
return scrollerRef;
};
export const StagingAreaItemsList = memo(() => {
const canvasManager = useCanvasManagerSafe();
const ctx = useCanvasSessionContext();
const virtuosoRef = useRef<VirtuosoHandle>(null);
const rangeRef = useRef<ListRange>({ startIndex: 0, endIndex: 0 });
const rootRef = useRef<HTMLDivElement>(null);
const items = useStore(ctx.$items);
const selectedItemId = useStore(ctx.$selectedItemId);
const context = useMemo(() => ({ selectedItemId }), [selectedItemId]);
const scrollerRef = useScrollableStagingArea(rootRef);
useEffect(() => {
if (!canvasManager) {
return;
@@ -20,19 +152,64 @@ export const StagingAreaItemsList = memo(() => {
return canvasManager.stagingArea.connectToSession(ctx.$selectedItemId, ctx.$progressData, ctx.$isPending);
}, [canvasManager, ctx.$progressData, ctx.$selectedItemId, ctx.$isPending]);
useEffect(() => {
return ctx.$selectedItemIndex.listen((index) => {
if (!virtuosoRef.current) {
return;
}
if (!rootRef.current) {
return;
}
if (index === null) {
return;
}
scrollIntoView(index, rootRef.current, virtuosoRef.current, rangeRef.current);
});
}, [ctx.$selectedItemIndex]);
const onRangeChanged = useCallback((range: ListRange) => {
rangeRef.current = range;
}, []);
return (
<ScrollableContent overflowX="scroll" overflowY="hidden">
<Flex gap={2} w="full" h="full" justifyContent="safe center">
{items.map((item, i) => (
<QueueItemPreviewMini
key={`${item.item_id}-mini`}
item={item}
number={i + 1}
isSelected={selectedItemId === item.item_id}
/>
))}
</Flex>
</ScrollableContent>
<Box data-overlayscrollbars-initialize="" ref={rootRef} position="relative" w="full" h="full">
<Virtuoso<S['SessionQueueItem'], VirtuosoContext>
ref={virtuosoRef}
context={context}
data={items}
horizontalDirection
style={virtuosoStyles}
itemContent={itemContent}
components={components}
rangeChanged={onRangeChanged}
// Virtuoso expects the ref to be of HTMLElement | null | Window, but overlayscrollbars doesn't allow Window
scrollerRef={scrollerRef as VirtuosoProps<S['SessionQueueItem'], VirtuosoContext>['scrollerRef']}
/>
</Box>
);
});
StagingAreaItemsList.displayName = 'StagingAreaItemsList';
const itemContent: ItemContent<S['SessionQueueItem'], VirtuosoContext> = (index, item, { selectedItemId }) => (
<QueueItemPreviewMini
key={`${item.item_id}-mini`}
item={item}
index={index}
isSelected={selectedItemId === item.item_id}
/>
);
const listSx = {
'& > * + *': {
pl: 2,
},
};
const components: Components<S['SessionQueueItem'], VirtuosoContext> = {
List: forwardRef(({ context: _, ...rest }, ref) => {
return <Flex ref={ref} sx={listSx} {...rest} />;
}),
};

View File

@@ -24,6 +24,7 @@ import {
import type { ImageDTO } from 'services/api/types';
import { LaunchpadButton } from './LaunchpadButton';
import { LaunchpadContainer } from './LaunchpadContainer';
export const UpscalingLaunchpadPanel = memo(() => {
const { t } = useTranslation();
@@ -65,108 +66,104 @@ export const UpscalingLaunchpadPanel = memo(() => {
}, [dispatch]);
return (
<Flex flexDir="column" h="full" w="full" alignItems="center" gap={2}>
<Flex flexDir="column" w="full" gap={8} px={14} maxW={768} pt="20vh">
<Heading>{t('ui.launchpad.upscalingTitle')}</Heading>
{/* Upload Area */}
<LaunchpadButton {...uploadApi.getUploadButtonProps()} position="relative" gap={8}>
{!upscaleInitialImage ? (
<>
<Icon as={PiImageBold} boxSize={8} color="base.500" />
<Flex flexDir="column" alignItems="flex-start" gap={2}>
<Heading size="sm">{t('ui.launchpad.upscaling.uploadImage.title')}</Heading>
<Text color="base.300">{t('ui.launchpad.upscaling.uploadImage.description')}</Text>
</Flex>
<Flex position="absolute" right={3} bottom={3}>
<PiUploadBold />
<input {...uploadApi.getUploadInputProps()} />
</Flex>
</>
) : (
<>
<Icon as={PiImageBold} boxSize={8} color="base.500" />
<Flex flexDir="column" alignItems="flex-start" gap={2}>
<Heading size="sm">{t('ui.launchpad.upscaling.replaceImage.title')}</Heading>
<Text color="base.300">{t('ui.launchpad.upscaling.replaceImage.description')}</Text>
</Flex>
<Flex position="absolute" right={3} bottom={3}>
<PiUploadBold />
<input {...uploadApi.getUploadInputProps()} />
</Flex>
</>
)}
<DndDropTarget
dndTarget={setUpscaleInitialImageDndTarget}
dndTargetData={dndTargetData}
label={t('gallery.drop')}
/>
</LaunchpadButton>
{/* Guidance text */}
{upscaleInitialImage && (
<Flex bg="base.800" p={4} borderRadius="base" border="1px solid" borderColor="base.700">
<Text variant="subtext" fontSize="sm" lineHeight="1.6">
<strong>{t('ui.launchpad.upscaling.readyToUpscale.title')}</strong>{' '}
{t('ui.launchpad.upscaling.readyToUpscale.description')}
</Text>
</Flex>
<LaunchpadContainer heading={t('ui.launchpad.upscalingTitle')}>
{/* Upload Area */}
<LaunchpadButton {...uploadApi.getUploadButtonProps()} position="relative" gap={8}>
{!upscaleInitialImage ? (
<>
<Icon as={PiImageBold} boxSize={8} color="base.500" />
<Flex flexDir="column" alignItems="flex-start" gap={2}>
<Heading size="sm">{t('ui.launchpad.upscaling.uploadImage.title')}</Heading>
<Text color="base.300">{t('ui.launchpad.upscaling.uploadImage.description')}</Text>
</Flex>
<Flex position="absolute" right={3} bottom={3}>
<PiUploadBold />
<input {...uploadApi.getUploadInputProps()} />
</Flex>
</>
) : (
<>
<Icon as={PiImageBold} boxSize={8} color="base.500" />
<Flex flexDir="column" alignItems="flex-start" gap={2}>
<Heading size="sm">{t('ui.launchpad.upscaling.replaceImage.title')}</Heading>
<Text color="base.300">{t('ui.launchpad.upscaling.replaceImage.description')}</Text>
</Flex>
<Flex position="absolute" right={3} bottom={3}>
<PiUploadBold />
<input {...uploadApi.getUploadInputProps()} />
</Flex>
</>
)}
<DndDropTarget
dndTarget={setUpscaleInitialImageDndTarget}
dndTargetData={dndTargetData}
label={t('gallery.drop')}
/>
</LaunchpadButton>
{/* Controls */}
<Grid gridTemplateColumns="1fr 1fr" gap={8} alignItems="start">
{/* Left Column: Creativity and Structural Defaults */}
<Box>
<Text fontWeight="semibold" fontSize="sm" mb={3}>
Creativity & Structure Defaults
</Text>
<ButtonGroup size="sm" orientation="vertical" variant="outline" w="full">
<Button
colorScheme={creativity === -5 && structure === 5 ? 'invokeBlue' : undefined}
justifyContent="center"
onClick={onConservativeClick}
leftIcon={<PiShieldCheckBold />}
>
Conservative
</Button>
<Button
colorScheme={creativity === 0 && structure === 0 ? 'invokeBlue' : undefined}
justifyContent="center"
onClick={onBalancedClick}
leftIcon={<PiScalesBold />}
>
Balanced
</Button>
<Button
colorScheme={creativity === 5 && structure === -2 ? 'invokeBlue' : undefined}
justifyContent="center"
onClick={onCreativeClick}
leftIcon={<PiPaletteBold />}
>
Creative
</Button>
<Button
colorScheme={creativity === 8 && structure === -5 ? 'invokeBlue' : undefined}
justifyContent="center"
onClick={onArtisticClick}
leftIcon={<PiSparkleBold />}
>
Artistic
</Button>
</ButtonGroup>
</Box>
{/* Right Column: Description/help text */}
<Box>
<Text variant="subtext" fontSize="sm" lineHeight="1.6">
{t('ui.launchpad.upscaling.helpText.promptAdvice')}
</Text>
<Text variant="subtext" fontSize="sm" lineHeight="1.6" mt={3}>
{t('ui.launchpad.upscaling.helpText.styleAdvice')}
</Text>
</Box>
</Grid>
</Flex>
</Flex>
{/* Guidance text */}
{upscaleInitialImage && (
<Flex bg="base.800" p={4} borderRadius="base" border="1px solid" borderColor="base.700">
<Text variant="subtext" fontSize="sm" lineHeight="1.6">
<strong>{t('ui.launchpad.upscaling.readyToUpscale.title')}</strong>{' '}
{t('ui.launchpad.upscaling.readyToUpscale.description')}
</Text>
</Flex>
)}
{/* Controls */}
<Grid gridTemplateColumns="1fr 1fr" gap={8} alignItems="start">
{/* Left Column: Creativity and Structural Defaults */}
<Box>
<Text fontWeight="semibold" fontSize="sm" mb={3}>
Creativity & Structure Defaults
</Text>
<ButtonGroup size="sm" orientation="vertical" variant="outline" w="full">
<Button
colorScheme={creativity === -5 && structure === 5 ? 'invokeBlue' : undefined}
justifyContent="center"
onClick={onConservativeClick}
leftIcon={<PiShieldCheckBold />}
>
Conservative
</Button>
<Button
colorScheme={creativity === 0 && structure === 0 ? 'invokeBlue' : undefined}
justifyContent="center"
onClick={onBalancedClick}
leftIcon={<PiScalesBold />}
>
Balanced
</Button>
<Button
colorScheme={creativity === 5 && structure === -2 ? 'invokeBlue' : undefined}
justifyContent="center"
onClick={onCreativeClick}
leftIcon={<PiPaletteBold />}
>
Creative
</Button>
<Button
colorScheme={creativity === 8 && structure === -5 ? 'invokeBlue' : undefined}
justifyContent="center"
onClick={onArtisticClick}
leftIcon={<PiSparkleBold />}
>
Artistic
</Button>
</ButtonGroup>
</Box>
{/* Right Column: Description/help text */}
<Box>
<Text variant="subtext" fontSize="sm" lineHeight="1.6">
{t('ui.launchpad.upscaling.helpText.promptAdvice')}
</Text>
<Text variant="subtext" fontSize="sm" lineHeight="1.6" mt={3}>
{t('ui.launchpad.upscaling.helpText.styleAdvice')}
</Text>
</Box>
</Grid>
</LaunchpadContainer>
);
});

View File

@@ -8,6 +8,7 @@ import { useTranslation } from 'react-i18next';
import { PiFilePlusBold, PiFolderOpenBold, PiUploadBold } from 'react-icons/pi';
import { LaunchpadButton } from './LaunchpadButton';
import { LaunchpadContainer } from './LaunchpadContainer';
export const WorkflowsLaunchpadPanel = memo(() => {
const { t } = useTranslation();
@@ -45,63 +46,59 @@ export const WorkflowsLaunchpadPanel = memo(() => {
});
return (
<Flex flexDir="column" h="full" w="full" alignItems="center" gap={2}>
<Flex flexDir="column" w="full" gap={4} px={14} maxW={768} pt="20vh">
<Heading>{t('ui.launchpad.workflowsTitle')}</Heading>
<LaunchpadContainer heading={t('ui.launchpad.workflowsTitle')}>
{/* Description */}
<Text variant="subtext" fontSize="md" lineHeight="1.6">
{t('ui.launchpad.workflows.description')}
</Text>
{/* Description */}
<Text variant="subtext" fontSize="md" lineHeight="1.6">
{t('ui.launchpad.workflows.description')}
</Text>
<Text>
<Button
as="a"
variant="link"
href="https://support.invoke.ai/support/solutions/articles/151000189610-getting-started-with-workflows-denoise-latents"
target="_blank"
rel="noopener noreferrer"
size="sm"
>
{t('ui.launchpad.workflows.learnMoreLink')}
</Button>
</Text>
<Text>
<Button
as="a"
variant="link"
href="https://support.invoke.ai/support/solutions/articles/151000189610-getting-started-with-workflows-denoise-latents"
target="_blank"
rel="noopener noreferrer"
size="sm"
>
{t('ui.launchpad.workflows.learnMoreLink')}
</Button>
</Text>
{/* Action Buttons */}
<Flex flexDir="column" gap={8}>
{/* Browse Workflow Templates */}
<LaunchpadButton onClick={handleBrowseTemplates} position="relative" gap={8}>
<Icon as={PiFolderOpenBold} boxSize={8} color="base.500" />
<Flex flexDir="column" alignItems="flex-start" gap={2}>
<Heading size="sm">{t('ui.launchpad.workflows.browseTemplates.title')}</Heading>
<Text color="base.300">{t('ui.launchpad.workflows.browseTemplates.description')}</Text>
</Flex>
</LaunchpadButton>
{/* Action Buttons */}
<Flex flexDir="column" gap={8}>
{/* Browse Workflow Templates */}
<LaunchpadButton onClick={handleBrowseTemplates} position="relative" gap={8}>
<Icon as={PiFolderOpenBold} boxSize={8} color="base.500" />
<Flex flexDir="column" alignItems="flex-start" gap={2}>
<Heading size="sm">{t('ui.launchpad.workflows.browseTemplates.title')}</Heading>
<Text color="base.300">{t('ui.launchpad.workflows.browseTemplates.description')}</Text>
</Flex>
</LaunchpadButton>
{/* Create a new Workflow */}
<LaunchpadButton onClick={handleCreateNew} position="relative" gap={8}>
<Icon as={PiFilePlusBold} boxSize={8} color="base.500" />
<Flex flexDir="column" alignItems="flex-start" gap={2}>
<Heading size="sm">{t('ui.launchpad.workflows.createNew.title')}</Heading>
<Text color="base.300">{t('ui.launchpad.workflows.createNew.description')}</Text>
</Flex>
</LaunchpadButton>
{/* Create a new Workflow */}
<LaunchpadButton onClick={handleCreateNew} position="relative" gap={8}>
<Icon as={PiFilePlusBold} boxSize={8} color="base.500" />
<Flex flexDir="column" alignItems="flex-start" gap={2}>
<Heading size="sm">{t('ui.launchpad.workflows.createNew.title')}</Heading>
<Text color="base.300">{t('ui.launchpad.workflows.createNew.description')}</Text>
</Flex>
</LaunchpadButton>
{/* Load workflow from existing image or file */}
<LaunchpadButton {...uploadApi.getRootProps()} position="relative" gap={8}>
<Icon as={PiUploadBold} boxSize={8} color="base.500" />
<Flex flexDir="column" alignItems="flex-start" gap={2}>
<Heading size="sm">{t('ui.launchpad.workflows.loadFromFile.title')}</Heading>
<Text color="base.300">{t('ui.launchpad.workflows.loadFromFile.description')}</Text>
</Flex>
<Flex position="absolute" right={3} bottom={3}>
<PiUploadBold />
<input {...uploadApi.getInputProps()} />
</Flex>
</LaunchpadButton>
</Flex>
{/* Load workflow from existing image or file */}
<LaunchpadButton {...uploadApi.getRootProps()} position="relative" gap={8}>
<Icon as={PiUploadBold} boxSize={8} color="base.500" />
<Flex flexDir="column" alignItems="flex-start" gap={2}>
<Heading size="sm">{t('ui.launchpad.workflows.loadFromFile.title')}</Heading>
<Text color="base.300">{t('ui.launchpad.workflows.loadFromFile.description')}</Text>
</Flex>
<Flex position="absolute" right={3} bottom={3}>
<PiUploadBold />
<input {...uploadApi.getInputProps()} />
</Flex>
</LaunchpadButton>
</Flex>
</Flex>
</LaunchpadContainer>
);
});

View File

@@ -1,9 +1,12 @@
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppStore } from 'app/store/storeHooks';
import { buildZodTypeGuard } from 'common/util/zodUtils';
import { getOutputImageName } from 'features/controlLayers/components/SimpleSession/shared';
import { selectStagingAreaAutoSwitch } from 'features/controlLayers/store/canvasSettingsSlice';
import {
buildSelectSessionQueueItems,
canvasQueueItemDiscarded,
canvasSessionReset,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import type { ProgressImage } from 'features/nodes/types/common';
import type { Atom, MapStore, StoreValue, WritableAtom } from 'nanostores';
import { atom, computed, effect, map, subscribeKeys } from 'nanostores';
@@ -14,11 +17,6 @@ import { queueApi } from 'services/api/endpoints/queue';
import type { ImageDTO, S } from 'services/api/types';
import { $socket } from 'services/events/stores';
import { assert, objectEntries } from 'tsafe';
import { z } from 'zod/v4';
const zAutoSwitchMode = z.enum(['off', 'switch_on_start', 'switch_on_finish']);
export const isAutoSwitchMode = buildZodTypeGuard(zAutoSwitchMode);
type AutoSwitchMode = z.infer<typeof zAutoSwitchMode>;
export type ProgressData = {
itemId: number;
@@ -98,12 +96,13 @@ type CanvasSessionContextValue = {
$selectedItem: Atom<S['SessionQueueItem'] | null>;
$selectedItemIndex: Atom<number | null>;
$selectedItemOutputImageDTO: Atom<ImageDTO | null>;
$autoSwitch: WritableAtom<AutoSwitchMode>;
selectNext: () => void;
selectPrev: () => void;
selectFirst: () => void;
selectLast: () => void;
onImageLoad: (itemId: number) => void;
discard: (itemId: number) => void;
discardAll: () => void;
};
const CanvasSessionContext = createContext<CanvasSessionContextValue | null>(null);
@@ -140,11 +139,6 @@ export const CanvasSessionContextProvider = memo(
*/
const $items = useState(() => atom<S['SessionQueueItem'][]>([]))[0];
/**
* Whether auto-switch is enabled.
*/
const $autoSwitch = useState(() => atom<AutoSwitchMode>('switch_on_start'))[0];
/**
* An internal flag used to work around race conditions with auto-switch switching to queue items before their
* output images have fully loaded.
@@ -226,19 +220,21 @@ export const CanvasSessionContextProvider = memo(
)[0];
/**
* A redux selector to select all queue items from the RTK Query cache. It's important that this returns stable
* references if possible to reduce re-renders. All derivations of the queue items (e.g. filtering out canceled
* items) should be done in a nanostores computed.
* A redux selector to select all queue items from the RTK Query cache.
*/
const selectQueueItems = useMemo(
() =>
createSelector(
queueApi.endpoints.listAllQueueItems.select({ destination: session.id }),
({ data }) => data ?? EMPTY_ARRAY
),
[session.id]
const selectQueueItems = useMemo(() => buildSelectSessionQueueItems(session.id), [session.id]);
const discard = useCallback(
(itemId: number) => {
store.dispatch(canvasQueueItemDiscarded({ itemId }));
},
[store]
);
const discardAll = useCallback(() => {
store.dispatch(canvasSessionReset());
}, [store]);
const selectNext = useCallback(() => {
const selectedItemId = $selectedItemId.get();
if (selectedItemId === null) {
@@ -300,12 +296,15 @@ export const CanvasSessionContextProvider = memo(
imageLoaded: true,
});
}
if ($lastCompletedItemId.get() === itemId && $autoSwitch.get() === 'switch_on_finish') {
if (
$lastCompletedItemId.get() === itemId &&
selectStagingAreaAutoSwitch(store.getState()) === 'switch_on_finish'
) {
$selectedItemId.set(itemId);
$lastCompletedItemId.set(null);
}
},
[$autoSwitch, $lastCompletedItemId, $progressData, $selectedItemId]
[$lastCompletedItemId, $progressData, $selectedItemId, store]
);
// Set up socket listeners
@@ -340,7 +339,7 @@ export const CanvasSessionContextProvider = memo(
socket.off('invocation_progress', onProgress);
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
};
}, [$autoSwitch, $lastCompletedItemId, $lastStartedItemId, $progressData, $selectedItemId, session.id, socket]);
}, [$lastCompletedItemId, $lastStartedItemId, $progressData, $selectedItemId, session.id, socket]);
// Set up state subscriptions and effects
useEffect(() => {
@@ -362,33 +361,32 @@ export const CanvasSessionContextProvider = memo(
const unsubEnsureSelectedItemIdExists = effect(
[$items, $selectedItemId, $lastStartedItemId],
(items, selectedItemId, lastStartedItemId) => {
// If there are no items, cannot have a selected item.
if (items.length === 0) {
// If there are no items, cannot have a selected item.
$selectedItemId.set(null);
return;
}
// If there is no selected item but there are items, select the first one.
if (selectedItemId === null && items.length > 0) {
} else if (selectedItemId === null && items.length > 0) {
// If there is no selected item but there are items, select the first one.
$selectedItemId.set(items[0]?.item_id ?? null);
return;
}
if (
$autoSwitch.get() === 'switch_on_start' &&
} else if (
selectStagingAreaAutoSwitch(store.getState()) === 'switch_on_start' &&
items.findIndex(({ item_id }) => item_id === lastStartedItemId) !== -1
) {
$selectedItemId.set(lastStartedItemId);
$lastStartedItemId.set(null);
}
// If an item is selected and it is not in the list of items, un-set it. This effect will run again and we'll
// the above case, selecting the first item if there are any.
if (selectedItemId !== null && items.findIndex(({ item_id }) => item_id === selectedItemId) === -1) {
} else if (selectedItemId !== null && items.findIndex(({ item_id }) => item_id === selectedItemId) === -1) {
// If an item is selected and it is not in the list of items, un-set it. This effect will run again and we'll
// the above case, selecting the first item if there are any.
let prevIndex = _prevItems.findIndex(({ item_id }) => item_id === selectedItemId);
if (prevIndex >= items.length) {
prevIndex = items.length - 1;
}
const nextItem = items[prevIndex];
$selectedItemId.set(nextItem?.item_id ?? null);
return;
}
if (items !== _prevItems) {
_prevItems = items;
}
}
);
@@ -474,7 +472,7 @@ export const CanvasSessionContextProvider = memo(
if (lastLoadedItemId === null) {
return;
}
if ($autoSwitch.get() === 'switch_on_finish') {
if (selectStagingAreaAutoSwitch(store.getState()) === 'switch_on_finish') {
$selectedItemId.set(lastLoadedItemId);
}
$lastLoadedItemId.set(null);
@@ -486,6 +484,22 @@ export const CanvasSessionContextProvider = memo(
queueApi.endpoints.listAllQueueItems.initiate({ destination: session.id })
);
// const unsubListener = store.dispatch(
// addAppListener({
// matcher: queueApi.endpoints.cancelQueueItem.matchFulfilled,
// effect: ({ payload }, { getState }) => {
// const { item_id } = payload;
// const items = selectQueueItems(getState());
// if (items.length === 0) {
// $selectedItemId.set(null);
// } else if ($selectedItemId.get() === null) {
// $selectedItemId.set(items[0].item_id);
// }
// },
// })
// );
// Clean up all subscriptions and top-level (i.e. non-computed/derived state)
return () => {
unsubHandleAutoSwitch();
@@ -498,7 +512,6 @@ export const CanvasSessionContextProvider = memo(
$selectedItemId.set(null);
};
}, [
$autoSwitch,
$items,
$lastLoadedItemId,
$lastStartedItemId,
@@ -517,7 +530,6 @@ export const CanvasSessionContextProvider = memo(
$isPending,
$progressData,
$selectedItemId,
$autoSwitch,
$selectedItem,
$selectedItemIndex,
$selectedItemOutputImageDTO,
@@ -527,9 +539,10 @@ export const CanvasSessionContextProvider = memo(
selectFirst,
selectLast,
onImageLoad,
discard,
discardAll,
}),
[
$autoSwitch,
$items,
$hasItems,
$isPending,
@@ -545,6 +558,8 @@ export const CanvasSessionContextProvider = memo(
selectFirst,
selectLast,
onImageLoad,
discard,
discardAll,
]
);

View File

@@ -13,7 +13,7 @@ export const getProgressMessage = (data?: S['InvocationProgressEvent'] | null) =
export const DROP_SHADOW = 'drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 4px rgba(0, 0, 0, 0.3))';
export const getQueueItemElementId = (itemId: number) => `queue-item-status-card-${itemId}`;
export const getQueueItemElementId = (index: number) => `queue-item-preview-${index}`;
export const getOutputImageName = (item: S['SessionQueueItem']) => {
const nodeId = Object.entries(item.session.source_prepared_mapping).find(([nodeId]) =>

View File

@@ -0,0 +1,50 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
selectStagingAreaAutoSwitch,
settingsStagingAreaAutoSwitchChanged,
} from 'features/controlLayers/store/canvasSettingsSlice';
import { memo, useCallback } from 'react';
import { PiCaretLineRightBold, PiCaretRightBold, PiMoonBold } from 'react-icons/pi';
export const StagingAreaAutoSwitchButtons = memo(() => {
const autoSwitch = useAppSelector(selectStagingAreaAutoSwitch);
const dispatch = useAppDispatch();
const onClickOff = useCallback(() => {
dispatch(settingsStagingAreaAutoSwitchChanged('off'));
}, [dispatch]);
const onClickSwitchOnStart = useCallback(() => {
dispatch(settingsStagingAreaAutoSwitchChanged('switch_on_start'));
}, [dispatch]);
const onClickSwitchOnFinished = useCallback(() => {
dispatch(settingsStagingAreaAutoSwitchChanged('switch_on_finish'));
}, [dispatch]);
return (
<>
<IconButton
aria-label="Do not auto-switch"
tooltip="Do not auto-switch"
icon={<PiMoonBold />}
colorScheme={autoSwitch === 'off' ? 'invokeBlue' : 'base'}
onClick={onClickOff}
/>
<IconButton
aria-label="Switch on start"
tooltip="Switch on start"
icon={<PiCaretRightBold />}
colorScheme={autoSwitch === 'switch_on_start' ? 'invokeBlue' : 'base'}
onClick={onClickSwitchOnStart}
/>
<IconButton
aria-label="Switch on finish"
tooltip="Switch on finish"
icon={<PiCaretLineRightBold />}
colorScheme={autoSwitch === 'switch_on_finish' ? 'invokeBlue' : 'base'}
onClick={onClickSwitchOnFinished}
/>
</>
);
});
StagingAreaAutoSwitchButtons.displayName = 'StagingAreaAutoSwitchButtons';

View File

@@ -1,7 +1,6 @@
import { ButtonGroup } from '@invoke-ai/ui-library';
import { ButtonGroup, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { getQueueItemElementId } from 'features/controlLayers/components/SimpleSession/shared';
import { StagingAreaToolbarAcceptButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarAcceptButton';
import { StagingAreaToolbarDiscardAllButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarDiscardAllButton';
import { StagingAreaToolbarDiscardSelectedButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarDiscardSelectedButton';
@@ -12,27 +11,22 @@ import { StagingAreaToolbarPrevButton } from 'features/controlLayers/components/
import { StagingAreaToolbarSaveSelectedToGalleryButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarSaveSelectedToGalleryButton';
import { StagingAreaToolbarToggleShowResultsButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarToggleShowResultsButton';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { memo, useEffect } from 'react';
import { memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { StagingAreaAutoSwitchButtons } from './StagingAreaAutoSwitchButtons';
export const StagingAreaToolbar = memo(() => {
const canvasManager = useCanvasManager();
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
const ctx = useCanvasSessionContext();
useEffect(() => {
return ctx.$selectedItemId.listen((id) => {
if (id !== null) {
document.getElementById(getQueueItemElementId(id))?.scrollIntoView();
}
});
}, [ctx.$selectedItemId]);
useHotkeys('meta+left', ctx.selectFirst, { preventDefault: true });
useHotkeys('meta+right', ctx.selectLast, { preventDefault: true });
return (
<>
<Flex gap={2}>
<ButtonGroup borderRadius="base" shadow="dark-lg">
<StagingAreaToolbarPrevButton isDisabled={!shouldShowStagedImage} />
<StagingAreaToolbarImageCountButton />
@@ -44,9 +38,14 @@ export const StagingAreaToolbar = memo(() => {
<StagingAreaToolbarSaveSelectedToGalleryButton />
<StagingAreaToolbarMenu />
<StagingAreaToolbarDiscardSelectedButton isDisabled={!shouldShowStagedImage} />
</ButtonGroup>
<ButtonGroup borderRadius="base" shadow="dark-lg">
<StagingAreaAutoSwitchButtons />
</ButtonGroup>
<ButtonGroup borderRadius="base" shadow="dark-lg">
<StagingAreaToolbarDiscardAllButton isDisabled={!shouldShowStagedImage} />
</ButtonGroup>
</>
</Flex>
);
});

View File

@@ -9,7 +9,7 @@ import { canvasSessionReset } from 'features/controlLayers/store/canvasStagingAr
import { selectBboxRect, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageNameToImageObject } from 'features/controlLayers/store/util';
import { useDeleteQueueItemsByDestination } from 'features/queue/hooks/useDeleteQueueItemsByDestination';
import { useCancelQueueItemsByDestination } from 'features/queue/hooks/useCancelQueueItemsByDestination';
import { memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
@@ -24,7 +24,7 @@ export const StagingAreaToolbarAcceptButton = memo(() => {
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
const isCanvasFocused = useIsRegionFocused('canvas');
const selectedItemImageDTO = useStore(ctx.$selectedItemOutputImageDTO);
const deleteQueueItemsByDestination = useDeleteQueueItemsByDestination();
const cancelQueueItemsByDestination = useCancelQueueItemsByDestination();
const { t } = useTranslation();
@@ -41,13 +41,13 @@ export const StagingAreaToolbarAcceptButton = memo(() => {
dispatch(rasterLayerAdded({ overrides, isSelected: selectedEntityIdentifier?.type === 'raster_layer' }));
dispatch(canvasSessionReset());
deleteQueueItemsByDestination.trigger(ctx.session.id);
cancelQueueItemsByDestination.trigger(ctx.session.id, { withToast: false });
}, [
selectedItemImageDTO,
bboxRect,
dispatch,
selectedEntityIdentifier?.type,
deleteQueueItemsByDestination,
cancelQueueItemsByDestination,
ctx.session.id,
]);
@@ -68,8 +68,8 @@ export const StagingAreaToolbarAcceptButton = memo(() => {
icon={<PiCheckBold />}
onClick={acceptSelected}
colorScheme="invokeBlue"
isDisabled={!selectedItemImageDTO || !shouldShowStagedImage || deleteQueueItemsByDestination.isDisabled}
isLoading={deleteQueueItemsByDestination.isLoading}
isDisabled={!selectedItemImageDTO || !shouldShowStagedImage || cancelQueueItemsByDestination.isDisabled}
isLoading={cancelQueueItemsByDestination.isLoading}
/>
);
});

View File

@@ -1,27 +1,19 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { canvasSessionReset, generateSessionReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { useDeleteQueueItemsByDestination } from 'features/queue/hooks/useDeleteQueueItemsByDestination';
import { useCancelQueueItemsByDestination } from 'features/queue/hooks/useCancelQueueItemsByDestination';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
export const StagingAreaToolbarDiscardAllButton = memo(({ isDisabled }: { isDisabled?: boolean }) => {
const ctx = useCanvasSessionContext();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const deleteQueueItemsByDestination = useDeleteQueueItemsByDestination();
const cancelQueueItemsByDestination = useCancelQueueItemsByDestination();
const discardAll = useCallback(() => {
deleteQueueItemsByDestination.trigger(ctx.session.id);
if (ctx.session.type === 'advanced') {
dispatch(canvasSessionReset());
} else {
// ctx.session.type === 'simple'
dispatch(generateSessionReset());
}
}, [deleteQueueItemsByDestination, ctx.session.id, ctx.session.type, dispatch]);
ctx.discardAll();
cancelQueueItemsByDestination.trigger(ctx.session.id, { withToast: false });
}, [cancelQueueItemsByDestination, ctx]);
return (
<IconButton
@@ -30,9 +22,8 @@ export const StagingAreaToolbarDiscardAllButton = memo(({ isDisabled }: { isDisa
icon={<PiTrashSimpleBold />}
onClick={discardAll}
colorScheme="error"
fontSize={16}
isDisabled={isDisabled || deleteQueueItemsByDestination.isDisabled}
isLoading={deleteQueueItemsByDestination.isLoading}
isDisabled={isDisabled || cancelQueueItemsByDestination.isDisabled}
isLoading={cancelQueueItemsByDestination.isLoading}
/>
);
});

View File

@@ -1,17 +1,14 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { canvasSessionReset, generateSessionReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { useDeleteQueueItem } from 'features/queue/hooks/useDeleteQueueItem';
import { useCancelQueueItem } from 'features/queue/hooks/useCancelQueueItem';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
export const StagingAreaToolbarDiscardSelectedButton = memo(({ isDisabled }: { isDisabled?: boolean }) => {
const dispatch = useAppDispatch();
const ctx = useCanvasSessionContext();
const deleteQueueItem = useDeleteQueueItem();
const cancelQueueItem = useCancelQueueItem();
const selectedItemId = useStore(ctx.$selectedItemId);
const { t } = useTranslation();
@@ -20,17 +17,9 @@ export const StagingAreaToolbarDiscardSelectedButton = memo(({ isDisabled }: { i
if (selectedItemId === null) {
return;
}
await deleteQueueItem.trigger(selectedItemId);
const itemCount = ctx.$itemCount.get();
if (itemCount <= 1) {
if (ctx.session.type === 'advanced') {
dispatch(canvasSessionReset());
} else {
// ctx.session.type === 'simple'
dispatch(generateSessionReset());
}
}
}, [selectedItemId, deleteQueueItem, ctx.$itemCount, ctx.session.type, dispatch]);
ctx.discard(selectedItemId);
await cancelQueueItem.trigger(selectedItemId, { withToast: false });
}, [selectedItemId, ctx, cancelQueueItem]);
return (
<IconButton
@@ -39,9 +28,8 @@ export const StagingAreaToolbarDiscardSelectedButton = memo(({ isDisabled }: { i
icon={<PiXBold />}
onClick={discardSelected}
colorScheme="invokeBlue"
fontSize={16}
isDisabled={selectedItemId === null || deleteQueueItem.isDisabled || isDisabled}
isLoading={deleteQueueItem.isLoading}
isDisabled={selectedItemId === null || cancelQueueItem.isDisabled || isDisabled}
isLoading={cancelQueueItem.isLoading}
/>
);
});

View File

@@ -1,16 +1,13 @@
import { IconButton, Menu, MenuButton, MenuDivider, MenuList } from '@invoke-ai/ui-library';
import { StagingAreaToolbarMenuAutoSwitch } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarMenuAutoSwitch';
import { IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
import { StagingAreaToolbarNewLayerFromImageMenuItems } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarMenuNewLayerFromImage';
import { memo } from 'react';
import { PiDotsThreeBold } from 'react-icons/pi';
import { PiDotsThreeVerticalBold } from 'react-icons/pi';
export const StagingAreaToolbarMenu = memo(() => {
return (
<Menu>
<MenuButton as={IconButton} icon={<PiDotsThreeBold />} colorScheme="invokeBlue" />
<MenuButton as={IconButton} icon={<PiDotsThreeVerticalBold />} colorScheme="invokeBlue" />
<MenuList>
<StagingAreaToolbarMenuAutoSwitch />
<MenuDivider />
<StagingAreaToolbarNewLayerFromImageMenuItems />
</MenuList>
</Menu>

View File

@@ -1,34 +0,0 @@
import { MenuItemOption, MenuOptionGroup } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { isAutoSwitchMode, useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { memo, useCallback } from 'react';
import { assert } from 'tsafe';
export const StagingAreaToolbarMenuAutoSwitch = memo(() => {
const ctx = useCanvasSessionContext();
const autoSwitch = useStore(ctx.$autoSwitch);
const onChange = useCallback(
(val: string | string[]) => {
assert(isAutoSwitchMode(val));
ctx.$autoSwitch.set(val);
},
[ctx.$autoSwitch]
);
return (
<MenuOptionGroup value={autoSwitch} onChange={onChange} title="Auto-Switch" type="radio">
<MenuItemOption value="off" closeOnSelect={false}>
Off
</MenuItemOption>
<MenuItemOption value="switch_on_start" closeOnSelect={false}>
Switch on Start
</MenuItemOption>
<MenuItemOption value="switch_on_finish" closeOnSelect={false}>
Switch on Finish
</MenuItemOption>
</MenuOptionGroup>
);
});
StagingAreaToolbarMenuAutoSwitch.displayName = 'StagingAreaToolbarMenuAutoSwitch';

View File

@@ -1,38 +1,41 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import type { RgbaColor } from 'features/controlLayers/store/types';
import { zRgbaColor } from 'features/controlLayers/store/types';
import { z } from 'zod/v4';
type CanvasSettingsState = {
const zAutoSwitchMode = z.enum(['off', 'switch_on_start', 'switch_on_finish']);
const zCanvasSettingsState = z.object({
/**
* Whether to show HUD (Heads-Up Display) on the canvas.
*/
showHUD: boolean;
showHUD: z.boolean().default(true),
/**
* Whether to clip lines and shapes to the generation bounding box. If disabled, lines and shapes will be clipped to
* the canvas bounds.
*/
clipToBbox: boolean;
clipToBbox: z.boolean().default(false),
/**
* Whether to show a dynamic grid on the canvas. If disabled, a checkerboard pattern will be shown instead.
*/
dynamicGrid: boolean;
dynamicGrid: z.boolean().default(false),
/**
* Whether to invert the scroll direction when adjusting the brush or eraser width with the scroll wheel.
*/
invertScrollForToolWidth: boolean;
invertScrollForToolWidth: z.boolean().default(false),
/**
* The width of the brush tool.
*/
brushWidth: number;
brushWidth: z.int().gt(0).default(50),
/**
* The width of the eraser tool.
*/
eraserWidth: number;
eraserWidth: z.int().gt(0).default(50),
/**
* The color to use when drawing lines or filling shapes.
*/
color: RgbaColor;
color: zRgbaColor.default({ r: 31, g: 160, b: 224, a: 1 }), // invokeBlue.500
/**
* Whether to composite inpainted/outpainted regions back onto the source image when saving canvas generations.
*
@@ -40,70 +43,61 @@ type CanvasSettingsState = {
*
* When `sendToCanvas` is disabled, this setting is ignored, masked regions will always be composited.
*/
outputOnlyMaskedRegions: boolean;
outputOnlyMaskedRegions: z.boolean().default(true),
/**
* Whether to automatically process the operations like filtering and auto-masking.
*/
autoProcess: boolean;
autoProcess: z.boolean().default(true),
/**
* The snap-to-grid setting for the canvas.
*/
snapToGrid: boolean;
snapToGrid: z.boolean().default(true),
/**
* Whether to show progress on the canvas when generating images.
*/
showProgressOnCanvas: boolean;
showProgressOnCanvas: z.boolean().default(true),
/**
* Whether to show the bounding box overlay on the canvas.
*/
bboxOverlay: boolean;
bboxOverlay: z.boolean().default(false),
/**
* Whether to preserve the masked region instead of inpainting it.
*/
preserveMask: boolean;
preserveMask: z.boolean().default(false),
/**
* Whether to show only raster layers while staging.
*/
isolatedStagingPreview: boolean;
isolatedStagingPreview: z.boolean().default(true),
/**
* Whether to show only the selected layer while filtering, transforming, or doing other operations.
*/
isolatedLayerPreview: boolean;
isolatedLayerPreview: z.boolean().default(true),
/**
* Whether to use pressure sensitivity for the brush and eraser tool when a pen device is used.
*/
pressureSensitivity: boolean;
pressureSensitivity: z.boolean().default(true),
/**
* Whether to show the rule of thirds composition guide overlay on the canvas.
*/
ruleOfThirds: boolean;
};
ruleOfThirds: z.boolean().default(false),
/**
* Whether to save all staging images to the gallery instead of keeping them as intermediate images.
*/
saveAllImagesToGallery: z.boolean().default(false),
/**
* The auto-switch mode for the canvas staging area.
*/
stagingAreaAutoSwitch: zAutoSwitchMode.default('switch_on_start'),
});
const initialState: CanvasSettingsState = {
showHUD: true,
clipToBbox: false,
dynamicGrid: false,
brushWidth: 50,
eraserWidth: 50,
invertScrollForToolWidth: false,
color: { r: 31, g: 160, b: 224, a: 1 }, // invokeBlue.500
outputOnlyMaskedRegions: true,
autoProcess: true,
snapToGrid: true,
showProgressOnCanvas: true,
bboxOverlay: false,
preserveMask: false,
isolatedStagingPreview: true,
isolatedLayerPreview: true,
pressureSensitivity: true,
ruleOfThirds: false,
};
type CanvasSettingsState = z.infer<typeof zCanvasSettingsState>;
const getInitialState = () => zCanvasSettingsState.parse({});
export const canvasSettingsSlice = createSlice({
name: 'canvasSettings',
initialState,
initialState: getInitialState(),
reducers: {
settingsClipToBboxChanged: (state, action: PayloadAction<boolean>) => {
settingsClipToBboxChanged: (state, action: PayloadAction<CanvasSettingsState['clipToBbox']>) => {
state.clipToBbox = action.payload;
},
settingsDynamicGridToggled: (state) => {
@@ -112,16 +106,19 @@ export const canvasSettingsSlice = createSlice({
settingsShowHUDToggled: (state) => {
state.showHUD = !state.showHUD;
},
settingsBrushWidthChanged: (state, action: PayloadAction<number>) => {
settingsBrushWidthChanged: (state, action: PayloadAction<CanvasSettingsState['brushWidth']>) => {
state.brushWidth = Math.round(action.payload);
},
settingsEraserWidthChanged: (state, action: PayloadAction<number>) => {
settingsEraserWidthChanged: (state, action: PayloadAction<CanvasSettingsState['eraserWidth']>) => {
state.eraserWidth = Math.round(action.payload);
},
settingsColorChanged: (state, action: PayloadAction<RgbaColor>) => {
settingsColorChanged: (state, action: PayloadAction<CanvasSettingsState['color']>) => {
state.color = action.payload;
},
settingsInvertScrollForToolWidthChanged: (state, action: PayloadAction<boolean>) => {
settingsInvertScrollForToolWidthChanged: (
state,
action: PayloadAction<CanvasSettingsState['invertScrollForToolWidth']>
) => {
state.invertScrollForToolWidth = action.payload;
},
settingsOutputOnlyMaskedRegionsToggled: (state) => {
@@ -154,6 +151,15 @@ export const canvasSettingsSlice = createSlice({
settingsRuleOfThirdsToggled: (state) => {
state.ruleOfThirds = !state.ruleOfThirds;
},
settingsSaveAllImagesToGalleryToggled: (state) => {
state.saveAllImagesToGallery = !state.saveAllImagesToGallery;
},
settingsStagingAreaAutoSwitchChanged: (
state,
action: PayloadAction<CanvasSettingsState['stagingAreaAutoSwitch']>
) => {
state.stagingAreaAutoSwitch = action.payload;
},
},
});
@@ -175,6 +181,8 @@ export const {
settingsIsolatedLayerPreviewToggled,
settingsPressureSensitivityToggled,
settingsRuleOfThirdsToggled,
settingsSaveAllImagesToGalleryToggled,
settingsStagingAreaAutoSwitchChanged,
} = canvasSettingsSlice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
@@ -184,7 +192,7 @@ const migrate = (state: any): any => {
export const canvasSettingsPersistConfig: PersistConfig<CanvasSettingsState> = {
name: canvasSettingsSlice.name,
initialState,
initialState: getInitialState(),
migrate,
persistDenylist: [],
};
@@ -209,3 +217,5 @@ export const selectIsolatedStagingPreview = createCanvasSettingsSelector((settin
export const selectIsolatedLayerPreview = createCanvasSettingsSelector((settings) => settings.isolatedLayerPreview);
export const selectPressureSensitivity = createCanvasSettingsSelector((settings) => settings.pressureSensitivity);
export const selectRuleOfThirds = createCanvasSettingsSelector((settings) => settings.ruleOfThirds);
export const selectSaveAllImagesToGallery = createCanvasSettingsSelector((settings) => settings.saveAllImagesToGallery);
export const selectStagingAreaAutoSwitch = createCanvasSettingsSelector((settings) => settings.stagingAreaAutoSwitch);

View File

@@ -1,16 +1,20 @@
import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { canvasReset } from 'features/controlLayers/store/actions';
import { queueApi } from 'services/api/endpoints/queue';
type CanvasStagingAreaState = {
generateSessionId: string | null;
canvasSessionId: string | null;
canvasDiscardedQueueItems: number[];
};
const INITIAL_STATE: CanvasStagingAreaState = {
generateSessionId: null,
canvasSessionId: null,
canvasDiscardedQueueItems: [],
};
const getInitialState = (): CanvasStagingAreaState => deepClone(INITIAL_STATE);
@@ -26,12 +30,20 @@ export const canvasSessionSlice = createSlice({
generateSessionReset: (state) => {
state.generateSessionId = null;
},
canvasQueueItemDiscarded: (state, action: PayloadAction<{ itemId: number }>) => {
const { itemId } = action.payload;
if (!state.canvasDiscardedQueueItems.includes(itemId)) {
state.canvasDiscardedQueueItems.push(itemId);
}
},
canvasSessionIdChanged: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
state.canvasSessionId = id;
state.canvasDiscardedQueueItems = [];
},
canvasSessionReset: (state) => {
state.canvasSessionId = null;
state.canvasDiscardedQueueItems = [];
},
},
extraReducers(builder) {
@@ -41,8 +53,13 @@ export const canvasSessionSlice = createSlice({
},
});
export const { generateSessionIdChanged, generateSessionReset, canvasSessionIdChanged, canvasSessionReset } =
canvasSessionSlice.actions;
export const {
generateSessionIdChanged,
generateSessionReset,
canvasSessionIdChanged,
canvasSessionReset,
canvasQueueItemDiscarded,
} = canvasSessionSlice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
@@ -63,4 +80,34 @@ export const selectGenerateSessionId = createSelector(
selectCanvasSessionSlice,
({ generateSessionId }) => generateSessionId
);
export const selectIsStaging = createSelector(selectCanvasSessionId, (canvasSessionId) => canvasSessionId !== null);
export const buildSelectSessionQueueItems = (sessionId: string) =>
createSelector(
[queueApi.endpoints.listAllQueueItems.select({ destination: sessionId }), selectDiscardedItems],
({ data }, discardedItems) => {
if (!data) {
return EMPTY_ARRAY;
}
return data.filter(
({ status, item_id }) => status !== 'canceled' && status !== 'failed' && !discardedItems.includes(item_id)
);
}
);
export const selectIsStaging = (state: RootState) => {
const sessionId = selectCanvasSessionId(state);
if (!sessionId) {
return false;
}
const { data } = queueApi.endpoints.listAllQueueItems.select({ destination: sessionId })(state);
if (!data) {
return false;
}
const discardedItems = selectDiscardedItems(state);
return data.some(
({ status, item_id }) => status !== 'canceled' && status !== 'failed' && !discardedItems.includes(item_id)
);
};
const selectDiscardedItems = createSelector(
selectCanvasSessionSlice,
({ canvasDiscardedQueueItems }) => canvasDiscardedQueueItems
);

View File

@@ -57,6 +57,8 @@ export const refImagesSlice = createSlice({
const { entities, replace } = action.payload;
if (replace) {
state.entities = entities;
state.isPanelOpen = false;
state.selectedEntityId = null;
} else {
state.entities.push(...entities);
}

View File

@@ -98,7 +98,7 @@ const zRgbColor = z.object({
b: z.number().int().min(0).max(255),
});
export type RgbColor = z.infer<typeof zRgbColor>;
const zRgbaColor = zRgbColor.extend({
export const zRgbaColor = zRgbColor.extend({
a: z.number().min(0).max(1),
});
export type RgbaColor = z.infer<typeof zRgbaColor>;

View File

@@ -15,7 +15,6 @@ const sx = {
objectFit: 'contain',
maxW: 'full',
maxH: 'full',
borderRadius: 'base',
cursor: 'grab',
'&[data-is-dragging=true]': {
opacity: 0.3,

View File

@@ -1,25 +1,33 @@
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
import { useImageActions } from 'features/gallery/hooks/useImageActions';
import { useRecallAll } from 'features/gallery/hooks/useRecallAll';
import { useRecallDimensions } from 'features/gallery/hooks/useRecallDimensions';
import { useRecallPrompts } from 'features/gallery/hooks/useRecallPrompts';
import { useRecallRemix } from 'features/gallery/hooks/useRecallRemix';
import { useRecallSeed } from 'features/gallery/hooks/useRecallSeed';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import {
PiArrowBendUpLeftBold,
PiArrowsCounterClockwiseBold,
PiAsteriskBold,
PiPaintBrushBold,
PiPlantBold,
PiQuotesBold,
PiRulerBold,
} from 'react-icons/pi';
export const ImageMenuItemMetadataRecallActions = memo(() => {
const { t } = useTranslation();
const imageDTO = useImageDTOContext();
const subMenu = useSubMenu();
const { recallAll, remix, recallSeed, recallPrompts, hasMetadata, hasSeed, hasPrompts, createAsPreset } =
useImageActions(imageDTO);
const imageDTO = useImageDTOContext();
const recallAll = useRecallAll(imageDTO);
const recallRemix = useRecallRemix(imageDTO);
const recallPrompts = useRecallPrompts(imageDTO);
const recallSeed = useRecallSeed(imageDTO);
const recallDimensions = useRecallDimensions(imageDTO);
return (
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiArrowBendUpLeftBold />}>
@@ -28,20 +36,24 @@ export const ImageMenuItemMetadataRecallActions = memo(() => {
<SubMenuButtonContent label={t('parameters.recallMetadata')} />
</MenuButton>
<MenuList {...subMenu.menuListProps}>
<MenuItem icon={<PiArrowsCounterClockwiseBold />} onClick={remix} isDisabled={!hasMetadata}>
<MenuItem
icon={<PiArrowsCounterClockwiseBold />}
onClick={recallRemix.recall}
isDisabled={!recallRemix.isEnabled}
>
{t('parameters.remixImage')}
</MenuItem>
<MenuItem icon={<PiQuotesBold />} onClick={recallPrompts} isDisabled={!hasPrompts}>
<MenuItem icon={<PiQuotesBold />} onClick={recallPrompts.recall} isDisabled={!recallPrompts.isEnabled}>
{t('parameters.usePrompt')}
</MenuItem>
<MenuItem icon={<PiPlantBold />} onClick={recallSeed} isDisabled={!hasSeed}>
<MenuItem icon={<PiPlantBold />} onClick={recallSeed.recall} isDisabled={!recallSeed.isEnabled}>
{t('parameters.useSeed')}
</MenuItem>
<MenuItem icon={<PiAsteriskBold />} onClick={recallAll} isDisabled={!hasMetadata}>
<MenuItem icon={<PiAsteriskBold />} onClick={recallAll.recall} isDisabled={!recallAll.isEnabled}>
{t('parameters.useAll')}
</MenuItem>
<MenuItem icon={<PiPaintBrushBold />} onClick={createAsPreset} isDisabled={!hasPrompts}>
{t('stylePresets.useForTemplate')}
<MenuItem icon={<PiRulerBold />} onClick={recallDimensions.recall} isDisabled={!recallDimensions.isEnabled}>
{t('parameters.useSize')}
</MenuItem>
</MenuList>
</Menu>

View File

@@ -0,0 +1,20 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
import { useCreateStylePresetFromMetadata } from 'features/gallery/hooks/useCreateStylePresetFromMetadata';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPaintBrushBold } from 'react-icons/pi';
export const ImageMenuItemUseAsPromptTemplate = memo(() => {
const { t } = useTranslation();
const imageDTO = useImageDTOContext();
const stylePreset = useCreateStylePresetFromMetadata(imageDTO);
return (
<MenuItem icon={<PiPaintBrushBold />} onClickCapture={stylePreset.create} isDisabled={!stylePreset.isEnabled}>
{t('stylePresets.useForTemplate')}
</MenuItem>
);
});
ImageMenuItemUseAsPromptTemplate.displayName = 'ImageMenuItemUseAsPromptTemplate';

View File

@@ -1,4 +1,5 @@
import { MenuDivider } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { IconMenuItemGroup } from 'common/components/IconMenuItem';
import { ImageMenuItemChangeBoard } from 'features/gallery/components/ImageContextMenu/ImageMenuItemChangeBoard';
import { ImageMenuItemCopy } from 'features/gallery/components/ImageContextMenu/ImageMenuItemCopy';
@@ -16,14 +17,19 @@ import { ImageMenuItemStarUnstar } from 'features/gallery/components/ImageContex
import { ImageMenuItemUseAsRefImage } from 'features/gallery/components/ImageContextMenu/ImageMenuItemUseAsRefImage';
import { ImageMenuItemUseForPromptGeneration } from 'features/gallery/components/ImageContextMenu/ImageMenuItemUseForPromptGeneration';
import { ImageDTOContextProvider } from 'features/gallery/contexts/ImageDTOContext';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo } from 'react';
import type { ImageDTO } from 'services/api/types';
import { ImageMenuItemUseAsPromptTemplate } from './ImageMenuItemUseAsPromptTemplate';
type SingleSelectionMenuItemsProps = {
imageDTO: ImageDTO;
};
const SingleSelectionMenuItems = ({ imageDTO }: SingleSelectionMenuItemsProps) => {
const tab = useAppSelector(selectActiveTab);
return (
<ImageDTOContextProvider value={imageDTO}>
<IconMenuItemGroup>
@@ -36,13 +42,14 @@ const SingleSelectionMenuItems = ({ imageDTO }: SingleSelectionMenuItemsProps) =
</IconMenuItemGroup>
<MenuDivider />
<ImageMenuItemLoadWorkflow />
<ImageMenuItemMetadataRecallActions />
{(tab === 'canvas' || tab === 'generate') && <ImageMenuItemMetadataRecallActions />}
<MenuDivider />
<ImageMenuItemSendToUpscale />
<ImageMenuItemUseForPromptGeneration />
<ImageMenuItemUseAsRefImage />
{(tab === 'canvas' || tab === 'generate') && <ImageMenuItemUseAsRefImage />}
<ImageMenuItemUseAsPromptTemplate />
<ImageMenuItemNewCanvasFromImageSubMenu />
<ImageMenuItemNewLayerFromImageSubMenu />
{tab === 'canvas' && <ImageMenuItemNewLayerFromImageSubMenu />}
<MenuDivider />
<ImageMenuItemChangeBoard />
<ImageMenuItemStarUnstar />

View File

@@ -1,7 +1,7 @@
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
import { draggable, monitorForElements } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
import type { FlexProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Icon, Image } from '@invoke-ai/ui-library';
import { Flex, Icon, Image } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import type { AppDispatch, AppGetState } from 'app/store/store';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
@@ -23,13 +23,11 @@ import { imageToCompareChanged, selectGallerySlice, selectionChanged } from 'fea
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { VIEWER_PANEL_ID } from 'features/ui/layouts/shared';
import type { MouseEvent, MouseEventHandler } from 'react';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { PiImageBold } from 'react-icons/pi';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
const GALLERY_IMAGE_CLASS = 'gallery-image';
const galleryImageContainerSX = {
containerType: 'inline-size',
w: 'full',
@@ -42,45 +40,42 @@ const galleryImageContainerSX = {
'&[data-is-dragging=true]': {
opacity: 0.3,
},
[`.${GALLERY_IMAGE_CLASS}`]: {
touchAction: 'none',
userSelect: 'none',
webkitUserSelect: 'none',
position: 'relative',
justifyContent: 'center',
alignItems: 'center',
aspectRatio: '1/1',
'::before': {
content: '""',
display: 'inline-block',
position: 'absolute',
top: 0,
left: 0,
right: 0,
bottom: 0,
pointerEvents: 'none',
borderRadius: 'base',
},
'&[data-selected=true]::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
},
'&[data-selected-for-compare=true]::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
},
'&:hover::before': {
boxShadow:
'inset 0px 0px 0px 1px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-800)',
},
'&:hover[data-selected=true]::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
},
'&:hover[data-selected-for-compare=true]::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
},
userSelect: 'none',
webkitUserSelect: 'none',
position: 'relative',
justifyContent: 'center',
alignItems: 'center',
aspectRatio: '1/1',
'::before': {
content: '""',
display: 'inline-block',
position: 'absolute',
top: 0,
left: 0,
right: 0,
bottom: 0,
pointerEvents: 'none',
borderRadius: 'base',
},
'&[data-selected=true]::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
},
'&[data-selected-for-compare=true]::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
},
'&:hover::before': {
boxShadow:
'inset 0px 0px 0px 1px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-800)',
},
'&:hover[data-selected=true]::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
},
'&:hover[data-selected-for-compare=true]::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
},
} satisfies SystemStyleObject;
@@ -142,8 +137,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
const [dragPreviewState, setDragPreviewState] = useState<
DndDragPreviewSingleImageState | DndDragPreviewMultipleImageState | null
>(null);
// Must use callback ref - else chakra's Image fallback prop will break the ref & dnd
const [element, ref] = useState<HTMLImageElement | null>(null);
const ref = useRef<HTMLDivElement>(null);
const selectIsSelectedForCompare = useMemo(
() => createSelector(selectGallerySlice, (gallery) => gallery.imageToCompare === imageDTO.image_name),
[imageDTO.image_name]
@@ -156,6 +150,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
const isSelected = useAppSelector(selectIsSelected);
useEffect(() => {
const element = ref.current;
if (!element) {
return;
}
@@ -221,7 +216,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
},
})
);
}, [element, imageDTO, store]);
}, [imageDTO, store]);
const [isHovered, setIsHovered] = useState(false);
@@ -240,34 +235,35 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
navigationApi.focusPanelInActiveTab(VIEWER_PANEL_ID);
}, [store]);
useImageContextMenu(imageDTO, element);
useImageContextMenu(imageDTO, ref);
return (
<>
<Box sx={galleryImageContainerSX} data-is-dragging={isDragging} data-image-name={imageDTO.image_name}>
<Flex
role="button"
className={GALLERY_IMAGE_CLASS}
onMouseOver={onMouseOver}
onMouseOut={onMouseOut}
onClick={onClick}
onDoubleClick={onDoubleClick}
data-selected={isSelected}
data-selected-for-compare={isSelectedForCompare}
>
<Image
ref={ref}
src={imageDTO.thumbnail_url}
w={imageDTO.width}
fallback={<GalleryImagePlaceholder />}
objectFit="contain"
maxW="full"
maxH="full"
borderRadius="base"
/>
<GalleryImageHoverIcons imageDTO={imageDTO} isHovered={isHovered} />
</Flex>
</Box>
<Flex
ref={ref}
sx={galleryImageContainerSX}
data-is-dragging={isDragging}
data-image-name={imageDTO.image_name}
role="button"
onMouseOver={onMouseOver}
onMouseOut={onMouseOut}
onClick={onClick}
onDoubleClick={onDoubleClick}
data-selected={isSelected}
data-selected-for-compare={isSelectedForCompare}
>
<Image
pointerEvents="none"
src={imageDTO.thumbnail_url}
w={imageDTO.width}
fallback={<GalleryImagePlaceholder />}
objectFit="contain"
maxW="full"
maxH="full"
borderRadius="base"
/>
<GalleryImageHoverIcons imageDTO={imageDTO} isHovered={isHovered} />
</Flex>
{dragPreviewState?.type === 'multiple-image' ? createMultipleImageDragPreview(dragPreviewState) : null}
{dragPreviewState?.type === 'single-image' ? createSingleImageDragPreview(dragPreviewState) : null}
</>

View File

@@ -85,7 +85,7 @@ const UnrecallableMetadataParsed = typedMemo(
return (
<Box as="span" lineHeight={1}>
<LabelComponent />
<LabelComponent i18nKey={handler.i18nKey} />
<ValueComponent value={data.value} />
</Box>
);
@@ -128,7 +128,7 @@ const SingleMetadataParsed = typedMemo(
onClick={onClick}
/>
<Box as="span" lineHeight={1}>
<LabelComponent />
<LabelComponent i18nKey={handler.i18nKey} />
<ValueComponent value={data.value} />
</Box>
</Flex>
@@ -178,7 +178,7 @@ const CollectionMetadataParsed = typedMemo(
onClick={onClick}
/>
<Box as="span" lineHeight={1}>
<LabelComponent />
<LabelComponent i18nKey={handler.i18nKey} />
<ValueComponent value={value} />
</Box>
</Flex>

View File

@@ -1,21 +1,19 @@
import { Button, Divider, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
import SingleSelectionMenuItems from 'features/gallery/components/ImageContextMenu/SingleSelectionMenuItems';
import { useImageActions } from 'features/gallery/hooks/useImageActions';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { newCanvasFromImage } from 'features/imageActions/actions';
import { $hasTemplates } from 'features/nodes/store/nodesSlice';
import { useDeleteImage } from 'features/gallery/hooks/useDeleteImage';
import { useEditImage } from 'features/gallery/hooks/useEditImage';
import { useLoadWorkflow } from 'features/gallery/hooks/useLoadWorkflow';
import { useRecallAll } from 'features/gallery/hooks/useRecallAll';
import { useRecallDimensions } from 'features/gallery/hooks/useRecallDimensions';
import { useRecallPrompts } from 'features/gallery/hooks/useRecallPrompts';
import { useRecallRemix } from 'features/gallery/hooks/useRecallRemix';
import { useRecallSeed } from 'features/gallery/hooks/useRecallSeed';
import { PostProcessingPopover } from 'features/parameters/components/PostProcessing/PostProcessingPopover';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { toast } from 'features/toast/toast';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { WORKSPACE_PANEL_ID } from 'features/ui/layouts/shared';
import { selectShouldShowProgressInViewer } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import {
PiArrowsCounterClockwiseBold,
@@ -27,51 +25,23 @@ import {
PiQuotesBold,
PiRulerBold,
} from 'react-icons/pi';
import { useImageDTO } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { useImageViewerContext } from './context';
export const CurrentImageButtons = memo(() => {
export const CurrentImageButtons = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
const { t } = useTranslation();
const ctx = useImageViewerContext();
const hasProgressImage = useStore(ctx.$hasProgressImage);
const shouldShowProgressInViewer = useAppSelector(selectShouldShowProgressInViewer);
const isDisabledOverride = hasProgressImage && shouldShowProgressInViewer;
const tab = useAppSelector(selectActiveTab);
const isCanvasOrGenerateTab = tab === 'canvas' || tab === 'generate';
const imageName = useAppSelector(selectLastSelectedImage);
const imageDTO = useImageDTO(imageName);
const hasTemplates = useStore($hasTemplates);
const imageActions = useImageActions(imageDTO);
const isStaging = useAppSelector(selectIsStaging);
const isUpscalingEnabled = useFeatureStatus('upscaling');
const { getState, dispatch } = useAppStore();
const canvasManager = useCanvasManagerSafe();
const handleEdit = useCallback(async () => {
if (!imageDTO) {
return;
}
await newCanvasFromImage({
imageDTO,
type: 'raster_layer',
withInpaintMask: true,
getState,
dispatch,
});
navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
// Automatically select the brush tool when editing an image
if (canvasManager) {
canvasManager.tool.$tool.set('brush');
}
toast({
id: 'SENT_TO_CANVAS',
title: t('toast.sentToCanvas'),
status: 'success',
});
}, [imageDTO, getState, dispatch, t, canvasManager]);
const recallAll = useRecallAll(imageDTO);
const recallRemix = useRecallRemix(imageDTO);
const recallPrompts = useRecallPrompts(imageDTO);
const recallSeed = useRecallSeed(imageDTO);
const recallDimensions = useRecallDimensions(imageDTO);
const loadWorkflow = useLoadWorkflow(imageDTO);
const editImage = useEditImage(imageDTO);
const deleteImage = useDeleteImage(imageDTO);
return (
<>
@@ -80,7 +50,7 @@ export const CurrentImageButtons = memo(() => {
as={IconButton}
aria-label={t('parameters.imageActions')}
tooltip={t('parameters.imageActions')}
isDisabled={isDisabledOverride || !imageDTO}
isDisabled={!imageDTO}
variant="link"
alignSelf="stretch"
icon={<PiDotsThreeOutlineFill />}
@@ -92,8 +62,8 @@ export const CurrentImageButtons = memo(() => {
<Button
leftIcon={<PiPencilBold />}
onClick={handleEdit}
isDisabled={isDisabledOverride || !imageDTO}
onClick={editImage.edit}
isDisabled={!editImage.isEnabled}
variant="link"
size="sm"
alignSelf="stretch"
@@ -108,62 +78,72 @@ export const CurrentImageButtons = memo(() => {
icon={<PiFlowArrowBold />}
tooltip={`${t('nodes.loadWorkflow')} (W)`}
aria-label={`${t('nodes.loadWorkflow')} (W)`}
isDisabled={isDisabledOverride || !imageDTO || !imageActions.hasWorkflow || !hasTemplates}
isDisabled={!loadWorkflow.isEnabled}
variant="link"
alignSelf="stretch"
onClick={imageActions.loadWorkflow}
/>
<IconButton
icon={<PiArrowsCounterClockwiseBold />}
tooltip={`${t('parameters.remixImage')} (R)`}
aria-label={`${t('parameters.remixImage')} (R)`}
isDisabled={isDisabledOverride || !imageDTO || !imageActions.hasMetadata}
variant="link"
alignSelf="stretch"
onClick={imageActions.remix}
/>
<IconButton
icon={<PiQuotesBold />}
tooltip={`${t('parameters.usePrompt')} (P)`}
aria-label={`${t('parameters.usePrompt')} (P)`}
isDisabled={isDisabledOverride || !imageDTO || !imageActions.hasPrompts}
variant="link"
alignSelf="stretch"
onClick={imageActions.recallPrompts}
/>
<IconButton
icon={<PiPlantBold />}
tooltip={`${t('parameters.useSeed')} (S)`}
aria-label={`${t('parameters.useSeed')} (S)`}
isDisabled={isDisabledOverride || !imageDTO || !imageActions.hasSeed}
variant="link"
alignSelf="stretch"
onClick={imageActions.recallSeed}
/>
<IconButton
icon={<PiRulerBold />}
tooltip={`${t('parameters.useSize')} (D)`}
aria-label={`${t('parameters.useSize')} (D)`}
variant="link"
alignSelf="stretch"
onClick={imageActions.recallSize}
isDisabled={isDisabledOverride || !imageDTO || isStaging}
/>
<IconButton
icon={<PiAsteriskBold />}
tooltip={`${t('parameters.useAll')} (A)`}
aria-label={`${t('parameters.useAll')} (A)`}
isDisabled={isDisabledOverride || !imageDTO || !imageActions.hasMetadata}
variant="link"
alignSelf="stretch"
onClick={imageActions.recallAll}
onClick={loadWorkflow.load}
/>
{isCanvasOrGenerateTab && (
<IconButton
icon={<PiArrowsCounterClockwiseBold />}
tooltip={`${t('parameters.remixImage')} (R)`}
aria-label={`${t('parameters.remixImage')} (R)`}
isDisabled={!recallRemix.isEnabled}
variant="link"
alignSelf="stretch"
onClick={recallRemix.recall}
/>
)}
{isCanvasOrGenerateTab && (
<IconButton
icon={<PiQuotesBold />}
tooltip={`${t('parameters.usePrompt')} (P)`}
aria-label={`${t('parameters.usePrompt')} (P)`}
isDisabled={!recallPrompts.isEnabled}
variant="link"
alignSelf="stretch"
onClick={recallPrompts.recall}
/>
)}
{isCanvasOrGenerateTab && (
<IconButton
icon={<PiPlantBold />}
tooltip={`${t('parameters.useSeed')} (S)`}
aria-label={`${t('parameters.useSeed')} (S)`}
isDisabled={!recallSeed.isEnabled}
variant="link"
alignSelf="stretch"
onClick={recallSeed.recall}
/>
)}
{isCanvasOrGenerateTab && (
<IconButton
icon={<PiRulerBold />}
tooltip={`${t('parameters.useSize')} (D)`}
aria-label={`${t('parameters.useSize')} (D)`}
variant="link"
alignSelf="stretch"
onClick={recallDimensions.recall}
isDisabled={!recallDimensions.isEnabled}
/>
)}
{isCanvasOrGenerateTab && (
<IconButton
icon={<PiAsteriskBold />}
tooltip={`${t('parameters.useAll')} (A)`}
aria-label={`${t('parameters.useAll')} (A)`}
isDisabled={!recallAll.isEnabled}
variant="link"
alignSelf="stretch"
onClick={recallAll.recall}
/>
)}
{isUpscalingEnabled && <PostProcessingPopover imageDTO={imageDTO} isDisabled={isDisabledOverride} />}
{isUpscalingEnabled && <PostProcessingPopover imageDTO={imageDTO} isDisabled={false} />}
<Divider orientation="vertical" h={8} mx={2} />
<DeleteImageButton onClick={imageActions.delete} isDisabled={isDisabledOverride || !imageDTO} />
<DeleteImageButton onClick={deleteImage.delete} isDisabled={!deleteImage.isEnabled} />
</>
);
});

View File

@@ -50,7 +50,7 @@ export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | nu
>
{imageDTO && (
<Flex w="full" h="full" position="absolute" alignItems="center" justifyContent="center">
<DndImage imageDTO={imageDTO} onLoad={onLoadImage} />
<DndImage imageDTO={imageDTO} onLoad={onLoadImage} borderRadius="base" />
</Flex>
)}
{!imageDTO && <NoContentForViewer />}

View File

@@ -1,18 +1,24 @@
import { Flex, Spacer } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { memo } from 'react';
import { useImageDTO } from 'services/api/endpoints/images';
import { CurrentImageButtons } from './CurrentImageButtons';
import { ToggleProgressButton } from './ToggleProgressButton';
export const ViewerToolbar = memo(() => {
const imageName = useAppSelector(selectLastSelectedImage);
const imageDTO = useImageDTO(imageName);
return (
<Flex w="full" justifyContent="center" h={8}>
<ToggleProgressButton />
<Spacer />
<CurrentImageButtons />
{imageDTO && <CurrentImageButtons imageDTO={imageDTO} />}
<Spacer />
<ToggleMetadataViewerButton />
{imageDTO && <ToggleMetadataViewerButton />}
</Flex>
);
});

View File

@@ -2,6 +2,7 @@ import { Box, Flex, forwardRef, Grid, GridItem, Spinner, Text } from '@invoke-ai
import { createSelector } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { getFocusedRegion } from 'common/hooks/focus';
import { useRangeBasedImageFetching } from 'features/gallery/hooks/useRangeBasedImageFetching';
import type { selectGetImageNamesQueryArgs } from 'features/gallery/store/gallerySelectors';
import {
@@ -221,6 +222,10 @@ const useKeyboardNavigation = (
const handleKeyDown = useCallback(
(event: KeyboardEvent) => {
if (getFocusedRegion() !== 'gallery') {
// Only handle keyboard navigation when the gallery is focused
return;
}
// Only handle arrow keys
if (!['ArrowUp', 'ArrowDown', 'ArrowLeft', 'ArrowRight'].includes(event.key)) {
return;
@@ -477,11 +482,6 @@ export const NewGallery = memo(() => {
const context = useMemo<GridContext>(() => ({ imageNames, queryArgs }), [imageNames, queryArgs]);
// Item content function
const itemContent: GridItemContent<string, GridContext> = useCallback((index, imageName) => {
return <ImageAtPosition index={index} imageName={imageName} />;
}, []);
if (isLoading) {
return (
<Flex w="full" h="full" alignItems="center" justifyContent="center" gap={4}>
@@ -506,7 +506,7 @@ export const NewGallery = memo(() => {
ref={virtuosoRef}
context={context}
data={imageNames}
increaseViewportBy={2048}
increaseViewportBy={4096}
itemContent={itemContent}
computeItemKey={computeItemKey}
components={components}
@@ -523,8 +523,12 @@ export const NewGallery = memo(() => {
NewGallery.displayName = 'NewGallery';
const scrollSeekConfiguration: ScrollSeekConfiguration = {
enter: (velocity) => velocity > 4096,
exit: (velocity) => velocity === 0,
enter: (velocity) => {
return Math.abs(velocity) > 2048;
},
exit: (velocity) => {
return velocity === 0;
},
};
// Styles
@@ -544,6 +548,10 @@ const ListComponent: GridComponents<GridContext>['List'] = forwardRef(({ context
});
ListComponent.displayName = 'ListComponent';
const itemContent: GridItemContent<string, GridContext> = (index, imageName) => {
return <ImageAtPosition index={index} imageName={imageName} />;
};
const ItemComponent: GridComponents<GridContext>['Item'] = forwardRef(({ context: _, ...rest }, ref) => (
<GridItem ref={ref} aspectRatio="1/1" {...rest} />
));

View File

@@ -0,0 +1,26 @@
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import {
activeStylePresetIdChanged,
selectStylePresetActivePresetId,
} from 'features/stylePresets/store/stylePresetSlice';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const useClearStylePresetWithToast = () => {
const store = useAppStore();
const { t } = useTranslation();
const activeStylePresetId = useAppSelector(selectStylePresetActivePresetId);
const clearStylePreset = useCallback(() => {
if (activeStylePresetId) {
store.dispatch(activeStylePresetIdChanged(null));
toast({
status: 'info',
title: t('stylePresets.promptTemplateCleared'),
});
}
}, [activeStylePresetId, store, t]);
return clearStylePreset;
};

View File

@@ -0,0 +1,81 @@
import { useAppStore } from 'app/store/storeHooks';
import { MetadataHandlers, MetadataUtils } from 'features/metadata/parsing';
import { $stylePresetModalState } from 'features/stylePresets/store/stylePresetModal';
import { useCallback, useEffect, useMemo, useState } from 'react';
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
import type { ImageDTO } from 'services/api/types';
export const useCreateStylePresetFromMetadata = (imageDTO?: ImageDTO | null) => {
const store = useAppStore();
const [hasPrompts, setHasPrompts] = useState(false);
const { metadata } = useDebouncedMetadata(imageDTO?.image_name);
useEffect(() => {
MetadataUtils.hasMetadataByHandlers({
handlers: [MetadataHandlers.PositivePrompt, MetadataHandlers.NegativePrompt],
metadata,
store,
require: 'some',
})
.then((result) => {
setHasPrompts(result);
})
.catch(() => {
setHasPrompts(false);
});
}, [metadata, store]);
const isEnabled = useMemo(() => {
if (!imageDTO) {
return false;
}
if (!hasPrompts) {
return false;
}
return true;
}, [hasPrompts, imageDTO]);
const create = useCallback(async () => {
if (!imageDTO) {
return;
}
if (!metadata) {
return;
}
if (!isEnabled) {
return;
}
let positivePrompt: string;
let negativePrompt: string;
try {
positivePrompt = await MetadataHandlers.PositivePrompt.parse(metadata, store);
} catch (error) {
positivePrompt = '';
}
try {
negativePrompt = (await MetadataHandlers.NegativePrompt.parse(metadata, store)) ?? '';
} catch (error) {
negativePrompt = '';
}
$stylePresetModalState.set({
prefilledFormData: {
name: '',
positivePrompt,
negativePrompt,
imageUrl: imageDTO.image_url,
type: 'user',
},
updatingStylePresetId: null,
isModalOpen: true,
});
}, [imageDTO, isEnabled, metadata, store]);
return {
create,
isEnabled,
};
};

View File

@@ -0,0 +1,28 @@
import { useDeleteImageModalApi } from 'features/deleteImageModal/store/state';
import { useCallback, useMemo } from 'react';
import type { ImageDTO } from 'services/api/types';
export const useDeleteImage = (imageDTO?: ImageDTO | null) => {
const deleteImageModal = useDeleteImageModalApi();
const isEnabled = useMemo(() => {
if (!imageDTO) {
return;
}
return true;
}, [imageDTO]);
const _delete = useCallback(() => {
if (!imageDTO) {
return;
}
if (!isEnabled) {
return;
}
deleteImageModal.delete([imageDTO.image_name]);
}, [deleteImageModal, imageDTO, isEnabled]);
return {
delete: _delete,
isEnabled,
};
};

View File

@@ -0,0 +1,57 @@
import { useAppStore } from 'app/store/storeHooks';
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { newCanvasFromImage } from 'features/imageActions/actions';
import { toast } from 'features/toast/toast';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { WORKSPACE_PANEL_ID } from 'features/ui/layouts/shared';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { ImageDTO } from 'services/api/types';
export const useEditImage = (imageDTO?: ImageDTO | null) => {
const { t } = useTranslation();
const { getState, dispatch } = useAppStore();
const canvasManager = useCanvasManagerSafe();
const isEnabled = useMemo(() => {
if (!imageDTO) {
return false;
}
return true;
}, [imageDTO]);
const edit = useCallback(async () => {
if (!imageDTO) {
return;
}
if (!isEnabled) {
return;
}
await newCanvasFromImage({
imageDTO,
type: 'raster_layer',
withInpaintMask: true,
getState,
dispatch,
});
navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
if (canvasManager) {
canvasManager.tool.$tool.set('brush');
}
toast({
id: 'SENT_TO_CANVAS',
title: t('toast.sentToCanvas'),
status: 'success',
});
}, [imageDTO, isEnabled, getState, dispatch, canvasManager, t]);
return {
edit,
isEnabled,
};
};

View File

@@ -1,209 +0,0 @@
import { useStore } from '@nanostores/react';
import { adHocPostProcessingRequested } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { useDeleteImageModalApi } from 'features/deleteImageModal/store/state';
import { MetadataHandlers, MetadataUtils } from 'features/metadata/parsing';
import { $hasTemplates } from 'features/nodes/store/nodesSlice';
import { $stylePresetModalState } from 'features/stylePresets/store/stylePresetModal';
import {
activeStylePresetIdChanged,
selectStylePresetActivePresetId,
} from 'features/stylePresets/store/stylePresetSlice';
import { toast } from 'features/toast/toast';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import { useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
import type { ImageDTO } from 'services/api/types';
export const useImageActions = (imageDTO: ImageDTO | null) => {
const store = useAppStore();
const { t } = useTranslation();
const activeStylePresetId = useAppSelector(selectStylePresetActivePresetId);
const isStaging = useAppSelector(selectIsStaging);
const { metadata } = useDebouncedMetadata(imageDTO?.image_name ?? null);
const [hasMetadata, setHasMetadata] = useState(false);
const [hasSeed, setHasSeed] = useState(false);
const [hasPrompts, setHasPrompts] = useState(false);
const hasTemplates = useStore($hasTemplates);
const deleteImageModal = useDeleteImageModalApi();
useEffect(() => {
const parseMetadata = async () => {
if (metadata) {
setHasMetadata(true);
try {
await MetadataHandlers.Seed.parse(metadata, store);
setHasSeed(true);
} catch {
setHasSeed(false);
}
let hasPrompt = false;
// Need to catch all of these to avoid unhandled promise rejections bubbling up to instrumented error handlers
for (const handler of [
MetadataHandlers.PositivePrompt,
MetadataHandlers.NegativePrompt,
MetadataHandlers.PositiveStylePrompt,
MetadataHandlers.NegativeStylePrompt,
]) {
try {
await handler.parse(metadata, store);
hasPrompt = true;
break;
} catch {
// noop
}
}
setHasPrompts(hasPrompt);
} else {
setHasMetadata(false);
setHasSeed(false);
setHasPrompts(false);
}
};
parseMetadata();
}, [metadata, store]);
const clearStylePreset = useCallback(() => {
if (activeStylePresetId) {
store.dispatch(activeStylePresetIdChanged(null));
toast({
status: 'info',
title: t('stylePresets.promptTemplateCleared'),
});
}
}, [activeStylePresetId, store, t]);
const recallAll = useCallback(() => {
if (!imageDTO) {
return;
}
if (!metadata) {
return;
}
MetadataUtils.recallAll(metadata, store, isStaging ? [MetadataHandlers.Width, MetadataHandlers.Height] : []);
clearStylePreset();
}, [imageDTO, metadata, store, isStaging, clearStylePreset]);
const remix = useCallback(() => {
if (!imageDTO) {
return;
}
if (!metadata) {
return;
}
// Recalls all metadata parameters except seed
MetadataUtils.recallAll(metadata, store, [MetadataHandlers.Seed]);
clearStylePreset();
}, [imageDTO, metadata, store, clearStylePreset]);
const recallSeed = useCallback(() => {
if (!imageDTO) {
return;
}
if (!metadata) {
return;
}
MetadataUtils.recallByHandler({ metadata, store, handler: MetadataHandlers.Seed });
}, [imageDTO, metadata, store]);
const recallPrompts = useCallback(() => {
if (!imageDTO) {
return;
}
if (!metadata) {
return;
}
MetadataUtils.recallPrompts(metadata, store);
clearStylePreset();
}, [imageDTO, metadata, store, clearStylePreset]);
const createAsPreset = useCallback(async () => {
if (!imageDTO) {
return;
}
if (!metadata) {
return;
}
let positivePrompt: string;
let negativePrompt: string;
try {
positivePrompt = await MetadataHandlers.PositivePrompt.parse(metadata, store);
} catch (error) {
positivePrompt = '';
}
try {
negativePrompt = (await MetadataHandlers.NegativePrompt.parse(metadata, store)) ?? '';
} catch (error) {
negativePrompt = '';
}
$stylePresetModalState.set({
prefilledFormData: {
name: '',
positivePrompt,
negativePrompt,
imageUrl: imageDTO.image_url,
type: 'user',
},
updatingStylePresetId: null,
isModalOpen: true,
});
}, [imageDTO, metadata, store]);
const loadWorkflowWithDialog = useLoadWorkflowWithDialog();
const loadWorkflowFromImage = useCallback(() => {
if (!imageDTO) {
return;
}
if (!imageDTO.has_workflow || !hasTemplates) {
return;
}
loadWorkflowWithDialog({ type: 'image', data: imageDTO.image_name });
}, [hasTemplates, imageDTO, loadWorkflowWithDialog]);
const recallSize = useCallback(() => {
if (!imageDTO) {
return;
}
if (isStaging) {
return;
}
MetadataUtils.recallDimensions(imageDTO, store);
}, [imageDTO, isStaging, store]);
const upscale = useCallback(() => {
if (!imageDTO) {
return;
}
store.dispatch(adHocPostProcessingRequested({ imageDTO }));
}, [imageDTO, store]);
const _delete = useCallback(() => {
if (!imageDTO) {
return;
}
deleteImageModal.delete([imageDTO.image_name]);
}, [deleteImageModal, imageDTO]);
return {
hasMetadata,
hasSeed,
hasPrompts,
recallAll,
remix,
recallSeed,
recallPrompts,
createAsPreset,
loadWorkflow: loadWorkflowFromImage,
hasWorkflow: imageDTO?.has_workflow ?? false,
recallSize,
upscale,
delete: _delete,
};
};

View File

@@ -0,0 +1,34 @@
import { useStore } from '@nanostores/react';
import { $hasTemplates } from 'features/nodes/store/nodesSlice';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import { useCallback, useMemo } from 'react';
import type { ImageDTO } from 'services/api/types';
export const useLoadWorkflow = (imageDTO: ImageDTO) => {
const hasTemplates = useStore($hasTemplates);
const loadWorkflowWithDialog = useLoadWorkflowWithDialog();
const isEnabled = useMemo(() => {
if (!imageDTO.has_workflow) {
return false;
}
if (!hasTemplates) {
return false;
}
return true;
}, [hasTemplates, imageDTO]);
const load = useCallback(() => {
if (!imageDTO) {
return;
}
if (!isEnabled) {
return;
}
loadWorkflowWithDialog({ type: 'image', data: imageDTO.image_name });
}, [imageDTO, isEnabled, loadWorkflowWithDialog]);
return { load, isEnabled };
};

View File

@@ -1,5 +1,5 @@
import { useAppStore } from 'app/store/storeHooks';
import { useCallback } from 'react';
import { useCallback, useEffect, useState } from 'react';
import type { ListRange } from 'react-virtuoso';
import { imagesApi, useGetImageDTOsByNamesMutation } from 'services/api/endpoints/images';
import { useThrottledCallback } from 'use-debounce';
@@ -13,33 +13,20 @@ interface UseRangeBasedImageFetchingReturn {
onRangeChanged: (range: ListRange) => void;
}
const getUncachedNames = (imageNames: string[], cachedImageNames: string[], range: ListRange): string[] => {
if (range.startIndex === range.endIndex) {
// If the start and end indices are the same, no range to fetch
return [];
}
const getUncachedNames = (imageNames: string[], cachedImageNames: string[], ranges: ListRange[]): string[] => {
const uncachedNamesSet = new Set<string>();
const cachedImageNamesSet = new Set(cachedImageNames);
if (imageNames.length === 0) {
return [];
}
const start = Math.max(0, range.startIndex);
const end = Math.min(imageNames.length - 1, range.endIndex);
if (cachedImageNames.length === 0) {
return imageNames.slice(start, end + 1);
}
const uncachedNames: string[] = [];
for (let i = start; i <= end; i++) {
const imageName = imageNames[i]!;
if (!cachedImageNames.includes(imageName)) {
uncachedNames.push(imageName);
for (const range of ranges) {
for (let i = range.startIndex; i <= range.endIndex; i++) {
const n = imageNames[i]!;
if (n && !cachedImageNamesSet.has(n)) {
uncachedNamesSet.add(n);
}
}
}
return uncachedNames;
return Array.from(uncachedNamesSet);
};
/**
@@ -53,30 +40,36 @@ export const useRangeBasedImageFetching = ({
}: UseRangeBasedImageFetchingArgs): UseRangeBasedImageFetchingReturn => {
const store = useAppStore();
const [getImageDTOsByNames] = useGetImageDTOsByNamesMutation();
const [lastRange, setLastRange] = useState<ListRange | null>(null);
const [pendingRanges, setPendingRanges] = useState<ListRange[]>([]);
const fetchImages = useCallback(
(visibleRange: ListRange) => {
(ranges: ListRange[], imageNames: string[]) => {
if (!enabled) {
return;
}
const cachedImageNames = imagesApi.util.selectCachedArgsForQuery(store.getState(), 'getImageDTO');
const uncachedNames = getUncachedNames(imageNames, cachedImageNames, visibleRange);
const uncachedNames = getUncachedNames(imageNames, cachedImageNames, ranges);
if (uncachedNames.length === 0) {
return;
}
getImageDTOsByNames({ image_names: uncachedNames });
setPendingRanges([]);
},
[enabled, getImageDTOsByNames, imageNames, store]
[enabled, getImageDTOsByNames, store]
);
const throttledFetchImages = useThrottledCallback(fetchImages, 100);
const throttledFetchImages = useThrottledCallback(fetchImages, 500);
const onRangeChanged = useCallback(
(range: ListRange) => {
throttledFetchImages(range);
},
[throttledFetchImages]
);
const onRangeChanged = useCallback((range: ListRange) => {
setLastRange(range);
setPendingRanges((prev) => [...prev, range]);
}, []);
useEffect(() => {
const combinedRanges = lastRange ? [...pendingRanges, lastRange] : pendingRanges;
throttledFetchImages(combinedRanges, imageNames);
}, [imageNames, lastRange, pendingRanges, throttledFetchImages]);
return {
onRangeChanged,

View File

@@ -0,0 +1,57 @@
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { MetadataHandlers, MetadataUtils } from 'features/metadata/parsing';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { useCallback, useMemo } from 'react';
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
import type { ImageDTO } from 'services/api/types';
import { useClearStylePresetWithToast } from './useClearStylePresetWithToast';
export const useRecallAll = (imageDTO: ImageDTO) => {
const store = useAppStore();
const tab = useAppSelector(selectActiveTab);
const { metadata, isLoading } = useDebouncedMetadata(imageDTO.image_name);
const isStaging = useAppSelector(selectIsStaging);
const clearStylePreset = useClearStylePresetWithToast();
const isEnabled = useMemo(() => {
if (isLoading) {
return false;
}
if (tab !== 'canvas' && tab !== 'generate') {
return false;
}
if (!metadata) {
return false;
}
return true;
}, [isLoading, metadata, tab]);
const handlersToSkip = useMemo(() => {
if (tab === 'canvas' && isStaging) {
// When we are staging and on canvas, the bbox is locked - we cannot recall width and height
return [MetadataHandlers.Width, MetadataHandlers.Height];
}
return undefined;
}, [isStaging, tab]);
const recall = useCallback(() => {
if (!metadata) {
return;
}
if (!isEnabled) {
return;
}
MetadataUtils.recallAll(metadata, store, handlersToSkip);
clearStylePreset();
}, [metadata, isEnabled, store, handlersToSkip, clearStylePreset]);
return {
recall,
isEnabled,
};
};

View File

@@ -0,0 +1,36 @@
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { MetadataUtils } from 'features/metadata/parsing';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { useCallback, useMemo } from 'react';
import type { ImageDTO } from 'services/api/types';
export const useRecallDimensions = (imageDTO: ImageDTO) => {
const store = useAppStore();
const tab = useAppSelector(selectActiveTab);
const isStaging = useAppSelector(selectIsStaging);
const isEnabled = useMemo(() => {
if (tab !== 'canvas' && tab !== 'generate') {
return false;
}
if (tab === 'canvas' && isStaging) {
return false;
}
return true;
}, [isStaging, tab]);
const recall = useCallback(() => {
if (!isEnabled) {
return;
}
MetadataUtils.recallDimensions(imageDTO, store);
}, [isEnabled, imageDTO, store]);
return {
recall,
isEnabled,
};
};

View File

@@ -0,0 +1,72 @@
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { MetadataHandlers, MetadataUtils } from 'features/metadata/parsing';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { useCallback, useEffect, useMemo, useState } from 'react';
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
import type { ImageDTO } from 'services/api/types';
import { useClearStylePresetWithToast } from './useClearStylePresetWithToast';
export const useRecallPrompts = (imageDTO: ImageDTO) => {
const store = useAppStore();
const tab = useAppSelector(selectActiveTab);
const clearStylePreset = useClearStylePresetWithToast();
const [hasPrompts, setHasPrompts] = useState(false);
const { metadata, isLoading } = useDebouncedMetadata(imageDTO.image_name);
useEffect(() => {
const parse = async () => {
try {
const result = await MetadataUtils.hasMetadataByHandlers({
handlers: [
MetadataHandlers.PositivePrompt,
MetadataHandlers.NegativePrompt,
MetadataHandlers.PositiveStylePrompt,
MetadataHandlers.NegativeStylePrompt,
],
metadata,
store,
require: 'some',
});
setHasPrompts(result);
} catch {
setHasPrompts(false);
}
};
parse();
}, [metadata, store]);
const isEnabled = useMemo(() => {
if (isLoading) {
return false;
}
if (tab !== 'canvas' && tab !== 'generate') {
return false;
}
if (!hasPrompts) {
return false;
}
return true;
}, [hasPrompts, isLoading, tab]);
const recall = useCallback(() => {
if (!metadata) {
return;
}
if (!isEnabled) {
return;
}
MetadataUtils.recallPrompts(metadata, store);
clearStylePreset();
}, [metadata, isEnabled, store, clearStylePreset]);
return {
recall,
isEnabled,
};
};

View File

@@ -0,0 +1,60 @@
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { MetadataHandlers, MetadataUtils } from 'features/metadata/parsing';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { useCallback, useMemo } from 'react';
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
import type { ImageDTO } from 'services/api/types';
import { useClearStylePresetWithToast } from './useClearStylePresetWithToast';
export const useRecallRemix = (imageDTO: ImageDTO) => {
const store = useAppStore();
const tab = useAppSelector(selectActiveTab);
const isStaging = useAppSelector(selectIsStaging);
const clearStylePreset = useClearStylePresetWithToast();
const { metadata, isLoading } = useDebouncedMetadata(imageDTO.image_name);
const isEnabled = useMemo(() => {
if (isLoading) {
return false;
}
if (tab !== 'canvas' && tab !== 'generate') {
return false;
}
if (!metadata) {
return false;
}
return true;
}, [isLoading, metadata, tab]);
const handlersToSkip = useMemo(() => {
// Remix always skips the seed handler
const _handlersToSkip = [MetadataHandlers.Seed];
if (tab === 'canvas' && isStaging) {
// When we are staging and on canvas, the bbox is locked - we cannot recall width and height
_handlersToSkip.push(MetadataHandlers.Width, MetadataHandlers.Height);
}
return _handlersToSkip;
}, [isStaging, tab]);
const recall = useCallback(() => {
if (!metadata) {
return;
}
if (!isEnabled) {
return;
}
MetadataUtils.recallAll(metadata, store, handlersToSkip);
clearStylePreset();
}, [metadata, isEnabled, store, handlersToSkip, clearStylePreset]);
return {
recall,
isEnabled,
};
};

View File

@@ -0,0 +1,62 @@
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { MetadataHandlers, MetadataUtils } from 'features/metadata/parsing';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { useCallback, useEffect, useMemo, useState } from 'react';
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
import type { ImageDTO } from 'services/api/types';
export const useRecallSeed = (imageDTO: ImageDTO) => {
const store = useAppStore();
const tab = useAppSelector(selectActiveTab);
const [hasSeed, setHasSeed] = useState(false);
const { metadata, isLoading } = useDebouncedMetadata(imageDTO.image_name);
useEffect(() => {
const parse = async () => {
try {
await MetadataHandlers.Seed.parse(metadata, store);
setHasSeed(true);
} catch {
setHasSeed(false);
}
};
parse();
}, [metadata, store]);
const isEnabled = useMemo(() => {
if (isLoading) {
return false;
}
if (tab !== 'canvas' && tab !== 'generate') {
return false;
}
if (!metadata) {
return false;
}
if (!hasSeed) {
return false;
}
return true;
}, [hasSeed, isLoading, metadata, tab]);
const recall = useCallback(() => {
if (!metadata) {
return;
}
if (!isEnabled) {
return;
}
MetadataUtils.recallByHandler({ metadata, handler: MetadataHandlers.Seed, store });
}, [metadata, isEnabled, store]);
return {
recall,
isEnabled,
};
};

View File

@@ -85,7 +85,12 @@ export const createNewCanvasEntityFromImage = async (arg: {
}) => {
const { type, imageDTO, dispatch, getState, withResize, overrides: _overrides } = arg;
const state = getState();
const { x, y, width, height } = selectBboxRect(state);
const { x, y } = selectBboxRect(state);
const base = selectBboxModelBase(state);
const ratio = imageDTO.width / imageDTO.height;
const optimalDimension = getOptimalDimension(base);
const { width, height } = calculateNewSize(ratio, optimalDimension ** 2, base);
let imageObject: CanvasImageState;

View File

@@ -8,6 +8,7 @@ import { getPrefixedId } from 'features/controlLayers/konva/util';
import { bboxHeightChanged, bboxWidthChanged, canvasMetadataRecalled } from 'features/controlLayers/store/canvasSlice';
import { loraAllDeleted, loraRecalled } from 'features/controlLayers/store/lorasSlice';
import {
heightChanged,
negativePrompt2Changed,
negativePromptChanged,
positivePrompt2Changed,
@@ -31,6 +32,7 @@ import {
setSteps,
shouldConcatPromptsChanged,
vaeSelected,
widthChanged,
} from 'features/controlLayers/store/paramsSlice';
import { refImagesRecalled } from 'features/controlLayers/store/refImagesSlice';
import type { CanvasMetadata, LoRA, RefImageState } from 'features/controlLayers/store/types';
@@ -82,8 +84,9 @@ import {
zParameterStrength,
} from 'features/parameters/types/parameterSchemas';
import { toast } from 'features/toast/toast';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { t } from 'i18next';
import type { ComponentType, ReactNode } from 'react';
import type { ComponentType } from 'react';
import { useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { modelsApi } from 'services/api/endpoints/models';
@@ -170,7 +173,8 @@ export type SingleMetadataHandler<T> = {
type: string;
parse: (metadata: unknown, store: AppStore) => Promise<T>;
recall: (value: T, store: AppStore) => void;
LabelComponent: ComponentType;
i18nKey: string;
LabelComponent: ComponentType<{ i18nKey: string }>;
ValueComponent: ComponentType<SingleMetadataValueProps<T>>;
};
@@ -184,7 +188,8 @@ export type CollectionMetadataHandler<T extends any[]> = {
parse: (metadata: unknown, store: AppStore) => Promise<T>;
recall: (values: T, store: AppStore) => void;
recallOne: (value: T[number], store: AppStore) => void;
LabelComponent: ComponentType;
i18nKey: string;
LabelComponent: ComponentType<{ i18nKey: string }>;
ValueComponent: ComponentType<CollectionMetadataValueProps<T>>;
};
@@ -196,7 +201,8 @@ export type UnrecallableMetadataHandler<T> = {
[UnrecallableMetadataKey]: true;
type: string;
parse: (metadata: unknown, store: AppStore) => Promise<T>;
LabelComponent: ComponentType;
i18nKey: string;
LabelComponent: ComponentType<{ i18nKey: string }>;
ValueComponent: ComponentType<UnrecallableMetadataValueProps<T>>;
};
@@ -221,7 +227,8 @@ const CreatedBy: UnrecallableMetadataHandler<string> = {
const parsed = z.string().parse(raw);
return Promise.resolve(parsed);
},
LabelComponent: () => <MetadataLabel i18nKey="metadata.createdBy" />,
i18nKey: 'metadata.createdBy',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: UnrecallableMetadataValueProps<string>) => <MetadataPrimitiveValue value={value} />,
};
//#endregion Created By
@@ -235,7 +242,8 @@ const GenerationMode: UnrecallableMetadataHandler<string> = {
const parsed = z.string().parse(raw);
return Promise.resolve(parsed);
},
LabelComponent: () => <MetadataLabel i18nKey="metadata.generationMode" />,
i18nKey: 'metadata.generationMode',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: UnrecallableMetadataValueProps<string>) => <MetadataPrimitiveValue value={value} />,
};
//#endregion Generation Mode
@@ -252,7 +260,8 @@ const PositivePrompt: SingleMetadataHandler<ParameterPositivePrompt> = {
recall: (value, store) => {
store.dispatch(positivePromptChanged(value));
},
LabelComponent: () => <MetadataLabel i18nKey="metadata.positivePrompt" />,
i18nKey: 'metadata.positivePrompt',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterPositivePrompt>) => (
<MetadataPrimitiveValue value={value} />
),
@@ -271,7 +280,8 @@ const NegativePrompt: SingleMetadataHandler<ParameterNegativePrompt> = {
recall: (value, store) => {
store.dispatch(negativePromptChanged(value || null));
},
LabelComponent: () => <MetadataLabel i18nKey="metadata.negativePrompt" />,
i18nKey: 'metadata.negativePrompt',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterNegativePrompt>) => (
<MetadataPrimitiveValue value={value} />
),
@@ -290,7 +300,8 @@ const PositiveStylePrompt: SingleMetadataHandler<ParameterPositiveStylePromptSDX
recall: (value, store) => {
store.dispatch(positivePrompt2Changed(value));
},
LabelComponent: () => <MetadataLabel i18nKey="sdxl.posStylePrompt" />,
i18nKey: 'sdxl.posStylePrompt',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterPositiveStylePromptSDXL>) => (
<MetadataPrimitiveValue value={value} />
),
@@ -309,7 +320,8 @@ const NegativeStylePrompt: SingleMetadataHandler<ParameterPositiveStylePromptSDX
recall: (value, store) => {
store.dispatch(negativePrompt2Changed(value));
},
LabelComponent: () => <MetadataLabel i18nKey="sdxl.negStylePrompt" />,
i18nKey: 'sdxl.negStylePrompt',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterPositiveStylePromptSDXL>) => (
<MetadataPrimitiveValue value={value} />
),
@@ -328,7 +340,8 @@ const CFGScale: SingleMetadataHandler<ParameterCFGScale> = {
recall: (value, store) => {
store.dispatch(setCfgScale(value));
},
LabelComponent: () => <MetadataLabel i18nKey="metadata.cfgScale" />,
i18nKey: 'metadata.cfgScale',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterCFGScale>) => <MetadataPrimitiveValue value={value} />,
};
//#endregion CFG Scale
@@ -345,7 +358,8 @@ const CFGRescaleMultiplier: SingleMetadataHandler<ParameterCFGRescaleMultiplier>
recall: (value, store) => {
store.dispatch(setCfgRescaleMultiplier(value));
},
LabelComponent: () => <MetadataLabel i18nKey="metadata.cfgRescaleMultiplier" />,
i18nKey: 'metadata.cfgRescaleMultiplier',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterCFGRescaleMultiplier>) => (
<MetadataPrimitiveValue value={value} />
),
@@ -364,7 +378,8 @@ const Guidance: SingleMetadataHandler<ParameterGuidance> = {
recall: (value, store) => {
store.dispatch(setGuidance(value));
},
LabelComponent: () => <MetadataLabel i18nKey="metadata.guidance" />,
i18nKey: 'metadata.guidance',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterGuidance>) => <MetadataPrimitiveValue value={value} />,
};
//#endregion Guidance
@@ -381,7 +396,8 @@ const Scheduler: SingleMetadataHandler<ParameterScheduler> = {
recall: (value, store) => {
store.dispatch(setScheduler(value));
},
LabelComponent: () => <MetadataLabel i18nKey="metadata.scheduler" />,
i18nKey: 'metadata.scheduler',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterScheduler>) => <MetadataPrimitiveValue value={value} />,
};
//#endregion Scheduler
@@ -396,9 +412,15 @@ const Width: SingleMetadataHandler<ParameterWidth> = {
return Promise.resolve(parsed);
},
recall: (value, store) => {
store.dispatch(bboxWidthChanged({ width: value, updateAspectRatio: true, clamp: true }));
const activeTab = selectActiveTab(store.getState());
if (activeTab === 'canvas') {
store.dispatch(bboxWidthChanged({ width: value, updateAspectRatio: true, clamp: true }));
} else if (activeTab === 'generate') {
store.dispatch(widthChanged({ width: value, updateAspectRatio: true, clamp: true }));
}
},
LabelComponent: () => <MetadataLabel i18nKey="metadata.width" />,
i18nKey: 'metadata.width',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterWidth>) => <MetadataPrimitiveValue value={value} />,
};
//#endregion Width
@@ -413,9 +435,15 @@ const Height: SingleMetadataHandler<ParameterHeight> = {
return Promise.resolve(parsed);
},
recall: (value, store) => {
store.dispatch(bboxHeightChanged({ height: value, updateAspectRatio: true, clamp: true }));
const activeTab = selectActiveTab(store.getState());
if (activeTab === 'canvas') {
store.dispatch(bboxHeightChanged({ height: value, updateAspectRatio: true, clamp: true }));
} else if (activeTab === 'generate') {
store.dispatch(heightChanged({ height: value, updateAspectRatio: true, clamp: true }));
}
},
LabelComponent: () => <MetadataLabel i18nKey="metadata.height" />,
i18nKey: 'metadata.height',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterHeight>) => <MetadataPrimitiveValue value={value} />,
};
//#endregion Height
@@ -432,7 +460,8 @@ const Seed: SingleMetadataHandler<ParameterSeed> = {
recall: (value, store) => {
store.dispatch(setSeed(value));
},
LabelComponent: () => <MetadataLabel i18nKey="metadata.seed" />,
i18nKey: 'metadata.seed',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterSeed>) => <MetadataPrimitiveValue value={value} />,
};
//#endregion Seed
@@ -449,7 +478,8 @@ const Steps: SingleMetadataHandler<ParameterSteps> = {
recall: (value, store) => {
store.dispatch(setSteps(value));
},
LabelComponent: () => <MetadataLabel i18nKey="metadata.steps" />,
i18nKey: 'metadata.steps',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterSteps>) => <MetadataPrimitiveValue value={value} />,
};
//#endregion Steps
@@ -466,7 +496,8 @@ const DenoisingStrength: SingleMetadataHandler<ParameterStrength> = {
recall: (value, store) => {
store.dispatch(setImg2imgStrength(value));
},
LabelComponent: () => <MetadataLabel i18nKey="metadata.strength" />,
i18nKey: 'metadata.strength',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterStrength>) => <MetadataPrimitiveValue value={value} />,
};
//#endregion DenoisingStrength
@@ -483,7 +514,8 @@ const SeamlessX: SingleMetadataHandler<ParameterSeamlessX> = {
recall: (value, store) => {
store.dispatch(setSeamlessXAxis(value));
},
LabelComponent: () => <MetadataLabel i18nKey="metadata.seamlessXAxis" />,
i18nKey: 'metadata.seamlessXAxis',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterSeamlessX>) => <MetadataPrimitiveValue value={value} />,
};
//#endregion SeamlessX
@@ -500,7 +532,8 @@ const SeamlessY: SingleMetadataHandler<ParameterSeamlessY> = {
recall: (value, store) => {
store.dispatch(setSeamlessYAxis(value));
},
LabelComponent: () => <MetadataLabel i18nKey="metadata.seamlessYAxis" />,
i18nKey: 'metadata.seamlessYAxis',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterSeamlessY>) => <MetadataPrimitiveValue value={value} />,
};
//#endregion SeamlessY
@@ -520,7 +553,8 @@ const RefinerModel: SingleMetadataHandler<ParameterSDXLRefinerModel> = {
recall: (value, store) => {
store.dispatch(refinerModelChanged(value));
},
LabelComponent: () => <MetadataLabel i18nKey="sdxl.refinermodel" />,
i18nKey: 'sdxl.refinermodel',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterSDXLRefinerModel>) => (
<MetadataPrimitiveValue value={`${value.name} (${value.base.toUpperCase()})`} />
),
@@ -539,7 +573,8 @@ const RefinerSteps: SingleMetadataHandler<ParameterSteps> = {
recall: (value, store) => {
store.dispatch(setRefinerSteps(value));
},
LabelComponent: () => <MetadataLabel i18nKey="sdxl.refinerSteps" />,
i18nKey: 'sdxl.refinerSteps',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterSteps>) => <MetadataPrimitiveValue value={value} />,
};
//#endregion RefinerSteps
@@ -556,7 +591,8 @@ const RefinerCFGScale: SingleMetadataHandler<ParameterCFGScale> = {
recall: (value, store) => {
store.dispatch(setRefinerCFGScale(value));
},
LabelComponent: () => <MetadataLabel i18nKey="sdxl.cfgScale" />,
i18nKey: 'sdxl.cfgScale',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterCFGScale>) => <MetadataPrimitiveValue value={value} />,
};
//#endregion RefinerCFGScale
@@ -573,7 +609,8 @@ const RefinerScheduler: SingleMetadataHandler<ParameterScheduler> = {
recall: (value, store) => {
store.dispatch(setRefinerScheduler(value));
},
LabelComponent: () => <MetadataLabel i18nKey="sdxl.scheduler" />,
i18nKey: 'sdxl.scheduler',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterScheduler>) => <MetadataPrimitiveValue value={value} />,
};
//#endregion RefinerScheduler
@@ -590,7 +627,8 @@ const RefinerPositiveAestheticScore: SingleMetadataHandler<ParameterSDXLRefinerP
recall: (value, store) => {
store.dispatch(setRefinerPositiveAestheticScore(value));
},
LabelComponent: () => <MetadataLabel i18nKey="sdxl.posAestheticScore" />,
i18nKey: 'sdxl.posAestheticScore',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterSDXLRefinerPositiveAestheticScore>) => (
<MetadataPrimitiveValue value={value} />
),
@@ -609,7 +647,8 @@ const RefinerNegativeAestheticScore: SingleMetadataHandler<ParameterSDXLRefinerN
recall: (value, store) => {
store.dispatch(setRefinerNegativeAestheticScore(value));
},
LabelComponent: () => <MetadataLabel i18nKey="sdxl.negAestheticScore" />,
i18nKey: 'sdxl.negAestheticScore',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterSDXLRefinerNegativeAestheticScore>) => (
<MetadataPrimitiveValue value={value} />
),
@@ -628,7 +667,8 @@ const RefinerDenoisingStart: SingleMetadataHandler<ParameterSDXLRefinerStart> =
recall: (value, store) => {
store.dispatch(setRefinerStart(value));
},
LabelComponent: () => <MetadataLabel i18nKey="sdxl.refinerStart" />,
i18nKey: 'sdxl.refinerStart',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterSDXLRefinerStart>) => (
<MetadataPrimitiveValue value={value} />
),
@@ -648,7 +688,8 @@ const MainModel: SingleMetadataHandler<ParameterModel> = {
recall: (value, store) => {
store.dispatch(modelSelected(value));
},
LabelComponent: () => <MetadataLabel i18nKey="metadata.model" />,
i18nKey: 'metadata.model',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterModel>) => (
<MetadataPrimitiveValue value={`${value.name} (${value.base.toUpperCase()})`} />
),
@@ -669,7 +710,8 @@ const VAEModel: SingleMetadataHandler<ParameterVAEModel> = {
recall: (value, store) => {
store.dispatch(vaeSelected(value));
},
LabelComponent: () => <MetadataLabel i18nKey="metadata.vae" />,
i18nKey: 'metadata.vae',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterVAEModel>) => (
<MetadataPrimitiveValue value={`${value.name} (${value.base.toUpperCase()})`} />
),
@@ -733,7 +775,8 @@ const LoRAs: CollectionMetadataHandler<LoRA[]> = {
store.dispatch(loraRecalled({ lora }));
}
},
LabelComponent: () => <MetadataLabel i18nKey="models.lora" />,
i18nKey: 'models.lora',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: CollectionMetadataValueProps<LoRA[]>) => (
<MetadataPrimitiveValue value={`${value.model.name} (${value.model.base.toUpperCase()}) - ${value.weight}`} />
),
@@ -763,7 +806,8 @@ const CanvasLayers: SingleMetadataHandler<CanvasMetadata> = {
}
store.dispatch(canvasMetadataRecalled(value));
},
LabelComponent: () => <MetadataLabel i18nKey="metadata.canvasV2Metadata" />,
i18nKey: 'metadata.canvasV2Metadata',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<CanvasMetadata>) => {
const { t } = useTranslation();
const count =
@@ -810,7 +854,8 @@ const RefImages: CollectionMetadataHandler<RefImageState[]> = {
const entities = [{ ...data, id: getPrefixedId('reference_image') }];
store.dispatch(refImagesRecalled({ entities, replace: false }));
},
LabelComponent: () => <MetadataLabel i18nKey="controlLayers.referenceImage" />,
i18nKey: 'controlLayers.referenceImage',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: CollectionMetadataValueProps<RefImageState[]>) => {
if (value.config.model) {
return <MetadataPrimitiveValue value={value.config.model.name} />;
@@ -862,7 +907,7 @@ export const MetadataHandlers = {
// ipAdapterToIPAdapterLayer: parseIPAdapterToIPAdapterLayer,
} as const;
const successToast = (parameter: ReactNode) => {
const successToast = (parameter: string) => {
toast({
id: 'PARAMETER_SET',
title: t('toast.parameterSet'),
@@ -871,7 +916,7 @@ const successToast = (parameter: ReactNode) => {
});
};
const failedToast = (parameter: ReactNode, message?: ReactNode) => {
const failedToast = (parameter: string, message?: string) => {
toast({
id: 'PARAMETER_NOT_SET',
title: t('toast.parameterNotSet'),
@@ -902,9 +947,9 @@ const recallByHandler = async (arg: {
if (!silent) {
if (didRecall) {
successToast(<handler.LabelComponent />);
successToast(t(handler.i18nKey));
} else {
failedToast(<handler.LabelComponent />);
failedToast(t(handler.i18nKey));
}
}
@@ -950,21 +995,24 @@ const recallByHandlers = async (arg: {
}
}
// If we recalled style prompts, and they were _different_ from the positive prompt, we need to disable prompt concat.
// We may need to update the prompt concat flag based on the recalled prompts
const positivePrompt = recalled.get(MetadataHandlers.PositivePrompt);
const negativePrompt = recalled.get(MetadataHandlers.NegativePrompt);
const positiveStylePrompt = recalled.get(MetadataHandlers.PositiveStylePrompt);
const negativeStylePrompt = recalled.get(MetadataHandlers.NegativeStylePrompt);
// The values will be undefined if the handler was not recalled
if (
(positiveStylePrompt && positiveStylePrompt !== positivePrompt) ||
(negativeStylePrompt && negativeStylePrompt !== negativePrompt)
positivePrompt !== undefined ||
negativePrompt !== undefined ||
positiveStylePrompt !== undefined ||
negativeStylePrompt !== undefined
) {
// If we set the negative style prompt or positive style prompt, we should disable prompt concat
store.dispatch(shouldConcatPromptsChanged(false));
} else {
// Otherwise, we should enable prompt concat
store.dispatch(shouldConcatPromptsChanged(true));
const concat =
(Boolean(positiveStylePrompt) && positiveStylePrompt === positivePrompt) ||
(Boolean(negativeStylePrompt) && negativeStylePrompt === negativePrompt);
store.dispatch(shouldConcatPromptsChanged(concat));
}
if (!silent) {
@@ -1003,6 +1051,28 @@ const recallPrompts = async (metadata: unknown, store: AppStore) => {
}
};
const hasMetadataByHandlers = async (arg: {
metadata: unknown;
handlers: (SingleMetadataHandler<any> | CollectionMetadataHandler<any[]>)[];
store: AppStore;
require: 'some' | 'all';
}) => {
const { metadata, handlers, store, require } = arg;
for (const handler of handlers) {
try {
await handler.parse(metadata, store);
if (require === 'some') {
return true;
}
} catch {
if (require === 'all') {
return false;
}
}
}
return true;
};
const recallDimensions = async (metadata: unknown, store: AppStore) => {
const recalled = await recallByHandlers({
metadata,
@@ -1032,6 +1102,7 @@ const recallAll = async (
};
export const MetadataUtils = {
hasMetadataByHandlers,
recallByHandler,
recallByHandlers,
recallAll,

View File

@@ -32,7 +32,7 @@ const CurrentImageNode = (props: NodeProps) => {
if (imageDTO) {
return (
<Wrapper nodeProps={props}>
<DndImage imageDTO={imageDTO} />
<DndImage imageDTO={imageDTO} borderRadius="base" />
</Wrapper>
);
}

View File

@@ -146,6 +146,7 @@ const ImageGridItemContent = memo(
return (
<>
<DndImage
borderRadius="base"
imageDTO={query.data}
asThumbnail
objectFit="contain"

View File

@@ -76,7 +76,7 @@ const ImageFieldInputComponent = (props: FieldComponentProps<ImageFieldInputInst
)}
{imageDTO && (
<>
<Flex borderRadius="base" borderWidth={1} borderStyle="solid">
<Flex borderRadius="base" borderWidth={1} borderStyle="solid" overflow="hidden">
<DndImage imageDTO={imageDTO} asThumbnail />
</Flex>
<Text

View File

@@ -14,7 +14,7 @@ const ImageOutputPreview = ({ output }: Props) => {
return null;
}
return <DndImage imageDTO={imageDTO} />;
return <DndImage imageDTO={imageDTO} borderRadius="base" />;
};
export default memo(ImageOutputPreview);

View File

@@ -137,6 +137,8 @@ const NODE_TYPE_PUBLISH_DENYLIST = [
'chatgpt_4o_edit_image',
'flux_kontext_generate_image',
'flux_kontext_edit_image',
'claude_expand_prompt',
'claude_analyze_image',
];
export const selectHasUnpublishableNodes = createSelector(selectNodes, (nodes) => {

View File

@@ -30,11 +30,14 @@ export const addFLUXFill = async ({
denoise.denoising_start = denoising_start;
denoise.denoising_end = denoising_end;
const { originalSize, scaledSize, rect } = getOriginalAndScaledSizesForOtherModes(state);
denoise.width = scaledSize.width;
denoise.height = scaledSize.height;
const params = selectParamsSlice(state);
const canvasSettings = selectCanvasSettingsSlice(state);
const { originalSize, scaledSize, rect } = getOriginalAndScaledSizesForOtherModes(state);
const rasterAdapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
const initialImage = await manager.compositor.getCompositeImageDTO(rasterAdapters, rect, {
is_intermediate: true,

View File

@@ -78,8 +78,6 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
if (generationMode !== 'txt2img') {
throw new UnsupportedGenerationModeError(t('toast.fluxKontextIncompatibleGenerationMode'));
}
guidance = 30;
}
const g = new Graph(getPrefixedId('flux_graph'));

View File

@@ -26,7 +26,7 @@ export const buildFluxKontextGraph = (arg: GraphBuilderArg): GraphBuilderReturn
assert(model.base === 'flux-kontext', 'Selected model is not a FLUX Kontext API model');
if (generationMode !== 'txt2img') {
throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'FLUX Kontext' }));
throw new UnsupportedGenerationModeError(t('toast.fluxKontextIncompatibleGenerationMode'));
}
log.debug({ generationMode, manager: manager?.id }, 'Building FLUX Kontext graph');

View File

@@ -78,7 +78,7 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
type: 'sdxl_compel_prompt',
id: getPrefixedId('neg_cond'),
prompt: prompts.negative,
style: prompts.negativeStyle,
style: prompts.useMainPromptsForStyle ? prompts.negative : prompts.negativeStyle,
});
const negCondCollect = g.addNode({
type: 'collect',

View File

@@ -1,6 +1,7 @@
import { createSelector } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectSaveAllImagesToGallery } from 'features/controlLayers/store/canvasSettingsSlice';
import {
selectImg2imgStrength,
selectMainModelConfig,
@@ -44,8 +45,11 @@ export const selectCanvasOutputFields = (state: RootState) => {
// Advanced session means working on canvas - images are not saved to gallery or added to a board.
// Simple session means working in YOLO mode - images are saved to gallery & board.
const tab = selectActiveTab(state);
const is_intermediate = tab === 'canvas';
const board = tab === 'canvas' ? undefined : getBoardField(state);
const saveAllImagesToGallery = selectSaveAllImagesToGallery(state);
// If we're on canvas and the save all images setting is enabled, save to gallery
const is_intermediate = tab === 'canvas' && !saveAllImagesToGallery;
const board = tab === 'canvas' && !saveAllImagesToGallery ? undefined : getBoardField(state);
return {
is_intermediate,

View File

@@ -18,7 +18,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { $onClickGoToModelManager } from 'app/store/nanostores/onClickGoToModelManager';
import { useAppSelector } from 'app/store/storeHooks';
import type { Group, PickerContextState } from 'common/components/Picker/Picker';
import { buildGroup, getRegex, isOption, Picker, usePickerContext } from 'common/components/Picker/Picker';
import { buildGroup, getRegex, isGroup, Picker, usePickerContext } from 'common/components/Picker/Picker';
import { useDisclosure } from 'common/hooks/useBoolean';
import { typedMemo } from 'common/util/typedMemo';
import { uniq } from 'es-toolkit/compat';
@@ -277,8 +277,22 @@ export const ModelPicker = typedMemo(
if (!selectedModelConfig) {
return undefined;
}
let _selectedOption: WithStarred<T> | undefined = undefined;
return options.filter(isOption).find((o) => o.key === selectedModelConfig.key);
for (const optionOrGroup of options) {
if (isGroup(optionOrGroup)) {
const result = optionOrGroup.options.find((o) => o.key === selectedModelConfig.key);
if (result) {
_selectedOption = result;
break;
}
} else if (optionOrGroup.key === selectedModelConfig.key) {
_selectedOption = optionOrGroup;
break;
}
}
return _selectedOption;
}, [options, selectedModelConfig]);
const onClose = useCallback(() => {
@@ -361,9 +375,19 @@ const optionSx: SystemStyleObject = {
cursor: 'pointer',
borderRadius: 'base',
'&[data-selected="true"]': {
bg: 'base.700',
bg: 'invokeBlue.300',
color: 'base.900',
'.extra-info': {
color: 'base.700',
},
'.picker-option': {
fontWeight: 'bold',
'&[data-is-compact="true"]': {
fontWeight: 'semibold',
},
},
'&[data-active="true"]': {
bg: 'base.650',
bg: 'invokeBlue.250',
},
},
'&[data-active="true"]': {
@@ -400,17 +424,31 @@ const PickerOptionComponent = typedMemo(
<Flex flexDir="column" gap={1} flex={1}>
<Flex gap={2} alignItems="center">
{option.starred && <Icon as={PiLinkSimple} color="invokeYellow.500" boxSize={4} />}
<Text sx={optionNameSx} data-is-compact={compactView}>
<Text className="picker-option" sx={optionNameSx} data-is-compact={compactView}>
{option.name}
</Text>
<Spacer />
{option.file_size > 0 && (
<Text variant="subtext" fontStyle="italic" noOfLines={1} flexShrink={0} overflow="visible">
<Text
className="extra-info"
variant="subtext"
fontStyle="italic"
noOfLines={1}
flexShrink={0}
overflow="visible"
>
{filesize(option.file_size)}
</Text>
)}
{option.usage_info && (
<Text variant="subtext" fontStyle="italic" noOfLines={1} flexShrink={0} overflow="visible">
<Text
className="extra-info"
variant="subtext"
fontStyle="italic"
noOfLines={1}
flexShrink={0}
overflow="visible"
>
{option.usage_info}
</Text>
)}

View File

@@ -0,0 +1,34 @@
import { Flex } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import {
createParamsSelector,
selectHasNegativePrompt,
selectModelSupportsNegativePrompt,
} from 'features/controlLayers/store/paramsSlice';
import { ParamNegativePrompt } from 'features/parameters/components/Core/ParamNegativePrompt';
import { ParamPositivePrompt } from 'features/parameters/components/Core/ParamPositivePrompt';
import { ParamSDXLNegativeStylePrompt } from 'features/sdxl/components/SDXLPrompts/ParamSDXLNegativeStylePrompt';
import { ParamSDXLPositiveStylePrompt } from 'features/sdxl/components/SDXLPrompts/ParamSDXLPositiveStylePrompt';
import { memo } from 'react';
const selectWithStylePrompts = createParamsSelector((params) => {
const isSDXL = params.model?.base === 'sdxl';
const shouldConcatPrompts = params.shouldConcatPrompts;
return isSDXL && !shouldConcatPrompts;
});
export const UpscalePrompts = memo(() => {
const withStylePrompts = useAppSelector(selectWithStylePrompts);
const modelSupportsNegativePrompt = useAppSelector(selectModelSupportsNegativePrompt);
const hasNegativePrompt = useAppSelector(selectHasNegativePrompt);
return (
<Flex flexDir="column" gap={2}>
<ParamPositivePrompt />
{withStylePrompts && <ParamSDXLPositiveStylePrompt />}
{modelSupportsNegativePrompt && hasNegativePrompt && <ParamNegativePrompt />}
{withStylePrompts && <ParamSDXLNegativeStylePrompt />}
</Flex>
);
});
UpscalePrompts.displayName = 'UpscalePrompts';

View File

@@ -0,0 +1,28 @@
import type { ButtonProps } from '@invoke-ai/ui-library';
import { Button } from '@invoke-ai/ui-library';
import { useCancelAllExceptCurrentQueueItemDialog } from 'features/queue/components/CancelAllExceptCurrentQueueItemConfirmationAlertDialog';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiXCircle } from 'react-icons/pi';
export const CancelAllExceptCurrentButton = memo((props: ButtonProps) => {
const { t } = useTranslation();
const api = useCancelAllExceptCurrentQueueItemDialog();
return (
<Button
isDisabled={api.isDisabled}
isLoading={api.isLoading}
aria-label={t('queue.clear')}
tooltip={t('queue.cancelAllExceptCurrentTooltip')}
leftIcon={<PiXCircle />}
colorScheme="error"
onClick={api.openDialog}
{...props}
>
{t('queue.clear')}
</Button>
);
});
CancelAllExceptCurrentButton.displayName = 'CancelAllExceptCurrentButton';

View File

@@ -0,0 +1,25 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useCancelAllExceptCurrentQueueItemDialog } from 'features/queue/components/CancelAllExceptCurrentQueueItemConfirmationAlertDialog';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiXCircle } from 'react-icons/pi';
export const CancelAllExceptCurrentIconButton = memo(() => {
const { t } = useTranslation();
const api = useCancelAllExceptCurrentQueueItemDialog();
return (
<IconButton
size="lg"
isDisabled={api.isDisabled}
isLoading={api.isLoading}
aria-label={t('queue.clear')}
tooltip={t('queue.cancelAllExceptCurrentTooltip')}
icon={<PiXCircle />}
colorScheme="error"
onClick={api.openDialog}
/>
);
});
CancelAllExceptCurrentIconButton.displayName = 'CancelAllExceptCurrentIconButton';

View File

@@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next';
const [useCancelAllExceptCurrentQueueItemConfirmationAlertDialog] = buildUseBoolean(false);
const useCancelAllExceptCurrentQueueItemDialog = () => {
export const useCancelAllExceptCurrentQueueItemDialog = () => {
const dialog = useCancelAllExceptCurrentQueueItemConfirmationAlertDialog();
const cancelAllExceptCurrentQueueItem = useCancelAllExceptCurrentQueueItem();

View File

@@ -0,0 +1,29 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useCancelCurrentQueueItem } from 'features/queue/hooks/useCancelCurrentQueueItem';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
export const CancelCurrentQueueItemIconButton = memo(() => {
const { t } = useTranslation();
const cancelCurrentQueueItem = useCancelCurrentQueueItem();
const cancelCurrentQueueItemWithToast = useCallback(() => {
cancelCurrentQueueItem.trigger({ withToast: true });
}, [cancelCurrentQueueItem]);
return (
<IconButton
size="lg"
onClick={cancelCurrentQueueItemWithToast}
isDisabled={cancelCurrentQueueItem.isDisabled}
isLoading={cancelCurrentQueueItem.isLoading}
aria-label={t('queue.cancel')}
tooltip={t('queue.cancelTooltip')}
icon={<PiXBold />}
colorScheme="error"
/>
);
});
CancelCurrentQueueItemIconButton.displayName = 'CancelCurrentQueueItemIconButton';

Some files were not shown because too many files have changed in this diff Show More