mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(app): optimize boards queries
Use SQL instead of python to retrieve image count, asset count and board cover image. This reduces the number of SQL queries needed to list all boards. Previously, we did `1 + 2 * board_count` queries:: - 1 query to get the list of board records - 1 query per board to get its total count - 1 query per board to get its cover image Then, on the frontend, we made two additional network requests to get each board's counts: - 1 request (== 1 SQL query) for image count - 1 request (== 1 SQL query) for asset count All of this information is now retrieved in a single SQL query, and provided via single network request. As part of this change, `BoardRecord` now includes `image_count`, `asset_count` and `cover_image_name`. This makes `BoardDTO` redundant, but removing it is a deeper change...
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -26,21 +26,25 @@ class BoardRecord(BaseModelExcludeNull):
|
||||
"""Whether or not the board is archived."""
|
||||
is_private: Optional[bool] = Field(default=None, description="Whether the board is private.")
|
||||
"""Whether the board is private."""
|
||||
image_count: int = Field(description="The number of images in the board.")
|
||||
asset_count: int = Field(description="The number of assets in the board.")
|
||||
|
||||
|
||||
def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||
def deserialize_board_record(board_dict: dict[str, Any]) -> BoardRecord:
|
||||
"""Deserializes a board record."""
|
||||
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
board_id = board_dict.get("board_id", "unknown")
|
||||
board_name = board_dict.get("board_name", "unknown")
|
||||
cover_image_name = board_dict.get("cover_image_name", "unknown")
|
||||
cover_image_name = board_dict.get("cover_image_name", None)
|
||||
created_at = board_dict.get("created_at", get_iso_timestamp())
|
||||
updated_at = board_dict.get("updated_at", get_iso_timestamp())
|
||||
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
|
||||
archived = board_dict.get("archived", False)
|
||||
is_private = board_dict.get("is_private", False)
|
||||
image_count = board_dict.get("image_count", 0)
|
||||
asset_count = board_dict.get("asset_count", 0)
|
||||
|
||||
return BoardRecord(
|
||||
board_id=board_id,
|
||||
@@ -51,6 +55,8 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||
deleted_at=deleted_at,
|
||||
archived=archived,
|
||||
is_private=is_private,
|
||||
image_count=image_count,
|
||||
asset_count=asset_count,
|
||||
)
|
||||
|
||||
|
||||
@@ -63,21 +69,21 @@ class BoardChanges(BaseModel, extra="forbid"):
|
||||
class BoardRecordNotFoundException(Exception):
|
||||
"""Raised when an board record is not found."""
|
||||
|
||||
def __init__(self, message="Board record not found"):
|
||||
def __init__(self, message: str = "Board record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BoardRecordSaveException(Exception):
|
||||
"""Raised when an board record cannot be saved."""
|
||||
|
||||
def __init__(self, message="Board record not saved"):
|
||||
def __init__(self, message: str = "Board record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BoardRecordDeleteException(Exception):
|
||||
"""Raised when an board record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Board record not deleted"):
|
||||
def __init__(self, message: str = "Board record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
|
||||
@@ -16,6 +16,114 @@ from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
_BASE_BOARD_RECORD_QUERY = """
|
||||
-- This query retrieves board records, joining with the board_images and images tables to get image counts and cover image names.
|
||||
-- It is not a complete query, as it is missing a GROUP BY or WHERE clause (and is unterminated).
|
||||
SELECT b.board_id,
|
||||
b.board_name,
|
||||
b.created_at,
|
||||
b.updated_at,
|
||||
b.archived,
|
||||
-- Count the number of images in the board, alias image_count
|
||||
COUNT(
|
||||
CASE
|
||||
WHEN i.image_category in ('general') -- "Images" are images in the 'general' category
|
||||
AND i.is_intermediate = 0 THEN 1 -- Intermediates are not counted
|
||||
END
|
||||
) AS image_count,
|
||||
-- Count the number of assets in the board, alias asset_count
|
||||
COUNT(
|
||||
CASE
|
||||
WHEN i.image_category in ('control', 'mask', 'user', 'other') -- "Assets" are images in any of the other categories ('control', 'mask', 'user', 'other')
|
||||
AND i.is_intermediate = 0 THEN 1 -- Intermediates are not counted
|
||||
END
|
||||
) AS asset_count,
|
||||
-- Get the name of the the most recent image in the board, alias cover_image_name
|
||||
(
|
||||
SELECT bi.image_name
|
||||
FROM board_images bi
|
||||
JOIN images i ON bi.image_name = i.image_name
|
||||
WHERE bi.board_id = b.board_id
|
||||
AND i.is_intermediate = 0 -- Intermediates cannot be cover images
|
||||
ORDER BY i.created_at DESC -- Sort by created_at to get the most recent image
|
||||
LIMIT 1
|
||||
) AS cover_image_name
|
||||
FROM boards b
|
||||
LEFT JOIN board_images bi ON b.board_id = bi.board_id
|
||||
LEFT JOIN images i ON bi.image_name = i.image_name
|
||||
"""
|
||||
|
||||
|
||||
def get_paginated_list_board_records_queries(include_archived: bool) -> str:
|
||||
"""Gets a query to retrieve a paginated list of board records. The query has placeholders for limit and offset.
|
||||
|
||||
Args:
|
||||
include_archived: Whether to include archived board records in the results.
|
||||
|
||||
Returns:
|
||||
A query to retrieve a paginated list of board records.
|
||||
"""
|
||||
|
||||
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
|
||||
|
||||
# The GROUP BY must be added _after_ the WHERE clause!
|
||||
query = f"""
|
||||
{_BASE_BOARD_RECORD_QUERY}
|
||||
{archived_condition}
|
||||
GROUP BY b.board_id,
|
||||
b.board_name,
|
||||
b.created_at,
|
||||
b.updated_at
|
||||
ORDER BY b.created_at DESC
|
||||
LIMIT ? OFFSET ?;
|
||||
"""
|
||||
|
||||
return query
|
||||
|
||||
|
||||
def get_total_boards_count_query(include_archived: bool) -> str:
|
||||
"""Gets a query to retrieve the total count of board records.
|
||||
|
||||
Args:
|
||||
include_archived: Whether to include archived board records in the count.
|
||||
|
||||
Returns:
|
||||
A query to retrieve the total count of board records.
|
||||
"""
|
||||
|
||||
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
|
||||
|
||||
return f"SELECT COUNT(*) FROM boards {archived_condition};"
|
||||
|
||||
|
||||
def get_list_all_board_records_query(include_archived: bool) -> str:
|
||||
"""Gets a query to retrieve all board records.
|
||||
|
||||
Args:
|
||||
include_archived: Whether to include archived board records in the results.
|
||||
|
||||
Returns:
|
||||
A query to retrieve all board records.
|
||||
"""
|
||||
|
||||
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
|
||||
|
||||
return f"""
|
||||
{_BASE_BOARD_RECORD_QUERY}
|
||||
{archived_condition}
|
||||
GROUP BY b.board_id,
|
||||
b.board_name,
|
||||
b.created_at,
|
||||
b.updated_at
|
||||
ORDER BY b.created_at DESC;
|
||||
"""
|
||||
|
||||
|
||||
def get_board_record_query() -> str:
|
||||
"""Gets a query to retrieve a board record. The query has a placeholder for the board_id."""
|
||||
|
||||
return f"{_BASE_BOARD_RECORD_QUERY} WHERE b.board_id = ?;"
|
||||
|
||||
|
||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
_conn: sqlite3.Connection
|
||||
@@ -77,11 +185,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
get_board_record_query(),
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
@@ -93,7 +197,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BoardRecordNotFoundException
|
||||
return BoardRecord(**dict(result))
|
||||
return deserialize_board_record(dict(result))
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -150,45 +254,15 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Build base query
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?;
|
||||
"""
|
||||
main_query = get_paginated_list_board_records_queries(include_archived=include_archived)
|
||||
|
||||
# Determine archived filter condition
|
||||
if include_archived:
|
||||
archived_filter = ""
|
||||
else:
|
||||
archived_filter = "WHERE archived = 0"
|
||||
|
||||
final_query = base_query.format(archived_filter=archived_filter)
|
||||
|
||||
# Execute query to fetch boards
|
||||
self._cursor.execute(final_query, (limit, offset))
|
||||
self._cursor.execute(main_query, (limit, offset))
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
# Determine count query
|
||||
if include_archived:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards;
|
||||
"""
|
||||
else:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards
|
||||
WHERE archived = 0;
|
||||
"""
|
||||
|
||||
# Execute count query
|
||||
self._cursor.execute(count_query)
|
||||
|
||||
total_query = get_total_boards_count_query(include_archived=include_archived)
|
||||
self._cursor.execute(total_query)
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
|
||||
@@ -202,26 +276,10 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
|
||||
if include_archived:
|
||||
archived_filter = ""
|
||||
else:
|
||||
archived_filter = "WHERE archived = 0"
|
||||
|
||||
final_query = base_query.format(archived_filter=archived_filter)
|
||||
|
||||
self._cursor.execute(final_query)
|
||||
|
||||
query = get_list_all_board_records_query(include_archived=include_archived)
|
||||
self._cursor.execute(query)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
return boards
|
||||
|
||||
except sqlite3.Error as e:
|
||||
|
||||
@@ -1,23 +1,8 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.services.board_records.board_records_common import BoardRecord
|
||||
|
||||
|
||||
# TODO(psyche): BoardDTO is now identical to BoardRecord. We should consider removing it.
|
||||
class BoardDTO(BoardRecord):
|
||||
"""Deserialized board record with cover image URL and image count."""
|
||||
"""Deserialized board record."""
|
||||
|
||||
cover_image_name: Optional[str] = Field(description="The name of the board's cover image.")
|
||||
"""The URL of the thumbnail of the most recent image in the board."""
|
||||
image_count: int = Field(description="The number of images in the board.")
|
||||
"""The number of images in the board."""
|
||||
|
||||
|
||||
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
|
||||
"""Converts a board record to a board DTO."""
|
||||
return BoardDTO(
|
||||
**board_record.model_dump(exclude={"cover_image_name"}),
|
||||
cover_image_name=cover_image_name,
|
||||
image_count=image_count,
|
||||
)
|
||||
pass
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from invokeai.app.services.board_records.board_records_common import BoardChanges
|
||||
from invokeai.app.services.boards.boards_base import BoardServiceABC
|
||||
from invokeai.app.services.boards.boards_common import BoardDTO, board_record_to_dto
|
||||
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
|
||||
@@ -16,17 +16,11 @@ class BoardService(BoardServiceABC):
|
||||
board_name: str,
|
||||
) -> BoardDTO:
|
||||
board_record = self.__invoker.services.board_records.save(board_name)
|
||||
return board_record_to_dto(board_record, None, 0)
|
||||
return BoardDTO.model_validate(board_record.model_dump())
|
||||
|
||||
def get_dto(self, board_id: str) -> BoardDTO:
|
||||
board_record = self.__invoker.services.board_records.get(board_id)
|
||||
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
|
||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||
return BoardDTO.model_validate(board_record.model_dump())
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -34,14 +28,7 @@ class BoardService(BoardServiceABC):
|
||||
changes: BoardChanges,
|
||||
) -> BoardDTO:
|
||||
board_record = self.__invoker.services.board_records.update(board_id, changes)
|
||||
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
|
||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||
return BoardDTO.model_validate(board_record.model_dump())
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
self.__invoker.services.board_records.delete(board_id)
|
||||
@@ -50,30 +37,10 @@ class BoardService(BoardServiceABC):
|
||||
self, offset: int = 0, limit: int = 10, include_archived: bool = False
|
||||
) -> OffsetPaginatedResults[BoardDTO]:
|
||||
board_records = self.__invoker.services.board_records.get_many(offset, limit, include_archived)
|
||||
board_dtos = []
|
||||
for r in board_records.items:
|
||||
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||
|
||||
board_dtos = [BoardDTO.model_validate(r.model_dump()) for r in board_records.items]
|
||||
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
|
||||
|
||||
def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
|
||||
board_records = self.__invoker.services.board_records.get_all(include_archived)
|
||||
board_dtos = []
|
||||
for r in board_records:
|
||||
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||
|
||||
board_dtos = [BoardDTO.model_validate(r.model_dump()) for r in board_records]
|
||||
return board_dtos
|
||||
|
||||
@@ -127,7 +127,16 @@ def test_generate_id_with_board_id(monkeypatch: Any, mock_invoker: Invoker):
|
||||
|
||||
def mock_board_get(*args, **kwargs):
|
||||
return BoardRecord(
|
||||
board_id="12345", board_name="test_board_name", created_at="None", updated_at="None", archived=False
|
||||
board_id="12345",
|
||||
board_name="test_board_name",
|
||||
created_at="None",
|
||||
updated_at="None",
|
||||
archived=False,
|
||||
asset_count=0,
|
||||
image_count=0,
|
||||
cover_image_name="asdf.png",
|
||||
deleted_at=None,
|
||||
is_private=False,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_board_get)
|
||||
@@ -156,7 +165,16 @@ def test_handler_board_id(tmp_path: Path, monkeypatch: Any, mock_image_dto: Imag
|
||||
|
||||
def mock_board_get(*args, **kwargs):
|
||||
return BoardRecord(
|
||||
board_id="12345", board_name="test_board_name", created_at="None", updated_at="None", archived=False
|
||||
board_id="12345",
|
||||
board_name="test_board_name",
|
||||
created_at="None",
|
||||
updated_at="None",
|
||||
archived=False,
|
||||
asset_count=0,
|
||||
image_count=0,
|
||||
cover_image_name="asdf.png",
|
||||
deleted_at=None,
|
||||
is_private=False,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(mock_invoker.services.board_records, "get", mock_board_get)
|
||||
|
||||
Reference in New Issue
Block a user