Files
InvokeAI/invokeai/app/services/session_queue/session_queue_common.py
Lincoln Stein 33ec16deb4 Feature: Shared/private workflows and image boards in multiuser mode (#9018)
* feat: Per-user workflow libraries in multiuser mode (#114)

* Add per-user workflow isolation: migration 28, service updates, router ownership checks, is_public endpoint, schema regeneration, frontend UI

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>

* feat: add shared workflow checkbox to Details panel, auto-tag, gate edit/delete, fix tests

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: lstein <111189+lstein@users.noreply.github.com>

* Restrict model sync to admin users only (#118)

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: lstein <111189+lstein@users.noreply.github.com>

* feat: distinct splash screens for admin/non-admin users in multiuser mode (#116)

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: lstein <111189+lstein@users.noreply.github.com>

* Disable Save when editing another user's shared workflow in multiuser mode (#120)

* Disable Save when editing another user's shared workflow in multiuser mode

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: lstein <111189+lstein@users.noreply.github.com>

* chore(app): ruff

* Add board visibility (private/shared/public) feature with tests and UI

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>

* Enforce read-only access for non-owners of shared/public boards in UI

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>

* Fix remaining board access enforcement: invoke icon, drag-out, change-board filter, archive

Co-authored-by: lstein <111189+lstein@users.noreply.github.com>

* fix: allow drag from shared boards to non-board targets (viewer, ref image, etc.)

Previously, images in shared boards owned by another user could not be
dragged at all — the draggable setup was completely skipped in
GalleryImage.tsx when canWriteImages was false. This blocked ALL drop
targets including the viewer, reference image pane, and canvas.

Now images are always draggable. The board-move restriction is enforced
in the dnd target isValid functions instead:
- addImageToBoardDndTarget: rejects moves from shared boards the user
  doesn't own (unless admin or board is public)
- removeImageFromBoardDndTarget: same check

Other drop targets (viewer, reference images, canvas, comparison, etc.)
remain fully functional for shared board images.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(security): add auth requirement to all sensitive routes in multimodal mode

* chore(backend): ruff

* fix (backend): improve user isolation for session queue and recall parameters

 - Sanitize session queue information of all cross-user fields except for the timestamps and status.
 - Recall parameters are now user-scoped.
 - Queue status endpoints now report user-scoped activity rather than global activity
 - Tests added:

  TestSessionQueueSanitization (4 tests):
  1. test_owner_sees_all_fields - Owner sees complete queue item data
  2. test_admin_sees_all_fields - Admin sees complete queue item data
  3. test_non_owner_sees_only_status_timestamps_errors -
     Non-owner sees only item_id, queue_id, status, and timestamps; everything else is redacted
  4. test_sanitization_does_not_mutate_original - Sanitization doesn't modify the original object

  TestRecallParametersIsolation (2 tests):

  5. test_user1_write_does_not_leak_to_user2 - User1's recall params are not visible in user2's client state
  6. test_two_users_independent_state - Both users can write recall params independently without overwriting each other

fix(backend): queue status endpoints report user-scoped stats rather than global stats

* fix(workflow): do not filter default workflows in multiuser mode

  Problem: When categories=['user', 'default'] (or no category filter)
  and user_id was set for multiuser scoping, the SQL query became
     WHERE category IN ('user', 'default') AND user_id = ?,
     which  excluded default workflows (owned by "system").

  Fix: Changed user_id = ? to (user_id = ? OR category = 'default') in
  all 6 occurrences across workflow_records_sqlite.py — in get_many,
  counts_by_category, counts_by_tag, and get_all_tags. Default
  workflows are now always visible regardless of user scoping.

  Tests added (2):
  - test_default_workflows_visible_when_listing_user_and_default — categories=['user','default'] includes both
  - test_default_workflows_visible_when_no_category_filter — no filter still shows defaults

* fix(multiuser): scope queue/recall/intermediates endpoints to current user

Several read-only and event-emitting endpoints were leaking aggregate
cross-user activity in multiuser mode:

- recall_parameters_updated event was broadcast to every queue
  subscriber. Added user_id to the event and routed it to the owner +
  admin rooms only.
- get_queue_status, get_batch_status, counts_by_destination and
  get_intermediates_count now scope counts to the calling user
  (admins still see global state). Removed the now-redundant
  user_pending/user_in_progress fields and simplified QueueCountBadge.
- get_queue_status hides current item_id/session_id/batch_id when the
  current item belongs to another user.

Also fixes test_session_queue_sanitization assertions that lagged
behind the recently expanded redaction set.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* chore(backend): ruff

* fix(multiuser): reject anonymous websockets and scope queue item events

Close three cross-user leaks in the websocket layer:

- _handle_connect() now rejects connections without a valid JWT in
  multiuser mode (previously fell through to user_id="system"), so
  anonymous clients can no longer subscribe to queue rooms and observe
  other users' activity. In single-user mode it still accepts as system
  admin.
- _handle_sub_queue() no longer silently falls back to the system user
  for an unknown sid in multiuser mode; it refuses the subscription.
- QueueItemStatusChangedEvent and BatchEnqueuedEvent are now routed to
  user:{user_id} + admin rooms instead of the full queue room. Both
  events carry unsanitized user_id, batch_id, origin, destination,
  session_id, and error metadata and must not be broadcast.
- BatchEnqueuedEvent gains a user_id field; emit_batch_enqueued and
  enqueue_batch thread it through.

New TestWebSocketAuth suite covers connect accept/reject for both
modes, sub_queue refusal, and private routing of the queue item and
batch events (plus a QueueClearedEvent sanity check).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(multiuser): verify user record on websocket connect

A deleted or deactivated user with an unexpired JWT could still open a
websocket and subscribe to queue rooms. Now _handle_connect() checks the
backing user record (exists + is_active) in multiuser mode, mirroring
the REST auth path in auth_dependencies.py. Fails closed if the user
service is unavailable.

Tests: added deleted-user and inactive-user rejection tests; updated
valid-token test to create the user in the database first.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(multiuser): close bulk download cross-user exfiltration path

Backend:
- POST /download now validates image read access (per-image) and board
  read access (per-board) before queuing the download.
- GET /download/{name} is intentionally unauthenticated because the
  browser triggers it via <a download> which cannot carry Authorization
  headers. Access control relies on POST-time checks, UUID filename
  unguessability, private socket event routing, and single-fetch deletion.
- Added _assert_board_read_access() helper to images router.
- Threaded user_id through bulk download handler, base class, event
  emission, and BulkDownloadEventBase so events carry the initiator.
- Bulk download service now tracks download ownership via _download_owners
  dict (cleaned up on delete).
- Socket bulk_download room subscription restricted to authenticated
  sockets in multiuser mode.
- Added error-catching in FastAPIEventService._dispatch_from_queue to
  prevent silent event dispatch failures.

Frontend:
- Fixed pre-existing race condition where the "Preparing Download" toast
  from the POST response overwrote the "Ready to Download" toast from the
  socket event (background task completes in ~17ms, so the socket event
  can arrive before Redux processes the HTTP response). Toast IDs are now
  distinct: "preparing:{name}" vs "{name}".
- bulk_download_complete/error handlers now dismiss the preparing toast.

Tests (8 new):
- Bulk download by image names rejected for non-owner (403)
- Bulk download by image names allowed for owner (202)
- Bulk download from private board rejected (403)
- Bulk download from shared board allowed (202)
- Admin can bulk download any images (202)
- Bulk download events carry user_id
- Bulk download event emitted to download room
- GET /download unauthenticated returns 404 for unknown files

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(multiuser): enforce board visibility on image listing endpoints

GET /api/v1/images?board_id=... and GET /api/v1/images/names?board_id=...
passed board_id directly to the SQL layer without checking board
visibility. The SQL only applied user_id filtering for board_id="none"
(uncategorized images), so any authenticated user who knew a private
board ID could enumerate its images.

Both endpoints now call _assert_board_read_access() before querying,
returning 403 unless the caller is the board owner, an admin, or the
board is Shared/Public.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* chore(backend): ruff

* fix(multiuser): require image ownership when adding images to boards

add_image_to_board and add_images_to_board only checked write access to
the destination board, never verifying that the caller owned the source
image.  An attacker could add a victim's image to their own board, then
exploit the board-ownership fallback in _assert_image_owner to gain
delete/patch/star/unstar rights on the image.

Both endpoints now call _assert_image_direct_owner which requires direct
image ownership (image_records.user_id) or admin — board ownership is
intentionally not sufficient, preventing the escalation chain.

Also fixed a pre-existing bug where HTTPException from the inner loop in
add_images_to_board was caught by the outer except-Exception and returned
as 500 instead of propagating the correct status code.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* chore(backend): ruff

* fix(multiuser): validate image access in recall parameter resolution

The recall endpoint loaded image files and ran ControlNet preprocessors
on any image_name supplied in control_layers or ip_adapters without
checking that the caller could read the image.  An attacker who knew
another user's image UUID could extract dimensions and, for supported
preprocessors, mint a derived processed image they could then fetch.

Added _assert_recall_image_access() which validates read access for every
image referenced in the request before any resolution or processing
occurs.  Access is granted to the image owner, admins, or when the image
sits on a Shared/Public board.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(multiuser): require admin auth on model install job endpoints

list_model_installs, get_model_install_job, pause, resume,
restart_failed, and restart_file were unauthenticated — any caller who
could reach the API could view sensitive install job fields (source,
local_path, error_traceback) and interfere with installation state.

All six endpoints now require AdminUserOrDefault, consistent with the
neighboring cancel and prune routes.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(multiuser): close bulk download exfiltration and additional review findings

Bulk download capability token exfiltration:
- Socket events now route to user:{user_id} + admin rooms instead of the
  shared 'default' room (the earlier toast race that blocked this approach
  was fixed in a prior commit).
- GET /download/{name} re-requires CurrentUserOrDefault and enforces
  ownership via get_owner().
- Frontend download handler replaced <a download> (which cannot carry auth
  headers) with fetch() + Authorization header + programmatic blob download.

Additional fixes from reviewer tests:
- Public boards now grant write access in _assert_board_write_access and
  mutation rights in _assert_image_owner (BoardVisibility.Public).
- Uncategorized image listing (GET /boards/none/image_names) now filters
  to the caller's images only, preventing cross-user enumeration.
- board_images router uses board_image_records.get_board_for_image()
  instead of images.get_dto() to avoid dependency on image_files service.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(multiuser): add user_id scoping to workflow SQL mutations

Defense-in-depth: the route layer already checks ownership before
calling update/delete/update_is_public/update_opened_at, but the SQL
statements did not include AND user_id = ?, so a bypass of the route
check would allow cross-user mutations.

All four methods now accept an optional user_id parameter.  When
provided, the SQL WHERE clause is scoped to that user.  The route layer
passes current_user.user_id for non-admin callers and None for admins.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* fix(multiuser): allow non-owner uploads to public boards

upload_image() blocked non-owner uploads even to public boards.  The
board write check now allows uploads when board_visibility is Public,
consistent with the public-board semantics in _assert_board_write_access
and _assert_image_owner.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: lstein <111189+lstein@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com>
2026-04-13 17:27:20 -04:00

650 lines
27 KiB
Python

import datetime
import json
from itertools import chain, product
from typing import Generator, Literal, Optional, TypeAlias, Union
from pydantic import (
AliasChoices,
BaseModel,
ConfigDict,
Field,
StrictStr,
TypeAdapter,
field_validator,
model_validator,
)
from pydantic_core import to_jsonable_python
from invokeai.app.invocations.fields import ImageField
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
from invokeai.app.services.workflow_records.workflow_records_common import (
WorkflowWithoutID,
WorkflowWithoutIDValidator,
)
from invokeai.app.util.misc import uuid_string
# region Errors
class BatchZippedLengthError(ValueError):
"""Raise when a batch has items of different lengths."""
class BatchItemsTypeError(ValueError): # this cannot be a TypeError in pydantic v2
"""Raise when a batch has items of different types."""
class BatchDuplicateNodeFieldError(ValueError):
"""Raise when a batch has duplicate node_path and field_name."""
class TooManySessionsError(ValueError):
"""Raise when too many sessions are requested."""
class SessionQueueItemNotFoundError(ValueError):
"""Raise when a queue item is not found."""
# endregion
# region Batch
BatchDataType = Union[StrictStr, float, int, ImageField]
class NodeFieldValue(BaseModel):
node_path: str = Field(description="The node into which this batch data item will be substituted.")
field_name: str = Field(description="The field into which this batch data item will be substituted.")
value: BatchDataType = Field(description="The value to substitute into the node/field.")
class BatchDatum(BaseModel):
node_path: str = Field(description="The node into which this batch data collection will be substituted.")
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
items: list[BatchDataType] = Field(
default_factory=list, description="The list of items to substitute into the node/field."
)
BatchDataCollection: TypeAlias = list[list[BatchDatum]]
class Batch(BaseModel):
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
origin: str | None = Field(
default=None,
description="The origin of this queue item. This data is used by the frontend to determine how to handle results.",
)
destination: str | None = Field(
default=None,
description="The origin of this queue item. This data is used by the frontend to determine how to handle results",
)
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
graph: Graph = Field(description="The graph to initialize the session with")
workflow: Optional[WorkflowWithoutID] = Field(
default=None, description="The workflow to initialize the session with"
)
runs: int = Field(
default=1, ge=1, description="Int stating how many times to iterate through all possible batch indices"
)
@field_validator("data")
def validate_lengths(cls, v: Optional[BatchDataCollection]):
if v is None:
return v
for batch_data_list in v:
first_item_length = len(batch_data_list[0].items) if batch_data_list and batch_data_list[0].items else 0
for i in batch_data_list:
if len(i.items) != first_item_length:
raise BatchZippedLengthError("Zipped batch items must all have the same length")
return v
@field_validator("data")
def validate_types(cls, v: Optional[BatchDataCollection]):
if v is None:
return v
for batch_data_list in v:
for datum in batch_data_list:
if not datum.items:
continue
# Special handling for numbers - they can be mixed
# TODO(psyche): Update BatchDatum to have a `type` field to specify the type of the items, then we can have strict float and int fields
if all(isinstance(item, (int, float)) for item in datum.items):
continue
# Get the type of the first item in the list
first_item_type = type(datum.items[0])
for item in datum.items:
if type(item) is not first_item_type:
raise BatchItemsTypeError("All items in a batch must have the same type")
return v
@field_validator("data")
def validate_unique_field_mappings(cls, v: Optional[BatchDataCollection]):
if v is None:
return v
paths: set[tuple[str, str]] = set()
for batch_data_list in v:
for datum in batch_data_list:
pair = (datum.node_path, datum.field_name)
if pair in paths:
raise BatchDuplicateNodeFieldError("Each batch data must have unique node_id and field_name")
paths.add(pair)
return v
@model_validator(mode="after")
def validate_batch_nodes_and_edges(self):
if self.data is None:
return self
for batch_data_list in self.data:
for batch_data in batch_data_list:
try:
node = self.graph.get_node(batch_data.node_path)
except NodeNotFoundError:
raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph")
if batch_data.field_name not in type(node).model_fields:
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
return self
@field_validator("graph")
def validate_graph(cls, v: Graph):
v.validate_self()
return v
model_config = ConfigDict(
json_schema_extra={
"required": [
"graph",
"runs",
]
}
)
# endregion Batch
# region Queue Items
DEFAULT_QUEUE_ID = "default"
SYSTEM_USER_ID = "system" # Default user_id for system-generated queue items
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
class ItemIdsResult(BaseModel):
"""Response containing ordered item ids with metadata for optimistic updates."""
item_ids: list[int] = Field(description="Ordered list of item ids")
total_count: int = Field(description="Total number of queue items matching the query")
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
def get_field_values(queue_item_dict: dict) -> Optional[list[NodeFieldValue]]:
field_values_raw = queue_item_dict.get("field_values", None)
return NodeFieldValueValidator.validate_json(field_values_raw) if field_values_raw is not None else None
GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
def get_session(queue_item_dict: dict) -> GraphExecutionState:
session_raw = queue_item_dict.get("session", "{}")
session = GraphExecutionStateValidator.validate_json(session_raw, strict=False)
return session
def get_workflow(queue_item_dict: dict) -> Optional[WorkflowWithoutID]:
workflow_raw = queue_item_dict.get("workflow", None)
if workflow_raw is not None:
workflow = WorkflowWithoutIDValidator.validate_json(workflow_raw, strict=False)
return workflow
return None
class FieldIdentifier(BaseModel):
kind: Literal["input", "output"] = Field(description="The kind of field")
node_id: str = Field(description="The ID of the node")
field_name: str = Field(description="The name of the field")
user_label: str | None = Field(description="The user label of the field, if any")
class SessionQueueItem(BaseModel):
"""Session queue item without the full graph. Used for serialization."""
item_id: int = Field(description="The identifier of the session queue item")
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
priority: int = Field(default=0, description="The priority of this queue item")
batch_id: str = Field(description="The ID of the batch associated with this queue item")
origin: str | None = Field(
default=None,
description="The origin of this queue item. This data is used by the frontend to determine how to handle results.",
)
destination: str | None = Field(
default=None,
description="The origin of this queue item. This data is used by the frontend to determine how to handle results",
)
session_id: str = Field(
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
)
error_type: Optional[str] = Field(default=None, description="The error type if this queue item errored")
error_message: Optional[str] = Field(default=None, description="The error message if this queue item errored")
error_traceback: Optional[str] = Field(
default=None,
description="The error traceback if this queue item errored",
validation_alias=AliasChoices("error_traceback", "error"),
)
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")
completed_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was completed")
queue_id: str = Field(description="The id of the queue with which this item is associated")
user_id: str = Field(default="system", description="The id of the user who created this queue item")
user_display_name: Optional[str] = Field(
default=None, description="The display name of the user who created this queue item, if available"
)
user_email: Optional[str] = Field(
default=None, description="The email of the user who created this queue item, if available"
)
field_values: Optional[list[NodeFieldValue]] = Field(
default=None, description="The field values that were used for this queue item"
)
retried_from_item_id: Optional[int] = Field(
default=None, description="The item_id of the queue item that this item was retried from"
)
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
workflow: Optional[WorkflowWithoutID] = Field(
default=None, description="The workflow associated with this queue item"
)
@classmethod
def queue_item_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItem":
# must parse these manually
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
queue_item_dict["session"] = get_session(queue_item_dict)
queue_item_dict["workflow"] = get_workflow(queue_item_dict)
return SessionQueueItem(**queue_item_dict)
model_config = ConfigDict(
json_schema_extra={
"required": [
"item_id",
"status",
"batch_id",
"queue_id",
"session_id",
"session",
"priority",
"session_id",
"created_at",
"updated_at",
]
}
)
# endregion Queue Items
# region Query Results
class SessionQueueStatus(BaseModel):
queue_id: str = Field(..., description="The ID of the queue")
item_id: Optional[int] = Field(description="The current queue item id")
batch_id: Optional[str] = Field(description="The current queue item's batch id")
session_id: Optional[str] = Field(description="The current queue item's session id")
pending: int = Field(..., description="Number of queue items with status 'pending'")
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
completed: int = Field(..., description="Number of queue items with status 'complete'")
failed: int = Field(..., description="Number of queue items with status 'error'")
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
total: int = Field(..., description="Total number of queue items")
class SessionQueueCountsByDestination(BaseModel):
queue_id: str = Field(..., description="The ID of the queue")
destination: str = Field(..., description="The destination of queue items included in this status")
pending: int = Field(..., description="Number of queue items with status 'pending' for the destination")
in_progress: int = Field(..., description="Number of queue items with status 'in_progress' for the destination")
completed: int = Field(..., description="Number of queue items with status 'complete' for the destination")
failed: int = Field(..., description="Number of queue items with status 'error' for the destination")
canceled: int = Field(..., description="Number of queue items with status 'canceled' for the destination")
total: int = Field(..., description="Total number of queue items for the destination")
class BatchStatus(BaseModel):
queue_id: str = Field(..., description="The ID of the queue")
batch_id: str = Field(..., description="The ID of the batch")
origin: str | None = Field(..., description="The origin of the batch")
destination: str | None = Field(..., description="The destination of the batch")
pending: int = Field(..., description="Number of queue items with status 'pending'")
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
completed: int = Field(..., description="Number of queue items with status 'complete'")
failed: int = Field(..., description="Number of queue items with status 'error'")
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
total: int = Field(..., description="Total number of queue items")
class EnqueueBatchResult(BaseModel):
queue_id: str = Field(description="The ID of the queue")
enqueued: int = Field(description="The total number of queue items enqueued")
requested: int = Field(description="The total number of queue items requested to be enqueued")
batch: Batch = Field(description="The batch that was enqueued")
priority: int = Field(description="The priority of the enqueued batch")
item_ids: list[int] = Field(description="The IDs of the queue items that were enqueued")
class RetryItemsResult(BaseModel):
queue_id: str = Field(description="The ID of the queue")
retried_item_ids: list[int] = Field(description="The IDs of the queue items that were retried")
class ClearResult(BaseModel):
"""Result of clearing the session queue"""
deleted: int = Field(..., description="Number of queue items deleted")
class PruneResult(ClearResult):
"""Result of pruning the session queue"""
pass
class CancelByBatchIDsResult(BaseModel):
"""Result of canceling by list of batch ids"""
canceled: int = Field(..., description="Number of queue items canceled")
class CancelByDestinationResult(CancelByBatchIDsResult):
"""Result of canceling by a destination"""
pass
class DeleteByDestinationResult(BaseModel):
"""Result of deleting by a destination"""
deleted: int = Field(..., description="Number of queue items deleted")
class DeleteAllExceptCurrentResult(DeleteByDestinationResult):
"""Result of deleting all except current"""
pass
class CancelByQueueIDResult(CancelByBatchIDsResult):
"""Result of canceling by queue id"""
pass
class CancelAllExceptCurrentResult(CancelByBatchIDsResult):
"""Result of canceling all except current"""
pass
class IsEmptyResult(BaseModel):
"""Result of checking if the session queue is empty"""
is_empty: bool = Field(..., description="Whether the session queue is empty")
class IsFullResult(BaseModel):
"""Result of checking if the session queue is full"""
is_full: bool = Field(..., description="Whether the session queue is full")
# endregion Query Results
# region Util
def create_session_nfv_tuples(batch: Batch, maximum: int) -> Generator[tuple[str, str, str], None, None]:
"""
Given a batch and a maximum number of sessions to create, generate a tuple of session_id, session_json, and
field_values_json for each session.
The batch has a "source" graph and a data property. The data property is a list of lists of BatchDatum objects.
Each BatchDatum has a field identifier (e.g. a node id and field name), and a list of values to substitute into
the field.
This structure allows us to create a new graph for every possible permutation of BatchDatum objects:
- Each BatchDatum can be "expanded" into a dict of node-field-value tuples - one for each item in the BatchDatum.
- Zip each inner list of expanded BatchDatum objects together. Call this a "batch_data_list".
- Take the cartesian product of all zipped batch_data_lists, resulting in a list of permutations of BatchDatum
- Take the cartesian product of all zipped batch_data_lists, resulting in a list of lists of BatchDatum objects.
Each inner list now represents the substitution values for a single permutation (session).
- For each permutation, substitute the values into the graph
This function is optimized for performance, as it is used to generate a large number of sessions at once.
Args:
batch: The batch to generate sessions from
maximum: The maximum number of sessions to generate
Returns:
A generator that yields tuples of session_id, session_json, and field_values_json for each session. The
generator will stop early if the maximum number of sessions is reached.
"""
# TODO: Should this be a class method on Batch?
data: list[list[tuple[dict]]] = []
batch_data_collection = batch.data if batch.data is not None else []
for batch_datum_list in batch_data_collection:
node_field_values_to_zip: list[list[dict]] = []
# Expand each BatchDatum into a list of dicts - one for each item in the BatchDatum
for batch_datum in batch_datum_list:
node_field_values = [
# Note: A tuple here is slightly faster than a dict, but we need the object in dict form to be inserted
# in the session_queue table anyways. So, overall creating NFVs as dicts is faster.
{"node_path": batch_datum.node_path, "field_name": batch_datum.field_name, "value": item}
for item in batch_datum.items
]
node_field_values_to_zip.append(node_field_values)
# Zip the dicts together to create a list of dicts for each permutation
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]
# We serialize the graph and session once, then mutate the graph dict in place for each session.
#
# This sounds scary, but it's actually fine.
#
# The batch prep logic injects field values into the same fields for each generated session.
#
# For example, after the product operation, we'll end up with a list of node-field-value tuples like this:
# [
# (
# {"node_path": "1", "field_name": "a", "value": 1},
# {"node_path": "2", "field_name": "b", "value": 2},
# {"node_path": "3", "field_name": "c", "value": 3},
# ),
# (
# {"node_path": "1", "field_name": "a", "value": 4},
# {"node_path": "2", "field_name": "b", "value": 5},
# {"node_path": "3", "field_name": "c", "value": 6},
# )
# ]
#
# Note that each tuple has the same length, and each tuple substitutes values in for exactly the same node fields.
# No matter the complexity of the batch, this property holds true.
#
# This means each permutation's substitution can be done in-place on the same graph dict, because it overwrites the
# previous mutation. We only need to serialize the graph once, and then we can mutate it in place for each session.
#
# Previously, we had created new Graph objects for each session, but this was very slow for large (1k+ session
# batches). We then tried dumping the graph to dict and using deep-copy to create a new dict for each session,
# but this was also slow.
#
# Overall, we achieved a 100x speedup by mutating the graph dict in place for each session over creating new Graph
# objects for each session.
#
# We will also mutate the session dict in place, setting a new ID for each session and setting the mutated graph
# dict as the session's graph.
# Dump the batch's graph to a dict once
graph_as_dict = batch.graph.model_dump(warnings=False, exclude_none=True)
# We must provide a Graph object when creating the "dummy" session dict, but we don't actually use it. It will be
# overwritten for each session by the mutated graph_as_dict.
session_dict = GraphExecutionState(graph=Graph()).model_dump(warnings=False, exclude_none=True)
# Now we can create a generator that yields the session_id, session_json, and field_values_json for each session.
count = 0
# Each batch may have multiple runs, so we need to generate the same number of sessions for each run. The total is
# still limited by the maximum number of sessions.
for _ in range(batch.runs):
for d in product(*data):
if count >= maximum:
# We've reached the maximum number of sessions we may generate
return
# Flatten the list of lists of dicts into a single list of dicts
# TODO(psyche): Is the a more efficient way to do this?
flat_node_field_values = list(chain.from_iterable(d))
# Need a fresh ID for each session
session_id = uuid_string()
# Mutate the session dict in place
session_dict["id"] = session_id
# Substitute the values into the graph
for nfv in flat_node_field_values:
graph_as_dict["nodes"][nfv["node_path"]][nfv["field_name"]] = nfv["value"]
# Mutate the session dict in place
session_dict["graph"] = graph_as_dict
# Serialize the session and field values
# Note the use of pydantic's to_jsonable_python to handle serialization of any python object, including sets.
session_json = json.dumps(session_dict, default=to_jsonable_python)
field_values_json = json.dumps(flat_node_field_values, default=to_jsonable_python)
# Yield the session_id, session_json, and field_values_json
yield (session_id, session_json, field_values_json)
# Increment the count so we know when to stop
count += 1
def calc_session_count(batch: Batch) -> int:
"""
Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually
creating them, as is done in `create_session_nfv_tuples()`.
The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how
many were _actually_ created (which may be less due to the maximum number of sessions).
"""
# TODO: Should this be a class method on Batch?
if not batch.data:
return batch.runs
data = []
for batch_datum_list in batch.data:
to_zip = []
for batch_datum in batch_datum_list:
batch_data_items = range(len(batch_datum.items))
to_zip.append(batch_data_items)
data.append(list(zip(*to_zip, strict=True)))
data_product = list(product(*data))
return len(data_product) * batch.runs
ValueToInsertTuple: TypeAlias = tuple[
str, # queue_id
str, # session (as stringified JSON)
str, # session_id
str, # batch_id
str | None, # field_values (optional, as stringified JSON)
int, # priority
str | None, # workflow (optional, as stringified JSON)
str | None, # origin (optional)
str | None, # destination (optional)
int | None, # retried_from_item_id (optional, this is always None for new items)
str, # user_id
]
"""A type alias for the tuple of values to insert into the session queue table.
**If you change this, be sure to update the `enqueue_batch` and `retry_items_by_id` methods in the session queue service!**
"""
def prepare_values_to_insert(
queue_id: str, batch: Batch, priority: int, max_new_queue_items: int, user_id: str = "system"
) -> list[ValueToInsertTuple]:
"""
Given a batch, prepare the values to insert into the session queue table. The list of tuples can be used with an
`executemany` statement to insert multiple rows at once.
Args:
queue_id: The ID of the queue to insert the items into
batch: The batch to prepare the values for
priority: The priority of the queue items
max_new_queue_items: The maximum number of queue items to insert
user_id: The user ID who is creating these queue items
Returns:
A list of tuples to insert into the session queue table. Each tuple contains the following values:
- queue_id
- session (as stringified JSON)
- session_id
- batch_id
- field_values (optional, as stringified JSON)
- priority
- workflow (optional, as stringified JSON)
- origin (optional)
- destination (optional)
- retried_from_item_id (optional, this is always None for new items)
- user_id
"""
# A tuple is a fast and memory-efficient way to store the values to insert. Previously, we used a NamedTuple, but
# measured a ~5% performance improvement by using a normal tuple instead. For very large batches (10k+ items), the
# this difference becomes noticeable.
#
# So, despite the inferior DX with normal tuples, we use one here for performance reasons.
values_to_insert: list[ValueToInsertTuple] = []
# pydantic's to_jsonable_python handles serialization of any python object, including sets, which json.dumps does
# not support by default. Apparently there are sets somewhere in the graph.
# The same workflow is used for all sessions in the batch - serialize it once
workflow_json = json.dumps(batch.workflow, default=to_jsonable_python) if batch.workflow else None
for session_id, session_json, field_values_json in create_session_nfv_tuples(batch, max_new_queue_items):
values_to_insert.append(
(
queue_id,
session_json,
session_id,
batch.batch_id,
field_values_json,
priority,
workflow_json,
batch.origin,
batch.destination,
None,
user_id,
)
)
return values_to_insert
# endregion Util
Batch.model_rebuild(force=True)
SessionQueueItem.model_rebuild(force=True)