mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(app): expose a cursor, not a connection in db util
This commit is contained in:
@@ -21,8 +21,8 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
with self._db.conn() as conn:
|
||||
conn.execute(
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO board_images (board_id, image_name)
|
||||
VALUES (?, ?)
|
||||
@@ -35,8 +35,8 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
self,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
with self._db.conn() as conn:
|
||||
conn.execute(
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM board_images
|
||||
WHERE image_name = ?;
|
||||
@@ -50,8 +50,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
@@ -80,8 +79,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
categories: list[ImageCategory] | None,
|
||||
is_intermediate: bool | None,
|
||||
) -> list[str]:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
params: list[str | bool] = []
|
||||
|
||||
# Base query is a join between images and board_images
|
||||
@@ -137,8 +135,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Optional[str]:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT board_id
|
||||
@@ -153,8 +150,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
return cast(str, result[0])
|
||||
|
||||
def get_image_count_for_board(self, board_id: str) -> int:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
|
||||
@@ -23,9 +23,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
self._db = db
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
with self._db.conn() as conn:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM boards
|
||||
@@ -40,10 +39,9 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardRecord:
|
||||
with self._db.conn() as conn:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
board_id = uuid_string()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO boards (board_id, board_name)
|
||||
@@ -59,9 +57,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
self,
|
||||
board_id: str,
|
||||
) -> BoardRecord:
|
||||
with self._db.conn() as conn:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
@@ -83,9 +80,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardRecord:
|
||||
with self._db.conn() as conn:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
# Change the name of a board
|
||||
if changes.board_name is not None:
|
||||
cursor.execute(
|
||||
@@ -131,9 +127,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
limit: int = 10,
|
||||
include_archived: bool = False,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
# Build base query
|
||||
base_query = """
|
||||
SELECT *
|
||||
@@ -179,8 +173,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
def get_all(
|
||||
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
|
||||
) -> list[BoardRecord]:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
if order_by == BoardRecordOrderBy.Name:
|
||||
base_query = """
|
||||
SELECT *
|
||||
|
||||
@@ -27,9 +27,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
self._db = db
|
||||
|
||||
def get(self, image_name: str) -> ImageRecord:
|
||||
with self._db.conn() as conn:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS} FROM images
|
||||
@@ -48,9 +47,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
return deserialize_image_record(dict(result))
|
||||
|
||||
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||
with self._db.conn() as conn:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM images
|
||||
@@ -76,9 +74,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
with self._db.conn() as conn:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
# Change the category of the image
|
||||
if changes.image_category is not None:
|
||||
cursor.execute(
|
||||
@@ -138,9 +135,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
count_query = """--sql
|
||||
SELECT COUNT(*)
|
||||
@@ -227,20 +222,20 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
# Build the list of images, deserializing each row
|
||||
cursor.execute(images_query, images_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
|
||||
# Set up and execute the count query, without pagination
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
cursor.execute(count_query, count_params)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
|
||||
# Set up and execute the count query, without pagination
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
cursor.execute(count_query, count_params)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||
|
||||
def delete(self, image_name: str) -> None:
|
||||
with self._db.conn() as conn:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
@@ -252,10 +247,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
raise ImageRecordDeleteException from e
|
||||
|
||||
def delete_many(self, image_names: list[str]) -> None:
|
||||
with self._db.conn() as conn:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
|
||||
placeholders = ",".join("?" for _ in image_names)
|
||||
|
||||
# Construct the SQLite query with the placeholders
|
||||
@@ -268,8 +261,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
raise ImageRecordDeleteException from e
|
||||
|
||||
def get_intermediates_count(self) -> int:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images
|
||||
@@ -280,9 +272,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
return count
|
||||
|
||||
def delete_intermediates(self) -> list[str]:
|
||||
with self._db.conn() as conn:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT image_name FROM images
|
||||
@@ -315,9 +306,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
node_id: Optional[str] = None,
|
||||
metadata: Optional[str] = None,
|
||||
) -> datetime:
|
||||
with self._db.conn() as conn:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO images (
|
||||
@@ -366,8 +356,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
return created_at
|
||||
|
||||
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
@@ -398,9 +387,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> ImageNamesResult:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
# Build query conditions (reused for both starred count and image names queries)
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
|
||||
@@ -88,9 +88,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||
"""
|
||||
with self._db.conn() as conn:
|
||||
with self._db.transaction() as cursor:
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO models (
|
||||
@@ -127,8 +126,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise an UnknownModelException
|
||||
"""
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM models
|
||||
@@ -140,7 +138,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
raise UnknownModelException("model not found")
|
||||
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
with self._db.conn() as conn:
|
||||
with self._db.transaction() as cursor:
|
||||
record = self.get_model(key)
|
||||
|
||||
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
|
||||
@@ -149,7 +147,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
json_serialized = record.model_dump_json()
|
||||
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE models
|
||||
@@ -172,8 +169,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
"""
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
@@ -188,8 +184,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
return model
|
||||
|
||||
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
@@ -209,8 +204,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
:param key: Unique key for the model to be deleted
|
||||
"""
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM models
|
||||
@@ -241,7 +235,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
"""
|
||||
with self._db.conn() as conn:
|
||||
with self._db.transaction() as cursor:
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
@@ -267,7 +261,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
bindings.append(model_format)
|
||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config, strftime('%s',updated_at)
|
||||
@@ -299,8 +292,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated path."""
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
@@ -313,8 +305,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated hash."""
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
@@ -329,7 +320,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Return a paginated summary listing of each model in the database."""
|
||||
with self._db.conn() as conn:
|
||||
with self._db.transaction() as cursor:
|
||||
assert isinstance(order_by, ModelRecordOrderBy)
|
||||
ordering = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
@@ -339,8 +330,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
# query1: get the total number of model configs
|
||||
cursor.execute(
|
||||
|
||||
@@ -10,28 +10,25 @@ class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase):
|
||||
self._db = db
|
||||
|
||||
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
|
||||
with self._db.conn() as conn:
|
||||
with self._db.transaction() as cursor:
|
||||
if model_key_1 == model_key_2:
|
||||
raise ValueError("Cannot relate a model to itself.")
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"INSERT OR IGNORE INTO model_relationships (model_key_1, model_key_2) VALUES (?, ?)",
|
||||
(a, b),
|
||||
)
|
||||
|
||||
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
|
||||
with self._db.conn() as conn:
|
||||
with self._db.transaction() as cursor:
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"DELETE FROM model_relationships WHERE model_key_1 = ? AND model_key_2 = ?",
|
||||
(a, b),
|
||||
)
|
||||
|
||||
def get_related_model_keys(self, model_key: str) -> list[str]:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
|
||||
@@ -44,9 +41,7 @@ class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase):
|
||||
return result
|
||||
|
||||
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
key_list = ",".join("?" for _ in model_keys)
|
||||
cursor.execute(
|
||||
f"""
|
||||
|
||||
@@ -57,8 +57,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
|
||||
This is necessary because the invoker may have been killed while processing a queue item.
|
||||
"""
|
||||
with self._db.conn() as conn:
|
||||
conn.execute(
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
@@ -68,8 +68,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
def _get_current_queue_size(self, queue_id: str) -> int:
|
||||
"""Gets the current number of pending queue items"""
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
@@ -85,8 +84,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
def _get_highest_priority(self, queue_id: str) -> int:
|
||||
"""Gets the highest priority value in the queue"""
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT MAX(priority)
|
||||
@@ -122,8 +120,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
enqueued_count = len(values_to_insert)
|
||||
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
|
||||
@@ -153,8 +150,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return enqueue_result
|
||||
|
||||
def dequeue(self) -> Optional[SessionQueueItem]:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
@@ -174,8 +170,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return queue_item
|
||||
|
||||
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
@@ -196,8 +191,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
|
||||
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
@@ -222,8 +216,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
error_message: Optional[str] = None,
|
||||
error_traceback: Optional[str] = None,
|
||||
) -> SessionQueueItem:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status FROM session_queue WHERE item_id = ?
|
||||
@@ -239,8 +232,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
if current_status in ("completed", "failed", "canceled"):
|
||||
return self.get_queue_item(item_id)
|
||||
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
@@ -257,8 +249,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return queue_item
|
||||
|
||||
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
@@ -271,8 +262,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return IsEmptyResult(is_empty=is_empty)
|
||||
|
||||
def is_full(self, queue_id: str) -> IsFullResult:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
@@ -286,8 +276,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return IsFullResult(is_full=is_full)
|
||||
|
||||
def clear(self, queue_id: str) -> ClearResult:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
@@ -309,8 +298,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return ClearResult(deleted=count)
|
||||
|
||||
def prune(self, queue_id: str) -> PruneResult:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id = ?
|
||||
@@ -349,8 +337,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.cancel_queue_item(item_id)
|
||||
except SessionQueueItemNotFoundError:
|
||||
pass
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE
|
||||
@@ -381,8 +368,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return queue_item
|
||||
|
||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
placeholders = ", ".join(["?" for _ in batch_ids])
|
||||
where = f"""--sql
|
||||
@@ -420,8 +406,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return CancelByBatchIDsResult(canceled=count)
|
||||
|
||||
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
where = """--sql
|
||||
WHERE
|
||||
@@ -456,8 +441,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return CancelByDestinationResult(canceled=count)
|
||||
|
||||
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
self.cancel_queue_item(current_queue_item.item_id)
|
||||
@@ -486,8 +470,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return DeleteByDestinationResult(deleted=count)
|
||||
|
||||
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
@@ -513,8 +496,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return DeleteAllExceptCurrentResult(deleted=count)
|
||||
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
where = """--sql
|
||||
WHERE
|
||||
@@ -549,8 +531,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return CancelByQueueIDResult(canceled=count)
|
||||
|
||||
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
@@ -576,8 +557,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return CancelAllExceptCurrentResult(canceled=count)
|
||||
|
||||
def get_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT * FROM session_queue
|
||||
@@ -592,8 +572,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
|
||||
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
|
||||
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
|
||||
# during execution.
|
||||
@@ -617,8 +596,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||
destination: Optional[str] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItem]:
|
||||
with self._db.conn() as conn:
|
||||
cursor_ = conn.cursor()
|
||||
with self._db.transaction() as cursor_:
|
||||
item_id = cursor
|
||||
query = """--sql
|
||||
SELECT *
|
||||
@@ -668,8 +646,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
destination: Optional[str] = None,
|
||||
) -> list[SessionQueueItem]:
|
||||
"""Gets all queue items that match the given parameters"""
|
||||
with self._db.conn() as conn:
|
||||
cursor_ = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
query = """--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
@@ -689,14 +666,13 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
item_id ASC
|
||||
;
|
||||
"""
|
||||
cursor_.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor_.fetchall())
|
||||
cursor.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
|
||||
return items
|
||||
|
||||
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
@@ -725,8 +701,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
|
||||
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*), origin, destination
|
||||
@@ -758,8 +733,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
|
||||
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
@@ -788,8 +762,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
|
||||
"""Retries the given queue items"""
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
values_to_insert: list[ValueToInsertTuple] = []
|
||||
retried_item_ids: list[int] = []
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ class SqliteDatabase:
|
||||
if not self._db_path:
|
||||
return
|
||||
try:
|
||||
with self.conn() as conn:
|
||||
with self._conn as conn:
|
||||
initial_db_size = Path(self._db_path).stat().st_size
|
||||
conn.execute("VACUUM;")
|
||||
conn.commit()
|
||||
@@ -76,17 +76,18 @@ class SqliteDatabase:
|
||||
raise
|
||||
|
||||
@contextmanager
|
||||
def conn(self) -> Generator[sqlite3.Connection]:
|
||||
def transaction(self) -> Generator[sqlite3.Cursor, None, None]:
|
||||
"""
|
||||
Thread-safe context manager for DB work.
|
||||
Acquires the RLock, yields the Connection, then commits or rolls back.
|
||||
Acquires the RLock, yields a Cursor, then commits or rolls back.
|
||||
"""
|
||||
self._lock.acquire()
|
||||
try:
|
||||
yield self._conn
|
||||
self._conn.commit()
|
||||
except:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
with self._lock:
|
||||
cursor = self._conn.cursor()
|
||||
try:
|
||||
yield cursor
|
||||
self._conn.commit()
|
||||
except:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
@@ -25,8 +25,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
|
||||
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
|
||||
"""Gets a style preset by ID."""
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
@@ -42,8 +41,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
|
||||
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
|
||||
style_preset_id = uuid_string()
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO style_presets (
|
||||
@@ -65,8 +63,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
|
||||
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
|
||||
style_preset_ids = []
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
for style_preset in style_presets:
|
||||
style_preset_id = uuid_string()
|
||||
style_preset_ids.append(style_preset_id)
|
||||
@@ -91,8 +88,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
return None
|
||||
|
||||
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
# Change the name of a style preset
|
||||
if changes.name is not None:
|
||||
cursor.execute(
|
||||
@@ -118,8 +114,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE from style_presets
|
||||
@@ -130,9 +125,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
return None
|
||||
|
||||
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
main_query = """
|
||||
SELECT
|
||||
*
|
||||
@@ -156,9 +149,8 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
|
||||
def _sync_default_style_presets(self) -> None:
|
||||
"""Syncs default style presets to the database. Internal use only."""
|
||||
with self._db.conn() as conn:
|
||||
with self._db.transaction() as cursor:
|
||||
# First delete all existing default style presets
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM style_presets
|
||||
|
||||
@@ -33,8 +33,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
|
||||
def get(self, workflow_id: str) -> WorkflowRecordDTO:
|
||||
"""Gets a workflow by ID. Updates the opened_at column."""
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
|
||||
@@ -52,9 +51,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
if workflow.meta.category is WorkflowCategory.Default:
|
||||
raise ValueError("Default workflows cannot be created via this method")
|
||||
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
with self._db.transaction() as cursor:
|
||||
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
@@ -72,8 +69,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
if workflow.meta.category is WorkflowCategory.Default:
|
||||
raise ValueError("Default workflows cannot be updated")
|
||||
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE workflow_library
|
||||
@@ -88,8 +84,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default:
|
||||
raise ValueError("Default workflows cannot be deleted")
|
||||
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE from workflow_library
|
||||
@@ -111,7 +106,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
with self._db.conn() as conn:
|
||||
with self._db.transaction() as cursor:
|
||||
# sanitize!
|
||||
assert order_by in WorkflowRecordOrderBy
|
||||
assert direction in SQLiteDirection
|
||||
@@ -207,7 +202,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
main_query += ";"
|
||||
count_query += ";"
|
||||
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(main_query, main_params)
|
||||
rows = cursor.fetchall()
|
||||
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
|
||||
@@ -238,8 +232,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
if not tags:
|
||||
return {}
|
||||
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
result: dict[str, int] = {}
|
||||
# Base conditions for categories and selected tags
|
||||
base_conditions: list[str] = []
|
||||
@@ -288,8 +281,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
result: dict[str, int] = {}
|
||||
# Base conditions for categories
|
||||
base_conditions: list[str] = []
|
||||
@@ -333,8 +325,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
return result
|
||||
|
||||
def update_opened_at(self, workflow_id: str) -> None:
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE workflow_library
|
||||
@@ -357,8 +348,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
meaningless, as they are overwritten every time the server starts.
|
||||
"""
|
||||
|
||||
with self._db.conn() as conn:
|
||||
cursor = conn.cursor()
|
||||
with self._db.transaction() as cursor:
|
||||
workflows_from_file: list[Workflow] = []
|
||||
workflows_to_update: list[Workflow] = []
|
||||
workflows_to_add: list[Workflow] = []
|
||||
|
||||
Reference in New Issue
Block a user