mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 19:28:18 -05:00
Compare commits
1 Commits
v6.0.1
...
psychedeli
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c1a4376b75 |
@@ -14,14 +14,15 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._conn = db.conn
|
||||
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO board_images (board_id, image_name)
|
||||
@@ -30,12 +31,17 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
""",
|
||||
(board_id, image_name, board_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM board_images
|
||||
@@ -43,6 +49,10 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
def get_images_for_board(
|
||||
self,
|
||||
@@ -50,26 +60,27 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
ORDER BY board_images.updated_at DESC;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
# TODO: this isn't paginated yet?
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
ORDER BY board_images.updated_at DESC;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images WHERE 1=1;
|
||||
"""
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images WHERE 1=1;
|
||||
"""
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||
|
||||
@@ -79,55 +90,56 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
categories: list[ImageCategory] | None,
|
||||
is_intermediate: bool | None,
|
||||
) -> list[str]:
|
||||
with self._db.transaction() as cursor:
|
||||
params: list[str | bool] = []
|
||||
params: list[str | bool] = []
|
||||
|
||||
# Base query is a join between images and board_images
|
||||
stmt = """
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
# Base query is a join between images and board_images
|
||||
stmt = """
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
# Handle board_id filter
|
||||
if board_id == "none":
|
||||
stmt += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
else:
|
||||
stmt += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
params.append(board_id)
|
||||
# Handle board_id filter
|
||||
if board_id == "none":
|
||||
stmt += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
else:
|
||||
stmt += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
params.append(board_id)
|
||||
|
||||
# Add the category filter
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
stmt += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
# Add the category filter
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
stmt += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
params.append(c)
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
params.append(c)
|
||||
|
||||
# Add the is_intermediate filter
|
||||
if is_intermediate is not None:
|
||||
stmt += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
params.append(is_intermediate)
|
||||
# Add the is_intermediate filter
|
||||
if is_intermediate is not None:
|
||||
stmt += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
params.append(is_intermediate)
|
||||
|
||||
# Put a ring on it
|
||||
stmt += ";"
|
||||
# Put a ring on it
|
||||
stmt += ";"
|
||||
|
||||
cursor.execute(stmt, params)
|
||||
# Execute the query
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(stmt, params)
|
||||
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [r[0] for r in result]
|
||||
return image_names
|
||||
|
||||
@@ -135,31 +147,31 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Optional[str]:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT board_id
|
||||
FROM board_images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT board_id
|
||||
FROM board_images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result is None:
|
||||
return None
|
||||
return cast(str, result[0])
|
||||
|
||||
def get_image_count_for_board(self, board_id: str) -> int:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE images.is_intermediate = FALSE
|
||||
AND board_images.board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE images.is_intermediate = FALSE
|
||||
AND board_images.board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
return count
|
||||
|
||||
@@ -20,57 +20,61 @@ from invokeai.app.util.misc import uuid_string
|
||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._conn = db.conn
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
except Exception as e:
|
||||
raise BoardRecordDeleteException from e
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordDeleteException from e
|
||||
|
||||
def save(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardRecord:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
board_id = uuid_string()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO boards (board_id, board_name)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
(board_id, board_name),
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise BoardRecordSaveException from e
|
||||
try:
|
||||
board_id = uuid_string()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO boards (board_id, board_name)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
(board_id, board_name),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordSaveException from e
|
||||
return self.get(board_id)
|
||||
|
||||
def get(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> BoardRecord:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
raise BoardRecordNotFoundException from e
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
raise BoardRecordNotFoundException from e
|
||||
if result is None:
|
||||
raise BoardRecordNotFoundException
|
||||
return BoardRecord(**dict(result))
|
||||
@@ -80,43 +84,45 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardRecord:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
# Change the name of a board
|
||||
if changes.board_name is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET board_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.board_name, board_id),
|
||||
)
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
# Change the name of a board
|
||||
if changes.board_name is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET board_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.board_name, board_id),
|
||||
)
|
||||
|
||||
# Change the cover image of a board
|
||||
if changes.cover_image_name is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET cover_image_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.cover_image_name, board_id),
|
||||
)
|
||||
# Change the cover image of a board
|
||||
if changes.cover_image_name is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET cover_image_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.cover_image_name, board_id),
|
||||
)
|
||||
|
||||
# Change the archived status of a board
|
||||
if changes.archived is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET archived = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.archived, board_id),
|
||||
)
|
||||
# Change the archived status of a board
|
||||
if changes.archived is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET archived = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.archived, board_id),
|
||||
)
|
||||
|
||||
except sqlite3.Error as e:
|
||||
raise BoardRecordSaveException from e
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordSaveException from e
|
||||
return self.get(board_id)
|
||||
|
||||
def get_many(
|
||||
@@ -127,77 +133,78 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
limit: int = 10,
|
||||
include_archived: bool = False,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
with self._db.transaction() as cursor:
|
||||
# Build base query
|
||||
base_query = """
|
||||
SELECT *
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
# Build base query
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
LIMIT ? OFFSET ?;
|
||||
"""
|
||||
|
||||
# Determine archived filter condition
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
# Execute query to fetch boards
|
||||
cursor.execute(final_query, (limit, offset))
|
||||
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
# Determine count query
|
||||
if include_archived:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards;
|
||||
"""
|
||||
else:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
LIMIT ? OFFSET ?;
|
||||
WHERE archived = 0;
|
||||
"""
|
||||
|
||||
# Determine archived filter condition
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
# Execute count query
|
||||
cursor.execute(count_query)
|
||||
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
# Execute query to fetch boards
|
||||
cursor.execute(final_query, (limit, offset))
|
||||
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
# Determine count query
|
||||
if include_archived:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards;
|
||||
"""
|
||||
else:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards
|
||||
WHERE archived = 0;
|
||||
"""
|
||||
|
||||
# Execute count query
|
||||
cursor.execute(count_query)
|
||||
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
|
||||
|
||||
def get_all(
|
||||
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
|
||||
) -> list[BoardRecord]:
|
||||
with self._db.transaction() as cursor:
|
||||
if order_by == BoardRecordOrderBy.Name:
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY LOWER(board_name) {direction}
|
||||
"""
|
||||
else:
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
"""
|
||||
cursor = self._conn.cursor()
|
||||
if order_by == BoardRecordOrderBy.Name:
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY LOWER(board_name) {direction}
|
||||
"""
|
||||
else:
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
"""
|
||||
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
cursor.execute(final_query)
|
||||
cursor.execute(final_query)
|
||||
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
return boards
|
||||
|
||||
@@ -8,7 +8,6 @@ import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from queue import Empty, PriorityQueue
|
||||
from shutil import disk_usage
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
|
||||
|
||||
import requests
|
||||
@@ -336,14 +335,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
|
||||
assert job.download_path
|
||||
|
||||
free_space = disk_usage(job.download_path.parent).free
|
||||
GB = 2**30
|
||||
self._logger.debug(f"Download is {job.total_bytes / GB:.2f} GB of {free_space / GB:.2f} GB free.")
|
||||
if free_space < job.total_bytes:
|
||||
raise RuntimeError(
|
||||
f"Free disk space {free_space / GB:.2f} GB is not enough for download of {job.total_bytes / GB:.2f} GB."
|
||||
)
|
||||
|
||||
# Don't clobber an existing file. See commit 82c2c85202f88c6d24ff84710f297cfc6ae174af
|
||||
# for code that instead resumes an interrupted download.
|
||||
if job.download_path.exists():
|
||||
|
||||
@@ -24,22 +24,22 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._conn = db.conn
|
||||
|
||||
def get(self, image_name: str) -> ImageRecord:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS} FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS} FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordNotFoundException from e
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordNotFoundException from e
|
||||
|
||||
if not result:
|
||||
raise ImageRecordNotFoundException
|
||||
@@ -47,20 +47,17 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
return deserialize_image_record(dict(result))
|
||||
|
||||
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordNotFoundException from e
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
|
||||
if not result:
|
||||
raise ImageRecordNotFoundException
|
||||
@@ -68,60 +65,64 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
as_dict = dict(result)
|
||||
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
|
||||
return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordNotFoundException from e
|
||||
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
# Change the category of the image
|
||||
if changes.image_category is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET image_category = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.image_category, image_name),
|
||||
)
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
# Change the category of the image
|
||||
if changes.image_category is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET image_category = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.image_category, image_name),
|
||||
)
|
||||
|
||||
# Change the session associated with the image
|
||||
if changes.session_id is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET session_id = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.session_id, image_name),
|
||||
)
|
||||
# Change the session associated with the image
|
||||
if changes.session_id is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET session_id = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.session_id, image_name),
|
||||
)
|
||||
|
||||
# Change the image's `is_intermediate`` flag
|
||||
if changes.is_intermediate is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET is_intermediate = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.is_intermediate, image_name),
|
||||
)
|
||||
# Change the image's `is_intermediate`` flag
|
||||
if changes.is_intermediate is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET is_intermediate = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.is_intermediate, image_name),
|
||||
)
|
||||
|
||||
# Change the image's `starred`` state
|
||||
if changes.starred is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET starred = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.starred, image_name),
|
||||
)
|
||||
# Change the image's `starred`` state
|
||||
if changes.starred is not None:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET starred = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.starred, image_name),
|
||||
)
|
||||
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordSaveException from e
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordSaveException from e
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
@@ -135,162 +136,170 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
with self._db.transaction() as cursor:
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
count_query = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
count_query = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
images_query = f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS}
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
|
||||
images_query = f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS}
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
# board_id of "none" is reserved for images without a board
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
# Search term condition
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
if starred_first:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
else:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
# Final images query with pagination
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
# Add all the parameters
|
||||
images_params = query_params.copy()
|
||||
# Add the pagination parameters
|
||||
images_params.extend([limit, offset])
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
# Build the list of images, deserializing each row
|
||||
cursor.execute(images_query, images_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
# board_id of "none" is reserved for images without a board
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
# Search term condition
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
if starred_first:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
else:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
# Final images query with pagination
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
# Add all the parameters
|
||||
images_params = query_params.copy()
|
||||
# Add the pagination parameters
|
||||
images_params.extend([limit, offset])
|
||||
|
||||
# Build the list of images, deserializing each row
|
||||
cursor.execute(images_query, images_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
|
||||
# Set up and execute the count query, without pagination
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
cursor.execute(count_query, count_params)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
# Set up and execute the count query, without pagination
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
cursor.execute(count_query, count_params)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||
|
||||
def delete(self, image_name: str) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordDeleteException from e
|
||||
|
||||
def delete_many(self, image_names: list[str]) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
placeholders = ",".join("?" for _ in image_names)
|
||||
|
||||
# Construct the SQLite query with the placeholders
|
||||
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
|
||||
|
||||
# Execute the query with the list of IDs as parameters
|
||||
cursor.execute(query, image_names)
|
||||
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordDeleteException from e
|
||||
|
||||
def get_intermediates_count(self) -> int:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
DELETE FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
|
||||
def delete_many(self, image_names: list[str]) -> None:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
placeholders = ",".join("?" for _ in image_names)
|
||||
|
||||
# Construct the SQLite query with the placeholders
|
||||
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
|
||||
|
||||
# Execute the query with the list of IDs as parameters
|
||||
cursor.execute(query, image_names)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
|
||||
def get_intermediates_count(self) -> int:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
self._conn.commit()
|
||||
return count
|
||||
|
||||
def delete_intermediates(self) -> list[str]:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT image_name FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [r[0] for r in result]
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordDeleteException from e
|
||||
return image_names
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT image_name FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [r[0] for r in result]
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
self._conn.commit()
|
||||
return image_names
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
|
||||
def save(
|
||||
self,
|
||||
@@ -306,71 +315,73 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
node_id: Optional[str] = None,
|
||||
metadata: Optional[str] = None,
|
||||
) -> datetime:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO images (
|
||||
image_name,
|
||||
image_origin,
|
||||
image_category,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
image_origin.value,
|
||||
image_category.value,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow,
|
||||
),
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT created_at
|
||||
FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
created_at = datetime.fromisoformat(cursor.fetchone()[0])
|
||||
|
||||
except sqlite3.Error as e:
|
||||
raise ImageRecordSaveException from e
|
||||
return created_at
|
||||
|
||||
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM images
|
||||
JOIN board_images ON images.image_name = board_images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
AND images.is_intermediate = FALSE
|
||||
ORDER BY images.starred DESC, images.created_at DESC
|
||||
LIMIT 1;
|
||||
INSERT OR IGNORE INTO images (
|
||||
image_name,
|
||||
image_origin,
|
||||
image_category,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(board_id,),
|
||||
(
|
||||
image_name,
|
||||
image_origin.value,
|
||||
image_category.value,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT created_at
|
||||
FROM images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
created_at = datetime.fromisoformat(cursor.fetchone()[0])
|
||||
|
||||
return created_at
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordSaveException from e
|
||||
|
||||
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM images
|
||||
JOIN board_images ON images.image_name = board_images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
AND images.is_intermediate = FALSE
|
||||
ORDER BY images.starred DESC, images.created_at DESC
|
||||
LIMIT 1;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
@@ -387,84 +398,85 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> ImageNamesResult:
|
||||
with self._db.transaction() as cursor:
|
||||
# Build query conditions (reused for both starred count and image names queries)
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
# Build query conditions (reused for both starred count and image names queries)
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
|
||||
if categories is not None:
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
query_params.append(is_intermediate)
|
||||
if categories is not None:
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
# Get starred count if starred_first is enabled
|
||||
starred_count = 0
|
||||
if starred_first:
|
||||
starred_count_query = f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE images.starred = TRUE AND (1=1{query_conditions})
|
||||
"""
|
||||
cursor.execute(starred_count_query, query_params)
|
||||
starred_count = cast(int, cursor.fetchone()[0])
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
# Get all image names with proper ordering
|
||||
if starred_first:
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value}
|
||||
"""
|
||||
else:
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.created_at {order_dir.value}
|
||||
"""
|
||||
# Get starred count if starred_first is enabled
|
||||
starred_count = 0
|
||||
if starred_first:
|
||||
starred_count_query = f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE images.starred = TRUE AND (1=1{query_conditions})
|
||||
"""
|
||||
cursor.execute(starred_count_query, query_params)
|
||||
starred_count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
cursor.execute(names_query, query_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
# Get all image names with proper ordering
|
||||
if starred_first:
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value}
|
||||
"""
|
||||
else:
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.created_at {order_dir.value}
|
||||
"""
|
||||
|
||||
cursor.execute(names_query, query_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [row[0] for row in result]
|
||||
|
||||
return ImageNamesResult(image_names=image_names, starred_count=starred_count, total_count=len(image_names))
|
||||
|
||||
@@ -78,6 +78,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db = db
|
||||
self._logger = logger
|
||||
|
||||
@property
|
||||
def db(self) -> SqliteDatabase:
|
||||
"""Return the underlying database."""
|
||||
return self._db
|
||||
|
||||
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
@@ -88,33 +93,38 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||
"""
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO models (
|
||||
id,
|
||||
config
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
config.key,
|
||||
config.model_dump_json(),
|
||||
),
|
||||
)
|
||||
try:
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO models (
|
||||
id,
|
||||
config
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
config.key,
|
||||
config.model_dump_json(),
|
||||
),
|
||||
)
|
||||
self._db.conn.commit()
|
||||
|
||||
except sqlite3.IntegrityError as e:
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "models.path" in str(e):
|
||||
msg = f"A model with path '{config.path}' is already installed"
|
||||
elif "models.name" in str(e):
|
||||
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
||||
else:
|
||||
msg = f"A model with key '{config.key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._db.conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "models.path" in str(e):
|
||||
msg = f"A model with path '{config.path}' is already installed"
|
||||
elif "models.name" in str(e):
|
||||
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
||||
else:
|
||||
raise e
|
||||
msg = f"A model with key '{config.key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
else:
|
||||
raise e
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(config.key)
|
||||
|
||||
@@ -126,7 +136,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise an UnknownModelException
|
||||
"""
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM models
|
||||
@@ -136,17 +147,22 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
)
|
||||
if cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
with self._db.transaction() as cursor:
|
||||
record = self.get_model(key)
|
||||
record = self.get_model(key)
|
||||
|
||||
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
|
||||
for field_name in changes.model_fields_set:
|
||||
setattr(record, field_name, getattr(changes, field_name))
|
||||
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
|
||||
for field_name in changes.model_fields_set:
|
||||
setattr(record, field_name, getattr(changes, field_name))
|
||||
|
||||
json_serialized = record.model_dump_json()
|
||||
json_serialized = record.model_dump_json()
|
||||
|
||||
try:
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE models
|
||||
@@ -158,6 +174,10 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
)
|
||||
if cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(key)
|
||||
|
||||
@@ -169,30 +189,30 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
"""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
rows = cursor.fetchone()
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
rows = cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
rows = cursor.fetchone()
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
rows = cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
@@ -204,15 +224,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
:param key: Unique key for the model to be deleted
|
||||
"""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
return count > 0
|
||||
|
||||
def search_by_attr(
|
||||
@@ -235,42 +255,43 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
"""
|
||||
with self._db.transaction() as cursor:
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
|
||||
where_clause: list[str] = []
|
||||
bindings: list[str] = []
|
||||
if model_name:
|
||||
where_clause.append("name=?")
|
||||
bindings.append(model_name)
|
||||
if base_model:
|
||||
where_clause.append("base=?")
|
||||
bindings.append(base_model)
|
||||
if model_type:
|
||||
where_clause.append("type=?")
|
||||
bindings.append(model_type)
|
||||
if model_format:
|
||||
where_clause.append("format=?")
|
||||
bindings.append(model_format)
|
||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config, strftime('%s',updated_at)
|
||||
FROM models
|
||||
{where}
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
result = cursor.fetchall()
|
||||
where_clause: list[str] = []
|
||||
bindings: list[str] = []
|
||||
if model_name:
|
||||
where_clause.append("name=?")
|
||||
bindings.append(model_name)
|
||||
if base_model:
|
||||
where_clause.append("base=?")
|
||||
bindings.append(base_model)
|
||||
if model_type:
|
||||
where_clause.append("type=?")
|
||||
bindings.append(model_type)
|
||||
if model_format:
|
||||
where_clause.append("format=?")
|
||||
bindings.append(model_format)
|
||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config, strftime('%s',updated_at)
|
||||
FROM models
|
||||
{where}
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
result = cursor.fetchall()
|
||||
|
||||
# Parse the model configs.
|
||||
results: list[AnyModelConfig] = []
|
||||
@@ -292,68 +313,69 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated path."""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
return results
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated hash."""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
return results
|
||||
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Return a paginated summary listing of each model in the database."""
|
||||
with self._db.transaction() as cursor:
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
# query1: get the total number of model configs
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from models;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(cursor.fetchone()[0])
|
||||
cursor = self._db.conn.cursor()
|
||||
|
||||
# query2: fetch key fields
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config
|
||||
FROM models
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
""",
|
||||
(
|
||||
per_page,
|
||||
page * per_page,
|
||||
),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
# query1: get the total number of model configs
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from models;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(cursor.fetchone()[0])
|
||||
|
||||
# query2: fetch key fields
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config
|
||||
FROM models
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
""",
|
||||
(
|
||||
per_page,
|
||||
page * per_page,
|
||||
),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
items = [ModelSummary.model_validate(dict(x)) for x in rows]
|
||||
return PaginatedResults(page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.model_relationship_records.model_relationship_records_base import (
|
||||
ModelRelationshipRecordStorageBase,
|
||||
)
|
||||
@@ -7,49 +9,58 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._conn = db.conn
|
||||
|
||||
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
if model_key_1 == model_key_2:
|
||||
raise ValueError("Cannot relate a model to itself.")
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
if model_key_1 == model_key_2:
|
||||
raise ValueError("Cannot relate a model to itself.")
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"INSERT OR IGNORE INTO model_relationships (model_key_1, model_key_2) VALUES (?, ?)",
|
||||
(a, b),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"DELETE FROM model_relationships WHERE model_key_1 = ? AND model_key_2 = ?",
|
||||
(a, b),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
def get_related_model_keys(self, model_key: str) -> list[str]:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
|
||||
UNION
|
||||
SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
|
||||
""",
|
||||
(model_key, model_key),
|
||||
)
|
||||
result = [row[0] for row in cursor.fetchall()]
|
||||
return result
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
|
||||
UNION
|
||||
SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
|
||||
""",
|
||||
(model_key, model_key),
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
|
||||
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
|
||||
with self._db.transaction() as cursor:
|
||||
key_list = ",".join("?" for _ in model_keys)
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
|
||||
UNION
|
||||
SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
|
||||
""",
|
||||
model_keys + model_keys,
|
||||
)
|
||||
result = [row[0] for row in cursor.fetchall()]
|
||||
return result
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
key_list = ",".join("?" for _ in model_keys)
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
|
||||
UNION
|
||||
SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
|
||||
""",
|
||||
model_keys + model_keys,
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
|
||||
@@ -50,14 +50,15 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._conn = db.conn
|
||||
|
||||
def _set_in_progress_to_canceled(self) -> None:
|
||||
"""
|
||||
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
|
||||
This is necessary because the invoker may have been killed while processing a queue item.
|
||||
"""
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
@@ -65,79 +66,87 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
WHERE status = 'in_progress';
|
||||
"""
|
||||
)
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
def _get_current_queue_size(self, queue_id: str) -> int:
|
||||
"""Gets the current number of pending queue items"""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
return count
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
return cast(int, cursor.fetchone()[0])
|
||||
|
||||
def _get_highest_priority(self, queue_id: str) -> int:
|
||||
"""Gets the highest priority value in the queue"""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT MAX(priority)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
priority = cast(Union[int, None], cursor.fetchone()[0]) or 0
|
||||
return priority
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT MAX(priority)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
return cast(Union[int, None], cursor.fetchone()[0]) or 0
|
||||
|
||||
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||
current_queue_size = self._get_current_queue_size(queue_id)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
max_new_queue_items = max_queue_size - current_queue_size
|
||||
try:
|
||||
# TODO: how does this work in a multi-user scenario?
|
||||
current_queue_size = self._get_current_queue_size(queue_id)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
max_new_queue_items = max_queue_size - current_queue_size
|
||||
|
||||
priority = 0
|
||||
if prepend:
|
||||
priority = self._get_highest_priority(queue_id) + 1
|
||||
priority = 0
|
||||
if prepend:
|
||||
priority = self._get_highest_priority(queue_id) + 1
|
||||
|
||||
requested_count = await asyncio.to_thread(
|
||||
calc_session_count,
|
||||
batch=batch,
|
||||
)
|
||||
values_to_insert = await asyncio.to_thread(
|
||||
prepare_values_to_insert,
|
||||
queue_id=queue_id,
|
||||
batch=batch,
|
||||
priority=priority,
|
||||
max_new_queue_items=max_new_queue_items,
|
||||
)
|
||||
enqueued_count = len(values_to_insert)
|
||||
requested_count = await asyncio.to_thread(
|
||||
calc_session_count,
|
||||
batch=batch,
|
||||
)
|
||||
values_to_insert = await asyncio.to_thread(
|
||||
prepare_values_to_insert,
|
||||
queue_id=queue_id,
|
||||
batch=batch,
|
||||
priority=priority,
|
||||
max_new_queue_items=max_new_queue_items,
|
||||
)
|
||||
enqueued_count = len(values_to_insert)
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.executemany(
|
||||
"""--sql
|
||||
with self._conn:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
values_to_insert,
|
||||
)
|
||||
with self._conn:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT item_id
|
||||
FROM session_queue
|
||||
WHERE batch_id = ?
|
||||
ORDER BY item_id DESC;
|
||||
""",
|
||||
(batch.batch_id,),
|
||||
)
|
||||
item_ids = [row[0] for row in cursor.fetchall()]
|
||||
(batch.batch_id,),
|
||||
)
|
||||
item_ids = [row[0] for row in cursor.fetchall()]
|
||||
except Exception:
|
||||
raise
|
||||
enqueue_result = EnqueueBatchResult(
|
||||
queue_id=queue_id,
|
||||
requested=requested_count,
|
||||
@@ -150,19 +159,19 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return enqueue_result
|
||||
|
||||
def dequeue(self) -> Optional[SessionQueueItem]:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
return None
|
||||
queue_item = SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
@@ -170,40 +179,40 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return queue_item
|
||||
|
||||
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
created_at ASC
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
created_at ASC
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
return None
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
|
||||
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'in_progress'
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'in_progress'
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
return None
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
@@ -216,7 +225,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
error_message: Optional[str] = None,
|
||||
error_traceback: Optional[str] = None,
|
||||
) -> SessionQueueItem:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status FROM session_queue WHERE item_id = ?
|
||||
@@ -224,15 +234,12 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
(item_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
current_status = row[0]
|
||||
|
||||
# Only update if not already finished (completed, failed or canceled)
|
||||
if current_status in ("completed", "failed", "canceled"):
|
||||
return self.get_queue_item(item_id)
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
if row is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
current_status = row[0]
|
||||
# Only update if not already finished (completed, failed or canceled)
|
||||
if current_status in ("completed", "failed", "canceled"):
|
||||
return self.get_queue_item(item_id)
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
@@ -241,7 +248,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(status, error_type, error_message, error_traceback, item_id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
|
||||
@@ -249,34 +259,35 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return queue_item
|
||||
|
||||
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
is_empty = cast(int, cursor.fetchone()[0]) == 0
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
is_empty = cast(int, cursor.fetchone()[0]) == 0
|
||||
return IsEmptyResult(is_empty=is_empty)
|
||||
|
||||
def is_full(self, queue_id: str) -> IsFullResult:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
|
||||
return IsFullResult(is_full=is_full)
|
||||
|
||||
def clear(self, queue_id: str) -> ClearResult:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
@@ -294,19 +305,24 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
self.__invoker.services.events.emit_queue_cleared(queue_id)
|
||||
return ClearResult(deleted=count)
|
||||
|
||||
def prune(self, queue_id: str) -> PruneResult:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND (
|
||||
queue_id = ?
|
||||
AND (
|
||||
status = 'completed'
|
||||
OR status = 'failed'
|
||||
OR status = 'canceled'
|
||||
)
|
||||
)
|
||||
"""
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
@@ -325,6 +341,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return PruneResult(deleted=count)
|
||||
|
||||
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
@@ -337,7 +357,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.cancel_queue_item(item_id)
|
||||
except SessionQueueItemNotFoundError:
|
||||
pass
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE
|
||||
@@ -346,6 +367,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
|
||||
@@ -368,7 +393,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return queue_item
|
||||
|
||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
placeholders = ", ".join(["?" for _ in batch_ids])
|
||||
where = f"""--sql
|
||||
@@ -399,14 +425,17 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
|
||||
self._conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return CancelByBatchIDsResult(canceled=count)
|
||||
|
||||
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
where = """--sql
|
||||
WHERE
|
||||
@@ -436,12 +465,17 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
params,
|
||||
)
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
self._conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return CancelByDestinationResult(canceled=count)
|
||||
|
||||
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
self.cancel_queue_item(current_queue_item.item_id)
|
||||
@@ -467,10 +501,15 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
params,
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return DeleteByDestinationResult(deleted=count)
|
||||
|
||||
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
@@ -493,10 +532,15 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return DeleteAllExceptCurrentResult(deleted=count)
|
||||
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
where = """--sql
|
||||
WHERE
|
||||
@@ -525,13 +569,18 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return CancelByQueueIDResult(canceled=count)
|
||||
|
||||
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
@@ -554,25 +603,30 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return CancelAllExceptCurrentResult(canceled=count)
|
||||
|
||||
def get_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT * FROM session_queue
|
||||
WHERE
|
||||
item_id = ?
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT * FROM session_queue
|
||||
WHERE
|
||||
item_id = ?
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
|
||||
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
|
||||
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
|
||||
# during execution.
|
||||
@@ -585,6 +639,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(session_json, item_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return self.get_queue_item(item_id)
|
||||
|
||||
def list_queue_items(
|
||||
@@ -596,42 +654,42 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||
destination: Optional[str] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItem]:
|
||||
with self._db.transaction() as cursor_:
|
||||
item_id = cursor
|
||||
query = """--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if status is not None:
|
||||
query += """--sql
|
||||
AND status = ?
|
||||
"""
|
||||
params.append(status)
|
||||
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
"""
|
||||
params.append(destination)
|
||||
|
||||
if item_id is not None:
|
||||
query += """--sql
|
||||
AND (priority < ?) OR (priority = ? AND item_id > ?)
|
||||
"""
|
||||
params.extend([priority, priority, item_id])
|
||||
cursor_ = self._conn.cursor()
|
||||
item_id = cursor
|
||||
query = """--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if status is not None:
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT ?
|
||||
AND status = ?
|
||||
"""
|
||||
params.append(limit + 1)
|
||||
cursor_.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor_.fetchall())
|
||||
params.append(status)
|
||||
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
"""
|
||||
params.append(destination)
|
||||
|
||||
if item_id is not None:
|
||||
query += """--sql
|
||||
AND (priority < ?) OR (priority = ? AND item_id > ?)
|
||||
"""
|
||||
params.extend([priority, priority, item_id])
|
||||
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit + 1)
|
||||
cursor_.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor_.fetchall())
|
||||
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
|
||||
has_more = False
|
||||
if len(items) > limit:
|
||||
@@ -646,43 +704,43 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
destination: Optional[str] = None,
|
||||
) -> list[SessionQueueItem]:
|
||||
"""Gets all queue items that match the given parameters"""
|
||||
with self._db.transaction() as cursor:
|
||||
query = """--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
cursor_ = self._conn.cursor()
|
||||
query = """--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
params.append(destination)
|
||||
|
||||
if destination is not None:
|
||||
query += """---sql
|
||||
AND destination = ?
|
||||
"""
|
||||
params.append(destination)
|
||||
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
;
|
||||
"""
|
||||
cursor.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
;
|
||||
"""
|
||||
cursor_.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor_.fetchall())
|
||||
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
|
||||
return items
|
||||
|
||||
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
current_item = self.get_current(queue_id=queue_id)
|
||||
total = sum(row[1] or 0 for row in counts_result)
|
||||
@@ -701,19 +759,19 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
|
||||
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*), origin, destination
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND batch_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, batch_id),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*), origin, destination
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND batch_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, batch_id),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
total = sum(row[1] or 0 for row in result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||
origin = result[0]["origin"] if result else None
|
||||
@@ -733,18 +791,18 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
|
||||
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
AND destination = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, destination),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
AND destination = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, destination),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
total = sum(row[1] or 0 for row in counts_result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
|
||||
@@ -762,7 +820,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
|
||||
"""Retries the given queue items"""
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
values_to_insert: list[ValueToInsertTuple] = []
|
||||
retried_item_ids: list[int] = []
|
||||
|
||||
@@ -813,6 +872,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
values_to_insert,
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
retry_result = RetryItemsResult(
|
||||
queue_id=queue_id,
|
||||
retried_item_ids=retried_item_ids,
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
@@ -29,65 +26,46 @@ class SqliteDatabase:
|
||||
|
||||
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
|
||||
"""Initializes the database. This is used internally by the class constructor."""
|
||||
self._logger = logger
|
||||
self._db_path = db_path
|
||||
self._verbose = verbose
|
||||
self._lock = threading.RLock()
|
||||
self.logger = logger
|
||||
self.db_path = db_path
|
||||
self.verbose = verbose
|
||||
|
||||
if not self._db_path:
|
||||
if not self.db_path:
|
||||
logger.info("Initializing in-memory database")
|
||||
else:
|
||||
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._logger.info(f"Initializing database at {self._db_path}")
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.logger.info(f"Initializing database at {self.db_path}")
|
||||
|
||||
self._conn = sqlite3.connect(database=self._db_path or sqlite_memory, check_same_thread=False)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
|
||||
if self._verbose:
|
||||
self._conn.set_trace_callback(self._logger.debug)
|
||||
if self.verbose:
|
||||
self.conn.set_trace_callback(self.logger.debug)
|
||||
|
||||
# Enable foreign key constraints
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self.conn.execute("PRAGMA foreign_keys = ON;")
|
||||
|
||||
# Enable Write-Ahead Logging (WAL) mode for better concurrency
|
||||
self._conn.execute("PRAGMA journal_mode = WAL;")
|
||||
self.conn.execute("PRAGMA journal_mode = WAL;")
|
||||
|
||||
# Set a busy timeout to prevent database lockups during writes
|
||||
self._conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
|
||||
self.conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
|
||||
|
||||
def clean(self) -> None:
|
||||
"""
|
||||
Cleans the database by running the VACUUM command, reporting on the freed space.
|
||||
"""
|
||||
# No need to clean in-memory database
|
||||
if not self._db_path:
|
||||
if not self.db_path:
|
||||
return
|
||||
try:
|
||||
with self._conn as conn:
|
||||
initial_db_size = Path(self._db_path).stat().st_size
|
||||
conn.execute("VACUUM;")
|
||||
conn.commit()
|
||||
final_db_size = Path(self._db_path).stat().st_size
|
||||
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
||||
if freed_space_in_mb > 0:
|
||||
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
||||
initial_db_size = Path(self.db_path).stat().st_size
|
||||
self.conn.execute("VACUUM;")
|
||||
self.conn.commit()
|
||||
final_db_size = Path(self.db_path).stat().st_size
|
||||
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
||||
if freed_space_in_mb > 0:
|
||||
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error cleaning database: {e}")
|
||||
self.logger.error(f"Error cleaning database: {e}")
|
||||
raise
|
||||
|
||||
@contextmanager
|
||||
def transaction(self) -> Generator[sqlite3.Cursor, None, None]:
|
||||
"""
|
||||
Thread-safe context manager for DB work.
|
||||
Acquires the RLock, yields a Cursor, then commits or rolls back.
|
||||
"""
|
||||
with self._lock:
|
||||
cursor = self._conn.cursor()
|
||||
try:
|
||||
yield cursor
|
||||
self._conn.commit()
|
||||
except:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
@@ -32,7 +32,7 @@ class SqliteMigrator:
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
self._db = db
|
||||
self._logger = db._logger
|
||||
self._logger = db.logger
|
||||
self._migration_set = MigrationSet()
|
||||
self._backup_path: Optional[Path] = None
|
||||
|
||||
@@ -45,7 +45,7 @@ class SqliteMigrator:
|
||||
"""Migrates the database to the latest version."""
|
||||
# This throws if there is a problem.
|
||||
self._migration_set.validate_migration_chain()
|
||||
cursor = self._db._conn.cursor()
|
||||
cursor = self._db.conn.cursor()
|
||||
self._create_migrations_table(cursor=cursor)
|
||||
|
||||
if self._migration_set.count == 0:
|
||||
@@ -59,13 +59,13 @@ class SqliteMigrator:
|
||||
self._logger.info("Database update needed")
|
||||
|
||||
# Make a backup of the db if it needs to be updated and is a file db
|
||||
if self._db._db_path is not None:
|
||||
if self._db.db_path is not None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
self._backup_path = self._db._db_path.parent / f"{self._db._db_path.stem}_backup_{timestamp}.db"
|
||||
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
|
||||
self._logger.info(f"Backing up database to {str(self._backup_path)}")
|
||||
# Use SQLite to do the backup
|
||||
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
|
||||
self._db._conn.backup(backup_conn)
|
||||
self._db.conn.backup(backup_conn)
|
||||
else:
|
||||
self._logger.info("Using in-memory database, no backup needed")
|
||||
|
||||
@@ -81,7 +81,7 @@ class SqliteMigrator:
|
||||
try:
|
||||
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
|
||||
# exception is raised.
|
||||
with self._db._conn as conn:
|
||||
with self._db.conn as conn:
|
||||
cursor = conn.cursor()
|
||||
if self._get_current_version(cursor) != migration.from_version:
|
||||
raise MigrationError(
|
||||
|
||||
@@ -17,7 +17,7 @@ from invokeai.app.util.misc import uuid_string
|
||||
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._conn = db.conn
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
@@ -25,23 +25,24 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
|
||||
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
|
||||
"""Gets a style preset by ID."""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM style_presets
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM style_presets
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
|
||||
return StylePresetRecordDTO.from_dict(dict(row))
|
||||
|
||||
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
|
||||
style_preset_id = uuid_string()
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO style_presets (
|
||||
@@ -59,11 +60,16 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
style_preset.type,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
|
||||
style_preset_ids = []
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
for style_preset in style_presets:
|
||||
style_preset_id = uuid_string()
|
||||
style_preset_ids.append(style_preset_id)
|
||||
@@ -84,11 +90,16 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
style_preset.type,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
return None
|
||||
|
||||
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
# Change the name of a style preset
|
||||
if changes.name is not None:
|
||||
cursor.execute(
|
||||
@@ -111,10 +122,15 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
(changes.preset_data.model_dump_json(), style_preset_id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE from style_presets
|
||||
@@ -122,41 +138,51 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return None
|
||||
|
||||
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
|
||||
with self._db.transaction() as cursor:
|
||||
main_query = """
|
||||
SELECT
|
||||
*
|
||||
FROM style_presets
|
||||
"""
|
||||
main_query = """
|
||||
SELECT
|
||||
*
|
||||
FROM style_presets
|
||||
"""
|
||||
|
||||
if type is not None:
|
||||
main_query += "WHERE type = ? "
|
||||
if type is not None:
|
||||
main_query += "WHERE type = ? "
|
||||
|
||||
main_query += "ORDER BY LOWER(name) ASC"
|
||||
main_query += "ORDER BY LOWER(name) ASC"
|
||||
|
||||
if type is not None:
|
||||
cursor.execute(main_query, (type,))
|
||||
else:
|
||||
cursor.execute(main_query)
|
||||
cursor = self._conn.cursor()
|
||||
if type is not None:
|
||||
cursor.execute(main_query, (type,))
|
||||
else:
|
||||
cursor.execute(main_query)
|
||||
|
||||
rows = cursor.fetchall()
|
||||
rows = cursor.fetchall()
|
||||
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
|
||||
|
||||
return style_presets
|
||||
|
||||
def _sync_default_style_presets(self) -> None:
|
||||
"""Syncs default style presets to the database. Internal use only."""
|
||||
with self._db.transaction() as cursor:
|
||||
# First delete all existing default style presets
|
||||
|
||||
# First delete all existing default style presets
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM style_presets
|
||||
WHERE type = "default";
|
||||
"""
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
# Next, parse and create the default style presets
|
||||
with open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
|
||||
presets = json.load(file)
|
||||
|
||||
@@ -25,7 +25,7 @@ SQL_TIME_FORMAT = "%Y-%m-%d %H:%M:%f"
|
||||
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._conn = db.conn
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
@@ -33,16 +33,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
|
||||
def get(self, workflow_id: str) -> WorkflowRecordDTO:
|
||||
"""Gets a workflow by ID. Updates the opened_at column."""
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
|
||||
FROM workflow_library
|
||||
WHERE workflow_id = ?;
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
|
||||
FROM workflow_library
|
||||
WHERE workflow_id = ?;
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
|
||||
return WorkflowRecordDTO.from_dict(dict(row))
|
||||
@@ -51,8 +51,9 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
if workflow.meta.category is WorkflowCategory.Default:
|
||||
raise ValueError("Default workflows cannot be created via this method")
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO workflow_library (
|
||||
@@ -63,13 +64,18 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
""",
|
||||
(workflow_with_id.id, workflow_with_id.model_dump_json()),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return self.get(workflow_with_id.id)
|
||||
|
||||
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
|
||||
if workflow.meta.category is WorkflowCategory.Default:
|
||||
raise ValueError("Default workflows cannot be updated")
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE workflow_library
|
||||
@@ -78,13 +84,18 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
""",
|
||||
(workflow.model_dump_json(), workflow.id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return self.get(workflow.id)
|
||||
|
||||
def delete(self, workflow_id: str) -> None:
|
||||
if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default:
|
||||
raise ValueError("Default workflows cannot be deleted")
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE from workflow_library
|
||||
@@ -92,6 +103,10 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
return None
|
||||
|
||||
def get_many(
|
||||
@@ -106,108 +121,108 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
with self._db.transaction() as cursor:
|
||||
# sanitize!
|
||||
assert order_by in WorkflowRecordOrderBy
|
||||
assert direction in SQLiteDirection
|
||||
# sanitize!
|
||||
assert order_by in WorkflowRecordOrderBy
|
||||
assert direction in SQLiteDirection
|
||||
|
||||
# We will construct the query dynamically based on the query params
|
||||
# We will construct the query dynamically based on the query params
|
||||
|
||||
# The main query to get the workflows / counts
|
||||
main_query = """
|
||||
SELECT
|
||||
workflow_id,
|
||||
category,
|
||||
name,
|
||||
description,
|
||||
created_at,
|
||||
updated_at,
|
||||
opened_at,
|
||||
tags
|
||||
FROM workflow_library
|
||||
"""
|
||||
count_query = "SELECT COUNT(*) FROM workflow_library"
|
||||
# The main query to get the workflows / counts
|
||||
main_query = """
|
||||
SELECT
|
||||
workflow_id,
|
||||
category,
|
||||
name,
|
||||
description,
|
||||
created_at,
|
||||
updated_at,
|
||||
opened_at,
|
||||
tags
|
||||
FROM workflow_library
|
||||
"""
|
||||
count_query = "SELECT COUNT(*) FROM workflow_library"
|
||||
|
||||
# Start with an empty list of conditions and params
|
||||
conditions: list[str] = []
|
||||
params: list[str | int] = []
|
||||
# Start with an empty list of conditions and params
|
||||
conditions: list[str] = []
|
||||
params: list[str | int] = []
|
||||
|
||||
if categories:
|
||||
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
|
||||
if categories:
|
||||
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
|
||||
|
||||
# Ensure all categories are valid (is this necessary?)
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
# Ensure all categories are valid (is this necessary?)
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
|
||||
# Construct a placeholder string for the number of categories
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
# Construct a placeholder string for the number of categories
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
|
||||
# Construct the condition string & params
|
||||
category_condition = f"category IN ({placeholders})"
|
||||
category_params = [category.value for category in categories]
|
||||
# Construct the condition string & params
|
||||
category_condition = f"category IN ({placeholders})"
|
||||
category_params = [category.value for category in categories]
|
||||
|
||||
conditions.append(category_condition)
|
||||
params.extend(category_params)
|
||||
conditions.append(category_condition)
|
||||
params.extend(category_params)
|
||||
|
||||
if tags:
|
||||
# Tags is a list of strings, and a single string in the DB
|
||||
# The string in the DB has no guaranteed format
|
||||
if tags:
|
||||
# Tags is a list of strings, and a single string in the DB
|
||||
# The string in the DB has no guaranteed format
|
||||
|
||||
# Construct a list of conditions for each tag
|
||||
tags_conditions = ["tags LIKE ?" for _ in tags]
|
||||
tags_conditions_joined = " OR ".join(tags_conditions)
|
||||
tags_condition = f"({tags_conditions_joined})"
|
||||
# Construct a list of conditions for each tag
|
||||
tags_conditions = ["tags LIKE ?" for _ in tags]
|
||||
tags_conditions_joined = " OR ".join(tags_conditions)
|
||||
tags_condition = f"({tags_conditions_joined})"
|
||||
|
||||
# And the params for the tags, case-insensitive
|
||||
tags_params = [f"%{t.strip()}%" for t in tags]
|
||||
# And the params for the tags, case-insensitive
|
||||
tags_params = [f"%{t.strip()}%" for t in tags]
|
||||
|
||||
conditions.append(tags_condition)
|
||||
params.extend(tags_params)
|
||||
conditions.append(tags_condition)
|
||||
params.extend(tags_params)
|
||||
|
||||
if has_been_opened:
|
||||
conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
conditions.append("opened_at IS NULL")
|
||||
if has_been_opened:
|
||||
conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
conditions.append("opened_at IS NULL")
|
||||
|
||||
# Ignore whitespace in the query
|
||||
stripped_query = query.strip() if query else None
|
||||
if stripped_query:
|
||||
# Construct a wildcard query for the name, description, and tags
|
||||
wildcard_query = "%" + stripped_query + "%"
|
||||
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
|
||||
# Ignore whitespace in the query
|
||||
stripped_query = query.strip() if query else None
|
||||
if stripped_query:
|
||||
# Construct a wildcard query for the name, description, and tags
|
||||
wildcard_query = "%" + stripped_query + "%"
|
||||
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
|
||||
|
||||
conditions.append(query_condition)
|
||||
params.extend([wildcard_query, wildcard_query, wildcard_query])
|
||||
conditions.append(query_condition)
|
||||
params.extend([wildcard_query, wildcard_query, wildcard_query])
|
||||
|
||||
if conditions:
|
||||
# If there are conditions, add a WHERE clause and then join the conditions
|
||||
main_query += " WHERE "
|
||||
count_query += " WHERE "
|
||||
if conditions:
|
||||
# If there are conditions, add a WHERE clause and then join the conditions
|
||||
main_query += " WHERE "
|
||||
count_query += " WHERE "
|
||||
|
||||
all_conditions = " AND ".join(conditions)
|
||||
main_query += all_conditions
|
||||
count_query += all_conditions
|
||||
all_conditions = " AND ".join(conditions)
|
||||
main_query += all_conditions
|
||||
count_query += all_conditions
|
||||
|
||||
# After this point, the query and params differ for the main query and the count query
|
||||
main_params = params.copy()
|
||||
count_params = params.copy()
|
||||
# After this point, the query and params differ for the main query and the count query
|
||||
main_params = params.copy()
|
||||
count_params = params.copy()
|
||||
|
||||
# Main query also gets ORDER BY and LIMIT/OFFSET
|
||||
main_query += f" ORDER BY {order_by.value} {direction.value}"
|
||||
# Main query also gets ORDER BY and LIMIT/OFFSET
|
||||
main_query += f" ORDER BY {order_by.value} {direction.value}"
|
||||
|
||||
if per_page:
|
||||
main_query += " LIMIT ? OFFSET ?"
|
||||
main_params.extend([per_page, page * per_page])
|
||||
if per_page:
|
||||
main_query += " LIMIT ? OFFSET ?"
|
||||
main_params.extend([per_page, page * per_page])
|
||||
|
||||
# Put a ring on it
|
||||
main_query += ";"
|
||||
count_query += ";"
|
||||
# Put a ring on it
|
||||
main_query += ";"
|
||||
count_query += ";"
|
||||
|
||||
cursor.execute(main_query, main_params)
|
||||
rows = cursor.fetchall()
|
||||
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(main_query, main_params)
|
||||
rows = cursor.fetchall()
|
||||
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
|
||||
|
||||
cursor.execute(count_query, count_params)
|
||||
total = cursor.fetchone()[0]
|
||||
cursor.execute(count_query, count_params)
|
||||
total = cursor.fetchone()[0]
|
||||
|
||||
if per_page:
|
||||
pages = total // per_page + (total % per_page > 0)
|
||||
@@ -232,46 +247,46 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
if not tags:
|
||||
return {}
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
result: dict[str, int] = {}
|
||||
# Base conditions for categories and selected tags
|
||||
base_conditions: list[str] = []
|
||||
base_params: list[str | int] = []
|
||||
cursor = self._conn.cursor()
|
||||
result: dict[str, int] = {}
|
||||
# Base conditions for categories and selected tags
|
||||
base_conditions: list[str] = []
|
||||
base_params: list[str | int] = []
|
||||
|
||||
# Add category conditions
|
||||
if categories:
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
base_conditions.append(f"category IN ({placeholders})")
|
||||
base_params.extend([category.value for category in categories])
|
||||
# Add category conditions
|
||||
if categories:
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
base_conditions.append(f"category IN ({placeholders})")
|
||||
base_params.extend([category.value for category in categories])
|
||||
|
||||
if has_been_opened:
|
||||
base_conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
base_conditions.append("opened_at IS NULL")
|
||||
if has_been_opened:
|
||||
base_conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
base_conditions.append("opened_at IS NULL")
|
||||
|
||||
# For each tag to count, run a separate query
|
||||
for tag in tags:
|
||||
# Start with the base conditions
|
||||
conditions = base_conditions.copy()
|
||||
params = base_params.copy()
|
||||
# For each tag to count, run a separate query
|
||||
for tag in tags:
|
||||
# Start with the base conditions
|
||||
conditions = base_conditions.copy()
|
||||
params = base_params.copy()
|
||||
|
||||
# Add this specific tag condition
|
||||
conditions.append("tags LIKE ?")
|
||||
params.append(f"%{tag.strip()}%")
|
||||
# Add this specific tag condition
|
||||
conditions.append("tags LIKE ?")
|
||||
params.append(f"%{tag.strip()}%")
|
||||
|
||||
# Construct the full query
|
||||
stmt = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM workflow_library
|
||||
"""
|
||||
# Construct the full query
|
||||
stmt = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM workflow_library
|
||||
"""
|
||||
|
||||
if conditions:
|
||||
stmt += " WHERE " + " AND ".join(conditions)
|
||||
if conditions:
|
||||
stmt += " WHERE " + " AND ".join(conditions)
|
||||
|
||||
cursor.execute(stmt, params)
|
||||
count = cursor.fetchone()[0]
|
||||
result[tag] = count
|
||||
cursor.execute(stmt, params)
|
||||
count = cursor.fetchone()[0]
|
||||
result[tag] = count
|
||||
|
||||
return result
|
||||
|
||||
@@ -281,51 +296,52 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
with self._db.transaction() as cursor:
|
||||
result: dict[str, int] = {}
|
||||
# Base conditions for categories
|
||||
base_conditions: list[str] = []
|
||||
base_params: list[str | int] = []
|
||||
cursor = self._conn.cursor()
|
||||
result: dict[str, int] = {}
|
||||
# Base conditions for categories
|
||||
base_conditions: list[str] = []
|
||||
base_params: list[str | int] = []
|
||||
|
||||
# Add category conditions
|
||||
if categories:
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
base_conditions.append(f"category IN ({placeholders})")
|
||||
base_params.extend([category.value for category in categories])
|
||||
# Add category conditions
|
||||
if categories:
|
||||
assert all(c in WorkflowCategory for c in categories)
|
||||
placeholders = ", ".join("?" for _ in categories)
|
||||
base_conditions.append(f"category IN ({placeholders})")
|
||||
base_params.extend([category.value for category in categories])
|
||||
|
||||
if has_been_opened:
|
||||
base_conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
base_conditions.append("opened_at IS NULL")
|
||||
if has_been_opened:
|
||||
base_conditions.append("opened_at IS NOT NULL")
|
||||
elif has_been_opened is False:
|
||||
base_conditions.append("opened_at IS NULL")
|
||||
|
||||
# For each category to count, run a separate query
|
||||
for category in categories:
|
||||
# Start with the base conditions
|
||||
conditions = base_conditions.copy()
|
||||
params = base_params.copy()
|
||||
# For each category to count, run a separate query
|
||||
for category in categories:
|
||||
# Start with the base conditions
|
||||
conditions = base_conditions.copy()
|
||||
params = base_params.copy()
|
||||
|
||||
# Add this specific category condition
|
||||
conditions.append("category = ?")
|
||||
params.append(category.value)
|
||||
# Add this specific category condition
|
||||
conditions.append("category = ?")
|
||||
params.append(category.value)
|
||||
|
||||
# Construct the full query
|
||||
stmt = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM workflow_library
|
||||
"""
|
||||
# Construct the full query
|
||||
stmt = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM workflow_library
|
||||
"""
|
||||
|
||||
if conditions:
|
||||
stmt += " WHERE " + " AND ".join(conditions)
|
||||
if conditions:
|
||||
stmt += " WHERE " + " AND ".join(conditions)
|
||||
|
||||
cursor.execute(stmt, params)
|
||||
count = cursor.fetchone()[0]
|
||||
result[category.value] = count
|
||||
cursor.execute(stmt, params)
|
||||
count = cursor.fetchone()[0]
|
||||
result[category.value] = count
|
||||
|
||||
return result
|
||||
|
||||
def update_opened_at(self, workflow_id: str) -> None:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE workflow_library
|
||||
@@ -334,6 +350,10 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
def _sync_default_workflows(self) -> None:
|
||||
"""Syncs default workflows to the database. Internal use only."""
|
||||
@@ -348,7 +368,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
meaningless, as they are overwritten every time the server starts.
|
||||
"""
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
workflows_from_file: list[Workflow] = []
|
||||
workflows_to_update: list[Workflow] = []
|
||||
workflows_to_add: list[Workflow] = []
|
||||
@@ -428,3 +449,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
""",
|
||||
(w.model_dump_json(), w.id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
@@ -187,7 +187,7 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
else:
|
||||
return config_cls.from_model_on_disk(mod, **overrides)
|
||||
|
||||
raise InvalidModelConfigException("Unable to determine model type")
|
||||
raise InvalidModelConfigException("No valid config found")
|
||||
|
||||
@classmethod
|
||||
def get_tag(cls) -> Tag:
|
||||
|
||||
@@ -63,7 +63,7 @@
|
||||
"framer-motion": "^11.10.0",
|
||||
"i18next": "^25.2.1",
|
||||
"i18next-http-backend": "^3.0.2",
|
||||
"idb-keyval": "6.2.1",
|
||||
"idb-keyval": "^6.2.2",
|
||||
"jsondiffpatch": "^0.7.3",
|
||||
"konva": "^9.3.20",
|
||||
"linkify-react": "^4.3.1",
|
||||
|
||||
10
invokeai/frontend/web/pnpm-lock.yaml
generated
10
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -81,8 +81,8 @@ importers:
|
||||
specifier: ^3.0.2
|
||||
version: 3.0.2
|
||||
idb-keyval:
|
||||
specifier: 6.2.1
|
||||
version: 6.2.1
|
||||
specifier: ^6.2.2
|
||||
version: 6.2.2
|
||||
jsondiffpatch:
|
||||
specifier: ^0.7.3
|
||||
version: 0.7.3
|
||||
@@ -2927,8 +2927,8 @@ packages:
|
||||
typescript:
|
||||
optional: true
|
||||
|
||||
idb-keyval@6.2.1:
|
||||
resolution: {integrity: sha512-8Sb3veuYCyrZL+VBt9LJfZjLUPWVvqn8tG28VqYNFCo43KHcKuq+b4EiXGeuaLAQWL2YmyDgMp2aSpH9JHsEQg==}
|
||||
idb-keyval@6.2.2:
|
||||
resolution: {integrity: sha512-yjD9nARJ/jb1g+CvD0tlhUHOrJ9Sy0P8T9MF3YaLlHnSRpwPfpTX0XIvpmw3gAJUmEu3FiICLBDPXVwyEvrleg==}
|
||||
|
||||
ieee754@1.2.1:
|
||||
resolution: {integrity: sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==}
|
||||
@@ -7720,7 +7720,7 @@ snapshots:
|
||||
optionalDependencies:
|
||||
typescript: 5.8.3
|
||||
|
||||
idb-keyval@6.2.1: {}
|
||||
idb-keyval@6.2.2: {}
|
||||
|
||||
ieee754@1.2.1: {}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import { atom, computed } from 'nanostores';
|
||||
import type { RefObject } from 'react';
|
||||
import { useEffect } from 'react';
|
||||
import { objectKeys } from 'tsafe';
|
||||
import z from 'zod/v4';
|
||||
|
||||
/**
|
||||
* We need to manage focus regions to conditionally enable hotkeys:
|
||||
@@ -27,7 +28,10 @@ import { objectKeys } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
const REGION_NAMES = [
|
||||
/**
|
||||
* The names of the focus regions.
|
||||
*/
|
||||
const zFocusRegionName = z.enum([
|
||||
'launchpad',
|
||||
'viewer',
|
||||
'gallery',
|
||||
@@ -37,16 +41,13 @@ const REGION_NAMES = [
|
||||
'workflows',
|
||||
'progress',
|
||||
'settings',
|
||||
] as const;
|
||||
/**
|
||||
* The names of the focus regions.
|
||||
*/
|
||||
export type FocusRegionName = (typeof REGION_NAMES)[number];
|
||||
]);
|
||||
export type FocusRegionName = z.infer<typeof zFocusRegionName>;
|
||||
|
||||
/**
|
||||
* A map of focus regions to the elements that are part of that region.
|
||||
*/
|
||||
const REGION_TARGETS: Record<FocusRegionName, Set<HTMLElement>> = REGION_NAMES.reduce(
|
||||
const REGION_TARGETS: Record<FocusRegionName, Set<HTMLElement>> = zFocusRegionName.options.values().reduce(
|
||||
(acc, region) => {
|
||||
acc[region] = new Set<HTMLElement>();
|
||||
return acc;
|
||||
|
||||
@@ -77,7 +77,7 @@ export const QueueItemPreviewMini = memo(({ item, isSelected, index }: Props) =>
|
||||
onDoubleClick={onDoubleClick}
|
||||
>
|
||||
<QueueItemStatusLabel item={item} position="absolute" margin="auto" />
|
||||
{imageDTO && <DndImage imageDTO={imageDTO} onLoad={onLoad} asThumbnail position="absolute" />}
|
||||
{imageDTO && <DndImage imageDTO={imageDTO} onLoad={onLoad} asThumbnail />}
|
||||
{!imageLoaded && <QueueItemProgressImage itemId={item.item_id} position="absolute" />}
|
||||
<QueueItemNumber number={index + 1} position="absolute" top={0} left={1} />
|
||||
<QueueItemCircularProgress itemId={item.item_id} status={item.status} position="absolute" top={1} right={2} />
|
||||
|
||||
@@ -94,7 +94,6 @@ const useScrollableStagingArea = (rootRef: RefObject<HTMLDivElement>) => {
|
||||
const { viewport } = osInstance.elements();
|
||||
viewport.style.overflowX = `var(--os-viewport-overflow-x)`;
|
||||
viewport.style.overflowY = `var(--os-viewport-overflow-y)`;
|
||||
viewport.style.textAlign = 'center';
|
||||
},
|
||||
},
|
||||
options: {
|
||||
|
||||
@@ -143,7 +143,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
|
||||
DndDragPreviewSingleImageState | DndDragPreviewMultipleImageState | null
|
||||
>(null);
|
||||
// Must use callback ref - else chakra's Image fallback prop will break the ref & dnd
|
||||
const [element, ref] = useState<HTMLDivElement | null>(null);
|
||||
const [element, ref] = useState<HTMLImageElement | null>(null);
|
||||
const selectIsSelectedForCompare = useMemo(
|
||||
() => createSelector(selectGallerySlice, (gallery) => gallery.imageToCompare === imageDTO.image_name),
|
||||
[imageDTO.image_name]
|
||||
@@ -246,7 +246,6 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
|
||||
<>
|
||||
<Box sx={galleryImageContainerSX} data-is-dragging={isDragging} data-image-name={imageDTO.image_name}>
|
||||
<Flex
|
||||
ref={ref}
|
||||
role="button"
|
||||
className={GALLERY_IMAGE_CLASS}
|
||||
onMouseOver={onMouseOver}
|
||||
@@ -257,6 +256,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
|
||||
data-selected-for-compare={isSelectedForCompare}
|
||||
>
|
||||
<Image
|
||||
ref={ref}
|
||||
src={imageDTO.thumbnail_url}
|
||||
w={imageDTO.width}
|
||||
fallback={<GalleryImagePlaceholder />}
|
||||
|
||||
@@ -78,7 +78,7 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise<GraphBuilder
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: getPrefixedId('neg_cond'),
|
||||
prompt: prompts.negative,
|
||||
style: prompts.useMainPromptsForStyle ? prompts.negative : prompts.negativeStyle,
|
||||
style: prompts.negativeStyle,
|
||||
});
|
||||
const negCondCollect = g.addNode({
|
||||
type: 'collect',
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "6.0.1"
|
||||
__version__ = "6.1.0rc1"
|
||||
|
||||
@@ -191,14 +191,14 @@ def test_migrator_registers_migration(migrator: SqliteMigrator, migration_no_op:
|
||||
|
||||
|
||||
def test_migrator_creates_migrations_table(migrator: SqliteMigrator) -> None:
|
||||
cursor = migrator._db._conn.cursor()
|
||||
cursor = migrator._db.conn.cursor()
|
||||
migrator._create_migrations_table(cursor)
|
||||
cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='migrations';")
|
||||
assert cursor.fetchone() is not None
|
||||
|
||||
|
||||
def test_migrator_migration_sets_version(migrator: SqliteMigrator, migration_no_op: Migration) -> None:
|
||||
cursor = migrator._db._conn.cursor()
|
||||
cursor = migrator._db.conn.cursor()
|
||||
migrator._create_migrations_table(cursor)
|
||||
migrator.register_migration(migration_no_op)
|
||||
migrator.run_migrations()
|
||||
@@ -207,7 +207,7 @@ def test_migrator_migration_sets_version(migrator: SqliteMigrator, migration_no_
|
||||
|
||||
|
||||
def test_migrator_gets_current_version(migrator: SqliteMigrator, migration_no_op: Migration) -> None:
|
||||
cursor = migrator._db._conn.cursor()
|
||||
cursor = migrator._db.conn.cursor()
|
||||
assert migrator._get_current_version(cursor) == 0
|
||||
migrator._create_migrations_table(cursor)
|
||||
assert migrator._get_current_version(cursor) == 0
|
||||
@@ -217,7 +217,7 @@ def test_migrator_gets_current_version(migrator: SqliteMigrator, migration_no_op
|
||||
|
||||
|
||||
def test_migrator_runs_single_migration(migrator: SqliteMigrator, migration_create_test_table: Migration) -> None:
|
||||
cursor = migrator._db._conn.cursor()
|
||||
cursor = migrator._db.conn.cursor()
|
||||
migrator._create_migrations_table(cursor)
|
||||
migrator._run_migration(migration_create_test_table)
|
||||
assert migrator._get_current_version(cursor) == 1
|
||||
@@ -226,7 +226,7 @@ def test_migrator_runs_single_migration(migrator: SqliteMigrator, migration_crea
|
||||
|
||||
|
||||
def test_migrator_runs_all_migrations_in_memory(migrator: SqliteMigrator) -> None:
|
||||
cursor = migrator._db._conn.cursor()
|
||||
cursor = migrator._db.conn.cursor()
|
||||
migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
|
||||
for migration in migrations:
|
||||
migrator.register_migration(migration)
|
||||
@@ -247,7 +247,7 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
|
||||
original_db_cursor = original_db_conn.cursor()
|
||||
assert SqliteMigrator._get_current_version(original_db_cursor) == 3
|
||||
# Must manually close else we get an error on Windows
|
||||
db._conn.close()
|
||||
db.conn.close()
|
||||
|
||||
|
||||
def test_migrator_backs_up_db(logger: Logger) -> None:
|
||||
@@ -255,9 +255,9 @@ def test_migrator_backs_up_db(logger: Logger) -> None:
|
||||
original_db_path = Path(tempdir) / "invokeai.db"
|
||||
db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False)
|
||||
# Write some data to the db to test for successful backup
|
||||
temp_cursor = db._conn.cursor()
|
||||
temp_cursor = db.conn.cursor()
|
||||
temp_cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
db._conn.commit()
|
||||
db.conn.commit()
|
||||
# Set up the migrator
|
||||
migrator = SqliteMigrator(db=db)
|
||||
migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
|
||||
@@ -265,7 +265,7 @@ def test_migrator_backs_up_db(logger: Logger) -> None:
|
||||
migrator.register_migration(migration)
|
||||
migrator.run_migrations()
|
||||
# Must manually close else we get an error on Windows
|
||||
db._conn.close()
|
||||
db.conn.close()
|
||||
assert original_db_path.exists()
|
||||
# We should have a backup file when we migrated a file db
|
||||
assert migrator._backup_path
|
||||
@@ -279,7 +279,7 @@ def test_migrator_backs_up_db(logger: Logger) -> None:
|
||||
def test_migrator_makes_no_changes_on_failed_migration(
|
||||
migrator: SqliteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
|
||||
) -> None:
|
||||
cursor = migrator._db._conn.cursor()
|
||||
cursor = migrator._db.conn.cursor()
|
||||
migrator.register_migration(migration_no_op)
|
||||
migrator.run_migrations()
|
||||
assert migrator._get_current_version(cursor) == 1
|
||||
@@ -290,7 +290,7 @@ def test_migrator_makes_no_changes_on_failed_migration(
|
||||
|
||||
|
||||
def test_idempotent_migrations(migrator: SqliteMigrator, migration_create_test_table: Migration) -> None:
|
||||
cursor = migrator._db._conn.cursor()
|
||||
cursor = migrator._db.conn.cursor()
|
||||
migrator.register_migration(migration_create_test_table)
|
||||
migrator.run_migrations()
|
||||
# not throwing is sufficient
|
||||
|
||||
Reference in New Issue
Block a user