mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-29 03:58:27 -05:00
The origin is an optional field indicating the queue item's origin. For example, "canvas" when the queue item originated from the canvas or "workflows" when the queue item originated from the workflows tab. If omitted, we assume the queue item originated from the API directly. - Add migration to add the nullable column to the `session_queue` table. - Update relevant event payloads with the new field. - Add `cancel_by_origin` method to `session_queue` service and corresponding route. This is required for the canvas to bail out early when staging images. - Add `origin` to both `SessionQueueItem` and `Batch` - it needs to be provided initially via the batch and then passed onto the queue item. -
485 lines
18 KiB
Python
485 lines
18 KiB
Python
import datetime
|
|
import json
|
|
from enum import Enum
|
|
from itertools import chain, product
|
|
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
|
|
|
|
from pydantic import (
|
|
AliasChoices,
|
|
BaseModel,
|
|
ConfigDict,
|
|
Field,
|
|
StrictStr,
|
|
TypeAdapter,
|
|
field_validator,
|
|
model_validator,
|
|
)
|
|
from pydantic_core import to_jsonable_python
|
|
|
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
|
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
|
|
from invokeai.app.services.workflow_records.workflow_records_common import (
|
|
WorkflowWithoutID,
|
|
WorkflowWithoutIDValidator,
|
|
)
|
|
from invokeai.app.util.metaenum import MetaEnum
|
|
from invokeai.app.util.misc import uuid_string
|
|
|
|
# region Errors
|
|
|
|
|
|
class BatchZippedLengthError(ValueError):
|
|
"""Raise when a batch has items of different lengths."""
|
|
|
|
|
|
class BatchItemsTypeError(ValueError): # this cannot be a TypeError in pydantic v2
|
|
"""Raise when a batch has items of different types."""
|
|
|
|
|
|
class BatchDuplicateNodeFieldError(ValueError):
|
|
"""Raise when a batch has duplicate node_path and field_name."""
|
|
|
|
|
|
class TooManySessionsError(ValueError):
|
|
"""Raise when too many sessions are requested."""
|
|
|
|
|
|
class SessionQueueItemNotFoundError(ValueError):
|
|
"""Raise when a queue item is not found."""
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region Batch
|
|
|
|
BatchDataType = Union[
|
|
StrictStr,
|
|
float,
|
|
int,
|
|
]
|
|
|
|
|
|
class QueueItemOrigin(str, Enum, metaclass=MetaEnum):
|
|
"""The origin of a batch. For example, a batch can be created from the canvas or workflows tab."""
|
|
|
|
CANVAS = "canvas"
|
|
WORKFLOWS = "workflows"
|
|
|
|
|
|
class NodeFieldValue(BaseModel):
|
|
node_path: str = Field(description="The node into which this batch data item will be substituted.")
|
|
field_name: str = Field(description="The field into which this batch data item will be substituted.")
|
|
value: BatchDataType = Field(description="The value to substitute into the node/field.")
|
|
|
|
|
|
class BatchDatum(BaseModel):
|
|
node_path: str = Field(description="The node into which this batch data collection will be substituted.")
|
|
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
|
|
items: list[BatchDataType] = Field(
|
|
default_factory=list, description="The list of items to substitute into the node/field."
|
|
)
|
|
|
|
|
|
BatchDataCollection: TypeAlias = list[list[BatchDatum]]
|
|
|
|
|
|
class Batch(BaseModel):
|
|
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
|
|
origin: QueueItemOrigin | None = Field(default=None, description="The origin of this batch.")
|
|
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
|
|
graph: Graph = Field(description="The graph to initialize the session with")
|
|
workflow: Optional[WorkflowWithoutID] = Field(
|
|
default=None, description="The workflow to initialize the session with"
|
|
)
|
|
runs: int = Field(
|
|
default=1, ge=1, description="Int stating how many times to iterate through all possible batch indices"
|
|
)
|
|
|
|
@field_validator("data")
|
|
def validate_lengths(cls, v: Optional[BatchDataCollection]):
|
|
if v is None:
|
|
return v
|
|
for batch_data_list in v:
|
|
first_item_length = len(batch_data_list[0].items) if batch_data_list and batch_data_list[0].items else 0
|
|
for i in batch_data_list:
|
|
if len(i.items) != first_item_length:
|
|
raise BatchZippedLengthError("Zipped batch items must all have the same length")
|
|
return v
|
|
|
|
@field_validator("data")
|
|
def validate_types(cls, v: Optional[BatchDataCollection]):
|
|
if v is None:
|
|
return v
|
|
for batch_data_list in v:
|
|
for datum in batch_data_list:
|
|
# Get the type of the first item in the list
|
|
first_item_type = type(datum.items[0]) if datum.items else None
|
|
for item in datum.items:
|
|
if type(item) is not first_item_type:
|
|
raise BatchItemsTypeError("All items in a batch must have the same type")
|
|
return v
|
|
|
|
@field_validator("data")
|
|
def validate_unique_field_mappings(cls, v: Optional[BatchDataCollection]):
|
|
if v is None:
|
|
return v
|
|
paths: set[tuple[str, str]] = set()
|
|
for batch_data_list in v:
|
|
for datum in batch_data_list:
|
|
pair = (datum.node_path, datum.field_name)
|
|
if pair in paths:
|
|
raise BatchDuplicateNodeFieldError("Each batch data must have unique node_id and field_name")
|
|
paths.add(pair)
|
|
return v
|
|
|
|
@model_validator(mode="after")
|
|
def validate_batch_nodes_and_edges(cls, values):
|
|
batch_data_collection = cast(Optional[BatchDataCollection], values.data)
|
|
if batch_data_collection is None:
|
|
return values
|
|
graph = cast(Graph, values.graph)
|
|
for batch_data_list in batch_data_collection:
|
|
for batch_data in batch_data_list:
|
|
try:
|
|
node = cast(BaseInvocation, graph.get_node(batch_data.node_path))
|
|
except NodeNotFoundError:
|
|
raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph")
|
|
if batch_data.field_name not in node.model_fields:
|
|
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
|
|
return values
|
|
|
|
@field_validator("graph")
|
|
def validate_graph(cls, v: Graph):
|
|
v.validate_self()
|
|
return v
|
|
|
|
model_config = ConfigDict(
|
|
json_schema_extra={
|
|
"required": [
|
|
"graph",
|
|
"runs",
|
|
]
|
|
}
|
|
)
|
|
|
|
|
|
# endregion Batch
|
|
|
|
|
|
# region Queue Items
|
|
|
|
DEFAULT_QUEUE_ID = "default"
|
|
|
|
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
|
|
|
|
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
|
|
|
|
|
|
def get_field_values(queue_item_dict: dict) -> Optional[list[NodeFieldValue]]:
|
|
field_values_raw = queue_item_dict.get("field_values", None)
|
|
return NodeFieldValueValidator.validate_json(field_values_raw) if field_values_raw is not None else None
|
|
|
|
|
|
GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
|
|
|
|
|
|
def get_session(queue_item_dict: dict) -> GraphExecutionState:
|
|
session_raw = queue_item_dict.get("session", "{}")
|
|
session = GraphExecutionStateValidator.validate_json(session_raw, strict=False)
|
|
return session
|
|
|
|
|
|
def get_workflow(queue_item_dict: dict) -> Optional[WorkflowWithoutID]:
|
|
workflow_raw = queue_item_dict.get("workflow", None)
|
|
if workflow_raw is not None:
|
|
workflow = WorkflowWithoutIDValidator.validate_json(workflow_raw, strict=False)
|
|
return workflow
|
|
return None
|
|
|
|
|
|
class SessionQueueItemWithoutGraph(BaseModel):
|
|
"""Session queue item without the full graph. Used for serialization."""
|
|
|
|
item_id: int = Field(description="The identifier of the session queue item")
|
|
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
|
|
priority: int = Field(default=0, description="The priority of this queue item")
|
|
batch_id: str = Field(description="The ID of the batch associated with this queue item")
|
|
origin: QueueItemOrigin | None = Field(default=None, description="The origin of this queue item. ")
|
|
session_id: str = Field(
|
|
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
|
|
)
|
|
error_type: Optional[str] = Field(default=None, description="The error type if this queue item errored")
|
|
error_message: Optional[str] = Field(default=None, description="The error message if this queue item errored")
|
|
error_traceback: Optional[str] = Field(
|
|
default=None,
|
|
description="The error traceback if this queue item errored",
|
|
validation_alias=AliasChoices("error_traceback", "error"),
|
|
)
|
|
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
|
|
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
|
|
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")
|
|
completed_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was completed")
|
|
queue_id: str = Field(description="The id of the queue with which this item is associated")
|
|
field_values: Optional[list[NodeFieldValue]] = Field(
|
|
default=None, description="The field values that were used for this queue item"
|
|
)
|
|
|
|
@classmethod
|
|
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
|
|
# must parse these manually
|
|
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
|
return SessionQueueItemDTO(**queue_item_dict)
|
|
|
|
model_config = ConfigDict(
|
|
json_schema_extra={
|
|
"required": [
|
|
"item_id",
|
|
"status",
|
|
"batch_id",
|
|
"queue_id",
|
|
"session_id",
|
|
"priority",
|
|
"session_id",
|
|
"created_at",
|
|
"updated_at",
|
|
]
|
|
}
|
|
)
|
|
|
|
|
|
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
|
|
pass
|
|
|
|
|
|
class SessionQueueItem(SessionQueueItemWithoutGraph):
|
|
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
|
|
workflow: Optional[WorkflowWithoutID] = Field(
|
|
default=None, description="The workflow associated with this queue item"
|
|
)
|
|
|
|
@classmethod
|
|
def queue_item_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItem":
|
|
# must parse these manually
|
|
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
|
queue_item_dict["session"] = get_session(queue_item_dict)
|
|
queue_item_dict["workflow"] = get_workflow(queue_item_dict)
|
|
return SessionQueueItem(**queue_item_dict)
|
|
|
|
model_config = ConfigDict(
|
|
json_schema_extra={
|
|
"required": [
|
|
"item_id",
|
|
"status",
|
|
"batch_id",
|
|
"queue_id",
|
|
"session_id",
|
|
"session",
|
|
"priority",
|
|
"session_id",
|
|
"created_at",
|
|
"updated_at",
|
|
]
|
|
}
|
|
)
|
|
|
|
|
|
# endregion Queue Items
|
|
|
|
# region Query Results
|
|
|
|
|
|
class SessionQueueStatus(BaseModel):
|
|
queue_id: str = Field(..., description="The ID of the queue")
|
|
item_id: Optional[int] = Field(description="The current queue item id")
|
|
batch_id: Optional[str] = Field(description="The current queue item's batch id")
|
|
session_id: Optional[str] = Field(description="The current queue item's session id")
|
|
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
|
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
|
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
|
failed: int = Field(..., description="Number of queue items with status 'error'")
|
|
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
|
|
total: int = Field(..., description="Total number of queue items")
|
|
|
|
|
|
class BatchStatus(BaseModel):
|
|
queue_id: str = Field(..., description="The ID of the queue")
|
|
batch_id: str = Field(..., description="The ID of the batch")
|
|
origin: QueueItemOrigin | None = Field(..., description="The origin of the batch")
|
|
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
|
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
|
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
|
failed: int = Field(..., description="Number of queue items with status 'error'")
|
|
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
|
|
total: int = Field(..., description="Total number of queue items")
|
|
|
|
|
|
class EnqueueBatchResult(BaseModel):
|
|
queue_id: str = Field(description="The ID of the queue")
|
|
enqueued: int = Field(description="The total number of queue items enqueued")
|
|
requested: int = Field(description="The total number of queue items requested to be enqueued")
|
|
batch: Batch = Field(description="The batch that was enqueued")
|
|
priority: int = Field(description="The priority of the enqueued batch")
|
|
|
|
|
|
class ClearResult(BaseModel):
|
|
"""Result of clearing the session queue"""
|
|
|
|
deleted: int = Field(..., description="Number of queue items deleted")
|
|
|
|
|
|
class PruneResult(ClearResult):
|
|
"""Result of pruning the session queue"""
|
|
|
|
pass
|
|
|
|
|
|
class CancelByBatchIDsResult(BaseModel):
|
|
"""Result of canceling by list of batch ids"""
|
|
|
|
canceled: int = Field(..., description="Number of queue items canceled")
|
|
|
|
|
|
class CancelByOriginResult(BaseModel):
|
|
"""Result of canceling by list of batch ids"""
|
|
|
|
canceled: int = Field(..., description="Number of queue items canceled")
|
|
|
|
|
|
class CancelByQueueIDResult(CancelByBatchIDsResult):
|
|
"""Result of canceling by queue id"""
|
|
|
|
pass
|
|
|
|
|
|
class IsEmptyResult(BaseModel):
|
|
"""Result of checking if the session queue is empty"""
|
|
|
|
is_empty: bool = Field(..., description="Whether the session queue is empty")
|
|
|
|
|
|
class IsFullResult(BaseModel):
|
|
"""Result of checking if the session queue is full"""
|
|
|
|
is_full: bool = Field(..., description="Whether the session queue is full")
|
|
|
|
|
|
# endregion Query Results
|
|
|
|
|
|
# region Util
|
|
|
|
|
|
def populate_graph(graph: Graph, node_field_values: Iterable[NodeFieldValue]) -> Graph:
|
|
"""
|
|
Populates the given graph with the given batch data items.
|
|
"""
|
|
graph_clone = graph.model_copy(deep=True)
|
|
for item in node_field_values:
|
|
node = graph_clone.get_node(item.node_path)
|
|
if node is None:
|
|
continue
|
|
setattr(node, item.field_name, item.value)
|
|
graph_clone.update_node(item.node_path, node)
|
|
return graph_clone
|
|
|
|
|
|
def create_session_nfv_tuples(
|
|
batch: Batch, maximum: int
|
|
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue], Optional[WorkflowWithoutID]], None, None]:
|
|
"""
|
|
Create all graph permutations from the given batch data and graph. Yields tuples
|
|
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
|
|
that was applied to the graph.
|
|
"""
|
|
|
|
# TODO: Should this be a class method on Batch?
|
|
|
|
data: list[list[tuple[NodeFieldValue]]] = []
|
|
batch_data_collection = batch.data if batch.data is not None else []
|
|
for batch_datum_list in batch_data_collection:
|
|
# each batch_datum_list needs to be convered to NodeFieldValues and then zipped
|
|
|
|
node_field_values_to_zip: list[list[NodeFieldValue]] = []
|
|
for batch_datum in batch_datum_list:
|
|
node_field_values = [
|
|
NodeFieldValue(node_path=batch_datum.node_path, field_name=batch_datum.field_name, value=item)
|
|
for item in batch_datum.items
|
|
]
|
|
node_field_values_to_zip.append(node_field_values)
|
|
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]
|
|
|
|
# create generator to yield session,nfv tuples
|
|
count = 0
|
|
for _ in range(batch.runs):
|
|
for d in product(*data):
|
|
if count >= maximum:
|
|
return
|
|
flat_node_field_values = list(chain.from_iterable(d))
|
|
graph = populate_graph(batch.graph, flat_node_field_values)
|
|
yield (GraphExecutionState(graph=graph), flat_node_field_values, batch.workflow)
|
|
count += 1
|
|
|
|
|
|
def calc_session_count(batch: Batch) -> int:
|
|
"""
|
|
Calculates the number of sessions that would be created by the batch, without incurring
|
|
the overhead of actually generating them. Adapted from `create_sessions().
|
|
"""
|
|
# TODO: Should this be a class method on Batch?
|
|
if not batch.data:
|
|
return batch.runs
|
|
data = []
|
|
for batch_datum_list in batch.data:
|
|
to_zip = []
|
|
for batch_datum in batch_datum_list:
|
|
batch_data_items = range(len(batch_datum.items))
|
|
to_zip.append(batch_data_items)
|
|
data.append(list(zip(*to_zip, strict=True)))
|
|
data_product = list(product(*data))
|
|
return len(data_product) * batch.runs
|
|
|
|
|
|
class SessionQueueValueToInsert(NamedTuple):
|
|
"""A tuple of values to insert into the session_queue table"""
|
|
|
|
# Careful with the ordering of this - it must match the insert statement
|
|
queue_id: str # queue_id
|
|
session: str # session json
|
|
session_id: str # session_id
|
|
batch_id: str # batch_id
|
|
field_values: Optional[str] # field_values json
|
|
priority: int # priority
|
|
workflow: Optional[str] # workflow json
|
|
origin: QueueItemOrigin | None
|
|
|
|
|
|
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
|
|
|
|
|
def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> ValuesToInsert:
|
|
values_to_insert: ValuesToInsert = []
|
|
for session, field_values, workflow in create_session_nfv_tuples(batch, max_new_queue_items):
|
|
# sessions must have unique id
|
|
session.id = uuid_string()
|
|
values_to_insert.append(
|
|
SessionQueueValueToInsert(
|
|
queue_id, # queue_id
|
|
session.model_dump_json(warnings=False, exclude_none=True), # session (json)
|
|
session.id, # session_id
|
|
batch.batch_id, # batch_id
|
|
# must use pydantic_encoder bc field_values is a list of models
|
|
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
|
|
priority, # priority
|
|
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
|
|
batch.origin, # origin
|
|
)
|
|
)
|
|
return values_to_insert
|
|
|
|
|
|
# endregion Util
|
|
|
|
Batch.model_rebuild(force=True)
|
|
SessionQueueItem.model_rebuild(force=True)
|