Files
InvokeAI/tests/test_graph_execution_state.py
Jonathan d4104be0b8 graph.py refactoring and If node optimization (#9030)
* test: add if-node execution coverage

* feat: short-circuit if-node branch execution

* test: cover iterated if-node pruning

* style: apply ruff fixes for if-node work

* refactor: track prepared exec node metadata

* fix: defer iterated if branches until resolution

* refactor: extract prepared exec registry

* refactor: extract if branch scheduler

* refactor: extract execution materializer

* refactor: extract execution scheduler

* refactor: extract execution runtime

* refactor: clarify if branch resolution

* refactor: clarify execution materialization

* docs: describe graph execution helpers

* refactor: clarify execution runtime

* refactor: clarify execution scheduling

* refactor: clarify iteration node selection

* docs: describe execution materializer flow

* refactor: clarify collector validation

* refactor: clarify iterator validation

* refactor: clarify graph validation flow

* docs: update shared graph design overview

* chore: typegen

* fix: harden if-node scheduler edge cases
2026-04-09 21:38:40 -04:00

767 lines
31 KiB
Python

from typing import Optional
from unittest.mock import Mock
import pytest
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.invocations.collections import RangeInvocation
from invokeai.app.invocations.logic import IfInvocation, IfInvocationOutput
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
from invokeai.app.invocations.primitives import BooleanCollectionInvocation, BooleanInvocation
from invokeai.app.services.shared.graph import (
CollectInvocation,
Graph,
GraphExecutionState,
IterateInvocation,
)
# This import must happen before other invoke imports or test in other files(!!) break
from tests.test_nodes import (
AnyTypeTestInvocation,
PromptCollectionTestInvocation,
PromptTestInvocation,
TextToImageTestInvocation,
create_edge,
)
@pytest.fixture
def simple_graph() -> Graph:
g = Graph()
g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
g.add_node(TextToImageTestInvocation(id="2"))
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
return g
def invoke_next(g: GraphExecutionState) -> tuple[Optional[BaseInvocation], Optional[BaseInvocationOutput]]:
n = g.next()
if n is None:
return (None, None)
print(f"invoking {n.id}: {type(n)}")
o = n.invoke(Mock(InvocationContext))
g.complete(n.id, o)
return (n, o)
def execute_all_nodes(g: GraphExecutionState) -> list[str]:
"""Execute the graph to completion and return source node ids in execution order."""
executed_source_ids: list[str] = []
while True:
invocation, _output = invoke_next(g)
if invocation is None:
break
executed_source_ids.append(g.prepared_source_mapping[invocation.id])
return executed_source_ids
def test_graph_state_executes_in_order(simple_graph: Graph):
g = GraphExecutionState(graph=simple_graph)
n1 = invoke_next(g)
n2 = invoke_next(g)
n3 = g.next()
assert g.prepared_source_mapping[n1[0].id] == "1"
assert g.prepared_source_mapping[n2[0].id] == "2"
assert n3 is None
assert g.results[n1[0].id].prompt == n1[0].prompt
assert n2[0].prompt == n1[0].prompt
def test_graph_is_complete(simple_graph: Graph):
g = GraphExecutionState(graph=simple_graph)
_ = invoke_next(g)
_ = invoke_next(g)
_ = g.next()
assert g.is_complete()
def test_graph_is_not_complete(simple_graph: Graph):
g = GraphExecutionState(graph=simple_graph)
_ = invoke_next(g)
_ = g.next()
assert not g.is_complete()
# TODO: test completion with iterators/subgraphs
def test_graph_state_expands_iterator():
graph = Graph()
graph.add_node(RangeInvocation(id="0", start=0, stop=3, step=1))
graph.add_node(IterateInvocation(id="1"))
graph.add_node(MultiplyInvocation(id="2", b=10))
graph.add_node(AddInvocation(id="3", b=1))
graph.add_edge(create_edge("0", "collection", "1", "collection"))
graph.add_edge(create_edge("1", "item", "2", "a"))
graph.add_edge(create_edge("2", "value", "3", "a"))
g = GraphExecutionState(graph=graph)
while not g.is_complete():
invoke_next(g)
prepared_add_nodes = g.source_prepared_mapping["3"]
results = {g.results[n].value for n in prepared_add_nodes}
expected = {1, 11, 21}
assert results == expected
def test_graph_state_collects():
graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(PromptCollectionTestInvocation(id="1", collection=list(test_prompts)))
graph.add_node(IterateInvocation(id="2"))
graph.add_node(PromptTestInvocation(id="3"))
graph.add_node(CollectInvocation(id="4"))
graph.add_edge(create_edge("1", "collection", "2", "collection"))
graph.add_edge(create_edge("2", "item", "3", "prompt"))
graph.add_edge(create_edge("3", "prompt", "4", "item"))
g = GraphExecutionState(graph=graph)
_ = invoke_next(g)
_ = invoke_next(g)
_ = invoke_next(g)
_ = invoke_next(g)
_ = invoke_next(g)
n6 = invoke_next(g)
assert isinstance(n6[0], CollectInvocation)
assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts)
def test_graph_state_prepares_eagerly():
"""Tests that all prepareable nodes are prepared"""
graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
graph.add_node(IterateInvocation(id="iterate"))
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
# separated, fully-preparable chain of nodes
graph.add_node(PromptTestInvocation(id="prompt_chain_1", prompt="Dinosaur sushi"))
graph.add_node(PromptTestInvocation(id="prompt_chain_2"))
graph.add_node(PromptTestInvocation(id="prompt_chain_3"))
graph.add_edge(create_edge("prompt_chain_1", "prompt", "prompt_chain_2", "prompt"))
graph.add_edge(create_edge("prompt_chain_2", "prompt", "prompt_chain_3", "prompt"))
g = GraphExecutionState(graph=graph)
g.next()
assert "prompt_collection" in g.source_prepared_mapping
assert "prompt_chain_1" in g.source_prepared_mapping
assert "prompt_chain_2" in g.source_prepared_mapping
assert "prompt_chain_3" in g.source_prepared_mapping
assert "iterate" not in g.source_prepared_mapping
assert "prompt_iterated" not in g.source_prepared_mapping
def test_graph_executes_depth_first():
"""Tests that the graph executes depth-first, executing a branch as far as possible before moving to the next branch"""
def assert_topo_order_and_all_executed(state: GraphExecutionState, order: list[str]):
"""
Validates:
1) Every materialized exec node executed exactly once.
2) Execution order respects all exec-graph dependencies (u→v ⇒ u before v).
"""
# order must be EXEC node ids in run order
exec_nodes = set(state.execution_graph.nodes.keys())
# 1) coverage: all exec nodes ran, and no duplicates
pos = {nid: i for i, nid in enumerate(order)}
assert set(pos.keys()) == exec_nodes, (
f"Executed {len(pos)} of {len(exec_nodes)} nodes. Missing: {sorted(exec_nodes - set(pos))[:10]}"
)
assert len(pos) == len(order), "Duplicate execution detected"
# 2) topo order: parents before children
for e in state.execution_graph.edges:
u = e.source.node_id
v = e.destination.node_id
assert pos[u] < pos[v], f"child {v} ran before parent {u}"
graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
graph.add_node(IterateInvocation(id="iterate"))
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
graph.add_node(PromptTestInvocation(id="prompt_successor"))
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt"))
g = GraphExecutionState(graph=graph)
order: list[str] = []
while True:
n = g.next()
if n is None:
break
o = n.invoke(Mock(InvocationContext))
g.complete(n.id, o)
order.append(n.id)
assert_topo_order_and_all_executed(g, order)
def test_graph_scheduler_drains_active_class_before_switching():
graph = Graph()
graph.add_node(PromptTestInvocation(id="prompt_a", prompt="a"))
graph.add_node(PromptTestInvocation(id="prompt_b", prompt="b"))
graph.add_node(TextToImageTestInvocation(id="image"))
g = GraphExecutionState(graph=graph)
g.set_ready_order([PromptTestInvocation, TextToImageTestInvocation])
first = invoke_next(g)[0]
second = invoke_next(g)[0]
third = invoke_next(g)[0]
assert first is not None
assert g.prepared_source_mapping[first.id] == "prompt_a"
assert g.prepared_source_mapping[second.id] == "prompt_b"
assert g.prepared_source_mapping[third.id] == "image"
def test_graph_scheduler_skips_stale_ready_entries():
graph = Graph()
graph.add_node(PromptTestInvocation(id="prompt_a", prompt="a"))
graph.add_node(PromptTestInvocation(id="prompt_b", prompt="b"))
g = GraphExecutionState(graph=graph)
g.set_ready_order([PromptTestInvocation])
first = invoke_next(g)[0]
assert first is not None
prompt_queue = g._queue_for(PromptTestInvocation.__name__)
prompt_queue.appendleft(first.id)
second = g.next()
assert second is not None
assert second.id != first.id
assert g.prepared_source_mapping[second.id] == "prompt_b"
def test_graph_scheduler_falls_back_to_non_priority_ready_classes():
graph = Graph()
graph.add_node(TextToImageTestInvocation(id="image"))
g = GraphExecutionState(graph=graph)
g.set_ready_order([PromptTestInvocation])
next_node = g.next()
assert next_node is not None
assert g.prepared_source_mapping[next_node.id] == "image"
# Because this tests deterministic ordering, we run it multiple times
@pytest.mark.parametrize("execution_number", range(5))
def test_graph_iterate_execution_order(execution_number: int):
"""Tests that iterate nodes execution is ordered by the order of the collection"""
graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi", "Strawberry Sushi", "Dinosaur Sushi"]
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
graph.add_node(IterateInvocation(id="iterate"))
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
g = GraphExecutionState(graph=graph)
_ = invoke_next(g)
_ = invoke_next(g)
assert _[1].item == "Banana sushi"
_ = invoke_next(g)
assert _[1].item == "Cat sushi"
_ = invoke_next(g)
assert _[1].item == "Strawberry Sushi"
_ = invoke_next(g)
assert _[1].item == "Dinosaur Sushi"
_ = invoke_next(g)
# Because this tests deterministic ordering, we run it multiple times
@pytest.mark.parametrize("execution_number", range(5))
def test_graph_nested_iterate_execution_order(execution_number: int):
"""
Validates best-effort in-order execution for nodes expanded under nested iterators.
Expected lexicographic order by (outer_index, inner_index), subject to readiness.
"""
graph = Graph()
# Outer iterator: [0, 1]
graph.add_node(RangeInvocation(id="outer_range", start=0, stop=2, step=1))
graph.add_node(IterateInvocation(id="outer_iter"))
# Inner iterator is derived from the outer item:
# start = outer_item * 10
# stop = start + 2 => yields 2 items per outer item
graph.add_node(MultiplyInvocation(id="mul10", b=10))
graph.add_node(AddInvocation(id="stop_plus2", b=2))
graph.add_node(RangeInvocation(id="inner_range", start=0, stop=1, step=1))
graph.add_node(IterateInvocation(id="inner_iter"))
# Observe inner items (they encode outer via start=outer*10)
graph.add_node(AddInvocation(id="sum", b=0))
graph.add_edge(create_edge("outer_range", "collection", "outer_iter", "collection"))
graph.add_edge(create_edge("outer_iter", "item", "mul10", "a"))
graph.add_edge(create_edge("mul10", "value", "stop_plus2", "a"))
graph.add_edge(create_edge("mul10", "value", "inner_range", "start"))
graph.add_edge(create_edge("stop_plus2", "value", "inner_range", "stop"))
graph.add_edge(create_edge("inner_range", "collection", "inner_iter", "collection"))
graph.add_edge(create_edge("inner_iter", "item", "sum", "a"))
g = GraphExecutionState(graph=graph)
sum_values: list[int] = []
while True:
n, o = invoke_next(g)
if n is None:
break
if g.prepared_source_mapping[n.id] == "sum":
sum_values.append(o.value)
assert sum_values == [0, 1, 10, 11]
def test_graph_validate_self_iterator_without_collection_input_raises_invalid_edge_error():
"""Iterator nodes with no collection input should fail validation cleanly.
This test exposes the bug where validation crashes with IndexError instead of raising InvalidEdgeError.
"""
from invokeai.app.services.shared.graph import InvalidEdgeError
graph = Graph()
graph.add_node(IterateInvocation(id="iterate"))
with pytest.raises(InvalidEdgeError):
graph.validate_self()
def test_graph_validate_self_collector_without_item_inputs_raises_invalid_edge_error():
"""Collector nodes with no item inputs should fail validation cleanly.
This test exposes the bug where validation can crash (e.g. StopIteration) instead of raising InvalidEdgeError.
"""
from invokeai.app.services.shared.graph import InvalidEdgeError
graph = Graph()
graph.add_node(CollectInvocation(id="collect"))
with pytest.raises(InvalidEdgeError):
graph.validate_self()
def test_if_invocation_selects_true_input_value():
invocation = IfInvocation(id="if", condition=True, true_input="true", false_input="false")
output = invocation.invoke(Mock(InvocationContext))
assert output.value == "true"
def test_if_invocation_outputs_none_when_selected_input_is_missing():
invocation = IfInvocation(id="if", condition=False, true_input="true")
output = invocation.invoke(Mock(InvocationContext))
assert output.value is None
def test_if_invocation_output_allows_missing_value_on_deserialization():
output = IfInvocationOutput.model_validate({"type": "if_output"})
assert output.value is None
def test_if_invocation_output_connects_to_downstream_input():
graph = Graph()
graph.add_node(IfInvocation(id="if", condition=True, true_input="connected value", false_input="unused"))
graph.add_node(PromptTestInvocation(id="prompt"))
graph.add_edge(create_edge("if", "value", "prompt", "prompt"))
g = GraphExecutionState(graph=graph)
while not g.is_complete():
invoke_next(g)
prepared_prompt_nodes = g.source_prepared_mapping["prompt"]
assert len(prepared_prompt_nodes) == 1
prepared_prompt_node_id = next(iter(prepared_prompt_nodes))
assert g.results[prepared_prompt_node_id].prompt == "connected value"
@pytest.mark.xfail(strict=True, reason="Legacy eager If-node execution should no longer occur")
def test_if_graph_current_behavior_executes_both_branches_and_shared_ancestors():
graph = Graph()
graph.add_node(BooleanInvocation(id="condition", value=True))
graph.add_node(PromptTestInvocation(id="shared", prompt="shared value"))
graph.add_node(PromptTestInvocation(id="true_mid"))
graph.add_node(PromptTestInvocation(id="true_leaf"))
graph.add_node(PromptTestInvocation(id="false_mid"))
graph.add_node(PromptTestInvocation(id="false_leaf"))
graph.add_node(PromptTestInvocation(id="side_consumer"))
graph.add_node(IfInvocation(id="if"))
graph.add_node(PromptTestInvocation(id="selected_output"))
graph.add_edge(create_edge("condition", "value", "if", "condition"))
graph.add_edge(create_edge("shared", "prompt", "true_mid", "prompt"))
graph.add_edge(create_edge("true_mid", "prompt", "true_leaf", "prompt"))
graph.add_edge(create_edge("true_leaf", "prompt", "if", "true_input"))
graph.add_edge(create_edge("shared", "prompt", "false_mid", "prompt"))
graph.add_edge(create_edge("false_mid", "prompt", "false_leaf", "prompt"))
graph.add_edge(create_edge("false_leaf", "prompt", "if", "false_input"))
graph.add_edge(create_edge("shared", "prompt", "side_consumer", "prompt"))
graph.add_edge(create_edge("if", "value", "selected_output", "prompt"))
g = GraphExecutionState(graph=graph)
executed_source_ids = execute_all_nodes(g)
assert set(executed_source_ids) == {
"condition",
"shared",
"true_mid",
"true_leaf",
"false_mid",
"false_leaf",
"side_consumer",
"if",
"selected_output",
}
assert executed_source_ids.count("false_mid") == 1
assert executed_source_ids.count("false_leaf") == 1
prepared_selected_output_id = next(iter(g.source_prepared_mapping["selected_output"]))
assert g.results[prepared_selected_output_id].prompt == "shared value"
@pytest.mark.xfail(strict=True, reason="Legacy eager If-node execution should no longer occur")
def test_if_graph_current_behavior_executes_both_simple_branches():
graph = Graph()
graph.add_node(BooleanInvocation(id="condition", value=True))
graph.add_node(PromptTestInvocation(id="true_value", prompt="true branch"))
graph.add_node(PromptTestInvocation(id="false_value", prompt="false branch"))
graph.add_node(IfInvocation(id="if"))
graph.add_node(PromptTestInvocation(id="selected_output"))
graph.add_edge(create_edge("condition", "value", "if", "condition"))
graph.add_edge(create_edge("true_value", "prompt", "if", "true_input"))
graph.add_edge(create_edge("false_value", "prompt", "if", "false_input"))
graph.add_edge(create_edge("if", "value", "selected_output", "prompt"))
g = GraphExecutionState(graph=graph)
executed_source_ids = execute_all_nodes(g)
assert set(executed_source_ids) == {"condition", "true_value", "false_value", "if", "selected_output"}
prepared_selected_output_id = next(iter(g.source_prepared_mapping["selected_output"]))
assert g.results[prepared_selected_output_id].prompt == "true branch"
def test_if_graph_optimized_behavior_executes_only_selected_simple_branch():
graph = Graph()
graph.add_node(BooleanInvocation(id="condition", value=True))
graph.add_node(PromptTestInvocation(id="true_value", prompt="true branch"))
graph.add_node(PromptTestInvocation(id="false_value", prompt="false branch"))
graph.add_node(IfInvocation(id="if"))
graph.add_node(PromptTestInvocation(id="selected_output"))
graph.add_edge(create_edge("condition", "value", "if", "condition"))
graph.add_edge(create_edge("true_value", "prompt", "if", "true_input"))
graph.add_edge(create_edge("false_value", "prompt", "if", "false_input"))
graph.add_edge(create_edge("if", "value", "selected_output", "prompt"))
g = GraphExecutionState(graph=graph)
executed_source_ids = execute_all_nodes(g)
assert set(executed_source_ids) == {"condition", "true_value", "if", "selected_output"}
assert "false_value" not in executed_source_ids
def test_if_graph_optimized_behavior_records_skipped_branch_in_execution_history():
graph = Graph()
graph.add_node(BooleanInvocation(id="condition", value=True))
graph.add_node(PromptTestInvocation(id="true_value", prompt="true branch"))
graph.add_node(PromptTestInvocation(id="false_value", prompt="false branch"))
graph.add_node(IfInvocation(id="if"))
graph.add_node(PromptTestInvocation(id="selected_output"))
graph.add_edge(create_edge("condition", "value", "if", "condition"))
graph.add_edge(create_edge("true_value", "prompt", "if", "true_input"))
graph.add_edge(create_edge("false_value", "prompt", "if", "false_input"))
graph.add_edge(create_edge("if", "value", "selected_output", "prompt"))
g = GraphExecutionState(graph=graph)
execute_all_nodes(g)
assert set(g.executed_history) == {"condition", "true_value", "false_value", "if", "selected_output"}
assert g.executed_history.count("false_value") == 1
def test_if_graph_optimized_behavior_skips_unselected_branch_but_keeps_shared_ancestors():
graph = Graph()
graph.add_node(BooleanInvocation(id="condition", value=True))
graph.add_node(PromptTestInvocation(id="shared", prompt="shared value"))
graph.add_node(PromptTestInvocation(id="true_mid"))
graph.add_node(PromptTestInvocation(id="true_leaf"))
graph.add_node(PromptTestInvocation(id="false_mid"))
graph.add_node(PromptTestInvocation(id="false_leaf"))
graph.add_node(PromptTestInvocation(id="side_consumer"))
graph.add_node(IfInvocation(id="if"))
graph.add_node(PromptTestInvocation(id="selected_output"))
graph.add_edge(create_edge("condition", "value", "if", "condition"))
graph.add_edge(create_edge("shared", "prompt", "true_mid", "prompt"))
graph.add_edge(create_edge("true_mid", "prompt", "true_leaf", "prompt"))
graph.add_edge(create_edge("true_leaf", "prompt", "if", "true_input"))
graph.add_edge(create_edge("shared", "prompt", "false_mid", "prompt"))
graph.add_edge(create_edge("false_mid", "prompt", "false_leaf", "prompt"))
graph.add_edge(create_edge("false_leaf", "prompt", "if", "false_input"))
graph.add_edge(create_edge("shared", "prompt", "side_consumer", "prompt"))
graph.add_edge(create_edge("if", "value", "selected_output", "prompt"))
g = GraphExecutionState(graph=graph)
executed_source_ids = execute_all_nodes(g)
assert set(executed_source_ids) == {
"condition",
"shared",
"true_mid",
"true_leaf",
"side_consumer",
"if",
"selected_output",
}
assert "false_mid" not in executed_source_ids
assert "false_leaf" not in executed_source_ids
def test_if_graph_optimized_behavior_skips_distant_unselected_ancestors_only_when_exclusive():
graph = Graph()
graph.add_node(BooleanInvocation(id="condition", value=False))
graph.add_node(PromptTestInvocation(id="shared_root", prompt="shared value"))
graph.add_node(PromptTestInvocation(id="true_shared_mid"))
graph.add_node(PromptTestInvocation(id="true_exclusive_leaf"))
graph.add_node(PromptTestInvocation(id="false_mid"))
graph.add_node(PromptTestInvocation(id="false_leaf"))
graph.add_node(PromptTestInvocation(id="shared_observer"))
graph.add_node(IfInvocation(id="if"))
graph.add_node(PromptTestInvocation(id="selected_output"))
graph.add_edge(create_edge("condition", "value", "if", "condition"))
graph.add_edge(create_edge("shared_root", "prompt", "true_shared_mid", "prompt"))
graph.add_edge(create_edge("true_shared_mid", "prompt", "true_exclusive_leaf", "prompt"))
graph.add_edge(create_edge("true_exclusive_leaf", "prompt", "if", "true_input"))
graph.add_edge(create_edge("shared_root", "prompt", "false_mid", "prompt"))
graph.add_edge(create_edge("false_mid", "prompt", "false_leaf", "prompt"))
graph.add_edge(create_edge("false_leaf", "prompt", "if", "false_input"))
graph.add_edge(create_edge("true_shared_mid", "prompt", "shared_observer", "prompt"))
graph.add_edge(create_edge("if", "value", "selected_output", "prompt"))
g = GraphExecutionState(graph=graph)
executed_source_ids = execute_all_nodes(g)
assert set(executed_source_ids) == {
"condition",
"shared_root",
"true_shared_mid",
"false_mid",
"false_leaf",
"shared_observer",
"if",
"selected_output",
}
assert "true_exclusive_leaf" not in executed_source_ids
def test_if_graph_optimized_behavior_allows_selected_missing_branch_input():
graph = Graph()
graph.add_node(BooleanInvocation(id="condition", value=False))
graph.add_node(PromptTestInvocation(id="true_value", prompt="true branch"))
graph.add_node(IfInvocation(id="if"))
graph.add_node(AnyTypeTestInvocation(id="selected_output"))
graph.add_edge(create_edge("condition", "value", "if", "condition"))
graph.add_edge(create_edge("true_value", "prompt", "if", "true_input"))
graph.add_edge(create_edge("if", "value", "selected_output", "value"))
g = GraphExecutionState(graph=graph)
executed_source_ids = execute_all_nodes(g)
prepared_selected_output_id = next(iter(g.source_prepared_mapping["selected_output"]))
assert g.results[prepared_selected_output_id].value is None
assert set(executed_source_ids) == {"condition", "if", "selected_output"}
assert "true_value" not in executed_source_ids
def test_if_graph_optimized_behavior_does_not_cross_defer_independent_ifs():
graph = Graph()
graph.add_node(BooleanInvocation(id="condition_a", value=True))
graph.add_node(BooleanInvocation(id="condition_b", value=False))
graph.add_node(PromptTestInvocation(id="true_a", prompt="true a"))
graph.add_node(PromptTestInvocation(id="false_a", prompt="false a"))
graph.add_node(PromptTestInvocation(id="true_b", prompt="true b"))
graph.add_node(PromptTestInvocation(id="false_b", prompt="false b"))
graph.add_node(IfInvocation(id="if_a"))
graph.add_node(IfInvocation(id="if_b"))
graph.add_node(CollectInvocation(id="collect"))
graph.add_edge(create_edge("condition_a", "value", "if_a", "condition"))
graph.add_edge(create_edge("true_a", "prompt", "if_a", "true_input"))
graph.add_edge(create_edge("false_a", "prompt", "if_a", "false_input"))
graph.add_edge(create_edge("condition_b", "value", "if_b", "condition"))
graph.add_edge(create_edge("true_b", "prompt", "if_b", "true_input"))
graph.add_edge(create_edge("false_b", "prompt", "if_b", "false_input"))
graph.add_edge(create_edge("if_a", "value", "collect", "item"))
graph.add_edge(create_edge("if_b", "value", "collect", "item"))
g = GraphExecutionState(graph=graph)
executed_source_ids = execute_all_nodes(g)
prepared_collect_id = next(iter(g.source_prepared_mapping["collect"]))
assert sorted(g.results[prepared_collect_id].collection) == ["false b", "true a"]
assert set(executed_source_ids) == {
"condition_a",
"condition_b",
"true_a",
"false_b",
"if_a",
"if_b",
"collect",
}
assert "false_a" not in executed_source_ids
assert "true_b" not in executed_source_ids
def test_if_graph_optimized_behavior_supports_nested_ifs():
graph = Graph()
graph.add_node(BooleanInvocation(id="outer_condition", value=True))
graph.add_node(BooleanInvocation(id="inner_condition", value=False))
graph.add_node(PromptTestInvocation(id="outer_false", prompt="outer false"))
graph.add_node(PromptTestInvocation(id="inner_true", prompt="inner true"))
graph.add_node(PromptTestInvocation(id="inner_false", prompt="inner false"))
graph.add_node(IfInvocation(id="inner_if"))
graph.add_node(IfInvocation(id="outer_if"))
graph.add_node(PromptTestInvocation(id="selected_output"))
graph.add_edge(create_edge("inner_condition", "value", "inner_if", "condition"))
graph.add_edge(create_edge("inner_true", "prompt", "inner_if", "true_input"))
graph.add_edge(create_edge("inner_false", "prompt", "inner_if", "false_input"))
graph.add_edge(create_edge("outer_condition", "value", "outer_if", "condition"))
graph.add_edge(create_edge("inner_if", "value", "outer_if", "true_input"))
graph.add_edge(create_edge("outer_false", "prompt", "outer_if", "false_input"))
graph.add_edge(create_edge("outer_if", "value", "selected_output", "prompt"))
g = GraphExecutionState(graph=graph)
executed_source_ids = execute_all_nodes(g)
prepared_selected_output_id = next(iter(g.source_prepared_mapping["selected_output"]))
assert g.results[prepared_selected_output_id].prompt == "inner false"
assert set(executed_source_ids) == {
"outer_condition",
"inner_condition",
"inner_false",
"inner_if",
"outer_if",
"selected_output",
}
assert "inner_true" not in executed_source_ids
assert "outer_false" not in executed_source_ids
def test_if_graph_optimized_behavior_prunes_branches_per_iteration():
graph = Graph()
graph.add_node(BooleanCollectionInvocation(id="conditions", collection=[True, False, True]))
graph.add_node(IterateInvocation(id="condition_iter"))
graph.add_node(AnyTypeTestInvocation(id="true_branch"))
graph.add_node(AnyTypeTestInvocation(id="false_branch"))
graph.add_node(IfInvocation(id="if"))
graph.add_node(CollectInvocation(id="collect"))
graph.add_edge(create_edge("conditions", "collection", "condition_iter", "collection"))
graph.add_edge(create_edge("condition_iter", "item", "if", "condition"))
graph.add_edge(create_edge("condition_iter", "item", "true_branch", "value"))
graph.add_edge(create_edge("true_branch", "value", "if", "true_input"))
graph.add_edge(create_edge("condition_iter", "item", "false_branch", "value"))
graph.add_edge(create_edge("false_branch", "value", "if", "false_input"))
graph.add_edge(create_edge("if", "value", "collect", "item"))
g = GraphExecutionState(graph=graph)
executed_source_ids = execute_all_nodes(g)
prepared_collect_id = next(iter(g.source_prepared_mapping["collect"]))
assert g.results[prepared_collect_id].collection == [True, False, True]
assert executed_source_ids.count("condition_iter") == 3
assert executed_source_ids.count("true_branch") == 2
assert executed_source_ids.count("false_branch") == 1
assert executed_source_ids.count("if") == 3
def test_if_graph_optimized_behavior_keeps_shared_live_consumers_per_iteration():
graph = Graph()
graph.add_node(BooleanCollectionInvocation(id="conditions", collection=[True, False, False]))
graph.add_node(IterateInvocation(id="condition_iter"))
graph.add_node(AnyTypeTestInvocation(id="shared_branch"))
graph.add_node(AnyTypeTestInvocation(id="true_leaf"))
graph.add_node(AnyTypeTestInvocation(id="false_branch"))
graph.add_node(AnyTypeTestInvocation(id="observer"))
graph.add_node(IfInvocation(id="if"))
graph.add_node(CollectInvocation(id="selected_collect"))
graph.add_node(CollectInvocation(id="observer_collect"))
graph.add_edge(create_edge("conditions", "collection", "condition_iter", "collection"))
graph.add_edge(create_edge("condition_iter", "item", "if", "condition"))
graph.add_edge(create_edge("condition_iter", "item", "shared_branch", "value"))
graph.add_edge(create_edge("shared_branch", "value", "true_leaf", "value"))
graph.add_edge(create_edge("true_leaf", "value", "if", "true_input"))
graph.add_edge(create_edge("condition_iter", "item", "false_branch", "value"))
graph.add_edge(create_edge("false_branch", "value", "if", "false_input"))
graph.add_edge(create_edge("shared_branch", "value", "observer", "value"))
graph.add_edge(create_edge("if", "value", "selected_collect", "item"))
graph.add_edge(create_edge("observer", "value", "observer_collect", "item"))
g = GraphExecutionState(graph=graph)
executed_source_ids = execute_all_nodes(g)
prepared_selected_collect_id = next(iter(g.source_prepared_mapping["selected_collect"]))
assert g.results[prepared_selected_collect_id].collection == [True, False, False]
prepared_observer_collect_id = next(iter(g.source_prepared_mapping["observer_collect"]))
assert g.results[prepared_observer_collect_id].collection == [True, False, False]
assert executed_source_ids.count("condition_iter") == 3
assert executed_source_ids.count("shared_branch") == 3
assert executed_source_ids.count("observer") == 3
assert executed_source_ids.count("true_leaf") == 1
assert executed_source_ids.count("false_branch") == 2
def test_are_connection_types_compatible_accepts_subclass_to_base():
"""A subclass output should be connectable to a base-class input.
This test exposes the bug where non-Union targets reject valid subclass connections.
"""
from invokeai.app.services.shared.graph import are_connection_types_compatible
class Base:
pass
class Child(Base):
pass
assert are_connection_types_compatible(Child, Base) is True