mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-22 12:58:22 -05:00
Compare commits
117 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
82fb897b62 | ||
|
|
192b00d969 | ||
|
|
7bb25ef1b4 | ||
|
|
62f52c74a8 | ||
|
|
97439c1daa | ||
|
|
b23bff1b53 | ||
|
|
d9a1efbabf | ||
|
|
d4e903ee2d | ||
|
|
bb3e5d16d8 | ||
|
|
e62d3f01a8 | ||
|
|
757ecdbf82 | ||
|
|
694c85b041 | ||
|
|
988d7ba24c | ||
|
|
ac981879ef | ||
|
|
fc71849c24 | ||
|
|
a19aa3b032 | ||
|
|
ef4d5d7377 | ||
|
|
6b0dfd8427 | ||
|
|
471c010217 | ||
|
|
b1193022f7 | ||
|
|
2152ca092c | ||
|
|
ccc62ba56d | ||
|
|
9cf82de8c5 | ||
|
|
aced349152 | ||
|
|
0d67ee6548 | ||
|
|
03c21d1607 | ||
|
|
752e8db1f5 | ||
|
|
85fc861dd9 | ||
|
|
458cbfd874 | ||
|
|
04331c070a | ||
|
|
632ddf0cb4 | ||
|
|
2b193ff416 | ||
|
|
96ee394f9e | ||
|
|
0badc80c0c | ||
|
|
78e6cbf96e | ||
|
|
0b969a661b | ||
|
|
6fe47ec9f8 | ||
|
|
3850dd61f8 | ||
|
|
75520eaf0f | ||
|
|
10e88c58c1 | ||
|
|
30ed4dbd92 | ||
|
|
ed9c090f33 | ||
|
|
d29f65ed22 | ||
|
|
2062ec8ac0 | ||
|
|
49e818338a | ||
|
|
1caab2b9c4 | ||
|
|
50079ea349 | ||
|
|
fffa1b24c4 | ||
|
|
a6d6170387 | ||
|
|
e5fceb0448 | ||
|
|
059baf5b29 | ||
|
|
1be8a9a310 | ||
|
|
7adc33e04d | ||
|
|
7f2dd22d47 | ||
|
|
bb50f4b8a2 | ||
|
|
a48958e0d4 | ||
|
|
e3a1e9af53 | ||
|
|
c6fe11c42f | ||
|
|
4eb1bd67df | ||
|
|
c376f914d2 | ||
|
|
b5d1c47ef7 | ||
|
|
004a52ca65 | ||
|
|
b1d5a51ddf | ||
|
|
2b2498eaa1 | ||
|
|
10dda4440e | ||
|
|
98f78abefa | ||
|
|
cc93fa270f | ||
|
|
014b27680f | ||
|
|
c3d8f875de | ||
|
|
79f9dc6e4a | ||
|
|
6e1c0c1105 | ||
|
|
0362524040 | ||
|
|
dc6656459b | ||
|
|
3ea1b97f6f | ||
|
|
a7c7405ccc | ||
|
|
c391f1117a | ||
|
|
b1e2cb8401 | ||
|
|
db6af134b7 | ||
|
|
7e6cffb00c | ||
|
|
5b187bcb00 | ||
|
|
0843d609a3 | ||
|
|
95bd9cef18 | ||
|
|
931d6521f6 | ||
|
|
e37665ff59 | ||
|
|
56857fbbe6 | ||
|
|
43cfb8a574 | ||
|
|
05b1682d15 | ||
|
|
69a08ee7f2 | ||
|
|
18212c7d8a | ||
|
|
7de26f8e69 | ||
|
|
0652b12a6f | ||
|
|
43a361a00f | ||
|
|
cf68ad9cbc | ||
|
|
ec02a39325 | ||
|
|
e52d7a05c2 | ||
|
|
c9d4e2b761 | ||
|
|
ac26aa9508 | ||
|
|
9ff6ada15b | ||
|
|
e81a115169 | ||
|
|
52827807de | ||
|
|
b631de4cb5 | ||
|
|
099ebdbc37 | ||
|
|
4de6549be9 | ||
|
|
368be34949 | ||
|
|
5baa4bd916 | ||
|
|
4229377532 | ||
|
|
2610772ffd | ||
|
|
193de6a8f2 | ||
|
|
7ea343c787 | ||
|
|
12179dabba | ||
|
|
ef135f9923 | ||
|
|
e6c67cc00f | ||
|
|
179b988148 | ||
|
|
d913a3c85b | ||
|
|
e79525c40c | ||
|
|
f409f913ac | ||
|
|
7a79f61d4c |
@@ -72,7 +72,7 @@ async def upload_image(
|
||||
resize_to: Optional[str] = Body(
|
||||
default=None,
|
||||
description=f"Dimensions to resize the image to, must be stringified tuple of 2 integers. Max total pixel count: {ResizeToDimensions.MAX_SIZE}",
|
||||
example='"[1024,1024]"',
|
||||
examples=['"[1024,1024]"'],
|
||||
),
|
||||
metadata: Optional[str] = Body(
|
||||
default=None,
|
||||
|
||||
@@ -292,7 +292,7 @@ async def get_hugging_face_models(
|
||||
)
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
|
||||
changes: Annotated[ModelRecordChanges, Body(description="Model config", examples=[example_model_input])],
|
||||
) -> AnyModelConfig:
|
||||
"""Update a model's config."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
@@ -450,7 +450,7 @@ async def install_model(
|
||||
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
|
||||
config: ModelRecordChanges = Body(
|
||||
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
example={"name": "string", "description": "string"},
|
||||
examples=[{"name": "string", "description": "string"}],
|
||||
),
|
||||
) -> ModelInstallJob:
|
||||
"""Install a model using a string identifier.
|
||||
|
||||
@@ -14,15 +14,14 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._conn = db.conn
|
||||
self._db = db
|
||||
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO board_images (board_id, image_name)
|
||||
@@ -31,17 +30,12 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
""",
|
||||
(board_id, image_name, board_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM board_images
|
||||
@@ -49,10 +43,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
def get_images_for_board(
|
||||
self,
|
||||
@@ -60,27 +50,26 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
# TODO: this isn't paginated yet?
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
ORDER BY board_images.updated_at DESC;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
ORDER BY board_images.updated_at DESC;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images WHERE 1=1;
|
||||
"""
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images WHERE 1=1;
|
||||
"""
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||
|
||||
@@ -90,56 +79,55 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
categories: list[ImageCategory] | None,
|
||||
is_intermediate: bool | None,
|
||||
) -> list[str]:
|
||||
params: list[str | bool] = []
|
||||
with self._db.transaction() as cursor:
|
||||
params: list[str | bool] = []
|
||||
|
||||
# Base query is a join between images and board_images
|
||||
stmt = """
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
# Base query is a join between images and board_images
|
||||
stmt = """
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
# Handle board_id filter
|
||||
if board_id == "none":
|
||||
stmt += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
else:
|
||||
stmt += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
params.append(board_id)
|
||||
# Handle board_id filter
|
||||
if board_id == "none":
|
||||
stmt += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
else:
|
||||
stmt += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
params.append(board_id)
|
||||
|
||||
# Add the category filter
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
stmt += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
# Add the category filter
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
stmt += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
params.append(c)
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
params.append(c)
|
||||
|
||||
# Add the is_intermediate filter
|
||||
if is_intermediate is not None:
|
||||
stmt += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
params.append(is_intermediate)
|
||||
# Add the is_intermediate filter
|
||||
if is_intermediate is not None:
|
||||
stmt += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
params.append(is_intermediate)
|
||||
|
||||
# Put a ring on it
|
||||
stmt += ";"
|
||||
# Put a ring on it
|
||||
stmt += ";"
|
||||
|
||||
# Execute the query
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(stmt, params)
|
||||
cursor.execute(stmt, params)
|
||||
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [r[0] for r in result]
|
||||
return image_names
|
||||
|
||||
@@ -147,31 +135,31 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Optional[str]:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT board_id
|
||||
FROM board_images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT board_id
|
||||
FROM board_images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result is None:
|
||||
return None
|
||||
return cast(str, result[0])
|
||||
|
||||
def get_image_count_for_board(self, board_id: str) -> int:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE images.is_intermediate = FALSE
|
||||
AND board_images.board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE images.is_intermediate = FALSE
|
||||
AND board_images.board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
return count
|
||||
|
||||
@@ -20,61 +20,57 @@ from invokeai.app.util.misc import uuid_string
|
||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._conn = db.conn
|
||||
self._db = db
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordDeleteException from e
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
except Exception as e:
|
||||
raise BoardRecordDeleteException from e
|
||||
|
||||
def save(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
board_id = uuid_string()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO boards (board_id, board_name)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
(board_id, board_name),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordSaveException from e
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
board_id = uuid_string()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO boards (board_id, board_name)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
(board_id, board_name),
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise BoardRecordSaveException from e
|
||||
return self.get(board_id)
|
||||
|
||||
def get(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
raise BoardRecordNotFoundException from e
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
raise BoardRecordNotFoundException from e
|
||||
if result is None:
|
||||
raise BoardRecordNotFoundException
|
||||
return BoardRecord(**dict(result))
|
||||
@@ -84,45 +80,43 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
# Change the name of a board
|
||||
if changes.board_name is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET board_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.board_name, board_id),
|
||||
)
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
# Change the name of a board
|
||||
if changes.board_name is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET board_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.board_name, board_id),
|
||||
)
|
||||
|
||||
# Change the cover image of a board
|
||||
if changes.cover_image_name is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET cover_image_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.cover_image_name, board_id),
|
||||
)
|
||||
# Change the cover image of a board
|
||||
if changes.cover_image_name is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET cover_image_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.cover_image_name, board_id),
|
||||
)
|
||||
|
||||
# Change the archived status of a board
|
||||
if changes.archived is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET archived = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.archived, board_id),
|
||||
)
|
||||
# Change the archived status of a board
|
||||
if changes.archived is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET archived = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.archived, board_id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordSaveException from e
|
||||
except sqlite3.Error as e:
|
||||
raise BoardRecordSaveException from e
|
||||
return self.get(board_id)
|
||||
|
||||
def get_many(
|
||||
@@ -133,78 +127,77 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
limit: int = 10,
|
||||
include_archived: bool = False,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
# Build base query
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
LIMIT ? OFFSET ?;
|
||||
"""
|
||||
|
||||
# Determine archived filter condition
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
# Execute query to fetch boards
|
||||
cursor.execute(final_query, (limit, offset))
|
||||
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
# Determine count query
|
||||
if include_archived:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards;
|
||||
"""
|
||||
else:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
with self._db.transaction() as cursor:
|
||||
# Build base query
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
WHERE archived = 0;
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
LIMIT ? OFFSET ?;
|
||||
"""
|
||||
|
||||
# Execute count query
|
||||
cursor.execute(count_query)
|
||||
# Determine archived filter condition
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
# Execute query to fetch boards
|
||||
cursor.execute(final_query, (limit, offset))
|
||||
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
# Determine count query
|
||||
if include_archived:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards;
|
||||
"""
|
||||
else:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards
|
||||
WHERE archived = 0;
|
||||
"""
|
||||
|
||||
# Execute count query
|
||||
cursor.execute(count_query)
|
||||
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
|
||||
|
||||
def get_all(
|
||||
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
|
||||
) -> list[BoardRecord]:
|
||||
cursor = self._conn.cursor()
|
||||
if order_by == BoardRecordOrderBy.Name:
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY LOWER(board_name) {direction}
|
||||
"""
|
||||
else:
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
"""
|
||||
with self._db.transaction() as cursor:
|
||||
if order_by == BoardRecordOrderBy.Name:
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY LOWER(board_name) {direction}
|
||||
"""
|
||||
else:
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
"""
|
||||
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
cursor.execute(final_query)
|
||||
cursor.execute(final_query)
|
||||
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
return boards
|
||||
|
||||
@@ -8,6 +8,7 @@ import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from queue import Empty, PriorityQueue
|
||||
from shutil import disk_usage
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
|
||||
|
||||
import requests
|
||||
@@ -335,6 +336,14 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
|
||||
assert job.download_path
|
||||
|
||||
free_space = disk_usage(job.download_path.parent).free
|
||||
GB = 2**30
|
||||
self._logger.debug(f"Download is {job.total_bytes / GB:.2f} GB of {free_space / GB:.2f} GB free.")
|
||||
if free_space < job.total_bytes:
|
||||
raise RuntimeError(
|
||||
f"Free disk space {free_space / GB:.2f} GB is not enough for download of {job.total_bytes / GB:.2f} GB."
|
||||
)
|
||||
|
||||
# Don't clobber an existing file. See commit 82c2c85202f88c6d24ff84710f297cfc6ae174af
|
||||
# for code that instead resumes an interrupted download.
|
||||
if job.download_path.exists():
|
||||
|
||||
@@ -24,22 +24,22 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._conn = db.conn
|
||||
self._db = db
|
||||
|
||||
def get(self, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS} FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS} FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordNotFoundException from e
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordNotFoundException from e
|
||||
|
||||
if not result:
|
||||
raise ImageRecordNotFoundException
|
||||
@@ -47,17 +47,20 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
return deserialize_image_record(dict(result))
|
||||
|
||||
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordNotFoundException from e
|
||||
|
||||
if not result:
|
||||
raise ImageRecordNotFoundException
|
||||
@@ -65,64 +68,60 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
as_dict = dict(result)
|
||||
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
|
||||
return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordNotFoundException from e
|
||||
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
# Change the category of the image
|
||||
if changes.image_category is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET image_category = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.image_category, image_name),
|
||||
)
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
# Change the category of the image
|
||||
if changes.image_category is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET image_category = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.image_category, image_name),
|
||||
)
|
||||
|
||||
# Change the session associated with the image
|
||||
if changes.session_id is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET session_id = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.session_id, image_name),
|
||||
)
|
||||
# Change the session associated with the image
|
||||
if changes.session_id is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET session_id = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.session_id, image_name),
|
||||
)
|
||||
|
||||
# Change the image's `is_intermediate`` flag
|
||||
if changes.is_intermediate is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET is_intermediate = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.is_intermediate, image_name),
|
||||
)
|
||||
# Change the image's `is_intermediate`` flag
|
||||
if changes.is_intermediate is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET is_intermediate = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.is_intermediate, image_name),
|
||||
)
|
||||
|
||||
# Change the image's `starred`` state
|
||||
if changes.starred is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET starred = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.starred, image_name),
|
||||
)
|
||||
# Change the image's `starred`` state
|
||||
if changes.starred is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET starred = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.starred, image_name),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordSaveException from e
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordSaveException from e
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
@@ -136,170 +135,162 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
count_query = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
images_query = f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS}
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
with self._db.transaction() as cursor:
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
count_query = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
images_query = f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS}
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
query_params.append(is_intermediate)
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
|
||||
# board_id of "none" is reserved for images without a board
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
# Search term condition
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
|
||||
if starred_first:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
else:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
|
||||
# Final images query with pagination
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
# Add all the parameters
|
||||
images_params = query_params.copy()
|
||||
# Add the pagination parameters
|
||||
images_params.extend([limit, offset])
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
# Build the list of images, deserializing each row
|
||||
cursor.execute(images_query, images_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
|
||||
# Set up and execute the count query, without pagination
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
cursor.execute(count_query, count_params)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
# board_id of "none" is reserved for images without a board
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
# Search term condition
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
if starred_first:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
else:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
# Final images query with pagination
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
# Add all the parameters
|
||||
images_params = query_params.copy()
|
||||
# Add the pagination parameters
|
||||
images_params.extend([limit, offset])
|
||||
|
||||
# Build the list of images, deserializing each row
|
||||
cursor.execute(images_query, images_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
|
||||
# Set up and execute the count query, without pagination
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
cursor.execute(count_query, count_params)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||
|
||||
def delete(self, image_name: str) -> None:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordDeleteException from e
|
||||
|
||||
def delete_many(self, image_names: list[str]) -> None:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
placeholders = ",".join("?" for _ in image_names)
|
||||
|
||||
placeholders = ",".join("?" for _ in image_names)
|
||||
# Construct the SQLite query with the placeholders
|
||||
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
|
||||
|
||||
# Construct the SQLite query with the placeholders
|
||||
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
|
||||
# Execute the query with the list of IDs as parameters
|
||||
cursor.execute(query, image_names)
|
||||
|
||||
# Execute the query with the list of IDs as parameters
|
||||
cursor.execute(query, image_names)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordDeleteException from e
|
||||
|
||||
def get_intermediates_count(self) -> int:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
self._conn.commit()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
return count
|
||||
|
||||
def delete_intermediates(self) -> list[str]:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT image_name FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [r[0] for r in result]
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
self._conn.commit()
|
||||
return image_names
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT image_name FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [r[0] for r in result]
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordDeleteException from e
|
||||
return image_names
|
||||
|
||||
def save(
|
||||
self,
|
||||
@@ -315,73 +306,71 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
node_id: Optional[str] = None,
|
||||
metadata: Optional[str] = None,
|
||||
) -> datetime:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO images (
|
||||
image_name,
|
||||
image_origin,
|
||||
image_category,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
image_origin.value,
|
||||
image_category.value,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO images (
|
||||
image_name,
|
||||
image_origin,
|
||||
image_category,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
image_origin.value,
|
||||
image_category.value,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow,
|
||||
),
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT created_at
|
||||
FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT created_at
|
||||
FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
created_at = datetime.fromisoformat(cursor.fetchone()[0])
|
||||
created_at = datetime.fromisoformat(cursor.fetchone()[0])
|
||||
|
||||
return created_at
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordSaveException from e
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordSaveException from e
|
||||
return created_at
|
||||
|
||||
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM images
|
||||
JOIN board_images ON images.image_name = board_images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
AND images.is_intermediate = FALSE
|
||||
ORDER BY images.starred DESC, images.created_at DESC
|
||||
LIMIT 1;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM images
|
||||
JOIN board_images ON images.image_name = board_images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
AND images.is_intermediate = FALSE
|
||||
ORDER BY images.starred DESC, images.created_at DESC
|
||||
LIMIT 1;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
@@ -398,85 +387,84 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> ImageNamesResult:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
# Build query conditions (reused for both starred count and image names queries)
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
|
||||
# Build query conditions (reused for both starred count and image names queries)
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
if categories is not None:
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if categories is not None:
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
query_params.append(is_intermediate)
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
# Get starred count if starred_first is enabled
|
||||
starred_count = 0
|
||||
if starred_first:
|
||||
starred_count_query = f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE images.starred = TRUE AND (1=1{query_conditions})
|
||||
"""
|
||||
cursor.execute(starred_count_query, query_params)
|
||||
starred_count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
# Get starred count if starred_first is enabled
|
||||
starred_count = 0
|
||||
if starred_first:
|
||||
starred_count_query = f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE images.starred = TRUE AND (1=1{query_conditions})
|
||||
"""
|
||||
cursor.execute(starred_count_query, query_params)
|
||||
starred_count = cast(int, cursor.fetchone()[0])
|
||||
# Get all image names with proper ordering
|
||||
if starred_first:
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value}
|
||||
"""
|
||||
else:
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.created_at {order_dir.value}
|
||||
"""
|
||||
|
||||
# Get all image names with proper ordering
|
||||
if starred_first:
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value}
|
||||
"""
|
||||
else:
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.created_at {order_dir.value}
|
||||
"""
|
||||
|
||||
cursor.execute(names_query, query_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
cursor.execute(names_query, query_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [row[0] for row in result]
|
||||
|
||||
return ImageNamesResult(image_names=image_names, starred_count=starred_count, total_count=len(image_names))
|
||||
|
||||
@@ -78,11 +78,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db = db
|
||||
self._logger = logger
|
||||
|
||||
@property
|
||||
def db(self) -> SqliteDatabase:
|
||||
"""Return the underlying database."""
|
||||
return self._db
|
||||
|
||||
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
@@ -93,38 +88,33 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||
"""
|
||||
try:
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO models (
|
||||
id,
|
||||
config
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
config.key,
|
||||
config.model_dump_json(),
|
||||
),
|
||||
)
|
||||
self._db.conn.commit()
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO models (
|
||||
id,
|
||||
config
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
config.key,
|
||||
config.model_dump_json(),
|
||||
),
|
||||
)
|
||||
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._db.conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "models.path" in str(e):
|
||||
msg = f"A model with path '{config.path}' is already installed"
|
||||
elif "models.name" in str(e):
|
||||
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
||||
except sqlite3.IntegrityError as e:
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "models.path" in str(e):
|
||||
msg = f"A model with path '{config.path}' is already installed"
|
||||
elif "models.name" in str(e):
|
||||
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
||||
else:
|
||||
msg = f"A model with key '{config.key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
else:
|
||||
msg = f"A model with key '{config.key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
else:
|
||||
raise e
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
raise e
|
||||
|
||||
return self.get_model(config.key)
|
||||
|
||||
@@ -136,8 +126,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise an UnknownModelException
|
||||
"""
|
||||
try:
|
||||
cursor = self._db.conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM models
|
||||
@@ -147,22 +136,17 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
)
|
||||
if cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
record = self.get_model(key)
|
||||
with self._db.transaction() as cursor:
|
||||
record = self.get_model(key)
|
||||
|
||||
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
|
||||
for field_name in changes.model_fields_set:
|
||||
setattr(record, field_name, getattr(changes, field_name))
|
||||
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
|
||||
for field_name in changes.model_fields_set:
|
||||
setattr(record, field_name, getattr(changes, field_name))
|
||||
|
||||
json_serialized = record.model_dump_json()
|
||||
json_serialized = record.model_dump_json()
|
||||
|
||||
try:
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE models
|
||||
@@ -174,10 +158,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
)
|
||||
if cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(key)
|
||||
|
||||
@@ -189,30 +169,30 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
"""
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
rows = cursor.fetchone()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
rows = cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
rows = cursor.fetchone()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
rows = cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
@@ -224,15 +204,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
:param key: Unique key for the model to be deleted
|
||||
"""
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
return count > 0
|
||||
|
||||
def search_by_attr(
|
||||
@@ -255,43 +235,42 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
"""
|
||||
with self._db.transaction() as cursor:
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
where_clause: list[str] = []
|
||||
bindings: list[str] = []
|
||||
if model_name:
|
||||
where_clause.append("name=?")
|
||||
bindings.append(model_name)
|
||||
if base_model:
|
||||
where_clause.append("base=?")
|
||||
bindings.append(base_model)
|
||||
if model_type:
|
||||
where_clause.append("type=?")
|
||||
bindings.append(model_type)
|
||||
if model_format:
|
||||
where_clause.append("format=?")
|
||||
bindings.append(model_format)
|
||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||
|
||||
where_clause: list[str] = []
|
||||
bindings: list[str] = []
|
||||
if model_name:
|
||||
where_clause.append("name=?")
|
||||
bindings.append(model_name)
|
||||
if base_model:
|
||||
where_clause.append("base=?")
|
||||
bindings.append(base_model)
|
||||
if model_type:
|
||||
where_clause.append("type=?")
|
||||
bindings.append(model_type)
|
||||
if model_format:
|
||||
where_clause.append("format=?")
|
||||
bindings.append(model_format)
|
||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config, strftime('%s',updated_at)
|
||||
FROM models
|
||||
{where}
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
result = cursor.fetchall()
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config, strftime('%s',updated_at)
|
||||
FROM models
|
||||
{where}
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
result = cursor.fetchall()
|
||||
|
||||
# Parse the model configs.
|
||||
results: list[AnyModelConfig] = []
|
||||
@@ -313,69 +292,68 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated path."""
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
return results
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated hash."""
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
return results
|
||||
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Return a paginated summary listing of each model in the database."""
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
with self._db.transaction() as cursor:
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
|
||||
cursor = self._db.conn.cursor()
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
# query1: get the total number of model configs
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from models;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(cursor.fetchone()[0])
|
||||
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
# query1: get the total number of model configs
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from models;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(cursor.fetchone()[0])
|
||||
|
||||
# query2: fetch key fields
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config
|
||||
FROM models
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
""",
|
||||
(
|
||||
per_page,
|
||||
page * per_page,
|
||||
),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
# query2: fetch key fields
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config
|
||||
FROM models
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
""",
|
||||
(
|
||||
per_page,
|
||||
page * per_page,
|
||||
),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
items = [ModelSummary.model_validate(dict(x)) for x in rows]
|
||||
return PaginatedResults(page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.model_relationship_records.model_relationship_records_base import (
|
||||
ModelRelationshipRecordStorageBase,
|
||||
)
|
||||
@@ -9,58 +7,49 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._conn = db.conn
|
||||
self._db = db
|
||||
|
||||
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
|
||||
if model_key_1 == model_key_2:
|
||||
raise ValueError("Cannot relate a model to itself.")
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
if model_key_1 == model_key_2:
|
||||
raise ValueError("Cannot relate a model to itself.")
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
cursor.execute(
|
||||
"INSERT OR IGNORE INTO model_relationships (model_key_1, model_key_2) VALUES (?, ?)",
|
||||
(a, b),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
cursor.execute(
|
||||
"DELETE FROM model_relationships WHERE model_key_1 = ? AND model_key_2 = ?",
|
||||
(a, b),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
def get_related_model_keys(self, model_key: str) -> list[str]:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
|
||||
UNION
|
||||
SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
|
||||
""",
|
||||
(model_key, model_key),
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
|
||||
UNION
|
||||
SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
|
||||
""",
|
||||
(model_key, model_key),
|
||||
)
|
||||
result = [row[0] for row in cursor.fetchall()]
|
||||
return result
|
||||
|
||||
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
key_list = ",".join("?" for _ in model_keys)
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
|
||||
UNION
|
||||
SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
|
||||
""",
|
||||
model_keys + model_keys,
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
with self._db.transaction() as cursor:
|
||||
key_list = ",".join("?" for _ in model_keys)
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
|
||||
UNION
|
||||
SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
|
||||
""",
|
||||
model_keys + model_keys,
|
||||
)
|
||||
result = [row[0] for row in cursor.fetchall()]
|
||||
return result
|
||||
|
||||
@@ -50,15 +50,14 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._conn = db.conn
|
||||
self._db = db
|
||||
|
||||
def _set_in_progress_to_canceled(self) -> None:
|
||||
"""
|
||||
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
|
||||
This is necessary because the invoker may have been killed while processing a queue item.
|
||||
"""
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
@@ -66,87 +65,79 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
WHERE status = 'in_progress';
|
||||
"""
|
||||
)
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
def _get_current_queue_size(self, queue_id: str) -> int:
|
||||
"""Gets the current number of pending queue items"""
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
return cast(int, cursor.fetchone()[0])
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
return count
|
||||
|
||||
def _get_highest_priority(self, queue_id: str) -> int:
|
||||
"""Gets the highest priority value in the queue"""
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT MAX(priority)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
return cast(Union[int, None], cursor.fetchone()[0]) or 0
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT MAX(priority)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
priority = cast(Union[int, None], cursor.fetchone()[0]) or 0
|
||||
return priority
|
||||
|
||||
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||
try:
|
||||
# TODO: how does this work in a multi-user scenario?
|
||||
current_queue_size = self._get_current_queue_size(queue_id)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
max_new_queue_items = max_queue_size - current_queue_size
|
||||
current_queue_size = self._get_current_queue_size(queue_id)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
max_new_queue_items = max_queue_size - current_queue_size
|
||||
|
||||
priority = 0
|
||||
if prepend:
|
||||
priority = self._get_highest_priority(queue_id) + 1
|
||||
priority = 0
|
||||
if prepend:
|
||||
priority = self._get_highest_priority(queue_id) + 1
|
||||
|
||||
requested_count = await asyncio.to_thread(
|
||||
calc_session_count,
|
||||
batch=batch,
|
||||
)
|
||||
values_to_insert = await asyncio.to_thread(
|
||||
prepare_values_to_insert,
|
||||
queue_id=queue_id,
|
||||
batch=batch,
|
||||
priority=priority,
|
||||
max_new_queue_items=max_new_queue_items,
|
||||
)
|
||||
enqueued_count = len(values_to_insert)
|
||||
requested_count = await asyncio.to_thread(
|
||||
calc_session_count,
|
||||
batch=batch,
|
||||
)
|
||||
values_to_insert = await asyncio.to_thread(
|
||||
prepare_values_to_insert,
|
||||
queue_id=queue_id,
|
||||
batch=batch,
|
||||
priority=priority,
|
||||
max_new_queue_items=max_new_queue_items,
|
||||
)
|
||||
enqueued_count = len(values_to_insert)
|
||||
|
||||
with self._conn:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.executemany(
|
||||
"""--sql
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
with self._conn:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
values_to_insert,
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT item_id
|
||||
FROM session_queue
|
||||
WHERE batch_id = ?
|
||||
ORDER BY item_id DESC;
|
||||
""",
|
||||
(batch.batch_id,),
|
||||
)
|
||||
item_ids = [row[0] for row in cursor.fetchall()]
|
||||
except Exception:
|
||||
raise
|
||||
(batch.batch_id,),
|
||||
)
|
||||
item_ids = [row[0] for row in cursor.fetchall()]
|
||||
enqueue_result = EnqueueBatchResult(
|
||||
queue_id=queue_id,
|
||||
requested=requested_count,
|
||||
@@ -159,19 +150,19 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return enqueue_result
|
||||
|
||||
def dequeue(self) -> Optional[SessionQueueItem]:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
return None
|
||||
queue_item = SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
@@ -179,40 +170,40 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return queue_item
|
||||
|
||||
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
created_at ASC
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
created_at ASC
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
return None
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
|
||||
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'in_progress'
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'in_progress'
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
return None
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
@@ -225,8 +216,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
error_message: Optional[str] = None,
|
||||
error_traceback: Optional[str] = None,
|
||||
) -> SessionQueueItem:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status FROM session_queue WHERE item_id = ?
|
||||
@@ -234,12 +224,15 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
(item_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
current_status = row[0]
|
||||
# Only update if not already finished (completed, failed or canceled)
|
||||
if current_status in ("completed", "failed", "canceled"):
|
||||
return self.get_queue_item(item_id)
|
||||
if row is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
current_status = row[0]
|
||||
|
||||
# Only update if not already finished (completed, failed or canceled)
|
||||
if current_status in ("completed", "failed", "canceled"):
|
||||
return self.get_queue_item(item_id)
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
@@ -248,10 +241,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(status, error_type, error_message, error_traceback, item_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
|
||||
@@ -259,35 +249,34 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return queue_item
|
||||
|
||||
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
is_empty = cast(int, cursor.fetchone()[0]) == 0
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
is_empty = cast(int, cursor.fetchone()[0]) == 0
|
||||
return IsEmptyResult(is_empty=is_empty)
|
||||
|
||||
def is_full(self, queue_id: str) -> IsFullResult:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
|
||||
return IsFullResult(is_full=is_full)
|
||||
|
||||
def clear(self, queue_id: str) -> ClearResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
@@ -305,24 +294,19 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
self.__invoker.services.events.emit_queue_cleared(queue_id)
|
||||
return ClearResult(deleted=count)
|
||||
|
||||
def prune(self, queue_id: str) -> PruneResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND (
|
||||
queue_id = ?
|
||||
AND (
|
||||
status = 'completed'
|
||||
OR status = 'failed'
|
||||
OR status = 'canceled'
|
||||
)
|
||||
)
|
||||
"""
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
@@ -341,10 +325,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return PruneResult(deleted=count)
|
||||
|
||||
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
@@ -357,8 +337,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.cancel_queue_item(item_id)
|
||||
except SessionQueueItemNotFoundError:
|
||||
pass
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE
|
||||
@@ -367,10 +346,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
|
||||
@@ -393,8 +368,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return queue_item
|
||||
|
||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
placeholders = ", ".join(["?" for _ in batch_ids])
|
||||
where = f"""--sql
|
||||
@@ -404,6 +378,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
-- We will cancel the current item separately below - skip it here
|
||||
AND status != 'in_progress'
|
||||
"""
|
||||
params = [queue_id] + batch_ids
|
||||
cursor.execute(
|
||||
@@ -423,17 +399,14 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
self._conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
|
||||
return CancelByBatchIDsResult(canceled=count)
|
||||
|
||||
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
where = """--sql
|
||||
WHERE
|
||||
@@ -442,6 +415,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
-- We will cancel the current item separately below - skip it here
|
||||
AND status != 'in_progress'
|
||||
"""
|
||||
params = (queue_id, destination)
|
||||
cursor.execute(
|
||||
@@ -461,17 +436,12 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
params,
|
||||
)
|
||||
self._conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
return CancelByDestinationResult(canceled=count)
|
||||
|
||||
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
self.cancel_queue_item(current_queue_item.item_id)
|
||||
@@ -497,15 +467,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
params,
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return DeleteByDestinationResult(deleted=count)
|
||||
|
||||
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
@@ -528,15 +493,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return DeleteAllExceptCurrentResult(deleted=count)
|
||||
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
where = """--sql
|
||||
WHERE
|
||||
@@ -544,6 +504,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
-- We will cancel the current item separately below - skip it here
|
||||
AND status != 'in_progress'
|
||||
"""
|
||||
params = [queue_id]
|
||||
cursor.execute(
|
||||
@@ -563,21 +525,13 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
self._conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
current_queue_item, batch_status, queue_status
|
||||
)
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
return CancelByQueueIDResult(canceled=count)
|
||||
|
||||
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
@@ -600,30 +554,25 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return CancelAllExceptCurrentResult(canceled=count)
|
||||
|
||||
def get_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT * FROM session_queue
|
||||
WHERE
|
||||
item_id = ?
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT * FROM session_queue
|
||||
WHERE
|
||||
item_id = ?
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
|
||||
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
|
||||
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
|
||||
# during execution.
|
||||
@@ -636,10 +585,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(session_json, item_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return self.get_queue_item(item_id)
|
||||
|
||||
def list_queue_items(
|
||||
@@ -651,42 +596,42 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||
destination: Optional[str] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItem]:
|
||||
cursor_ = self._conn.cursor()
|
||||
item_id = cursor
|
||||
query = """--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if status is not None:
|
||||
query += """--sql
|
||||
AND status = ?
|
||||
"""
|
||||
params.append(status)
|
||||
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
with self._db.transaction() as cursor_:
|
||||
item_id = cursor
|
||||
query = """--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params.append(destination)
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if item_id is not None:
|
||||
query += """--sql
|
||||
AND (priority < ?) OR (priority = ? AND item_id > ?)
|
||||
if status is not None:
|
||||
query += """--sql
|
||||
AND status = ?
|
||||
"""
|
||||
params.append(status)
|
||||
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
"""
|
||||
params.extend([priority, priority, item_id])
|
||||
params.append(destination)
|
||||
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit + 1)
|
||||
cursor_.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor_.fetchall())
|
||||
if item_id is not None:
|
||||
query += """--sql
|
||||
AND (priority < ?) OR (priority = ? AND item_id > ?)
|
||||
"""
|
||||
params.extend([priority, priority, item_id])
|
||||
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit + 1)
|
||||
cursor_.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor_.fetchall())
|
||||
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
|
||||
has_more = False
|
||||
if len(items) > limit:
|
||||
@@ -701,43 +646,43 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
destination: Optional[str] = None,
|
||||
) -> list[SessionQueueItem]:
|
||||
"""Gets all queue items that match the given parameters"""
|
||||
cursor_ = self._conn.cursor()
|
||||
query = """--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
with self._db.transaction() as cursor:
|
||||
query = """--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params.append(destination)
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
;
|
||||
"""
|
||||
cursor_.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor_.fetchall())
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
"""
|
||||
params.append(destination)
|
||||
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
;
|
||||
"""
|
||||
cursor.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
|
||||
return items
|
||||
|
||||
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
current_item = self.get_current(queue_id=queue_id)
|
||||
total = sum(row[1] or 0 for row in counts_result)
|
||||
@@ -756,19 +701,19 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
|
||||
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*), origin, destination
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND batch_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, batch_id),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*), origin, destination
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND batch_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, batch_id),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
total = sum(row[1] or 0 for row in result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||
origin = result[0]["origin"] if result else None
|
||||
@@ -788,18 +733,18 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
|
||||
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
AND destination = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, destination),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
AND destination = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, destination),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
total = sum(row[1] or 0 for row in counts_result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
|
||||
@@ -817,8 +762,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
|
||||
"""Retries the given queue items"""
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
values_to_insert: list[ValueToInsertTuple] = []
|
||||
retried_item_ids: list[int] = []
|
||||
|
||||
@@ -869,10 +813,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
values_to_insert,
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
retry_result = RetryItemsResult(
|
||||
queue_id=queue_id,
|
||||
retried_item_ids=retried_item_ids,
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
@@ -26,46 +29,65 @@ class SqliteDatabase:
|
||||
|
||||
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
|
||||
"""Initializes the database. This is used internally by the class constructor."""
|
||||
self.logger = logger
|
||||
self.db_path = db_path
|
||||
self.verbose = verbose
|
||||
self._logger = logger
|
||||
self._db_path = db_path
|
||||
self._verbose = verbose
|
||||
self._lock = threading.RLock()
|
||||
|
||||
if not self.db_path:
|
||||
if not self._db_path:
|
||||
logger.info("Initializing in-memory database")
|
||||
else:
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.logger.info(f"Initializing database at {self.db_path}")
|
||||
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._logger.info(f"Initializing database at {self._db_path}")
|
||||
|
||||
self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
self._conn = sqlite3.connect(database=self._db_path or sqlite_memory, check_same_thread=False)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
|
||||
if self.verbose:
|
||||
self.conn.set_trace_callback(self.logger.debug)
|
||||
if self._verbose:
|
||||
self._conn.set_trace_callback(self._logger.debug)
|
||||
|
||||
# Enable foreign key constraints
|
||||
self.conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
|
||||
# Enable Write-Ahead Logging (WAL) mode for better concurrency
|
||||
self.conn.execute("PRAGMA journal_mode = WAL;")
|
||||
self._conn.execute("PRAGMA journal_mode = WAL;")
|
||||
|
||||
# Set a busy timeout to prevent database lockups during writes
|
||||
self.conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
|
||||
self._conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
|
||||
|
||||
def clean(self) -> None:
|
||||
"""
|
||||
Cleans the database by running the VACUUM command, reporting on the freed space.
|
||||
"""
|
||||
# No need to clean in-memory database
|
||||
if not self.db_path:
|
||||
if not self._db_path:
|
||||
return
|
||||
try:
|
||||
initial_db_size = Path(self.db_path).stat().st_size
|
||||
self.conn.execute("VACUUM;")
|
||||
self.conn.commit()
|
||||
final_db_size = Path(self.db_path).stat().st_size
|
||||
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
||||
if freed_space_in_mb > 0:
|
||||
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
||||
with self._conn as conn:
|
||||
initial_db_size = Path(self._db_path).stat().st_size
|
||||
conn.execute("VACUUM;")
|
||||
conn.commit()
|
||||
final_db_size = Path(self._db_path).stat().st_size
|
||||
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
||||
if freed_space_in_mb > 0:
|
||||
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error cleaning database: {e}")
|
||||
self._logger.error(f"Error cleaning database: {e}")
|
||||
raise
|
||||
|
||||
@contextmanager
|
||||
def transaction(self) -> Generator[sqlite3.Cursor, None, None]:
|
||||
"""
|
||||
Thread-safe context manager for DB work.
|
||||
Acquires the RLock, yields a Cursor, then commits or rolls back.
|
||||
"""
|
||||
with self._lock:
|
||||
cursor = self._conn.cursor()
|
||||
try:
|
||||
yield cursor
|
||||
self._conn.commit()
|
||||
except:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
@@ -32,7 +32,7 @@ class SqliteMigrator:
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
self._db = db
|
||||
self._logger = db.logger
|
||||
self._logger = db._logger
|
||||
self._migration_set = MigrationSet()
|
||||
self._backup_path: Optional[Path] = None
|
||||
|
||||
@@ -45,7 +45,7 @@ class SqliteMigrator:
|
||||
"""Migrates the database to the latest version."""
|
||||
# This throws if there is a problem.
|
||||
self._migration_set.validate_migration_chain()
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor = self._db._conn.cursor()
|
||||
self._create_migrations_table(cursor=cursor)
|
||||
|
||||
if self._migration_set.count == 0:
|
||||
@@ -59,13 +59,13 @@ class SqliteMigrator:
|
||||
self._logger.info("Database update needed")
|
||||
|
||||
# Make a backup of the db if it needs to be updated and is a file db
|
||||
if self._db.db_path is not None:
|
||||
if self._db._db_path is not None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
|
||||
self._backup_path = self._db._db_path.parent / f"{self._db._db_path.stem}_backup_{timestamp}.db"
|
||||
self._logger.info(f"Backing up database to {str(self._backup_path)}")
|
||||
# Use SQLite to do the backup
|
||||
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
|
||||
self._db.conn.backup(backup_conn)
|
||||
self._db._conn.backup(backup_conn)
|
||||
else:
|
||||
self._logger.info("Using in-memory database, no backup needed")
|
||||
|
||||
@@ -81,7 +81,7 @@ class SqliteMigrator:
|
||||
try:
|
||||
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
|
||||
# exception is raised.
|
||||
with self._db.conn as conn:
|
||||
with self._db._conn as conn:
|
||||
cursor = conn.cursor()
|
||||
if self._get_current_version(cursor) != migration.from_version:
|
||||
raise MigrationError(
|
||||
|
||||
@@ -17,7 +17,7 @@ from invokeai.app.util.misc import uuid_string
|
||||
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._conn = db.conn
|
||||
self._db = db
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
@@ -25,24 +25,23 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
|
||||
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
|
||||
"""Gets a style preset by ID."""
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM style_presets
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM style_presets
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
|
||||
return StylePresetRecordDTO.from_dict(dict(row))
|
||||
|
||||
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
|
||||
style_preset_id = uuid_string()
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO style_presets (
|
||||
@@ -60,16 +59,11 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
style_preset.type,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
|
||||
style_preset_ids = []
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
for style_preset in style_presets:
|
||||
style_preset_id = uuid_string()
|
||||
style_preset_ids.append(style_preset_id)
|
||||
@@ -90,16 +84,11 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
style_preset.type,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
return None
|
||||
|
||||
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
# Change the name of a style preset
|
||||
if changes.name is not None:
|
||||
cursor.execute(
|
||||
@@ -122,15 +111,10 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
(changes.preset_data.model_dump_json(), style_preset_id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE from style_presets
|
||||
@@ -138,51 +122,41 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return None
|
||||
|
||||
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
|
||||
main_query = """
|
||||
SELECT
|
||||
*
|
||||
FROM style_presets
|
||||
"""
|
||||
with self._db.transaction() as cursor:
|
||||
main_query = """
|
||||
SELECT
|
||||
*
|
||||
FROM style_presets
|
||||
"""
|
||||
|
||||
if type is not None:
|
||||
main_query += "WHERE type = ? "
|
||||
if type is not None:
|
||||
main_query += "WHERE type = ? "
|
||||
|
||||
main_query += "ORDER BY LOWER(name) ASC"
|
||||
main_query += "ORDER BY LOWER(name) ASC"
|
||||
|
||||
cursor = self._conn.cursor()
|
||||
if type is not None:
|
||||
cursor.execute(main_query, (type,))
|
||||
else:
|
||||
cursor.execute(main_query)
|
||||
if type is not None:
|
||||
cursor.execute(main_query, (type,))
|
||||
else:
|
||||
cursor.execute(main_query)
|
||||
|
||||
rows = cursor.fetchall()
|
||||
rows = cursor.fetchall()
|
||||
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
|
||||
|
||||
return style_presets
|
||||
|
||||
def _sync_default_style_presets(self) -> None:
|
||||
"""Syncs default style presets to the database. Internal use only."""
|
||||
|
||||
# First delete all existing default style presets
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
# First delete all existing default style presets
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM style_presets
|
||||
WHERE type = "default";
|
||||
"""
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
# Next, parse and create the default style presets
|
||||
with open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
|
||||
presets = json.load(file)
|
||||
|
||||
@@ -25,7 +25,7 @@ SQL_TIME_FORMAT = "%Y-%m-%d %H:%M:%f"
|
||||
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._conn = db.conn
|
||||
self._db = db
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
@@ -33,16 +33,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
|
||||
def get(self, workflow_id: str) -> WorkflowRecordDTO:
|
||||
"""Gets a workflow by ID. Updates the opened_at column."""
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
|
||||
FROM workflow_library
|
||||
WHERE workflow_id = ?;
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
|
||||
FROM workflow_library
|
||||
WHERE workflow_id = ?;
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
|
||||
return WorkflowRecordDTO.from_dict(dict(row))
|
||||
@@ -51,9 +51,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
if workflow.meta.category is WorkflowCategory.Default:
|
||||
raise ValueError("Default workflows cannot be created via this method")
|
||||
|
||||
try:
|
||||
with self._db.transaction() as cursor:
|
||||
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO workflow_library (
|
||||
@@ -64,18 +63,13 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
""",
|
||||
(workflow_with_id.id, workflow_with_id.model_dump_json()),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return self.get(workflow_with_id.id)
|
||||
|
||||
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
|
||||
if workflow.meta.category is WorkflowCategory.Default:
|
||||
raise ValueError("Default workflows cannot be updated")
|
||||
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE workflow_library
|
||||
@@ -84,18 +78,13 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
""",
|
||||
(workflow.model_dump_json(), workflow.id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return self.get(workflow.id)
|
||||
|
||||
def delete(self, workflow_id: str) -> None:
|
||||
if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default:
|
||||
raise ValueError("Default workflows cannot be deleted")
|
||||
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE from workflow_library
|
||||
@@ -103,10 +92,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return None
|
||||
|
||||
def get_many(
|
||||
@@ -121,108 +106,108 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
# sanitize!
|
||||
assert order_by in WorkflowRecordOrderBy
|
||||
assert direction in SQLiteDirection
|
||||
with self._db.transaction() as cursor:
|
||||
# sanitize!
|
||||
assert order_by in WorkflowRecordOrderBy
|
||||
assert direction in SQLiteDirection
|
||||
|
||||
# We will construct the query dynamically based on the query params
|
||||
# We will construct the query dynamically based on the query params
|
||||
|
||||
# The main query to get the workflows / counts
|
||||
main_query = """
|
||||
SELECT
|
||||
workflow_id,
|
||||
category,
|
||||
name,
|
||||
description,
|
||||
created_at,
|
||||
updated_at,
|
||||
opened_at,
|
||||
tags
|
||||
FROM workflow_library
|
||||
"""
|
||||
count_query = "SELECT COUNT(*) FROM workflow_library"
|
||||
# The main query to get the workflows / counts
|
||||
main_query = """
|
||||
SELECT
|
||||
workflow_id,
|
||||
category,
|
||||
name,
|
||||
description,
|
||||
created_at,
|
||||
updated_at,
|
||||
opened_at,
|
||||
tags
|
||||
FROM workflow_library
|
||||
"""
|
||||
count_query = "SELECT COUNT(*) FROM workflow_library"
|
||||
|
||||
# Start with an empty list of conditions and params
|
||||
conditions: list[str] = []
|
||||
params: list[str | int] = []
|
||||
# Start with an empty list of conditions and params
|
||||
conditions: list[str] = []
|
||||
params: list[str | int] = []
|
||||
|
||||
if categories:
|
||||
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
|
||||
if categories:
|
||||
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
|
||||
|
||||
# Ensure all categories are valid (is this necessary?)
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
# Ensure all categories are valid (is this necessary?)
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
|
||||
# Construct a placeholder string for the number of categories
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
# Construct a placeholder string for the number of categories
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
|
||||
# Construct the condition string & params
|
||||
category_condition = f"category IN ({placeholders})"
|
||||
category_params = [category.value for category in categories]
|
||||
# Construct the condition string & params
|
||||
category_condition = f"category IN ({placeholders})"
|
||||
category_params = [category.value for category in categories]
|
||||
|
||||
conditions.append(category_condition)
|
||||
params.extend(category_params)
|
||||
conditions.append(category_condition)
|
||||
params.extend(category_params)
|
||||
|
||||
if tags:
|
||||
# Tags is a list of strings, and a single string in the DB
|
||||
# The string in the DB has no guaranteed format
|
||||
if tags:
|
||||
# Tags is a list of strings, and a single string in the DB
|
||||
# The string in the DB has no guaranteed format
|
||||
|
||||
# Construct a list of conditions for each tag
|
||||
tags_conditions = ["tags LIKE ?" for _ in tags]
|
||||
tags_conditions_joined = " OR ".join(tags_conditions)
|
||||
tags_condition = f"({tags_conditions_joined})"
|
||||
# Construct a list of conditions for each tag
|
||||
tags_conditions = ["tags LIKE ?" for _ in tags]
|
||||
tags_conditions_joined = " OR ".join(tags_conditions)
|
||||
tags_condition = f"({tags_conditions_joined})"
|
||||
|
||||
# And the params for the tags, case-insensitive
|
||||
tags_params = [f"%{t.strip()}%" for t in tags]
|
||||
# And the params for the tags, case-insensitive
|
||||
tags_params = [f"%{t.strip()}%" for t in tags]
|
||||
|
||||
conditions.append(tags_condition)
|
||||
params.extend(tags_params)
|
||||
conditions.append(tags_condition)
|
||||
params.extend(tags_params)
|
||||
|
||||
if has_been_opened:
|
||||
conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
conditions.append("opened_at IS NULL")
|
||||
if has_been_opened:
|
||||
conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
conditions.append("opened_at IS NULL")
|
||||
|
||||
# Ignore whitespace in the query
|
||||
stripped_query = query.strip() if query else None
|
||||
if stripped_query:
|
||||
# Construct a wildcard query for the name, description, and tags
|
||||
wildcard_query = "%" + stripped_query + "%"
|
||||
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
|
||||
# Ignore whitespace in the query
|
||||
stripped_query = query.strip() if query else None
|
||||
if stripped_query:
|
||||
# Construct a wildcard query for the name, description, and tags
|
||||
wildcard_query = "%" + stripped_query + "%"
|
||||
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
|
||||
|
||||
conditions.append(query_condition)
|
||||
params.extend([wildcard_query, wildcard_query, wildcard_query])
|
||||
conditions.append(query_condition)
|
||||
params.extend([wildcard_query, wildcard_query, wildcard_query])
|
||||
|
||||
if conditions:
|
||||
# If there are conditions, add a WHERE clause and then join the conditions
|
||||
main_query += " WHERE "
|
||||
count_query += " WHERE "
|
||||
if conditions:
|
||||
# If there are conditions, add a WHERE clause and then join the conditions
|
||||
main_query += " WHERE "
|
||||
count_query += " WHERE "
|
||||
|
||||
all_conditions = " AND ".join(conditions)
|
||||
main_query += all_conditions
|
||||
count_query += all_conditions
|
||||
all_conditions = " AND ".join(conditions)
|
||||
main_query += all_conditions
|
||||
count_query += all_conditions
|
||||
|
||||
# After this point, the query and params differ for the main query and the count query
|
||||
main_params = params.copy()
|
||||
count_params = params.copy()
|
||||
# After this point, the query and params differ for the main query and the count query
|
||||
main_params = params.copy()
|
||||
count_params = params.copy()
|
||||
|
||||
# Main query also gets ORDER BY and LIMIT/OFFSET
|
||||
main_query += f" ORDER BY {order_by.value} {direction.value}"
|
||||
# Main query also gets ORDER BY and LIMIT/OFFSET
|
||||
main_query += f" ORDER BY {order_by.value} {direction.value}"
|
||||
|
||||
if per_page:
|
||||
main_query += " LIMIT ? OFFSET ?"
|
||||
main_params.extend([per_page, page * per_page])
|
||||
if per_page:
|
||||
main_query += " LIMIT ? OFFSET ?"
|
||||
main_params.extend([per_page, page * per_page])
|
||||
|
||||
# Put a ring on it
|
||||
main_query += ";"
|
||||
count_query += ";"
|
||||
# Put a ring on it
|
||||
main_query += ";"
|
||||
count_query += ";"
|
||||
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(main_query, main_params)
|
||||
rows = cursor.fetchall()
|
||||
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
|
||||
cursor.execute(main_query, main_params)
|
||||
rows = cursor.fetchall()
|
||||
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
|
||||
|
||||
cursor.execute(count_query, count_params)
|
||||
total = cursor.fetchone()[0]
|
||||
cursor.execute(count_query, count_params)
|
||||
total = cursor.fetchone()[0]
|
||||
|
||||
if per_page:
|
||||
pages = total // per_page + (total % per_page > 0)
|
||||
@@ -247,46 +232,46 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
if not tags:
|
||||
return {}
|
||||
|
||||
cursor = self._conn.cursor()
|
||||
result: dict[str, int] = {}
|
||||
# Base conditions for categories and selected tags
|
||||
base_conditions: list[str] = []
|
||||
base_params: list[str | int] = []
|
||||
with self._db.transaction() as cursor:
|
||||
result: dict[str, int] = {}
|
||||
# Base conditions for categories and selected tags
|
||||
base_conditions: list[str] = []
|
||||
base_params: list[str | int] = []
|
||||
|
||||
# Add category conditions
|
||||
if categories:
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
base_conditions.append(f"category IN ({placeholders})")
|
||||
base_params.extend([category.value for category in categories])
|
||||
# Add category conditions
|
||||
if categories:
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
base_conditions.append(f"category IN ({placeholders})")
|
||||
base_params.extend([category.value for category in categories])
|
||||
|
||||
if has_been_opened:
|
||||
base_conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
base_conditions.append("opened_at IS NULL")
|
||||
if has_been_opened:
|
||||
base_conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
base_conditions.append("opened_at IS NULL")
|
||||
|
||||
# For each tag to count, run a separate query
|
||||
for tag in tags:
|
||||
# Start with the base conditions
|
||||
conditions = base_conditions.copy()
|
||||
params = base_params.copy()
|
||||
# For each tag to count, run a separate query
|
||||
for tag in tags:
|
||||
# Start with the base conditions
|
||||
conditions = base_conditions.copy()
|
||||
params = base_params.copy()
|
||||
|
||||
# Add this specific tag condition
|
||||
conditions.append("tags LIKE ?")
|
||||
params.append(f"%{tag.strip()}%")
|
||||
# Add this specific tag condition
|
||||
conditions.append("tags LIKE ?")
|
||||
params.append(f"%{tag.strip()}%")
|
||||
|
||||
# Construct the full query
|
||||
stmt = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM workflow_library
|
||||
"""
|
||||
# Construct the full query
|
||||
stmt = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM workflow_library
|
||||
"""
|
||||
|
||||
if conditions:
|
||||
stmt += " WHERE " + " AND ".join(conditions)
|
||||
if conditions:
|
||||
stmt += " WHERE " + " AND ".join(conditions)
|
||||
|
||||
cursor.execute(stmt, params)
|
||||
count = cursor.fetchone()[0]
|
||||
result[tag] = count
|
||||
cursor.execute(stmt, params)
|
||||
count = cursor.fetchone()[0]
|
||||
result[tag] = count
|
||||
|
||||
return result
|
||||
|
||||
@@ -296,52 +281,51 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
cursor = self._conn.cursor()
|
||||
result: dict[str, int] = {}
|
||||
# Base conditions for categories
|
||||
base_conditions: list[str] = []
|
||||
base_params: list[str | int] = []
|
||||
with self._db.transaction() as cursor:
|
||||
result: dict[str, int] = {}
|
||||
# Base conditions for categories
|
||||
base_conditions: list[str] = []
|
||||
base_params: list[str | int] = []
|
||||
|
||||
# Add category conditions
|
||||
if categories:
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
base_conditions.append(f"category IN ({placeholders})")
|
||||
base_params.extend([category.value for category in categories])
|
||||
# Add category conditions
|
||||
if categories:
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
base_conditions.append(f"category IN ({placeholders})")
|
||||
base_params.extend([category.value for category in categories])
|
||||
|
||||
if has_been_opened:
|
||||
base_conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
base_conditions.append("opened_at IS NULL")
|
||||
if has_been_opened:
|
||||
base_conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
base_conditions.append("opened_at IS NULL")
|
||||
|
||||
# For each category to count, run a separate query
|
||||
for category in categories:
|
||||
# Start with the base conditions
|
||||
conditions = base_conditions.copy()
|
||||
params = base_params.copy()
|
||||
# For each category to count, run a separate query
|
||||
for category in categories:
|
||||
# Start with the base conditions
|
||||
conditions = base_conditions.copy()
|
||||
params = base_params.copy()
|
||||
|
||||
# Add this specific category condition
|
||||
conditions.append("category = ?")
|
||||
params.append(category.value)
|
||||
# Add this specific category condition
|
||||
conditions.append("category = ?")
|
||||
params.append(category.value)
|
||||
|
||||
# Construct the full query
|
||||
stmt = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM workflow_library
|
||||
"""
|
||||
# Construct the full query
|
||||
stmt = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM workflow_library
|
||||
"""
|
||||
|
||||
if conditions:
|
||||
stmt += " WHERE " + " AND ".join(conditions)
|
||||
if conditions:
|
||||
stmt += " WHERE " + " AND ".join(conditions)
|
||||
|
||||
cursor.execute(stmt, params)
|
||||
count = cursor.fetchone()[0]
|
||||
result[category.value] = count
|
||||
cursor.execute(stmt, params)
|
||||
count = cursor.fetchone()[0]
|
||||
result[category.value] = count
|
||||
|
||||
return result
|
||||
|
||||
def update_opened_at(self, workflow_id: str) -> None:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE workflow_library
|
||||
@@ -350,10 +334,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
def _sync_default_workflows(self) -> None:
|
||||
"""Syncs default workflows to the database. Internal use only."""
|
||||
@@ -368,8 +348,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
meaningless, as they are overwritten every time the server starts.
|
||||
"""
|
||||
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
workflows_from_file: list[Workflow] = []
|
||||
workflows_to_update: list[Workflow] = []
|
||||
workflows_to_add: list[Workflow] = []
|
||||
@@ -449,8 +428,3 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
""",
|
||||
(w.model_dump_json(), w.id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
@@ -187,7 +187,7 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
else:
|
||||
return config_cls.from_model_on_disk(mod, **overrides)
|
||||
|
||||
raise InvalidModelConfigException("No valid config found")
|
||||
raise InvalidModelConfigException("Unable to determine model type")
|
||||
|
||||
@classmethod
|
||||
def get_tag(cls) -> Tag:
|
||||
|
||||
@@ -143,11 +143,19 @@ flux_dev = StarterModel(
|
||||
flux_kontext = StarterModel(
|
||||
name="FLUX.1 Kontext dev",
|
||||
base=BaseModelType.Flux,
|
||||
source="black-forest-labs/FLUX.1-Kontext-dev::flux1-kontext-dev.safetensors",
|
||||
source="https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev/resolve/main/flux1-kontext-dev.safetensors",
|
||||
description="FLUX.1 Kontext dev transformer in bfloat16. Total size with dependencies: ~33GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
flux_kontext_quantized = StarterModel(
|
||||
name="FLUX.1 Kontext dev (Quantized)",
|
||||
base=BaseModelType.Flux,
|
||||
source="https://huggingface.co/unsloth/FLUX.1-Kontext-dev-GGUF/resolve/main/flux1-kontext-dev-Q4_K_M.gguf",
|
||||
description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~14GB",
|
||||
type=ModelType.Main,
|
||||
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
|
||||
)
|
||||
sd35_medium = StarterModel(
|
||||
name="SD3.5 Medium",
|
||||
base=BaseModelType.StableDiffusion3,
|
||||
@@ -664,7 +672,7 @@ flux_fill = StarterModel(
|
||||
# List of starter models, displayed on the frontend.
|
||||
# The order/sort of this list is not changed by the frontend - set it how you want it here.
|
||||
STARTER_MODELS: list[StarterModel] = [
|
||||
flux_kontext,
|
||||
flux_kontext_quantized,
|
||||
flux_schnell_quantized,
|
||||
flux_dev_quantized,
|
||||
flux_schnell,
|
||||
@@ -785,7 +793,7 @@ flux_bundle: list[StarterModel] = [
|
||||
flux_depth_control_lora,
|
||||
flux_redux,
|
||||
flux_fill,
|
||||
flux_kontext,
|
||||
flux_kontext_quantized,
|
||||
]
|
||||
|
||||
STARTER_BUNDLES: dict[str, StarterModelBundle] = {
|
||||
|
||||
@@ -12,6 +12,8 @@ const config: KnipConfig = {
|
||||
'src/features/parameters/types/parameterSchemas.ts',
|
||||
// TODO(psyche): maybe we can clean up these utils after canvas v2 release
|
||||
'src/features/controlLayers/konva/util.ts',
|
||||
// Will be using this
|
||||
'src/common/hooks/useAsyncState.ts',
|
||||
],
|
||||
ignoreBinaries: ['only-allow'],
|
||||
paths: {
|
||||
|
||||
@@ -63,7 +63,7 @@
|
||||
"framer-motion": "^11.10.0",
|
||||
"i18next": "^25.2.1",
|
||||
"i18next-http-backend": "^3.0.2",
|
||||
"idb-keyval": "^6.2.2",
|
||||
"idb-keyval": "6.2.1",
|
||||
"jsondiffpatch": "^0.7.3",
|
||||
"konva": "^9.3.20",
|
||||
"linkify-react": "^4.3.1",
|
||||
|
||||
10
invokeai/frontend/web/pnpm-lock.yaml
generated
10
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -81,8 +81,8 @@ importers:
|
||||
specifier: ^3.0.2
|
||||
version: 3.0.2
|
||||
idb-keyval:
|
||||
specifier: ^6.2.2
|
||||
version: 6.2.2
|
||||
specifier: 6.2.1
|
||||
version: 6.2.1
|
||||
jsondiffpatch:
|
||||
specifier: ^0.7.3
|
||||
version: 0.7.3
|
||||
@@ -2927,8 +2927,8 @@ packages:
|
||||
typescript:
|
||||
optional: true
|
||||
|
||||
idb-keyval@6.2.2:
|
||||
resolution: {integrity: sha512-yjD9nARJ/jb1g+CvD0tlhUHOrJ9Sy0P8T9MF3YaLlHnSRpwPfpTX0XIvpmw3gAJUmEu3FiICLBDPXVwyEvrleg==}
|
||||
idb-keyval@6.2.1:
|
||||
resolution: {integrity: sha512-8Sb3veuYCyrZL+VBt9LJfZjLUPWVvqn8tG28VqYNFCo43KHcKuq+b4EiXGeuaLAQWL2YmyDgMp2aSpH9JHsEQg==}
|
||||
|
||||
ieee754@1.2.1:
|
||||
resolution: {integrity: sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==}
|
||||
@@ -7720,7 +7720,7 @@ snapshots:
|
||||
optionalDependencies:
|
||||
typescript: 5.8.3
|
||||
|
||||
idb-keyval@6.2.2: {}
|
||||
idb-keyval@6.2.1: {}
|
||||
|
||||
ieee754@1.2.1: {}
|
||||
|
||||
|
||||
@@ -1399,7 +1399,7 @@
|
||||
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks.",
|
||||
"imagenIncompatibleGenerationMode": "Google {{model}} supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
|
||||
"chatGPT4oIncompatibleGenerationMode": "ChatGPT 4o supports Text to Image and Image to Image only. Use other models Inpainting and Outpainting tasks.",
|
||||
"fluxKontextIncompatibleGenerationMode": "FLUX Kontext supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
|
||||
"fluxKontextIncompatibleGenerationMode": "FLUX Kontext does not support generation from images placed on the canvas. Re-try using the Reference Image section and disable any Raster Layers.",
|
||||
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
|
||||
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
|
||||
"workflowUnpublished": "Workflow Unpublished",
|
||||
@@ -1407,7 +1407,7 @@
|
||||
"sentToUpscale": "Sent to Upscale",
|
||||
"promptGenerationStarted": "Prompt generation started",
|
||||
"uploadAndPromptGenerationFailed": "Failed to upload image and generate prompt",
|
||||
"promptExpansionFailed": "Prompt expansion failed"
|
||||
"promptExpansionFailed": "We ran into an issue. Please try prompt expansion again."
|
||||
},
|
||||
"popovers": {
|
||||
"clipSkip": {
|
||||
@@ -1962,6 +1962,7 @@
|
||||
"recalculateRects": "Recalculate Rects",
|
||||
"clipToBbox": "Clip Strokes to Bbox",
|
||||
"outputOnlyMaskedRegions": "Output Only Generated Regions",
|
||||
"saveAllImagesToGallery": "Save All Images to Gallery",
|
||||
"addLayer": "Add Layer",
|
||||
"duplicate": "Duplicate",
|
||||
"moveToFront": "Move to Front",
|
||||
@@ -2330,6 +2331,9 @@
|
||||
"label": "Preserve Masked Region",
|
||||
"alert": "Preserving Masked Region"
|
||||
},
|
||||
"saveAllImagesToGallery": {
|
||||
"alert": "Saving All Images to Gallery"
|
||||
},
|
||||
"isolatedStagingPreview": "Isolated Staging Preview",
|
||||
"isolatedPreview": "Isolated Preview",
|
||||
"isolatedLayerPreview": "Isolated Layer Preview",
|
||||
@@ -2376,6 +2380,11 @@
|
||||
"saveToGallery": "Save To Gallery",
|
||||
"showResultsOn": "Showing Results",
|
||||
"showResultsOff": "Hiding Results"
|
||||
},
|
||||
"autoSwitch": {
|
||||
"off": "Off",
|
||||
"switchOnStart": "On Start",
|
||||
"switchOnFinish": "On Finish"
|
||||
}
|
||||
},
|
||||
"upscaling": {
|
||||
@@ -2551,8 +2560,9 @@
|
||||
"whatsNew": {
|
||||
"whatsNewInInvoke": "What's New in Invoke",
|
||||
"items": [
|
||||
"Inpainting: Per-mask noise levels and denoise limits.",
|
||||
"Canvas: Smarter aspect ratios for SDXL and improved scroll-to-zoom."
|
||||
"Generate images faster with new Launchpads and a simplified Generate tab.",
|
||||
"Edit with prompts using Flux Kontext Dev.",
|
||||
"Export to PSD, bulk-hide overlays, organize models & images — all in a reimagined interface built for control."
|
||||
],
|
||||
"readReleaseNotes": "Read Release Notes",
|
||||
"watchRecentReleaseVideos": "Watch Recent Release Videos",
|
||||
@@ -2561,62 +2571,16 @@
|
||||
"supportVideos": {
|
||||
"supportVideos": "Support Videos",
|
||||
"gettingStarted": "Getting Started",
|
||||
"controlCanvas": "Control Canvas",
|
||||
"watch": "Watch",
|
||||
"studioSessionsDesc1": "Check out the <StudioSessionsPlaylistLink /> for Invoke deep dives.",
|
||||
"studioSessionsDesc2": "Join our <DiscordLink /> to participate in the live sessions and ask questions. Sessions are uploaded to the playlist the following week.",
|
||||
"studioSessionsDesc": "Join our <DiscordLink /> to participate in the live sessions and ask questions. Sessions are uploaded to the playlist the following week.",
|
||||
"videos": {
|
||||
"creatingYourFirstImage": {
|
||||
"title": "Creating Your First Image",
|
||||
"description": "Introduction to creating an image from scratch using Invoke's tools."
|
||||
"gettingStarted": {
|
||||
"title": "Getting Started with Invoke",
|
||||
"description": "Complete video series covering everything you need to know to get started with Invoke, from creating your first image to advanced techniques."
|
||||
},
|
||||
"usingControlLayersAndReferenceGuides": {
|
||||
"title": "Using Control Layers and Reference Guides",
|
||||
"description": "Learn how to guide your image creation with control layers and reference images."
|
||||
},
|
||||
"understandingImageToImageAndDenoising": {
|
||||
"title": "Understanding Image-to-Image and Denoising",
|
||||
"description": "Overview of image-to-image transformations and denoising in Invoke."
|
||||
},
|
||||
"exploringAIModelsAndConceptAdapters": {
|
||||
"title": "Exploring AI Models and Concept Adapters",
|
||||
"description": "Dive into AI models and how to use concept adapters for creative control."
|
||||
},
|
||||
"creatingAndComposingOnInvokesControlCanvas": {
|
||||
"title": "Creating and Composing on Invoke's Control Canvas",
|
||||
"description": "Learn to compose images using Invoke's control canvas."
|
||||
},
|
||||
"upscaling": {
|
||||
"title": "Upscaling",
|
||||
"description": "How to upscale images with Invoke's tools to enhance resolution."
|
||||
},
|
||||
"howDoIGenerateAndSaveToTheGallery": {
|
||||
"title": "How Do I Generate and Save to the Gallery?",
|
||||
"description": "Steps to generate and save images to the gallery."
|
||||
},
|
||||
"howDoIEditOnTheCanvas": {
|
||||
"title": "How Do I Edit on the Canvas?",
|
||||
"description": "Guide to editing images directly on the canvas."
|
||||
},
|
||||
"howDoIDoImageToImageTransformation": {
|
||||
"title": "How Do I Do Image-to-Image Transformation?",
|
||||
"description": "Tutorial on performing image-to-image transformations in Invoke."
|
||||
},
|
||||
"howDoIUseControlNetsAndControlLayers": {
|
||||
"title": "How Do I Use Control Nets and Control Layers?",
|
||||
"description": "Learn to apply control layers and controlnets to your images."
|
||||
},
|
||||
"howDoIUseGlobalIPAdaptersAndReferenceImages": {
|
||||
"title": "How Do I Use Global IP Adapters and Reference Images?",
|
||||
"description": "Introduction to adding reference images and global IP adapters."
|
||||
},
|
||||
"howDoIUseInpaintMasks": {
|
||||
"title": "How Do I Use Inpaint Masks?",
|
||||
"description": "How to apply inpaint masks for image correction and variation."
|
||||
},
|
||||
"howDoIOutpaint": {
|
||||
"title": "How Do I Outpaint?",
|
||||
"description": "Guide to outpainting beyond the original image borders."
|
||||
"studioSessions": {
|
||||
"title": "Studio Sessions",
|
||||
"description": "Deep dive sessions exploring advanced Invoke features, creative workflows, and community discussions."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import { memo, useCallback } from 'react';
|
||||
import { ErrorBoundary } from 'react-error-boundary';
|
||||
|
||||
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
|
||||
import ThemeLocaleProvider from './ThemeLocaleProvider';
|
||||
const DEFAULT_CONFIG = {};
|
||||
|
||||
interface Props {
|
||||
@@ -30,12 +31,14 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
|
||||
|
||||
return (
|
||||
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
|
||||
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
|
||||
<AppContent />
|
||||
{!didStudioInit && <Loading />}
|
||||
</Box>
|
||||
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
|
||||
<GlobalModalIsolator />
|
||||
<ThemeLocaleProvider>
|
||||
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
|
||||
<AppContent />
|
||||
{!didStudioInit && <Loading />}
|
||||
</Box>
|
||||
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
|
||||
<GlobalModalIsolator />
|
||||
</ThemeLocaleProvider>
|
||||
</ErrorBoundary>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useIsRegionFocused } from 'common/hooks/focus';
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { useImageActions } from 'features/gallery/hooks/useImageActions';
|
||||
import { useLoadWorkflow } from 'features/gallery/hooks/useLoadWorkflow';
|
||||
import { useRecallAll } from 'features/gallery/hooks/useRecallAll';
|
||||
import { useRecallDimensions } from 'features/gallery/hooks/useRecallDimensions';
|
||||
import { useRecallPrompts } from 'features/gallery/hooks/useRecallPrompts';
|
||||
import { useRecallRemix } from 'features/gallery/hooks/useRecallRemix';
|
||||
import { useRecallSeed } from 'features/gallery/hooks/useRecallSeed';
|
||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { memo } from 'react';
|
||||
import { useImageDTO } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
@@ -27,59 +30,64 @@ GlobalImageHotkeys.displayName = 'GlobalImageHotkeys';
|
||||
const GlobalImageHotkeysInternal = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
|
||||
const isGalleryFocused = useIsRegionFocused('gallery');
|
||||
const isViewerFocused = useIsRegionFocused('viewer');
|
||||
const imageActions = useImageActions(imageDTO);
|
||||
const isStaging = useAppSelector(selectIsStaging);
|
||||
const isUpscalingEnabled = useFeatureStatus('upscaling');
|
||||
|
||||
const isFocusOK = isGalleryFocused || isViewerFocused;
|
||||
|
||||
const recallAll = useRecallAll(imageDTO);
|
||||
const recallRemix = useRecallRemix(imageDTO);
|
||||
const recallPrompts = useRecallPrompts(imageDTO);
|
||||
const recallSeed = useRecallSeed(imageDTO);
|
||||
const recallDimensions = useRecallDimensions(imageDTO);
|
||||
const loadWorkflow = useLoadWorkflow(imageDTO);
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'loadWorkflow',
|
||||
category: 'viewer',
|
||||
callback: imageActions.loadWorkflow,
|
||||
options: { enabled: isGalleryFocused || isViewerFocused },
|
||||
dependencies: [imageActions.loadWorkflow, isGalleryFocused, isViewerFocused],
|
||||
callback: loadWorkflow.load,
|
||||
options: { enabled: loadWorkflow.isEnabled && isFocusOK },
|
||||
dependencies: [loadWorkflow, isFocusOK],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'recallAll',
|
||||
category: 'viewer',
|
||||
callback: imageActions.recallAll,
|
||||
options: { enabled: !isStaging && (isGalleryFocused || isViewerFocused) },
|
||||
dependencies: [imageActions.recallAll, isStaging, isGalleryFocused, isViewerFocused],
|
||||
callback: recallAll.recall,
|
||||
options: { enabled: recallAll.isEnabled && isFocusOK },
|
||||
dependencies: [recallAll, isFocusOK],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'recallSeed',
|
||||
category: 'viewer',
|
||||
callback: imageActions.recallSeed,
|
||||
options: { enabled: isGalleryFocused || isViewerFocused },
|
||||
dependencies: [imageActions.recallSeed, isGalleryFocused, isViewerFocused],
|
||||
callback: recallSeed.recall,
|
||||
options: { enabled: recallSeed.isEnabled && isFocusOK },
|
||||
dependencies: [recallSeed, isFocusOK],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'recallPrompts',
|
||||
category: 'viewer',
|
||||
callback: imageActions.recallPrompts,
|
||||
options: { enabled: isGalleryFocused || isViewerFocused },
|
||||
dependencies: [imageActions.recallPrompts, isGalleryFocused, isViewerFocused],
|
||||
callback: recallPrompts.recall,
|
||||
options: { enabled: recallPrompts.isEnabled && isFocusOK },
|
||||
dependencies: [recallPrompts, isFocusOK],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'remix',
|
||||
category: 'viewer',
|
||||
callback: imageActions.remix,
|
||||
options: { enabled: isGalleryFocused || isViewerFocused },
|
||||
dependencies: [imageActions.remix, isGalleryFocused, isViewerFocused],
|
||||
callback: recallRemix.recall,
|
||||
options: { enabled: recallRemix.isEnabled && isFocusOK },
|
||||
dependencies: [recallRemix, isFocusOK],
|
||||
});
|
||||
|
||||
useRegisteredHotkeys({
|
||||
id: 'useSize',
|
||||
category: 'viewer',
|
||||
callback: imageActions.recallSize,
|
||||
options: { enabled: !isStaging && (isGalleryFocused || isViewerFocused) },
|
||||
dependencies: [imageActions.recallSize, isStaging, isGalleryFocused, isViewerFocused],
|
||||
});
|
||||
useRegisteredHotkeys({
|
||||
id: 'runPostprocessing',
|
||||
category: 'viewer',
|
||||
callback: imageActions.upscale,
|
||||
options: { enabled: isUpscalingEnabled && isViewerFocused },
|
||||
dependencies: [isUpscalingEnabled, imageDTO, isViewerFocused],
|
||||
callback: recallDimensions.recall,
|
||||
options: { enabled: recallDimensions.isEnabled && isFocusOK },
|
||||
dependencies: [recallDimensions, isFocusOK],
|
||||
});
|
||||
|
||||
return null;
|
||||
});
|
||||
|
||||
|
||||
@@ -42,7 +42,6 @@ import { $socketOptions } from 'services/events/stores';
|
||||
import type { ManagerOptions, SocketOptions } from 'socket.io-client';
|
||||
|
||||
const App = lazy(() => import('./App'));
|
||||
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
||||
|
||||
interface Props extends PropsWithChildren {
|
||||
apiUrl?: string;
|
||||
@@ -330,9 +329,7 @@ const InvokeAIUI = ({
|
||||
<React.StrictMode>
|
||||
<Provider store={store}>
|
||||
<React.Suspense fallback={<Loading />}>
|
||||
<ThemeLocaleProvider>
|
||||
<App config={config} studioInitAction={studioInitAction} />
|
||||
</ThemeLocaleProvider>
|
||||
<App config={config} studioInitAction={studioInitAction} />
|
||||
</React.Suspense>
|
||||
</Provider>
|
||||
</React.StrictMode>
|
||||
|
||||
@@ -20,6 +20,7 @@ import {
|
||||
import { $isStylePresetsMenuOpen, activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { navigationApi } from 'features/ui/layouts/navigation-api';
|
||||
import { LAUNCHPAD_PANEL_ID, WORKSPACE_PANEL_ID } from 'features/ui/layouts/shared';
|
||||
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
|
||||
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
|
||||
import { atom } from 'nanostores';
|
||||
@@ -91,6 +92,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
|
||||
const overrides: Partial<CanvasRasterLayerState> = {
|
||||
objects: [imageObject],
|
||||
};
|
||||
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
|
||||
store.dispatch(canvasReset());
|
||||
store.dispatch(rasterLayerAdded({ overrides, isSelected: true }));
|
||||
store.dispatch(sentImageToCanvas());
|
||||
@@ -157,16 +159,17 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
|
||||
);
|
||||
|
||||
const handleGoToDestination = useCallback(
|
||||
(destination: StudioDestinationAction['data']['destination']) => {
|
||||
async (destination: StudioDestinationAction['data']['destination']) => {
|
||||
switch (destination) {
|
||||
case 'generation':
|
||||
// Go to the canvas tab, open the image viewer, and enable send-to-gallery mode
|
||||
// Go to the generate tab, open the launchpad
|
||||
await navigationApi.focusPanel('generate', LAUNCHPAD_PANEL_ID);
|
||||
store.dispatch(paramsReset());
|
||||
store.dispatch(activeTabCanvasRightPanelChanged('gallery'));
|
||||
break;
|
||||
case 'canvas':
|
||||
// Go to the canvas tab, close the image viewer, and disable send-to-gallery mode
|
||||
store.dispatch(canvasReset());
|
||||
// Go to the canvas tab, open the launchpad
|
||||
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
|
||||
break;
|
||||
case 'workflows':
|
||||
// Go to the workflows tab
|
||||
|
||||
@@ -1,14 +1,28 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { bboxSyncedToOptimalDimension } from 'features/controlLayers/store/canvasSlice';
|
||||
import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
|
||||
import { modelChanged, syncedToOptimalDimension, vaeSelected } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectBboxModelBase } from 'features/controlLayers/store/selectors';
|
||||
import { refImageModelChanged, selectReferenceImageEntities } from 'features/controlLayers/store/refImagesSlice';
|
||||
import {
|
||||
selectAllEntitiesOfType,
|
||||
selectBboxModelBase,
|
||||
selectCanvasSlice,
|
||||
} from 'features/controlLayers/store/selectors';
|
||||
import { getEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { selectGlobalRefImageModels, selectRegionalRefImageModels } from 'services/api/hooks/modelsByType';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import {
|
||||
isChatGPT4oModelConfig,
|
||||
isFluxKontextApiModelConfig,
|
||||
isFluxKontextModelConfig,
|
||||
isFluxReduxModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
const log = logger('models');
|
||||
|
||||
@@ -25,9 +39,8 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
}
|
||||
|
||||
const newModel = result.data;
|
||||
|
||||
const newBaseModel = newModel.base;
|
||||
const didBaseModelChange = state.params.model?.base !== newBaseModel;
|
||||
const newBase = newModel.base;
|
||||
const didBaseModelChange = state.params.model?.base !== newBase;
|
||||
|
||||
if (didBaseModelChange) {
|
||||
// we may need to reset some incompatible submodels
|
||||
@@ -35,7 +48,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
|
||||
// handle incompatible loras
|
||||
state.loras.loras.forEach((lora) => {
|
||||
if (lora.model.base !== newBaseModel) {
|
||||
if (lora.model.base !== newBase) {
|
||||
dispatch(loraDeleted({ id: lora.id }));
|
||||
modelsCleared += 1;
|
||||
}
|
||||
@@ -43,20 +56,82 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
|
||||
|
||||
// handle incompatible vae
|
||||
const { vae } = state.params;
|
||||
if (vae && vae.base !== newBaseModel) {
|
||||
if (vae && vae.base !== newBase) {
|
||||
dispatch(vaeSelected(null));
|
||||
modelsCleared += 1;
|
||||
}
|
||||
|
||||
// handle incompatible controlnets
|
||||
// state.canvas.present.controlAdapters.entities.forEach((ca) => {
|
||||
// if (ca.model?.base !== newBaseModel) {
|
||||
// modelsCleared += 1;
|
||||
// if (ca.isEnabled) {
|
||||
// dispatch(entityIsEnabledToggled({ entityIdentifier: { id: ca.id, type: 'control_adapter' } }));
|
||||
// }
|
||||
// }
|
||||
// });
|
||||
// Handle incompatible reference image models - switch to first compatible model, with some smart logic
|
||||
// to choose the best available model based on the new main model.
|
||||
const allRefImageModels = selectGlobalRefImageModels(state).filter(({ base }) => base === newBase);
|
||||
|
||||
let newGlobalRefImageModel = null;
|
||||
|
||||
// Certain models require the ref image model to be the same as the main model - others just need a matching
|
||||
// base. Helper to grab the first exact match or the first available model if no exact match is found.
|
||||
const exactMatchOrFirst = <T extends AnyModelConfig>(candidates: T[]): T | null =>
|
||||
candidates.find(({ key }) => key === newModel.key) ?? candidates[0] ?? null;
|
||||
|
||||
// The only way we can differentiate between FLUX and FLUX Kontext is to check for "kontext" in the name
|
||||
if (newModel.base === 'flux' && newModel.name.toLowerCase().includes('kontext')) {
|
||||
const fluxKontextDevModels = allRefImageModels.filter(isFluxKontextModelConfig);
|
||||
newGlobalRefImageModel = exactMatchOrFirst(fluxKontextDevModels);
|
||||
} else if (newModel.base === 'chatgpt-4o') {
|
||||
const chatGPT4oModels = allRefImageModels.filter(isChatGPT4oModelConfig);
|
||||
newGlobalRefImageModel = exactMatchOrFirst(chatGPT4oModels);
|
||||
} else if (newModel.base === 'flux-kontext') {
|
||||
const fluxKontextApiModels = allRefImageModels.filter(isFluxKontextApiModelConfig);
|
||||
newGlobalRefImageModel = exactMatchOrFirst(fluxKontextApiModels);
|
||||
} else if (newModel.base === 'flux') {
|
||||
const fluxReduxModels = allRefImageModels.filter(isFluxReduxModelConfig);
|
||||
newGlobalRefImageModel = fluxReduxModels[0] ?? null;
|
||||
} else {
|
||||
newGlobalRefImageModel = allRefImageModels[0] ?? null;
|
||||
}
|
||||
|
||||
// All ref image entities are updated to use the same new model
|
||||
const refImageEntities = selectReferenceImageEntities(state);
|
||||
for (const entity of refImageEntities) {
|
||||
const shouldUpdateModel =
|
||||
(entity.config.model && entity.config.model.base !== newBase) ||
|
||||
(!entity.config.model && newGlobalRefImageModel);
|
||||
|
||||
if (shouldUpdateModel) {
|
||||
dispatch(
|
||||
refImageModelChanged({
|
||||
id: entity.id,
|
||||
modelConfig: newGlobalRefImageModel,
|
||||
})
|
||||
);
|
||||
modelsCleared += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// For regional guidance, there is no smart logic - we just pick the first available model.
|
||||
const newRegionalRefImageModel = selectRegionalRefImageModels(state)[0] ?? null;
|
||||
|
||||
// All regional guidance entities are updated to use the same new model.
|
||||
const canvasState = selectCanvasSlice(state);
|
||||
const canvasRegionalGuidanceEntities = selectAllEntitiesOfType(canvasState, 'regional_guidance');
|
||||
for (const entity of canvasRegionalGuidanceEntities) {
|
||||
for (const refImage of entity.referenceImages) {
|
||||
// Only change the model if the current one is not compatible with the new base model.
|
||||
const shouldUpdateModel =
|
||||
(refImage.config.model && refImage.config.model.base !== newBase) ||
|
||||
(!refImage.config.model && newRegionalRefImageModel);
|
||||
|
||||
if (shouldUpdateModel) {
|
||||
dispatch(
|
||||
rgRefImageModelChanged({
|
||||
entityIdentifier: getEntityIdentifier(entity),
|
||||
referenceImageId: refImage.id,
|
||||
modelConfig: newRegionalRefImageModel,
|
||||
})
|
||||
);
|
||||
modelsCleared += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (modelsCleared > 0) {
|
||||
toast({
|
||||
|
||||
@@ -3,6 +3,7 @@ import { isNil } from 'es-toolkit';
|
||||
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import {
|
||||
heightChanged,
|
||||
setCfgRescaleMultiplier,
|
||||
setCfgScale,
|
||||
setGuidance,
|
||||
@@ -10,6 +11,7 @@ import {
|
||||
setSteps,
|
||||
vaePrecisionChanged,
|
||||
vaeSelected,
|
||||
widthChanged,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import { setDefaultSettings } from 'features/parameters/store/actions';
|
||||
import {
|
||||
@@ -24,6 +26,7 @@ import {
|
||||
zParameterVAEModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { t } from 'i18next';
|
||||
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
@@ -113,15 +116,24 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
|
||||
const setSizeOptions = { updateAspectRatio: true, clamp: true };
|
||||
|
||||
const isStaging = selectIsStaging(getState());
|
||||
if (!isStaging && width) {
|
||||
const activeTab = selectActiveTab(getState());
|
||||
if (activeTab === 'generate') {
|
||||
if (isParameterWidth(width)) {
|
||||
dispatch(bboxWidthChanged({ width, ...setSizeOptions }));
|
||||
dispatch(widthChanged({ width, ...setSizeOptions }));
|
||||
}
|
||||
if (isParameterHeight(height)) {
|
||||
dispatch(heightChanged({ height, ...setSizeOptions }));
|
||||
}
|
||||
}
|
||||
|
||||
if (!isStaging && height) {
|
||||
if (isParameterHeight(height)) {
|
||||
dispatch(bboxHeightChanged({ height, ...setSizeOptions }));
|
||||
if (activeTab === 'canvas') {
|
||||
if (!isStaging) {
|
||||
if (isParameterWidth(width)) {
|
||||
dispatch(bboxWidthChanged({ width, ...setSizeOptions }));
|
||||
}
|
||||
if (isParameterHeight(height)) {
|
||||
dispatch(bboxHeightChanged({ height, ...setSizeOptions }));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -87,14 +87,10 @@ export const buildGroup = <T extends object>(group: Omit<Group<T>, typeof unique
|
||||
[uniqueGroupKey]: true,
|
||||
});
|
||||
|
||||
const isGroup = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGroup is Group<T> => {
|
||||
export const isGroup = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGroup is Group<T> => {
|
||||
return uniqueGroupKey in optionOrGroup && optionOrGroup[uniqueGroupKey] === true;
|
||||
};
|
||||
|
||||
export const isOption = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGroup is T => {
|
||||
return !(uniqueGroupKey in optionOrGroup);
|
||||
};
|
||||
|
||||
const DefaultOptionComponent = typedMemo(<T extends object>({ option }: { option: T }) => {
|
||||
const { getOptionId } = usePickerContext();
|
||||
return <Text fontWeight="bold">{getOptionId(option)}</Text>;
|
||||
|
||||
@@ -6,7 +6,6 @@ import { atom, computed } from 'nanostores';
|
||||
import type { RefObject } from 'react';
|
||||
import { useEffect } from 'react';
|
||||
import { objectKeys } from 'tsafe';
|
||||
import z from 'zod/v4';
|
||||
|
||||
/**
|
||||
* We need to manage focus regions to conditionally enable hotkeys:
|
||||
@@ -28,10 +27,7 @@ import z from 'zod/v4';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
/**
|
||||
* The names of the focus regions.
|
||||
*/
|
||||
const zFocusRegionName = z.enum([
|
||||
const REGION_NAMES = [
|
||||
'launchpad',
|
||||
'viewer',
|
||||
'gallery',
|
||||
@@ -41,13 +37,16 @@ const zFocusRegionName = z.enum([
|
||||
'workflows',
|
||||
'progress',
|
||||
'settings',
|
||||
]);
|
||||
export type FocusRegionName = z.infer<typeof zFocusRegionName>;
|
||||
] as const;
|
||||
/**
|
||||
* The names of the focus regions.
|
||||
*/
|
||||
export type FocusRegionName = (typeof REGION_NAMES)[number];
|
||||
|
||||
/**
|
||||
* A map of focus regions to the elements that are part of that region.
|
||||
*/
|
||||
const REGION_TARGETS: Record<FocusRegionName, Set<HTMLElement>> = zFocusRegionName.options.values().reduce(
|
||||
const REGION_TARGETS: Record<FocusRegionName, Set<HTMLElement>> = REGION_NAMES.reduce(
|
||||
(acc, region) => {
|
||||
acc[region] = new Set<HTMLElement>();
|
||||
return acc;
|
||||
|
||||
115
invokeai/frontend/web/src/common/hooks/useAsyncState.ts
Normal file
115
invokeai/frontend/web/src/common/hooks/useAsyncState.ts
Normal file
@@ -0,0 +1,115 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { WrappedError } from 'common/util/result';
|
||||
import type { Atom } from 'nanostores';
|
||||
import { atom } from 'nanostores';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
type SuccessState<T> = {
|
||||
status: 'success';
|
||||
value: T;
|
||||
error: null;
|
||||
};
|
||||
|
||||
type ErrorState = {
|
||||
status: 'error';
|
||||
value: null;
|
||||
error: Error;
|
||||
};
|
||||
|
||||
type PendingState = {
|
||||
status: 'pending';
|
||||
value: null;
|
||||
error: null;
|
||||
};
|
||||
|
||||
type IdleState = {
|
||||
status: 'idle';
|
||||
value: null;
|
||||
error: null;
|
||||
};
|
||||
|
||||
export type State<T> = IdleState | PendingState | SuccessState<T> | ErrorState;
|
||||
|
||||
type UseAsyncStateOptions = {
|
||||
immediate?: boolean;
|
||||
};
|
||||
|
||||
type UseAsyncReturn<T> = {
|
||||
$state: Atom<State<T>>;
|
||||
trigger: () => Promise<void>;
|
||||
reset: () => void;
|
||||
};
|
||||
|
||||
export const useAsyncState = <T>(execute: () => Promise<T>, options?: UseAsyncStateOptions): UseAsyncReturn<T> => {
|
||||
const $state = useState(() =>
|
||||
atom<State<T>>({
|
||||
status: 'idle',
|
||||
value: null,
|
||||
error: null,
|
||||
})
|
||||
)[0];
|
||||
|
||||
const trigger = useCallback(async () => {
|
||||
$state.set({
|
||||
status: 'pending',
|
||||
value: null,
|
||||
error: null,
|
||||
});
|
||||
try {
|
||||
const value = await execute();
|
||||
$state.set({
|
||||
status: 'success',
|
||||
value,
|
||||
error: null,
|
||||
});
|
||||
} catch (error) {
|
||||
$state.set({
|
||||
status: 'error',
|
||||
value: null,
|
||||
error: WrappedError.wrap(error),
|
||||
});
|
||||
}
|
||||
}, [$state, execute]);
|
||||
|
||||
const reset = useCallback(() => {
|
||||
$state.set({
|
||||
status: 'idle',
|
||||
value: null,
|
||||
error: null,
|
||||
});
|
||||
}, [$state]);
|
||||
|
||||
useEffect(() => {
|
||||
if (options?.immediate) {
|
||||
trigger();
|
||||
}
|
||||
}, [options?.immediate, trigger]);
|
||||
|
||||
const api = useMemo(
|
||||
() =>
|
||||
({
|
||||
$state,
|
||||
trigger,
|
||||
reset,
|
||||
}) satisfies UseAsyncReturn<T>,
|
||||
[$state, trigger, reset]
|
||||
);
|
||||
|
||||
return api;
|
||||
};
|
||||
|
||||
type UseAsyncReturnReactive<T> = {
|
||||
state: State<T>;
|
||||
trigger: () => Promise<void>;
|
||||
reset: () => void;
|
||||
};
|
||||
|
||||
export const useAsyncStateReactive = <T>(
|
||||
execute: () => Promise<T>,
|
||||
options?: UseAsyncStateOptions
|
||||
): UseAsyncReturnReactive<T> => {
|
||||
const { $state, trigger, reset } = useAsyncState(execute, options);
|
||||
const state = useStore($state);
|
||||
|
||||
return { state, trigger, reset };
|
||||
};
|
||||
@@ -0,0 +1,23 @@
|
||||
import { Alert, AlertIcon, AlertTitle } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectSaveAllImagesToGallery } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const CanvasAlertsSaveAllImagesToGallery = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const saveAllImagesToGallery = useAppSelector(selectSaveAllImagesToGallery);
|
||||
|
||||
if (!saveAllImagesToGallery) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Alert status="info" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
|
||||
<AlertIcon />
|
||||
<AlertTitle>{t('controlLayers.settings.saveAllImagesToGallery.alert')}</AlertTitle>
|
||||
</Alert>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasAlertsSaveAllImagesToGallery.displayName = 'CanvasAlertsSaveAllImagesToGallery';
|
||||
@@ -4,13 +4,17 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useRefImageEntity } from 'features/controlLayers/components/RefImage/useRefImageEntity';
|
||||
import { useRefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
|
||||
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
|
||||
import {
|
||||
refImageDeleted,
|
||||
refImageIsEnabledToggled,
|
||||
selectRefImageEntityIds,
|
||||
} from 'features/controlLayers/store/refImagesSlice';
|
||||
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { PiCircleBold, PiCircleFill, PiTrashBold } from 'react-icons/pi';
|
||||
import { PiCircleBold, PiCircleFill, PiTrashBold, PiWarningBold } from 'react-icons/pi';
|
||||
|
||||
import { RefImageWarningTooltipContent } from './RefImageWarningTooltipContent';
|
||||
|
||||
const textSx: SystemStyleObject = {
|
||||
color: 'base.300',
|
||||
@@ -28,6 +32,12 @@ export const RefImageHeader = memo(() => {
|
||||
);
|
||||
const refImageNumber = useAppSelector(selectRefImageNumber);
|
||||
const entity = useRefImageEntity(id);
|
||||
const mainModelConfig = useAppSelector(selectMainModelConfig);
|
||||
|
||||
const warnings = useMemo(() => {
|
||||
return getGlobalReferenceImageWarnings(entity, mainModelConfig);
|
||||
}, [entity, mainModelConfig]);
|
||||
|
||||
const deleteRefImage = useCallback(() => {
|
||||
dispatch(refImageDeleted({ id }));
|
||||
}, [dispatch, id]);
|
||||
@@ -42,6 +52,18 @@ export const RefImageHeader = memo(() => {
|
||||
Reference Image #{refImageNumber}
|
||||
</Text>
|
||||
<Flex alignItems="center" gap={1}>
|
||||
{warnings.length > 0 && (
|
||||
<IconButton
|
||||
as="span"
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
aria-label="warnings"
|
||||
tooltip={<RefImageWarningTooltipContent warnings={warnings} />}
|
||||
icon={<PiWarningBold />}
|
||||
colorScheme="warning"
|
||||
/>
|
||||
)}
|
||||
{!entity.isEnabled && (
|
||||
<Text fontSize="xs" fontStyle="italic" color="base.400">
|
||||
Disabled
|
||||
|
||||
@@ -61,7 +61,7 @@ export const RefImageImage = memo(
|
||||
)}
|
||||
{imageDTO && (
|
||||
<>
|
||||
<DndImage imageDTO={imageDTO} borderWidth={1} borderStyle="solid" w="full" />
|
||||
<DndImage imageDTO={imageDTO} borderRadius="base" borderWidth={1} borderStyle="solid" w="full" />
|
||||
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
|
||||
<DndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import { Button, Collapse, Divider, Flex } from '@invoke-ai/ui-library';
|
||||
import { Button, Collapse, Divider, Flex, IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import { RefImagePreview } from 'features/controlLayers/components/RefImage/RefImagePreview';
|
||||
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { RefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
|
||||
import { getDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { useNewGlobalReferenceImageFromBbox } from 'features/controlLayers/hooks/saveCanvasHooks';
|
||||
import { useCanvasIsBusySafe } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import {
|
||||
refImageAdded,
|
||||
selectIsRefImagePanelOpen,
|
||||
@@ -13,8 +16,10 @@ import {
|
||||
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
|
||||
import { addGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { PiUploadBold } from 'react-icons/pi';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiBoundingBoxBold, PiUploadBold } from 'react-icons/pi';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
import { RefImageHeader } from './RefImageHeader';
|
||||
@@ -78,6 +83,7 @@ MaxRefImages.displayName = 'MaxRefImages';
|
||||
|
||||
const AddRefImageDropTargetAndButton = memo(() => {
|
||||
const { dispatch, getState } = useAppStore();
|
||||
const tab = useAppSelector(selectActiveTab);
|
||||
|
||||
const uploadOptions = useMemo(
|
||||
() =>
|
||||
@@ -95,7 +101,7 @@ const AddRefImageDropTargetAndButton = memo(() => {
|
||||
const uploadApi = useImageUploadButton(uploadOptions);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex gap={1} h="full" w="full">
|
||||
<Button
|
||||
position="relative"
|
||||
size="sm"
|
||||
@@ -112,7 +118,31 @@ const AddRefImageDropTargetAndButton = memo(() => {
|
||||
<input {...uploadApi.getUploadInputProps()} />
|
||||
<DndDropTarget label="Drop" dndTarget={addGlobalReferenceImageDndTarget} dndTargetData={dndTargetData} />
|
||||
</Button>
|
||||
</>
|
||||
{tab === 'canvas' && (
|
||||
<CanvasManagerProviderGate>
|
||||
<BboxButton />
|
||||
</CanvasManagerProviderGate>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
const BboxButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const isBusy = useCanvasIsBusySafe();
|
||||
const newGlobalReferenceImageFromBbox = useNewGlobalReferenceImageFromBbox();
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
size="lg"
|
||||
variant="outline"
|
||||
h="full"
|
||||
icon={<PiBoundingBoxBold />}
|
||||
onClick={newGlobalReferenceImageFromBbox}
|
||||
isDisabled={isBusy}
|
||||
aria-label={t('controlLayers.pullBboxIntoReferenceImage')}
|
||||
tooltip={t('controlLayers.pullBboxIntoReferenceImage')}
|
||||
/>
|
||||
);
|
||||
});
|
||||
AddRefImageDropTargetAndButton.displayName = 'AddRefImageDropTargetAndButton';
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex, Icon, IconButton, Image, Skeleton, Text } from '@invoke-ai/ui-library';
|
||||
import { Flex, Icon, IconButton, Image, Skeleton, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { round } from 'es-toolkit/compat';
|
||||
@@ -17,6 +17,8 @@ import { memo, useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { PiExclamationMarkBold, PiEyeSlashBold, PiImageBold } from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
|
||||
import { RefImageWarningTooltipContent } from './RefImageWarningTooltipContent';
|
||||
|
||||
const baseSx: SystemStyleObject = {
|
||||
'&[data-is-open="true"]': {
|
||||
borderColor: 'invokeBlue.300',
|
||||
@@ -51,9 +53,6 @@ const getImageSxWithWeight = (weight: number): SystemStyleObject => {
|
||||
|
||||
return {
|
||||
...baseSx,
|
||||
'&[data-is-disabled="true"]': {
|
||||
opacity: 0.4,
|
||||
},
|
||||
_after: {
|
||||
content: '""',
|
||||
position: 'absolute',
|
||||
@@ -95,8 +94,8 @@ export const RefImagePreview = memo(() => {
|
||||
};
|
||||
}, [entity.config]);
|
||||
|
||||
const isInvalid = useMemo(() => {
|
||||
return getGlobalReferenceImageWarnings(entity, mainModelConfig).length > 0;
|
||||
const warnings = useMemo(() => {
|
||||
return getGlobalReferenceImageWarnings(entity, mainModelConfig);
|
||||
}, [entity, mainModelConfig]);
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
@@ -126,74 +125,76 @@ export const RefImagePreview = memo(() => {
|
||||
);
|
||||
}
|
||||
return (
|
||||
<Flex
|
||||
position="relative"
|
||||
borderWidth={1}
|
||||
borderStyle="solid"
|
||||
borderRadius="base"
|
||||
aspectRatio="1/1"
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
flexShrink={0}
|
||||
sx={sx}
|
||||
data-is-open={selectedEntityId === id && isPanelOpen}
|
||||
data-is-error={isInvalid}
|
||||
data-is-disabled={!entity.isEnabled}
|
||||
role="button"
|
||||
onClick={onClick}
|
||||
cursor="pointer"
|
||||
>
|
||||
<Image
|
||||
src={imageDTO?.thumbnail_url}
|
||||
objectFit="contain"
|
||||
<Tooltip label={warnings.length > 0 ? <RefImageWarningTooltipContent warnings={warnings} /> : undefined}>
|
||||
<Flex
|
||||
position="relative"
|
||||
borderWidth={1}
|
||||
borderStyle="solid"
|
||||
borderRadius="base"
|
||||
aspectRatio="1/1"
|
||||
height={imageDTO?.height}
|
||||
fallback={<Skeleton h="full" aspectRatio="1/1" />}
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
borderRadius="base"
|
||||
/>
|
||||
{isIPAdapterConfig(entity.config) && (
|
||||
<Flex
|
||||
position="absolute"
|
||||
inset={0}
|
||||
fontWeight="semibold"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
zIndex={1}
|
||||
data-visible={showWeightDisplay}
|
||||
sx={weightDisplaySx}
|
||||
>
|
||||
<Text filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))">
|
||||
{`${round(entity.config.weight * 100, 2)}%`}
|
||||
</Text>
|
||||
</Flex>
|
||||
)}
|
||||
{!entity.isEnabled && (
|
||||
<Icon
|
||||
position="absolute"
|
||||
top="50%"
|
||||
left="50%"
|
||||
transform="translateX(-50%) translateY(-50%)"
|
||||
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
|
||||
color="base.300"
|
||||
boxSize={8}
|
||||
as={PiEyeSlashBold}
|
||||
flexShrink={0}
|
||||
sx={sx}
|
||||
data-is-open={selectedEntityId === id && isPanelOpen}
|
||||
data-is-error={warnings.length > 0}
|
||||
data-is-disabled={!entity.isEnabled}
|
||||
role="button"
|
||||
onClick={onClick}
|
||||
cursor="pointer"
|
||||
overflow="hidden"
|
||||
>
|
||||
<Image
|
||||
src={imageDTO?.thumbnail_url}
|
||||
objectFit="contain"
|
||||
aspectRatio="1/1"
|
||||
height={imageDTO?.height}
|
||||
fallback={<Skeleton h="full" aspectRatio="1/1" />}
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
/>
|
||||
)}
|
||||
{entity.isEnabled && isInvalid && (
|
||||
<Icon
|
||||
position="absolute"
|
||||
top="50%"
|
||||
left="50%"
|
||||
transform="translateX(-50%) translateY(-50%)"
|
||||
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
|
||||
color="error.500"
|
||||
boxSize={12}
|
||||
as={PiExclamationMarkBold}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
{isIPAdapterConfig(entity.config) && (
|
||||
<Flex
|
||||
position="absolute"
|
||||
inset={0}
|
||||
fontWeight="semibold"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
zIndex={1}
|
||||
data-visible={showWeightDisplay}
|
||||
sx={weightDisplaySx}
|
||||
>
|
||||
<Text filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))">
|
||||
{`${round(entity.config.weight * 100, 2)}%`}
|
||||
</Text>
|
||||
</Flex>
|
||||
)}
|
||||
{!entity.isEnabled && (
|
||||
<Icon
|
||||
position="absolute"
|
||||
top="50%"
|
||||
left="50%"
|
||||
transform="translateX(-50%) translateY(-50%)"
|
||||
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
|
||||
color="base.300"
|
||||
boxSize={8}
|
||||
as={PiEyeSlashBold}
|
||||
/>
|
||||
)}
|
||||
{entity.isEnabled && warnings.length > 0 && (
|
||||
<Icon
|
||||
position="absolute"
|
||||
top="50%"
|
||||
left="50%"
|
||||
transform="translateX(-50%) translateY(-50%)"
|
||||
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
|
||||
color="error.500"
|
||||
boxSize={12}
|
||||
as={PiExclamationMarkBold}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
);
|
||||
});
|
||||
RefImagePreview.displayName = 'RefImagePreview';
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
import { Flex, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library';
|
||||
import { upperFirst } from 'es-toolkit/compat';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const RefImageWarningTooltipContent = ({ warnings }: { warnings: string[] }) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Flex flexDir="column">
|
||||
<Text fontWeight="semibold">Invalid Reference Image:</Text>
|
||||
<UnorderedList>
|
||||
{warnings.map((tKey) => (
|
||||
<ListItem key={tKey}>{upperFirst(t(tKey))}</ListItem>
|
||||
))}
|
||||
</UnorderedList>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -26,6 +26,7 @@ import { CanvasSettingsPreserveMaskCheckbox } from 'features/controlLayers/compo
|
||||
import { CanvasSettingsPressureSensitivityCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsPressureSensitivity';
|
||||
import { CanvasSettingsRecalculateRectsButton } from 'features/controlLayers/components/Settings/CanvasSettingsRecalculateRectsButton';
|
||||
import { CanvasSettingsRuleOfThirdsSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsRuleOfThirdsGuideSwitch';
|
||||
import { CanvasSettingsSaveAllImagesToGalleryCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsSaveAllImagesToGalleryCheckbox';
|
||||
import { CanvasSettingsShowHUDSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsShowHUDSwitch';
|
||||
import { CanvasSettingsShowProgressOnCanvas } from 'features/controlLayers/components/Settings/CanvasSettingsShowProgressOnCanvasSwitch';
|
||||
import { memo } from 'react';
|
||||
@@ -61,6 +62,7 @@ export const CanvasSettingsPopover = memo(() => {
|
||||
<CanvasSettingsPreserveMaskCheckbox />
|
||||
<CanvasSettingsClipToBboxCheckbox />
|
||||
<CanvasSettingsOutputOnlyMaskedRegionsCheckbox />
|
||||
<CanvasSettingsSaveAllImagesToGalleryCheckbox />
|
||||
</Flex>
|
||||
|
||||
<Divider />
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
selectSaveAllImagesToGallery,
|
||||
settingsSaveAllImagesToGalleryToggled,
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const CanvasSettingsSaveAllImagesToGalleryCheckbox = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const saveAllImagesToGallery = useAppSelector(selectSaveAllImagesToGallery);
|
||||
const onChange = useCallback(() => {
|
||||
dispatch(settingsSaveAllImagesToGalleryToggled());
|
||||
}, [dispatch]);
|
||||
return (
|
||||
<FormControl w="full">
|
||||
<FormLabel flexGrow={1}>{t('controlLayers.saveAllImagesToGallery')}</FormLabel>
|
||||
<Checkbox isChecked={saveAllImagesToGallery} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
});
|
||||
|
||||
CanvasSettingsSaveAllImagesToGalleryCheckbox.displayName = 'CanvasSettingsSaveAllImagesToGalleryCheckbox';
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Button, Flex, Grid, Heading, Text } from '@invoke-ai/ui-library';
|
||||
import { Button, Flex, Grid, Text } from '@invoke-ai/ui-library';
|
||||
import { navigationApi } from 'features/ui/layouts/navigation-api';
|
||||
import { WORKSPACE_PANEL_ID } from 'features/ui/layouts/shared';
|
||||
import { memo, useCallback } from 'react';
|
||||
@@ -6,6 +6,7 @@ import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { InitialStateMainModelPicker } from './InitialStateMainModelPicker';
|
||||
import { LaunchpadAddStyleReference } from './LaunchpadAddStyleReference';
|
||||
import { LaunchpadContainer } from './LaunchpadContainer';
|
||||
import { LaunchpadEditImageButton } from './LaunchpadEditImageButton';
|
||||
import { LaunchpadGenerateFromTextButton } from './LaunchpadGenerateFromTextButton';
|
||||
import { LaunchpadUseALayoutImageButton } from './LaunchpadUseALayoutImageButton';
|
||||
@@ -16,35 +17,30 @@ export const CanvasLaunchpadPanel = memo(() => {
|
||||
navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
|
||||
}, []);
|
||||
return (
|
||||
<Flex flexDir="column" h="full" w="full" alignItems="center" gap={2}>
|
||||
<Flex flexDir="column" w="full" gap={4} px={14} maxW={768} pt="20vh">
|
||||
<Heading mb={4}>{t('ui.launchpad.canvasTitle')}</Heading>
|
||||
<Flex flexDir="column" gap={8}>
|
||||
<Grid gridTemplateColumns="1fr 1fr" gap={8}>
|
||||
<InitialStateMainModelPicker />
|
||||
<Flex flexDir="column" gap={2} justifyContent="center">
|
||||
<Text>
|
||||
{t('ui.launchpad.modelGuideText')}{' '}
|
||||
<Button
|
||||
as="a"
|
||||
variant="link"
|
||||
href="https://support.invoke.ai/support/solutions/articles/151000216086-model-guide"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
size="sm"
|
||||
>
|
||||
{t('ui.launchpad.modelGuideLink')}
|
||||
</Button>
|
||||
</Text>
|
||||
</Flex>
|
||||
</Grid>
|
||||
<LaunchpadGenerateFromTextButton extraAction={focusCanvas} />
|
||||
<LaunchpadAddStyleReference extraAction={focusCanvas} />
|
||||
<LaunchpadEditImageButton extraAction={focusCanvas} />
|
||||
<LaunchpadUseALayoutImageButton extraAction={focusCanvas} />
|
||||
<LaunchpadContainer heading={t('ui.launchpad.canvasTitle')}>
|
||||
<Grid gridTemplateColumns="1fr 1fr" gap={8}>
|
||||
<InitialStateMainModelPicker />
|
||||
<Flex flexDir="column" gap={2} justifyContent="center">
|
||||
<Text>
|
||||
{t('ui.launchpad.modelGuideText')}{' '}
|
||||
<Button
|
||||
as="a"
|
||||
variant="link"
|
||||
href="https://support.invoke.ai/support/solutions/articles/151000216086-model-guide"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
size="sm"
|
||||
>
|
||||
{t('ui.launchpad.modelGuideLink')}
|
||||
</Button>
|
||||
</Text>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Grid>
|
||||
<LaunchpadGenerateFromTextButton extraAction={focusCanvas} />
|
||||
<LaunchpadAddStyleReference extraAction={focusCanvas} />
|
||||
<LaunchpadEditImageButton extraAction={focusCanvas} />
|
||||
<LaunchpadUseALayoutImageButton extraAction={focusCanvas} />
|
||||
</LaunchpadContainer>
|
||||
);
|
||||
});
|
||||
CanvasLaunchpadPanel.displayName = 'CanvasLaunchpadPanel';
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { Alert, Button, Flex, Grid, Heading, Text } from '@invoke-ai/ui-library';
|
||||
import { Alert, Button, Flex, Grid, Text } from '@invoke-ai/ui-library';
|
||||
import { InitialStateMainModelPicker } from 'features/controlLayers/components/SimpleSession/InitialStateMainModelPicker';
|
||||
import { LaunchpadAddStyleReference } from 'features/controlLayers/components/SimpleSession/LaunchpadAddStyleReference';
|
||||
import { navigationApi } from 'features/ui/layouts/navigation-api';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
import { LaunchpadContainer } from './LaunchpadContainer';
|
||||
import { LaunchpadGenerateFromTextButton } from './LaunchpadGenerateFromTextButton';
|
||||
|
||||
export const GenerateLaunchpadPanel = memo(() => {
|
||||
@@ -12,41 +13,36 @@ export const GenerateLaunchpadPanel = memo(() => {
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" h="full" w="full" alignItems="center" gap={2}>
|
||||
<Flex flexDir="column" w="full" gap={4} px={14} maxW={768} pt="20vh">
|
||||
<Heading mb={4}>Generate images from text prompts.</Heading>
|
||||
<Flex flexDir="column" gap={8}>
|
||||
<Grid gridTemplateColumns="1fr 1fr" gap={8}>
|
||||
<InitialStateMainModelPicker />
|
||||
<Flex flexDir="column" gap={2} justifyContent="center">
|
||||
<Text>
|
||||
Want to learn what prompts work best for each model?{' '}
|
||||
<Button
|
||||
as="a"
|
||||
variant="link"
|
||||
href="https://support.invoke.ai/support/solutions/articles/151000216086-model-guide"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
size="sm"
|
||||
>
|
||||
Check out our Model Guide.
|
||||
</Button>
|
||||
</Text>
|
||||
</Flex>
|
||||
</Grid>
|
||||
<LaunchpadGenerateFromTextButton />
|
||||
<LaunchpadAddStyleReference />
|
||||
<Alert status="info" borderRadius="base" flexDir="column" gap={2} overflow="unset">
|
||||
<Text fontSize="md" fontWeight="semibold">
|
||||
Looking to get more control, edit, and iterate on your images?
|
||||
</Text>
|
||||
<Button variant="link" onClick={newCanvasSession}>
|
||||
Navigate to Canvas for more capabilities.
|
||||
<LaunchpadContainer heading="Generate images from text prompts.">
|
||||
<Grid gridTemplateColumns="1fr 1fr" gap={8}>
|
||||
<InitialStateMainModelPicker />
|
||||
<Flex flexDir="column" gap={2} justifyContent="center">
|
||||
<Text>
|
||||
Want to learn what prompts work best for each model?{' '}
|
||||
<Button
|
||||
as="a"
|
||||
variant="link"
|
||||
href="https://support.invoke.ai/support/solutions/articles/151000216086-model-guide"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
size="sm"
|
||||
>
|
||||
Check out our Model Guide.
|
||||
</Button>
|
||||
</Alert>
|
||||
</Text>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Grid>
|
||||
<LaunchpadGenerateFromTextButton />
|
||||
<LaunchpadAddStyleReference />
|
||||
<Alert status="info" borderRadius="base" flexDir="column" gap={2} overflow="unset">
|
||||
<Text fontSize="md" fontWeight="semibold">
|
||||
Looking to get more control, edit, and iterate on your images?
|
||||
</Text>
|
||||
<Button variant="link" onClick={newCanvasSession}>
|
||||
Navigate to Canvas for more capabilities.
|
||||
</Button>
|
||||
</Alert>
|
||||
</LaunchpadContainer>
|
||||
);
|
||||
});
|
||||
GenerateLaunchpadPanel.displayName = 'GenerateLaunchpad';
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
import { Flex, Heading } from '@invoke-ai/ui-library';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const LaunchpadContainer = memo((props: PropsWithChildren<{ heading: string }>) => {
|
||||
return (
|
||||
<Flex flexDir="column" h="full" w="full" alignItems="center" justifyContent="center" gap={2}>
|
||||
<Flex flexDir="column" w="full" gap={4} px={14} maxW={768}>
|
||||
<Heading>{props.heading}</Heading>
|
||||
<Flex flexDir="column" gap={4}>
|
||||
{props.children}
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
LaunchpadContainer.displayName = 'LaunchpadContainer';
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
useCanvasSessionContext,
|
||||
useOutputImageDTO,
|
||||
@@ -10,6 +11,10 @@ import { QueueItemNumber } from 'features/controlLayers/components/SimpleSession
|
||||
import { QueueItemProgressImage } from 'features/controlLayers/components/SimpleSession/QueueItemProgressImage';
|
||||
import { QueueItemStatusLabel } from 'features/controlLayers/components/SimpleSession/QueueItemStatusLabel';
|
||||
import { getQueueItemElementId } from 'features/controlLayers/components/SimpleSession/shared';
|
||||
import {
|
||||
selectStagingAreaAutoSwitch,
|
||||
settingsStagingAreaAutoSwitchChanged,
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { DndImage } from 'features/dnd/DndImage';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback } from 'react';
|
||||
@@ -21,12 +26,13 @@ const sx = {
|
||||
pos: 'relative',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
h: 108,
|
||||
w: 108,
|
||||
flexShrink: 0,
|
||||
h: 'full',
|
||||
aspectRatio: '1/1',
|
||||
borderWidth: 2,
|
||||
borderRadius: 'base',
|
||||
bg: 'base.900',
|
||||
overflow: 'hidden',
|
||||
'&[data-selected="true"]': {
|
||||
borderColor: 'invokeBlue.300',
|
||||
},
|
||||
@@ -34,28 +40,29 @@ const sx = {
|
||||
|
||||
type Props = {
|
||||
item: S['SessionQueueItem'];
|
||||
number: number;
|
||||
index: number;
|
||||
isSelected: boolean;
|
||||
};
|
||||
|
||||
export const QueueItemPreviewMini = memo(({ item, isSelected, number }: Props) => {
|
||||
export const QueueItemPreviewMini = memo(({ item, isSelected, index }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const ctx = useCanvasSessionContext();
|
||||
const { imageLoaded } = useProgressData(ctx.$progressData, item.item_id);
|
||||
const imageDTO = useOutputImageDTO(item);
|
||||
const autoSwitch = useAppSelector(selectStagingAreaAutoSwitch);
|
||||
|
||||
const onClick = useCallback(() => {
|
||||
ctx.$selectedItemId.set(item.item_id);
|
||||
}, [ctx.$selectedItemId, item.item_id]);
|
||||
|
||||
const onDoubleClick = useCallback(() => {
|
||||
const autoSwitch = ctx.$autoSwitch.get();
|
||||
if (autoSwitch !== 'off') {
|
||||
ctx.$autoSwitch.set('off');
|
||||
dispatch(settingsStagingAreaAutoSwitchChanged('off'));
|
||||
toast({
|
||||
title: 'Auto-Switch Disabled',
|
||||
});
|
||||
}
|
||||
}, [ctx.$autoSwitch]);
|
||||
}, [autoSwitch, dispatch]);
|
||||
|
||||
const onLoad = useCallback(() => {
|
||||
ctx.onImageLoad(item.item_id);
|
||||
@@ -63,16 +70,16 @@ export const QueueItemPreviewMini = memo(({ item, isSelected, number }: Props) =
|
||||
|
||||
return (
|
||||
<Flex
|
||||
id={getQueueItemElementId(item.item_id)}
|
||||
id={getQueueItemElementId(index)}
|
||||
sx={sx}
|
||||
data-selected={isSelected}
|
||||
onClick={onClick}
|
||||
onDoubleClick={onDoubleClick}
|
||||
>
|
||||
<QueueItemStatusLabel item={item} position="absolute" margin="auto" />
|
||||
{imageDTO && <DndImage imageDTO={imageDTO} onLoad={onLoad} asThumbnail />}
|
||||
{imageDTO && <DndImage imageDTO={imageDTO} onLoad={onLoad} asThumbnail position="absolute" />}
|
||||
{!imageLoaded && <QueueItemProgressImage itemId={item.item_id} position="absolute" />}
|
||||
<QueueItemNumber number={number} position="absolute" top={0} left={1} />
|
||||
<QueueItemNumber number={index + 1} position="absolute" top={0} left={1} />
|
||||
<QueueItemCircularProgress itemId={item.item_id} status={item.status} position="absolute" top={1} right={2} />
|
||||
</Flex>
|
||||
);
|
||||
|
||||
@@ -16,21 +16,21 @@ export const QueueItemStatusLabel = memo(({ item, ...rest }: Props) => {
|
||||
|
||||
if (item.status === 'pending') {
|
||||
return (
|
||||
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="base.300" {...rest}>
|
||||
<Text fontSize="xs" pointerEvents="none" userSelect="none" fontWeight="semibold" color="base.300" {...rest}>
|
||||
Pending
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
if (item.status === 'canceled') {
|
||||
return (
|
||||
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="warning.300" {...rest}>
|
||||
<Text fontSize="xs" pointerEvents="none" userSelect="none" fontWeight="semibold" color="warning.300" {...rest}>
|
||||
Canceled
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
if (item.status === 'failed') {
|
||||
return (
|
||||
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="error.300" {...rest}>
|
||||
<Text fontSize="xs" pointerEvents="none" userSelect="none" fontWeight="semibold" color="error.300" {...rest}>
|
||||
Failed
|
||||
</Text>
|
||||
);
|
||||
@@ -38,7 +38,7 @@ export const QueueItemStatusLabel = memo(({ item, ...rest }: Props) => {
|
||||
|
||||
if (item.status === 'in_progress') {
|
||||
return (
|
||||
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="invokeBlue.300" {...rest}>
|
||||
<Text fontSize="xs" pointerEvents="none" userSelect="none" fontWeight="semibold" color="invokeBlue.300" {...rest}>
|
||||
In Progress
|
||||
</Text>
|
||||
);
|
||||
@@ -46,7 +46,14 @@ export const QueueItemStatusLabel = memo(({ item, ...rest }: Props) => {
|
||||
|
||||
if (item.status === 'completed') {
|
||||
return (
|
||||
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="invokeGreen.300" {...rest}>
|
||||
<Text
|
||||
fontSize="xs"
|
||||
pointerEvents="none"
|
||||
userSelect="none"
|
||||
fontWeight="semibold"
|
||||
color="invokeGreen.300"
|
||||
{...rest}
|
||||
>
|
||||
Completed
|
||||
</Text>
|
||||
);
|
||||
|
||||
@@ -1,17 +1,149 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, forwardRef } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { QueueItemPreviewMini } from 'features/controlLayers/components/SimpleSession/QueueItemPreviewMini';
|
||||
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { memo, useEffect } from 'react';
|
||||
import { useOverlayScrollbars } from 'overlayscrollbars-react';
|
||||
import type { CSSProperties, RefObject } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import type { Components, ItemContent, ListRange, VirtuosoHandle, VirtuosoProps } from 'react-virtuoso';
|
||||
import { Virtuoso } from 'react-virtuoso';
|
||||
import type { S } from 'services/api/types';
|
||||
|
||||
import { getQueueItemElementId } from './shared';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
const virtuosoStyles = {
|
||||
width: '100%',
|
||||
height: '72px',
|
||||
} satisfies CSSProperties;
|
||||
|
||||
type VirtuosoContext = { selectedItemId: number | null };
|
||||
|
||||
/**
|
||||
* Scroll the item at the given index into view if it is not currently visible.
|
||||
*/
|
||||
const scrollIntoView = (
|
||||
targetIndex: number,
|
||||
rootEl: HTMLDivElement,
|
||||
virtuosoHandle: VirtuosoHandle,
|
||||
range: ListRange
|
||||
) => {
|
||||
if (range.endIndex === 0) {
|
||||
// No range is rendered; no need to scroll to anything.
|
||||
return;
|
||||
}
|
||||
|
||||
const targetItem = rootEl.querySelector(`#${getQueueItemElementId(targetIndex)}`);
|
||||
|
||||
if (!targetItem) {
|
||||
if (targetIndex > range.endIndex) {
|
||||
virtuosoHandle.scrollToIndex({
|
||||
index: targetIndex,
|
||||
behavior: 'auto',
|
||||
align: 'end',
|
||||
});
|
||||
} else if (targetIndex < range.startIndex) {
|
||||
virtuosoHandle.scrollToIndex({
|
||||
index: targetIndex,
|
||||
behavior: 'auto',
|
||||
align: 'start',
|
||||
});
|
||||
} else {
|
||||
log.debug(
|
||||
`Unable to find queue item at index ${targetIndex} but it is in the rendered range ${range.startIndex}-${range.endIndex}`
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// We found the image in the DOM, but it might be in the overscan range - rendered but not in the visible viewport.
|
||||
// Check if it is in the viewport and scroll if necessary.
|
||||
|
||||
const itemRect = targetItem.getBoundingClientRect();
|
||||
const rootRect = rootEl.getBoundingClientRect();
|
||||
|
||||
if (itemRect.left < rootRect.left) {
|
||||
virtuosoHandle.scrollToIndex({
|
||||
index: targetIndex,
|
||||
behavior: 'auto',
|
||||
align: 'start',
|
||||
});
|
||||
} else if (itemRect.right > rootRect.right) {
|
||||
virtuosoHandle.scrollToIndex({
|
||||
index: targetIndex,
|
||||
behavior: 'auto',
|
||||
align: 'end',
|
||||
});
|
||||
} else {
|
||||
// Image is already in view
|
||||
}
|
||||
|
||||
return;
|
||||
};
|
||||
|
||||
const useScrollableStagingArea = (rootRef: RefObject<HTMLDivElement>) => {
|
||||
const [scroller, scrollerRef] = useState<HTMLElement | null>(null);
|
||||
const [initialize, osInstance] = useOverlayScrollbars({
|
||||
defer: true,
|
||||
events: {
|
||||
initialized(osInstance) {
|
||||
// force overflow styles
|
||||
const { viewport } = osInstance.elements();
|
||||
viewport.style.overflowX = `var(--os-viewport-overflow-x)`;
|
||||
viewport.style.overflowY = `var(--os-viewport-overflow-y)`;
|
||||
viewport.style.textAlign = 'center';
|
||||
},
|
||||
},
|
||||
options: {
|
||||
scrollbars: {
|
||||
visibility: 'auto',
|
||||
autoHide: 'scroll',
|
||||
autoHideDelay: 1300,
|
||||
theme: 'os-theme-dark',
|
||||
},
|
||||
overflow: {
|
||||
y: 'hidden',
|
||||
x: 'scroll',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
const { current: root } = rootRef;
|
||||
|
||||
if (scroller && root) {
|
||||
initialize({
|
||||
target: root,
|
||||
elements: {
|
||||
viewport: scroller,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
return () => {
|
||||
osInstance()?.destroy();
|
||||
};
|
||||
}, [scroller, initialize, osInstance, rootRef]);
|
||||
|
||||
return scrollerRef;
|
||||
};
|
||||
|
||||
export const StagingAreaItemsList = memo(() => {
|
||||
const canvasManager = useCanvasManagerSafe();
|
||||
const ctx = useCanvasSessionContext();
|
||||
const virtuosoRef = useRef<VirtuosoHandle>(null);
|
||||
const rangeRef = useRef<ListRange>({ startIndex: 0, endIndex: 0 });
|
||||
const rootRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const items = useStore(ctx.$items);
|
||||
const selectedItemId = useStore(ctx.$selectedItemId);
|
||||
|
||||
const context = useMemo(() => ({ selectedItemId }), [selectedItemId]);
|
||||
const scrollerRef = useScrollableStagingArea(rootRef);
|
||||
|
||||
useEffect(() => {
|
||||
if (!canvasManager) {
|
||||
return;
|
||||
@@ -20,19 +152,64 @@ export const StagingAreaItemsList = memo(() => {
|
||||
return canvasManager.stagingArea.connectToSession(ctx.$selectedItemId, ctx.$progressData, ctx.$isPending);
|
||||
}, [canvasManager, ctx.$progressData, ctx.$selectedItemId, ctx.$isPending]);
|
||||
|
||||
useEffect(() => {
|
||||
return ctx.$selectedItemIndex.listen((index) => {
|
||||
if (!virtuosoRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!rootRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (index === null) {
|
||||
return;
|
||||
}
|
||||
|
||||
scrollIntoView(index, rootRef.current, virtuosoRef.current, rangeRef.current);
|
||||
});
|
||||
}, [ctx.$selectedItemIndex]);
|
||||
|
||||
const onRangeChanged = useCallback((range: ListRange) => {
|
||||
rangeRef.current = range;
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<ScrollableContent overflowX="scroll" overflowY="hidden">
|
||||
<Flex gap={2} w="full" h="full" justifyContent="safe center">
|
||||
{items.map((item, i) => (
|
||||
<QueueItemPreviewMini
|
||||
key={`${item.item_id}-mini`}
|
||||
item={item}
|
||||
number={i + 1}
|
||||
isSelected={selectedItemId === item.item_id}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
<Box data-overlayscrollbars-initialize="" ref={rootRef} position="relative" w="full" h="full">
|
||||
<Virtuoso<S['SessionQueueItem'], VirtuosoContext>
|
||||
ref={virtuosoRef}
|
||||
context={context}
|
||||
data={items}
|
||||
horizontalDirection
|
||||
style={virtuosoStyles}
|
||||
itemContent={itemContent}
|
||||
components={components}
|
||||
rangeChanged={onRangeChanged}
|
||||
// Virtuoso expects the ref to be of HTMLElement | null | Window, but overlayscrollbars doesn't allow Window
|
||||
scrollerRef={scrollerRef as VirtuosoProps<S['SessionQueueItem'], VirtuosoContext>['scrollerRef']}
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
});
|
||||
StagingAreaItemsList.displayName = 'StagingAreaItemsList';
|
||||
|
||||
const itemContent: ItemContent<S['SessionQueueItem'], VirtuosoContext> = (index, item, { selectedItemId }) => (
|
||||
<QueueItemPreviewMini
|
||||
key={`${item.item_id}-mini`}
|
||||
item={item}
|
||||
index={index}
|
||||
isSelected={selectedItemId === item.item_id}
|
||||
/>
|
||||
);
|
||||
|
||||
const listSx = {
|
||||
'& > * + *': {
|
||||
pl: 2,
|
||||
},
|
||||
};
|
||||
|
||||
const components: Components<S['SessionQueueItem'], VirtuosoContext> = {
|
||||
List: forwardRef(({ context: _, ...rest }, ref) => {
|
||||
return <Flex ref={ref} sx={listSx} {...rest} />;
|
||||
}),
|
||||
};
|
||||
|
||||
@@ -24,6 +24,7 @@ import {
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
import { LaunchpadButton } from './LaunchpadButton';
|
||||
import { LaunchpadContainer } from './LaunchpadContainer';
|
||||
|
||||
export const UpscalingLaunchpadPanel = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
@@ -65,108 +66,104 @@ export const UpscalingLaunchpadPanel = memo(() => {
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" h="full" w="full" alignItems="center" gap={2}>
|
||||
<Flex flexDir="column" w="full" gap={8} px={14} maxW={768} pt="20vh">
|
||||
<Heading>{t('ui.launchpad.upscalingTitle')}</Heading>
|
||||
|
||||
{/* Upload Area */}
|
||||
<LaunchpadButton {...uploadApi.getUploadButtonProps()} position="relative" gap={8}>
|
||||
{!upscaleInitialImage ? (
|
||||
<>
|
||||
<Icon as={PiImageBold} boxSize={8} color="base.500" />
|
||||
<Flex flexDir="column" alignItems="flex-start" gap={2}>
|
||||
<Heading size="sm">{t('ui.launchpad.upscaling.uploadImage.title')}</Heading>
|
||||
<Text color="base.300">{t('ui.launchpad.upscaling.uploadImage.description')}</Text>
|
||||
</Flex>
|
||||
<Flex position="absolute" right={3} bottom={3}>
|
||||
<PiUploadBold />
|
||||
<input {...uploadApi.getUploadInputProps()} />
|
||||
</Flex>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Icon as={PiImageBold} boxSize={8} color="base.500" />
|
||||
<Flex flexDir="column" alignItems="flex-start" gap={2}>
|
||||
<Heading size="sm">{t('ui.launchpad.upscaling.replaceImage.title')}</Heading>
|
||||
<Text color="base.300">{t('ui.launchpad.upscaling.replaceImage.description')}</Text>
|
||||
</Flex>
|
||||
<Flex position="absolute" right={3} bottom={3}>
|
||||
<PiUploadBold />
|
||||
<input {...uploadApi.getUploadInputProps()} />
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
<DndDropTarget
|
||||
dndTarget={setUpscaleInitialImageDndTarget}
|
||||
dndTargetData={dndTargetData}
|
||||
label={t('gallery.drop')}
|
||||
/>
|
||||
</LaunchpadButton>
|
||||
|
||||
{/* Guidance text */}
|
||||
{upscaleInitialImage && (
|
||||
<Flex bg="base.800" p={4} borderRadius="base" border="1px solid" borderColor="base.700">
|
||||
<Text variant="subtext" fontSize="sm" lineHeight="1.6">
|
||||
<strong>{t('ui.launchpad.upscaling.readyToUpscale.title')}</strong>{' '}
|
||||
{t('ui.launchpad.upscaling.readyToUpscale.description')}
|
||||
</Text>
|
||||
</Flex>
|
||||
<LaunchpadContainer heading={t('ui.launchpad.upscalingTitle')}>
|
||||
{/* Upload Area */}
|
||||
<LaunchpadButton {...uploadApi.getUploadButtonProps()} position="relative" gap={8}>
|
||||
{!upscaleInitialImage ? (
|
||||
<>
|
||||
<Icon as={PiImageBold} boxSize={8} color="base.500" />
|
||||
<Flex flexDir="column" alignItems="flex-start" gap={2}>
|
||||
<Heading size="sm">{t('ui.launchpad.upscaling.uploadImage.title')}</Heading>
|
||||
<Text color="base.300">{t('ui.launchpad.upscaling.uploadImage.description')}</Text>
|
||||
</Flex>
|
||||
<Flex position="absolute" right={3} bottom={3}>
|
||||
<PiUploadBold />
|
||||
<input {...uploadApi.getUploadInputProps()} />
|
||||
</Flex>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Icon as={PiImageBold} boxSize={8} color="base.500" />
|
||||
<Flex flexDir="column" alignItems="flex-start" gap={2}>
|
||||
<Heading size="sm">{t('ui.launchpad.upscaling.replaceImage.title')}</Heading>
|
||||
<Text color="base.300">{t('ui.launchpad.upscaling.replaceImage.description')}</Text>
|
||||
</Flex>
|
||||
<Flex position="absolute" right={3} bottom={3}>
|
||||
<PiUploadBold />
|
||||
<input {...uploadApi.getUploadInputProps()} />
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
<DndDropTarget
|
||||
dndTarget={setUpscaleInitialImageDndTarget}
|
||||
dndTargetData={dndTargetData}
|
||||
label={t('gallery.drop')}
|
||||
/>
|
||||
</LaunchpadButton>
|
||||
|
||||
{/* Controls */}
|
||||
<Grid gridTemplateColumns="1fr 1fr" gap={8} alignItems="start">
|
||||
{/* Left Column: Creativity and Structural Defaults */}
|
||||
<Box>
|
||||
<Text fontWeight="semibold" fontSize="sm" mb={3}>
|
||||
Creativity & Structure Defaults
|
||||
</Text>
|
||||
<ButtonGroup size="sm" orientation="vertical" variant="outline" w="full">
|
||||
<Button
|
||||
colorScheme={creativity === -5 && structure === 5 ? 'invokeBlue' : undefined}
|
||||
justifyContent="center"
|
||||
onClick={onConservativeClick}
|
||||
leftIcon={<PiShieldCheckBold />}
|
||||
>
|
||||
Conservative
|
||||
</Button>
|
||||
<Button
|
||||
colorScheme={creativity === 0 && structure === 0 ? 'invokeBlue' : undefined}
|
||||
justifyContent="center"
|
||||
onClick={onBalancedClick}
|
||||
leftIcon={<PiScalesBold />}
|
||||
>
|
||||
Balanced
|
||||
</Button>
|
||||
<Button
|
||||
colorScheme={creativity === 5 && structure === -2 ? 'invokeBlue' : undefined}
|
||||
justifyContent="center"
|
||||
onClick={onCreativeClick}
|
||||
leftIcon={<PiPaletteBold />}
|
||||
>
|
||||
Creative
|
||||
</Button>
|
||||
<Button
|
||||
colorScheme={creativity === 8 && structure === -5 ? 'invokeBlue' : undefined}
|
||||
justifyContent="center"
|
||||
onClick={onArtisticClick}
|
||||
leftIcon={<PiSparkleBold />}
|
||||
>
|
||||
Artistic
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
</Box>
|
||||
{/* Right Column: Description/help text */}
|
||||
<Box>
|
||||
<Text variant="subtext" fontSize="sm" lineHeight="1.6">
|
||||
{t('ui.launchpad.upscaling.helpText.promptAdvice')}
|
||||
</Text>
|
||||
<Text variant="subtext" fontSize="sm" lineHeight="1.6" mt={3}>
|
||||
{t('ui.launchpad.upscaling.helpText.styleAdvice')}
|
||||
</Text>
|
||||
</Box>
|
||||
</Grid>
|
||||
</Flex>
|
||||
</Flex>
|
||||
{/* Guidance text */}
|
||||
{upscaleInitialImage && (
|
||||
<Flex bg="base.800" p={4} borderRadius="base" border="1px solid" borderColor="base.700">
|
||||
<Text variant="subtext" fontSize="sm" lineHeight="1.6">
|
||||
<strong>{t('ui.launchpad.upscaling.readyToUpscale.title')}</strong>{' '}
|
||||
{t('ui.launchpad.upscaling.readyToUpscale.description')}
|
||||
</Text>
|
||||
</Flex>
|
||||
)}
|
||||
|
||||
{/* Controls */}
|
||||
<Grid gridTemplateColumns="1fr 1fr" gap={8} alignItems="start">
|
||||
{/* Left Column: Creativity and Structural Defaults */}
|
||||
<Box>
|
||||
<Text fontWeight="semibold" fontSize="sm" mb={3}>
|
||||
Creativity & Structure Defaults
|
||||
</Text>
|
||||
<ButtonGroup size="sm" orientation="vertical" variant="outline" w="full">
|
||||
<Button
|
||||
colorScheme={creativity === -5 && structure === 5 ? 'invokeBlue' : undefined}
|
||||
justifyContent="center"
|
||||
onClick={onConservativeClick}
|
||||
leftIcon={<PiShieldCheckBold />}
|
||||
>
|
||||
Conservative
|
||||
</Button>
|
||||
<Button
|
||||
colorScheme={creativity === 0 && structure === 0 ? 'invokeBlue' : undefined}
|
||||
justifyContent="center"
|
||||
onClick={onBalancedClick}
|
||||
leftIcon={<PiScalesBold />}
|
||||
>
|
||||
Balanced
|
||||
</Button>
|
||||
<Button
|
||||
colorScheme={creativity === 5 && structure === -2 ? 'invokeBlue' : undefined}
|
||||
justifyContent="center"
|
||||
onClick={onCreativeClick}
|
||||
leftIcon={<PiPaletteBold />}
|
||||
>
|
||||
Creative
|
||||
</Button>
|
||||
<Button
|
||||
colorScheme={creativity === 8 && structure === -5 ? 'invokeBlue' : undefined}
|
||||
justifyContent="center"
|
||||
onClick={onArtisticClick}
|
||||
leftIcon={<PiSparkleBold />}
|
||||
>
|
||||
Artistic
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
</Box>
|
||||
{/* Right Column: Description/help text */}
|
||||
<Box>
|
||||
<Text variant="subtext" fontSize="sm" lineHeight="1.6">
|
||||
{t('ui.launchpad.upscaling.helpText.promptAdvice')}
|
||||
</Text>
|
||||
<Text variant="subtext" fontSize="sm" lineHeight="1.6" mt={3}>
|
||||
{t('ui.launchpad.upscaling.helpText.styleAdvice')}
|
||||
</Text>
|
||||
</Box>
|
||||
</Grid>
|
||||
</LaunchpadContainer>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import { useTranslation } from 'react-i18next';
|
||||
import { PiFilePlusBold, PiFolderOpenBold, PiUploadBold } from 'react-icons/pi';
|
||||
|
||||
import { LaunchpadButton } from './LaunchpadButton';
|
||||
import { LaunchpadContainer } from './LaunchpadContainer';
|
||||
|
||||
export const WorkflowsLaunchpadPanel = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
@@ -45,63 +46,59 @@ export const WorkflowsLaunchpadPanel = memo(() => {
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" h="full" w="full" alignItems="center" gap={2}>
|
||||
<Flex flexDir="column" w="full" gap={4} px={14} maxW={768} pt="20vh">
|
||||
<Heading>{t('ui.launchpad.workflowsTitle')}</Heading>
|
||||
<LaunchpadContainer heading={t('ui.launchpad.workflowsTitle')}>
|
||||
{/* Description */}
|
||||
<Text variant="subtext" fontSize="md" lineHeight="1.6">
|
||||
{t('ui.launchpad.workflows.description')}
|
||||
</Text>
|
||||
|
||||
{/* Description */}
|
||||
<Text variant="subtext" fontSize="md" lineHeight="1.6">
|
||||
{t('ui.launchpad.workflows.description')}
|
||||
</Text>
|
||||
<Text>
|
||||
<Button
|
||||
as="a"
|
||||
variant="link"
|
||||
href="https://support.invoke.ai/support/solutions/articles/151000189610-getting-started-with-workflows-denoise-latents"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
size="sm"
|
||||
>
|
||||
{t('ui.launchpad.workflows.learnMoreLink')}
|
||||
</Button>
|
||||
</Text>
|
||||
|
||||
<Text>
|
||||
<Button
|
||||
as="a"
|
||||
variant="link"
|
||||
href="https://support.invoke.ai/support/solutions/articles/151000189610-getting-started-with-workflows-denoise-latents"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
size="sm"
|
||||
>
|
||||
{t('ui.launchpad.workflows.learnMoreLink')}
|
||||
</Button>
|
||||
</Text>
|
||||
{/* Action Buttons */}
|
||||
<Flex flexDir="column" gap={8}>
|
||||
{/* Browse Workflow Templates */}
|
||||
<LaunchpadButton onClick={handleBrowseTemplates} position="relative" gap={8}>
|
||||
<Icon as={PiFolderOpenBold} boxSize={8} color="base.500" />
|
||||
<Flex flexDir="column" alignItems="flex-start" gap={2}>
|
||||
<Heading size="sm">{t('ui.launchpad.workflows.browseTemplates.title')}</Heading>
|
||||
<Text color="base.300">{t('ui.launchpad.workflows.browseTemplates.description')}</Text>
|
||||
</Flex>
|
||||
</LaunchpadButton>
|
||||
|
||||
{/* Action Buttons */}
|
||||
<Flex flexDir="column" gap={8}>
|
||||
{/* Browse Workflow Templates */}
|
||||
<LaunchpadButton onClick={handleBrowseTemplates} position="relative" gap={8}>
|
||||
<Icon as={PiFolderOpenBold} boxSize={8} color="base.500" />
|
||||
<Flex flexDir="column" alignItems="flex-start" gap={2}>
|
||||
<Heading size="sm">{t('ui.launchpad.workflows.browseTemplates.title')}</Heading>
|
||||
<Text color="base.300">{t('ui.launchpad.workflows.browseTemplates.description')}</Text>
|
||||
</Flex>
|
||||
</LaunchpadButton>
|
||||
{/* Create a new Workflow */}
|
||||
<LaunchpadButton onClick={handleCreateNew} position="relative" gap={8}>
|
||||
<Icon as={PiFilePlusBold} boxSize={8} color="base.500" />
|
||||
<Flex flexDir="column" alignItems="flex-start" gap={2}>
|
||||
<Heading size="sm">{t('ui.launchpad.workflows.createNew.title')}</Heading>
|
||||
<Text color="base.300">{t('ui.launchpad.workflows.createNew.description')}</Text>
|
||||
</Flex>
|
||||
</LaunchpadButton>
|
||||
|
||||
{/* Create a new Workflow */}
|
||||
<LaunchpadButton onClick={handleCreateNew} position="relative" gap={8}>
|
||||
<Icon as={PiFilePlusBold} boxSize={8} color="base.500" />
|
||||
<Flex flexDir="column" alignItems="flex-start" gap={2}>
|
||||
<Heading size="sm">{t('ui.launchpad.workflows.createNew.title')}</Heading>
|
||||
<Text color="base.300">{t('ui.launchpad.workflows.createNew.description')}</Text>
|
||||
</Flex>
|
||||
</LaunchpadButton>
|
||||
|
||||
{/* Load workflow from existing image or file */}
|
||||
<LaunchpadButton {...uploadApi.getRootProps()} position="relative" gap={8}>
|
||||
<Icon as={PiUploadBold} boxSize={8} color="base.500" />
|
||||
<Flex flexDir="column" alignItems="flex-start" gap={2}>
|
||||
<Heading size="sm">{t('ui.launchpad.workflows.loadFromFile.title')}</Heading>
|
||||
<Text color="base.300">{t('ui.launchpad.workflows.loadFromFile.description')}</Text>
|
||||
</Flex>
|
||||
<Flex position="absolute" right={3} bottom={3}>
|
||||
<PiUploadBold />
|
||||
<input {...uploadApi.getInputProps()} />
|
||||
</Flex>
|
||||
</LaunchpadButton>
|
||||
</Flex>
|
||||
{/* Load workflow from existing image or file */}
|
||||
<LaunchpadButton {...uploadApi.getRootProps()} position="relative" gap={8}>
|
||||
<Icon as={PiUploadBold} boxSize={8} color="base.500" />
|
||||
<Flex flexDir="column" alignItems="flex-start" gap={2}>
|
||||
<Heading size="sm">{t('ui.launchpad.workflows.loadFromFile.title')}</Heading>
|
||||
<Text color="base.300">{t('ui.launchpad.workflows.loadFromFile.description')}</Text>
|
||||
</Flex>
|
||||
<Flex position="absolute" right={3} bottom={3}>
|
||||
<PiUploadBold />
|
||||
<input {...uploadApi.getInputProps()} />
|
||||
</Flex>
|
||||
</LaunchpadButton>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</LaunchpadContainer>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { buildZodTypeGuard } from 'common/util/zodUtils';
|
||||
import { getOutputImageName } from 'features/controlLayers/components/SimpleSession/shared';
|
||||
import { selectStagingAreaAutoSwitch } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import {
|
||||
buildSelectSessionQueueItems,
|
||||
canvasQueueItemDiscarded,
|
||||
canvasSessionReset,
|
||||
} from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import type { ProgressImage } from 'features/nodes/types/common';
|
||||
import type { Atom, MapStore, StoreValue, WritableAtom } from 'nanostores';
|
||||
import { atom, computed, effect, map, subscribeKeys } from 'nanostores';
|
||||
@@ -14,11 +17,6 @@ import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { ImageDTO, S } from 'services/api/types';
|
||||
import { $socket } from 'services/events/stores';
|
||||
import { assert, objectEntries } from 'tsafe';
|
||||
import { z } from 'zod/v4';
|
||||
|
||||
const zAutoSwitchMode = z.enum(['off', 'switch_on_start', 'switch_on_finish']);
|
||||
export const isAutoSwitchMode = buildZodTypeGuard(zAutoSwitchMode);
|
||||
type AutoSwitchMode = z.infer<typeof zAutoSwitchMode>;
|
||||
|
||||
export type ProgressData = {
|
||||
itemId: number;
|
||||
@@ -98,12 +96,13 @@ type CanvasSessionContextValue = {
|
||||
$selectedItem: Atom<S['SessionQueueItem'] | null>;
|
||||
$selectedItemIndex: Atom<number | null>;
|
||||
$selectedItemOutputImageDTO: Atom<ImageDTO | null>;
|
||||
$autoSwitch: WritableAtom<AutoSwitchMode>;
|
||||
selectNext: () => void;
|
||||
selectPrev: () => void;
|
||||
selectFirst: () => void;
|
||||
selectLast: () => void;
|
||||
onImageLoad: (itemId: number) => void;
|
||||
discard: (itemId: number) => void;
|
||||
discardAll: () => void;
|
||||
};
|
||||
|
||||
const CanvasSessionContext = createContext<CanvasSessionContextValue | null>(null);
|
||||
@@ -140,11 +139,6 @@ export const CanvasSessionContextProvider = memo(
|
||||
*/
|
||||
const $items = useState(() => atom<S['SessionQueueItem'][]>([]))[0];
|
||||
|
||||
/**
|
||||
* Whether auto-switch is enabled.
|
||||
*/
|
||||
const $autoSwitch = useState(() => atom<AutoSwitchMode>('switch_on_start'))[0];
|
||||
|
||||
/**
|
||||
* An internal flag used to work around race conditions with auto-switch switching to queue items before their
|
||||
* output images have fully loaded.
|
||||
@@ -226,19 +220,21 @@ export const CanvasSessionContextProvider = memo(
|
||||
)[0];
|
||||
|
||||
/**
|
||||
* A redux selector to select all queue items from the RTK Query cache. It's important that this returns stable
|
||||
* references if possible to reduce re-renders. All derivations of the queue items (e.g. filtering out canceled
|
||||
* items) should be done in a nanostores computed.
|
||||
* A redux selector to select all queue items from the RTK Query cache.
|
||||
*/
|
||||
const selectQueueItems = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
queueApi.endpoints.listAllQueueItems.select({ destination: session.id }),
|
||||
({ data }) => data ?? EMPTY_ARRAY
|
||||
),
|
||||
[session.id]
|
||||
const selectQueueItems = useMemo(() => buildSelectSessionQueueItems(session.id), [session.id]);
|
||||
|
||||
const discard = useCallback(
|
||||
(itemId: number) => {
|
||||
store.dispatch(canvasQueueItemDiscarded({ itemId }));
|
||||
},
|
||||
[store]
|
||||
);
|
||||
|
||||
const discardAll = useCallback(() => {
|
||||
store.dispatch(canvasSessionReset());
|
||||
}, [store]);
|
||||
|
||||
const selectNext = useCallback(() => {
|
||||
const selectedItemId = $selectedItemId.get();
|
||||
if (selectedItemId === null) {
|
||||
@@ -300,12 +296,15 @@ export const CanvasSessionContextProvider = memo(
|
||||
imageLoaded: true,
|
||||
});
|
||||
}
|
||||
if ($lastCompletedItemId.get() === itemId && $autoSwitch.get() === 'switch_on_finish') {
|
||||
if (
|
||||
$lastCompletedItemId.get() === itemId &&
|
||||
selectStagingAreaAutoSwitch(store.getState()) === 'switch_on_finish'
|
||||
) {
|
||||
$selectedItemId.set(itemId);
|
||||
$lastCompletedItemId.set(null);
|
||||
}
|
||||
},
|
||||
[$autoSwitch, $lastCompletedItemId, $progressData, $selectedItemId]
|
||||
[$lastCompletedItemId, $progressData, $selectedItemId, store]
|
||||
);
|
||||
|
||||
// Set up socket listeners
|
||||
@@ -340,7 +339,7 @@ export const CanvasSessionContextProvider = memo(
|
||||
socket.off('invocation_progress', onProgress);
|
||||
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
|
||||
};
|
||||
}, [$autoSwitch, $lastCompletedItemId, $lastStartedItemId, $progressData, $selectedItemId, session.id, socket]);
|
||||
}, [$lastCompletedItemId, $lastStartedItemId, $progressData, $selectedItemId, session.id, socket]);
|
||||
|
||||
// Set up state subscriptions and effects
|
||||
useEffect(() => {
|
||||
@@ -362,33 +361,32 @@ export const CanvasSessionContextProvider = memo(
|
||||
const unsubEnsureSelectedItemIdExists = effect(
|
||||
[$items, $selectedItemId, $lastStartedItemId],
|
||||
(items, selectedItemId, lastStartedItemId) => {
|
||||
// If there are no items, cannot have a selected item.
|
||||
if (items.length === 0) {
|
||||
// If there are no items, cannot have a selected item.
|
||||
$selectedItemId.set(null);
|
||||
return;
|
||||
}
|
||||
// If there is no selected item but there are items, select the first one.
|
||||
if (selectedItemId === null && items.length > 0) {
|
||||
} else if (selectedItemId === null && items.length > 0) {
|
||||
// If there is no selected item but there are items, select the first one.
|
||||
$selectedItemId.set(items[0]?.item_id ?? null);
|
||||
return;
|
||||
}
|
||||
if (
|
||||
$autoSwitch.get() === 'switch_on_start' &&
|
||||
} else if (
|
||||
selectStagingAreaAutoSwitch(store.getState()) === 'switch_on_start' &&
|
||||
items.findIndex(({ item_id }) => item_id === lastStartedItemId) !== -1
|
||||
) {
|
||||
$selectedItemId.set(lastStartedItemId);
|
||||
$lastStartedItemId.set(null);
|
||||
}
|
||||
// If an item is selected and it is not in the list of items, un-set it. This effect will run again and we'll
|
||||
// the above case, selecting the first item if there are any.
|
||||
if (selectedItemId !== null && items.findIndex(({ item_id }) => item_id === selectedItemId) === -1) {
|
||||
} else if (selectedItemId !== null && items.findIndex(({ item_id }) => item_id === selectedItemId) === -1) {
|
||||
// If an item is selected and it is not in the list of items, un-set it. This effect will run again and we'll
|
||||
// the above case, selecting the first item if there are any.
|
||||
let prevIndex = _prevItems.findIndex(({ item_id }) => item_id === selectedItemId);
|
||||
if (prevIndex >= items.length) {
|
||||
prevIndex = items.length - 1;
|
||||
}
|
||||
const nextItem = items[prevIndex];
|
||||
$selectedItemId.set(nextItem?.item_id ?? null);
|
||||
return;
|
||||
}
|
||||
|
||||
if (items !== _prevItems) {
|
||||
_prevItems = items;
|
||||
}
|
||||
}
|
||||
);
|
||||
@@ -474,7 +472,7 @@ export const CanvasSessionContextProvider = memo(
|
||||
if (lastLoadedItemId === null) {
|
||||
return;
|
||||
}
|
||||
if ($autoSwitch.get() === 'switch_on_finish') {
|
||||
if (selectStagingAreaAutoSwitch(store.getState()) === 'switch_on_finish') {
|
||||
$selectedItemId.set(lastLoadedItemId);
|
||||
}
|
||||
$lastLoadedItemId.set(null);
|
||||
@@ -486,6 +484,22 @@ export const CanvasSessionContextProvider = memo(
|
||||
queueApi.endpoints.listAllQueueItems.initiate({ destination: session.id })
|
||||
);
|
||||
|
||||
// const unsubListener = store.dispatch(
|
||||
// addAppListener({
|
||||
// matcher: queueApi.endpoints.cancelQueueItem.matchFulfilled,
|
||||
// effect: ({ payload }, { getState }) => {
|
||||
// const { item_id } = payload;
|
||||
|
||||
// const items = selectQueueItems(getState());
|
||||
// if (items.length === 0) {
|
||||
// $selectedItemId.set(null);
|
||||
// } else if ($selectedItemId.get() === null) {
|
||||
// $selectedItemId.set(items[0].item_id);
|
||||
// }
|
||||
// },
|
||||
// })
|
||||
// );
|
||||
|
||||
// Clean up all subscriptions and top-level (i.e. non-computed/derived state)
|
||||
return () => {
|
||||
unsubHandleAutoSwitch();
|
||||
@@ -498,7 +512,6 @@ export const CanvasSessionContextProvider = memo(
|
||||
$selectedItemId.set(null);
|
||||
};
|
||||
}, [
|
||||
$autoSwitch,
|
||||
$items,
|
||||
$lastLoadedItemId,
|
||||
$lastStartedItemId,
|
||||
@@ -517,7 +530,6 @@ export const CanvasSessionContextProvider = memo(
|
||||
$isPending,
|
||||
$progressData,
|
||||
$selectedItemId,
|
||||
$autoSwitch,
|
||||
$selectedItem,
|
||||
$selectedItemIndex,
|
||||
$selectedItemOutputImageDTO,
|
||||
@@ -527,9 +539,10 @@ export const CanvasSessionContextProvider = memo(
|
||||
selectFirst,
|
||||
selectLast,
|
||||
onImageLoad,
|
||||
discard,
|
||||
discardAll,
|
||||
}),
|
||||
[
|
||||
$autoSwitch,
|
||||
$items,
|
||||
$hasItems,
|
||||
$isPending,
|
||||
@@ -545,6 +558,8 @@ export const CanvasSessionContextProvider = memo(
|
||||
selectFirst,
|
||||
selectLast,
|
||||
onImageLoad,
|
||||
discard,
|
||||
discardAll,
|
||||
]
|
||||
);
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ export const getProgressMessage = (data?: S['InvocationProgressEvent'] | null) =
|
||||
|
||||
export const DROP_SHADOW = 'drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 4px rgba(0, 0, 0, 0.3))';
|
||||
|
||||
export const getQueueItemElementId = (itemId: number) => `queue-item-status-card-${itemId}`;
|
||||
export const getQueueItemElementId = (index: number) => `queue-item-preview-${index}`;
|
||||
|
||||
export const getOutputImageName = (item: S['SessionQueueItem']) => {
|
||||
const nodeId = Object.entries(item.session.source_prepared_mapping).find(([nodeId]) =>
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
selectStagingAreaAutoSwitch,
|
||||
settingsStagingAreaAutoSwitchChanged,
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { PiCaretLineRightBold, PiCaretRightBold, PiMoonBold } from 'react-icons/pi';
|
||||
|
||||
export const StagingAreaAutoSwitchButtons = memo(() => {
|
||||
const autoSwitch = useAppSelector(selectStagingAreaAutoSwitch);
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const onClickOff = useCallback(() => {
|
||||
dispatch(settingsStagingAreaAutoSwitchChanged('off'));
|
||||
}, [dispatch]);
|
||||
const onClickSwitchOnStart = useCallback(() => {
|
||||
dispatch(settingsStagingAreaAutoSwitchChanged('switch_on_start'));
|
||||
}, [dispatch]);
|
||||
const onClickSwitchOnFinished = useCallback(() => {
|
||||
dispatch(settingsStagingAreaAutoSwitchChanged('switch_on_finish'));
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<IconButton
|
||||
aria-label="Do not auto-switch"
|
||||
tooltip="Do not auto-switch"
|
||||
icon={<PiMoonBold />}
|
||||
colorScheme={autoSwitch === 'off' ? 'invokeBlue' : 'base'}
|
||||
onClick={onClickOff}
|
||||
/>
|
||||
<IconButton
|
||||
aria-label="Switch on start"
|
||||
tooltip="Switch on start"
|
||||
icon={<PiCaretRightBold />}
|
||||
colorScheme={autoSwitch === 'switch_on_start' ? 'invokeBlue' : 'base'}
|
||||
onClick={onClickSwitchOnStart}
|
||||
/>
|
||||
<IconButton
|
||||
aria-label="Switch on finish"
|
||||
tooltip="Switch on finish"
|
||||
icon={<PiCaretLineRightBold />}
|
||||
colorScheme={autoSwitch === 'switch_on_finish' ? 'invokeBlue' : 'base'}
|
||||
onClick={onClickSwitchOnFinished}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
});
|
||||
StagingAreaAutoSwitchButtons.displayName = 'StagingAreaAutoSwitchButtons';
|
||||
@@ -1,7 +1,6 @@
|
||||
import { ButtonGroup } from '@invoke-ai/ui-library';
|
||||
import { ButtonGroup, Flex } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { getQueueItemElementId } from 'features/controlLayers/components/SimpleSession/shared';
|
||||
import { StagingAreaToolbarAcceptButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarAcceptButton';
|
||||
import { StagingAreaToolbarDiscardAllButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarDiscardAllButton';
|
||||
import { StagingAreaToolbarDiscardSelectedButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarDiscardSelectedButton';
|
||||
@@ -12,27 +11,22 @@ import { StagingAreaToolbarPrevButton } from 'features/controlLayers/components/
|
||||
import { StagingAreaToolbarSaveSelectedToGalleryButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarSaveSelectedToGalleryButton';
|
||||
import { StagingAreaToolbarToggleShowResultsButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarToggleShowResultsButton';
|
||||
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { memo, useEffect } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
|
||||
import { StagingAreaAutoSwitchButtons } from './StagingAreaAutoSwitchButtons';
|
||||
|
||||
export const StagingAreaToolbar = memo(() => {
|
||||
const canvasManager = useCanvasManager();
|
||||
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
|
||||
|
||||
const ctx = useCanvasSessionContext();
|
||||
|
||||
useEffect(() => {
|
||||
return ctx.$selectedItemId.listen((id) => {
|
||||
if (id !== null) {
|
||||
document.getElementById(getQueueItemElementId(id))?.scrollIntoView();
|
||||
}
|
||||
});
|
||||
}, [ctx.$selectedItemId]);
|
||||
|
||||
useHotkeys('meta+left', ctx.selectFirst, { preventDefault: true });
|
||||
useHotkeys('meta+right', ctx.selectLast, { preventDefault: true });
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flex gap={2}>
|
||||
<ButtonGroup borderRadius="base" shadow="dark-lg">
|
||||
<StagingAreaToolbarPrevButton isDisabled={!shouldShowStagedImage} />
|
||||
<StagingAreaToolbarImageCountButton />
|
||||
@@ -44,9 +38,14 @@ export const StagingAreaToolbar = memo(() => {
|
||||
<StagingAreaToolbarSaveSelectedToGalleryButton />
|
||||
<StagingAreaToolbarMenu />
|
||||
<StagingAreaToolbarDiscardSelectedButton isDisabled={!shouldShowStagedImage} />
|
||||
</ButtonGroup>
|
||||
<ButtonGroup borderRadius="base" shadow="dark-lg">
|
||||
<StagingAreaAutoSwitchButtons />
|
||||
</ButtonGroup>
|
||||
<ButtonGroup borderRadius="base" shadow="dark-lg">
|
||||
<StagingAreaToolbarDiscardAllButton isDisabled={!shouldShowStagedImage} />
|
||||
</ButtonGroup>
|
||||
</>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import { canvasSessionReset } from 'features/controlLayers/store/canvasStagingAr
|
||||
import { selectBboxRect, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
|
||||
import { imageNameToImageObject } from 'features/controlLayers/store/util';
|
||||
import { useDeleteQueueItemsByDestination } from 'features/queue/hooks/useDeleteQueueItemsByDestination';
|
||||
import { useCancelQueueItemsByDestination } from 'features/queue/hooks/useCancelQueueItemsByDestination';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -24,7 +24,7 @@ export const StagingAreaToolbarAcceptButton = memo(() => {
|
||||
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
|
||||
const isCanvasFocused = useIsRegionFocused('canvas');
|
||||
const selectedItemImageDTO = useStore(ctx.$selectedItemOutputImageDTO);
|
||||
const deleteQueueItemsByDestination = useDeleteQueueItemsByDestination();
|
||||
const cancelQueueItemsByDestination = useCancelQueueItemsByDestination();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -41,13 +41,13 @@ export const StagingAreaToolbarAcceptButton = memo(() => {
|
||||
|
||||
dispatch(rasterLayerAdded({ overrides, isSelected: selectedEntityIdentifier?.type === 'raster_layer' }));
|
||||
dispatch(canvasSessionReset());
|
||||
deleteQueueItemsByDestination.trigger(ctx.session.id);
|
||||
cancelQueueItemsByDestination.trigger(ctx.session.id, { withToast: false });
|
||||
}, [
|
||||
selectedItemImageDTO,
|
||||
bboxRect,
|
||||
dispatch,
|
||||
selectedEntityIdentifier?.type,
|
||||
deleteQueueItemsByDestination,
|
||||
cancelQueueItemsByDestination,
|
||||
ctx.session.id,
|
||||
]);
|
||||
|
||||
@@ -68,8 +68,8 @@ export const StagingAreaToolbarAcceptButton = memo(() => {
|
||||
icon={<PiCheckBold />}
|
||||
onClick={acceptSelected}
|
||||
colorScheme="invokeBlue"
|
||||
isDisabled={!selectedItemImageDTO || !shouldShowStagedImage || deleteQueueItemsByDestination.isDisabled}
|
||||
isLoading={deleteQueueItemsByDestination.isLoading}
|
||||
isDisabled={!selectedItemImageDTO || !shouldShowStagedImage || cancelQueueItemsByDestination.isDisabled}
|
||||
isLoading={cancelQueueItemsByDestination.isLoading}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,27 +1,19 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { canvasSessionReset, generateSessionReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { useDeleteQueueItemsByDestination } from 'features/queue/hooks/useDeleteQueueItemsByDestination';
|
||||
import { useCancelQueueItemsByDestination } from 'features/queue/hooks/useCancelQueueItemsByDestination';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
|
||||
export const StagingAreaToolbarDiscardAllButton = memo(({ isDisabled }: { isDisabled?: boolean }) => {
|
||||
const ctx = useCanvasSessionContext();
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const deleteQueueItemsByDestination = useDeleteQueueItemsByDestination();
|
||||
const cancelQueueItemsByDestination = useCancelQueueItemsByDestination();
|
||||
|
||||
const discardAll = useCallback(() => {
|
||||
deleteQueueItemsByDestination.trigger(ctx.session.id);
|
||||
if (ctx.session.type === 'advanced') {
|
||||
dispatch(canvasSessionReset());
|
||||
} else {
|
||||
// ctx.session.type === 'simple'
|
||||
dispatch(generateSessionReset());
|
||||
}
|
||||
}, [deleteQueueItemsByDestination, ctx.session.id, ctx.session.type, dispatch]);
|
||||
ctx.discardAll();
|
||||
cancelQueueItemsByDestination.trigger(ctx.session.id, { withToast: false });
|
||||
}, [cancelQueueItemsByDestination, ctx]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
@@ -30,9 +22,8 @@ export const StagingAreaToolbarDiscardAllButton = memo(({ isDisabled }: { isDisa
|
||||
icon={<PiTrashSimpleBold />}
|
||||
onClick={discardAll}
|
||||
colorScheme="error"
|
||||
fontSize={16}
|
||||
isDisabled={isDisabled || deleteQueueItemsByDestination.isDisabled}
|
||||
isLoading={deleteQueueItemsByDestination.isLoading}
|
||||
isDisabled={isDisabled || cancelQueueItemsByDestination.isDisabled}
|
||||
isLoading={cancelQueueItemsByDestination.isLoading}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { canvasSessionReset, generateSessionReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { useDeleteQueueItem } from 'features/queue/hooks/useDeleteQueueItem';
|
||||
import { useCancelQueueItem } from 'features/queue/hooks/useCancelQueueItem';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
|
||||
export const StagingAreaToolbarDiscardSelectedButton = memo(({ isDisabled }: { isDisabled?: boolean }) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const ctx = useCanvasSessionContext();
|
||||
const deleteQueueItem = useDeleteQueueItem();
|
||||
const cancelQueueItem = useCancelQueueItem();
|
||||
const selectedItemId = useStore(ctx.$selectedItemId);
|
||||
|
||||
const { t } = useTranslation();
|
||||
@@ -20,17 +17,9 @@ export const StagingAreaToolbarDiscardSelectedButton = memo(({ isDisabled }: { i
|
||||
if (selectedItemId === null) {
|
||||
return;
|
||||
}
|
||||
await deleteQueueItem.trigger(selectedItemId);
|
||||
const itemCount = ctx.$itemCount.get();
|
||||
if (itemCount <= 1) {
|
||||
if (ctx.session.type === 'advanced') {
|
||||
dispatch(canvasSessionReset());
|
||||
} else {
|
||||
// ctx.session.type === 'simple'
|
||||
dispatch(generateSessionReset());
|
||||
}
|
||||
}
|
||||
}, [selectedItemId, deleteQueueItem, ctx.$itemCount, ctx.session.type, dispatch]);
|
||||
ctx.discard(selectedItemId);
|
||||
await cancelQueueItem.trigger(selectedItemId, { withToast: false });
|
||||
}, [selectedItemId, ctx, cancelQueueItem]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
@@ -39,9 +28,8 @@ export const StagingAreaToolbarDiscardSelectedButton = memo(({ isDisabled }: { i
|
||||
icon={<PiXBold />}
|
||||
onClick={discardSelected}
|
||||
colorScheme="invokeBlue"
|
||||
fontSize={16}
|
||||
isDisabled={selectedItemId === null || deleteQueueItem.isDisabled || isDisabled}
|
||||
isLoading={deleteQueueItem.isLoading}
|
||||
isDisabled={selectedItemId === null || cancelQueueItem.isDisabled || isDisabled}
|
||||
isLoading={cancelQueueItem.isLoading}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
import { IconButton, Menu, MenuButton, MenuDivider, MenuList } from '@invoke-ai/ui-library';
|
||||
import { StagingAreaToolbarMenuAutoSwitch } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarMenuAutoSwitch';
|
||||
import { IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
|
||||
import { StagingAreaToolbarNewLayerFromImageMenuItems } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarMenuNewLayerFromImage';
|
||||
import { memo } from 'react';
|
||||
import { PiDotsThreeBold } from 'react-icons/pi';
|
||||
import { PiDotsThreeVerticalBold } from 'react-icons/pi';
|
||||
|
||||
export const StagingAreaToolbarMenu = memo(() => {
|
||||
return (
|
||||
<Menu>
|
||||
<MenuButton as={IconButton} icon={<PiDotsThreeBold />} colorScheme="invokeBlue" />
|
||||
<MenuButton as={IconButton} icon={<PiDotsThreeVerticalBold />} colorScheme="invokeBlue" />
|
||||
<MenuList>
|
||||
<StagingAreaToolbarMenuAutoSwitch />
|
||||
<MenuDivider />
|
||||
<StagingAreaToolbarNewLayerFromImageMenuItems />
|
||||
</MenuList>
|
||||
</Menu>
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
import { MenuItemOption, MenuOptionGroup } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { isAutoSwitchMode, useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const StagingAreaToolbarMenuAutoSwitch = memo(() => {
|
||||
const ctx = useCanvasSessionContext();
|
||||
const autoSwitch = useStore(ctx.$autoSwitch);
|
||||
|
||||
const onChange = useCallback(
|
||||
(val: string | string[]) => {
|
||||
assert(isAutoSwitchMode(val));
|
||||
ctx.$autoSwitch.set(val);
|
||||
},
|
||||
[ctx.$autoSwitch]
|
||||
);
|
||||
|
||||
return (
|
||||
<MenuOptionGroup value={autoSwitch} onChange={onChange} title="Auto-Switch" type="radio">
|
||||
<MenuItemOption value="off" closeOnSelect={false}>
|
||||
Off
|
||||
</MenuItemOption>
|
||||
<MenuItemOption value="switch_on_start" closeOnSelect={false}>
|
||||
Switch on Start
|
||||
</MenuItemOption>
|
||||
<MenuItemOption value="switch_on_finish" closeOnSelect={false}>
|
||||
Switch on Finish
|
||||
</MenuItemOption>
|
||||
</MenuOptionGroup>
|
||||
);
|
||||
});
|
||||
|
||||
StagingAreaToolbarMenuAutoSwitch.displayName = 'StagingAreaToolbarMenuAutoSwitch';
|
||||
@@ -1,38 +1,41 @@
|
||||
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
|
||||
import { createSelector, createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import type { RgbaColor } from 'features/controlLayers/store/types';
|
||||
import { zRgbaColor } from 'features/controlLayers/store/types';
|
||||
import { z } from 'zod/v4';
|
||||
|
||||
type CanvasSettingsState = {
|
||||
const zAutoSwitchMode = z.enum(['off', 'switch_on_start', 'switch_on_finish']);
|
||||
|
||||
const zCanvasSettingsState = z.object({
|
||||
/**
|
||||
* Whether to show HUD (Heads-Up Display) on the canvas.
|
||||
*/
|
||||
showHUD: boolean;
|
||||
showHUD: z.boolean().default(true),
|
||||
/**
|
||||
* Whether to clip lines and shapes to the generation bounding box. If disabled, lines and shapes will be clipped to
|
||||
* the canvas bounds.
|
||||
*/
|
||||
clipToBbox: boolean;
|
||||
clipToBbox: z.boolean().default(false),
|
||||
/**
|
||||
* Whether to show a dynamic grid on the canvas. If disabled, a checkerboard pattern will be shown instead.
|
||||
*/
|
||||
dynamicGrid: boolean;
|
||||
dynamicGrid: z.boolean().default(false),
|
||||
/**
|
||||
* Whether to invert the scroll direction when adjusting the brush or eraser width with the scroll wheel.
|
||||
*/
|
||||
invertScrollForToolWidth: boolean;
|
||||
invertScrollForToolWidth: z.boolean().default(false),
|
||||
/**
|
||||
* The width of the brush tool.
|
||||
*/
|
||||
brushWidth: number;
|
||||
brushWidth: z.int().gt(0).default(50),
|
||||
/**
|
||||
* The width of the eraser tool.
|
||||
*/
|
||||
eraserWidth: number;
|
||||
eraserWidth: z.int().gt(0).default(50),
|
||||
/**
|
||||
* The color to use when drawing lines or filling shapes.
|
||||
*/
|
||||
color: RgbaColor;
|
||||
color: zRgbaColor.default({ r: 31, g: 160, b: 224, a: 1 }), // invokeBlue.500
|
||||
/**
|
||||
* Whether to composite inpainted/outpainted regions back onto the source image when saving canvas generations.
|
||||
*
|
||||
@@ -40,70 +43,61 @@ type CanvasSettingsState = {
|
||||
*
|
||||
* When `sendToCanvas` is disabled, this setting is ignored, masked regions will always be composited.
|
||||
*/
|
||||
outputOnlyMaskedRegions: boolean;
|
||||
outputOnlyMaskedRegions: z.boolean().default(true),
|
||||
/**
|
||||
* Whether to automatically process the operations like filtering and auto-masking.
|
||||
*/
|
||||
autoProcess: boolean;
|
||||
autoProcess: z.boolean().default(true),
|
||||
/**
|
||||
* The snap-to-grid setting for the canvas.
|
||||
*/
|
||||
snapToGrid: boolean;
|
||||
snapToGrid: z.boolean().default(true),
|
||||
/**
|
||||
* Whether to show progress on the canvas when generating images.
|
||||
*/
|
||||
showProgressOnCanvas: boolean;
|
||||
showProgressOnCanvas: z.boolean().default(true),
|
||||
/**
|
||||
* Whether to show the bounding box overlay on the canvas.
|
||||
*/
|
||||
bboxOverlay: boolean;
|
||||
bboxOverlay: z.boolean().default(false),
|
||||
/**
|
||||
* Whether to preserve the masked region instead of inpainting it.
|
||||
*/
|
||||
preserveMask: boolean;
|
||||
preserveMask: z.boolean().default(false),
|
||||
/**
|
||||
* Whether to show only raster layers while staging.
|
||||
*/
|
||||
isolatedStagingPreview: boolean;
|
||||
isolatedStagingPreview: z.boolean().default(true),
|
||||
/**
|
||||
* Whether to show only the selected layer while filtering, transforming, or doing other operations.
|
||||
*/
|
||||
isolatedLayerPreview: boolean;
|
||||
isolatedLayerPreview: z.boolean().default(true),
|
||||
/**
|
||||
* Whether to use pressure sensitivity for the brush and eraser tool when a pen device is used.
|
||||
*/
|
||||
pressureSensitivity: boolean;
|
||||
pressureSensitivity: z.boolean().default(true),
|
||||
/**
|
||||
* Whether to show the rule of thirds composition guide overlay on the canvas.
|
||||
*/
|
||||
ruleOfThirds: boolean;
|
||||
};
|
||||
ruleOfThirds: z.boolean().default(false),
|
||||
/**
|
||||
* Whether to save all staging images to the gallery instead of keeping them as intermediate images.
|
||||
*/
|
||||
saveAllImagesToGallery: z.boolean().default(false),
|
||||
/**
|
||||
* The auto-switch mode for the canvas staging area.
|
||||
*/
|
||||
stagingAreaAutoSwitch: zAutoSwitchMode.default('switch_on_start'),
|
||||
});
|
||||
|
||||
const initialState: CanvasSettingsState = {
|
||||
showHUD: true,
|
||||
clipToBbox: false,
|
||||
dynamicGrid: false,
|
||||
brushWidth: 50,
|
||||
eraserWidth: 50,
|
||||
invertScrollForToolWidth: false,
|
||||
color: { r: 31, g: 160, b: 224, a: 1 }, // invokeBlue.500
|
||||
outputOnlyMaskedRegions: true,
|
||||
autoProcess: true,
|
||||
snapToGrid: true,
|
||||
showProgressOnCanvas: true,
|
||||
bboxOverlay: false,
|
||||
preserveMask: false,
|
||||
isolatedStagingPreview: true,
|
||||
isolatedLayerPreview: true,
|
||||
pressureSensitivity: true,
|
||||
ruleOfThirds: false,
|
||||
};
|
||||
type CanvasSettingsState = z.infer<typeof zCanvasSettingsState>;
|
||||
const getInitialState = () => zCanvasSettingsState.parse({});
|
||||
|
||||
export const canvasSettingsSlice = createSlice({
|
||||
name: 'canvasSettings',
|
||||
initialState,
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
settingsClipToBboxChanged: (state, action: PayloadAction<boolean>) => {
|
||||
settingsClipToBboxChanged: (state, action: PayloadAction<CanvasSettingsState['clipToBbox']>) => {
|
||||
state.clipToBbox = action.payload;
|
||||
},
|
||||
settingsDynamicGridToggled: (state) => {
|
||||
@@ -112,16 +106,19 @@ export const canvasSettingsSlice = createSlice({
|
||||
settingsShowHUDToggled: (state) => {
|
||||
state.showHUD = !state.showHUD;
|
||||
},
|
||||
settingsBrushWidthChanged: (state, action: PayloadAction<number>) => {
|
||||
settingsBrushWidthChanged: (state, action: PayloadAction<CanvasSettingsState['brushWidth']>) => {
|
||||
state.brushWidth = Math.round(action.payload);
|
||||
},
|
||||
settingsEraserWidthChanged: (state, action: PayloadAction<number>) => {
|
||||
settingsEraserWidthChanged: (state, action: PayloadAction<CanvasSettingsState['eraserWidth']>) => {
|
||||
state.eraserWidth = Math.round(action.payload);
|
||||
},
|
||||
settingsColorChanged: (state, action: PayloadAction<RgbaColor>) => {
|
||||
settingsColorChanged: (state, action: PayloadAction<CanvasSettingsState['color']>) => {
|
||||
state.color = action.payload;
|
||||
},
|
||||
settingsInvertScrollForToolWidthChanged: (state, action: PayloadAction<boolean>) => {
|
||||
settingsInvertScrollForToolWidthChanged: (
|
||||
state,
|
||||
action: PayloadAction<CanvasSettingsState['invertScrollForToolWidth']>
|
||||
) => {
|
||||
state.invertScrollForToolWidth = action.payload;
|
||||
},
|
||||
settingsOutputOnlyMaskedRegionsToggled: (state) => {
|
||||
@@ -154,6 +151,15 @@ export const canvasSettingsSlice = createSlice({
|
||||
settingsRuleOfThirdsToggled: (state) => {
|
||||
state.ruleOfThirds = !state.ruleOfThirds;
|
||||
},
|
||||
settingsSaveAllImagesToGalleryToggled: (state) => {
|
||||
state.saveAllImagesToGallery = !state.saveAllImagesToGallery;
|
||||
},
|
||||
settingsStagingAreaAutoSwitchChanged: (
|
||||
state,
|
||||
action: PayloadAction<CanvasSettingsState['stagingAreaAutoSwitch']>
|
||||
) => {
|
||||
state.stagingAreaAutoSwitch = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@@ -175,6 +181,8 @@ export const {
|
||||
settingsIsolatedLayerPreviewToggled,
|
||||
settingsPressureSensitivityToggled,
|
||||
settingsRuleOfThirdsToggled,
|
||||
settingsSaveAllImagesToGalleryToggled,
|
||||
settingsStagingAreaAutoSwitchChanged,
|
||||
} = canvasSettingsSlice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
@@ -184,7 +192,7 @@ const migrate = (state: any): any => {
|
||||
|
||||
export const canvasSettingsPersistConfig: PersistConfig<CanvasSettingsState> = {
|
||||
name: canvasSettingsSlice.name,
|
||||
initialState,
|
||||
initialState: getInitialState(),
|
||||
migrate,
|
||||
persistDenylist: [],
|
||||
};
|
||||
@@ -209,3 +217,5 @@ export const selectIsolatedStagingPreview = createCanvasSettingsSelector((settin
|
||||
export const selectIsolatedLayerPreview = createCanvasSettingsSelector((settings) => settings.isolatedLayerPreview);
|
||||
export const selectPressureSensitivity = createCanvasSettingsSelector((settings) => settings.pressureSensitivity);
|
||||
export const selectRuleOfThirds = createCanvasSettingsSelector((settings) => settings.ruleOfThirds);
|
||||
export const selectSaveAllImagesToGallery = createCanvasSettingsSelector((settings) => settings.saveAllImagesToGallery);
|
||||
export const selectStagingAreaAutoSwitch = createCanvasSettingsSelector((settings) => settings.stagingAreaAutoSwitch);
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { canvasReset } from 'features/controlLayers/store/actions';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
|
||||
type CanvasStagingAreaState = {
|
||||
generateSessionId: string | null;
|
||||
canvasSessionId: string | null;
|
||||
canvasDiscardedQueueItems: number[];
|
||||
};
|
||||
|
||||
const INITIAL_STATE: CanvasStagingAreaState = {
|
||||
generateSessionId: null,
|
||||
canvasSessionId: null,
|
||||
canvasDiscardedQueueItems: [],
|
||||
};
|
||||
|
||||
const getInitialState = (): CanvasStagingAreaState => deepClone(INITIAL_STATE);
|
||||
@@ -26,12 +30,20 @@ export const canvasSessionSlice = createSlice({
|
||||
generateSessionReset: (state) => {
|
||||
state.generateSessionId = null;
|
||||
},
|
||||
canvasQueueItemDiscarded: (state, action: PayloadAction<{ itemId: number }>) => {
|
||||
const { itemId } = action.payload;
|
||||
if (!state.canvasDiscardedQueueItems.includes(itemId)) {
|
||||
state.canvasDiscardedQueueItems.push(itemId);
|
||||
}
|
||||
},
|
||||
canvasSessionIdChanged: (state, action: PayloadAction<{ id: string }>) => {
|
||||
const { id } = action.payload;
|
||||
state.canvasSessionId = id;
|
||||
state.canvasDiscardedQueueItems = [];
|
||||
},
|
||||
canvasSessionReset: (state) => {
|
||||
state.canvasSessionId = null;
|
||||
state.canvasDiscardedQueueItems = [];
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
@@ -41,8 +53,13 @@ export const canvasSessionSlice = createSlice({
|
||||
},
|
||||
});
|
||||
|
||||
export const { generateSessionIdChanged, generateSessionReset, canvasSessionIdChanged, canvasSessionReset } =
|
||||
canvasSessionSlice.actions;
|
||||
export const {
|
||||
generateSessionIdChanged,
|
||||
generateSessionReset,
|
||||
canvasSessionIdChanged,
|
||||
canvasSessionReset,
|
||||
canvasQueueItemDiscarded,
|
||||
} = canvasSessionSlice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
const migrate = (state: any): any => {
|
||||
@@ -63,4 +80,34 @@ export const selectGenerateSessionId = createSelector(
|
||||
selectCanvasSessionSlice,
|
||||
({ generateSessionId }) => generateSessionId
|
||||
);
|
||||
export const selectIsStaging = createSelector(selectCanvasSessionId, (canvasSessionId) => canvasSessionId !== null);
|
||||
export const buildSelectSessionQueueItems = (sessionId: string) =>
|
||||
createSelector(
|
||||
[queueApi.endpoints.listAllQueueItems.select({ destination: sessionId }), selectDiscardedItems],
|
||||
({ data }, discardedItems) => {
|
||||
if (!data) {
|
||||
return EMPTY_ARRAY;
|
||||
}
|
||||
return data.filter(
|
||||
({ status, item_id }) => status !== 'canceled' && status !== 'failed' && !discardedItems.includes(item_id)
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
export const selectIsStaging = (state: RootState) => {
|
||||
const sessionId = selectCanvasSessionId(state);
|
||||
if (!sessionId) {
|
||||
return false;
|
||||
}
|
||||
const { data } = queueApi.endpoints.listAllQueueItems.select({ destination: sessionId })(state);
|
||||
if (!data) {
|
||||
return false;
|
||||
}
|
||||
const discardedItems = selectDiscardedItems(state);
|
||||
return data.some(
|
||||
({ status, item_id }) => status !== 'canceled' && status !== 'failed' && !discardedItems.includes(item_id)
|
||||
);
|
||||
};
|
||||
const selectDiscardedItems = createSelector(
|
||||
selectCanvasSessionSlice,
|
||||
({ canvasDiscardedQueueItems }) => canvasDiscardedQueueItems
|
||||
);
|
||||
|
||||
@@ -57,6 +57,8 @@ export const refImagesSlice = createSlice({
|
||||
const { entities, replace } = action.payload;
|
||||
if (replace) {
|
||||
state.entities = entities;
|
||||
state.isPanelOpen = false;
|
||||
state.selectedEntityId = null;
|
||||
} else {
|
||||
state.entities.push(...entities);
|
||||
}
|
||||
|
||||
@@ -98,7 +98,7 @@ const zRgbColor = z.object({
|
||||
b: z.number().int().min(0).max(255),
|
||||
});
|
||||
export type RgbColor = z.infer<typeof zRgbColor>;
|
||||
const zRgbaColor = zRgbColor.extend({
|
||||
export const zRgbaColor = zRgbColor.extend({
|
||||
a: z.number().min(0).max(1),
|
||||
});
|
||||
export type RgbaColor = z.infer<typeof zRgbaColor>;
|
||||
|
||||
@@ -15,7 +15,6 @@ const sx = {
|
||||
objectFit: 'contain',
|
||||
maxW: 'full',
|
||||
maxH: 'full',
|
||||
borderRadius: 'base',
|
||||
cursor: 'grab',
|
||||
'&[data-is-dragging=true]': {
|
||||
opacity: 0.3,
|
||||
|
||||
@@ -1,25 +1,33 @@
|
||||
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
|
||||
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
|
||||
import { useImageActions } from 'features/gallery/hooks/useImageActions';
|
||||
import { useRecallAll } from 'features/gallery/hooks/useRecallAll';
|
||||
import { useRecallDimensions } from 'features/gallery/hooks/useRecallDimensions';
|
||||
import { useRecallPrompts } from 'features/gallery/hooks/useRecallPrompts';
|
||||
import { useRecallRemix } from 'features/gallery/hooks/useRecallRemix';
|
||||
import { useRecallSeed } from 'features/gallery/hooks/useRecallSeed';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
PiArrowBendUpLeftBold,
|
||||
PiArrowsCounterClockwiseBold,
|
||||
PiAsteriskBold,
|
||||
PiPaintBrushBold,
|
||||
PiPlantBold,
|
||||
PiQuotesBold,
|
||||
PiRulerBold,
|
||||
} from 'react-icons/pi';
|
||||
|
||||
export const ImageMenuItemMetadataRecallActions = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const imageDTO = useImageDTOContext();
|
||||
const subMenu = useSubMenu();
|
||||
|
||||
const { recallAll, remix, recallSeed, recallPrompts, hasMetadata, hasSeed, hasPrompts, createAsPreset } =
|
||||
useImageActions(imageDTO);
|
||||
const imageDTO = useImageDTOContext();
|
||||
|
||||
const recallAll = useRecallAll(imageDTO);
|
||||
const recallRemix = useRecallRemix(imageDTO);
|
||||
const recallPrompts = useRecallPrompts(imageDTO);
|
||||
const recallSeed = useRecallSeed(imageDTO);
|
||||
const recallDimensions = useRecallDimensions(imageDTO);
|
||||
|
||||
return (
|
||||
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiArrowBendUpLeftBold />}>
|
||||
@@ -28,20 +36,24 @@ export const ImageMenuItemMetadataRecallActions = memo(() => {
|
||||
<SubMenuButtonContent label={t('parameters.recallMetadata')} />
|
||||
</MenuButton>
|
||||
<MenuList {...subMenu.menuListProps}>
|
||||
<MenuItem icon={<PiArrowsCounterClockwiseBold />} onClick={remix} isDisabled={!hasMetadata}>
|
||||
<MenuItem
|
||||
icon={<PiArrowsCounterClockwiseBold />}
|
||||
onClick={recallRemix.recall}
|
||||
isDisabled={!recallRemix.isEnabled}
|
||||
>
|
||||
{t('parameters.remixImage')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiQuotesBold />} onClick={recallPrompts} isDisabled={!hasPrompts}>
|
||||
<MenuItem icon={<PiQuotesBold />} onClick={recallPrompts.recall} isDisabled={!recallPrompts.isEnabled}>
|
||||
{t('parameters.usePrompt')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlantBold />} onClick={recallSeed} isDisabled={!hasSeed}>
|
||||
<MenuItem icon={<PiPlantBold />} onClick={recallSeed.recall} isDisabled={!recallSeed.isEnabled}>
|
||||
{t('parameters.useSeed')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiAsteriskBold />} onClick={recallAll} isDisabled={!hasMetadata}>
|
||||
<MenuItem icon={<PiAsteriskBold />} onClick={recallAll.recall} isDisabled={!recallAll.isEnabled}>
|
||||
{t('parameters.useAll')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPaintBrushBold />} onClick={createAsPreset} isDisabled={!hasPrompts}>
|
||||
{t('stylePresets.useForTemplate')}
|
||||
<MenuItem icon={<PiRulerBold />} onClick={recallDimensions.recall} isDisabled={!recallDimensions.isEnabled}>
|
||||
{t('parameters.useSize')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
</Menu>
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
|
||||
import { useCreateStylePresetFromMetadata } from 'features/gallery/hooks/useCreateStylePresetFromMetadata';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPaintBrushBold } from 'react-icons/pi';
|
||||
|
||||
export const ImageMenuItemUseAsPromptTemplate = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const imageDTO = useImageDTOContext();
|
||||
const stylePreset = useCreateStylePresetFromMetadata(imageDTO);
|
||||
|
||||
return (
|
||||
<MenuItem icon={<PiPaintBrushBold />} onClickCapture={stylePreset.create} isDisabled={!stylePreset.isEnabled}>
|
||||
{t('stylePresets.useForTemplate')}
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
|
||||
ImageMenuItemUseAsPromptTemplate.displayName = 'ImageMenuItemUseAsPromptTemplate';
|
||||
@@ -1,4 +1,5 @@
|
||||
import { MenuDivider } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { IconMenuItemGroup } from 'common/components/IconMenuItem';
|
||||
import { ImageMenuItemChangeBoard } from 'features/gallery/components/ImageContextMenu/ImageMenuItemChangeBoard';
|
||||
import { ImageMenuItemCopy } from 'features/gallery/components/ImageContextMenu/ImageMenuItemCopy';
|
||||
@@ -16,14 +17,19 @@ import { ImageMenuItemStarUnstar } from 'features/gallery/components/ImageContex
|
||||
import { ImageMenuItemUseAsRefImage } from 'features/gallery/components/ImageContextMenu/ImageMenuItemUseAsRefImage';
|
||||
import { ImageMenuItemUseForPromptGeneration } from 'features/gallery/components/ImageContextMenu/ImageMenuItemUseForPromptGeneration';
|
||||
import { ImageDTOContextProvider } from 'features/gallery/contexts/ImageDTOContext';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { memo } from 'react';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
import { ImageMenuItemUseAsPromptTemplate } from './ImageMenuItemUseAsPromptTemplate';
|
||||
|
||||
type SingleSelectionMenuItemsProps = {
|
||||
imageDTO: ImageDTO;
|
||||
};
|
||||
|
||||
const SingleSelectionMenuItems = ({ imageDTO }: SingleSelectionMenuItemsProps) => {
|
||||
const tab = useAppSelector(selectActiveTab);
|
||||
|
||||
return (
|
||||
<ImageDTOContextProvider value={imageDTO}>
|
||||
<IconMenuItemGroup>
|
||||
@@ -36,13 +42,14 @@ const SingleSelectionMenuItems = ({ imageDTO }: SingleSelectionMenuItemsProps) =
|
||||
</IconMenuItemGroup>
|
||||
<MenuDivider />
|
||||
<ImageMenuItemLoadWorkflow />
|
||||
<ImageMenuItemMetadataRecallActions />
|
||||
{(tab === 'canvas' || tab === 'generate') && <ImageMenuItemMetadataRecallActions />}
|
||||
<MenuDivider />
|
||||
<ImageMenuItemSendToUpscale />
|
||||
<ImageMenuItemUseForPromptGeneration />
|
||||
<ImageMenuItemUseAsRefImage />
|
||||
{(tab === 'canvas' || tab === 'generate') && <ImageMenuItemUseAsRefImage />}
|
||||
<ImageMenuItemUseAsPromptTemplate />
|
||||
<ImageMenuItemNewCanvasFromImageSubMenu />
|
||||
<ImageMenuItemNewLayerFromImageSubMenu />
|
||||
{tab === 'canvas' && <ImageMenuItemNewLayerFromImageSubMenu />}
|
||||
<MenuDivider />
|
||||
<ImageMenuItemChangeBoard />
|
||||
<ImageMenuItemStarUnstar />
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
|
||||
import { draggable, monitorForElements } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
|
||||
import type { FlexProps, SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, Icon, Image } from '@invoke-ai/ui-library';
|
||||
import { Flex, Icon, Image } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import type { AppDispatch, AppGetState } from 'app/store/store';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
@@ -23,13 +23,11 @@ import { imageToCompareChanged, selectGallerySlice, selectionChanged } from 'fea
|
||||
import { navigationApi } from 'features/ui/layouts/navigation-api';
|
||||
import { VIEWER_PANEL_ID } from 'features/ui/layouts/shared';
|
||||
import type { MouseEvent, MouseEventHandler } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { PiImageBold } from 'react-icons/pi';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
const GALLERY_IMAGE_CLASS = 'gallery-image';
|
||||
|
||||
const galleryImageContainerSX = {
|
||||
containerType: 'inline-size',
|
||||
w: 'full',
|
||||
@@ -42,45 +40,42 @@ const galleryImageContainerSX = {
|
||||
'&[data-is-dragging=true]': {
|
||||
opacity: 0.3,
|
||||
},
|
||||
[`.${GALLERY_IMAGE_CLASS}`]: {
|
||||
touchAction: 'none',
|
||||
userSelect: 'none',
|
||||
webkitUserSelect: 'none',
|
||||
position: 'relative',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
aspectRatio: '1/1',
|
||||
'::before': {
|
||||
content: '""',
|
||||
display: 'inline-block',
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
left: 0,
|
||||
right: 0,
|
||||
bottom: 0,
|
||||
pointerEvents: 'none',
|
||||
borderRadius: 'base',
|
||||
},
|
||||
'&[data-selected=true]::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
},
|
||||
'&[data-selected-for-compare=true]::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
},
|
||||
'&:hover::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 1px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-800)',
|
||||
},
|
||||
'&:hover[data-selected=true]::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
},
|
||||
'&:hover[data-selected-for-compare=true]::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
},
|
||||
userSelect: 'none',
|
||||
webkitUserSelect: 'none',
|
||||
position: 'relative',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
aspectRatio: '1/1',
|
||||
'::before': {
|
||||
content: '""',
|
||||
display: 'inline-block',
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
left: 0,
|
||||
right: 0,
|
||||
bottom: 0,
|
||||
pointerEvents: 'none',
|
||||
borderRadius: 'base',
|
||||
},
|
||||
'&[data-selected=true]::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
},
|
||||
'&[data-selected-for-compare=true]::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
},
|
||||
'&:hover::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 1px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-800)',
|
||||
},
|
||||
'&:hover[data-selected=true]::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
|
||||
},
|
||||
'&:hover[data-selected-for-compare=true]::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
|
||||
},
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
@@ -142,8 +137,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
|
||||
const [dragPreviewState, setDragPreviewState] = useState<
|
||||
DndDragPreviewSingleImageState | DndDragPreviewMultipleImageState | null
|
||||
>(null);
|
||||
// Must use callback ref - else chakra's Image fallback prop will break the ref & dnd
|
||||
const [element, ref] = useState<HTMLImageElement | null>(null);
|
||||
const ref = useRef<HTMLDivElement>(null);
|
||||
const selectIsSelectedForCompare = useMemo(
|
||||
() => createSelector(selectGallerySlice, (gallery) => gallery.imageToCompare === imageDTO.image_name),
|
||||
[imageDTO.image_name]
|
||||
@@ -156,6 +150,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
|
||||
const isSelected = useAppSelector(selectIsSelected);
|
||||
|
||||
useEffect(() => {
|
||||
const element = ref.current;
|
||||
if (!element) {
|
||||
return;
|
||||
}
|
||||
@@ -221,7 +216,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
|
||||
},
|
||||
})
|
||||
);
|
||||
}, [element, imageDTO, store]);
|
||||
}, [imageDTO, store]);
|
||||
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
@@ -240,34 +235,35 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
|
||||
navigationApi.focusPanelInActiveTab(VIEWER_PANEL_ID);
|
||||
}, [store]);
|
||||
|
||||
useImageContextMenu(imageDTO, element);
|
||||
useImageContextMenu(imageDTO, ref);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Box sx={galleryImageContainerSX} data-is-dragging={isDragging} data-image-name={imageDTO.image_name}>
|
||||
<Flex
|
||||
role="button"
|
||||
className={GALLERY_IMAGE_CLASS}
|
||||
onMouseOver={onMouseOver}
|
||||
onMouseOut={onMouseOut}
|
||||
onClick={onClick}
|
||||
onDoubleClick={onDoubleClick}
|
||||
data-selected={isSelected}
|
||||
data-selected-for-compare={isSelectedForCompare}
|
||||
>
|
||||
<Image
|
||||
ref={ref}
|
||||
src={imageDTO.thumbnail_url}
|
||||
w={imageDTO.width}
|
||||
fallback={<GalleryImagePlaceholder />}
|
||||
objectFit="contain"
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
borderRadius="base"
|
||||
/>
|
||||
<GalleryImageHoverIcons imageDTO={imageDTO} isHovered={isHovered} />
|
||||
</Flex>
|
||||
</Box>
|
||||
<Flex
|
||||
ref={ref}
|
||||
sx={galleryImageContainerSX}
|
||||
data-is-dragging={isDragging}
|
||||
data-image-name={imageDTO.image_name}
|
||||
role="button"
|
||||
onMouseOver={onMouseOver}
|
||||
onMouseOut={onMouseOut}
|
||||
onClick={onClick}
|
||||
onDoubleClick={onDoubleClick}
|
||||
data-selected={isSelected}
|
||||
data-selected-for-compare={isSelectedForCompare}
|
||||
>
|
||||
<Image
|
||||
pointerEvents="none"
|
||||
src={imageDTO.thumbnail_url}
|
||||
w={imageDTO.width}
|
||||
fallback={<GalleryImagePlaceholder />}
|
||||
objectFit="contain"
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
borderRadius="base"
|
||||
/>
|
||||
<GalleryImageHoverIcons imageDTO={imageDTO} isHovered={isHovered} />
|
||||
</Flex>
|
||||
{dragPreviewState?.type === 'multiple-image' ? createMultipleImageDragPreview(dragPreviewState) : null}
|
||||
{dragPreviewState?.type === 'single-image' ? createSingleImageDragPreview(dragPreviewState) : null}
|
||||
</>
|
||||
|
||||
@@ -85,7 +85,7 @@ const UnrecallableMetadataParsed = typedMemo(
|
||||
|
||||
return (
|
||||
<Box as="span" lineHeight={1}>
|
||||
<LabelComponent />
|
||||
<LabelComponent i18nKey={handler.i18nKey} />
|
||||
<ValueComponent value={data.value} />
|
||||
</Box>
|
||||
);
|
||||
@@ -128,7 +128,7 @@ const SingleMetadataParsed = typedMemo(
|
||||
onClick={onClick}
|
||||
/>
|
||||
<Box as="span" lineHeight={1}>
|
||||
<LabelComponent />
|
||||
<LabelComponent i18nKey={handler.i18nKey} />
|
||||
<ValueComponent value={data.value} />
|
||||
</Box>
|
||||
</Flex>
|
||||
@@ -178,7 +178,7 @@ const CollectionMetadataParsed = typedMemo(
|
||||
onClick={onClick}
|
||||
/>
|
||||
<Box as="span" lineHeight={1}>
|
||||
<LabelComponent />
|
||||
<LabelComponent i18nKey={handler.i18nKey} />
|
||||
<ValueComponent value={value} />
|
||||
</Box>
|
||||
</Flex>
|
||||
|
||||
@@ -1,21 +1,19 @@
|
||||
import { Button, Divider, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
|
||||
import SingleSelectionMenuItems from 'features/gallery/components/ImageContextMenu/SingleSelectionMenuItems';
|
||||
import { useImageActions } from 'features/gallery/hooks/useImageActions';
|
||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
||||
import { newCanvasFromImage } from 'features/imageActions/actions';
|
||||
import { $hasTemplates } from 'features/nodes/store/nodesSlice';
|
||||
import { useDeleteImage } from 'features/gallery/hooks/useDeleteImage';
|
||||
import { useEditImage } from 'features/gallery/hooks/useEditImage';
|
||||
import { useLoadWorkflow } from 'features/gallery/hooks/useLoadWorkflow';
|
||||
import { useRecallAll } from 'features/gallery/hooks/useRecallAll';
|
||||
import { useRecallDimensions } from 'features/gallery/hooks/useRecallDimensions';
|
||||
import { useRecallPrompts } from 'features/gallery/hooks/useRecallPrompts';
|
||||
import { useRecallRemix } from 'features/gallery/hooks/useRecallRemix';
|
||||
import { useRecallSeed } from 'features/gallery/hooks/useRecallSeed';
|
||||
import { PostProcessingPopover } from 'features/parameters/components/PostProcessing/PostProcessingPopover';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { navigationApi } from 'features/ui/layouts/navigation-api';
|
||||
import { WORKSPACE_PANEL_ID } from 'features/ui/layouts/shared';
|
||||
import { selectShouldShowProgressInViewer } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
PiArrowsCounterClockwiseBold,
|
||||
@@ -27,51 +25,23 @@ import {
|
||||
PiQuotesBold,
|
||||
PiRulerBold,
|
||||
} from 'react-icons/pi';
|
||||
import { useImageDTO } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
import { useImageViewerContext } from './context';
|
||||
|
||||
export const CurrentImageButtons = memo(() => {
|
||||
export const CurrentImageButtons = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
|
||||
const { t } = useTranslation();
|
||||
const ctx = useImageViewerContext();
|
||||
const hasProgressImage = useStore(ctx.$hasProgressImage);
|
||||
const shouldShowProgressInViewer = useAppSelector(selectShouldShowProgressInViewer);
|
||||
const isDisabledOverride = hasProgressImage && shouldShowProgressInViewer;
|
||||
const tab = useAppSelector(selectActiveTab);
|
||||
const isCanvasOrGenerateTab = tab === 'canvas' || tab === 'generate';
|
||||
|
||||
const imageName = useAppSelector(selectLastSelectedImage);
|
||||
const imageDTO = useImageDTO(imageName);
|
||||
const hasTemplates = useStore($hasTemplates);
|
||||
const imageActions = useImageActions(imageDTO);
|
||||
const isStaging = useAppSelector(selectIsStaging);
|
||||
const isUpscalingEnabled = useFeatureStatus('upscaling');
|
||||
const { getState, dispatch } = useAppStore();
|
||||
const canvasManager = useCanvasManagerSafe();
|
||||
|
||||
const handleEdit = useCallback(async () => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
|
||||
await newCanvasFromImage({
|
||||
imageDTO,
|
||||
type: 'raster_layer',
|
||||
withInpaintMask: true,
|
||||
getState,
|
||||
dispatch,
|
||||
});
|
||||
navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
|
||||
|
||||
// Automatically select the brush tool when editing an image
|
||||
if (canvasManager) {
|
||||
canvasManager.tool.$tool.set('brush');
|
||||
}
|
||||
|
||||
toast({
|
||||
id: 'SENT_TO_CANVAS',
|
||||
title: t('toast.sentToCanvas'),
|
||||
status: 'success',
|
||||
});
|
||||
}, [imageDTO, getState, dispatch, t, canvasManager]);
|
||||
const recallAll = useRecallAll(imageDTO);
|
||||
const recallRemix = useRecallRemix(imageDTO);
|
||||
const recallPrompts = useRecallPrompts(imageDTO);
|
||||
const recallSeed = useRecallSeed(imageDTO);
|
||||
const recallDimensions = useRecallDimensions(imageDTO);
|
||||
const loadWorkflow = useLoadWorkflow(imageDTO);
|
||||
const editImage = useEditImage(imageDTO);
|
||||
const deleteImage = useDeleteImage(imageDTO);
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -80,7 +50,7 @@ export const CurrentImageButtons = memo(() => {
|
||||
as={IconButton}
|
||||
aria-label={t('parameters.imageActions')}
|
||||
tooltip={t('parameters.imageActions')}
|
||||
isDisabled={isDisabledOverride || !imageDTO}
|
||||
isDisabled={!imageDTO}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
icon={<PiDotsThreeOutlineFill />}
|
||||
@@ -92,8 +62,8 @@ export const CurrentImageButtons = memo(() => {
|
||||
|
||||
<Button
|
||||
leftIcon={<PiPencilBold />}
|
||||
onClick={handleEdit}
|
||||
isDisabled={isDisabledOverride || !imageDTO}
|
||||
onClick={editImage.edit}
|
||||
isDisabled={!editImage.isEnabled}
|
||||
variant="link"
|
||||
size="sm"
|
||||
alignSelf="stretch"
|
||||
@@ -108,62 +78,72 @@ export const CurrentImageButtons = memo(() => {
|
||||
icon={<PiFlowArrowBold />}
|
||||
tooltip={`${t('nodes.loadWorkflow')} (W)`}
|
||||
aria-label={`${t('nodes.loadWorkflow')} (W)`}
|
||||
isDisabled={isDisabledOverride || !imageDTO || !imageActions.hasWorkflow || !hasTemplates}
|
||||
isDisabled={!loadWorkflow.isEnabled}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={imageActions.loadWorkflow}
|
||||
/>
|
||||
<IconButton
|
||||
icon={<PiArrowsCounterClockwiseBold />}
|
||||
tooltip={`${t('parameters.remixImage')} (R)`}
|
||||
aria-label={`${t('parameters.remixImage')} (R)`}
|
||||
isDisabled={isDisabledOverride || !imageDTO || !imageActions.hasMetadata}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={imageActions.remix}
|
||||
/>
|
||||
<IconButton
|
||||
icon={<PiQuotesBold />}
|
||||
tooltip={`${t('parameters.usePrompt')} (P)`}
|
||||
aria-label={`${t('parameters.usePrompt')} (P)`}
|
||||
isDisabled={isDisabledOverride || !imageDTO || !imageActions.hasPrompts}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={imageActions.recallPrompts}
|
||||
/>
|
||||
<IconButton
|
||||
icon={<PiPlantBold />}
|
||||
tooltip={`${t('parameters.useSeed')} (S)`}
|
||||
aria-label={`${t('parameters.useSeed')} (S)`}
|
||||
isDisabled={isDisabledOverride || !imageDTO || !imageActions.hasSeed}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={imageActions.recallSeed}
|
||||
/>
|
||||
<IconButton
|
||||
icon={<PiRulerBold />}
|
||||
tooltip={`${t('parameters.useSize')} (D)`}
|
||||
aria-label={`${t('parameters.useSize')} (D)`}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={imageActions.recallSize}
|
||||
isDisabled={isDisabledOverride || !imageDTO || isStaging}
|
||||
/>
|
||||
<IconButton
|
||||
icon={<PiAsteriskBold />}
|
||||
tooltip={`${t('parameters.useAll')} (A)`}
|
||||
aria-label={`${t('parameters.useAll')} (A)`}
|
||||
isDisabled={isDisabledOverride || !imageDTO || !imageActions.hasMetadata}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={imageActions.recallAll}
|
||||
onClick={loadWorkflow.load}
|
||||
/>
|
||||
{isCanvasOrGenerateTab && (
|
||||
<IconButton
|
||||
icon={<PiArrowsCounterClockwiseBold />}
|
||||
tooltip={`${t('parameters.remixImage')} (R)`}
|
||||
aria-label={`${t('parameters.remixImage')} (R)`}
|
||||
isDisabled={!recallRemix.isEnabled}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={recallRemix.recall}
|
||||
/>
|
||||
)}
|
||||
{isCanvasOrGenerateTab && (
|
||||
<IconButton
|
||||
icon={<PiQuotesBold />}
|
||||
tooltip={`${t('parameters.usePrompt')} (P)`}
|
||||
aria-label={`${t('parameters.usePrompt')} (P)`}
|
||||
isDisabled={!recallPrompts.isEnabled}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={recallPrompts.recall}
|
||||
/>
|
||||
)}
|
||||
{isCanvasOrGenerateTab && (
|
||||
<IconButton
|
||||
icon={<PiPlantBold />}
|
||||
tooltip={`${t('parameters.useSeed')} (S)`}
|
||||
aria-label={`${t('parameters.useSeed')} (S)`}
|
||||
isDisabled={!recallSeed.isEnabled}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={recallSeed.recall}
|
||||
/>
|
||||
)}
|
||||
{isCanvasOrGenerateTab && (
|
||||
<IconButton
|
||||
icon={<PiRulerBold />}
|
||||
tooltip={`${t('parameters.useSize')} (D)`}
|
||||
aria-label={`${t('parameters.useSize')} (D)`}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={recallDimensions.recall}
|
||||
isDisabled={!recallDimensions.isEnabled}
|
||||
/>
|
||||
)}
|
||||
{isCanvasOrGenerateTab && (
|
||||
<IconButton
|
||||
icon={<PiAsteriskBold />}
|
||||
tooltip={`${t('parameters.useAll')} (A)`}
|
||||
aria-label={`${t('parameters.useAll')} (A)`}
|
||||
isDisabled={!recallAll.isEnabled}
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={recallAll.recall}
|
||||
/>
|
||||
)}
|
||||
|
||||
{isUpscalingEnabled && <PostProcessingPopover imageDTO={imageDTO} isDisabled={isDisabledOverride} />}
|
||||
{isUpscalingEnabled && <PostProcessingPopover imageDTO={imageDTO} isDisabled={false} />}
|
||||
|
||||
<Divider orientation="vertical" h={8} mx={2} />
|
||||
|
||||
<DeleteImageButton onClick={imageActions.delete} isDisabled={isDisabledOverride || !imageDTO} />
|
||||
<DeleteImageButton onClick={deleteImage.delete} isDisabled={!deleteImage.isEnabled} />
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -50,7 +50,7 @@ export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | nu
|
||||
>
|
||||
{imageDTO && (
|
||||
<Flex w="full" h="full" position="absolute" alignItems="center" justifyContent="center">
|
||||
<DndImage imageDTO={imageDTO} onLoad={onLoadImage} />
|
||||
<DndImage imageDTO={imageDTO} onLoad={onLoadImage} borderRadius="base" />
|
||||
</Flex>
|
||||
)}
|
||||
{!imageDTO && <NoContentForViewer />}
|
||||
|
||||
@@ -1,18 +1,24 @@
|
||||
import { Flex, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton';
|
||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
||||
import { memo } from 'react';
|
||||
import { useImageDTO } from 'services/api/endpoints/images';
|
||||
|
||||
import { CurrentImageButtons } from './CurrentImageButtons';
|
||||
import { ToggleProgressButton } from './ToggleProgressButton';
|
||||
|
||||
export const ViewerToolbar = memo(() => {
|
||||
const imageName = useAppSelector(selectLastSelectedImage);
|
||||
const imageDTO = useImageDTO(imageName);
|
||||
|
||||
return (
|
||||
<Flex w="full" justifyContent="center" h={8}>
|
||||
<ToggleProgressButton />
|
||||
<Spacer />
|
||||
<CurrentImageButtons />
|
||||
{imageDTO && <CurrentImageButtons imageDTO={imageDTO} />}
|
||||
<Spacer />
|
||||
<ToggleMetadataViewerButton />
|
||||
{imageDTO && <ToggleMetadataViewerButton />}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -2,6 +2,7 @@ import { Box, Flex, forwardRef, Grid, GridItem, Spinner, Text } from '@invoke-ai
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { getFocusedRegion } from 'common/hooks/focus';
|
||||
import { useRangeBasedImageFetching } from 'features/gallery/hooks/useRangeBasedImageFetching';
|
||||
import type { selectGetImageNamesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import {
|
||||
@@ -221,6 +222,10 @@ const useKeyboardNavigation = (
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(event: KeyboardEvent) => {
|
||||
if (getFocusedRegion() !== 'gallery') {
|
||||
// Only handle keyboard navigation when the gallery is focused
|
||||
return;
|
||||
}
|
||||
// Only handle arrow keys
|
||||
if (!['ArrowUp', 'ArrowDown', 'ArrowLeft', 'ArrowRight'].includes(event.key)) {
|
||||
return;
|
||||
@@ -477,11 +482,6 @@ export const NewGallery = memo(() => {
|
||||
|
||||
const context = useMemo<GridContext>(() => ({ imageNames, queryArgs }), [imageNames, queryArgs]);
|
||||
|
||||
// Item content function
|
||||
const itemContent: GridItemContent<string, GridContext> = useCallback((index, imageName) => {
|
||||
return <ImageAtPosition index={index} imageName={imageName} />;
|
||||
}, []);
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center" gap={4}>
|
||||
@@ -506,7 +506,7 @@ export const NewGallery = memo(() => {
|
||||
ref={virtuosoRef}
|
||||
context={context}
|
||||
data={imageNames}
|
||||
increaseViewportBy={2048}
|
||||
increaseViewportBy={4096}
|
||||
itemContent={itemContent}
|
||||
computeItemKey={computeItemKey}
|
||||
components={components}
|
||||
@@ -523,8 +523,12 @@ export const NewGallery = memo(() => {
|
||||
NewGallery.displayName = 'NewGallery';
|
||||
|
||||
const scrollSeekConfiguration: ScrollSeekConfiguration = {
|
||||
enter: (velocity) => velocity > 4096,
|
||||
exit: (velocity) => velocity === 0,
|
||||
enter: (velocity) => {
|
||||
return Math.abs(velocity) > 2048;
|
||||
},
|
||||
exit: (velocity) => {
|
||||
return velocity === 0;
|
||||
},
|
||||
};
|
||||
|
||||
// Styles
|
||||
@@ -544,6 +548,10 @@ const ListComponent: GridComponents<GridContext>['List'] = forwardRef(({ context
|
||||
});
|
||||
ListComponent.displayName = 'ListComponent';
|
||||
|
||||
const itemContent: GridItemContent<string, GridContext> = (index, imageName) => {
|
||||
return <ImageAtPosition index={index} imageName={imageName} />;
|
||||
};
|
||||
|
||||
const ItemComponent: GridComponents<GridContext>['Item'] = forwardRef(({ context: _, ...rest }, ref) => (
|
||||
<GridItem ref={ref} aspectRatio="1/1" {...rest} />
|
||||
));
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import {
|
||||
activeStylePresetIdChanged,
|
||||
selectStylePresetActivePresetId,
|
||||
} from 'features/stylePresets/store/stylePresetSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const useClearStylePresetWithToast = () => {
|
||||
const store = useAppStore();
|
||||
const { t } = useTranslation();
|
||||
const activeStylePresetId = useAppSelector(selectStylePresetActivePresetId);
|
||||
|
||||
const clearStylePreset = useCallback(() => {
|
||||
if (activeStylePresetId) {
|
||||
store.dispatch(activeStylePresetIdChanged(null));
|
||||
toast({
|
||||
status: 'info',
|
||||
title: t('stylePresets.promptTemplateCleared'),
|
||||
});
|
||||
}
|
||||
}, [activeStylePresetId, store, t]);
|
||||
|
||||
return clearStylePreset;
|
||||
};
|
||||
@@ -0,0 +1,81 @@
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { MetadataHandlers, MetadataUtils } from 'features/metadata/parsing';
|
||||
import { $stylePresetModalState } from 'features/stylePresets/store/stylePresetModal';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const useCreateStylePresetFromMetadata = (imageDTO?: ImageDTO | null) => {
|
||||
const store = useAppStore();
|
||||
const [hasPrompts, setHasPrompts] = useState(false);
|
||||
|
||||
const { metadata } = useDebouncedMetadata(imageDTO?.image_name);
|
||||
|
||||
useEffect(() => {
|
||||
MetadataUtils.hasMetadataByHandlers({
|
||||
handlers: [MetadataHandlers.PositivePrompt, MetadataHandlers.NegativePrompt],
|
||||
metadata,
|
||||
store,
|
||||
require: 'some',
|
||||
})
|
||||
.then((result) => {
|
||||
setHasPrompts(result);
|
||||
})
|
||||
.catch(() => {
|
||||
setHasPrompts(false);
|
||||
});
|
||||
}, [metadata, store]);
|
||||
|
||||
const isEnabled = useMemo(() => {
|
||||
if (!imageDTO) {
|
||||
return false;
|
||||
}
|
||||
if (!hasPrompts) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}, [hasPrompts, imageDTO]);
|
||||
|
||||
const create = useCallback(async () => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
if (!metadata) {
|
||||
return;
|
||||
}
|
||||
if (!isEnabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
let positivePrompt: string;
|
||||
let negativePrompt: string;
|
||||
|
||||
try {
|
||||
positivePrompt = await MetadataHandlers.PositivePrompt.parse(metadata, store);
|
||||
} catch (error) {
|
||||
positivePrompt = '';
|
||||
}
|
||||
try {
|
||||
negativePrompt = (await MetadataHandlers.NegativePrompt.parse(metadata, store)) ?? '';
|
||||
} catch (error) {
|
||||
negativePrompt = '';
|
||||
}
|
||||
|
||||
$stylePresetModalState.set({
|
||||
prefilledFormData: {
|
||||
name: '',
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
imageUrl: imageDTO.image_url,
|
||||
type: 'user',
|
||||
},
|
||||
updatingStylePresetId: null,
|
||||
isModalOpen: true,
|
||||
});
|
||||
}, [imageDTO, isEnabled, metadata, store]);
|
||||
|
||||
return {
|
||||
create,
|
||||
isEnabled,
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,28 @@
|
||||
import { useDeleteImageModalApi } from 'features/deleteImageModal/store/state';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const useDeleteImage = (imageDTO?: ImageDTO | null) => {
|
||||
const deleteImageModal = useDeleteImageModalApi();
|
||||
|
||||
const isEnabled = useMemo(() => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
return true;
|
||||
}, [imageDTO]);
|
||||
const _delete = useCallback(() => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
if (!isEnabled) {
|
||||
return;
|
||||
}
|
||||
deleteImageModal.delete([imageDTO.image_name]);
|
||||
}, [deleteImageModal, imageDTO, isEnabled]);
|
||||
|
||||
return {
|
||||
delete: _delete,
|
||||
isEnabled,
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,57 @@
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { newCanvasFromImage } from 'features/imageActions/actions';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { navigationApi } from 'features/ui/layouts/navigation-api';
|
||||
import { WORKSPACE_PANEL_ID } from 'features/ui/layouts/shared';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const useEditImage = (imageDTO?: ImageDTO | null) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { getState, dispatch } = useAppStore();
|
||||
const canvasManager = useCanvasManagerSafe();
|
||||
|
||||
const isEnabled = useMemo(() => {
|
||||
if (!imageDTO) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}, [imageDTO]);
|
||||
|
||||
const edit = useCallback(async () => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!isEnabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
await newCanvasFromImage({
|
||||
imageDTO,
|
||||
type: 'raster_layer',
|
||||
withInpaintMask: true,
|
||||
getState,
|
||||
dispatch,
|
||||
});
|
||||
navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
|
||||
|
||||
if (canvasManager) {
|
||||
canvasManager.tool.$tool.set('brush');
|
||||
}
|
||||
|
||||
toast({
|
||||
id: 'SENT_TO_CANVAS',
|
||||
title: t('toast.sentToCanvas'),
|
||||
status: 'success',
|
||||
});
|
||||
}, [imageDTO, isEnabled, getState, dispatch, canvasManager, t]);
|
||||
|
||||
return {
|
||||
edit,
|
||||
isEnabled,
|
||||
};
|
||||
};
|
||||
@@ -1,209 +0,0 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { adHocPostProcessingRequested } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { useDeleteImageModalApi } from 'features/deleteImageModal/store/state';
|
||||
import { MetadataHandlers, MetadataUtils } from 'features/metadata/parsing';
|
||||
import { $hasTemplates } from 'features/nodes/store/nodesSlice';
|
||||
import { $stylePresetModalState } from 'features/stylePresets/store/stylePresetModal';
|
||||
import {
|
||||
activeStylePresetIdChanged,
|
||||
selectStylePresetActivePresetId,
|
||||
} from 'features/stylePresets/store/stylePresetSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const useImageActions = (imageDTO: ImageDTO | null) => {
|
||||
const store = useAppStore();
|
||||
const { t } = useTranslation();
|
||||
const activeStylePresetId = useAppSelector(selectStylePresetActivePresetId);
|
||||
const isStaging = useAppSelector(selectIsStaging);
|
||||
const { metadata } = useDebouncedMetadata(imageDTO?.image_name ?? null);
|
||||
const [hasMetadata, setHasMetadata] = useState(false);
|
||||
const [hasSeed, setHasSeed] = useState(false);
|
||||
const [hasPrompts, setHasPrompts] = useState(false);
|
||||
const hasTemplates = useStore($hasTemplates);
|
||||
const deleteImageModal = useDeleteImageModalApi();
|
||||
|
||||
useEffect(() => {
|
||||
const parseMetadata = async () => {
|
||||
if (metadata) {
|
||||
setHasMetadata(true);
|
||||
try {
|
||||
await MetadataHandlers.Seed.parse(metadata, store);
|
||||
setHasSeed(true);
|
||||
} catch {
|
||||
setHasSeed(false);
|
||||
}
|
||||
|
||||
let hasPrompt = false;
|
||||
// Need to catch all of these to avoid unhandled promise rejections bubbling up to instrumented error handlers
|
||||
for (const handler of [
|
||||
MetadataHandlers.PositivePrompt,
|
||||
MetadataHandlers.NegativePrompt,
|
||||
MetadataHandlers.PositiveStylePrompt,
|
||||
MetadataHandlers.NegativeStylePrompt,
|
||||
]) {
|
||||
try {
|
||||
await handler.parse(metadata, store);
|
||||
hasPrompt = true;
|
||||
break;
|
||||
} catch {
|
||||
// noop
|
||||
}
|
||||
}
|
||||
setHasPrompts(hasPrompt);
|
||||
} else {
|
||||
setHasMetadata(false);
|
||||
setHasSeed(false);
|
||||
setHasPrompts(false);
|
||||
}
|
||||
};
|
||||
parseMetadata();
|
||||
}, [metadata, store]);
|
||||
|
||||
const clearStylePreset = useCallback(() => {
|
||||
if (activeStylePresetId) {
|
||||
store.dispatch(activeStylePresetIdChanged(null));
|
||||
toast({
|
||||
status: 'info',
|
||||
title: t('stylePresets.promptTemplateCleared'),
|
||||
});
|
||||
}
|
||||
}, [activeStylePresetId, store, t]);
|
||||
|
||||
const recallAll = useCallback(() => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
if (!metadata) {
|
||||
return;
|
||||
}
|
||||
MetadataUtils.recallAll(metadata, store, isStaging ? [MetadataHandlers.Width, MetadataHandlers.Height] : []);
|
||||
clearStylePreset();
|
||||
}, [imageDTO, metadata, store, isStaging, clearStylePreset]);
|
||||
|
||||
const remix = useCallback(() => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
if (!metadata) {
|
||||
return;
|
||||
}
|
||||
// Recalls all metadata parameters except seed
|
||||
MetadataUtils.recallAll(metadata, store, [MetadataHandlers.Seed]);
|
||||
clearStylePreset();
|
||||
}, [imageDTO, metadata, store, clearStylePreset]);
|
||||
|
||||
const recallSeed = useCallback(() => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
if (!metadata) {
|
||||
return;
|
||||
}
|
||||
MetadataUtils.recallByHandler({ metadata, store, handler: MetadataHandlers.Seed });
|
||||
}, [imageDTO, metadata, store]);
|
||||
|
||||
const recallPrompts = useCallback(() => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
if (!metadata) {
|
||||
return;
|
||||
}
|
||||
MetadataUtils.recallPrompts(metadata, store);
|
||||
clearStylePreset();
|
||||
}, [imageDTO, metadata, store, clearStylePreset]);
|
||||
|
||||
const createAsPreset = useCallback(async () => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
if (!metadata) {
|
||||
return;
|
||||
}
|
||||
let positivePrompt: string;
|
||||
let negativePrompt: string;
|
||||
|
||||
try {
|
||||
positivePrompt = await MetadataHandlers.PositivePrompt.parse(metadata, store);
|
||||
} catch (error) {
|
||||
positivePrompt = '';
|
||||
}
|
||||
try {
|
||||
negativePrompt = (await MetadataHandlers.NegativePrompt.parse(metadata, store)) ?? '';
|
||||
} catch (error) {
|
||||
negativePrompt = '';
|
||||
}
|
||||
|
||||
$stylePresetModalState.set({
|
||||
prefilledFormData: {
|
||||
name: '',
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
imageUrl: imageDTO.image_url,
|
||||
type: 'user',
|
||||
},
|
||||
updatingStylePresetId: null,
|
||||
isModalOpen: true,
|
||||
});
|
||||
}, [imageDTO, metadata, store]);
|
||||
|
||||
const loadWorkflowWithDialog = useLoadWorkflowWithDialog();
|
||||
|
||||
const loadWorkflowFromImage = useCallback(() => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
if (!imageDTO.has_workflow || !hasTemplates) {
|
||||
return;
|
||||
}
|
||||
|
||||
loadWorkflowWithDialog({ type: 'image', data: imageDTO.image_name });
|
||||
}, [hasTemplates, imageDTO, loadWorkflowWithDialog]);
|
||||
|
||||
const recallSize = useCallback(() => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
if (isStaging) {
|
||||
return;
|
||||
}
|
||||
MetadataUtils.recallDimensions(imageDTO, store);
|
||||
}, [imageDTO, isStaging, store]);
|
||||
|
||||
const upscale = useCallback(() => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
store.dispatch(adHocPostProcessingRequested({ imageDTO }));
|
||||
}, [imageDTO, store]);
|
||||
|
||||
const _delete = useCallback(() => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
deleteImageModal.delete([imageDTO.image_name]);
|
||||
}, [deleteImageModal, imageDTO]);
|
||||
|
||||
return {
|
||||
hasMetadata,
|
||||
hasSeed,
|
||||
hasPrompts,
|
||||
recallAll,
|
||||
remix,
|
||||
recallSeed,
|
||||
recallPrompts,
|
||||
createAsPreset,
|
||||
loadWorkflow: loadWorkflowFromImage,
|
||||
hasWorkflow: imageDTO?.has_workflow ?? false,
|
||||
recallSize,
|
||||
upscale,
|
||||
delete: _delete,
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,34 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $hasTemplates } from 'features/nodes/store/nodesSlice';
|
||||
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const useLoadWorkflow = (imageDTO: ImageDTO) => {
|
||||
const hasTemplates = useStore($hasTemplates);
|
||||
|
||||
const loadWorkflowWithDialog = useLoadWorkflowWithDialog();
|
||||
|
||||
const isEnabled = useMemo(() => {
|
||||
if (!imageDTO.has_workflow) {
|
||||
return false;
|
||||
}
|
||||
if (!hasTemplates) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}, [hasTemplates, imageDTO]);
|
||||
|
||||
const load = useCallback(() => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
if (!isEnabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
loadWorkflowWithDialog({ type: 'image', data: imageDTO.image_name });
|
||||
}, [imageDTO, isEnabled, loadWorkflowWithDialog]);
|
||||
|
||||
return { load, isEnabled };
|
||||
};
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { useCallback } from 'react';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import type { ListRange } from 'react-virtuoso';
|
||||
import { imagesApi, useGetImageDTOsByNamesMutation } from 'services/api/endpoints/images';
|
||||
import { useThrottledCallback } from 'use-debounce';
|
||||
@@ -13,33 +13,20 @@ interface UseRangeBasedImageFetchingReturn {
|
||||
onRangeChanged: (range: ListRange) => void;
|
||||
}
|
||||
|
||||
const getUncachedNames = (imageNames: string[], cachedImageNames: string[], range: ListRange): string[] => {
|
||||
if (range.startIndex === range.endIndex) {
|
||||
// If the start and end indices are the same, no range to fetch
|
||||
return [];
|
||||
}
|
||||
const getUncachedNames = (imageNames: string[], cachedImageNames: string[], ranges: ListRange[]): string[] => {
|
||||
const uncachedNamesSet = new Set<string>();
|
||||
const cachedImageNamesSet = new Set(cachedImageNames);
|
||||
|
||||
if (imageNames.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const start = Math.max(0, range.startIndex);
|
||||
const end = Math.min(imageNames.length - 1, range.endIndex);
|
||||
|
||||
if (cachedImageNames.length === 0) {
|
||||
return imageNames.slice(start, end + 1);
|
||||
}
|
||||
|
||||
const uncachedNames: string[] = [];
|
||||
|
||||
for (let i = start; i <= end; i++) {
|
||||
const imageName = imageNames[i]!;
|
||||
if (!cachedImageNames.includes(imageName)) {
|
||||
uncachedNames.push(imageName);
|
||||
for (const range of ranges) {
|
||||
for (let i = range.startIndex; i <= range.endIndex; i++) {
|
||||
const n = imageNames[i]!;
|
||||
if (n && !cachedImageNamesSet.has(n)) {
|
||||
uncachedNamesSet.add(n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return uncachedNames;
|
||||
return Array.from(uncachedNamesSet);
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -53,30 +40,36 @@ export const useRangeBasedImageFetching = ({
|
||||
}: UseRangeBasedImageFetchingArgs): UseRangeBasedImageFetchingReturn => {
|
||||
const store = useAppStore();
|
||||
const [getImageDTOsByNames] = useGetImageDTOsByNamesMutation();
|
||||
const [lastRange, setLastRange] = useState<ListRange | null>(null);
|
||||
const [pendingRanges, setPendingRanges] = useState<ListRange[]>([]);
|
||||
|
||||
const fetchImages = useCallback(
|
||||
(visibleRange: ListRange) => {
|
||||
(ranges: ListRange[], imageNames: string[]) => {
|
||||
if (!enabled) {
|
||||
return;
|
||||
}
|
||||
const cachedImageNames = imagesApi.util.selectCachedArgsForQuery(store.getState(), 'getImageDTO');
|
||||
const uncachedNames = getUncachedNames(imageNames, cachedImageNames, visibleRange);
|
||||
const uncachedNames = getUncachedNames(imageNames, cachedImageNames, ranges);
|
||||
if (uncachedNames.length === 0) {
|
||||
return;
|
||||
}
|
||||
getImageDTOsByNames({ image_names: uncachedNames });
|
||||
setPendingRanges([]);
|
||||
},
|
||||
[enabled, getImageDTOsByNames, imageNames, store]
|
||||
[enabled, getImageDTOsByNames, store]
|
||||
);
|
||||
|
||||
const throttledFetchImages = useThrottledCallback(fetchImages, 100);
|
||||
const throttledFetchImages = useThrottledCallback(fetchImages, 500);
|
||||
|
||||
const onRangeChanged = useCallback(
|
||||
(range: ListRange) => {
|
||||
throttledFetchImages(range);
|
||||
},
|
||||
[throttledFetchImages]
|
||||
);
|
||||
const onRangeChanged = useCallback((range: ListRange) => {
|
||||
setLastRange(range);
|
||||
setPendingRanges((prev) => [...prev, range]);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const combinedRanges = lastRange ? [...pendingRanges, lastRange] : pendingRanges;
|
||||
throttledFetchImages(combinedRanges, imageNames);
|
||||
}, [imageNames, lastRange, pendingRanges, throttledFetchImages]);
|
||||
|
||||
return {
|
||||
onRangeChanged,
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { MetadataHandlers, MetadataUtils } from 'features/metadata/parsing';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
import { useClearStylePresetWithToast } from './useClearStylePresetWithToast';
|
||||
|
||||
export const useRecallAll = (imageDTO: ImageDTO) => {
|
||||
const store = useAppStore();
|
||||
const tab = useAppSelector(selectActiveTab);
|
||||
const { metadata, isLoading } = useDebouncedMetadata(imageDTO.image_name);
|
||||
const isStaging = useAppSelector(selectIsStaging);
|
||||
const clearStylePreset = useClearStylePresetWithToast();
|
||||
|
||||
const isEnabled = useMemo(() => {
|
||||
if (isLoading) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tab !== 'canvas' && tab !== 'generate') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!metadata) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}, [isLoading, metadata, tab]);
|
||||
|
||||
const handlersToSkip = useMemo(() => {
|
||||
if (tab === 'canvas' && isStaging) {
|
||||
// When we are staging and on canvas, the bbox is locked - we cannot recall width and height
|
||||
return [MetadataHandlers.Width, MetadataHandlers.Height];
|
||||
}
|
||||
return undefined;
|
||||
}, [isStaging, tab]);
|
||||
|
||||
const recall = useCallback(() => {
|
||||
if (!metadata) {
|
||||
return;
|
||||
}
|
||||
if (!isEnabled) {
|
||||
return;
|
||||
}
|
||||
MetadataUtils.recallAll(metadata, store, handlersToSkip);
|
||||
clearStylePreset();
|
||||
}, [metadata, isEnabled, store, handlersToSkip, clearStylePreset]);
|
||||
|
||||
return {
|
||||
recall,
|
||||
isEnabled,
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,36 @@
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { MetadataUtils } from 'features/metadata/parsing';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const useRecallDimensions = (imageDTO: ImageDTO) => {
|
||||
const store = useAppStore();
|
||||
const tab = useAppSelector(selectActiveTab);
|
||||
const isStaging = useAppSelector(selectIsStaging);
|
||||
|
||||
const isEnabled = useMemo(() => {
|
||||
if (tab !== 'canvas' && tab !== 'generate') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tab === 'canvas' && isStaging) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}, [isStaging, tab]);
|
||||
|
||||
const recall = useCallback(() => {
|
||||
if (!isEnabled) {
|
||||
return;
|
||||
}
|
||||
MetadataUtils.recallDimensions(imageDTO, store);
|
||||
}, [isEnabled, imageDTO, store]);
|
||||
|
||||
return {
|
||||
recall,
|
||||
isEnabled,
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,72 @@
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { MetadataHandlers, MetadataUtils } from 'features/metadata/parsing';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
import { useClearStylePresetWithToast } from './useClearStylePresetWithToast';
|
||||
|
||||
export const useRecallPrompts = (imageDTO: ImageDTO) => {
|
||||
const store = useAppStore();
|
||||
const tab = useAppSelector(selectActiveTab);
|
||||
const clearStylePreset = useClearStylePresetWithToast();
|
||||
const [hasPrompts, setHasPrompts] = useState(false);
|
||||
|
||||
const { metadata, isLoading } = useDebouncedMetadata(imageDTO.image_name);
|
||||
|
||||
useEffect(() => {
|
||||
const parse = async () => {
|
||||
try {
|
||||
const result = await MetadataUtils.hasMetadataByHandlers({
|
||||
handlers: [
|
||||
MetadataHandlers.PositivePrompt,
|
||||
MetadataHandlers.NegativePrompt,
|
||||
MetadataHandlers.PositiveStylePrompt,
|
||||
MetadataHandlers.NegativeStylePrompt,
|
||||
],
|
||||
metadata,
|
||||
store,
|
||||
require: 'some',
|
||||
});
|
||||
setHasPrompts(result);
|
||||
} catch {
|
||||
setHasPrompts(false);
|
||||
}
|
||||
};
|
||||
|
||||
parse();
|
||||
}, [metadata, store]);
|
||||
|
||||
const isEnabled = useMemo(() => {
|
||||
if (isLoading) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tab !== 'canvas' && tab !== 'generate') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!hasPrompts) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}, [hasPrompts, isLoading, tab]);
|
||||
|
||||
const recall = useCallback(() => {
|
||||
if (!metadata) {
|
||||
return;
|
||||
}
|
||||
if (!isEnabled) {
|
||||
return;
|
||||
}
|
||||
MetadataUtils.recallPrompts(metadata, store);
|
||||
clearStylePreset();
|
||||
}, [metadata, isEnabled, store, clearStylePreset]);
|
||||
|
||||
return {
|
||||
recall,
|
||||
isEnabled,
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,60 @@
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { MetadataHandlers, MetadataUtils } from 'features/metadata/parsing';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
import { useClearStylePresetWithToast } from './useClearStylePresetWithToast';
|
||||
|
||||
export const useRecallRemix = (imageDTO: ImageDTO) => {
|
||||
const store = useAppStore();
|
||||
const tab = useAppSelector(selectActiveTab);
|
||||
const isStaging = useAppSelector(selectIsStaging);
|
||||
const clearStylePreset = useClearStylePresetWithToast();
|
||||
|
||||
const { metadata, isLoading } = useDebouncedMetadata(imageDTO.image_name);
|
||||
|
||||
const isEnabled = useMemo(() => {
|
||||
if (isLoading) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tab !== 'canvas' && tab !== 'generate') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!metadata) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}, [isLoading, metadata, tab]);
|
||||
|
||||
const handlersToSkip = useMemo(() => {
|
||||
// Remix always skips the seed handler
|
||||
const _handlersToSkip = [MetadataHandlers.Seed];
|
||||
if (tab === 'canvas' && isStaging) {
|
||||
// When we are staging and on canvas, the bbox is locked - we cannot recall width and height
|
||||
_handlersToSkip.push(MetadataHandlers.Width, MetadataHandlers.Height);
|
||||
}
|
||||
return _handlersToSkip;
|
||||
}, [isStaging, tab]);
|
||||
|
||||
const recall = useCallback(() => {
|
||||
if (!metadata) {
|
||||
return;
|
||||
}
|
||||
if (!isEnabled) {
|
||||
return;
|
||||
}
|
||||
MetadataUtils.recallAll(metadata, store, handlersToSkip);
|
||||
clearStylePreset();
|
||||
}, [metadata, isEnabled, store, handlersToSkip, clearStylePreset]);
|
||||
|
||||
return {
|
||||
recall,
|
||||
isEnabled,
|
||||
};
|
||||
};
|
||||
@@ -0,0 +1,62 @@
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { MetadataHandlers, MetadataUtils } from 'features/metadata/parsing';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const useRecallSeed = (imageDTO: ImageDTO) => {
|
||||
const store = useAppStore();
|
||||
const tab = useAppSelector(selectActiveTab);
|
||||
const [hasSeed, setHasSeed] = useState(false);
|
||||
|
||||
const { metadata, isLoading } = useDebouncedMetadata(imageDTO.image_name);
|
||||
|
||||
useEffect(() => {
|
||||
const parse = async () => {
|
||||
try {
|
||||
await MetadataHandlers.Seed.parse(metadata, store);
|
||||
setHasSeed(true);
|
||||
} catch {
|
||||
setHasSeed(false);
|
||||
}
|
||||
};
|
||||
|
||||
parse();
|
||||
}, [metadata, store]);
|
||||
|
||||
const isEnabled = useMemo(() => {
|
||||
if (isLoading) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tab !== 'canvas' && tab !== 'generate') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!metadata) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!hasSeed) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}, [hasSeed, isLoading, metadata, tab]);
|
||||
|
||||
const recall = useCallback(() => {
|
||||
if (!metadata) {
|
||||
return;
|
||||
}
|
||||
if (!isEnabled) {
|
||||
return;
|
||||
}
|
||||
MetadataUtils.recallByHandler({ metadata, handler: MetadataHandlers.Seed, store });
|
||||
}, [metadata, isEnabled, store]);
|
||||
|
||||
return {
|
||||
recall,
|
||||
isEnabled,
|
||||
};
|
||||
};
|
||||
@@ -85,7 +85,12 @@ export const createNewCanvasEntityFromImage = async (arg: {
|
||||
}) => {
|
||||
const { type, imageDTO, dispatch, getState, withResize, overrides: _overrides } = arg;
|
||||
const state = getState();
|
||||
const { x, y, width, height } = selectBboxRect(state);
|
||||
const { x, y } = selectBboxRect(state);
|
||||
|
||||
const base = selectBboxModelBase(state);
|
||||
const ratio = imageDTO.width / imageDTO.height;
|
||||
const optimalDimension = getOptimalDimension(base);
|
||||
const { width, height } = calculateNewSize(ratio, optimalDimension ** 2, base);
|
||||
|
||||
let imageObject: CanvasImageState;
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { bboxHeightChanged, bboxWidthChanged, canvasMetadataRecalled } from 'features/controlLayers/store/canvasSlice';
|
||||
import { loraAllDeleted, loraRecalled } from 'features/controlLayers/store/lorasSlice';
|
||||
import {
|
||||
heightChanged,
|
||||
negativePrompt2Changed,
|
||||
negativePromptChanged,
|
||||
positivePrompt2Changed,
|
||||
@@ -31,6 +32,7 @@ import {
|
||||
setSteps,
|
||||
shouldConcatPromptsChanged,
|
||||
vaeSelected,
|
||||
widthChanged,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import { refImagesRecalled } from 'features/controlLayers/store/refImagesSlice';
|
||||
import type { CanvasMetadata, LoRA, RefImageState } from 'features/controlLayers/store/types';
|
||||
@@ -82,8 +84,9 @@ import {
|
||||
zParameterStrength,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { t } from 'i18next';
|
||||
import type { ComponentType, ReactNode } from 'react';
|
||||
import type { ComponentType } from 'react';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
@@ -170,7 +173,8 @@ export type SingleMetadataHandler<T> = {
|
||||
type: string;
|
||||
parse: (metadata: unknown, store: AppStore) => Promise<T>;
|
||||
recall: (value: T, store: AppStore) => void;
|
||||
LabelComponent: ComponentType;
|
||||
i18nKey: string;
|
||||
LabelComponent: ComponentType<{ i18nKey: string }>;
|
||||
ValueComponent: ComponentType<SingleMetadataValueProps<T>>;
|
||||
};
|
||||
|
||||
@@ -184,7 +188,8 @@ export type CollectionMetadataHandler<T extends any[]> = {
|
||||
parse: (metadata: unknown, store: AppStore) => Promise<T>;
|
||||
recall: (values: T, store: AppStore) => void;
|
||||
recallOne: (value: T[number], store: AppStore) => void;
|
||||
LabelComponent: ComponentType;
|
||||
i18nKey: string;
|
||||
LabelComponent: ComponentType<{ i18nKey: string }>;
|
||||
ValueComponent: ComponentType<CollectionMetadataValueProps<T>>;
|
||||
};
|
||||
|
||||
@@ -196,7 +201,8 @@ export type UnrecallableMetadataHandler<T> = {
|
||||
[UnrecallableMetadataKey]: true;
|
||||
type: string;
|
||||
parse: (metadata: unknown, store: AppStore) => Promise<T>;
|
||||
LabelComponent: ComponentType;
|
||||
i18nKey: string;
|
||||
LabelComponent: ComponentType<{ i18nKey: string }>;
|
||||
ValueComponent: ComponentType<UnrecallableMetadataValueProps<T>>;
|
||||
};
|
||||
|
||||
@@ -221,7 +227,8 @@ const CreatedBy: UnrecallableMetadataHandler<string> = {
|
||||
const parsed = z.string().parse(raw);
|
||||
return Promise.resolve(parsed);
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="metadata.createdBy" />,
|
||||
i18nKey: 'metadata.createdBy',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: UnrecallableMetadataValueProps<string>) => <MetadataPrimitiveValue value={value} />,
|
||||
};
|
||||
//#endregion Created By
|
||||
@@ -235,7 +242,8 @@ const GenerationMode: UnrecallableMetadataHandler<string> = {
|
||||
const parsed = z.string().parse(raw);
|
||||
return Promise.resolve(parsed);
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="metadata.generationMode" />,
|
||||
i18nKey: 'metadata.generationMode',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: UnrecallableMetadataValueProps<string>) => <MetadataPrimitiveValue value={value} />,
|
||||
};
|
||||
//#endregion Generation Mode
|
||||
@@ -252,7 +260,8 @@ const PositivePrompt: SingleMetadataHandler<ParameterPositivePrompt> = {
|
||||
recall: (value, store) => {
|
||||
store.dispatch(positivePromptChanged(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="metadata.positivePrompt" />,
|
||||
i18nKey: 'metadata.positivePrompt',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterPositivePrompt>) => (
|
||||
<MetadataPrimitiveValue value={value} />
|
||||
),
|
||||
@@ -271,7 +280,8 @@ const NegativePrompt: SingleMetadataHandler<ParameterNegativePrompt> = {
|
||||
recall: (value, store) => {
|
||||
store.dispatch(negativePromptChanged(value || null));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="metadata.negativePrompt" />,
|
||||
i18nKey: 'metadata.negativePrompt',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterNegativePrompt>) => (
|
||||
<MetadataPrimitiveValue value={value} />
|
||||
),
|
||||
@@ -290,7 +300,8 @@ const PositiveStylePrompt: SingleMetadataHandler<ParameterPositiveStylePromptSDX
|
||||
recall: (value, store) => {
|
||||
store.dispatch(positivePrompt2Changed(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="sdxl.posStylePrompt" />,
|
||||
i18nKey: 'sdxl.posStylePrompt',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterPositiveStylePromptSDXL>) => (
|
||||
<MetadataPrimitiveValue value={value} />
|
||||
),
|
||||
@@ -309,7 +320,8 @@ const NegativeStylePrompt: SingleMetadataHandler<ParameterPositiveStylePromptSDX
|
||||
recall: (value, store) => {
|
||||
store.dispatch(negativePrompt2Changed(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="sdxl.negStylePrompt" />,
|
||||
i18nKey: 'sdxl.negStylePrompt',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterPositiveStylePromptSDXL>) => (
|
||||
<MetadataPrimitiveValue value={value} />
|
||||
),
|
||||
@@ -328,7 +340,8 @@ const CFGScale: SingleMetadataHandler<ParameterCFGScale> = {
|
||||
recall: (value, store) => {
|
||||
store.dispatch(setCfgScale(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="metadata.cfgScale" />,
|
||||
i18nKey: 'metadata.cfgScale',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterCFGScale>) => <MetadataPrimitiveValue value={value} />,
|
||||
};
|
||||
//#endregion CFG Scale
|
||||
@@ -345,7 +358,8 @@ const CFGRescaleMultiplier: SingleMetadataHandler<ParameterCFGRescaleMultiplier>
|
||||
recall: (value, store) => {
|
||||
store.dispatch(setCfgRescaleMultiplier(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="metadata.cfgRescaleMultiplier" />,
|
||||
i18nKey: 'metadata.cfgRescaleMultiplier',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterCFGRescaleMultiplier>) => (
|
||||
<MetadataPrimitiveValue value={value} />
|
||||
),
|
||||
@@ -364,7 +378,8 @@ const Guidance: SingleMetadataHandler<ParameterGuidance> = {
|
||||
recall: (value, store) => {
|
||||
store.dispatch(setGuidance(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="metadata.guidance" />,
|
||||
i18nKey: 'metadata.guidance',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterGuidance>) => <MetadataPrimitiveValue value={value} />,
|
||||
};
|
||||
//#endregion Guidance
|
||||
@@ -381,7 +396,8 @@ const Scheduler: SingleMetadataHandler<ParameterScheduler> = {
|
||||
recall: (value, store) => {
|
||||
store.dispatch(setScheduler(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="metadata.scheduler" />,
|
||||
i18nKey: 'metadata.scheduler',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterScheduler>) => <MetadataPrimitiveValue value={value} />,
|
||||
};
|
||||
//#endregion Scheduler
|
||||
@@ -396,9 +412,15 @@ const Width: SingleMetadataHandler<ParameterWidth> = {
|
||||
return Promise.resolve(parsed);
|
||||
},
|
||||
recall: (value, store) => {
|
||||
store.dispatch(bboxWidthChanged({ width: value, updateAspectRatio: true, clamp: true }));
|
||||
const activeTab = selectActiveTab(store.getState());
|
||||
if (activeTab === 'canvas') {
|
||||
store.dispatch(bboxWidthChanged({ width: value, updateAspectRatio: true, clamp: true }));
|
||||
} else if (activeTab === 'generate') {
|
||||
store.dispatch(widthChanged({ width: value, updateAspectRatio: true, clamp: true }));
|
||||
}
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="metadata.width" />,
|
||||
i18nKey: 'metadata.width',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterWidth>) => <MetadataPrimitiveValue value={value} />,
|
||||
};
|
||||
//#endregion Width
|
||||
@@ -413,9 +435,15 @@ const Height: SingleMetadataHandler<ParameterHeight> = {
|
||||
return Promise.resolve(parsed);
|
||||
},
|
||||
recall: (value, store) => {
|
||||
store.dispatch(bboxHeightChanged({ height: value, updateAspectRatio: true, clamp: true }));
|
||||
const activeTab = selectActiveTab(store.getState());
|
||||
if (activeTab === 'canvas') {
|
||||
store.dispatch(bboxHeightChanged({ height: value, updateAspectRatio: true, clamp: true }));
|
||||
} else if (activeTab === 'generate') {
|
||||
store.dispatch(heightChanged({ height: value, updateAspectRatio: true, clamp: true }));
|
||||
}
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="metadata.height" />,
|
||||
i18nKey: 'metadata.height',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterHeight>) => <MetadataPrimitiveValue value={value} />,
|
||||
};
|
||||
//#endregion Height
|
||||
@@ -432,7 +460,8 @@ const Seed: SingleMetadataHandler<ParameterSeed> = {
|
||||
recall: (value, store) => {
|
||||
store.dispatch(setSeed(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="metadata.seed" />,
|
||||
i18nKey: 'metadata.seed',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterSeed>) => <MetadataPrimitiveValue value={value} />,
|
||||
};
|
||||
//#endregion Seed
|
||||
@@ -449,7 +478,8 @@ const Steps: SingleMetadataHandler<ParameterSteps> = {
|
||||
recall: (value, store) => {
|
||||
store.dispatch(setSteps(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="metadata.steps" />,
|
||||
i18nKey: 'metadata.steps',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterSteps>) => <MetadataPrimitiveValue value={value} />,
|
||||
};
|
||||
//#endregion Steps
|
||||
@@ -466,7 +496,8 @@ const DenoisingStrength: SingleMetadataHandler<ParameterStrength> = {
|
||||
recall: (value, store) => {
|
||||
store.dispatch(setImg2imgStrength(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="metadata.strength" />,
|
||||
i18nKey: 'metadata.strength',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterStrength>) => <MetadataPrimitiveValue value={value} />,
|
||||
};
|
||||
//#endregion DenoisingStrength
|
||||
@@ -483,7 +514,8 @@ const SeamlessX: SingleMetadataHandler<ParameterSeamlessX> = {
|
||||
recall: (value, store) => {
|
||||
store.dispatch(setSeamlessXAxis(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="metadata.seamlessXAxis" />,
|
||||
i18nKey: 'metadata.seamlessXAxis',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterSeamlessX>) => <MetadataPrimitiveValue value={value} />,
|
||||
};
|
||||
//#endregion SeamlessX
|
||||
@@ -500,7 +532,8 @@ const SeamlessY: SingleMetadataHandler<ParameterSeamlessY> = {
|
||||
recall: (value, store) => {
|
||||
store.dispatch(setSeamlessYAxis(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="metadata.seamlessYAxis" />,
|
||||
i18nKey: 'metadata.seamlessYAxis',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterSeamlessY>) => <MetadataPrimitiveValue value={value} />,
|
||||
};
|
||||
//#endregion SeamlessY
|
||||
@@ -520,7 +553,8 @@ const RefinerModel: SingleMetadataHandler<ParameterSDXLRefinerModel> = {
|
||||
recall: (value, store) => {
|
||||
store.dispatch(refinerModelChanged(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="sdxl.refinermodel" />,
|
||||
i18nKey: 'sdxl.refinermodel',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterSDXLRefinerModel>) => (
|
||||
<MetadataPrimitiveValue value={`${value.name} (${value.base.toUpperCase()})`} />
|
||||
),
|
||||
@@ -539,7 +573,8 @@ const RefinerSteps: SingleMetadataHandler<ParameterSteps> = {
|
||||
recall: (value, store) => {
|
||||
store.dispatch(setRefinerSteps(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="sdxl.refinerSteps" />,
|
||||
i18nKey: 'sdxl.refinerSteps',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterSteps>) => <MetadataPrimitiveValue value={value} />,
|
||||
};
|
||||
//#endregion RefinerSteps
|
||||
@@ -556,7 +591,8 @@ const RefinerCFGScale: SingleMetadataHandler<ParameterCFGScale> = {
|
||||
recall: (value, store) => {
|
||||
store.dispatch(setRefinerCFGScale(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="sdxl.cfgScale" />,
|
||||
i18nKey: 'sdxl.cfgScale',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterCFGScale>) => <MetadataPrimitiveValue value={value} />,
|
||||
};
|
||||
//#endregion RefinerCFGScale
|
||||
@@ -573,7 +609,8 @@ const RefinerScheduler: SingleMetadataHandler<ParameterScheduler> = {
|
||||
recall: (value, store) => {
|
||||
store.dispatch(setRefinerScheduler(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="sdxl.scheduler" />,
|
||||
i18nKey: 'sdxl.scheduler',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterScheduler>) => <MetadataPrimitiveValue value={value} />,
|
||||
};
|
||||
//#endregion RefinerScheduler
|
||||
@@ -590,7 +627,8 @@ const RefinerPositiveAestheticScore: SingleMetadataHandler<ParameterSDXLRefinerP
|
||||
recall: (value, store) => {
|
||||
store.dispatch(setRefinerPositiveAestheticScore(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="sdxl.posAestheticScore" />,
|
||||
i18nKey: 'sdxl.posAestheticScore',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterSDXLRefinerPositiveAestheticScore>) => (
|
||||
<MetadataPrimitiveValue value={value} />
|
||||
),
|
||||
@@ -609,7 +647,8 @@ const RefinerNegativeAestheticScore: SingleMetadataHandler<ParameterSDXLRefinerN
|
||||
recall: (value, store) => {
|
||||
store.dispatch(setRefinerNegativeAestheticScore(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="sdxl.negAestheticScore" />,
|
||||
i18nKey: 'sdxl.negAestheticScore',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterSDXLRefinerNegativeAestheticScore>) => (
|
||||
<MetadataPrimitiveValue value={value} />
|
||||
),
|
||||
@@ -628,7 +667,8 @@ const RefinerDenoisingStart: SingleMetadataHandler<ParameterSDXLRefinerStart> =
|
||||
recall: (value, store) => {
|
||||
store.dispatch(setRefinerStart(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="sdxl.refinerStart" />,
|
||||
i18nKey: 'sdxl.refinerStart',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterSDXLRefinerStart>) => (
|
||||
<MetadataPrimitiveValue value={value} />
|
||||
),
|
||||
@@ -648,7 +688,8 @@ const MainModel: SingleMetadataHandler<ParameterModel> = {
|
||||
recall: (value, store) => {
|
||||
store.dispatch(modelSelected(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="metadata.model" />,
|
||||
i18nKey: 'metadata.model',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterModel>) => (
|
||||
<MetadataPrimitiveValue value={`${value.name} (${value.base.toUpperCase()})`} />
|
||||
),
|
||||
@@ -669,7 +710,8 @@ const VAEModel: SingleMetadataHandler<ParameterVAEModel> = {
|
||||
recall: (value, store) => {
|
||||
store.dispatch(vaeSelected(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="metadata.vae" />,
|
||||
i18nKey: 'metadata.vae',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<ParameterVAEModel>) => (
|
||||
<MetadataPrimitiveValue value={`${value.name} (${value.base.toUpperCase()})`} />
|
||||
),
|
||||
@@ -733,7 +775,8 @@ const LoRAs: CollectionMetadataHandler<LoRA[]> = {
|
||||
store.dispatch(loraRecalled({ lora }));
|
||||
}
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="models.lora" />,
|
||||
i18nKey: 'models.lora',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: CollectionMetadataValueProps<LoRA[]>) => (
|
||||
<MetadataPrimitiveValue value={`${value.model.name} (${value.model.base.toUpperCase()}) - ${value.weight}`} />
|
||||
),
|
||||
@@ -763,7 +806,8 @@ const CanvasLayers: SingleMetadataHandler<CanvasMetadata> = {
|
||||
}
|
||||
store.dispatch(canvasMetadataRecalled(value));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="metadata.canvasV2Metadata" />,
|
||||
i18nKey: 'metadata.canvasV2Metadata',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: SingleMetadataValueProps<CanvasMetadata>) => {
|
||||
const { t } = useTranslation();
|
||||
const count =
|
||||
@@ -810,7 +854,8 @@ const RefImages: CollectionMetadataHandler<RefImageState[]> = {
|
||||
const entities = [{ ...data, id: getPrefixedId('reference_image') }];
|
||||
store.dispatch(refImagesRecalled({ entities, replace: false }));
|
||||
},
|
||||
LabelComponent: () => <MetadataLabel i18nKey="controlLayers.referenceImage" />,
|
||||
i18nKey: 'controlLayers.referenceImage',
|
||||
LabelComponent: MetadataLabel,
|
||||
ValueComponent: ({ value }: CollectionMetadataValueProps<RefImageState[]>) => {
|
||||
if (value.config.model) {
|
||||
return <MetadataPrimitiveValue value={value.config.model.name} />;
|
||||
@@ -862,7 +907,7 @@ export const MetadataHandlers = {
|
||||
// ipAdapterToIPAdapterLayer: parseIPAdapterToIPAdapterLayer,
|
||||
} as const;
|
||||
|
||||
const successToast = (parameter: ReactNode) => {
|
||||
const successToast = (parameter: string) => {
|
||||
toast({
|
||||
id: 'PARAMETER_SET',
|
||||
title: t('toast.parameterSet'),
|
||||
@@ -871,7 +916,7 @@ const successToast = (parameter: ReactNode) => {
|
||||
});
|
||||
};
|
||||
|
||||
const failedToast = (parameter: ReactNode, message?: ReactNode) => {
|
||||
const failedToast = (parameter: string, message?: string) => {
|
||||
toast({
|
||||
id: 'PARAMETER_NOT_SET',
|
||||
title: t('toast.parameterNotSet'),
|
||||
@@ -902,9 +947,9 @@ const recallByHandler = async (arg: {
|
||||
|
||||
if (!silent) {
|
||||
if (didRecall) {
|
||||
successToast(<handler.LabelComponent />);
|
||||
successToast(t(handler.i18nKey));
|
||||
} else {
|
||||
failedToast(<handler.LabelComponent />);
|
||||
failedToast(t(handler.i18nKey));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -950,21 +995,24 @@ const recallByHandlers = async (arg: {
|
||||
}
|
||||
}
|
||||
|
||||
// If we recalled style prompts, and they were _different_ from the positive prompt, we need to disable prompt concat.
|
||||
// We may need to update the prompt concat flag based on the recalled prompts
|
||||
const positivePrompt = recalled.get(MetadataHandlers.PositivePrompt);
|
||||
const negativePrompt = recalled.get(MetadataHandlers.NegativePrompt);
|
||||
const positiveStylePrompt = recalled.get(MetadataHandlers.PositiveStylePrompt);
|
||||
const negativeStylePrompt = recalled.get(MetadataHandlers.NegativeStylePrompt);
|
||||
|
||||
// The values will be undefined if the handler was not recalled
|
||||
if (
|
||||
(positiveStylePrompt && positiveStylePrompt !== positivePrompt) ||
|
||||
(negativeStylePrompt && negativeStylePrompt !== negativePrompt)
|
||||
positivePrompt !== undefined ||
|
||||
negativePrompt !== undefined ||
|
||||
positiveStylePrompt !== undefined ||
|
||||
negativeStylePrompt !== undefined
|
||||
) {
|
||||
// If we set the negative style prompt or positive style prompt, we should disable prompt concat
|
||||
store.dispatch(shouldConcatPromptsChanged(false));
|
||||
} else {
|
||||
// Otherwise, we should enable prompt concat
|
||||
store.dispatch(shouldConcatPromptsChanged(true));
|
||||
const concat =
|
||||
(Boolean(positiveStylePrompt) && positiveStylePrompt === positivePrompt) ||
|
||||
(Boolean(negativeStylePrompt) && negativeStylePrompt === negativePrompt);
|
||||
|
||||
store.dispatch(shouldConcatPromptsChanged(concat));
|
||||
}
|
||||
|
||||
if (!silent) {
|
||||
@@ -1003,6 +1051,28 @@ const recallPrompts = async (metadata: unknown, store: AppStore) => {
|
||||
}
|
||||
};
|
||||
|
||||
const hasMetadataByHandlers = async (arg: {
|
||||
metadata: unknown;
|
||||
handlers: (SingleMetadataHandler<any> | CollectionMetadataHandler<any[]>)[];
|
||||
store: AppStore;
|
||||
require: 'some' | 'all';
|
||||
}) => {
|
||||
const { metadata, handlers, store, require } = arg;
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
await handler.parse(metadata, store);
|
||||
if (require === 'some') {
|
||||
return true;
|
||||
}
|
||||
} catch {
|
||||
if (require === 'all') {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
const recallDimensions = async (metadata: unknown, store: AppStore) => {
|
||||
const recalled = await recallByHandlers({
|
||||
metadata,
|
||||
@@ -1032,6 +1102,7 @@ const recallAll = async (
|
||||
};
|
||||
|
||||
export const MetadataUtils = {
|
||||
hasMetadataByHandlers,
|
||||
recallByHandler,
|
||||
recallByHandlers,
|
||||
recallAll,
|
||||
|
||||
@@ -32,7 +32,7 @@ const CurrentImageNode = (props: NodeProps) => {
|
||||
if (imageDTO) {
|
||||
return (
|
||||
<Wrapper nodeProps={props}>
|
||||
<DndImage imageDTO={imageDTO} />
|
||||
<DndImage imageDTO={imageDTO} borderRadius="base" />
|
||||
</Wrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -146,6 +146,7 @@ const ImageGridItemContent = memo(
|
||||
return (
|
||||
<>
|
||||
<DndImage
|
||||
borderRadius="base"
|
||||
imageDTO={query.data}
|
||||
asThumbnail
|
||||
objectFit="contain"
|
||||
|
||||
@@ -76,7 +76,7 @@ const ImageFieldInputComponent = (props: FieldComponentProps<ImageFieldInputInst
|
||||
)}
|
||||
{imageDTO && (
|
||||
<>
|
||||
<Flex borderRadius="base" borderWidth={1} borderStyle="solid">
|
||||
<Flex borderRadius="base" borderWidth={1} borderStyle="solid" overflow="hidden">
|
||||
<DndImage imageDTO={imageDTO} asThumbnail />
|
||||
</Flex>
|
||||
<Text
|
||||
|
||||
@@ -14,7 +14,7 @@ const ImageOutputPreview = ({ output }: Props) => {
|
||||
return null;
|
||||
}
|
||||
|
||||
return <DndImage imageDTO={imageDTO} />;
|
||||
return <DndImage imageDTO={imageDTO} borderRadius="base" />;
|
||||
};
|
||||
|
||||
export default memo(ImageOutputPreview);
|
||||
|
||||
@@ -137,6 +137,8 @@ const NODE_TYPE_PUBLISH_DENYLIST = [
|
||||
'chatgpt_4o_edit_image',
|
||||
'flux_kontext_generate_image',
|
||||
'flux_kontext_edit_image',
|
||||
'claude_expand_prompt',
|
||||
'claude_analyze_image',
|
||||
];
|
||||
|
||||
export const selectHasUnpublishableNodes = createSelector(selectNodes, (nodes) => {
|
||||
|
||||
@@ -30,11 +30,14 @@ export const addFLUXFill = async ({
|
||||
denoise.denoising_start = denoising_start;
|
||||
denoise.denoising_end = denoising_end;
|
||||
|
||||
const { originalSize, scaledSize, rect } = getOriginalAndScaledSizesForOtherModes(state);
|
||||
|
||||
denoise.width = scaledSize.width;
|
||||
denoise.height = scaledSize.height;
|
||||
|
||||
const params = selectParamsSlice(state);
|
||||
const canvasSettings = selectCanvasSettingsSlice(state);
|
||||
|
||||
const { originalSize, scaledSize, rect } = getOriginalAndScaledSizesForOtherModes(state);
|
||||
|
||||
const rasterAdapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
|
||||
const initialImage = await manager.compositor.getCompositeImageDTO(rasterAdapters, rect, {
|
||||
is_intermediate: true,
|
||||
|
||||
@@ -78,8 +78,6 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
|
||||
if (generationMode !== 'txt2img') {
|
||||
throw new UnsupportedGenerationModeError(t('toast.fluxKontextIncompatibleGenerationMode'));
|
||||
}
|
||||
|
||||
guidance = 30;
|
||||
}
|
||||
|
||||
const g = new Graph(getPrefixedId('flux_graph'));
|
||||
|
||||
@@ -26,7 +26,7 @@ export const buildFluxKontextGraph = (arg: GraphBuilderArg): GraphBuilderReturn
|
||||
assert(model.base === 'flux-kontext', 'Selected model is not a FLUX Kontext API model');
|
||||
|
||||
if (generationMode !== 'txt2img') {
|
||||
throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'FLUX Kontext' }));
|
||||
throw new UnsupportedGenerationModeError(t('toast.fluxKontextIncompatibleGenerationMode'));
|
||||
}
|
||||
|
||||
log.debug({ generationMode, manager: manager?.id }, 'Building FLUX Kontext graph');
|
||||
|
||||
@@ -78,7 +78,7 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('neg_cond'),
|
||||
prompt: prompts.negative,
|
||||
style: prompts.negativeStyle,
|
||||
style: prompts.useMainPromptsForStyle ? prompts.negative : prompts.negativeStyle,
|
||||
});
|
||||
const negCondCollect = g.addNode({
|
||||
type: 'collect',
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectSaveAllImagesToGallery } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import {
|
||||
selectImg2imgStrength,
|
||||
selectMainModelConfig,
|
||||
@@ -44,8 +45,11 @@ export const selectCanvasOutputFields = (state: RootState) => {
|
||||
// Advanced session means working on canvas - images are not saved to gallery or added to a board.
|
||||
// Simple session means working in YOLO mode - images are saved to gallery & board.
|
||||
const tab = selectActiveTab(state);
|
||||
const is_intermediate = tab === 'canvas';
|
||||
const board = tab === 'canvas' ? undefined : getBoardField(state);
|
||||
const saveAllImagesToGallery = selectSaveAllImagesToGallery(state);
|
||||
|
||||
// If we're on canvas and the save all images setting is enabled, save to gallery
|
||||
const is_intermediate = tab === 'canvas' && !saveAllImagesToGallery;
|
||||
const board = tab === 'canvas' && !saveAllImagesToGallery ? undefined : getBoardField(state);
|
||||
|
||||
return {
|
||||
is_intermediate,
|
||||
|
||||
@@ -18,7 +18,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { $onClickGoToModelManager } from 'app/store/nanostores/onClickGoToModelManager';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { Group, PickerContextState } from 'common/components/Picker/Picker';
|
||||
import { buildGroup, getRegex, isOption, Picker, usePickerContext } from 'common/components/Picker/Picker';
|
||||
import { buildGroup, getRegex, isGroup, Picker, usePickerContext } from 'common/components/Picker/Picker';
|
||||
import { useDisclosure } from 'common/hooks/useBoolean';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { uniq } from 'es-toolkit/compat';
|
||||
@@ -277,8 +277,22 @@ export const ModelPicker = typedMemo(
|
||||
if (!selectedModelConfig) {
|
||||
return undefined;
|
||||
}
|
||||
let _selectedOption: WithStarred<T> | undefined = undefined;
|
||||
|
||||
return options.filter(isOption).find((o) => o.key === selectedModelConfig.key);
|
||||
for (const optionOrGroup of options) {
|
||||
if (isGroup(optionOrGroup)) {
|
||||
const result = optionOrGroup.options.find((o) => o.key === selectedModelConfig.key);
|
||||
if (result) {
|
||||
_selectedOption = result;
|
||||
break;
|
||||
}
|
||||
} else if (optionOrGroup.key === selectedModelConfig.key) {
|
||||
_selectedOption = optionOrGroup;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return _selectedOption;
|
||||
}, [options, selectedModelConfig]);
|
||||
|
||||
const onClose = useCallback(() => {
|
||||
@@ -361,9 +375,19 @@ const optionSx: SystemStyleObject = {
|
||||
cursor: 'pointer',
|
||||
borderRadius: 'base',
|
||||
'&[data-selected="true"]': {
|
||||
bg: 'base.700',
|
||||
bg: 'invokeBlue.300',
|
||||
color: 'base.900',
|
||||
'.extra-info': {
|
||||
color: 'base.700',
|
||||
},
|
||||
'.picker-option': {
|
||||
fontWeight: 'bold',
|
||||
'&[data-is-compact="true"]': {
|
||||
fontWeight: 'semibold',
|
||||
},
|
||||
},
|
||||
'&[data-active="true"]': {
|
||||
bg: 'base.650',
|
||||
bg: 'invokeBlue.250',
|
||||
},
|
||||
},
|
||||
'&[data-active="true"]': {
|
||||
@@ -400,17 +424,31 @@ const PickerOptionComponent = typedMemo(
|
||||
<Flex flexDir="column" gap={1} flex={1}>
|
||||
<Flex gap={2} alignItems="center">
|
||||
{option.starred && <Icon as={PiLinkSimple} color="invokeYellow.500" boxSize={4} />}
|
||||
<Text sx={optionNameSx} data-is-compact={compactView}>
|
||||
<Text className="picker-option" sx={optionNameSx} data-is-compact={compactView}>
|
||||
{option.name}
|
||||
</Text>
|
||||
<Spacer />
|
||||
{option.file_size > 0 && (
|
||||
<Text variant="subtext" fontStyle="italic" noOfLines={1} flexShrink={0} overflow="visible">
|
||||
<Text
|
||||
className="extra-info"
|
||||
variant="subtext"
|
||||
fontStyle="italic"
|
||||
noOfLines={1}
|
||||
flexShrink={0}
|
||||
overflow="visible"
|
||||
>
|
||||
{filesize(option.file_size)}
|
||||
</Text>
|
||||
)}
|
||||
{option.usage_info && (
|
||||
<Text variant="subtext" fontStyle="italic" noOfLines={1} flexShrink={0} overflow="visible">
|
||||
<Text
|
||||
className="extra-info"
|
||||
variant="subtext"
|
||||
fontStyle="italic"
|
||||
noOfLines={1}
|
||||
flexShrink={0}
|
||||
overflow="visible"
|
||||
>
|
||||
{option.usage_info}
|
||||
</Text>
|
||||
)}
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
createParamsSelector,
|
||||
selectHasNegativePrompt,
|
||||
selectModelSupportsNegativePrompt,
|
||||
} from 'features/controlLayers/store/paramsSlice';
|
||||
import { ParamNegativePrompt } from 'features/parameters/components/Core/ParamNegativePrompt';
|
||||
import { ParamPositivePrompt } from 'features/parameters/components/Core/ParamPositivePrompt';
|
||||
import { ParamSDXLNegativeStylePrompt } from 'features/sdxl/components/SDXLPrompts/ParamSDXLNegativeStylePrompt';
|
||||
import { ParamSDXLPositiveStylePrompt } from 'features/sdxl/components/SDXLPrompts/ParamSDXLPositiveStylePrompt';
|
||||
import { memo } from 'react';
|
||||
|
||||
const selectWithStylePrompts = createParamsSelector((params) => {
|
||||
const isSDXL = params.model?.base === 'sdxl';
|
||||
const shouldConcatPrompts = params.shouldConcatPrompts;
|
||||
return isSDXL && !shouldConcatPrompts;
|
||||
});
|
||||
|
||||
export const UpscalePrompts = memo(() => {
|
||||
const withStylePrompts = useAppSelector(selectWithStylePrompts);
|
||||
const modelSupportsNegativePrompt = useAppSelector(selectModelSupportsNegativePrompt);
|
||||
const hasNegativePrompt = useAppSelector(selectHasNegativePrompt);
|
||||
return (
|
||||
<Flex flexDir="column" gap={2}>
|
||||
<ParamPositivePrompt />
|
||||
{withStylePrompts && <ParamSDXLPositiveStylePrompt />}
|
||||
{modelSupportsNegativePrompt && hasNegativePrompt && <ParamNegativePrompt />}
|
||||
{withStylePrompts && <ParamSDXLNegativeStylePrompt />}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
UpscalePrompts.displayName = 'UpscalePrompts';
|
||||
@@ -0,0 +1,28 @@
|
||||
import type { ButtonProps } from '@invoke-ai/ui-library';
|
||||
import { Button } from '@invoke-ai/ui-library';
|
||||
import { useCancelAllExceptCurrentQueueItemDialog } from 'features/queue/components/CancelAllExceptCurrentQueueItemConfirmationAlertDialog';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXCircle } from 'react-icons/pi';
|
||||
|
||||
export const CancelAllExceptCurrentButton = memo((props: ButtonProps) => {
|
||||
const { t } = useTranslation();
|
||||
const api = useCancelAllExceptCurrentQueueItemDialog();
|
||||
|
||||
return (
|
||||
<Button
|
||||
isDisabled={api.isDisabled}
|
||||
isLoading={api.isLoading}
|
||||
aria-label={t('queue.clear')}
|
||||
tooltip={t('queue.cancelAllExceptCurrentTooltip')}
|
||||
leftIcon={<PiXCircle />}
|
||||
colorScheme="error"
|
||||
onClick={api.openDialog}
|
||||
{...props}
|
||||
>
|
||||
{t('queue.clear')}
|
||||
</Button>
|
||||
);
|
||||
});
|
||||
|
||||
CancelAllExceptCurrentButton.displayName = 'CancelAllExceptCurrentButton';
|
||||
@@ -0,0 +1,25 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useCancelAllExceptCurrentQueueItemDialog } from 'features/queue/components/CancelAllExceptCurrentQueueItemConfirmationAlertDialog';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXCircle } from 'react-icons/pi';
|
||||
|
||||
export const CancelAllExceptCurrentIconButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const api = useCancelAllExceptCurrentQueueItemDialog();
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
size="lg"
|
||||
isDisabled={api.isDisabled}
|
||||
isLoading={api.isLoading}
|
||||
aria-label={t('queue.clear')}
|
||||
tooltip={t('queue.cancelAllExceptCurrentTooltip')}
|
||||
icon={<PiXCircle />}
|
||||
colorScheme="error"
|
||||
onClick={api.openDialog}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
CancelAllExceptCurrentIconButton.displayName = 'CancelAllExceptCurrentIconButton';
|
||||
@@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next';
|
||||
|
||||
const [useCancelAllExceptCurrentQueueItemConfirmationAlertDialog] = buildUseBoolean(false);
|
||||
|
||||
const useCancelAllExceptCurrentQueueItemDialog = () => {
|
||||
export const useCancelAllExceptCurrentQueueItemDialog = () => {
|
||||
const dialog = useCancelAllExceptCurrentQueueItemConfirmationAlertDialog();
|
||||
const cancelAllExceptCurrentQueueItem = useCancelAllExceptCurrentQueueItem();
|
||||
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useCancelCurrentQueueItem } from 'features/queue/hooks/useCancelCurrentQueueItem';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
|
||||
export const CancelCurrentQueueItemIconButton = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const cancelCurrentQueueItem = useCancelCurrentQueueItem();
|
||||
|
||||
const cancelCurrentQueueItemWithToast = useCallback(() => {
|
||||
cancelCurrentQueueItem.trigger({ withToast: true });
|
||||
}, [cancelCurrentQueueItem]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
size="lg"
|
||||
onClick={cancelCurrentQueueItemWithToast}
|
||||
isDisabled={cancelCurrentQueueItem.isDisabled}
|
||||
isLoading={cancelCurrentQueueItem.isLoading}
|
||||
aria-label={t('queue.cancel')}
|
||||
tooltip={t('queue.cancelTooltip')}
|
||||
icon={<PiXBold />}
|
||||
colorScheme="error"
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
CancelCurrentQueueItemIconButton.displayName = 'CancelCurrentQueueItemIconButton';
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user