feat(app): expose a cursor, not a connection in db util

This commit is contained in:
psychedelicious
2025-07-11 08:12:26 +10:00
parent a19aa3b032
commit fc71849c24
9 changed files with 106 additions and 190 deletions

View File

@@ -21,8 +21,8 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
board_id: str,
image_name: str,
) -> None:
with self._db.conn() as conn:
conn.execute(
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
INSERT INTO board_images (board_id, image_name)
VALUES (?, ?)
@@ -35,8 +35,8 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
self,
image_name: str,
) -> None:
with self._db.conn() as conn:
conn.execute(
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE FROM board_images
WHERE image_name = ?;
@@ -50,8 +50,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT images.*
@@ -80,8 +79,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
categories: list[ImageCategory] | None,
is_intermediate: bool | None,
) -> list[str]:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
params: list[str | bool] = []
# Base query is a join between images and board_images
@@ -137,8 +135,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
self,
image_name: str,
) -> Optional[str]:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT board_id
@@ -153,8 +150,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
return cast(str, result[0])
def get_image_count_for_board(self, board_id: str) -> int:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT COUNT(*)

View File

@@ -23,9 +23,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
self._db = db
def delete(self, board_id: str) -> None:
with self._db.conn() as conn:
with self._db.transaction() as cursor:
try:
cursor = conn.cursor()
cursor.execute(
"""--sql
DELETE FROM boards
@@ -40,10 +39,9 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
self,
board_name: str,
) -> BoardRecord:
with self._db.conn() as conn:
with self._db.transaction() as cursor:
try:
board_id = uuid_string()
cursor = conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
@@ -59,9 +57,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
self,
board_id: str,
) -> BoardRecord:
with self._db.conn() as conn:
with self._db.transaction() as cursor:
try:
cursor = conn.cursor()
cursor.execute(
"""--sql
SELECT *
@@ -83,9 +80,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
board_id: str,
changes: BoardChanges,
) -> BoardRecord:
with self._db.conn() as conn:
with self._db.transaction() as cursor:
try:
cursor = conn.cursor()
# Change the name of a board
if changes.board_name is not None:
cursor.execute(
@@ -131,9 +127,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardRecord]:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
# Build base query
base_query = """
SELECT *
@@ -179,8 +173,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
) -> list[BoardRecord]:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT *

View File

@@ -27,9 +27,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self._db = db
def get(self, image_name: str) -> ImageRecord:
with self._db.conn() as conn:
with self._db.transaction() as cursor:
try:
cursor = conn.cursor()
cursor.execute(
f"""--sql
SELECT {IMAGE_DTO_COLS} FROM images
@@ -48,9 +47,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return deserialize_image_record(dict(result))
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
with self._db.conn() as conn:
with self._db.transaction() as cursor:
try:
cursor = conn.cursor()
cursor.execute(
"""--sql
SELECT metadata FROM images
@@ -76,9 +74,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
image_name: str,
changes: ImageRecordChanges,
) -> None:
with self._db.conn() as conn:
with self._db.transaction() as cursor:
try:
cursor = conn.cursor()
# Change the category of the image
if changes.image_category is not None:
cursor.execute(
@@ -138,9 +135,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
@@ -227,20 +222,20 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Build the list of images, deserializing each row
cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
cursor.execute(count_query, count_params)
count = cast(int, cursor.fetchone()[0])
images = [deserialize_image_record(dict(r)) for r in result]
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
cursor.execute(count_query, count_params)
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
def delete(self, image_name: str) -> None:
with self._db.conn() as conn:
with self._db.transaction() as cursor:
try:
cursor = conn.cursor()
cursor.execute(
"""--sql
DELETE FROM images
@@ -252,10 +247,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
raise ImageRecordDeleteException from e
def delete_many(self, image_names: list[str]) -> None:
with self._db.conn() as conn:
with self._db.transaction() as cursor:
try:
cursor = conn.cursor()
placeholders = ",".join("?" for _ in image_names)
# Construct the SQLite query with the placeholders
@@ -268,8 +261,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
raise ImageRecordDeleteException from e
def get_intermediates_count(self) -> int:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images
@@ -280,9 +272,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return count
def delete_intermediates(self) -> list[str]:
with self._db.conn() as conn:
with self._db.transaction() as cursor:
try:
cursor = conn.cursor()
cursor.execute(
"""--sql
SELECT image_name FROM images
@@ -315,9 +306,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id: Optional[str] = None,
metadata: Optional[str] = None,
) -> datetime:
with self._db.conn() as conn:
with self._db.transaction() as cursor:
try:
cursor = conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO images (
@@ -366,8 +356,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return created_at
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT images.*
@@ -398,9 +387,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
# Build query conditions (reused for both starred count and image names queries)
query_conditions = ""
query_params: list[Union[int, str, bool]] = []

View File

@@ -88,9 +88,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
with self._db.conn() as conn:
with self._db.transaction() as cursor:
try:
cursor = conn.cursor()
cursor.execute(
"""--sql
INSERT INTO models (
@@ -127,8 +126,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise an UnknownModelException
"""
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE FROM models
@@ -140,7 +138,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
raise UnknownModelException("model not found")
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
with self._db.conn() as conn:
with self._db.transaction() as cursor:
record = self.get_model(key)
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
@@ -149,7 +147,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
json_serialized = record.model_dump_json()
cursor = conn.cursor()
cursor.execute(
"""--sql
UPDATE models
@@ -172,8 +169,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Exceptions: UnknownModelException
"""
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
@@ -188,8 +184,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
return model
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
@@ -209,8 +204,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param key: Unique key for the model to be deleted
"""
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
select count(*) FROM models
@@ -241,7 +235,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
If none of the optional filters are passed, will return all
models in the database.
"""
with self._db.conn() as conn:
with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
@@ -267,7 +261,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
cursor = conn.cursor()
cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
@@ -299,8 +292,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
"""Return models with the indicated path."""
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
@@ -313,8 +305,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
"""Return models with the indicated hash."""
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
@@ -329,7 +320,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database."""
with self._db.conn() as conn:
with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
@@ -339,8 +330,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
ModelRecordOrderBy.Format: "format",
}
cursor = conn.cursor()
# Lock so that the database isn't updated while we're doing the two queries.
# query1: get the total number of model configs
cursor.execute(

View File

@@ -10,28 +10,25 @@ class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase):
self._db = db
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
with self._db.conn() as conn:
with self._db.transaction() as cursor:
if model_key_1 == model_key_2:
raise ValueError("Cannot relate a model to itself.")
a, b = sorted([model_key_1, model_key_2])
cursor = conn.cursor()
cursor.execute(
"INSERT OR IGNORE INTO model_relationships (model_key_1, model_key_2) VALUES (?, ?)",
(a, b),
)
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
with self._db.conn() as conn:
with self._db.transaction() as cursor:
a, b = sorted([model_key_1, model_key_2])
cursor = conn.cursor()
cursor.execute(
"DELETE FROM model_relationships WHERE model_key_1 = ? AND model_key_2 = ?",
(a, b),
)
def get_related_model_keys(self, model_key: str) -> list[str]:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
@@ -44,9 +41,7 @@ class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase):
return result
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
key_list = ",".join("?" for _ in model_keys)
cursor.execute(
f"""

View File

@@ -57,8 +57,8 @@ class SqliteSessionQueue(SessionQueueBase):
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
This is necessary because the invoker may have been killed while processing a queue item.
"""
with self._db.conn() as conn:
conn.execute(
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -68,8 +68,7 @@ class SqliteSessionQueue(SessionQueueBase):
def _get_current_queue_size(self, queue_id: str) -> int:
"""Gets the current number of pending queue items"""
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
@@ -85,8 +84,7 @@ class SqliteSessionQueue(SessionQueueBase):
def _get_highest_priority(self, queue_id: str) -> int:
"""Gets the highest priority value in the queue"""
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT MAX(priority)
@@ -122,8 +120,7 @@ class SqliteSessionQueue(SessionQueueBase):
)
enqueued_count = len(values_to_insert)
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
@@ -153,8 +150,7 @@ class SqliteSessionQueue(SessionQueueBase):
return enqueue_result
def dequeue(self) -> Optional[SessionQueueItem]:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
@@ -174,8 +170,7 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
@@ -196,8 +191,7 @@ class SqliteSessionQueue(SessionQueueBase):
return SessionQueueItem.queue_item_from_dict(dict(result))
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
@@ -222,8 +216,7 @@ class SqliteSessionQueue(SessionQueueBase):
error_message: Optional[str] = None,
error_traceback: Optional[str] = None,
) -> SessionQueueItem:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status FROM session_queue WHERE item_id = ?
@@ -239,8 +232,7 @@ class SqliteSessionQueue(SessionQueueBase):
if current_status in ("completed", "failed", "canceled"):
return self.get_queue_item(item_id)
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
UPDATE session_queue
@@ -257,8 +249,7 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
@@ -271,8 +262,7 @@ class SqliteSessionQueue(SessionQueueBase):
return IsEmptyResult(is_empty=is_empty)
def is_full(self, queue_id: str) -> IsFullResult:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
@@ -286,8 +276,7 @@ class SqliteSessionQueue(SessionQueueBase):
return IsFullResult(is_full=is_full)
def clear(self, queue_id: str) -> ClearResult:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT COUNT(*)
@@ -309,8 +298,7 @@ class SqliteSessionQueue(SessionQueueBase):
return ClearResult(deleted=count)
def prune(self, queue_id: str) -> PruneResult:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
where = """--sql
WHERE
queue_id = ?
@@ -349,8 +337,7 @@ class SqliteSessionQueue(SessionQueueBase):
self.cancel_queue_item(item_id)
except SessionQueueItemNotFoundError:
pass
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE
@@ -381,8 +368,7 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
placeholders = ", ".join(["?" for _ in batch_ids])
where = f"""--sql
@@ -420,8 +406,7 @@ class SqliteSessionQueue(SessionQueueBase):
return CancelByBatchIDsResult(canceled=count)
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
where = """--sql
WHERE
@@ -456,8 +441,7 @@ class SqliteSessionQueue(SessionQueueBase):
return CancelByDestinationResult(canceled=count)
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
if current_queue_item is not None and current_queue_item.destination == destination:
self.cancel_queue_item(current_queue_item.item_id)
@@ -486,8 +470,7 @@ class SqliteSessionQueue(SessionQueueBase):
return DeleteByDestinationResult(deleted=count)
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
where = """--sql
WHERE
queue_id == ?
@@ -513,8 +496,7 @@ class SqliteSessionQueue(SessionQueueBase):
return DeleteAllExceptCurrentResult(deleted=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
where = """--sql
WHERE
@@ -549,8 +531,7 @@ class SqliteSessionQueue(SessionQueueBase):
return CancelByQueueIDResult(canceled=count)
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
where = """--sql
WHERE
queue_id == ?
@@ -576,8 +557,7 @@ class SqliteSessionQueue(SessionQueueBase):
return CancelAllExceptCurrentResult(canceled=count)
def get_queue_item(self, item_id: int) -> SessionQueueItem:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT * FROM session_queue
@@ -592,8 +572,7 @@ class SqliteSessionQueue(SessionQueueBase):
return SessionQueueItem.queue_item_from_dict(dict(result))
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
# during execution.
@@ -617,8 +596,7 @@ class SqliteSessionQueue(SessionQueueBase):
status: Optional[QUEUE_ITEM_STATUS] = None,
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
with self._db.conn() as conn:
cursor_ = conn.cursor()
with self._db.transaction() as cursor_:
item_id = cursor
query = """--sql
SELECT *
@@ -668,8 +646,7 @@ class SqliteSessionQueue(SessionQueueBase):
destination: Optional[str] = None,
) -> list[SessionQueueItem]:
"""Gets all queue items that match the given parameters"""
with self._db.conn() as conn:
cursor_ = conn.cursor()
with self._db.transaction() as cursor:
query = """--sql
SELECT *
FROM session_queue
@@ -689,14 +666,13 @@ class SqliteSessionQueue(SessionQueueBase):
item_id ASC
;
"""
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
cursor.execute(query, params)
results = cast(list[sqlite3.Row], cursor.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
return items
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*)
@@ -725,8 +701,7 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*), origin, destination
@@ -758,8 +733,7 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*)
@@ -788,8 +762,7 @@ class SqliteSessionQueue(SessionQueueBase):
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
"""Retries the given queue items"""
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
values_to_insert: list[ValueToInsertTuple] = []
retried_item_ids: list[int] = []

View File

@@ -63,7 +63,7 @@ class SqliteDatabase:
if not self._db_path:
return
try:
with self.conn() as conn:
with self._conn as conn:
initial_db_size = Path(self._db_path).stat().st_size
conn.execute("VACUUM;")
conn.commit()
@@ -76,17 +76,18 @@ class SqliteDatabase:
raise
@contextmanager
def conn(self) -> Generator[sqlite3.Connection]:
def transaction(self) -> Generator[sqlite3.Cursor, None, None]:
"""
Thread-safe context manager for DB work.
Acquires the RLock, yields the Connection, then commits or rolls back.
Acquires the RLock, yields a Cursor, then commits or rolls back.
"""
self._lock.acquire()
try:
yield self._conn
self._conn.commit()
except:
self._conn.rollback()
raise
finally:
self._lock.release()
with self._lock:
cursor = self._conn.cursor()
try:
yield cursor
self._conn.commit()
except:
self._conn.rollback()
raise
finally:
cursor.close()

View File

@@ -25,8 +25,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
"""Gets a style preset by ID."""
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
@@ -42,8 +41,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
style_preset_id = uuid_string()
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
@@ -65,8 +63,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
style_preset_ids = []
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
for style_preset in style_presets:
style_preset_id = uuid_string()
style_preset_ids.append(style_preset_id)
@@ -91,8 +88,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
return None
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
# Change the name of a style preset
if changes.name is not None:
cursor.execute(
@@ -118,8 +114,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
return self.get(style_preset_id)
def delete(self, style_preset_id: str) -> None:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE from style_presets
@@ -130,9 +125,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
return None
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
main_query = """
SELECT
*
@@ -156,9 +149,8 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def _sync_default_style_presets(self) -> None:
"""Syncs default style presets to the database. Internal use only."""
with self._db.conn() as conn:
with self._db.transaction() as cursor:
# First delete all existing default style presets
cursor = conn.cursor()
cursor.execute(
"""--sql
DELETE FROM style_presets

View File

@@ -33,8 +33,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def get(self, workflow_id: str) -> WorkflowRecordDTO:
"""Gets a workflow by ID. Updates the opened_at column."""
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
@@ -52,9 +51,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be created via this method")
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
cursor.execute(
"""--sql
@@ -72,8 +69,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be updated")
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
UPDATE workflow_library
@@ -88,8 +84,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be deleted")
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE from workflow_library
@@ -111,7 +106,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
with self._db.conn() as conn:
with self._db.transaction() as cursor:
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
@@ -207,7 +202,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
main_query += ";"
count_query += ";"
cursor = conn.cursor()
cursor.execute(main_query, main_params)
rows = cursor.fetchall()
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
@@ -238,8 +232,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if not tags:
return {}
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
result: dict[str, int] = {}
# Base conditions for categories and selected tags
base_conditions: list[str] = []
@@ -288,8 +281,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> dict[str, int]:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
result: dict[str, int] = {}
# Base conditions for categories
base_conditions: list[str] = []
@@ -333,8 +325,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
return result
def update_opened_at(self, workflow_id: str) -> None:
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
cursor.execute(
f"""--sql
UPDATE workflow_library
@@ -357,8 +348,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
meaningless, as they are overwritten every time the server starts.
"""
with self._db.conn() as conn:
cursor = conn.cursor()
with self._db.transaction() as cursor:
workflows_from_file: list[Workflow] = []
workflows_to_update: list[Workflow] = []
workflows_to_add: list[Workflow] = []