Compare commits

...

44 Commits

Author SHA1 Message Date
psychedelicious
0cfd713b93 fix(ui): typo 2025-03-06 08:52:10 +11:00
psychedelicious
45f5d7617a chore: bump version to v5.7.0 2025-03-06 08:38:59 +11:00
psychedelicious
f49df7d327 chore(ui): update whats new 2025-03-06 08:38:59 +11:00
Linos
87ed0ed48a translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (1802 of 1802 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2025-03-06 08:00:35 +11:00
Riccardo Giovanetti
d445c88e4c translationBot(ui): update translation (Italian)
Currently translated at 98.8% (1782 of 1802 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.8% (1782 of 1802 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-03-06 08:00:35 +11:00
Riku
c15c43ed2a translationBot(ui): update translation (German)
Currently translated at 67.2% (1212 of 1802 strings)

Co-authored-by: Riku <riku.block@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2025-03-06 08:00:35 +11:00
psychedelicious
d2f8db9745 tidy: remove unused utils 2025-03-06 07:49:35 +11:00
psychedelicious
c1cf01a038 tests: use dangerously_run_function_in_subprocess to fix configure_torch_cuda_allocator tests 2025-03-06 07:49:35 +11:00
psychedelicious
2bfb4fc79c tests: add util to run a function in separate process
This allows our tests to run in an isolated environment. For tests taht implicitly depend on import behaviour, this can prevent side-effects.

The function should only be used for tests.
2025-03-06 07:49:35 +11:00
psychedelicious
d037d8f9aa tests: update tests for configure_torch_cuda_allocator 2025-03-06 07:49:35 +11:00
psychedelicious
d5401e8443 tests: add testing utils to set/unset env var 2025-03-06 07:49:35 +11:00
psychedelicious
d193e4f02a feat(app): log warning instead of raising if PYTORCH_CUDA_ALLOC_CONF is already set 2025-03-06 07:49:35 +11:00
psychedelicious
ec493e30ee feat(app): make logger a required arg in configure_torch_cuda_allocator 2025-03-06 07:49:35 +11:00
Jonathan
081b931edf Update util.py
Changed string to a literal
2025-03-05 14:39:17 +11:00
Jonathan
8cd7035494 Fixed validation of begin and end steps
Fixed logic to match the error message - begin should be <= end.
2025-03-05 14:39:17 +11:00
Eugene Brodsky
4de6fd3ae6 chore(docker): reduce size between docker builds (#7571)
by adding a layer with all the pytorch dependencies that don't change
most of the time.

## Summary

Every time the [`main` docker
images](https://github.com/invoke-ai/InvokeAI/pkgs/container/invokeai)
rebuild and I pull `main-cuda`, it gets another 3+ GB, which seems like
about a zillion times too much since most things don't change from one
commit on `main` to the next.

This is an attempt to follow the guidance in [Using uv in Docker:
Intermediate
Layers](https://docs.astral.sh/uv/guides/integration/docker/#intermediate-layers)
so there's one layer that installs all the dependencies—including
PyTorch with its bundled nvidia libraries—_before_ the project's own
frequently-changing files are copied in to the image.


## Related Issues / Discussions

- [Improved docker layer cache with
uv](https://discord.com/channels/1020123559063990373/1329975172022927370)
- [astral: Can `uv pip install` torch, but not `uv sync`
it](https://discord.com/channels/1039017663004942429/1329986610770612347)


## QA Instructions

Hopefully the CI system building the docker images is sufficient.

But there is one change to `pyproject.toml` related to xformers, so it'd
be worth checking that `python -m xformers.info` still says it has
triton on the platforms that expect it.


## Merge Plan

I don't expect this to be a disruptive merge.

(An earlier revision of this PR moved the venv, but I've reverted that
change at ebr's recommendation.)


## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-03-04 20:42:28 -05:00
Eugene Brodsky
3feb1a6600 Merge branch 'main' into build/docker-dependency-layer 2025-03-04 20:33:24 -05:00
psychedelicious
ea2320c57b feat(ui): add button ref image layer empty state to pull bbox 2025-03-05 08:00:20 +11:00
psychedelicious
0ad0016c2d chore: bump version to v5.7.2rc2 2025-03-04 08:48:28 +11:00
psychedelicious
c2a3c66e49 feat(app): avoid nested cursors in workflow_records service 2025-03-04 08:33:42 +11:00
psychedelicious
c0a0d20935 feat(app): avoid nested cursors in style_preset_records service 2025-03-04 08:33:42 +11:00
psychedelicious
028d8d8ead feat(app): avoid nested cursors in model_records service 2025-03-04 08:33:42 +11:00
psychedelicious
657095d2e2 feat(app): avoid nested cursors in image_records service 2025-03-04 08:33:42 +11:00
psychedelicious
1c47dc997e feat(app): avoid nested cursors in board_records service 2025-03-04 08:33:42 +11:00
psychedelicious
a3de6b6165 feat(app): avoid nested cursors in board_image_records service 2025-03-04 08:33:42 +11:00
psychedelicious
e57f0ff055 experiment(app): avoid nested cursors in session_queue service
SQLite cursors are meant to be lightweight and not reused. For whatever reason, we reuse one per service for the entire app lifecycle.

This can cause issues where a cursor is used twice at the same time in different transactions.

This experiment makes the session queue use a fresh cursor for each method, hopefully fixing the issue.
2025-03-04 08:33:42 +11:00
Eugene Brodsky
0362bd5a06 Merge branch 'main' into build/docker-dependency-layer 2025-03-03 09:32:04 -05:00
Linos
feee4c49a2 translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (1798 of 1798 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2025-03-03 14:50:08 +11:00
Riccardo Giovanetti
42e052d6f2 translationBot(ui): update translation (Italian)
Currently translated at 98.8% (1777 of 1798 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-03-03 14:50:08 +11:00
psychedelicious
b03e429b26 fix(ui): add missing builder translations 2025-03-03 14:43:23 +11:00
psychedelicious
7399909029 feat(app): use simpler syntax for enqueue_batch threaded execution 2025-03-03 14:40:48 +11:00
psychedelicious
c8aaf5e76b tidy(app): remove extraneous class attr type annotations 2025-03-03 14:40:48 +11:00
psychedelicious
0cdf7a7048 Revert "experiment(app): simulate very long enqueue operations (15s)"
This reverts commit eb6a323d0b70004732de493d6530e08eb5ca8acf.
2025-03-03 14:40:48 +11:00
psychedelicious
41985487d3 Revert "experiment(app): make socketio server ping every 1s"
This reverts commit ddf00bf260167092a3bc2afdce1244c6b116ebfb.
2025-03-03 14:40:48 +11:00
psychedelicious
41d5a17114 fix(ui): set RTKQ tag invalidationBehaviour to immediate
This allows tags to be invalidated while mutations are executing, resolving an issue in this situation:
- A long-running mutation starts.
- A tag is invalidated; for example, user edits a board name, and the boards list query tag is invalidated.
- The boards list query isn't fired, and the board name isn't updated.
- The long-running mutation finishes.
- Finally, the boards list query fires and the board name is updated.

This is the "delayed" behaviour. The "immediately" behaviour has the fires requests from tag invalidation immediately, without waiting for all mutations to finish.

It may cause extra network requests and stale data if we are mutating a lot of things very quickly. I don't think it will be an issue in practice and the improved responsiveness will be a net benefit.
2025-03-03 14:40:48 +11:00
psychedelicious
14f9d5b6bc experiment(app): remove db locking logic
Rely on WAL mode and the busy timeout.

Also changed:
- Remove extraneous rollbacks when we were only doing a `SELECT`
- Remove try/catch blocks that were made extraneous when removing the extraneous rollbacks
2025-03-03 14:40:48 +11:00
psychedelicious
eec4bdb038 experiment(app): enable WAL mode and set busy_timeout
This allows for read and write concurrency without using a global mutex. Operations may still fail they take longer than the busy timeout (5s).

If we get a database lock error after waiting 5s for an operation, we have a problem. So, I think it's actually better to use a busy timeout instead of a global mutex.

Alternatively, we could add a timeout to the global mutex.
2025-03-03 14:40:48 +11:00
psychedelicious
f3dd44044a experiment(app): run enqueue_batch async in a thread 2025-03-03 14:40:48 +11:00
psychedelicious
61a22eb8cb experiment(app): make socketio server ping every 1s 2025-03-03 14:40:48 +11:00
psychedelicious
03ca83fe13 experiment(app): simulate very long enqueue operations (15s) 2025-03-03 14:40:48 +11:00
Kevin Turner
80d38c0e47 chore(docker): include fewer files while installing dependencies
including just invokeai/version seems sufficient to appease uv sync here. including everything else would invalidate the cache we're trying to establish.
2025-02-16 12:31:14 -08:00
Kevin Turner
22362350dc chore(docker): revert to keeping venv in /opt/venv 2025-02-16 11:26:06 -08:00
Kevin Turner
275d891f48 Merge branch 'main' into build/docker-dependency-layer 2025-02-16 10:34:17 -08:00
Kevin Turner
3848e1926b chore(docker): reduce size between docker builds
by adding a layer with all the pytorch dependencies that don't change most of the time.
2025-01-18 09:10:54 -08:00
27 changed files with 1243 additions and 1150 deletions

View File

@@ -13,48 +13,63 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
git
# Install `uv` for package management
COPY --from=ghcr.io/astral-sh/uv:0.5.5 /uv /uvx /bin/
COPY --from=ghcr.io/astral-sh/uv:0.6.0 /uv /uvx /bin/
ENV VIRTUAL_ENV=/opt/venv
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
ENV INVOKEAI_SRC=/opt/invokeai
ENV PYTHON_VERSION=3.11
ENV UV_PYTHON=3.11
ENV UV_COMPILE_BYTECODE=1
ENV UV_LINK_MODE=copy
ENV UV_PROJECT_ENVIRONMENT="$VIRTUAL_ENV"
ENV UV_INDEX="https://download.pytorch.org/whl/cu124"
ARG GPU_DRIVER=cuda
ARG TARGETPLATFORM="linux/amd64"
# unused but available
ARG BUILDPLATFORM
# Switch to the `ubuntu` user to work around dependency issues with uv-installed python
RUN mkdir -p ${VIRTUAL_ENV} && \
mkdir -p ${INVOKEAI_SRC} && \
chmod -R a+w /opt
chmod -R a+w /opt && \
mkdir ~ubuntu/.cache && chown ubuntu: ~ubuntu/.cache
USER ubuntu
# Install python and create the venv
RUN uv python install ${PYTHON_VERSION} && \
uv venv --relocatable --prompt "invoke" --python ${PYTHON_VERSION} ${VIRTUAL_ENV}
# Install python
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
uv python install ${PYTHON_VERSION}
WORKDIR ${INVOKEAI_SRC}
COPY invokeai ./invokeai
COPY pyproject.toml ./
# Editable mode helps use the same image for development:
# the local working copy can be bind-mounted into the image
# at path defined by ${INVOKEAI_SRC}
# Install project's dependencies as a separate layer so they aren't rebuilt every commit.
# bind-mount instead of copy to defer adding sources to the image until next layer.
#
# NOTE: there are no pytorch builds for arm64 + cuda, only cpu
# x86_64/CUDA is the default
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=invokeai/version,target=invokeai/version \
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cpu"; \
UV_INDEX="https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/rocm6.1"; \
else \
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu124"; \
UV_INDEX="https://download.pytorch.org/whl/rocm6.1"; \
fi && \
uv pip install --python ${PYTHON_VERSION} $extra_index_url_arg -e "."
uv sync --no-install-project
# Now that the bulk of the dependencies have been installed, copy in the project files that change more frequently.
COPY invokeai invokeai
COPY pyproject.toml .
RUN --mount=type=cache,target=/home/ubuntu/.cache/uv,uid=1000,gid=1000 \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
UV_INDEX="https://download.pytorch.org/whl/cpu"; \
elif [ "$GPU_DRIVER" = "rocm" ]; then \
UV_INDEX="https://download.pytorch.org/whl/rocm6.1"; \
fi && \
uv sync
#### Build the Web UI ------------------------------------
@@ -98,6 +113,7 @@ RUN apt update && apt install -y --no-install-recommends \
ENV INVOKEAI_SRC=/opt/invokeai
ENV VIRTUAL_ENV=/opt/venv
ENV UV_PROJECT_ENVIRONMENT="$VIRTUAL_ENV"
ENV PYTHON_VERSION=3.11
ENV INVOKEAI_ROOT=/invokeai
ENV INVOKEAI_HOST=0.0.0.0
@@ -109,7 +125,7 @@ ENV CONTAINER_GID=${CONTAINER_GID:-1000}
# Install `uv` for package management
# and install python for the ubuntu user (expected to exist on ubuntu >=24.x)
# this is too tiny to optimize with multi-stage builds, but maybe we'll come back to it
COPY --from=ghcr.io/astral-sh/uv:0.5.5 /uv /uvx /bin/
COPY --from=ghcr.io/astral-sh/uv:0.6.0 /uv /uvx /bin/
USER ubuntu
RUN uv python install ${PYTHON_VERSION}
USER root

View File

@@ -48,7 +48,9 @@ async def enqueue_batch(
) -> EnqueueBatchResult:
"""Processes a batch and enqueues the output graphs for execution."""
return ApiDependencies.invoker.services.session_queue.enqueue_batch(queue_id=queue_id, batch=batch, prepend=prepend)
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
queue_id=queue_id, batch=batch, prepend=prepend
)
@session_queue_router.get(

View File

@@ -9,6 +9,6 @@ def validate_weights(weights: Union[float, list[float]]) -> None:
def validate_begin_end_step(begin_step_percent: float, end_step_percent: float) -> None:
"""Validate that begin_step_percent is less than end_step_percent"""
if begin_step_percent >= end_step_percent:
"""Validate that begin_step_percent is less than or equal to end_step_percent"""
if begin_step_percent > end_step_percent:
raise ValueError("Begin step percent must be less than or equal to end step percent")

View File

@@ -1,5 +1,4 @@
import sqlite3
import threading
from typing import Optional, cast
from invokeai.app.services.board_image_records.board_image_records_base import BoardImageRecordStorageBase
@@ -13,15 +12,9 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.RLock
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
def add_image_to_board(
self,
@@ -29,8 +22,8 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
image_name: str,
) -> None:
try:
self._lock.acquire()
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT INTO board_images (board_id, image_name)
VALUES (?, ?)
@@ -42,16 +35,14 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def remove_image_from_board(
self,
image_name: str,
) -> None:
try:
self._lock.acquire()
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE FROM board_images
WHERE image_name = ?;
@@ -62,8 +53,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_images_for_board(
self,
@@ -72,33 +61,27 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]:
# TODO: this isn't paginated yet?
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
self._cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, self._cursor.fetchone()[0])
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, cursor.fetchone()[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
def get_all_board_image_names_for_board(
@@ -107,98 +90,79 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
categories: list[ImageCategory] | None,
is_intermediate: bool | None,
) -> list[str]:
try:
self._lock.acquire()
params: list[str | bool] = []
params: list[str | bool] = []
# Base query is a join between images and board_images
stmt = """
# Base query is a join between images and board_images
stmt = """
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
AND board_images.board_id = ?
"""
params.append(board_id)
params.append(board_id)
# Add the category filter
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
stmt += f"""--sql
# Add the category filter
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
stmt += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Unpack the included categories into the query params
for c in category_strings:
params.append(c)
# Unpack the included categories into the query params
for c in category_strings:
params.append(c)
# Add the is_intermediate filter
if is_intermediate is not None:
stmt += """--sql
# Add the is_intermediate filter
if is_intermediate is not None:
stmt += """--sql
AND images.is_intermediate = ?
"""
params.append(is_intermediate)
params.append(is_intermediate)
# Put a ring on it
stmt += ";"
# Put a ring on it
stmt += ";"
# Execute the query
self._cursor.execute(stmt, params)
# Execute the query
cursor = self._conn.cursor()
cursor.execute(stmt, params)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
image_names = [r[0] for r in result]
return image_names
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
return image_names
def get_board_for_image(
self,
image_name: str,
) -> Optional[str]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT board_id
FROM board_images
WHERE image_name = ?;
""",
(image_name,),
)
result = self._cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
(image_name,),
)
result = cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
def get_image_count_for_board(self, board_id: str) -> int:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*)
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE images.is_intermediate = FALSE
AND board_images.board_id = ?;
""",
(board_id,),
)
count = cast(int, self._cursor.fetchone()[0])
return count
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
(board_id,),
)
count = cast(int, cursor.fetchone()[0])
return count

View File

@@ -1,5 +1,4 @@
import sqlite3
import threading
from typing import Union, cast
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
@@ -19,20 +18,14 @@ from invokeai.app.util.misc import uuid_string
class SqliteBoardRecordStorage(BoardRecordStorageBase):
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.RLock
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
def delete(self, board_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
@@ -40,14 +33,9 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
(board_id,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
except Exception as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
finally:
self._lock.release()
def save(
self,
@@ -55,8 +43,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
) -> BoardRecord:
try:
board_id = uuid_string()
self._lock.acquire()
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
@@ -67,8 +55,6 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
return self.get(board_id)
def get(
@@ -76,8 +62,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
board_id: str,
) -> BoardRecord:
try:
self._lock.acquire()
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM boards
@@ -86,12 +72,9 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
(board_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BoardRecordNotFoundException
return BoardRecord(**dict(result))
@@ -102,11 +85,10 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
changes: BoardChanges,
) -> BoardRecord:
try:
self._lock.acquire()
cursor = self._conn.cursor()
# Change the name of a board
if changes.board_name is not None:
self._cursor.execute(
cursor.execute(
"""--sql
UPDATE boards
SET board_name = ?
@@ -117,7 +99,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
# Change the cover image of a board
if changes.cover_image_name is not None:
self._cursor.execute(
cursor.execute(
"""--sql
UPDATE boards
SET cover_image_name = ?
@@ -128,7 +110,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
# Change the archived status of a board
if changes.archived is not None:
self._cursor.execute(
cursor.execute(
"""--sql
UPDATE boards
SET archived = ?
@@ -141,8 +123,6 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
return self.get(board_id)
def get_many(
@@ -153,11 +133,10 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardRecord]:
try:
self._lock.acquire()
cursor = self._conn.cursor()
# Build base query
base_query = """
# Build base query
base_query = """
SELECT *
FROM boards
{archived_filter}
@@ -165,81 +144,67 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
LIMIT ? OFFSET ?;
"""
# Determine archived filter condition
archived_filter = "" if include_archived else "WHERE archived = 0"
# Determine archived filter condition
archived_filter = "" if include_archived else "WHERE archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
# Execute query to fetch boards
self._cursor.execute(final_query, (limit, offset))
# Execute query to fetch boards
cursor.execute(final_query, (limit, offset))
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query
if include_archived:
count_query = """
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
"""
else:
count_query = """
else:
count_query = """
SELECT COUNT(*)
FROM boards
WHERE archived = 0;
"""
# Execute count query
self._cursor.execute(count_query)
# Execute count query
cursor.execute(count_query)
count = cast(int, self._cursor.fetchone()[0])
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
) -> list[BoardRecord]:
try:
self._lock.acquire()
if order_by == BoardRecordOrderBy.Name:
base_query = """
cursor = self._conn.cursor()
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY LOWER(board_name) {direction}
"""
else:
base_query = """
else:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
"""
archived_filter = "" if include_archived else "WHERE archived = 0"
archived_filter = "" if include_archived else "WHERE archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
self._cursor.execute(final_query)
cursor.execute(final_query)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
return boards
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return boards

View File

@@ -1,5 +1,4 @@
import sqlite3
import threading
from datetime import datetime
from typing import Optional, Union, cast
@@ -22,21 +21,14 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteImageRecordStorage(ImageRecordStorageBase):
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.RLock
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
def get(self, image_name: str) -> ImageRecord:
try:
self._lock.acquire()
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
f"""--sql
SELECT {IMAGE_DTO_COLS} FROM images
WHERE image_name = ?;
@@ -44,12 +36,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
(image_name,),
)
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
result = cast(Optional[sqlite3.Row], cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordNotFoundException from e
finally:
self._lock.release()
if not result:
raise ImageRecordNotFoundException
@@ -58,9 +47,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
try:
self._lock.acquire()
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT metadata FROM images
WHERE image_name = ?;
@@ -68,7 +56,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
(image_name,),
)
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
result = cast(Optional[sqlite3.Row], cursor.fetchone())
if not result:
raise ImageRecordNotFoundException
@@ -77,10 +65,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordNotFoundException from e
finally:
self._lock.release()
def update(
self,
@@ -88,10 +73,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
changes: ImageRecordChanges,
) -> None:
try:
self._lock.acquire()
cursor = self._conn.cursor()
# Change the category of the image
if changes.image_category is not None:
self._cursor.execute(
cursor.execute(
"""--sql
UPDATE images
SET image_category = ?
@@ -102,7 +87,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Change the session associated with the image
if changes.session_id is not None:
self._cursor.execute(
cursor.execute(
"""--sql
UPDATE images
SET session_id = ?
@@ -113,7 +98,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Change the image's `is_intermediate`` flag
if changes.is_intermediate is not None:
self._cursor.execute(
cursor.execute(
"""--sql
UPDATE images
SET is_intermediate = ?
@@ -124,7 +109,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Change the image's `starred`` state
if changes.starred is not None:
self._cursor.execute(
cursor.execute(
"""--sql
UPDATE images
SET starred = ?
@@ -137,8 +122,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
finally:
self._lock.release()
def get_many(
self,
@@ -152,110 +135,104 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
try:
self._lock.acquire()
cursor = self._conn.cursor()
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
images_query = f"""--sql
SELECT {IMAGE_DTO_COLS}
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
images_query = f"""--sql
SELECT {IMAGE_DTO_COLS}
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
query_params.append(is_intermediate)
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
# board_id of "none" is reserved for images without a board
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
# Search term condition
if search_term:
query_conditions += """--sql
AND images.metadata LIKE ?
"""
query_params.append(f"%{search_term.lower()}%")
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
# Add the pagination parameters
images_params.extend([limit, offset])
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
# Build the list of images, deserializing each row
cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
query_params.append(is_intermediate)
# board_id of "none" is reserved for images without a board
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
# Search term condition
if search_term:
query_conditions += """--sql
AND images.metadata LIKE ?
"""
query_params.append(f"%{search_term.lower()}%")
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
# Add the pagination parameters
images_params.extend([limit, offset])
# Build the list of images, deserializing each row
self._cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
self._cursor.execute(count_query, count_params)
count = cast(int, self._cursor.fetchone()[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
cursor.execute(count_query, count_params)
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
def delete(self, image_name: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE FROM images
WHERE image_name = ?;
@@ -266,58 +243,48 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
finally:
self._lock.release()
def delete_many(self, image_names: list[str]) -> None:
try:
placeholders = ",".join("?" for _ in image_names)
cursor = self._conn.cursor()
self._lock.acquire()
placeholders = ",".join("?" for _ in image_names)
# Construct the SQLite query with the placeholders
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
# Execute the query with the list of IDs as parameters
self._cursor.execute(query, image_names)
cursor.execute(query, image_names)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
finally:
self._lock.release()
def get_intermediates_count(self) -> int:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT COUNT(*) FROM images
WHERE is_intermediate = TRUE;
"""
)
count = cast(int, self._cursor.fetchone()[0])
self._conn.commit()
return count
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
finally:
self._lock.release()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images
WHERE is_intermediate = TRUE;
"""
)
count = cast(int, cursor.fetchone()[0])
self._conn.commit()
return count
def delete_intermediates(self) -> list[str]:
try:
self._lock.acquire()
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT image_name FROM images
WHERE is_intermediate = TRUE;
"""
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
self._cursor.execute(
cursor.execute(
"""--sql
DELETE FROM images
WHERE is_intermediate = TRUE;
@@ -328,8 +295,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
finally:
self._lock.release()
def save(
self,
@@ -346,8 +311,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
metadata: Optional[str] = None,
) -> datetime:
try:
self._lock.acquire()
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO images (
image_name,
@@ -380,7 +345,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
)
self._conn.commit()
self._cursor.execute(
cursor.execute(
"""--sql
SELECT created_at
FROM images
@@ -389,34 +354,30 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
(image_name,),
)
created_at = datetime.fromisoformat(self._cursor.fetchone()[0])
created_at = datetime.fromisoformat(cursor.fetchone()[0])
return created_at
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
finally:
self._lock.release()
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
AND images.is_intermediate = FALSE
ORDER BY images.starred DESC, images.created_at DESC
LIMIT 1;
""",
(board_id,),
)
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
AND images.is_intermediate = FALSE
ORDER BY images.starred DESC, images.created_at DESC
LIMIT 1;
""",
(board_id,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
finally:
self._lock.release()
if result is None:
return None

View File

@@ -78,7 +78,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
"""
super().__init__()
self._db = db
self._cursor = db.conn.cursor()
self._logger = logger
@property
@@ -96,38 +95,38 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
with self._db.lock:
try:
self._cursor.execute(
"""--sql
INSERT INTO models (
id,
config
)
VALUES (?,?);
""",
(
config.key,
config.model_dump_json(),
),
)
self._db.conn.commit()
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
INSERT INTO models (
id,
config
)
VALUES (?,?);
""",
(
config.key,
config.model_dump_json(),
),
)
self._db.conn.commit()
except sqlite3.IntegrityError as e:
self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e):
if "models.path" in str(e):
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in str(e):
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
else:
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
except sqlite3.IntegrityError as e:
self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e):
if "models.path" in str(e):
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in str(e):
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
else:
raise e
except sqlite3.Error as e:
self._db.conn.rollback()
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
else:
raise e
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(config.key)
@@ -139,21 +138,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise an UnknownModelException
"""
with self._db.lock:
try:
self._cursor.execute(
"""--sql
DELETE FROM models
WHERE id=?;
""",
(key,),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
DELETE FROM models
WHERE id=?;
""",
(key,),
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
record = self.get_model(key)
@@ -164,23 +163,23 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
json_serialized = record.model_dump_json()
with self._db.lock:
try:
self._cursor.execute(
"""--sql
UPDATE models
SET
config=?
WHERE id=?;
""",
(json_serialized, key),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
UPDATE models
SET
config=?
WHERE id=?;
""",
(json_serialized, key),
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(key)
@@ -192,33 +191,33 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Exceptions: UnknownModelException
"""
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
)
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = self._cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def exists(self, key: str) -> bool:
@@ -227,16 +226,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param key: Unique key for the model to be deleted
"""
count = 0
with self._db.lock:
self._cursor.execute(
"""--sql
select count(*) FROM models
WHERE id=?;
""",
(key,),
)
count = self._cursor.fetchone()[0]
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
select count(*) FROM models
WHERE id=?;
""",
(key,),
)
count = cursor.fetchone()[0]
return count > 0
def search_by_attr(
@@ -284,17 +282,18 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
where_clause.append("format=?")
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
with self._db.lock:
self._cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
""",
tuple(bindings),
)
result = self._cursor.fetchall()
cursor = self._db.conn.cursor()
cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
""",
tuple(bindings),
)
result = cursor.fetchall()
# Parse the model configs.
results: list[AnyModelConfig] = []
@@ -313,34 +312,28 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
"""Return models with the indicated path."""
results = []
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
)
results = [
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
]
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
"""Return models with the indicated hash."""
results = []
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
results = [
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
]
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
return results
def list_models(
@@ -356,33 +349,32 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
ModelRecordOrderBy.Format: "format",
}
# Lock so that the database isn't updated while we're doing the two queries.
with self._db.lock:
# query1: get the total number of model configs
self._cursor.execute(
"""--sql
select count(*) from models;
""",
(),
)
total = int(self._cursor.fetchone()[0])
cursor = self._db.conn.cursor()
# query2: fetch key fields
self._cursor.execute(
f"""--sql
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",
(
per_page,
page * per_page,
),
)
rows = self._cursor.fetchall()
items = [ModelSummary.model_validate(dict(x)) for x in rows]
return PaginatedResults(
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
)
# Lock so that the database isn't updated while we're doing the two queries.
# query1: get the total number of model configs
cursor.execute(
"""--sql
select count(*) from models;
""",
(),
)
total = int(cursor.fetchone()[0])
# query2: fetch key fields
cursor.execute(
f"""--sql
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",
(
per_page,
page * per_page,
),
)
rows = cursor.fetchall()
items = [ModelSummary.model_validate(dict(x)) for x in rows]
return PaginatedResults(page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items)

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional
from typing import Any, Coroutine, Optional
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
@@ -33,7 +33,7 @@ class SessionQueueBase(ABC):
pass
@abstractmethod
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> Coroutine[Any, Any, EnqueueBatchResult]:
"""Enqueues all permutations of a batch for execution."""
pass

View File

@@ -1,6 +1,6 @@
import asyncio
import json
import sqlite3
import threading
from typing import Optional, Union, cast
from pydantic_core import to_jsonable_python
@@ -37,9 +37,6 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteSessionQueue(SessionQueueBase):
__invoker: Invoker
__conn: sqlite3.Connection
__cursor: sqlite3.Cursor
__lock: threading.RLock
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
@@ -55,9 +52,7 @@ class SqliteSessionQueue(SessionQueueBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self.__lock = db.lock
self.__conn = db.conn
self.__cursor = self.__conn.cursor()
self._conn = db.conn
def _set_in_progress_to_canceled(self) -> None:
"""
@@ -65,8 +60,8 @@ class SqliteSessionQueue(SessionQueueBase):
This is necessary because the invoker may have been killed while processing a queue item.
"""
try:
self.__lock.acquire()
self.__cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -74,14 +69,13 @@ class SqliteSessionQueue(SessionQueueBase):
"""
)
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
finally:
self.__lock.release()
def _get_current_queue_size(self, queue_id: str) -> int:
"""Gets the current number of pending queue items"""
self.__cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
@@ -91,11 +85,12 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
return cast(int, self.__cursor.fetchone()[0])
return cast(int, cursor.fetchone()[0])
def _get_highest_priority(self, queue_id: str) -> int:
"""Gets the highest priority value in the queue"""
self.__cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT MAX(priority)
FROM session_queue
@@ -105,12 +100,14 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
return cast(Union[int, None], self.__cursor.fetchone()[0]) or 0
return cast(Union[int, None], cursor.fetchone()[0]) or 0
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
return await asyncio.to_thread(self._enqueue_batch, queue_id, batch, prepend)
def _enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
try:
self.__lock.acquire()
cursor = self._conn.cursor()
# TODO: how does this work in a multi-user scenario?
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.max_queue_size
@@ -132,19 +129,17 @@ class SqliteSessionQueue(SessionQueueBase):
if requested_count > enqueued_count:
values_to_insert = values_to_insert[:max_new_queue_items]
self.__cursor.executemany(
cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
self.__conn.commit()
self._conn.commit()
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
finally:
self.__lock.release()
enqueue_result = EnqueueBatchResult(
queue_id=queue_id,
requested=requested_count,
@@ -156,25 +151,19 @@ class SqliteSessionQueue(SessionQueueBase):
return enqueue_result
def dequeue(self) -> Optional[SessionQueueItem]:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE status = 'pending'
ORDER BY
priority DESC,
item_id ASC
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE status = 'pending'
ORDER BY
priority DESC,
item_id ASC
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
queue_item = SessionQueueItem.queue_item_from_dict(dict(result))
@@ -182,52 +171,40 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
ORDER BY
priority DESC,
created_at ASC
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
ORDER BY
priority DESC,
created_at ASC
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'in_progress'
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'in_progress'
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
@@ -241,8 +218,8 @@ class SqliteSessionQueue(SessionQueueBase):
error_traceback: Optional[str] = None,
) -> SessionQueueItem:
try:
self.__lock.acquire()
self.__cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
UPDATE session_queue
SET status = ?, error_type = ?, error_message = ?, error_traceback = ?
@@ -250,12 +227,10 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(status, error_type, error_message, error_traceback, item_id),
)
self.__conn.commit()
self._conn.commit()
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
finally:
self.__lock.release()
queue_item = self.get_queue_item(item_id)
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
@@ -263,48 +238,36 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
is_empty = cast(int, self.__cursor.fetchone()[0]) == 0
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
is_empty = cast(int, cursor.fetchone()[0]) == 0
return IsEmptyResult(is_empty=is_empty)
def is_full(self, queue_id: str) -> IsFullResult:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, self.__cursor.fetchone()[0]) >= max_queue_size
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
return IsFullResult(is_full=is_full)
def clear(self, queue_id: str) -> ClearResult:
try:
self.__lock.acquire()
self.__cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -312,8 +275,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
count = cursor.fetchone()[0]
cursor.execute(
"""--sql
DELETE
FROM session_queue
@@ -321,17 +284,16 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self.__conn.commit()
self._conn.commit()
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
finally:
self.__lock.release()
self.__invoker.services.events.emit_queue_cleared(queue_id)
return ClearResult(deleted=count)
def prune(self, queue_id: str) -> PruneResult:
try:
cursor = self._conn.cursor()
where = """--sql
WHERE
queue_id = ?
@@ -341,8 +303,7 @@ class SqliteSessionQueue(SessionQueueBase):
OR status = 'canceled'
)
"""
self.__lock.acquire()
self.__cursor.execute(
cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -350,8 +311,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
count = cursor.fetchone()[0]
cursor.execute(
f"""--sql
DELETE
FROM session_queue
@@ -359,12 +320,10 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self.__conn.commit()
self._conn.commit()
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
finally:
self.__lock.release()
return PruneResult(deleted=count)
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
@@ -393,8 +352,8 @@ class SqliteSessionQueue(SessionQueueBase):
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
try:
cursor = self._conn.cursor()
current_queue_item = self.get_current(queue_id)
self.__lock.acquire()
placeholders = ", ".join(["?" for _ in batch_ids])
where = f"""--sql
WHERE
@@ -405,7 +364,7 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'failed'
"""
params = [queue_id] + batch_ids
self.__cursor.execute(
cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -413,8 +372,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
count = cursor.fetchone()[0]
cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -422,20 +381,18 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
self.__conn.commit()
self._conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
finally:
self.__lock.release()
return CancelByBatchIDsResult(canceled=count)
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
try:
cursor = self._conn.cursor()
current_queue_item = self.get_current(queue_id)
self.__lock.acquire()
where = """--sql
WHERE
queue_id == ?
@@ -445,7 +402,7 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'failed'
"""
params = (queue_id, destination)
self.__cursor.execute(
cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -453,8 +410,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
params,
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
count = cursor.fetchone()[0]
cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -462,20 +419,18 @@ class SqliteSessionQueue(SessionQueueBase):
""",
params,
)
self.__conn.commit()
self._conn.commit()
if current_queue_item is not None and current_queue_item.destination == destination:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
finally:
self.__lock.release()
return CancelByDestinationResult(canceled=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
try:
cursor = self._conn.cursor()
current_queue_item = self.get_current(queue_id)
self.__lock.acquire()
where = """--sql
WHERE
queue_id is ?
@@ -484,7 +439,7 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'failed'
"""
params = [queue_id]
self.__cursor.execute(
cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -492,8 +447,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
count = cursor.fetchone()[0]
cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -501,7 +456,7 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
self.__conn.commit()
self._conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
@@ -509,21 +464,19 @@ class SqliteSessionQueue(SessionQueueBase):
current_queue_item, batch_status, queue_status
)
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
finally:
self.__lock.release()
return CancelByQueueIDResult(canceled=count)
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
try:
cursor = self._conn.cursor()
where = """--sql
WHERE
queue_id == ?
AND status == 'pending'
"""
self.__lock.acquire()
self.__cursor.execute(
cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -531,8 +484,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
count = cursor.fetchone()[0]
cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -540,43 +493,35 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self.__conn.commit()
self._conn.commit()
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
finally:
self.__lock.release()
return CancelAllExceptCurrentResult(canceled=count)
def get_queue_item(self, item_id: int) -> SessionQueueItem:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
return SessionQueueItem.queue_item_from_dict(dict(result))
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
try:
cursor = self._conn.cursor()
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
# during execution.
session_json = session.model_dump_json(warnings=False, exclude_none=True)
self.__lock.acquire()
self.__cursor.execute(
cursor.execute(
"""--sql
UPDATE session_queue
SET session = ?
@@ -584,12 +529,10 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(session_json, item_id),
)
self.__conn.commit()
self._conn.commit()
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
finally:
self.__lock.release()
return self.get_queue_item(item_id)
def list_queue_items(
@@ -600,83 +543,71 @@ class SqliteSessionQueue(SessionQueueBase):
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
) -> CursorPaginatedResults[SessionQueueItemDTO]:
try:
item_id = cursor
self.__lock.acquire()
query = """--sql
SELECT item_id,
status,
priority,
field_values,
error_type,
error_message,
error_traceback,
created_at,
updated_at,
completed_at,
started_at,
session_id,
batch_id,
queue_id,
origin,
destination
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
"""
params.extend([priority, priority, item_id])
cursor_ = self._conn.cursor()
item_id = cursor
query = """--sql
SELECT item_id,
status,
priority,
field_values,
error_type,
error_message,
error_traceback,
created_at,
updated_at,
completed_at,
started_at,
session_id,
batch_id,
queue_id,
origin,
destination
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if status is not None:
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
AND status = ?
"""
params.append(limit + 1)
self.__cursor.execute(query, params)
results = cast(list[sqlite3.Row], self.__cursor.fetchall())
items = [SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
# remove the extra item
items.pop()
has_more = True
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
params.append(status)
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
"""
params.extend([priority, priority, item_id])
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
"""
params.append(limit + 1)
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
# remove the extra item
items.pop()
has_more = True
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], self.__cursor.fetchall())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
current_item = self.get_current(queue_id=queue_id)
total = sum(row[1] for row in counts_result)
@@ -695,29 +626,23 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT status, count(*), origin, destination
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None
destination = result[0]["destination"] if result else None
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*), origin, destination
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None
destination = result[0]["destination"] if result else None
return BatchStatus(
batch_id=batch_id,
@@ -733,24 +658,18 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
AND destination = ?
GROUP BY status
""",
(queue_id, destination),
)
counts_result = cast(list[sqlite3.Row], self.__cursor.fetchall())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
AND destination = ?
GROUP BY status
""",
(queue_id, destination),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
@@ -769,8 +688,7 @@ class SqliteSessionQueue(SessionQueueBase):
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
"""Retries the given queue items"""
try:
self.__lock.acquire()
cursor = self._conn.cursor()
values_to_insert: list[tuple] = []
retried_item_ids: list[int] = []
@@ -813,7 +731,7 @@ class SqliteSessionQueue(SessionQueueBase):
# TODO(psyche): Handle max queue size?
self.__cursor.executemany(
cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
@@ -821,12 +739,10 @@ class SqliteSessionQueue(SessionQueueBase):
values_to_insert,
)
self.__conn.commit()
self._conn.commit()
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
finally:
self.__lock.release()
retry_result = RetryItemsResult(
queue_id=queue_id,
retried_item_ids=retried_item_ids,

View File

@@ -1,5 +1,4 @@
import sqlite3
import threading
from logging import Logger
from pathlib import Path
@@ -38,14 +37,20 @@ class SqliteDatabase:
self.logger.info(f"Initializing database at {self.db_path}")
self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
self.lock = threading.RLock()
self.conn.row_factory = sqlite3.Row
if self.verbose:
self.conn.set_trace_callback(self.logger.debug)
# Enable foreign key constraints
self.conn.execute("PRAGMA foreign_keys = ON;")
# Enable Write-Ahead Logging (WAL) mode for better concurrency
self.conn.execute("PRAGMA journal_mode = WAL;")
# Set a busy timeout to prevent database lockups during writes
self.conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
def clean(self) -> None:
"""
Cleans the database by running the VACUUM command, reporting on the freed space.
@@ -53,15 +58,14 @@ class SqliteDatabase:
# No need to clean in-memory database
if not self.db_path:
return
with self.lock:
try:
initial_db_size = Path(self.db_path).stat().st_size
self.conn.execute("VACUUM;")
self.conn.commit()
final_db_size = Path(self.db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
except Exception as e:
self.logger.error(f"Error cleaning database: {e}")
raise
try:
initial_db_size = Path(self.db_path).stat().st_size
self.conn.execute("VACUUM;")
self.conn.commit()
final_db_size = Path(self.db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
except Exception as e:
self.logger.error(f"Error cleaning database: {e}")
raise

View File

@@ -43,46 +43,45 @@ class SqliteMigrator:
def run_migrations(self) -> bool:
"""Migrates the database to the latest version."""
with self._db.lock:
# This throws if there is a problem.
self._migration_set.validate_migration_chain()
cursor = self._db.conn.cursor()
self._create_migrations_table(cursor=cursor)
# This throws if there is a problem.
self._migration_set.validate_migration_chain()
cursor = self._db.conn.cursor()
self._create_migrations_table(cursor=cursor)
if self._migration_set.count == 0:
self._logger.debug("No migrations registered")
return False
if self._migration_set.count == 0:
self._logger.debug("No migrations registered")
return False
if self._get_current_version(cursor=cursor) == self._migration_set.latest_version:
self._logger.debug("Database is up to date, no migrations to run")
return False
if self._get_current_version(cursor=cursor) == self._migration_set.latest_version:
self._logger.debug("Database is up to date, no migrations to run")
return False
self._logger.info("Database update needed")
self._logger.info("Database update needed")
# Make a backup of the db if it needs to be updated and is a file db
if self._db.db_path is not None:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
self._logger.info(f"Backing up database to {str(self._backup_path)}")
# Use SQLite to do the backup
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
self._db.conn.backup(backup_conn)
else:
self._logger.info("Using in-memory database, no backup needed")
# Make a backup of the db if it needs to be updated and is a file db
if self._db.db_path is not None:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
self._logger.info(f"Backing up database to {str(self._backup_path)}")
# Use SQLite to do the backup
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
self._db.conn.backup(backup_conn)
else:
self._logger.info("Using in-memory database, no backup needed")
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
while next_migration is not None:
self._run_migration(next_migration)
next_migration = self._migration_set.get(self._get_current_version(cursor))
self._logger.info("Database updated successfully")
return True
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
while next_migration is not None:
self._run_migration(next_migration)
next_migration = self._migration_set.get(self._get_current_version(cursor))
self._logger.info("Database updated successfully")
return True
def _run_migration(self, migration: Migration) -> None:
"""Runs a single migration."""
try:
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
# exception is raised.
with self._db.lock, self._db.conn as conn:
with self._db.conn as conn:
cursor = conn.cursor()
if self._get_current_version(cursor) != migration.from_version:
raise MigrationError(
@@ -108,27 +107,26 @@ class SqliteMigrator:
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
"""Creates the migrations table for the database, if one does not already exist."""
with self._db.lock:
try:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
if cursor.fetchone() is not None:
return
cursor.execute(
"""--sql
CREATE TABLE migrations (
version INTEGER PRIMARY KEY,
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
);
"""
)
cursor.execute("INSERT INTO migrations (version) VALUES (0);")
cursor.connection.commit()
self._logger.debug("Created migrations table")
except sqlite3.Error as e:
msg = f"Problem creating migrations table: {e}"
self._logger.error(msg)
cursor.connection.rollback()
raise MigrationError(msg) from e
try:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
if cursor.fetchone() is not None:
return
cursor.execute(
"""--sql
CREATE TABLE migrations (
version INTEGER PRIMARY KEY,
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
);
"""
)
cursor.execute("INSERT INTO migrations (version) VALUES (0);")
cursor.connection.commit()
self._logger.debug("Created migrations table")
except sqlite3.Error as e:
msg = f"Problem creating migrations table: {e}"
self._logger.error(msg)
cursor.connection.rollback()
raise MigrationError(msg) from e
@classmethod
def _get_current_version(cls, cursor: sqlite3.Cursor) -> int:

View File

@@ -17,9 +17,7 @@ from invokeai.app.util.misc import uuid_string
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@@ -27,31 +25,25 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
"""Gets a style preset by ID."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = self._cursor.fetchone()
if row is None:
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
return StylePresetRecordDTO.from_dict(dict(row))
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = cursor.fetchone()
if row is None:
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
return StylePresetRecordDTO.from_dict(dict(row))
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
style_preset_id = uuid_string()
try:
self._lock.acquire()
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
id,
@@ -72,18 +64,16 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(style_preset_id)
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
style_preset_ids = []
try:
self._lock.acquire()
cursor = self._conn.cursor()
for style_preset in style_presets:
style_preset_id = uuid_string()
style_preset_ids.append(style_preset_id)
self._cursor.execute(
cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
id,
@@ -104,17 +94,15 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return None
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
try:
self._lock.acquire()
cursor = self._conn.cursor()
# Change the name of a style preset
if changes.name is not None:
self._cursor.execute(
cursor.execute(
"""--sql
UPDATE style_presets
SET name = ?
@@ -125,7 +113,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
# Change the preset data for a style preset
if changes.preset_data is not None:
self._cursor.execute(
cursor.execute(
"""--sql
UPDATE style_presets
SET preset_data = ?
@@ -138,14 +126,12 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(style_preset_id)
def delete(self, style_preset_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE from style_presets
WHERE id = ?;
@@ -156,46 +142,38 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return None
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
try:
self._lock.acquire()
main_query = """
SELECT
*
FROM style_presets
"""
main_query = """
SELECT
*
FROM style_presets
"""
if type is not None:
main_query += "WHERE type = ? "
if type is not None:
main_query += "WHERE type = ? "
main_query += "ORDER BY LOWER(name) ASC"
main_query += "ORDER BY LOWER(name) ASC"
if type is not None:
self._cursor.execute(main_query, (type,))
else:
self._cursor.execute(main_query)
cursor = self._conn.cursor()
if type is not None:
cursor.execute(main_query, (type,))
else:
cursor.execute(main_query)
rows = self._cursor.fetchall()
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
rows = cursor.fetchall()
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
return style_presets
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return style_presets
def _sync_default_style_presets(self) -> None:
"""Syncs default style presets to the database. Internal use only."""
# First delete all existing default style presets
try:
self._lock.acquire()
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE FROM style_presets
WHERE type = "default";
@@ -205,10 +183,8 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
# Next, parse and create the default style presets
with self._lock, open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
with open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
presets = json.load(file)
for preset in presets:
style_preset = StylePresetWithoutId.model_validate(preset)

View File

@@ -23,9 +23,7 @@ from invokeai.app.util.misc import uuid_string
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._lock = db.lock
self._conn = db.conn
self._cursor = self._conn.cursor()
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@@ -33,42 +31,36 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def get(self, workflow_id: str) -> WorkflowRecordDTO:
"""Gets a workflow by ID. Updates the opened_at column."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
UPDATE workflow_library
SET opened_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = ?;
""",
(workflow_id,),
)
self._conn.commit()
self._cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
FROM workflow_library
WHERE workflow_id = ?;
""",
(workflow_id,),
)
row = self._cursor.fetchone()
if row is None:
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
return WorkflowRecordDTO.from_dict(dict(row))
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
UPDATE workflow_library
SET opened_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE workflow_id = ?;
""",
(workflow_id,),
)
self._conn.commit()
cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
FROM workflow_library
WHERE workflow_id = ?;
""",
(workflow_id,),
)
row = cursor.fetchone()
if row is None:
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
return WorkflowRecordDTO.from_dict(dict(row))
def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
try:
# Only user workflows may be created by this method
assert workflow.meta.category is WorkflowCategory.User
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
self._lock.acquire()
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO workflow_library (
workflow_id,
@@ -82,14 +74,12 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(workflow_with_id.id)
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
try:
self._lock.acquire()
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
UPDATE workflow_library
SET workflow = ?
@@ -101,14 +91,12 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return self.get(workflow.id)
def delete(self, workflow_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE from workflow_library
WHERE workflow_id = ? AND category = 'user';
@@ -119,8 +107,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return None
def get_many(
@@ -132,66 +118,60 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
per_page: Optional[int] = None,
query: Optional[str] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
try:
self._lock.acquire()
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
assert category in WorkflowCategory
count_query = "SELECT COUNT(*) FROM workflow_library WHERE category = ?"
main_query = """
SELECT
workflow_id,
category,
name,
description,
created_at,
updated_at,
opened_at
FROM workflow_library
WHERE category = ?
"""
main_params: list[int | str] = [category.value]
count_params: list[int | str] = [category.value]
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
assert category in WorkflowCategory
count_query = "SELECT COUNT(*) FROM workflow_library WHERE category = ?"
main_query = """
SELECT
workflow_id,
category,
name,
description,
created_at,
updated_at,
opened_at
FROM workflow_library
WHERE category = ?
"""
main_params: list[int | str] = [category.value]
count_params: list[int | str] = [category.value]
stripped_query = query.strip() if query else None
if stripped_query:
wildcard_query = "%" + stripped_query + "%"
main_query += " AND name LIKE ? OR description LIKE ? "
count_query += " AND name LIKE ? OR description LIKE ?;"
main_params.extend([wildcard_query, wildcard_query])
count_params.extend([wildcard_query, wildcard_query])
stripped_query = query.strip() if query else None
if stripped_query:
wildcard_query = "%" + stripped_query + "%"
main_query += " AND name LIKE ? OR description LIKE ? "
count_query += " AND name LIKE ? OR description LIKE ?;"
main_params.extend([wildcard_query, wildcard_query])
count_params.extend([wildcard_query, wildcard_query])
main_query += f" ORDER BY {order_by.value} {direction.value}"
main_query += f" ORDER BY {order_by.value} {direction.value}"
if per_page:
main_query += " LIMIT ? OFFSET ?"
main_params.extend([per_page, page * per_page])
if per_page:
main_query += " LIMIT ? OFFSET ?"
main_params.extend([per_page, page * per_page])
self._cursor.execute(main_query, main_params)
rows = self._cursor.fetchall()
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
cursor = self._conn.cursor()
cursor.execute(main_query, main_params)
rows = cursor.fetchall()
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
self._cursor.execute(count_query, count_params)
total = self._cursor.fetchone()[0]
cursor.execute(count_query, count_params)
total = cursor.fetchone()[0]
if per_page:
pages = total // per_page + (total % per_page > 0)
else:
pages = 1 # If no pagination, there is only one page
if per_page:
pages = total // per_page + (total % per_page > 0)
else:
pages = 1 # If no pagination, there is only one page
return PaginatedResults(
items=workflows,
page=page,
per_page=per_page if per_page else total,
pages=pages,
total=total,
)
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()
return PaginatedResults(
items=workflows,
page=page,
per_page=per_page if per_page else total,
pages=pages,
total=total,
)
def _sync_default_workflows(self) -> None:
"""Syncs default workflows to the database. Internal use only."""
@@ -207,7 +187,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
"""
try:
self._lock.acquire()
workflows: list[Workflow] = []
workflows_dir = Path(__file__).parent / Path("default_workflows")
workflow_paths = workflows_dir.glob("*.json")
@@ -218,14 +197,15 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
workflows.append(workflow)
# Only default workflows may be managed by this method
assert all(w.meta.category is WorkflowCategory.Default for w in workflows)
self._cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE FROM workflow_library
WHERE category = 'default';
"""
)
for w in workflows:
self._cursor.execute(
cursor.execute(
"""--sql
INSERT OR REPLACE INTO workflow_library (
workflow_id,
@@ -239,5 +219,3 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
except Exception:
self._conn.rollback()
raise
finally:
self._lock.release()

View File

@@ -1,20 +1,31 @@
import logging
import os
import sys
def configure_torch_cuda_allocator(pytorch_cuda_alloc_conf: str, logger: logging.Logger | None = None):
def configure_torch_cuda_allocator(pytorch_cuda_alloc_conf: str, logger: logging.Logger):
"""Configure the PyTorch CUDA memory allocator. See
https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf for supported
configurations.
"""
# Raise if the PYTORCH_CUDA_ALLOC_CONF environment variable is already set.
if "torch" in sys.modules:
raise RuntimeError("configure_torch_cuda_allocator() must be called before importing torch.")
# Log a warning if the PYTORCH_CUDA_ALLOC_CONF environment variable is already set.
prev_cuda_alloc_conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", None)
if prev_cuda_alloc_conf is not None:
raise RuntimeError(
f"Attempted to configure the PyTorch CUDA memory allocator, but PYTORCH_CUDA_ALLOC_CONF is already set to "
f"'{prev_cuda_alloc_conf}'."
)
if prev_cuda_alloc_conf == pytorch_cuda_alloc_conf:
logger.info(
f"PYTORCH_CUDA_ALLOC_CONF is already set to '{pytorch_cuda_alloc_conf}'. Skipping configuration."
)
return
else:
logger.warning(
f"Attempted to configure the PyTorch CUDA memory allocator with '{pytorch_cuda_alloc_conf}', but PYTORCH_CUDA_ALLOC_CONF is already set to "
f"'{prev_cuda_alloc_conf}'. Skipping configuration."
)
return
# Configure the PyTorch CUDA memory allocator.
# NOTE: It is important that this happens before torch is imported.
@@ -38,5 +49,4 @@ def configure_torch_cuda_allocator(pytorch_cuda_alloc_conf: str, logger: logging
"not imported before calling configure_torch_cuda_allocator()."
)
if logger is not None:
logger.info(f"PyTorch CUDA memory allocator: {torch.cuda.get_allocator_backend()}")
logger.info(f"PyTorch CUDA memory allocator: {torch.cuda.get_allocator_backend()}")

View File

@@ -107,7 +107,13 @@
"min": "Min",
"max": "Max",
"resetToDefaults": "Auf Standard zurücksetzen",
"seed": "Seed"
"seed": "Seed",
"row": "Reihe",
"column": "Spalte",
"end": "Ende",
"layout": "Layout",
"board": "Ordner",
"combinatorial": "Kombinatorisch"
},
"gallery": {
"galleryImageSize": "Bildgröße",
@@ -616,7 +622,9 @@
"hfTokenUnableToVerify": "HF-Token kann nicht überprüft werden",
"hfTokenUnableToVerifyErrorMessage": "HuggingFace-Token kann nicht überprüft werden. Dies ist wahrscheinlich auf einen Netzwerkfehler zurückzuführen. Bitte versuchen Sie es später erneut.",
"hfTokenSaved": "HF-Token gespeichert",
"hfTokenRequired": "Sie versuchen, ein Modell herunterzuladen, für das ein gültiges HuggingFace-Token erforderlich ist."
"hfTokenRequired": "Sie versuchen, ein Modell herunterzuladen, für das ein gültiges HuggingFace-Token erforderlich ist.",
"urlUnauthorizedErrorMessage2": "Hier erfahren wie.",
"urlForbidden": "Sie haben keinen Zugriff auf dieses Modell"
},
"parameters": {
"images": "Bilder",
@@ -683,7 +691,8 @@
"iterations": "Iterationen",
"guidance": "Führung",
"coherenceMode": "Modus",
"recallMetadata": "Metadaten abrufen"
"recallMetadata": "Metadaten abrufen",
"gaussianBlur": "Gaußsche Unschärfe"
},
"settings": {
"displayInProgress": "Zwischenbilder anzeigen",
@@ -883,7 +892,8 @@
"canvas": "Leinwand",
"prompts_one": "Prompt",
"prompts_other": "Prompts",
"batchSize": "Stapelgröße"
"batchSize": "Stapelgröße",
"confirm": "Bestätigen"
},
"metadata": {
"negativePrompt": "Negativ Beschreibung",
@@ -1298,7 +1308,13 @@
"noBatchGroup": "keine Gruppe",
"generatorNoValues": "leer",
"generatorLoading": "wird geladen",
"generatorLoadFromFile": "Aus Datei laden"
"generatorLoadFromFile": "Aus Datei laden",
"showEdgeLabels": "Kantenbeschriftungen anzeigen",
"downloadWorkflowError": "Fehler beim Herunterladen des Arbeitsablaufs",
"nodeName": "Knotenname",
"description": "Beschreibung",
"loadWorkflowDesc": "Arbeitsablauf laden?",
"loadWorkflowDesc2": "Ihr aktueller Arbeitsablauf enthält nicht gespeicherte Änderungen."
},
"hrf": {
"enableHrf": "Korrektur für hohe Auflösungen",

View File

@@ -1726,6 +1726,10 @@
"layout": "Layout",
"row": "Row",
"column": "Column",
"container": "Container",
"heading": "Heading",
"text": "Text",
"divider": "Divider",
"nodeField": "Node Field",
"zoomToNode": "Zoom to Node",
"nodeFieldTooltip": "To add a node field, click the small plus sign button on the field in the Workflow Editor, or drag the field by its name into the form.",
@@ -1907,7 +1911,7 @@
"resetGenerationSettings": "Reset Generation Settings",
"replaceCurrent": "Replace Current",
"controlLayerEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer, or draw on the canvas to get started.",
"referenceImageEmptyState": "<UploadButton>Upload an image</UploadButton> or drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer to get started.",
"referenceImageEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer, or <PullBboxButton>pull the bounding box into this layer</PullBboxButton> to get started.",
"warnings": {
"problemsFound": "Problems found",
"unsupportedModel": "layer not supported for selected base model",
@@ -2302,8 +2306,8 @@
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"items": [
"Workflow Editor: New drag-and-drop form builder for easier workflow creation.",
"Other improvements: Faster batch queuing, better upscaling, improved color picker, and metadata nodes."
"Memory Management: New setting for users with Nvidia GPUs to reduce VRAM usage.",
"Performance: Continued improvements to overall application performance and responsiveness."
],
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",

View File

@@ -105,7 +105,11 @@
"resetToDefaults": "Ripristina le impostazioni predefinite",
"seed": "Seme",
"combinatorial": "Combinatorio",
"count": "Quantità"
"count": "Quantità",
"board": "Bacheca",
"layout": "Schema",
"row": "Riga",
"column": "Colonna"
},
"gallery": {
"galleryImageSize": "Dimensione dell'immagine",
@@ -173,7 +177,8 @@
"assetsTab": "File che hai caricato per usarli nei tuoi progetti.",
"boardsSettings": "Impostazioni Bacheche",
"imagesSettings": "Impostazioni Immagini Galleria",
"assets": "Risorse"
"assets": "Risorse",
"images": "Immagini"
},
"hotkeys": {
"searchHotkeys": "Cerca tasti di scelta rapida",
@@ -703,7 +708,8 @@
"collectionNumberNotMultipleOf": "{{value}} non è multiplo di {{multipleOf}}",
"collectionNumberLTMin": "{{value}} < {{minimum}} (incr min)",
"collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (excl max)",
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (excl min)"
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (excl min)",
"collectionEmpty": "raccolta vuota"
},
"useCpuNoise": "Usa la CPU per generare rumore",
"iterations": "Iterazioni",
@@ -920,8 +926,8 @@
"boolean": "Booleani",
"node": "Nodo",
"collection": "Raccolta",
"cannotConnectInputToInput": "Impossibile collegare Input a Input",
"cannotConnectOutputToOutput": "Impossibile collegare Output ad Output",
"cannotConnectInputToInput": "Impossibile collegare ingresso a ingresso",
"cannotConnectOutputToOutput": "Impossibile collegare uscita ad uscita",
"cannotConnectToSelf": "Impossibile connettersi a se stesso",
"mismatchedVersion": "Nodo non valido: il nodo {{node}} di tipo {{type}} ha una versione non corrispondente (provare ad aggiornare?)",
"loadingNodes": "Caricamento nodi...",
@@ -1014,7 +1020,21 @@
"generatorNRandomValues_one": "{{count}} valore casuale",
"generatorNRandomValues_many": "{{count}} valori casuali",
"generatorNRandomValues_other": "{{count}} valori casuali",
"arithmeticSequence": "Sequenza aritmetica"
"arithmeticSequence": "Sequenza aritmetica",
"nodeName": "Nome del nodo",
"loadWorkflowDesc": "Caricare il flusso di lavoro?",
"loadWorkflowDesc2": "Il flusso di lavoro corrente presenta modifiche non salvate.",
"downloadWorkflowError": "Errore durante lo scaricamento del flusso di lavoro",
"deletedMissingNodeFieldFormElement": "Campo modulo mancante eliminato: nodo {{nodeId}} campo {{fieldName}}",
"loadingTemplates": "Caricamento {{name}}",
"unableToUpdateNode": "Aggiornamento del nodo non riuscito: nodo {{node}} di tipo {{type}} (potrebbe essere necessario eliminarlo e ricrearlo)",
"description": "Descrizione",
"generatorImagesCategory": "Categoria",
"generatorImages_one": "{{count}} immagine",
"generatorImages_many": "{{count}} immagini",
"generatorImages_other": "{{count}} immagini",
"generatorImagesFromBoard": "Immagini dalla Bacheca",
"missingSourceOrTargetNode": "Nodo sorgente o di destinazione mancante"
},
"boards": {
"autoAddBoard": "Aggiungi automaticamente bacheca",
@@ -1140,7 +1160,10 @@
"cancelAllExceptCurrentQueueItemAlertDialog2": "Vuoi davvero annullare tutti gli elementi in coda in sospeso?",
"confirm": "Conferma",
"cancelAllExceptCurrentQueueItemAlertDialog": "L'annullamento di tutti gli elementi della coda, eccetto quello corrente, interromperà gli elementi in sospeso ma consentirà il completamento di quello in corso.",
"cancelAllExceptCurrentTooltip": "Annulla tutto tranne l'elemento corrente"
"cancelAllExceptCurrentTooltip": "Annulla tutto tranne l'elemento corrente",
"retrySucceeded": "Elemento rieseguito",
"retryItem": "Riesegui elemento",
"retryFailed": "Problema riesecuzione elemento"
},
"models": {
"noMatchingModels": "Nessun modello corrispondente",
@@ -1718,7 +1741,38 @@
"download": "Scarica",
"copyShareLink": "Copia Condividi Link",
"copyShareLinkForWorkflow": "Copia Condividi Link del Flusso di lavoro",
"delete": "Elimina"
"delete": "Elimina",
"openLibrary": "Apri la libreria",
"builder": {
"resetAllNodeFields": "Reimposta tutti i campi del nodo",
"row": "Riga",
"nodeField": "Campo del nodo",
"slider": "Cursore",
"emptyRootPlaceholderEditMode": "Per iniziare, trascina qui un elemento del modulo o un campo nodo.",
"containerPlaceholder": "Contenitore vuoto",
"headingPlaceholder": "Titolo vuoto",
"column": "Colonna",
"nodeFieldTooltip": "Per aggiungere un campo nodo, fare clic sul piccolo pulsante con il segno più sul campo nell'editor del flusso di lavoro, oppure trascinare il campo in base al suo nome nel modulo.",
"label": "Etichetta",
"deleteAllElements": "Elimina tutti gli elementi del modulo",
"addToForm": "Aggiungi al Modulo",
"layout": "Schema",
"builder": "Generatore Modulo",
"zoomToNode": "Zoom sul nodo",
"component": "Componente",
"showDescription": "Mostra Descrizione",
"singleLine": "Linea singola",
"multiLine": "Linea Multipla",
"both": "Entrambi",
"emptyRootPlaceholderViewMode": "Fare clic su Modifica per iniziare a creare un modulo per questo flusso di lavoro.",
"textPlaceholder": "Testo vuoto",
"workflowBuilderAlphaWarning": "Il generatore di flussi di lavoro è attualmente in versione alpha. Potrebbero esserci cambiamenti radicali prima della versione stabile.",
"heading": "Intestazione",
"divider": "Divisore",
"container": "Contenitore",
"text": "Testo",
"numberInput": "Ingresso numerico"
}
},
"accordions": {
"compositing": {
@@ -2140,7 +2194,7 @@
"controlLayerEmptyState": "<UploadButton>Carica un'immagine</UploadButton>, trascina un'immagine dalla <GalleryButton>galleria</GalleryButton> su questo livello oppure disegna sulla tela per iniziare.",
"useImage": "Usa immagine",
"resetGenerationSettings": "Ripristina impostazioni di generazione",
"referenceImageEmptyState": "Per iniziare, <UploadButton>carica un'immagine</UploadButton> oppure trascina un'immagine dalla <GalleryButton>galleria</GalleryButton> su questo livello.",
"referenceImageEmptyState": "Per iniziare, <UploadButton>carica un'immagine</UploadButton>, trascina un'immagine dalla <GalleryButton>galleria</GalleryButton>, oppure <PullBboxButton>trascina il riquadro di delimitazione in questo livello</PullBboxButton> su questo livello.",
"asRasterLayer": "Come $t(controlLayers.rasterLayer)",
"asRasterLayerResize": "Come $t(controlLayers.rasterLayer) (Ridimensiona)",
"asControlLayer": "Come $t(controlLayers.controlLayer)",
@@ -2276,8 +2330,8 @@
"watchRecentReleaseVideos": "Guarda i video su questa versione",
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
"items": [
"Impostazioni predefinite VRAM migliorate",
"Cancellazione della cache del modello su richiesta"
"Editor del flusso di lavoro: nuovo generatore di moduli trascina-e-rilascia per una creazione più facile del flusso di lavoro.",
"Altri miglioramenti: messa in coda dei lotti più rapida, migliore ampliamento, selettore colore migliorato e nodi metadati."
]
},
"system": {

View File

@@ -118,7 +118,8 @@
"compareHelp2": "Nhấn <Kbd>M</Kbd> để tuần hoàn trong chế độ so sánh.",
"boardsSettings": "Thiết Lập Bảng",
"imagesSettings": "Cài Đặt Ảnh Trong Thư Viện Ảnh",
"assets": "Tài Nguyên"
"assets": "Tài Nguyên",
"images": "Hình Ảnh"
},
"common": {
"ipAdapter": "IP Adapter",
@@ -233,7 +234,8 @@
"combinatorial": "Tổ Hợp",
"column": "Cột",
"layout": "Bố Cục",
"row": "Hàng"
"row": "Hàng",
"board": "Bảng"
},
"prompt": {
"addPromptTrigger": "Thêm Prompt Trigger",
@@ -873,7 +875,7 @@
"removeLinearView": "Xoá Khỏi Chế Độ Xem Tuyến Tính",
"unknownErrorValidatingWorkflow": "Lỗi không rõ khi xác thực workflow",
"unableToLoadWorkflow": "Không Thể Tải Workflow",
"workflowSettings": "Cài Đặt Trình Biên Tập Viên Workflow",
"workflowSettings": "Cài Đặt Biên Tập Workflow",
"workflowVersion": "Phiên Bản",
"unableToGetWorkflowVersion": "Không thể tìm phiên bản của lược đồ workflow",
"collection": "Đa tài nguyên",
@@ -1010,7 +1012,11 @@
"loadWorkflowDesc2": "Workflow hiện tại của bạn có những điều chỉnh chưa được lưu.",
"loadingTemplates": "Đang Tải {{name}}",
"nodeName": "Tên Node",
"unableToUpdateNode": "Cập nhật node thất bại: node {{node}} thuộc dạng {{type}} (có thể cần xóa và tạo lại)"
"unableToUpdateNode": "Cập nhật node thất bại: node {{node}} thuộc dạng {{type}} (có thể cần xóa và tạo lại)",
"downloadWorkflowError": "Lỗi tải xuống workflow",
"generatorImagesFromBoard": "Ảnh Từ Bảng",
"generatorImagesCategory": "Phân Loại",
"generatorImages_other": "{{count}} ảnh"
},
"popovers": {
"paramCFGRescaleMultiplier": {
@@ -1494,7 +1500,9 @@
"batchNodeCollectionSizeMismatch": "Kích cỡ tài nguyên không phù hợp với Lô {{batchGroupId}}",
"emptyBatches": "lô trống",
"batchNodeNotConnected": "Node Hàng Loạt chưa được kết nối: {{label}}",
"batchNodeEmptyCollection": "Một vài node hàng loạt có tài nguyên rỗng"
"batchNodeEmptyCollection": "Một vài node hàng loạt có tài nguyên rỗng",
"collectionEmpty": "tài nguyên trống",
"batchNodeCollectionSizeMismatchNoGroupId": "tài nguyên theo nhóm có kích thước sai lệch"
},
"cfgScale": "Thang CFG",
"useSeed": "Dùng Hạt Giống",
@@ -2159,7 +2167,7 @@
"parameterSetDesc": "Gợi lại {{parameter}}",
"loadedWithWarnings": "Đã Tải Workflow Với Cảnh Báo",
"outOfMemoryErrorDesc": "Thiết lập tạo sinh hiện tại đã vượt mức cho phép của thiết bị. Hãy điều chỉnh thiết lập và thử lại.",
"setNodeField": "Đặt làm vùng cho node",
"setNodeField": "Đặt làm vùng node",
"problemRetrievingWorkflow": "Có Vấn Đề Khi Lấy Lại Workflow",
"somethingWentWrong": "Có Vấn Đề Phát Sinh",
"problemDeletingWorkflow": "Có Vấn Đề Khi Xoá Workflow",
@@ -2247,11 +2255,11 @@
"workflowLibrary": "Thư Viện",
"opened": "Ngày Mở",
"deleteWorkflow": "Xoá Workflow",
"workflowEditorMenu": "Menu Biên Tập Viên Workflow",
"workflowEditorMenu": "Menu Biên Tập Workflow",
"uploadAndSaveWorkflow": "Tải Lên Thư Viện",
"openLibrary": "Mở Thư Viện",
"builder": {
"resetAllNodeFields": "Khởi Động Lại Tất Cả Vùng Cho Node",
"resetAllNodeFields": "Tải Lại Các Vùng Node",
"builder": "Trình Tạo Vùng Nhập",
"layout": "Bố Cục",
"row": "Hàng",
@@ -2266,15 +2274,19 @@
"slider": "Thanh Trượt",
"both": "Cả Hai",
"emptyRootPlaceholderViewMode": "Chọn Chỉnh Sửa để bắt đầu tạo nên một vùng nhập cho workflow này.",
"emptyRootPlaceholderEditMode": "Kéo thành phần vùng nhập hoặc vùng cho node vào đây để bắt đầu.",
"emptyRootPlaceholderEditMode": "Kéo thành phần vùng nhập hoặc vùng node vào đây để bắt đầu.",
"containerPlaceholder": "Hộp Chứa Trống",
"headingPlaceholder": "Đầu Dòng Trống",
"textPlaceholder": "Mô Tả Trống",
"textPlaceholder": "Văn Bản Trống",
"column": "Cột",
"deleteAllElements": "Xóa Tất Cả Thành Phần Vùng Nhập",
"nodeField": "Vùng Cho Node",
"nodeFieldTooltip": "Để thêm vùng cho node, bấm vào dấu cộng nhỏ trên vùng trong Vùng Biên Tập Workflow, hoặc kéo vùng theo tên của nó vào vùng nhập.",
"workflowBuilderAlphaWarning": "Trình tạo workflow đang trong giai đoạn alpha. Nó có thể xuất hiện những thay đổi đột ngột trước khi chính thức được phát hành."
"deleteAllElements": "Xóa Tất Cả Thành Phần",
"nodeField": "Vùng Node",
"nodeFieldTooltip": "Để thêm vùng node, bấm vào dấu cộng nhỏ trên vùng trong Trình Biên Tập Workflow, hoặc kéo vùng theo tên của nó vào vùng nhập.",
"workflowBuilderAlphaWarning": "Trình tạo vùng nhập đang trong giai đoạn alpha. Nó có thể xuất hiện những thay đổi đột ngột trước khi chính thức được phát hành.",
"container": "Hộp Chứa",
"heading": "Đầu Dòng",
"text": "Văn Bản",
"divider": "Gạch Chia"
}
},
"upscaling": {
@@ -2310,8 +2322,8 @@
"watchRecentReleaseVideos": "Xem Video Phát Hành Mới Nhất",
"watchUiUpdatesOverview": "Xem Tổng Quan Về Những Cập Nhật Cho Giao Diện Người Dùng",
"items": [
"Cải thiện các thiết lập mặc định của VRAM",
"Xoá bộ nhớ đệm của model theo yêu cầu"
"Trình Biên Tập Workflow: trình tạo vùng nhập dưới dạng kéo thả nhằm tạo dựng workflow dễ dàng hơn.",
"Các nâng cấp khác: Xếp hàng tạo sinh theo nhóm nhanh hơn, upscale tốt hơn, trình chọn màu được cải thiện, và node chứa metadata."
]
},
"upsell": {

View File

@@ -2,6 +2,7 @@ import { Button, Flex, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { usePullBboxIntoGlobalReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd';
import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
@@ -27,6 +28,7 @@ export const IPAdapterSettingsEmptyState = memo(() => {
const onClickGalleryButton = useCallback(() => {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch]);
const pullBboxIntoIPAdapter = usePullBboxIntoGlobalReferenceImage(entityIdentifier);
const dndTargetData = useMemo<SetGlobalReferenceImageDndTargetData>(
() => setGlobalReferenceImageDndTarget.getData({ entityIdentifier }),
@@ -41,8 +43,11 @@ export const IPAdapterSettingsEmptyState = memo(() => {
GalleryButton: (
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
),
PullBboxButton: (
<Button onClick={pullBboxIntoIPAdapter} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
),
}),
[isBusy, onClickGalleryButton, uploadApi]
[isBusy, onClickGalleryButton, pullBboxIntoIPAdapter, uploadApi]
);
return (

View File

@@ -2,6 +2,7 @@ import { Button, Flex, IconButton, Spacer, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { usePullBboxIntoRegionalGuidanceReferenceImage } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { rgIPAdapterDeleted } from 'features/controlLayers/store/canvasSlice';
import type { SetRegionalGuidanceReferenceImageDndTargetData } from 'features/dnd/dnd';
@@ -36,6 +37,7 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag
const onDeleteIPAdapter = useCallback(() => {
dispatch(rgIPAdapterDeleted({ entityIdentifier, referenceImageId }));
}, [dispatch, entityIdentifier, referenceImageId]);
const pullBboxIntoIPAdapter = usePullBboxIntoRegionalGuidanceReferenceImage(entityIdentifier, referenceImageId);
const dndTargetData = useMemo<SetRegionalGuidanceReferenceImageDndTargetData>(
() =>
@@ -46,6 +48,21 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag
[entityIdentifier, referenceImageId]
);
const components = useMemo(
() => ({
UploadButton: (
<Button isDisabled={isBusy} size="sm" variant="link" color="base.300" {...uploadApi.getUploadButtonProps()} />
),
GalleryButton: (
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
),
PullBboxButton: (
<Button onClick={pullBboxIntoIPAdapter} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
),
}),
[isBusy, onClickGalleryButton, pullBboxIntoIPAdapter, uploadApi]
);
return (
<Flex flexDir="column" gap={2} position="relative" w="full">
<Flex alignItems="center" gap={2}>
@@ -66,23 +83,7 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag
</Flex>
<Flex alignItems="center" gap={2} p={4}>
<Text textAlign="center" color="base.300">
<Trans
i18nKey="controlLayers.referenceImageEmptyState"
components={{
UploadButton: (
<Button
isDisabled={isBusy}
size="sm"
variant="link"
color="base.300"
{...uploadApi.getUploadButtonProps()}
/>
),
GalleryButton: (
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
),
}}
/>
<Trans i18nKey="controlLayers.referenceImageEmptyState" components={components} />
</Text>
</Flex>
<input {...uploadApi.getUploadInputProps()} />

View File

@@ -14,8 +14,7 @@ import { $hasTemplates } from 'features/nodes/store/nodesSlice';
import { selectIsFormEmpty } from 'features/nodes/store/workflowSlice';
import type { FormElement } from 'features/nodes/types/workflow';
import { buildContainer, buildDivider, buildHeading, buildText } from 'features/nodes/types/workflow';
import { startCase } from 'lodash-es';
import type { RefObject } from 'react';
import type { PropsWithChildren, RefObject } from 'react';
import { memo, useEffect, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
@@ -39,10 +38,10 @@ export const WorkflowBuilder = memo(() => {
<Flex justifyContent="center" w="full" h="full">
<Flex flexDir="column" w="full" maxW="768px" gap={2}>
<Flex w="full" alignItems="center" gap={2} pt={3}>
<AddFormElementDndButton type="container" />
<AddFormElementDndButton type="divider" />
<AddFormElementDndButton type="heading" />
<AddFormElementDndButton type="text" />
<AddFormElementDndButton type="container">{t('workflows.builder.container')}</AddFormElementDndButton>
<AddFormElementDndButton type="divider">{t('workflows.builder.divider')}</AddFormElementDndButton>
<AddFormElementDndButton type="heading">{t('workflows.builder.heading')}</AddFormElementDndButton>
<AddFormElementDndButton type="text">{t('workflows.builder.text')}</AddFormElementDndButton>
<Button size="sm" variant="ghost" tooltip={t('workflows.builder.nodeFieldTooltip')}>
{t('workflows.builder.nodeField')}
</Button>
@@ -130,14 +129,17 @@ const addFormElementButtonSx: SystemStyleObject = {
_disabled: { borderStyle: 'dashed', opacity: 0.5 },
};
const AddFormElementDndButton = ({ type }: { type: Parameters<typeof useAddFormElementDnd>[0] }) => {
const AddFormElementDndButton = ({
type,
children,
}: PropsWithChildren<{ type: Parameters<typeof useAddFormElementDnd>[0] }>) => {
const draggableRef = useRef<HTMLDivElement>(null);
const isDragging = useAddFormElementDnd(type, draggableRef);
return (
// Must be as div for draggable to work correctly
<Button as="div" ref={draggableRef} size="sm" isDisabled={isDragging} variant="outline" sx={addFormElementButtonSx}>
{startCase(type)}
{children}
</Button>
);
};

View File

@@ -103,6 +103,7 @@ export const api = customCreateApi({
reducerPath: 'api',
tagTypes,
endpoints: () => ({}),
invalidationBehavior: 'immediately',
});
function getCircularReplacer() {

View File

@@ -1 +1 @@
__version__ = "5.7.2rc1"
__version__ = "5.7.2"

View File

@@ -101,8 +101,7 @@ dependencies = [
"xformers" = [
# Core generation dependencies, pinned for reproducible builds.
"xformers>=0.0.28.post1; sys_platform!='darwin'",
# Auxiliary dependencies, pinned only if necessary.
"triton; sys_platform=='linux'",
# torch 2.4+cu carries its own triton dependency
]
"onnx" = ["onnxruntime"]
"onnx-cuda" = ["onnxruntime-gpu"]

View File

@@ -1,13 +1,127 @@
import pytest
import torch
from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator
from tests.dangerously_run_function_in_subprocess import dangerously_run_function_in_subprocess
# These tests are a bit fiddly, because the depend on the import behaviour of torch. They use subprocesses to isolate
# the import behaviour of torch, and then check that the function behaves as expected. We have to hack in some logging
# to check that the tested function is behaving as expected.
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA device.")
def test_configure_torch_cuda_allocator_raises_if_torch_is_already_imported():
"""Test that configure_torch_cuda_allocator() raises a RuntimeError if torch is already imported."""
import torch # noqa: F401
def test_configure_torch_cuda_allocator_configures_backend():
"""Test that configure_torch_cuda_allocator() raises a RuntimeError if the configured backend does not match the
expected backend."""
with pytest.raises(RuntimeError, match="Failed to configure the PyTorch CUDA memory allocator."):
configure_torch_cuda_allocator("backend:cudaMallocAsync")
def test_func():
import os
# Unset the environment variable if it is set so that we can test setting it
try:
del os.environ["PYTORCH_CUDA_ALLOC_CONF"]
except KeyError:
pass
from unittest.mock import MagicMock
from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator
mock_logger = MagicMock()
# Set the PyTorch CUDA memory allocator to cudaMallocAsync
configure_torch_cuda_allocator("backend:cudaMallocAsync", logger=mock_logger)
# Verify that the PyTorch CUDA memory allocator was configured correctly
import torch
assert torch.cuda.get_allocator_backend() == "cudaMallocAsync"
# Verify that the logger was called with the correct message
mock_logger.info.assert_called_once()
args, _kwargs = mock_logger.info.call_args
logged_message = args[0]
print(logged_message)
stdout, _stderr, returncode = dangerously_run_function_in_subprocess(test_func)
assert returncode == 0
assert "PyTorch CUDA memory allocator: cudaMallocAsync" in stdout
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA device.")
def test_configure_torch_cuda_allocator_raises_if_torch_already_imported():
"""Test that configure_torch_cuda_allocator() raises a RuntimeError if torch was already imported."""
def test_func():
from unittest.mock import MagicMock
# Import torch before calling configure_torch_cuda_allocator()
import torch # noqa: F401
from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator
try:
configure_torch_cuda_allocator("backend:cudaMallocAsync", logger=MagicMock())
except RuntimeError as e:
print(e)
stdout, _stderr, returncode = dangerously_run_function_in_subprocess(test_func)
assert returncode == 0
assert "configure_torch_cuda_allocator() must be called before importing torch." in stdout
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA device.")
def test_configure_torch_cuda_allocator_warns_if_env_var_is_set_differently():
"""Test that configure_torch_cuda_allocator() logs at WARNING level if PYTORCH_CUDA_ALLOC_CONF is set and doesn't
match the requested configuration."""
def test_func():
import os
# Explicitly set the environment variable
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:native"
from unittest.mock import MagicMock
from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator
mock_logger = MagicMock()
# Set the PyTorch CUDA memory allocator a different configuration
configure_torch_cuda_allocator("backend:cudaMallocAsync", logger=mock_logger)
# Verify that the logger was called with the correct message
mock_logger.warning.assert_called_once()
args, _kwargs = mock_logger.warning.call_args
logged_message = args[0]
print(logged_message)
stdout, _stderr, returncode = dangerously_run_function_in_subprocess(test_func)
assert returncode == 0
assert "Attempted to configure the PyTorch CUDA memory allocator with 'backend:cudaMallocAsync'" in stdout
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA device.")
def test_configure_torch_cuda_allocator_logs_if_env_var_is_already_set_correctly():
"""Test that configure_torch_cuda_allocator() logs at INFO level if PYTORCH_CUDA_ALLOC_CONF is set and matches the
requested configuration."""
def test_func():
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:native"
from unittest.mock import MagicMock
from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator
mock_logger = MagicMock()
configure_torch_cuda_allocator("backend:native", logger=mock_logger)
mock_logger.info.assert_called_once()
args, _kwargs = mock_logger.info.call_args
logged_message = args[0]
print(logged_message)
stdout, _stderr, returncode = dangerously_run_function_in_subprocess(test_func)
assert returncode == 0
assert "PYTORCH_CUDA_ALLOC_CONF is already set to 'backend:native'" in stdout

View File

@@ -0,0 +1,46 @@
import inspect
import subprocess
import sys
import textwrap
from typing import Any, Callable
def dangerously_run_function_in_subprocess(func: Callable[[], Any]) -> tuple[str, str, int]:
"""**Use with caution! This should _only_ be used with trusted code!**
Extracts a function's source and runs it in a separate subprocess. Returns stdout, stderr, and return code
from the subprocess.
This is useful for tests where an isolated environment is required.
The function to be called must not have any arguments and must not have any closures over the scope in which is was
defined.
Any modules that the function depends on must be imported inside the function.
"""
source_code = inspect.getsource(func)
# Must dedent the source code to avoid indentation errors
dedented_source_code = textwrap.dedent(source_code)
# Get the function name so we can call it in the subprocess
func_name = func.__name__
# Create a script that calls the function
script = f"""
import sys
{dedented_source_code}
if __name__ == "__main__":
{func_name}()
"""
result = subprocess.run(
[sys.executable, "-c", textwrap.dedent(script)], # Run the script in a subprocess
capture_output=True, # Capture stdout and stderr
text=True,
)
return result.stdout, result.stderr, result.returncode

View File

@@ -0,0 +1,57 @@
from tests.dangerously_run_function_in_subprocess import dangerously_run_function_in_subprocess
def test_simple_function():
def test_func():
print("Hello, Test!")
stdout, stderr, returncode = dangerously_run_function_in_subprocess(test_func)
assert returncode == 0
assert stdout.strip() == "Hello, Test!"
assert stderr == ""
def test_function_with_error():
def test_func():
raise ValueError("This is an error")
_stdout, stderr, returncode = dangerously_run_function_in_subprocess(test_func)
assert returncode != 0 # Should fail
assert "ValueError: This is an error" in stderr
def test_function_with_imports():
def test_func():
import math
print(math.sqrt(4))
stdout, stderr, returncode = dangerously_run_function_in_subprocess(test_func)
assert returncode == 0
assert stdout.strip() == "2.0"
assert stderr == ""
def test_function_with_sys_exit():
def test_func():
import sys
sys.exit(42)
_stdout, _stderr, returncode = dangerously_run_function_in_subprocess(test_func)
assert returncode == 42 # Should return the custom exit code
def test_function_with_closure():
foo = "bar"
def test_func():
print(foo)
_stdout, _stderr, returncode = dangerously_run_function_in_subprocess(test_func)
assert returncode == 1 # Should fail because of closure