mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-24 19:18:04 -05:00
* chore: bump pydantic to 2.5.2 This release fixes pydantic/pydantic#8175 and allows us to use `JsonValue` * fix(ui): exclude public/en.json from prettier config * fix(workflow_records): fix SQLite workflow insertion to ignore duplicates * feat(backend): update workflows handling Update workflows handling for Workflow Library. **Updated Workflow Storage** "Embedded Workflows" are workflows associated with images, and are now only stored in the image files. "Library Workflows" are not associated with images, and are stored only in DB. This works out nicely. We have always saved workflows to files, but recently began saving them to the DB in addition to in image files. When that happened, we stopped reading workflows from files, so all the workflows that only existed in images were inaccessible. With this change, access to those workflows is restored, and no workflows are lost. **Updated Workflow Handling in Nodes** Prior to this change, workflows were embedded in images by passing the whole workflow JSON to a special workflow field on a node. In the node's `invoke()` function, the node was able to access this workflow and save it with the image. This (inaccurately) models workflows as a property of an image and is rather awkward technically. A workflow is now a property of a batch/session queue item. It is available in the InvocationContext and therefore available to all nodes during `invoke()`. **Database Migrations** Added a `SQLiteMigrator` class to handle database migrations. Migrations were needed to accomodate the DB-related changes in this PR. See the code for details. The `images`, `workflows` and `session_queue` tables required migrations for this PR, and are using the new migrator. Other tables/services are still creating tables themselves. A followup PR will adapt them to use the migrator. **Other/Support Changes** - Add a `has_workflow` column to `images` table to indicate that the image has an embedded workflow. - Add handling for retrieving the workflow from an image in python. The image file must be fetched, the workflow extracted, and then sent to client, avoiding needing the browser to parse the image file. With the `has_workflow` column, the UI knows if there is a workflow to be fetched, and only fetches when the user requests to load the workflow. - Add route to get the workflow from an image - Add CRUD service/routes for the library workflows - `workflow_images` table and services removed (no longer needed now that embedded workflows are not in the DB) * feat(ui): updated workflow handling (WIP) Clientside updates for the backend workflow changes. Includes roughed-out workflow library UI. * feat: revert SQLiteMigrator class Will pursue this in a separate PR. * feat(nodes): do not overwrite custom node module names Use a different, simpler method to detect if a node is custom. * feat(nodes): restore WithWorkflow as no-op class This class is deprecated and no longer needed. Set its workflow attr value to None (meaning it is now a no-op), and issue a warning when an invocation subclasses it. * fix(nodes): fix get_workflow from queue item dict func * feat(backend): add WorkflowRecordListItemDTO This is the id, name, description, created at and updated at workflow columns/attrs. Used to display lists of workflowsl * chore(ui): typegen * feat(ui): add workflow loading, deleting to workflow library UI * feat(ui): workflow library pagination button styles * wip * feat: workflow library WIP - Save to library - Duplicate - Filter/sort - UI/queries * feat: workflow library - system graphs - wip * feat(backend): sync system workflows to db * fix: merge conflicts * feat: simplify default workflows - Rename "system" -> "default" - Simplify syncing logic - Update UI to match * feat(workflows): update default workflows - Update TextToImage_SD15 - Add TextToImage_SDXL - Add README * feat(ui): refine workflow list UI * fix(workflow_records): typo * fix(tests): fix tests * feat(ui): clean up workflow library hooks * fix(db): fix mis-ordered db cleanup step It was happening before pruning queue items - should happen afterwards, else you have to restart the app again to free disk space made available by the pruning. * feat(ui): tweak reset workflow editor translations * feat(ui): split out workflow redux state The `nodes` slice is a rather complicated slice. Removing `workflow` makes it a bit more reasonable. Also helps to flatten state out a bit. * docs: update default workflows README * fix: tidy up unused files, unrelated changes * fix(backend): revert unrelated service organisational changes * feat(backend): workflow_records.get_many arg "filter_text" -> "query" * feat(ui): use custom hook in current image buttons Already in use elsewhere, forgot to use it here. * fix(ui): remove commented out property * fix(ui): fix workflow loading - Different handling for loading from library vs external - Fix bug where only nodes and edges loaded * fix(ui): fix save/save-as workflow naming * fix(ui): fix circular dependency * fix(db): fix bug with releasing without lock in db.clean() * fix(db): remove extraneous lock * chore: bump ruff * fix(workflow_records): default `category` to `WorkflowCategory.User` This allows old workflows to validate when reading them from the db or image files. * hide workflow library buttons if feature is disabled --------- Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
450 lines
17 KiB
Python
450 lines
17 KiB
Python
import datetime
|
|
import json
|
|
from itertools import chain, product
|
|
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
|
|
|
|
from pydantic import BaseModel, ConfigDict, Field, StrictStr, TypeAdapter, field_validator, model_validator
|
|
from pydantic_core import to_jsonable_python
|
|
|
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
|
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
|
|
from invokeai.app.services.workflow_records.workflow_records_common import (
|
|
WorkflowWithoutID,
|
|
WorkflowWithoutIDValidator,
|
|
)
|
|
from invokeai.app.util.misc import uuid_string
|
|
|
|
# region Errors
|
|
|
|
|
|
class BatchZippedLengthError(ValueError):
|
|
"""Raise when a batch has items of different lengths."""
|
|
|
|
|
|
class BatchItemsTypeError(ValueError): # this cannot be a TypeError in pydantic v2
|
|
"""Raise when a batch has items of different types."""
|
|
|
|
|
|
class BatchDuplicateNodeFieldError(ValueError):
|
|
"""Raise when a batch has duplicate node_path and field_name."""
|
|
|
|
|
|
class TooManySessionsError(ValueError):
|
|
"""Raise when too many sessions are requested."""
|
|
|
|
|
|
class SessionQueueItemNotFoundError(ValueError):
|
|
"""Raise when a queue item is not found."""
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region Batch
|
|
|
|
BatchDataType = Union[
|
|
StrictStr,
|
|
float,
|
|
int,
|
|
]
|
|
|
|
|
|
class NodeFieldValue(BaseModel):
|
|
node_path: str = Field(description="The node into which this batch data item will be substituted.")
|
|
field_name: str = Field(description="The field into which this batch data item will be substituted.")
|
|
value: BatchDataType = Field(description="The value to substitute into the node/field.")
|
|
|
|
|
|
class BatchDatum(BaseModel):
|
|
node_path: str = Field(description="The node into which this batch data collection will be substituted.")
|
|
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
|
|
items: list[BatchDataType] = Field(
|
|
default_factory=list, description="The list of items to substitute into the node/field."
|
|
)
|
|
|
|
|
|
BatchDataCollection: TypeAlias = list[list[BatchDatum]]
|
|
|
|
|
|
class Batch(BaseModel):
|
|
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
|
|
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
|
|
graph: Graph = Field(description="The graph to initialize the session with")
|
|
workflow: Optional[WorkflowWithoutID] = Field(
|
|
default=None, description="The workflow to initialize the session with"
|
|
)
|
|
runs: int = Field(
|
|
default=1, ge=1, description="Int stating how many times to iterate through all possible batch indices"
|
|
)
|
|
|
|
@field_validator("data")
|
|
def validate_lengths(cls, v: Optional[BatchDataCollection]):
|
|
if v is None:
|
|
return v
|
|
for batch_data_list in v:
|
|
first_item_length = len(batch_data_list[0].items) if batch_data_list and batch_data_list[0].items else 0
|
|
for i in batch_data_list:
|
|
if len(i.items) != first_item_length:
|
|
raise BatchZippedLengthError("Zipped batch items must all have the same length")
|
|
return v
|
|
|
|
@field_validator("data")
|
|
def validate_types(cls, v: Optional[BatchDataCollection]):
|
|
if v is None:
|
|
return v
|
|
for batch_data_list in v:
|
|
for datum in batch_data_list:
|
|
# Get the type of the first item in the list
|
|
first_item_type = type(datum.items[0]) if datum.items else None
|
|
for item in datum.items:
|
|
if type(item) is not first_item_type:
|
|
raise BatchItemsTypeError("All items in a batch must have the same type")
|
|
return v
|
|
|
|
@field_validator("data")
|
|
def validate_unique_field_mappings(cls, v: Optional[BatchDataCollection]):
|
|
if v is None:
|
|
return v
|
|
paths: set[tuple[str, str]] = set()
|
|
for batch_data_list in v:
|
|
for datum in batch_data_list:
|
|
pair = (datum.node_path, datum.field_name)
|
|
if pair in paths:
|
|
raise BatchDuplicateNodeFieldError("Each batch data must have unique node_id and field_name")
|
|
paths.add(pair)
|
|
return v
|
|
|
|
@model_validator(mode="after")
|
|
def validate_batch_nodes_and_edges(cls, values):
|
|
batch_data_collection = cast(Optional[BatchDataCollection], values.data)
|
|
if batch_data_collection is None:
|
|
return values
|
|
graph = cast(Graph, values.graph)
|
|
for batch_data_list in batch_data_collection:
|
|
for batch_data in batch_data_list:
|
|
try:
|
|
node = cast(BaseInvocation, graph.get_node(batch_data.node_path))
|
|
except NodeNotFoundError:
|
|
raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph")
|
|
if batch_data.field_name not in node.model_fields:
|
|
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
|
|
return values
|
|
|
|
@field_validator("graph")
|
|
def validate_graph(cls, v: Graph):
|
|
v.validate_self()
|
|
return v
|
|
|
|
model_config = ConfigDict(
|
|
json_schema_extra={
|
|
"required": [
|
|
"graph",
|
|
"runs",
|
|
]
|
|
}
|
|
)
|
|
|
|
|
|
# endregion Batch
|
|
|
|
|
|
# region Queue Items
|
|
|
|
DEFAULT_QUEUE_ID = "default"
|
|
|
|
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
|
|
|
|
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
|
|
|
|
|
|
def get_field_values(queue_item_dict: dict) -> Optional[list[NodeFieldValue]]:
|
|
field_values_raw = queue_item_dict.get("field_values", None)
|
|
return NodeFieldValueValidator.validate_json(field_values_raw) if field_values_raw is not None else None
|
|
|
|
|
|
GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
|
|
|
|
|
|
def get_session(queue_item_dict: dict) -> GraphExecutionState:
|
|
session_raw = queue_item_dict.get("session", "{}")
|
|
session = GraphExecutionStateValidator.validate_json(session_raw, strict=False)
|
|
return session
|
|
|
|
|
|
def get_workflow(queue_item_dict: dict) -> Optional[WorkflowWithoutID]:
|
|
workflow_raw = queue_item_dict.get("workflow", None)
|
|
if workflow_raw is not None:
|
|
workflow = WorkflowWithoutIDValidator.validate_json(workflow_raw, strict=False)
|
|
return workflow
|
|
return None
|
|
|
|
|
|
class SessionQueueItemWithoutGraph(BaseModel):
|
|
"""Session queue item without the full graph. Used for serialization."""
|
|
|
|
item_id: int = Field(description="The identifier of the session queue item")
|
|
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
|
|
priority: int = Field(default=0, description="The priority of this queue item")
|
|
batch_id: str = Field(description="The ID of the batch associated with this queue item")
|
|
session_id: str = Field(
|
|
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
|
|
)
|
|
error: Optional[str] = Field(default=None, description="The error message if this queue item errored")
|
|
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
|
|
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
|
|
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")
|
|
completed_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was completed")
|
|
queue_id: str = Field(description="The id of the queue with which this item is associated")
|
|
field_values: Optional[list[NodeFieldValue]] = Field(
|
|
default=None, description="The field values that were used for this queue item"
|
|
)
|
|
|
|
@classmethod
|
|
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
|
|
# must parse these manually
|
|
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
|
return SessionQueueItemDTO(**queue_item_dict)
|
|
|
|
model_config = ConfigDict(
|
|
json_schema_extra={
|
|
"required": [
|
|
"item_id",
|
|
"status",
|
|
"batch_id",
|
|
"queue_id",
|
|
"session_id",
|
|
"priority",
|
|
"session_id",
|
|
"created_at",
|
|
"updated_at",
|
|
]
|
|
}
|
|
)
|
|
|
|
|
|
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
|
|
pass
|
|
|
|
|
|
class SessionQueueItem(SessionQueueItemWithoutGraph):
|
|
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
|
|
workflow: Optional[WorkflowWithoutID] = Field(
|
|
default=None, description="The workflow associated with this queue item"
|
|
)
|
|
|
|
@classmethod
|
|
def queue_item_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItem":
|
|
# must parse these manually
|
|
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
|
queue_item_dict["session"] = get_session(queue_item_dict)
|
|
queue_item_dict["workflow"] = get_workflow(queue_item_dict)
|
|
return SessionQueueItem(**queue_item_dict)
|
|
|
|
model_config = ConfigDict(
|
|
json_schema_extra={
|
|
"required": [
|
|
"item_id",
|
|
"status",
|
|
"batch_id",
|
|
"queue_id",
|
|
"session_id",
|
|
"session",
|
|
"priority",
|
|
"session_id",
|
|
"created_at",
|
|
"updated_at",
|
|
]
|
|
}
|
|
)
|
|
|
|
|
|
# endregion Queue Items
|
|
|
|
# region Query Results
|
|
|
|
|
|
class SessionQueueStatus(BaseModel):
|
|
queue_id: str = Field(..., description="The ID of the queue")
|
|
item_id: Optional[int] = Field(description="The current queue item id")
|
|
batch_id: Optional[str] = Field(description="The current queue item's batch id")
|
|
session_id: Optional[str] = Field(description="The current queue item's session id")
|
|
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
|
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
|
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
|
failed: int = Field(..., description="Number of queue items with status 'error'")
|
|
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
|
|
total: int = Field(..., description="Total number of queue items")
|
|
|
|
|
|
class BatchStatus(BaseModel):
|
|
queue_id: str = Field(..., description="The ID of the queue")
|
|
batch_id: str = Field(..., description="The ID of the batch")
|
|
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
|
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
|
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
|
failed: int = Field(..., description="Number of queue items with status 'error'")
|
|
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
|
|
total: int = Field(..., description="Total number of queue items")
|
|
|
|
|
|
class EnqueueBatchResult(BaseModel):
|
|
queue_id: str = Field(description="The ID of the queue")
|
|
enqueued: int = Field(description="The total number of queue items enqueued")
|
|
requested: int = Field(description="The total number of queue items requested to be enqueued")
|
|
batch: Batch = Field(description="The batch that was enqueued")
|
|
priority: int = Field(description="The priority of the enqueued batch")
|
|
|
|
|
|
class ClearResult(BaseModel):
|
|
"""Result of clearing the session queue"""
|
|
|
|
deleted: int = Field(..., description="Number of queue items deleted")
|
|
|
|
|
|
class PruneResult(ClearResult):
|
|
"""Result of pruning the session queue"""
|
|
|
|
pass
|
|
|
|
|
|
class CancelByBatchIDsResult(BaseModel):
|
|
"""Result of canceling by list of batch ids"""
|
|
|
|
canceled: int = Field(..., description="Number of queue items canceled")
|
|
|
|
|
|
class CancelByQueueIDResult(CancelByBatchIDsResult):
|
|
"""Result of canceling by queue id"""
|
|
|
|
pass
|
|
|
|
|
|
class IsEmptyResult(BaseModel):
|
|
"""Result of checking if the session queue is empty"""
|
|
|
|
is_empty: bool = Field(..., description="Whether the session queue is empty")
|
|
|
|
|
|
class IsFullResult(BaseModel):
|
|
"""Result of checking if the session queue is full"""
|
|
|
|
is_full: bool = Field(..., description="Whether the session queue is full")
|
|
|
|
|
|
# endregion Query Results
|
|
|
|
|
|
# region Util
|
|
|
|
|
|
def populate_graph(graph: Graph, node_field_values: Iterable[NodeFieldValue]) -> Graph:
|
|
"""
|
|
Populates the given graph with the given batch data items.
|
|
"""
|
|
graph_clone = graph.model_copy(deep=True)
|
|
for item in node_field_values:
|
|
node = graph_clone.get_node(item.node_path)
|
|
if node is None:
|
|
continue
|
|
setattr(node, item.field_name, item.value)
|
|
graph_clone.update_node(item.node_path, node)
|
|
return graph_clone
|
|
|
|
|
|
def create_session_nfv_tuples(
|
|
batch: Batch, maximum: int
|
|
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue], Optional[WorkflowWithoutID]], None, None]:
|
|
"""
|
|
Create all graph permutations from the given batch data and graph. Yields tuples
|
|
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
|
|
that was applied to the graph.
|
|
"""
|
|
|
|
# TODO: Should this be a class method on Batch?
|
|
|
|
data: list[list[tuple[NodeFieldValue]]] = []
|
|
batch_data_collection = batch.data if batch.data is not None else []
|
|
for batch_datum_list in batch_data_collection:
|
|
# each batch_datum_list needs to be convered to NodeFieldValues and then zipped
|
|
|
|
node_field_values_to_zip: list[list[NodeFieldValue]] = []
|
|
for batch_datum in batch_datum_list:
|
|
node_field_values = [
|
|
NodeFieldValue(node_path=batch_datum.node_path, field_name=batch_datum.field_name, value=item)
|
|
for item in batch_datum.items
|
|
]
|
|
node_field_values_to_zip.append(node_field_values)
|
|
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]
|
|
|
|
# create generator to yield session,nfv tuples
|
|
count = 0
|
|
for _ in range(batch.runs):
|
|
for d in product(*data):
|
|
if count >= maximum:
|
|
return
|
|
flat_node_field_values = list(chain.from_iterable(d))
|
|
graph = populate_graph(batch.graph, flat_node_field_values)
|
|
yield (GraphExecutionState(graph=graph), flat_node_field_values, batch.workflow)
|
|
count += 1
|
|
|
|
|
|
def calc_session_count(batch: Batch) -> int:
|
|
"""
|
|
Calculates the number of sessions that would be created by the batch, without incurring
|
|
the overhead of actually generating them. Adapted from `create_sessions().
|
|
"""
|
|
# TODO: Should this be a class method on Batch?
|
|
if not batch.data:
|
|
return batch.runs
|
|
data = []
|
|
for batch_datum_list in batch.data:
|
|
to_zip = []
|
|
for batch_datum in batch_datum_list:
|
|
batch_data_items = range(len(batch_datum.items))
|
|
to_zip.append(batch_data_items)
|
|
data.append(list(zip(*to_zip, strict=True)))
|
|
data_product = list(product(*data))
|
|
return len(data_product) * batch.runs
|
|
|
|
|
|
class SessionQueueValueToInsert(NamedTuple):
|
|
"""A tuple of values to insert into the session_queue table"""
|
|
|
|
# Careful with the ordering of this - it must match the insert statement
|
|
queue_id: str # queue_id
|
|
session: str # session json
|
|
session_id: str # session_id
|
|
batch_id: str # batch_id
|
|
field_values: Optional[str] # field_values json
|
|
priority: int # priority
|
|
workflow: Optional[str] # workflow json
|
|
|
|
|
|
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
|
|
|
|
|
def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> ValuesToInsert:
|
|
values_to_insert: ValuesToInsert = []
|
|
for session, field_values, workflow in create_session_nfv_tuples(batch, max_new_queue_items):
|
|
# sessions must have unique id
|
|
session.id = uuid_string()
|
|
values_to_insert.append(
|
|
SessionQueueValueToInsert(
|
|
queue_id, # queue_id
|
|
session.model_dump_json(warnings=False, exclude_none=True), # session (json)
|
|
session.id, # session_id
|
|
batch.batch_id, # batch_id
|
|
# must use pydantic_encoder bc field_values is a list of models
|
|
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
|
|
priority, # priority
|
|
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
|
|
)
|
|
)
|
|
return values_to_insert
|
|
|
|
|
|
# endregion Util
|
|
|
|
Batch.model_rebuild(force=True)
|
|
SessionQueueItem.model_rebuild(force=True)
|