mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-02 15:44:57 -05:00
Upgrade pydantic and fastapi to latest.
- pydantic~=2.4.2
- fastapi~=103.2
- fastapi-events~=0.9.1
**Big Changes**
There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes.
**Invocations**
The biggest change relates to invocation creation, instantiation and validation.
Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie.
Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`.
With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation.
This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method.
In the end, this implementation is cleaner.
**Invocation Fields**
In pydantic v2, you can no longer directly add or remove fields from a model.
Previously, we did this to add the `type` field to invocations.
**Invocation Decorators**
With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper.
A similar technique is used for `invocation_output()`.
**Minor Changes**
There are a number of minor changes around the pydantic v2 models API.
**Protected `model_` Namespace**
All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_".
Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple.
```py
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
model_config = ConfigDict(protected_namespaces=())
```
**Model Serialization**
Pydantic models no longer have `Model.dict()` or `Model.json()`.
Instead, we use `Model.model_dump()` or `Model.model_dump_json()`.
**Model Deserialization**
Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions.
Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model.
```py
adapter_graph = TypeAdapter(Graph)
deserialized_graph_from_json = adapter_graph.validate_json(graph_json)
deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict)
```
**Field Customisation**
Pydantic `Field`s no longer accept arbitrary args.
Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field.
**Schema Customisation**
FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec.
This necessitates two changes:
- Our schema customization logic has been revised
- Schema parsing to build node templates has been revised
The specific aren't important, but this does present additional surface area for bugs.
**Performance Improvements**
Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node.
I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
436 lines
16 KiB
Python
436 lines
16 KiB
Python
import datetime
|
|
import json
|
|
from itertools import chain, product
|
|
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
|
|
|
|
from pydantic import BaseModel, ConfigDict, Field, StrictStr, TypeAdapter, field_validator, model_validator
|
|
from pydantic_core import to_jsonable_python
|
|
|
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
|
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
|
|
from invokeai.app.util.misc import uuid_string
|
|
|
|
# region Errors
|
|
|
|
|
|
class BatchZippedLengthError(ValueError):
|
|
"""Raise when a batch has items of different lengths."""
|
|
|
|
|
|
class BatchItemsTypeError(ValueError): # this cannot be a TypeError in pydantic v2
|
|
"""Raise when a batch has items of different types."""
|
|
|
|
|
|
class BatchDuplicateNodeFieldError(ValueError):
|
|
"""Raise when a batch has duplicate node_path and field_name."""
|
|
|
|
|
|
class TooManySessionsError(ValueError):
|
|
"""Raise when too many sessions are requested."""
|
|
|
|
|
|
class SessionQueueItemNotFoundError(ValueError):
|
|
"""Raise when a queue item is not found."""
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region Batch
|
|
|
|
BatchDataType = Union[
|
|
StrictStr,
|
|
float,
|
|
int,
|
|
]
|
|
|
|
|
|
class NodeFieldValue(BaseModel):
|
|
node_path: str = Field(description="The node into which this batch data item will be substituted.")
|
|
field_name: str = Field(description="The field into which this batch data item will be substituted.")
|
|
value: BatchDataType = Field(description="The value to substitute into the node/field.")
|
|
|
|
|
|
class BatchDatum(BaseModel):
|
|
node_path: str = Field(description="The node into which this batch data collection will be substituted.")
|
|
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
|
|
items: list[BatchDataType] = Field(
|
|
default_factory=list, description="The list of items to substitute into the node/field."
|
|
)
|
|
|
|
|
|
BatchDataCollection: TypeAlias = list[list[BatchDatum]]
|
|
|
|
|
|
class Batch(BaseModel):
|
|
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
|
|
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
|
|
graph: Graph = Field(description="The graph to initialize the session with")
|
|
runs: int = Field(
|
|
default=1, ge=1, description="Int stating how many times to iterate through all possible batch indices"
|
|
)
|
|
|
|
@field_validator("data")
|
|
def validate_lengths(cls, v: Optional[BatchDataCollection]):
|
|
if v is None:
|
|
return v
|
|
for batch_data_list in v:
|
|
first_item_length = len(batch_data_list[0].items) if batch_data_list and batch_data_list[0].items else 0
|
|
for i in batch_data_list:
|
|
if len(i.items) != first_item_length:
|
|
raise BatchZippedLengthError("Zipped batch items must all have the same length")
|
|
return v
|
|
|
|
@field_validator("data")
|
|
def validate_types(cls, v: Optional[BatchDataCollection]):
|
|
if v is None:
|
|
return v
|
|
for batch_data_list in v:
|
|
for datum in batch_data_list:
|
|
# Get the type of the first item in the list
|
|
first_item_type = type(datum.items[0]) if datum.items else None
|
|
for item in datum.items:
|
|
if type(item) is not first_item_type:
|
|
raise BatchItemsTypeError("All items in a batch must have the same type")
|
|
return v
|
|
|
|
@field_validator("data")
|
|
def validate_unique_field_mappings(cls, v: Optional[BatchDataCollection]):
|
|
if v is None:
|
|
return v
|
|
paths: set[tuple[str, str]] = set()
|
|
for batch_data_list in v:
|
|
for datum in batch_data_list:
|
|
pair = (datum.node_path, datum.field_name)
|
|
if pair in paths:
|
|
raise BatchDuplicateNodeFieldError("Each batch data must have unique node_id and field_name")
|
|
paths.add(pair)
|
|
return v
|
|
|
|
@model_validator(mode="after")
|
|
def validate_batch_nodes_and_edges(cls, values):
|
|
batch_data_collection = cast(Optional[BatchDataCollection], values.data)
|
|
if batch_data_collection is None:
|
|
return values
|
|
graph = cast(Graph, values.graph)
|
|
for batch_data_list in batch_data_collection:
|
|
for batch_data in batch_data_list:
|
|
try:
|
|
node = cast(BaseInvocation, graph.get_node(batch_data.node_path))
|
|
except NodeNotFoundError:
|
|
raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph")
|
|
if batch_data.field_name not in node.model_fields:
|
|
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
|
|
return values
|
|
|
|
@field_validator("graph")
|
|
def validate_graph(cls, v: Graph):
|
|
v.validate_self()
|
|
return v
|
|
|
|
model_config = ConfigDict(
|
|
json_schema_extra=dict(
|
|
required=[
|
|
"graph",
|
|
"runs",
|
|
]
|
|
)
|
|
)
|
|
|
|
|
|
# endregion Batch
|
|
|
|
|
|
# region Queue Items
|
|
|
|
DEFAULT_QUEUE_ID = "default"
|
|
|
|
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
|
|
|
|
adapter_NodeFieldValue = TypeAdapter(list[NodeFieldValue])
|
|
|
|
|
|
def get_field_values(queue_item_dict: dict) -> Optional[list[NodeFieldValue]]:
|
|
field_values_raw = queue_item_dict.get("field_values", None)
|
|
return adapter_NodeFieldValue.validate_json(field_values_raw) if field_values_raw is not None else None
|
|
|
|
|
|
adapter_GraphExecutionState = TypeAdapter(GraphExecutionState)
|
|
|
|
|
|
def get_session(queue_item_dict: dict) -> GraphExecutionState:
|
|
session_raw = queue_item_dict.get("session", "{}")
|
|
session = adapter_GraphExecutionState.validate_json(session_raw, strict=False)
|
|
return session
|
|
|
|
|
|
class SessionQueueItemWithoutGraph(BaseModel):
|
|
"""Session queue item without the full graph. Used for serialization."""
|
|
|
|
item_id: int = Field(description="The identifier of the session queue item")
|
|
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
|
|
priority: int = Field(default=0, description="The priority of this queue item")
|
|
batch_id: str = Field(description="The ID of the batch associated with this queue item")
|
|
session_id: str = Field(
|
|
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
|
|
)
|
|
error: Optional[str] = Field(default=None, description="The error message if this queue item errored")
|
|
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
|
|
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
|
|
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")
|
|
completed_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was completed")
|
|
queue_id: str = Field(description="The id of the queue with which this item is associated")
|
|
field_values: Optional[list[NodeFieldValue]] = Field(
|
|
default=None, description="The field values that were used for this queue item"
|
|
)
|
|
|
|
@classmethod
|
|
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
|
|
# must parse these manually
|
|
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
|
return SessionQueueItemDTO(**queue_item_dict)
|
|
|
|
model_config = ConfigDict(
|
|
json_schema_extra=dict(
|
|
required=[
|
|
"item_id",
|
|
"status",
|
|
"batch_id",
|
|
"queue_id",
|
|
"session_id",
|
|
"priority",
|
|
"session_id",
|
|
"created_at",
|
|
"updated_at",
|
|
]
|
|
)
|
|
)
|
|
|
|
|
|
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
|
|
pass
|
|
|
|
|
|
class SessionQueueItem(SessionQueueItemWithoutGraph):
|
|
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
|
|
|
|
@classmethod
|
|
def queue_item_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItem":
|
|
# must parse these manually
|
|
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
|
queue_item_dict["session"] = get_session(queue_item_dict)
|
|
return SessionQueueItem(**queue_item_dict)
|
|
|
|
model_config = ConfigDict(
|
|
json_schema_extra=dict(
|
|
required=[
|
|
"item_id",
|
|
"status",
|
|
"batch_id",
|
|
"queue_id",
|
|
"session_id",
|
|
"session",
|
|
"priority",
|
|
"session_id",
|
|
"created_at",
|
|
"updated_at",
|
|
]
|
|
)
|
|
)
|
|
|
|
|
|
# endregion Queue Items
|
|
|
|
# region Query Results
|
|
|
|
|
|
class SessionQueueStatus(BaseModel):
|
|
queue_id: str = Field(..., description="The ID of the queue")
|
|
item_id: Optional[int] = Field(description="The current queue item id")
|
|
batch_id: Optional[str] = Field(description="The current queue item's batch id")
|
|
session_id: Optional[str] = Field(description="The current queue item's session id")
|
|
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
|
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
|
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
|
failed: int = Field(..., description="Number of queue items with status 'error'")
|
|
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
|
|
total: int = Field(..., description="Total number of queue items")
|
|
|
|
|
|
class BatchStatus(BaseModel):
|
|
queue_id: str = Field(..., description="The ID of the queue")
|
|
batch_id: str = Field(..., description="The ID of the batch")
|
|
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
|
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
|
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
|
failed: int = Field(..., description="Number of queue items with status 'error'")
|
|
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
|
|
total: int = Field(..., description="Total number of queue items")
|
|
|
|
|
|
class EnqueueBatchResult(BaseModel):
|
|
queue_id: str = Field(description="The ID of the queue")
|
|
enqueued: int = Field(description="The total number of queue items enqueued")
|
|
requested: int = Field(description="The total number of queue items requested to be enqueued")
|
|
batch: Batch = Field(description="The batch that was enqueued")
|
|
priority: int = Field(description="The priority of the enqueued batch")
|
|
|
|
|
|
class EnqueueGraphResult(BaseModel):
|
|
enqueued: int = Field(description="The total number of queue items enqueued")
|
|
requested: int = Field(description="The total number of queue items requested to be enqueued")
|
|
batch: Batch = Field(description="The batch that was enqueued")
|
|
priority: int = Field(description="The priority of the enqueued batch")
|
|
queue_item: SessionQueueItemDTO = Field(description="The queue item that was enqueued")
|
|
|
|
|
|
class ClearResult(BaseModel):
|
|
"""Result of clearing the session queue"""
|
|
|
|
deleted: int = Field(..., description="Number of queue items deleted")
|
|
|
|
|
|
class PruneResult(ClearResult):
|
|
"""Result of pruning the session queue"""
|
|
|
|
pass
|
|
|
|
|
|
class CancelByBatchIDsResult(BaseModel):
|
|
"""Result of canceling by list of batch ids"""
|
|
|
|
canceled: int = Field(..., description="Number of queue items canceled")
|
|
|
|
|
|
class CancelByQueueIDResult(CancelByBatchIDsResult):
|
|
"""Result of canceling by queue id"""
|
|
|
|
pass
|
|
|
|
|
|
class IsEmptyResult(BaseModel):
|
|
"""Result of checking if the session queue is empty"""
|
|
|
|
is_empty: bool = Field(..., description="Whether the session queue is empty")
|
|
|
|
|
|
class IsFullResult(BaseModel):
|
|
"""Result of checking if the session queue is full"""
|
|
|
|
is_full: bool = Field(..., description="Whether the session queue is full")
|
|
|
|
|
|
# endregion Query Results
|
|
|
|
|
|
# region Util
|
|
|
|
|
|
def populate_graph(graph: Graph, node_field_values: Iterable[NodeFieldValue]) -> Graph:
|
|
"""
|
|
Populates the given graph with the given batch data items.
|
|
"""
|
|
graph_clone = graph.model_copy(deep=True)
|
|
for item in node_field_values:
|
|
node = graph_clone.get_node(item.node_path)
|
|
if node is None:
|
|
continue
|
|
setattr(node, item.field_name, item.value)
|
|
graph_clone.update_node(item.node_path, node)
|
|
return graph_clone
|
|
|
|
|
|
def create_session_nfv_tuples(
|
|
batch: Batch, maximum: int
|
|
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue]], None, None]:
|
|
"""
|
|
Create all graph permutations from the given batch data and graph. Yields tuples
|
|
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
|
|
that was applied to the graph.
|
|
"""
|
|
|
|
# TODO: Should this be a class method on Batch?
|
|
|
|
data: list[list[tuple[NodeFieldValue]]] = []
|
|
batch_data_collection = batch.data if batch.data is not None else []
|
|
for batch_datum_list in batch_data_collection:
|
|
# each batch_datum_list needs to be convered to NodeFieldValues and then zipped
|
|
|
|
node_field_values_to_zip: list[list[NodeFieldValue]] = []
|
|
for batch_datum in batch_datum_list:
|
|
node_field_values = [
|
|
NodeFieldValue(node_path=batch_datum.node_path, field_name=batch_datum.field_name, value=item)
|
|
for item in batch_datum.items
|
|
]
|
|
node_field_values_to_zip.append(node_field_values)
|
|
data.append(list(zip(*node_field_values_to_zip))) # type: ignore [arg-type]
|
|
|
|
# create generator to yield session,nfv tuples
|
|
count = 0
|
|
for _ in range(batch.runs):
|
|
for d in product(*data):
|
|
if count >= maximum:
|
|
return
|
|
flat_node_field_values = list(chain.from_iterable(d))
|
|
graph = populate_graph(batch.graph, flat_node_field_values)
|
|
yield (GraphExecutionState(graph=graph), flat_node_field_values)
|
|
count += 1
|
|
|
|
|
|
def calc_session_count(batch: Batch) -> int:
|
|
"""
|
|
Calculates the number of sessions that would be created by the batch, without incurring
|
|
the overhead of actually generating them. Adapted from `create_sessions().
|
|
"""
|
|
# TODO: Should this be a class method on Batch?
|
|
if not batch.data:
|
|
return batch.runs
|
|
data = []
|
|
for batch_datum_list in batch.data:
|
|
to_zip = []
|
|
for batch_datum in batch_datum_list:
|
|
batch_data_items = range(len(batch_datum.items))
|
|
to_zip.append(batch_data_items)
|
|
data.append(list(zip(*to_zip)))
|
|
data_product = list(product(*data))
|
|
return len(data_product) * batch.runs
|
|
|
|
|
|
class SessionQueueValueToInsert(NamedTuple):
|
|
"""A tuple of values to insert into the session_queue table"""
|
|
|
|
queue_id: str # queue_id
|
|
session: str # session json
|
|
session_id: str # session_id
|
|
batch_id: str # batch_id
|
|
field_values: Optional[str] # field_values json
|
|
priority: int # priority
|
|
|
|
|
|
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
|
|
|
|
|
def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> ValuesToInsert:
|
|
values_to_insert: ValuesToInsert = []
|
|
for session, field_values in create_session_nfv_tuples(batch, max_new_queue_items):
|
|
# sessions must have unique id
|
|
session.id = uuid_string()
|
|
values_to_insert.append(
|
|
SessionQueueValueToInsert(
|
|
queue_id, # queue_id
|
|
session.model_dump_json(warnings=False, exclude_none=True), # session (json)
|
|
session.id, # session_id
|
|
batch.batch_id, # batch_id
|
|
# must use pydantic_encoder bc field_values is a list of models
|
|
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
|
|
priority, # priority
|
|
)
|
|
)
|
|
return values_to_insert
|
|
|
|
|
|
# endregion Util
|
|
|
|
Batch.model_rebuild(force=True)
|
|
SessionQueueItem.model_rebuild(force=True)
|