mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
* Add chained collect node * test(frontend): align parseSchema fixtures with collect v1.1 and normalize undefined fields in assertions * fix(nodes): block collect-to-collect links when inferred item types differ --------- Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
1384 lines
57 KiB
Python
1384 lines
57 KiB
Python
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
import copy
|
|
import itertools
|
|
from collections import deque
|
|
from typing import Any, Deque, Iterable, Optional, Type, TypeVar, Union, get_args, get_origin
|
|
|
|
import networkx as nx
|
|
from pydantic import (
|
|
BaseModel,
|
|
ConfigDict,
|
|
GetCoreSchemaHandler,
|
|
GetJsonSchemaHandler,
|
|
PrivateAttr,
|
|
ValidationError,
|
|
field_validator,
|
|
)
|
|
from pydantic.fields import Field
|
|
from pydantic.json_schema import JsonSchemaValue
|
|
from pydantic_core import core_schema
|
|
|
|
# Importing * is bad karma but needed here for node detection
|
|
from invokeai.app.invocations import * # noqa: F401 F403
|
|
from invokeai.app.invocations.baseinvocation import (
|
|
BaseInvocation,
|
|
BaseInvocationOutput,
|
|
InvocationRegistry,
|
|
invocation,
|
|
invocation_output,
|
|
)
|
|
from invokeai.app.invocations.fields import Input, InputField, OutputField, UIType
|
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
from invokeai.app.util.misc import uuid_string
|
|
|
|
# in 3.10 this would be "from types import NoneType"
|
|
NoneType = type(None)
|
|
|
|
# Port name constants
|
|
ITEM_FIELD = "item"
|
|
COLLECTION_FIELD = "collection"
|
|
|
|
|
|
class EdgeConnection(BaseModel):
|
|
node_id: str = Field(description="The id of the node for this edge connection")
|
|
field: str = Field(description="The field for this connection")
|
|
|
|
def __eq__(self, other):
|
|
return (
|
|
isinstance(other, self.__class__)
|
|
and getattr(other, "node_id", None) == self.node_id
|
|
and getattr(other, "field", None) == self.field
|
|
)
|
|
|
|
def __hash__(self):
|
|
return hash(f"{self.node_id}.{self.field}")
|
|
|
|
|
|
class Edge(BaseModel):
|
|
source: EdgeConnection = Field(description="The connection for the edge's from node and field")
|
|
destination: EdgeConnection = Field(description="The connection for the edge's to node and field")
|
|
|
|
def __str__(self):
|
|
return f"{self.source.node_id}.{self.source.field} -> {self.destination.node_id}.{self.destination.field}"
|
|
|
|
|
|
def get_output_field_type(node: BaseInvocation, field: str) -> Any:
|
|
# TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which
|
|
# really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this
|
|
# would require some fairly significant changes and I don't want risk breaking anything.
|
|
try:
|
|
invocation_class = type(node)
|
|
invocation_output_class = invocation_class.get_output_annotation()
|
|
field_info = invocation_output_class.model_fields.get(field)
|
|
assert field_info is not None, f"Output field '{field}' not found in {invocation_output_class.get_type()}"
|
|
output_field_type = field_info.annotation
|
|
return output_field_type
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def get_input_field_type(node: BaseInvocation, field: str) -> Any:
|
|
# TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which
|
|
# really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this
|
|
# would require some fairly significant changes and I don't want risk breaking anything.
|
|
try:
|
|
invocation_class = type(node)
|
|
field_info = invocation_class.model_fields.get(field)
|
|
assert field_info is not None, f"Input field '{field}' not found in {invocation_class.get_type()}"
|
|
input_field_type = field_info.annotation
|
|
return input_field_type
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def is_union_subtype(t1, t2):
|
|
t1_args = get_args(t1)
|
|
t2_args = get_args(t2)
|
|
if not t1_args:
|
|
# t1 is a single type
|
|
return t1 in t2_args
|
|
else:
|
|
# t1 is a Union, check that all of its types are in t2_args
|
|
return all(arg in t2_args for arg in t1_args)
|
|
|
|
|
|
def is_list_or_contains_list(t):
|
|
t_args = get_args(t)
|
|
|
|
# If the type is a List
|
|
if get_origin(t) is list:
|
|
return True
|
|
|
|
# If the type is a Union
|
|
elif t_args:
|
|
# Check if any of the types in the Union is a List
|
|
for arg in t_args:
|
|
if get_origin(arg) is list:
|
|
return True
|
|
return False
|
|
|
|
|
|
def is_any(t: Any) -> bool:
|
|
return t == Any or Any in get_args(t)
|
|
|
|
|
|
def extract_collection_item_types(t: Any) -> set[Any]:
|
|
"""Extracts list item types from a collection annotation, including unions containing list branches."""
|
|
if is_any(t):
|
|
return {Any}
|
|
|
|
if get_origin(t) is list:
|
|
return {arg for arg in get_args(t) if arg != NoneType}
|
|
|
|
item_types: set[Any] = set()
|
|
for arg in get_args(t):
|
|
if is_any(arg):
|
|
item_types.add(Any)
|
|
elif get_origin(arg) is list:
|
|
item_types.update(item_arg for item_arg in get_args(arg) if item_arg != NoneType)
|
|
return item_types
|
|
|
|
|
|
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
|
|
if not from_type or not to_type:
|
|
return False
|
|
|
|
# Ports are compatible
|
|
if from_type == to_type or is_any(from_type) or is_any(to_type):
|
|
return True
|
|
|
|
if from_type in get_args(to_type):
|
|
return True
|
|
|
|
if to_type in get_args(from_type):
|
|
return True
|
|
|
|
# allow int -> float, pydantic will cast for us
|
|
if from_type is int and to_type is float:
|
|
return True
|
|
|
|
# allow int|float -> str, pydantic will cast for us
|
|
if (from_type is int or from_type is float) and to_type is str:
|
|
return True
|
|
|
|
# Prefer issubclass when both are real classes
|
|
try:
|
|
if isinstance(from_type, type) and isinstance(to_type, type):
|
|
return issubclass(from_type, to_type)
|
|
except TypeError:
|
|
pass
|
|
|
|
# Union-to-Union (or Union-to-non-Union) handling
|
|
return is_union_subtype(from_type, to_type)
|
|
|
|
|
|
def are_connections_compatible(
|
|
from_node: BaseInvocation, from_field: str, to_node: BaseInvocation, to_field: str
|
|
) -> bool:
|
|
"""Determines if a connection between fields of two nodes is compatible."""
|
|
|
|
# TODO: handle iterators and collectors
|
|
from_type = get_output_field_type(from_node, from_field)
|
|
to_type = get_input_field_type(to_node, to_field)
|
|
|
|
return are_connection_types_compatible(from_type, to_type)
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def copydeep(obj: T) -> T:
|
|
"""Deep-copies an object. If it is a pydantic model, use the model's copy method."""
|
|
if isinstance(obj, BaseModel):
|
|
return obj.model_copy(deep=True)
|
|
return copy.deepcopy(obj)
|
|
|
|
|
|
class NodeAlreadyInGraphError(ValueError):
|
|
pass
|
|
|
|
|
|
class InvalidEdgeError(ValueError):
|
|
pass
|
|
|
|
|
|
class NodeNotFoundError(ValueError):
|
|
pass
|
|
|
|
|
|
class NodeAlreadyExecutedError(ValueError):
|
|
pass
|
|
|
|
|
|
class DuplicateNodeIdError(ValueError):
|
|
pass
|
|
|
|
|
|
class NodeFieldNotFoundError(ValueError):
|
|
pass
|
|
|
|
|
|
class NodeIdMismatchError(ValueError):
|
|
pass
|
|
|
|
|
|
class CyclicalGraphError(ValueError):
|
|
pass
|
|
|
|
|
|
class UnknownGraphValidationError(ValueError):
|
|
pass
|
|
|
|
|
|
class NodeInputError(ValueError):
|
|
"""Raised when a node fails preparation. This occurs when a node's inputs are being set from its incomers, but an
|
|
input fails validation.
|
|
|
|
Attributes:
|
|
node: The node that failed preparation. Note: only successfully set fields will be accurate. Review the error to
|
|
determine which field caused the failure.
|
|
"""
|
|
|
|
def __init__(self, node: BaseInvocation, e: ValidationError):
|
|
self.original_error = e
|
|
self.node = node
|
|
# When preparing a node, we set each input one-at-a-time. We may thus safely assume that the first error
|
|
# represents the first input that failed.
|
|
self.failed_input = loc_to_dot_sep(e.errors()[0]["loc"])
|
|
super().__init__(f"Node {node.id} has invalid incoming input for {self.failed_input}")
|
|
|
|
|
|
def loc_to_dot_sep(loc: tuple[Union[str, int], ...]) -> str:
|
|
"""Helper to pretty-print pydantic error locations as dot-separated strings.
|
|
Taken from https://docs.pydantic.dev/latest/errors/errors/#customize-error-messages
|
|
"""
|
|
path = ""
|
|
for i, x in enumerate(loc):
|
|
if isinstance(x, str):
|
|
if i > 0:
|
|
path += "."
|
|
path += x
|
|
else:
|
|
path += f"[{x}]"
|
|
return path
|
|
|
|
|
|
@invocation_output("iterate_output")
|
|
class IterateInvocationOutput(BaseInvocationOutput):
|
|
"""Used to connect iteration outputs. Will be expanded to a specific output."""
|
|
|
|
item: Any = OutputField(
|
|
description="The item being iterated over", title="Collection Item", ui_type=UIType._CollectionItem
|
|
)
|
|
index: int = OutputField(description="The index of the item", title="Index")
|
|
total: int = OutputField(description="The total number of items", title="Total")
|
|
|
|
|
|
# TODO: Fill this out and move to invocations
|
|
@invocation("iterate", version="1.1.0")
|
|
class IterateInvocation(BaseInvocation):
|
|
"""Iterates over a list of items"""
|
|
|
|
collection: list[Any] = InputField(
|
|
description="The list of items to iterate over", default=[], ui_type=UIType._Collection
|
|
)
|
|
index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True)
|
|
|
|
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
|
|
"""Produces the outputs as values"""
|
|
return IterateInvocationOutput(item=self.collection[self.index], index=self.index, total=len(self.collection))
|
|
|
|
|
|
@invocation_output("collect_output")
|
|
class CollectInvocationOutput(BaseInvocationOutput):
|
|
collection: list[Any] = OutputField(
|
|
description="The collection of input items", title="Collection", ui_type=UIType._Collection
|
|
)
|
|
|
|
|
|
@invocation("collect", version="1.1.0")
|
|
class CollectInvocation(BaseInvocation):
|
|
"""Collects values into a collection"""
|
|
|
|
item: Optional[Any] = InputField(
|
|
default=None,
|
|
description="The item to collect (all inputs must be of the same type)",
|
|
ui_type=UIType._CollectionItem,
|
|
title="Collection Item",
|
|
input=Input.Connection,
|
|
)
|
|
collection: list[Any] = InputField(
|
|
description="An optional collection to append to",
|
|
default=[],
|
|
ui_type=UIType._Collection,
|
|
input=Input.Connection,
|
|
)
|
|
|
|
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
|
|
"""Invoke with provided services and return outputs."""
|
|
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
|
|
|
|
|
class AnyInvocation(BaseInvocation):
|
|
@classmethod
|
|
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
|
def validate_invocation(v: Any) -> "AnyInvocation":
|
|
return InvocationRegistry.get_invocation_typeadapter().validate_python(v)
|
|
|
|
return core_schema.no_info_plain_validator_function(validate_invocation)
|
|
|
|
@classmethod
|
|
def __get_pydantic_json_schema__(
|
|
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
|
) -> JsonSchemaValue:
|
|
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
|
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
|
oneOf: list[dict[str, str]] = []
|
|
names = [i.__name__ for i in InvocationRegistry.get_invocation_classes()]
|
|
for name in sorted(names):
|
|
oneOf.append({"$ref": f"#/components/schemas/{name}"})
|
|
return {"oneOf": oneOf}
|
|
|
|
|
|
class AnyInvocationOutput(BaseInvocationOutput):
|
|
@classmethod
|
|
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
|
|
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
|
|
return InvocationRegistry.get_output_typeadapter().validate_python(v)
|
|
|
|
return core_schema.no_info_plain_validator_function(validate_invocation_output)
|
|
|
|
@classmethod
|
|
def __get_pydantic_json_schema__(
|
|
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
|
) -> JsonSchemaValue:
|
|
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
|
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
|
|
|
oneOf: list[dict[str, str]] = []
|
|
names = [i.__name__ for i in InvocationRegistry.get_output_classes()]
|
|
for name in sorted(names):
|
|
oneOf.append({"$ref": f"#/components/schemas/{name}"})
|
|
return {"oneOf": oneOf}
|
|
|
|
|
|
class Graph(BaseModel):
|
|
id: str = Field(description="The id of this graph", default_factory=uuid_string)
|
|
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
|
nodes: dict[str, AnyInvocation] = Field(description="The nodes in this graph", default_factory=dict)
|
|
edges: list[Edge] = Field(
|
|
description="The connections between nodes and their fields in this graph",
|
|
default_factory=list,
|
|
)
|
|
|
|
def add_node(self, node: BaseInvocation) -> None:
|
|
"""Adds a node to a graph
|
|
|
|
:raises NodeAlreadyInGraphError: the node is already present in the graph.
|
|
"""
|
|
|
|
if node.id in self.nodes:
|
|
raise NodeAlreadyInGraphError()
|
|
|
|
self.nodes[node.id] = node
|
|
|
|
def delete_node(self, node_id: str) -> None:
|
|
"""Deletes a node from a graph"""
|
|
|
|
try:
|
|
# Delete edges for this node
|
|
input_edges = self._get_input_edges(node_id)
|
|
output_edges = self._get_output_edges(node_id)
|
|
|
|
for edge in input_edges:
|
|
self.delete_edge(edge)
|
|
|
|
for edge in output_edges:
|
|
self.delete_edge(edge)
|
|
|
|
del self.nodes[node_id]
|
|
|
|
except NodeNotFoundError:
|
|
pass # Ignore, not doesn't exist (should this throw?)
|
|
|
|
def add_edge(self, edge: Edge) -> None:
|
|
"""Adds an edge to a graph
|
|
|
|
:raises InvalidEdgeError: the provided edge is invalid.
|
|
"""
|
|
|
|
self._validate_edge(edge)
|
|
if edge not in self.edges:
|
|
self.edges.append(edge)
|
|
else:
|
|
raise InvalidEdgeError()
|
|
|
|
def delete_edge(self, edge: Edge) -> None:
|
|
"""Deletes an edge from a graph"""
|
|
|
|
try:
|
|
self.edges.remove(edge)
|
|
except ValueError:
|
|
pass
|
|
|
|
def validate_self(self) -> None:
|
|
"""
|
|
Validates the graph.
|
|
|
|
Raises an exception if the graph is invalid:
|
|
- `DuplicateNodeIdError`
|
|
- `NodeIdMismatchError`
|
|
- `InvalidSubGraphError`
|
|
- `NodeNotFoundError`
|
|
- `NodeFieldNotFoundError`
|
|
- `CyclicalGraphError`
|
|
- `InvalidEdgeError`
|
|
"""
|
|
|
|
# Validate that all node ids are unique
|
|
node_ids = [n.id for n in self.nodes.values()]
|
|
seen = set()
|
|
duplicate_node_ids = {nid for nid in node_ids if (nid in seen) or seen.add(nid)}
|
|
if duplicate_node_ids:
|
|
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
|
|
|
|
# Validate that all node ids match the keys in the nodes dict
|
|
for k, v in self.nodes.items():
|
|
if k != v.id:
|
|
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
|
|
|
|
# Validate that all edges match nodes and fields in the graph
|
|
for edge in self.edges:
|
|
source_node = self.nodes.get(edge.source.node_id, None)
|
|
if source_node is None:
|
|
raise NodeNotFoundError(f"Edge source node {edge.source.node_id} does not exist in the graph")
|
|
|
|
destination_node = self.nodes.get(edge.destination.node_id, None)
|
|
if destination_node is None:
|
|
raise NodeNotFoundError(f"Edge destination node {edge.destination.node_id} does not exist in the graph")
|
|
|
|
# output fields are not on the node object directly, they are on the output type
|
|
if edge.source.field not in source_node.get_output_annotation().model_fields:
|
|
raise NodeFieldNotFoundError(
|
|
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
|
|
)
|
|
|
|
# input fields are on the node
|
|
if edge.destination.field not in type(destination_node).model_fields:
|
|
raise NodeFieldNotFoundError(
|
|
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
|
|
)
|
|
|
|
# Validate there are no cycles
|
|
g = self.nx_graph_flat()
|
|
if not nx.is_directed_acyclic_graph(g):
|
|
raise CyclicalGraphError("Graph contains cycles")
|
|
|
|
# Validate all edge connections are valid
|
|
for edge in self.edges:
|
|
if not are_connections_compatible(
|
|
self.get_node(edge.source.node_id),
|
|
edge.source.field,
|
|
self.get_node(edge.destination.node_id),
|
|
edge.destination.field,
|
|
):
|
|
raise InvalidEdgeError(f"Edge source and target types do not match ({edge})")
|
|
|
|
# Validate all iterators & collectors
|
|
# TODO: may need to validate all iterators & collectors in subgraphs so edge connections in parent graphs will be available
|
|
for node in self.nodes.values():
|
|
if isinstance(node, IterateInvocation):
|
|
err = self._is_iterator_connection_valid(node.id)
|
|
if err is not None:
|
|
raise InvalidEdgeError(f"Invalid iterator node ({node.id}): {err}")
|
|
if isinstance(node, CollectInvocation):
|
|
err = self._is_collector_connection_valid(node.id)
|
|
if err is not None:
|
|
raise InvalidEdgeError(f"Invalid collector node ({node.id}): {err}")
|
|
|
|
return None
|
|
|
|
def is_valid(self) -> bool:
|
|
"""
|
|
Checks if the graph is valid.
|
|
|
|
Raises `UnknownGraphValidationError` if there is a problem validating the graph (not a validation error).
|
|
"""
|
|
try:
|
|
self.validate_self()
|
|
return True
|
|
except (
|
|
DuplicateNodeIdError,
|
|
NodeIdMismatchError,
|
|
NodeNotFoundError,
|
|
NodeFieldNotFoundError,
|
|
CyclicalGraphError,
|
|
InvalidEdgeError,
|
|
):
|
|
return False
|
|
except Exception as e:
|
|
raise UnknownGraphValidationError(f"Problem validating graph {e}") from e
|
|
|
|
def _is_destination_field_Any(self, edge: Edge) -> bool:
|
|
"""Checks if the destination field for an edge is of type typing.Any"""
|
|
return get_input_field_type(self.get_node(edge.destination.node_id), edge.destination.field) == Any
|
|
|
|
def _is_destination_field_list_of_Any(self, edge: Edge) -> bool:
|
|
"""Checks if the destination field for an edge is of type typing.Any"""
|
|
return get_input_field_type(self.get_node(edge.destination.node_id), edge.destination.field) == list[Any]
|
|
|
|
def _validate_edge(self, edge: Edge):
|
|
"""Validates that a new edge doesn't create a cycle in the graph"""
|
|
|
|
# Validate that the nodes exist
|
|
try:
|
|
from_node = self.get_node(edge.source.node_id)
|
|
to_node = self.get_node(edge.destination.node_id)
|
|
except NodeNotFoundError:
|
|
raise InvalidEdgeError(f"One or both nodes don't exist ({edge})")
|
|
|
|
# Validate that an edge to this node+field doesn't already exist
|
|
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
|
if len(input_edges) > 0 and (
|
|
not isinstance(to_node, CollectInvocation) or edge.destination.field != ITEM_FIELD
|
|
):
|
|
raise InvalidEdgeError(f"Edge already exists ({edge})")
|
|
|
|
# Validate that no cycles would be created
|
|
g = self.nx_graph_flat()
|
|
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
|
if not nx.is_directed_acyclic_graph(g):
|
|
raise InvalidEdgeError(f"Edge creates a cycle in the graph ({edge})")
|
|
|
|
# Validate that the field types are compatible
|
|
if not are_connections_compatible(from_node, edge.source.field, to_node, edge.destination.field):
|
|
raise InvalidEdgeError(f"Field types are incompatible ({edge})")
|
|
|
|
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
|
if isinstance(to_node, IterateInvocation) and edge.destination.field == COLLECTION_FIELD:
|
|
err = self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source)
|
|
if err is not None:
|
|
raise InvalidEdgeError(f"Iterator input type does not match iterator output type ({edge}): {err}")
|
|
|
|
# Validate if iterator input type matches output type (if this edge results in both being set)
|
|
if isinstance(from_node, IterateInvocation) and edge.source.field == ITEM_FIELD:
|
|
err = self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination)
|
|
if err is not None:
|
|
raise InvalidEdgeError(f"Iterator output type does not match iterator input type ({edge}): {err}")
|
|
|
|
# Validate if collector input type matches output type (if this edge results in both being set)
|
|
if isinstance(to_node, CollectInvocation) and edge.destination.field in (ITEM_FIELD, COLLECTION_FIELD):
|
|
err = self._is_collector_connection_valid(
|
|
edge.destination.node_id, new_input=edge.source, new_input_field=edge.destination.field
|
|
)
|
|
if err is not None:
|
|
raise InvalidEdgeError(f"Collector output type does not match collector input type ({edge}): {err}")
|
|
|
|
# Validate if collector output type matches input type (if this edge results in both being set) - skip if the destination field is not Any or list[Any]
|
|
if (
|
|
isinstance(from_node, CollectInvocation)
|
|
and edge.source.field == COLLECTION_FIELD
|
|
and not self._is_destination_field_list_of_Any(edge)
|
|
and not self._is_destination_field_Any(edge)
|
|
):
|
|
err = self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination)
|
|
if err is not None:
|
|
raise InvalidEdgeError(f"Collector input type does not match collector output type ({edge}): {err}")
|
|
|
|
def has_node(self, node_id: str) -> bool:
|
|
"""Determines whether or not a node exists in the graph."""
|
|
try:
|
|
_ = self.get_node(node_id)
|
|
return True
|
|
except NodeNotFoundError:
|
|
return False
|
|
|
|
def get_node(self, node_id: str) -> BaseInvocation:
|
|
"""Gets a node from the graph."""
|
|
try:
|
|
return self.nodes[node_id]
|
|
except KeyError as e:
|
|
raise NodeNotFoundError(f"Node {node_id} not found in graph") from e
|
|
|
|
def update_node(self, node_id: str, new_node: BaseInvocation) -> None:
|
|
"""Updates a node in the graph."""
|
|
node = self.nodes[node_id]
|
|
|
|
# Ensure the node type matches the new node
|
|
if type(node) is not type(new_node):
|
|
raise TypeError(f"Node {node_id} is type {type(node)} but new node is type {type(new_node)}")
|
|
|
|
# Ensure the new id is either the same or is not in the graph
|
|
if new_node.id != node.id and self.has_node(new_node.id):
|
|
raise NodeAlreadyInGraphError(f"Node with id {new_node.id} already exists in graph")
|
|
|
|
# Set the new node in the graph
|
|
self.nodes[new_node.id] = new_node
|
|
if new_node.id != node.id:
|
|
input_edges = self._get_input_edges(node_id)
|
|
output_edges = self._get_output_edges(node_id)
|
|
|
|
# Delete node and all edges
|
|
self.delete_node(node_id)
|
|
|
|
# Create new edges for each input and output
|
|
for edge in input_edges:
|
|
self.add_edge(
|
|
Edge(
|
|
source=edge.source,
|
|
destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field),
|
|
)
|
|
)
|
|
|
|
for edge in output_edges:
|
|
self.add_edge(
|
|
Edge(
|
|
source=EdgeConnection(node_id=new_node.id, field=edge.source.field),
|
|
destination=edge.destination,
|
|
)
|
|
)
|
|
|
|
def _get_input_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]:
|
|
"""Gets all input edges for a node. If field is provided, only edges to that field are returned."""
|
|
|
|
edges = [e for e in self.edges if e.destination.node_id == node_id]
|
|
|
|
if field is None:
|
|
return edges
|
|
|
|
filtered_edges = [e for e in edges if e.destination.field == field]
|
|
|
|
return filtered_edges
|
|
|
|
def _get_output_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]:
|
|
"""Gets all output edges for a node. If field is provided, only edges from that field are returned."""
|
|
edges = [e for e in self.edges if e.source.node_id == node_id]
|
|
|
|
if field is None:
|
|
return edges
|
|
|
|
filtered_edges = [e for e in edges if e.source.field == field]
|
|
|
|
return filtered_edges
|
|
|
|
def _is_iterator_connection_valid(
|
|
self,
|
|
node_id: str,
|
|
new_input: Optional[EdgeConnection] = None,
|
|
new_output: Optional[EdgeConnection] = None,
|
|
) -> str | None:
|
|
inputs = [e.source for e in self._get_input_edges(node_id, COLLECTION_FIELD)]
|
|
outputs = [e.destination for e in self._get_output_edges(node_id, ITEM_FIELD)]
|
|
|
|
if new_input is not None:
|
|
inputs.append(new_input)
|
|
if new_output is not None:
|
|
outputs.append(new_output)
|
|
|
|
if len(inputs) == 0:
|
|
return "Iterator must have a collection input edge"
|
|
|
|
# Only one input is allowed for iterators
|
|
if len(inputs) > 1:
|
|
return "Iterator may only have one input edge"
|
|
|
|
input_node = self.get_node(inputs[0].node_id)
|
|
|
|
# Get input and output fields (the fields linked to the iterator's input/output)
|
|
input_field_type = get_output_field_type(input_node, inputs[0].field)
|
|
output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs]
|
|
|
|
# Input type must be a list
|
|
if get_origin(input_field_type) is not list:
|
|
return "Iterator input must be a collection"
|
|
|
|
# Validate that all outputs match the input type
|
|
input_field_item_type = get_args(input_field_type)[0]
|
|
if not all((are_connection_types_compatible(input_field_item_type, t) for t in output_field_types)):
|
|
return "Iterator outputs must connect to an input with a matching type"
|
|
|
|
# Collector input type must match all iterator output types
|
|
if isinstance(input_node, CollectInvocation):
|
|
input_root_type = self._get_collector_input_root_type(input_node.id)
|
|
if input_root_type is None:
|
|
return "Iterator input collector must have at least one item or collection input edge"
|
|
if not all((are_connection_types_compatible(input_root_type, t) for t in output_field_types)):
|
|
return "Iterator collection type must match all iterator output types"
|
|
|
|
return None
|
|
|
|
def _resolve_collector_input_types(self, node_id: str, visited: Optional[set[str]] = None) -> set[Any]:
|
|
"""Resolves possible item types for a collector's inputs, recursively following chained collectors."""
|
|
visited = visited or set()
|
|
if node_id in visited:
|
|
return set()
|
|
visited.add(node_id)
|
|
|
|
input_types: set[Any] = set()
|
|
|
|
for edge in self._get_input_edges(node_id, ITEM_FIELD):
|
|
input_field_type = get_output_field_type(self.get_node(edge.source.node_id), edge.source.field)
|
|
resolved_types = [input_field_type] if get_origin(input_field_type) is None else get_args(input_field_type)
|
|
input_types.update(t for t in resolved_types if t != NoneType)
|
|
|
|
for edge in self._get_input_edges(node_id, COLLECTION_FIELD):
|
|
source_node = self.get_node(edge.source.node_id)
|
|
if isinstance(source_node, CollectInvocation) and edge.source.field == COLLECTION_FIELD:
|
|
input_types.update(self._resolve_collector_input_types(source_node.id, visited.copy()))
|
|
continue
|
|
|
|
input_field_type = get_output_field_type(source_node, edge.source.field)
|
|
input_types.update(extract_collection_item_types(input_field_type))
|
|
|
|
return input_types
|
|
|
|
def _get_collector_input_root_type(self, node_id: str) -> Any | None:
|
|
input_types = self._resolve_collector_input_types(node_id)
|
|
non_any_input_types = {t for t in input_types if t != Any}
|
|
if len(non_any_input_types) == 0 and Any in input_types:
|
|
return Any
|
|
if len(non_any_input_types) == 0:
|
|
return None
|
|
|
|
type_tree = nx.DiGraph()
|
|
type_tree.add_nodes_from(non_any_input_types)
|
|
type_tree.add_edges_from([e for e in itertools.permutations(non_any_input_types, 2) if issubclass(e[1], e[0])])
|
|
type_degrees = type_tree.in_degree(type_tree.nodes)
|
|
root_types = [t[0] for t in type_degrees if t[1] == 0] # type: ignore
|
|
if len(root_types) != 1:
|
|
return Any
|
|
return root_types[0]
|
|
|
|
def _is_collector_connection_valid(
|
|
self,
|
|
node_id: str,
|
|
new_input: Optional[EdgeConnection] = None,
|
|
new_input_field: Optional[str] = None,
|
|
new_output: Optional[EdgeConnection] = None,
|
|
) -> str | None:
|
|
item_inputs = [e.source for e in self._get_input_edges(node_id, ITEM_FIELD)]
|
|
collection_inputs = [e.source for e in self._get_input_edges(node_id, COLLECTION_FIELD)]
|
|
outputs = [e.destination for e in self._get_output_edges(node_id, COLLECTION_FIELD)]
|
|
|
|
if new_input is not None:
|
|
field = new_input_field or ITEM_FIELD
|
|
if field == ITEM_FIELD:
|
|
item_inputs.append(new_input)
|
|
elif field == COLLECTION_FIELD:
|
|
collection_inputs.append(new_input)
|
|
if new_output is not None:
|
|
outputs.append(new_output)
|
|
|
|
if len(item_inputs) == 0 and len(collection_inputs) == 0:
|
|
return "Collector must have at least one item or collection input edge"
|
|
|
|
# Get input and output fields (the fields linked to the collector's input/output)
|
|
item_input_field_types = [get_output_field_type(self.get_node(e.node_id), e.field) for e in item_inputs]
|
|
collection_input_field_types = [
|
|
get_output_field_type(self.get_node(e.node_id), e.field) for e in collection_inputs
|
|
]
|
|
output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs]
|
|
|
|
if not all((is_list_or_contains_list(t) or is_any(t) for t in collection_input_field_types)):
|
|
return "Collector collection input must be a collection"
|
|
|
|
# Validate that all inputs are derived from or match a single type
|
|
input_field_types = {
|
|
resolved_type
|
|
for input_field_type in item_input_field_types
|
|
for resolved_type in (
|
|
[input_field_type] if get_origin(input_field_type) is None else get_args(input_field_type)
|
|
)
|
|
if resolved_type != NoneType
|
|
} # Get unique types
|
|
|
|
for input_conn, input_field_type in zip(collection_inputs, collection_input_field_types, strict=False):
|
|
source_node = self.get_node(input_conn.node_id)
|
|
if isinstance(source_node, CollectInvocation) and input_conn.field == COLLECTION_FIELD:
|
|
input_field_types.update(self._resolve_collector_input_types(source_node.id))
|
|
continue
|
|
input_field_types.update(extract_collection_item_types(input_field_type))
|
|
|
|
non_any_input_field_types = {t for t in input_field_types if t != Any}
|
|
type_tree = nx.DiGraph()
|
|
type_tree.add_nodes_from(non_any_input_field_types)
|
|
type_tree.add_edges_from(
|
|
[e for e in itertools.permutations(non_any_input_field_types, 2) if issubclass(e[1], e[0])]
|
|
)
|
|
type_degrees = type_tree.in_degree(type_tree.nodes)
|
|
root_types = [t[0] for t in type_degrees if t[1] == 0] # type: ignore
|
|
if len(root_types) > 1:
|
|
return "Collector input collection items must be of a single type"
|
|
|
|
# Get the input root type (if known)
|
|
input_root_type = root_types[0] if len(root_types) == 1 else None
|
|
|
|
# Verify that all outputs are lists
|
|
if not all(is_list_or_contains_list(t) or is_any(t) for t in output_field_types):
|
|
return "Collector output must connect to a collection input"
|
|
|
|
# Verify that all outputs match the input type (are a base class or the same class)
|
|
if input_root_type is not None:
|
|
if not all(
|
|
is_any(t)
|
|
or is_union_subtype(input_root_type, get_args(t)[0])
|
|
or issubclass(input_root_type, get_args(t)[0])
|
|
for t in output_field_types
|
|
):
|
|
return "Collector outputs must connect to a collection input with a matching type"
|
|
elif any(not is_any(t) and get_args(t)[0] != Any for t in output_field_types):
|
|
return "Collector outputs must connect to a collection input with a matching type"
|
|
|
|
# If this collector outputs to another collector's collection input, validate against the downstream
|
|
# collector's resolved input type (if available).
|
|
for output in outputs:
|
|
output_node = self.get_node(output.node_id)
|
|
if not isinstance(output_node, CollectInvocation) or output.field != COLLECTION_FIELD:
|
|
continue
|
|
output_root_type = self._get_collector_input_root_type(output_node.id)
|
|
if output_root_type is None:
|
|
continue
|
|
if input_root_type is None:
|
|
if output_root_type != Any:
|
|
return "Collector outputs must connect to a collection input with a matching type"
|
|
continue
|
|
if not are_connection_types_compatible(input_root_type, output_root_type):
|
|
return "Collector outputs must connect to a collection input with a matching type"
|
|
|
|
return None
|
|
|
|
def nx_graph(self) -> nx.DiGraph:
|
|
"""Returns a NetworkX DiGraph representing the layout of this graph"""
|
|
# TODO: Cache this?
|
|
g = nx.DiGraph()
|
|
g.add_nodes_from(list(self.nodes.keys()))
|
|
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
|
|
return g
|
|
|
|
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None) -> nx.DiGraph:
|
|
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
|
|
g = nx_graph or nx.DiGraph()
|
|
|
|
# Add all nodes from this graph except graph/iteration nodes
|
|
g.add_nodes_from([n.id for n in self.nodes.values()])
|
|
|
|
unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges}
|
|
g.add_edges_from(unique_edges)
|
|
return g
|
|
|
|
|
|
class GraphExecutionState(BaseModel):
|
|
"""Tracks the state of a graph execution"""
|
|
|
|
id: str = Field(description="The id of the execution state", default_factory=uuid_string)
|
|
# TODO: Store a reference to the graph instead of the actual graph?
|
|
graph: Graph = Field(description="The graph being executed")
|
|
|
|
# The graph of materialized nodes
|
|
execution_graph: Graph = Field(
|
|
description="The expanded graph of activated and executed nodes",
|
|
default_factory=Graph,
|
|
)
|
|
|
|
# Nodes that have been executed
|
|
executed: set[str] = Field(description="The set of node ids that have been executed", default_factory=set)
|
|
executed_history: list[str] = Field(
|
|
description="The list of node ids that have been executed, in order of execution",
|
|
default_factory=list,
|
|
)
|
|
|
|
# The results of executed nodes
|
|
results: dict[str, AnyInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
|
|
|
|
# Errors raised when executing nodes
|
|
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
|
|
|
# Map of prepared/executed nodes to their original nodes
|
|
prepared_source_mapping: dict[str, str] = Field(
|
|
description="The map of prepared nodes to original graph nodes",
|
|
default_factory=dict,
|
|
)
|
|
|
|
# Map of original nodes to prepared nodes
|
|
source_prepared_mapping: dict[str, set[str]] = Field(
|
|
description="The map of original graph nodes to prepared nodes",
|
|
default_factory=dict,
|
|
)
|
|
# Ready queues grouped by node class name (internal only)
|
|
_ready_queues: dict[str, Deque[str]] = PrivateAttr(default_factory=dict)
|
|
# Current class being drained; stays until its queue empties
|
|
_active_class: Optional[str] = PrivateAttr(default=None)
|
|
# Optional priority; others follow in name order
|
|
ready_order: list[str] = Field(default_factory=list)
|
|
indegree: dict[str, int] = Field(default_factory=dict, description="Remaining unmet input count for exec nodes")
|
|
_iteration_path_cache: dict[str, tuple[int, ...]] = PrivateAttr(default_factory=dict)
|
|
|
|
def _type_key(self, node_obj: BaseInvocation) -> str:
|
|
return node_obj.__class__.__name__
|
|
|
|
def _get_iteration_path(self, exec_node_id: str) -> tuple[int, ...]:
|
|
"""Best-effort outer->inner iteration indices for an execution node, stopping at collectors."""
|
|
cached = self._iteration_path_cache.get(exec_node_id)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
# Only prepared execution nodes participate; otherwise treat as non-iterated.
|
|
source_node_id = self.prepared_source_mapping.get(exec_node_id)
|
|
if source_node_id is None:
|
|
self._iteration_path_cache[exec_node_id] = ()
|
|
return ()
|
|
|
|
# Source-graph iterator ancestry, with edges into collectors removed so iteration context doesn't leak.
|
|
it_g = self._iterator_graph(self.graph.nx_graph())
|
|
iterator_sources = [
|
|
n for n in nx.ancestors(it_g, source_node_id) if isinstance(self.graph.get_node(n), IterateInvocation)
|
|
]
|
|
|
|
# Order iterators outer->inner via topo order of the iterator graph.
|
|
topo = list(nx.topological_sort(it_g))
|
|
topo_index = {n: i for i, n in enumerate(topo)}
|
|
iterator_sources.sort(key=lambda n: topo_index.get(n, 0))
|
|
|
|
# Map iterator source nodes to the prepared iterator exec nodes that are ancestors of exec_node_id.
|
|
eg = self.execution_graph.nx_graph()
|
|
path: list[int] = []
|
|
for it_src in iterator_sources:
|
|
prepared = self.source_prepared_mapping.get(it_src)
|
|
if not prepared:
|
|
continue
|
|
it_exec = next((p for p in prepared if nx.has_path(eg, p, exec_node_id)), None)
|
|
if it_exec is None:
|
|
continue
|
|
it_node = self.execution_graph.nodes.get(it_exec)
|
|
if isinstance(it_node, IterateInvocation):
|
|
path.append(it_node.index)
|
|
|
|
# If this exec node is itself an iterator, include its own index as the innermost element.
|
|
node_obj = self.execution_graph.nodes.get(exec_node_id)
|
|
if isinstance(node_obj, IterateInvocation):
|
|
path.append(node_obj.index)
|
|
|
|
result = tuple(path)
|
|
self._iteration_path_cache[exec_node_id] = result
|
|
return result
|
|
|
|
def _queue_for(self, cls_name: str) -> Deque[str]:
|
|
q = self._ready_queues.get(cls_name)
|
|
if q is None:
|
|
q = deque()
|
|
self._ready_queues[cls_name] = q
|
|
return q
|
|
|
|
def set_ready_order(self, order: Iterable[Type[BaseInvocation] | str]) -> None:
|
|
names: list[str] = []
|
|
for x in order:
|
|
names.append(x.__name__ if hasattr(x, "__name__") else str(x))
|
|
self.ready_order = names
|
|
|
|
def _enqueue_if_ready(self, nid: str) -> None:
|
|
"""Push nid to its class queue if unmet inputs == 0."""
|
|
# Invariants: exec node exists and has an indegree entry
|
|
if nid not in self.execution_graph.nodes:
|
|
raise KeyError(f"exec node {nid} missing from execution_graph")
|
|
if nid not in self.indegree:
|
|
raise KeyError(f"indegree missing for exec node {nid}")
|
|
if self.indegree[nid] != 0 or nid in self.executed:
|
|
return
|
|
node_obj = self.execution_graph.nodes[nid]
|
|
q = self._queue_for(self._type_key(node_obj))
|
|
nid_path = self._get_iteration_path(nid)
|
|
# Insert in lexicographic outer->inner order; preserve FIFO for equal paths.
|
|
for i, existing in enumerate(q):
|
|
if self._get_iteration_path(existing) > nid_path:
|
|
q.insert(i, nid)
|
|
break
|
|
else:
|
|
q.append(nid)
|
|
|
|
model_config = ConfigDict(
|
|
json_schema_extra={
|
|
"required": [
|
|
"id",
|
|
"graph",
|
|
"execution_graph",
|
|
"executed",
|
|
"executed_history",
|
|
"results",
|
|
"errors",
|
|
"prepared_source_mapping",
|
|
"source_prepared_mapping",
|
|
]
|
|
}
|
|
)
|
|
|
|
@field_validator("graph")
|
|
def graph_is_valid(cls, v: Graph):
|
|
"""Validates that the graph is valid"""
|
|
v.validate_self()
|
|
return v
|
|
|
|
def next(self) -> Optional[BaseInvocation]:
|
|
"""Gets the next node ready to execute."""
|
|
|
|
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes
|
|
# possibly with a timeout?
|
|
|
|
# If there are no prepared nodes, prepare some nodes
|
|
next_node = self._get_next_node()
|
|
if next_node is None:
|
|
base_g = self.graph.nx_graph_flat()
|
|
prepared_id = self._prepare(base_g)
|
|
|
|
# Prepare as many nodes as we can
|
|
while prepared_id is not None:
|
|
prepared_id = self._prepare(base_g)
|
|
if next_node is None:
|
|
next_node = self._get_next_node()
|
|
|
|
# Get values from edges
|
|
if next_node is not None:
|
|
try:
|
|
self._prepare_inputs(next_node)
|
|
except ValidationError as e:
|
|
raise NodeInputError(next_node, e)
|
|
|
|
# If next is still none, there's no next node, return None
|
|
return next_node
|
|
|
|
def complete(self, node_id: str, output: BaseInvocationOutput) -> None:
|
|
"""Marks a node as complete"""
|
|
|
|
if node_id not in self.execution_graph.nodes:
|
|
return # TODO: log error?
|
|
|
|
# Mark node as executed
|
|
self.executed.add(node_id)
|
|
self.results[node_id] = output
|
|
|
|
# Check if source node is complete (all prepared nodes are complete)
|
|
source_node = self.prepared_source_mapping[node_id]
|
|
prepared_nodes = self.source_prepared_mapping[source_node]
|
|
|
|
if all(n in self.executed for n in prepared_nodes):
|
|
self.executed.add(source_node)
|
|
self.executed_history.append(source_node)
|
|
|
|
# Decrement children indegree and enqueue when ready
|
|
for e in self.execution_graph._get_output_edges(node_id):
|
|
child = e.destination.node_id
|
|
if child not in self.indegree:
|
|
raise KeyError(f"indegree missing for exec node {child}")
|
|
# Only decrement if there's something to satisfy
|
|
if self.indegree[child] == 0:
|
|
raise RuntimeError(f"indegree underflow for {child} from parent {node_id}")
|
|
self.indegree[child] -= 1
|
|
if self.indegree[child] == 0:
|
|
self._enqueue_if_ready(child)
|
|
|
|
def set_node_error(self, node_id: str, error: str):
|
|
"""Marks a node as errored"""
|
|
self.errors[node_id] = error
|
|
|
|
def is_complete(self) -> bool:
|
|
"""Returns true if the graph is complete"""
|
|
node_ids = set(self.graph.nx_graph_flat().nodes)
|
|
return self.has_error() or all((k in self.executed for k in node_ids))
|
|
|
|
def has_error(self) -> bool:
|
|
"""Returns true if the graph has any errors"""
|
|
return len(self.errors) > 0
|
|
|
|
def _create_execution_node(self, node_id: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
|
|
"""Prepares an iteration node and connects all edges, returning the new node id"""
|
|
|
|
node = self.graph.get_node(node_id)
|
|
|
|
self_iteration_count = -1
|
|
|
|
# If this is an iterator node, we must create a copy for each iteration
|
|
if isinstance(node, IterateInvocation):
|
|
# Get input collection edge (should error if there are no inputs)
|
|
input_collection_edge = next(iter(self.graph._get_input_edges(node_id, COLLECTION_FIELD)))
|
|
input_collection_prepared_node_id = next(
|
|
n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id
|
|
)
|
|
input_collection_prepared_node_output = self.results[input_collection_prepared_node_id]
|
|
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
|
|
self_iteration_count = len(input_collection)
|
|
|
|
new_nodes: list[str] = []
|
|
if self_iteration_count == 0:
|
|
# TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid.
|
|
return new_nodes
|
|
|
|
# Get all input edges
|
|
input_edges = self.graph._get_input_edges(node_id)
|
|
|
|
# Create new edges for this iteration
|
|
# For collect nodes, this may contain multiple inputs to the same field
|
|
new_edges: list[Edge] = []
|
|
for edge in input_edges:
|
|
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
|
|
new_edge = Edge(
|
|
source=EdgeConnection(node_id=input_node_id, field=edge.source.field),
|
|
destination=EdgeConnection(node_id="", field=edge.destination.field),
|
|
)
|
|
new_edges.append(new_edge)
|
|
|
|
# Create a new node (or one for each iteration of this iterator)
|
|
for i in range(self_iteration_count) if self_iteration_count > 0 else [-1]:
|
|
# Create a new node
|
|
new_node = node.model_copy(deep=True)
|
|
|
|
# Create the node id (use a random uuid)
|
|
new_node.id = uuid_string()
|
|
|
|
# Set the iteration index for iteration invocations
|
|
if isinstance(new_node, IterateInvocation):
|
|
new_node.index = i
|
|
|
|
# Add to execution graph
|
|
self.execution_graph.add_node(new_node)
|
|
self.prepared_source_mapping[new_node.id] = node_id
|
|
if node_id not in self.source_prepared_mapping:
|
|
self.source_prepared_mapping[node_id] = set()
|
|
self.source_prepared_mapping[node_id].add(new_node.id)
|
|
|
|
# Add new edges to execution graph
|
|
for edge in new_edges:
|
|
new_edge = Edge(
|
|
source=edge.source,
|
|
destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field),
|
|
)
|
|
self.execution_graph.add_edge(new_edge)
|
|
|
|
# Initialize indegree as unmet inputs only and enqueue if ready
|
|
inputs = self.execution_graph._get_input_edges(new_node.id)
|
|
unmet = sum(1 for e in inputs if e.source.node_id not in self.executed)
|
|
self.indegree[new_node.id] = unmet
|
|
self._enqueue_if_ready(new_node.id)
|
|
|
|
new_nodes.append(new_node.id)
|
|
|
|
return new_nodes
|
|
|
|
def _iterator_graph(self, base: Optional[nx.DiGraph] = None) -> nx.DiGraph:
|
|
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
|
|
g = base.copy() if base is not None else self.graph.nx_graph_flat()
|
|
collectors = (n for n in self.graph.nodes if isinstance(self.graph.get_node(n), CollectInvocation))
|
|
for c in collectors:
|
|
g.remove_edges_from(list(g.in_edges(c)))
|
|
return g
|
|
|
|
def _get_node_iterators(self, node_id: str, it_graph: Optional[nx.DiGraph] = None) -> list[str]:
|
|
"""Gets iterators for a node"""
|
|
g = it_graph or self._iterator_graph()
|
|
return [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.get_node(n), IterateInvocation)]
|
|
|
|
def _prepare(self, base_g: Optional[nx.DiGraph] = None) -> Optional[str]:
|
|
# Get flattened source graph
|
|
g = base_g or self.graph.nx_graph_flat()
|
|
|
|
# Find next node that:
|
|
# - was not already prepared
|
|
# - is not an iterate node whose inputs have not been executed
|
|
# - does not have an unexecuted iterate ancestor
|
|
sorted_nodes = nx.topological_sort(g)
|
|
|
|
def unprepared(n: str) -> bool:
|
|
return n not in self.source_prepared_mapping
|
|
|
|
def iter_inputs_ready(n: str) -> bool:
|
|
if not isinstance(self.graph.get_node(n), IterateInvocation):
|
|
return True
|
|
return all(u in self.executed for u, _ in g.in_edges(n))
|
|
|
|
def no_unexecuted_iter_ancestors(n: str) -> bool:
|
|
return not any(
|
|
isinstance(self.graph.get_node(a), IterateInvocation) and a not in self.executed
|
|
for a in nx.ancestors(g, n)
|
|
)
|
|
|
|
next_node_id = next(
|
|
(n for n in sorted_nodes if unprepared(n) and iter_inputs_ready(n) and no_unexecuted_iter_ancestors(n)),
|
|
None,
|
|
)
|
|
|
|
if next_node_id is None:
|
|
return None
|
|
|
|
# Get all parents of the next node
|
|
next_node_parents = [u for u, _ in g.in_edges(next_node_id)]
|
|
|
|
# Create execution nodes
|
|
next_node = self.graph.get_node(next_node_id)
|
|
new_node_ids = []
|
|
if isinstance(next_node, CollectInvocation):
|
|
# Collapse all iterator input mappings and create a single execution node for the collect invocation
|
|
all_iteration_mappings = []
|
|
for source_node_id in next_node_parents:
|
|
prepared_nodes = self.source_prepared_mapping[source_node_id]
|
|
all_iteration_mappings.extend([(source_node_id, p) for p in prepared_nodes])
|
|
|
|
create_results = self._create_execution_node(next_node_id, all_iteration_mappings)
|
|
if create_results is not None:
|
|
new_node_ids.extend(create_results)
|
|
else: # Iterators or normal nodes
|
|
# Get all iterator combinations for this node
|
|
# Will produce a list of lists of prepared iterator nodes, from which results can be iterated
|
|
it_g = self._iterator_graph(g)
|
|
iterator_nodes = self._get_node_iterators(next_node_id, it_g)
|
|
iterator_nodes_prepared = [list(self.source_prepared_mapping[n]) for n in iterator_nodes]
|
|
iterator_node_prepared_combinations = list(itertools.product(*iterator_nodes_prepared))
|
|
|
|
# Select the correct prepared parents for each iteration
|
|
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
|
|
eg = self.execution_graph.nx_graph_flat()
|
|
prepared_parent_mappings = [
|
|
[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents]
|
|
for it in iterator_node_prepared_combinations
|
|
] # type: ignore
|
|
prepared_parent_mappings = [m for m in prepared_parent_mappings if all(p[1] is not None for p in m)]
|
|
|
|
# Create execution node for each iteration
|
|
for iteration_mappings in prepared_parent_mappings:
|
|
create_results = self._create_execution_node(next_node_id, iteration_mappings) # type: ignore
|
|
if create_results is not None:
|
|
new_node_ids.extend(create_results)
|
|
|
|
return next(iter(new_node_ids), None)
|
|
|
|
def _get_iteration_node(
|
|
self,
|
|
source_node_id: str,
|
|
graph: nx.DiGraph,
|
|
execution_graph: nx.DiGraph,
|
|
prepared_iterator_nodes: list[str],
|
|
) -> Optional[str]:
|
|
"""Gets the prepared version of the specified source node that matches every iteration specified"""
|
|
prepared_nodes = self.source_prepared_mapping[source_node_id]
|
|
if len(prepared_nodes) == 1:
|
|
return next(iter(prepared_nodes))
|
|
|
|
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
|
|
iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
|
|
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_id)]
|
|
|
|
# If the requested node is an iterator, only accept it if it is compatible with all parent iterators
|
|
prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None)
|
|
if prepared_iterator is not None:
|
|
if all(nx.has_path(execution_graph, pit[0], prepared_iterator) for pit in parent_iterators):
|
|
return prepared_iterator
|
|
return None
|
|
|
|
return next(
|
|
(n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)),
|
|
None,
|
|
)
|
|
|
|
def _get_next_node(self) -> Optional[BaseInvocation]:
|
|
"""Gets the next ready node: FIFO within class, drain class before switching."""
|
|
# 1) Continue draining the active class
|
|
if self._active_class:
|
|
q = self._ready_queues.get(self._active_class)
|
|
while q:
|
|
nid = q.popleft()
|
|
if nid not in self.executed:
|
|
return self.execution_graph.nodes[nid]
|
|
# emptied: release active class
|
|
self._active_class = None
|
|
|
|
# 2) Pick next class by priority, then by class name
|
|
seen = set(self.ready_order)
|
|
for cls_name in self.ready_order:
|
|
q = self._ready_queues.get(cls_name)
|
|
if q:
|
|
self._active_class = cls_name
|
|
# recurse to drain newly set active class
|
|
return self._get_next_node()
|
|
for cls_name in sorted(k for k in self._ready_queues.keys() if k not in seen):
|
|
q = self._ready_queues[cls_name]
|
|
if q:
|
|
self._active_class = cls_name
|
|
return self._get_next_node()
|
|
return None
|
|
|
|
def _prepare_inputs(self, node: BaseInvocation):
|
|
input_edges = self.execution_graph._get_input_edges(node.id)
|
|
# Inputs must be deep-copied, else if a node mutates the object, other nodes that get the same input
|
|
# will see the mutation.
|
|
if isinstance(node, CollectInvocation):
|
|
item_edges = [e for e in input_edges if e.destination.field == ITEM_FIELD]
|
|
item_edges.sort(key=lambda e: (self._get_iteration_path(e.source.node_id), e.source.node_id))
|
|
collection_edges = [e for e in input_edges if e.destination.field == COLLECTION_FIELD]
|
|
collection_edges.sort(key=lambda e: (self._get_iteration_path(e.source.node_id), e.source.node_id))
|
|
|
|
output_collection = []
|
|
for edge in collection_edges:
|
|
source_value = copydeep(getattr(self.results[edge.source.node_id], edge.source.field))
|
|
if isinstance(source_value, list):
|
|
output_collection.extend(source_value)
|
|
else:
|
|
output_collection.append(source_value)
|
|
output_collection.extend(
|
|
copydeep(getattr(self.results[e.source.node_id], e.source.field)) for e in item_edges
|
|
)
|
|
node.collection = output_collection
|
|
else:
|
|
for edge in input_edges:
|
|
setattr(
|
|
node,
|
|
edge.destination.field,
|
|
copydeep(getattr(self.results[edge.source.node_id], edge.source.field)),
|
|
)
|
|
|
|
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
|
|
def _is_edge_valid(self, edge: Edge) -> bool:
|
|
try:
|
|
self.graph._validate_edge(edge)
|
|
except InvalidEdgeError:
|
|
return False
|
|
|
|
# Invalid if destination has already been prepared or executed
|
|
if edge.destination.node_id in self.source_prepared_mapping:
|
|
return False
|
|
|
|
# Otherwise, the edge is valid
|
|
return True
|
|
|
|
def _is_node_updatable(self, node_id: str) -> bool:
|
|
# The node is updatable as long as it hasn't been prepared or executed
|
|
return node_id not in self.source_prepared_mapping
|
|
|
|
def add_node(self, node: BaseInvocation) -> None:
|
|
self.graph.add_node(node)
|
|
|
|
def update_node(self, node_id: str, new_node: BaseInvocation) -> None:
|
|
if not self._is_node_updatable(node_id):
|
|
raise NodeAlreadyExecutedError(
|
|
f"Node {node_id} has already been prepared or executed and cannot be updated"
|
|
)
|
|
self.graph.update_node(node_id, new_node)
|
|
|
|
def delete_node(self, node_id: str) -> None:
|
|
if not self._is_node_updatable(node_id):
|
|
raise NodeAlreadyExecutedError(
|
|
f"Node {node_id} has already been prepared or executed and cannot be deleted"
|
|
)
|
|
self.graph.delete_node(node_id)
|
|
|
|
def add_edge(self, edge: Edge) -> None:
|
|
if not self._is_node_updatable(edge.destination.node_id):
|
|
raise NodeAlreadyExecutedError(
|
|
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot be linked to"
|
|
)
|
|
self.graph.add_edge(edge)
|
|
|
|
def delete_edge(self, edge: Edge) -> None:
|
|
if not self._is_node_updatable(edge.destination.node_id):
|
|
raise NodeAlreadyExecutedError(
|
|
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot have a source edge deleted"
|
|
)
|
|
self.graph.delete_edge(edge)
|