mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
* Implemented ordering for expanded iterators * Update test_graph_execution_state.py Added a test for nested iterator execution ordering. (Failing at commit time!) * Filter invalid nested-iterator parent mappings in _prepare() When a graph has nested iterators, some "ready to run" node combinations do not actually belong together. Previously, the scheduler would still try to build nodes for those mismatched combinations, which could cause the same work to run more than once. This change skips any combination that is missing a valid iterator parent, so nested iterator expansions run once per intended item. * Fixed Collect node ordering * ruff * Removed ordering guarantees from test_node_graph.py * Fix iterator prep and type compatibility in graph execution Include iterator nodes in nx_graph_flat so iterators are prepared/expanded correctly. Fix connection type checks to allow subclass-to-base via issubclass. Harden iterator/collector validation to fail cleanly instead of crashing on missing edges. Remove unused nx_graph_with_data(). Added tests to verify proper functionality.
317 lines
11 KiB
Python
317 lines
11 KiB
Python
from typing import Optional
|
|
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
|
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
|
from invokeai.app.invocations.collections import RangeInvocation
|
|
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
|
|
from invokeai.app.services.shared.graph import (
|
|
CollectInvocation,
|
|
Graph,
|
|
GraphExecutionState,
|
|
IterateInvocation,
|
|
)
|
|
|
|
# This import must happen before other invoke imports or test in other files(!!) break
|
|
from tests.test_nodes import (
|
|
PromptCollectionTestInvocation,
|
|
PromptTestInvocation,
|
|
TextToImageTestInvocation,
|
|
create_edge,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def simple_graph() -> Graph:
|
|
g = Graph()
|
|
g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
|
g.add_node(TextToImageTestInvocation(id="2"))
|
|
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
|
return g
|
|
|
|
|
|
def invoke_next(g: GraphExecutionState) -> tuple[Optional[BaseInvocation], Optional[BaseInvocationOutput]]:
|
|
n = g.next()
|
|
if n is None:
|
|
return (None, None)
|
|
|
|
print(f"invoking {n.id}: {type(n)}")
|
|
o = n.invoke(Mock(InvocationContext))
|
|
g.complete(n.id, o)
|
|
|
|
return (n, o)
|
|
|
|
|
|
def test_graph_state_executes_in_order(simple_graph: Graph):
|
|
g = GraphExecutionState(graph=simple_graph)
|
|
|
|
n1 = invoke_next(g)
|
|
n2 = invoke_next(g)
|
|
n3 = g.next()
|
|
|
|
assert g.prepared_source_mapping[n1[0].id] == "1"
|
|
assert g.prepared_source_mapping[n2[0].id] == "2"
|
|
assert n3 is None
|
|
assert g.results[n1[0].id].prompt == n1[0].prompt
|
|
assert n2[0].prompt == n1[0].prompt
|
|
|
|
|
|
def test_graph_is_complete(simple_graph: Graph):
|
|
g = GraphExecutionState(graph=simple_graph)
|
|
_ = invoke_next(g)
|
|
_ = invoke_next(g)
|
|
_ = g.next()
|
|
|
|
assert g.is_complete()
|
|
|
|
|
|
def test_graph_is_not_complete(simple_graph: Graph):
|
|
g = GraphExecutionState(graph=simple_graph)
|
|
_ = invoke_next(g)
|
|
_ = g.next()
|
|
|
|
assert not g.is_complete()
|
|
|
|
|
|
# TODO: test completion with iterators/subgraphs
|
|
|
|
|
|
def test_graph_state_expands_iterator():
|
|
graph = Graph()
|
|
graph.add_node(RangeInvocation(id="0", start=0, stop=3, step=1))
|
|
graph.add_node(IterateInvocation(id="1"))
|
|
graph.add_node(MultiplyInvocation(id="2", b=10))
|
|
graph.add_node(AddInvocation(id="3", b=1))
|
|
graph.add_edge(create_edge("0", "collection", "1", "collection"))
|
|
graph.add_edge(create_edge("1", "item", "2", "a"))
|
|
graph.add_edge(create_edge("2", "value", "3", "a"))
|
|
|
|
g = GraphExecutionState(graph=graph)
|
|
while not g.is_complete():
|
|
invoke_next(g)
|
|
|
|
prepared_add_nodes = g.source_prepared_mapping["3"]
|
|
results = {g.results[n].value for n in prepared_add_nodes}
|
|
expected = {1, 11, 21}
|
|
assert results == expected
|
|
|
|
|
|
def test_graph_state_collects():
|
|
graph = Graph()
|
|
test_prompts = ["Banana sushi", "Cat sushi"]
|
|
graph.add_node(PromptCollectionTestInvocation(id="1", collection=list(test_prompts)))
|
|
graph.add_node(IterateInvocation(id="2"))
|
|
graph.add_node(PromptTestInvocation(id="3"))
|
|
graph.add_node(CollectInvocation(id="4"))
|
|
graph.add_edge(create_edge("1", "collection", "2", "collection"))
|
|
graph.add_edge(create_edge("2", "item", "3", "prompt"))
|
|
graph.add_edge(create_edge("3", "prompt", "4", "item"))
|
|
|
|
g = GraphExecutionState(graph=graph)
|
|
_ = invoke_next(g)
|
|
_ = invoke_next(g)
|
|
_ = invoke_next(g)
|
|
_ = invoke_next(g)
|
|
_ = invoke_next(g)
|
|
n6 = invoke_next(g)
|
|
|
|
assert isinstance(n6[0], CollectInvocation)
|
|
|
|
assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts)
|
|
|
|
|
|
def test_graph_state_prepares_eagerly():
|
|
"""Tests that all prepareable nodes are prepared"""
|
|
graph = Graph()
|
|
|
|
test_prompts = ["Banana sushi", "Cat sushi"]
|
|
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
|
|
graph.add_node(IterateInvocation(id="iterate"))
|
|
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
|
|
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
|
|
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
|
|
|
|
# separated, fully-preparable chain of nodes
|
|
graph.add_node(PromptTestInvocation(id="prompt_chain_1", prompt="Dinosaur sushi"))
|
|
graph.add_node(PromptTestInvocation(id="prompt_chain_2"))
|
|
graph.add_node(PromptTestInvocation(id="prompt_chain_3"))
|
|
graph.add_edge(create_edge("prompt_chain_1", "prompt", "prompt_chain_2", "prompt"))
|
|
graph.add_edge(create_edge("prompt_chain_2", "prompt", "prompt_chain_3", "prompt"))
|
|
|
|
g = GraphExecutionState(graph=graph)
|
|
g.next()
|
|
|
|
assert "prompt_collection" in g.source_prepared_mapping
|
|
assert "prompt_chain_1" in g.source_prepared_mapping
|
|
assert "prompt_chain_2" in g.source_prepared_mapping
|
|
assert "prompt_chain_3" in g.source_prepared_mapping
|
|
assert "iterate" not in g.source_prepared_mapping
|
|
assert "prompt_iterated" not in g.source_prepared_mapping
|
|
|
|
|
|
def test_graph_executes_depth_first():
|
|
"""Tests that the graph executes depth-first, executing a branch as far as possible before moving to the next branch"""
|
|
|
|
def assert_topo_order_and_all_executed(state: GraphExecutionState, order: list[str]):
|
|
"""
|
|
Validates:
|
|
1) Every materialized exec node executed exactly once.
|
|
2) Execution order respects all exec-graph dependencies (u→v ⇒ u before v).
|
|
"""
|
|
# order must be EXEC node ids in run order
|
|
exec_nodes = set(state.execution_graph.nodes.keys())
|
|
|
|
# 1) coverage: all exec nodes ran, and no duplicates
|
|
pos = {nid: i for i, nid in enumerate(order)}
|
|
assert set(pos.keys()) == exec_nodes, (
|
|
f"Executed {len(pos)} of {len(exec_nodes)} nodes. Missing: {sorted(exec_nodes - set(pos))[:10]}"
|
|
)
|
|
assert len(pos) == len(order), "Duplicate execution detected"
|
|
|
|
# 2) topo order: parents before children
|
|
for e in state.execution_graph.edges:
|
|
u = e.source.node_id
|
|
v = e.destination.node_id
|
|
assert pos[u] < pos[v], f"child {v} ran before parent {u}"
|
|
|
|
graph = Graph()
|
|
|
|
test_prompts = ["Banana sushi", "Cat sushi"]
|
|
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
|
|
graph.add_node(IterateInvocation(id="iterate"))
|
|
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
|
|
graph.add_node(PromptTestInvocation(id="prompt_successor"))
|
|
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
|
|
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
|
|
graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt"))
|
|
|
|
g = GraphExecutionState(graph=graph)
|
|
order: list[str] = []
|
|
|
|
while True:
|
|
n = g.next()
|
|
if n is None:
|
|
break
|
|
o = n.invoke(Mock(InvocationContext))
|
|
g.complete(n.id, o)
|
|
order.append(n.id)
|
|
|
|
assert_topo_order_and_all_executed(g, order)
|
|
|
|
|
|
# Because this tests deterministic ordering, we run it multiple times
|
|
@pytest.mark.parametrize("execution_number", range(5))
|
|
def test_graph_iterate_execution_order(execution_number: int):
|
|
"""Tests that iterate nodes execution is ordered by the order of the collection"""
|
|
|
|
graph = Graph()
|
|
|
|
test_prompts = ["Banana sushi", "Cat sushi", "Strawberry Sushi", "Dinosaur Sushi"]
|
|
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
|
|
graph.add_node(IterateInvocation(id="iterate"))
|
|
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
|
|
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
|
|
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
|
|
|
|
g = GraphExecutionState(graph=graph)
|
|
_ = invoke_next(g)
|
|
_ = invoke_next(g)
|
|
assert _[1].item == "Banana sushi"
|
|
_ = invoke_next(g)
|
|
assert _[1].item == "Cat sushi"
|
|
_ = invoke_next(g)
|
|
assert _[1].item == "Strawberry Sushi"
|
|
_ = invoke_next(g)
|
|
assert _[1].item == "Dinosaur Sushi"
|
|
_ = invoke_next(g)
|
|
|
|
|
|
# Because this tests deterministic ordering, we run it multiple times
|
|
@pytest.mark.parametrize("execution_number", range(5))
|
|
def test_graph_nested_iterate_execution_order(execution_number: int):
|
|
"""
|
|
Validates best-effort in-order execution for nodes expanded under nested iterators.
|
|
Expected lexicographic order by (outer_index, inner_index), subject to readiness.
|
|
"""
|
|
graph = Graph()
|
|
|
|
# Outer iterator: [0, 1]
|
|
graph.add_node(RangeInvocation(id="outer_range", start=0, stop=2, step=1))
|
|
graph.add_node(IterateInvocation(id="outer_iter"))
|
|
|
|
# Inner iterator is derived from the outer item:
|
|
# start = outer_item * 10
|
|
# stop = start + 2 => yields 2 items per outer item
|
|
graph.add_node(MultiplyInvocation(id="mul10", b=10))
|
|
graph.add_node(AddInvocation(id="stop_plus2", b=2))
|
|
graph.add_node(RangeInvocation(id="inner_range", start=0, stop=1, step=1))
|
|
graph.add_node(IterateInvocation(id="inner_iter"))
|
|
|
|
# Observe inner items (they encode outer via start=outer*10)
|
|
graph.add_node(AddInvocation(id="sum", b=0))
|
|
|
|
graph.add_edge(create_edge("outer_range", "collection", "outer_iter", "collection"))
|
|
graph.add_edge(create_edge("outer_iter", "item", "mul10", "a"))
|
|
graph.add_edge(create_edge("mul10", "value", "stop_plus2", "a"))
|
|
graph.add_edge(create_edge("mul10", "value", "inner_range", "start"))
|
|
graph.add_edge(create_edge("stop_plus2", "value", "inner_range", "stop"))
|
|
graph.add_edge(create_edge("inner_range", "collection", "inner_iter", "collection"))
|
|
graph.add_edge(create_edge("inner_iter", "item", "sum", "a"))
|
|
|
|
g = GraphExecutionState(graph=graph)
|
|
sum_values: list[int] = []
|
|
|
|
while True:
|
|
n, o = invoke_next(g)
|
|
if n is None:
|
|
break
|
|
if g.prepared_source_mapping[n.id] == "sum":
|
|
sum_values.append(o.value)
|
|
|
|
assert sum_values == [0, 1, 10, 11]
|
|
|
|
|
|
def test_graph_validate_self_iterator_without_collection_input_raises_invalid_edge_error():
|
|
"""Iterator nodes with no collection input should fail validation cleanly.
|
|
|
|
This test exposes the bug where validation crashes with IndexError instead of raising InvalidEdgeError.
|
|
"""
|
|
from invokeai.app.services.shared.graph import InvalidEdgeError
|
|
|
|
graph = Graph()
|
|
graph.add_node(IterateInvocation(id="iterate"))
|
|
|
|
with pytest.raises(InvalidEdgeError):
|
|
graph.validate_self()
|
|
|
|
|
|
def test_graph_validate_self_collector_without_item_inputs_raises_invalid_edge_error():
|
|
"""Collector nodes with no item inputs should fail validation cleanly.
|
|
|
|
This test exposes the bug where validation can crash (e.g. StopIteration) instead of raising InvalidEdgeError.
|
|
"""
|
|
from invokeai.app.services.shared.graph import InvalidEdgeError
|
|
|
|
graph = Graph()
|
|
graph.add_node(CollectInvocation(id="collect"))
|
|
|
|
with pytest.raises(InvalidEdgeError):
|
|
graph.validate_self()
|
|
|
|
|
|
def test_are_connection_types_compatible_accepts_subclass_to_base():
|
|
"""A subclass output should be connectable to a base-class input.
|
|
|
|
This test exposes the bug where non-Union targets reject valid subclass connections.
|
|
"""
|
|
from invokeai.app.services.shared.graph import are_connection_types_compatible
|
|
|
|
class Base:
|
|
pass
|
|
|
|
class Child(Base):
|
|
pass
|
|
|
|
assert are_connection_types_compatible(Child, Base) is True
|