Files
InvokeAI/invokeai/app/services/session_queue/session_queue_common.py
psychedelicious 831fb814cc feat(app): add origin to session queue
The origin is an optional field indicating the queue item's origin. For example, "canvas" when the queue item originated from the canvas or "workflows" when the queue item originated from the workflows tab. If omitted, we assume the queue item originated from the API directly.

- Add migration to add the nullable column to the `session_queue` table.
- Update relevant event payloads with the new field.
- Add `cancel_by_origin` method to `session_queue` service and corresponding route. This is required for the canvas to bail out early when staging images.
- Add `origin` to both `SessionQueueItem` and `Batch` - it needs to be provided initially via the batch and then passed onto the queue item.
-
2024-08-27 11:01:21 +10:00

485 lines
18 KiB
Python

import datetime
import json
from enum import Enum
from itertools import chain, product
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
from pydantic import (
AliasChoices,
BaseModel,
ConfigDict,
Field,
StrictStr,
TypeAdapter,
field_validator,
model_validator,
)
from pydantic_core import to_jsonable_python
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
from invokeai.app.services.workflow_records.workflow_records_common import (
WorkflowWithoutID,
WorkflowWithoutIDValidator,
)
from invokeai.app.util.metaenum import MetaEnum
from invokeai.app.util.misc import uuid_string
# region Errors
class BatchZippedLengthError(ValueError):
"""Raise when a batch has items of different lengths."""
class BatchItemsTypeError(ValueError): # this cannot be a TypeError in pydantic v2
"""Raise when a batch has items of different types."""
class BatchDuplicateNodeFieldError(ValueError):
"""Raise when a batch has duplicate node_path and field_name."""
class TooManySessionsError(ValueError):
"""Raise when too many sessions are requested."""
class SessionQueueItemNotFoundError(ValueError):
"""Raise when a queue item is not found."""
# endregion
# region Batch
BatchDataType = Union[
StrictStr,
float,
int,
]
class QueueItemOrigin(str, Enum, metaclass=MetaEnum):
"""The origin of a batch. For example, a batch can be created from the canvas or workflows tab."""
CANVAS = "canvas"
WORKFLOWS = "workflows"
class NodeFieldValue(BaseModel):
node_path: str = Field(description="The node into which this batch data item will be substituted.")
field_name: str = Field(description="The field into which this batch data item will be substituted.")
value: BatchDataType = Field(description="The value to substitute into the node/field.")
class BatchDatum(BaseModel):
node_path: str = Field(description="The node into which this batch data collection will be substituted.")
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
items: list[BatchDataType] = Field(
default_factory=list, description="The list of items to substitute into the node/field."
)
BatchDataCollection: TypeAlias = list[list[BatchDatum]]
class Batch(BaseModel):
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
origin: QueueItemOrigin | None = Field(default=None, description="The origin of this batch.")
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
graph: Graph = Field(description="The graph to initialize the session with")
workflow: Optional[WorkflowWithoutID] = Field(
default=None, description="The workflow to initialize the session with"
)
runs: int = Field(
default=1, ge=1, description="Int stating how many times to iterate through all possible batch indices"
)
@field_validator("data")
def validate_lengths(cls, v: Optional[BatchDataCollection]):
if v is None:
return v
for batch_data_list in v:
first_item_length = len(batch_data_list[0].items) if batch_data_list and batch_data_list[0].items else 0
for i in batch_data_list:
if len(i.items) != first_item_length:
raise BatchZippedLengthError("Zipped batch items must all have the same length")
return v
@field_validator("data")
def validate_types(cls, v: Optional[BatchDataCollection]):
if v is None:
return v
for batch_data_list in v:
for datum in batch_data_list:
# Get the type of the first item in the list
first_item_type = type(datum.items[0]) if datum.items else None
for item in datum.items:
if type(item) is not first_item_type:
raise BatchItemsTypeError("All items in a batch must have the same type")
return v
@field_validator("data")
def validate_unique_field_mappings(cls, v: Optional[BatchDataCollection]):
if v is None:
return v
paths: set[tuple[str, str]] = set()
for batch_data_list in v:
for datum in batch_data_list:
pair = (datum.node_path, datum.field_name)
if pair in paths:
raise BatchDuplicateNodeFieldError("Each batch data must have unique node_id and field_name")
paths.add(pair)
return v
@model_validator(mode="after")
def validate_batch_nodes_and_edges(cls, values):
batch_data_collection = cast(Optional[BatchDataCollection], values.data)
if batch_data_collection is None:
return values
graph = cast(Graph, values.graph)
for batch_data_list in batch_data_collection:
for batch_data in batch_data_list:
try:
node = cast(BaseInvocation, graph.get_node(batch_data.node_path))
except NodeNotFoundError:
raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph")
if batch_data.field_name not in node.model_fields:
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
return values
@field_validator("graph")
def validate_graph(cls, v: Graph):
v.validate_self()
return v
model_config = ConfigDict(
json_schema_extra={
"required": [
"graph",
"runs",
]
}
)
# endregion Batch
# region Queue Items
DEFAULT_QUEUE_ID = "default"
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
def get_field_values(queue_item_dict: dict) -> Optional[list[NodeFieldValue]]:
field_values_raw = queue_item_dict.get("field_values", None)
return NodeFieldValueValidator.validate_json(field_values_raw) if field_values_raw is not None else None
GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
def get_session(queue_item_dict: dict) -> GraphExecutionState:
session_raw = queue_item_dict.get("session", "{}")
session = GraphExecutionStateValidator.validate_json(session_raw, strict=False)
return session
def get_workflow(queue_item_dict: dict) -> Optional[WorkflowWithoutID]:
workflow_raw = queue_item_dict.get("workflow", None)
if workflow_raw is not None:
workflow = WorkflowWithoutIDValidator.validate_json(workflow_raw, strict=False)
return workflow
return None
class SessionQueueItemWithoutGraph(BaseModel):
"""Session queue item without the full graph. Used for serialization."""
item_id: int = Field(description="The identifier of the session queue item")
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
priority: int = Field(default=0, description="The priority of this queue item")
batch_id: str = Field(description="The ID of the batch associated with this queue item")
origin: QueueItemOrigin | None = Field(default=None, description="The origin of this queue item. ")
session_id: str = Field(
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
)
error_type: Optional[str] = Field(default=None, description="The error type if this queue item errored")
error_message: Optional[str] = Field(default=None, description="The error message if this queue item errored")
error_traceback: Optional[str] = Field(
default=None,
description="The error traceback if this queue item errored",
validation_alias=AliasChoices("error_traceback", "error"),
)
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")
completed_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was completed")
queue_id: str = Field(description="The id of the queue with which this item is associated")
field_values: Optional[list[NodeFieldValue]] = Field(
default=None, description="The field values that were used for this queue item"
)
@classmethod
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
# must parse these manually
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
return SessionQueueItemDTO(**queue_item_dict)
model_config = ConfigDict(
json_schema_extra={
"required": [
"item_id",
"status",
"batch_id",
"queue_id",
"session_id",
"priority",
"session_id",
"created_at",
"updated_at",
]
}
)
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
pass
class SessionQueueItem(SessionQueueItemWithoutGraph):
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
workflow: Optional[WorkflowWithoutID] = Field(
default=None, description="The workflow associated with this queue item"
)
@classmethod
def queue_item_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItem":
# must parse these manually
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
queue_item_dict["session"] = get_session(queue_item_dict)
queue_item_dict["workflow"] = get_workflow(queue_item_dict)
return SessionQueueItem(**queue_item_dict)
model_config = ConfigDict(
json_schema_extra={
"required": [
"item_id",
"status",
"batch_id",
"queue_id",
"session_id",
"session",
"priority",
"session_id",
"created_at",
"updated_at",
]
}
)
# endregion Queue Items
# region Query Results
class SessionQueueStatus(BaseModel):
queue_id: str = Field(..., description="The ID of the queue")
item_id: Optional[int] = Field(description="The current queue item id")
batch_id: Optional[str] = Field(description="The current queue item's batch id")
session_id: Optional[str] = Field(description="The current queue item's session id")
pending: int = Field(..., description="Number of queue items with status 'pending'")
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
completed: int = Field(..., description="Number of queue items with status 'complete'")
failed: int = Field(..., description="Number of queue items with status 'error'")
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
total: int = Field(..., description="Total number of queue items")
class BatchStatus(BaseModel):
queue_id: str = Field(..., description="The ID of the queue")
batch_id: str = Field(..., description="The ID of the batch")
origin: QueueItemOrigin | None = Field(..., description="The origin of the batch")
pending: int = Field(..., description="Number of queue items with status 'pending'")
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
completed: int = Field(..., description="Number of queue items with status 'complete'")
failed: int = Field(..., description="Number of queue items with status 'error'")
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
total: int = Field(..., description="Total number of queue items")
class EnqueueBatchResult(BaseModel):
queue_id: str = Field(description="The ID of the queue")
enqueued: int = Field(description="The total number of queue items enqueued")
requested: int = Field(description="The total number of queue items requested to be enqueued")
batch: Batch = Field(description="The batch that was enqueued")
priority: int = Field(description="The priority of the enqueued batch")
class ClearResult(BaseModel):
"""Result of clearing the session queue"""
deleted: int = Field(..., description="Number of queue items deleted")
class PruneResult(ClearResult):
"""Result of pruning the session queue"""
pass
class CancelByBatchIDsResult(BaseModel):
"""Result of canceling by list of batch ids"""
canceled: int = Field(..., description="Number of queue items canceled")
class CancelByOriginResult(BaseModel):
"""Result of canceling by list of batch ids"""
canceled: int = Field(..., description="Number of queue items canceled")
class CancelByQueueIDResult(CancelByBatchIDsResult):
"""Result of canceling by queue id"""
pass
class IsEmptyResult(BaseModel):
"""Result of checking if the session queue is empty"""
is_empty: bool = Field(..., description="Whether the session queue is empty")
class IsFullResult(BaseModel):
"""Result of checking if the session queue is full"""
is_full: bool = Field(..., description="Whether the session queue is full")
# endregion Query Results
# region Util
def populate_graph(graph: Graph, node_field_values: Iterable[NodeFieldValue]) -> Graph:
"""
Populates the given graph with the given batch data items.
"""
graph_clone = graph.model_copy(deep=True)
for item in node_field_values:
node = graph_clone.get_node(item.node_path)
if node is None:
continue
setattr(node, item.field_name, item.value)
graph_clone.update_node(item.node_path, node)
return graph_clone
def create_session_nfv_tuples(
batch: Batch, maximum: int
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue], Optional[WorkflowWithoutID]], None, None]:
"""
Create all graph permutations from the given batch data and graph. Yields tuples
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
that was applied to the graph.
"""
# TODO: Should this be a class method on Batch?
data: list[list[tuple[NodeFieldValue]]] = []
batch_data_collection = batch.data if batch.data is not None else []
for batch_datum_list in batch_data_collection:
# each batch_datum_list needs to be convered to NodeFieldValues and then zipped
node_field_values_to_zip: list[list[NodeFieldValue]] = []
for batch_datum in batch_datum_list:
node_field_values = [
NodeFieldValue(node_path=batch_datum.node_path, field_name=batch_datum.field_name, value=item)
for item in batch_datum.items
]
node_field_values_to_zip.append(node_field_values)
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]
# create generator to yield session,nfv tuples
count = 0
for _ in range(batch.runs):
for d in product(*data):
if count >= maximum:
return
flat_node_field_values = list(chain.from_iterable(d))
graph = populate_graph(batch.graph, flat_node_field_values)
yield (GraphExecutionState(graph=graph), flat_node_field_values, batch.workflow)
count += 1
def calc_session_count(batch: Batch) -> int:
"""
Calculates the number of sessions that would be created by the batch, without incurring
the overhead of actually generating them. Adapted from `create_sessions().
"""
# TODO: Should this be a class method on Batch?
if not batch.data:
return batch.runs
data = []
for batch_datum_list in batch.data:
to_zip = []
for batch_datum in batch_datum_list:
batch_data_items = range(len(batch_datum.items))
to_zip.append(batch_data_items)
data.append(list(zip(*to_zip, strict=True)))
data_product = list(product(*data))
return len(data_product) * batch.runs
class SessionQueueValueToInsert(NamedTuple):
"""A tuple of values to insert into the session_queue table"""
# Careful with the ordering of this - it must match the insert statement
queue_id: str # queue_id
session: str # session json
session_id: str # session_id
batch_id: str # batch_id
field_values: Optional[str] # field_values json
priority: int # priority
workflow: Optional[str] # workflow json
origin: QueueItemOrigin | None
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> ValuesToInsert:
values_to_insert: ValuesToInsert = []
for session, field_values, workflow in create_session_nfv_tuples(batch, max_new_queue_items):
# sessions must have unique id
session.id = uuid_string()
values_to_insert.append(
SessionQueueValueToInsert(
queue_id, # queue_id
session.model_dump_json(warnings=False, exclude_none=True), # session (json)
session.id, # session_id
batch.batch_id, # batch_id
# must use pydantic_encoder bc field_values is a list of models
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
priority, # priority
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
batch.origin, # origin
)
)
return values_to_insert
# endregion Util
Batch.model_rebuild(force=True)
SessionQueueItem.model_rebuild(force=True)