From baae998b5b690ca0878c44cbb383f11737ee7419 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 30 Jun 2025 17:03:12 +1000 Subject: [PATCH] tests(app): add failing test for collector edge case squash squash --- tests/test_node_graph.py | 28 ++++++++++++++++++++++++++++ tests/test_nodes.py | 27 ++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/tests/test_node_graph.py b/tests/test_node_graph.py index 0a4ce77538..5ba07ff7f0 100644 --- a/tests/test_node_graph.py +++ b/tests/test_node_graph.py @@ -20,9 +20,11 @@ from invokeai.app.invocations.primitives import ( from invokeai.app.invocations.upscale import ESRGANInvocation from invokeai.app.services.shared.graph import ( CollectInvocation, + CollectInvocationOutput, Edge, EdgeConnection, Graph, + GraphExecutionState, InvalidEdgeError, IterateInvocation, NodeAlreadyInGraphError, @@ -35,9 +37,12 @@ from tests.test_nodes import ( ListPassThroughInvocation, PolymorphicStringTestInvocation, PromptCollectionTestInvocation, + PromptCollectionTestInvocationOutput, PromptTestInvocation, PromptTestInvocationOutput, TextToImageTestInvocation, + get_single_output_from_session, + run_session_with_mock_context, ) @@ -758,3 +763,26 @@ def test_nodes_must_return_invocation_output(): class NoOutputInvocation(BaseInvocation): def invoke(self) -> str: return "foo" + +def test_collector_different_incomers(): + """Tests an edge case where a collector has incoming edges from invocations with differently-named output fields.""" + g = Graph() + # This node has a str type output field named "prompt" + n1 = PromptTestInvocation(id="1", prompt="Banana") + # This node has a str type output field named "value" + n2 = StringInvocation(id="2", value="Sushi") + n3 = CollectInvocation(id="3") + g.add_node(n1) + g.add_node(n2) + g.add_node(n3) + e1 = create_edge(n1.id, "prompt", n3.id, "item") + e2 = create_edge(n2.id, "value", n3.id, "item") + g.add_edge(e1) + g.add_edge(e2) + session = GraphExecutionState(graph=g) + # The bug resulted in an error like this when calling session.next(): + # Field types are incompatible (a0f9797b-1179-4200-81ae-6ef981660163.prompt -> ccc6af96-2a65-4bbe-a02f-4189bb4770ac.item) + run_session_with_mock_context(session) + output= get_single_output_from_session(session, n3.id) + assert isinstance(output, CollectInvocationOutput) + assert output.collection == ["Banana", "Sushi"] # Both inputs should be collected diff --git a/tests/test_nodes.py b/tests/test_nodes.py index ad354fde35..04ea5126f0 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -1,4 +1,5 @@ from typing import Any, Callable, Union +from unittest.mock import MagicMock from invokeai.app.invocations.baseinvocation import ( BaseInvocation, @@ -108,7 +109,7 @@ class PolymorphicStringTestInvocation(BaseInvocation): # Importing these must happen after test invocations are defined or they won't register from invokeai.app.services.events.events_base import EventServiceBase # noqa: E402 -from invokeai.app.services.shared.graph import Edge, EdgeConnection # noqa: E402 +from invokeai.app.services.shared.graph import Edge, EdgeConnection, GraphExecutionState # noqa: E402 def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge: @@ -155,3 +156,27 @@ def wait_until(condition: Callable[[], bool], timeout: int = 10, interval: float return time.sleep(interval) raise TimeoutError("Condition not met") + + +def run_session_with_mock_context(session: GraphExecutionState): + """Run the session with a mock context to simulate invocation execution. + + The graph may only contain invocations that operate on primitive types. Images, models, or any other types that + require a real context cannot be used in this mock execution. + """ + mock_context = MagicMock(spec=InvocationContext) + invocation = session.next() + while invocation is not None: + output = invocation.invoke(mock_context) + session.complete(invocation.id, output) + invocation = session.next() + + +def get_single_output_from_session(session: GraphExecutionState, node_id: str) -> BaseInvocationOutput: + assert len(session.source_prepared_mapping[node_id]) == 1, ( + "Expected exactly one prepared node for the given node_id" + ) + prepared_node_id = session.source_prepared_mapping[node_id].pop() + output = session.results[prepared_node_id] + assert isinstance(output, BaseInvocationOutput), "Expected output to be of type BaseInvocationOutput" + return output