tests(app): add failing test for collector edge case

squash

squash
This commit is contained in:
psychedelicious
2025-06-30 17:03:12 +10:00
parent 4077ffe595
commit baae998b5b
2 changed files with 54 additions and 1 deletions

View File

@@ -20,9 +20,11 @@ from invokeai.app.invocations.primitives import (
from invokeai.app.invocations.upscale import ESRGANInvocation
from invokeai.app.services.shared.graph import (
CollectInvocation,
CollectInvocationOutput,
Edge,
EdgeConnection,
Graph,
GraphExecutionState,
InvalidEdgeError,
IterateInvocation,
NodeAlreadyInGraphError,
@@ -35,9 +37,12 @@ from tests.test_nodes import (
ListPassThroughInvocation,
PolymorphicStringTestInvocation,
PromptCollectionTestInvocation,
PromptCollectionTestInvocationOutput,
PromptTestInvocation,
PromptTestInvocationOutput,
TextToImageTestInvocation,
get_single_output_from_session,
run_session_with_mock_context,
)
@@ -758,3 +763,26 @@ def test_nodes_must_return_invocation_output():
class NoOutputInvocation(BaseInvocation):
def invoke(self) -> str:
return "foo"
def test_collector_different_incomers():
"""Tests an edge case where a collector has incoming edges from invocations with differently-named output fields."""
g = Graph()
# This node has a str type output field named "prompt"
n1 = PromptTestInvocation(id="1", prompt="Banana")
# This node has a str type output field named "value"
n2 = StringInvocation(id="2", value="Sushi")
n3 = CollectInvocation(id="3")
g.add_node(n1)
g.add_node(n2)
g.add_node(n3)
e1 = create_edge(n1.id, "prompt", n3.id, "item")
e2 = create_edge(n2.id, "value", n3.id, "item")
g.add_edge(e1)
g.add_edge(e2)
session = GraphExecutionState(graph=g)
# The bug resulted in an error like this when calling session.next():
# Field types are incompatible (a0f9797b-1179-4200-81ae-6ef981660163.prompt -> ccc6af96-2a65-4bbe-a02f-4189bb4770ac.item)
run_session_with_mock_context(session)
output= get_single_output_from_session(session, n3.id)
assert isinstance(output, CollectInvocationOutput)
assert output.collection == ["Banana", "Sushi"] # Both inputs should be collected

View File

@@ -1,4 +1,5 @@
from typing import Any, Callable, Union
from unittest.mock import MagicMock
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@@ -108,7 +109,7 @@ class PolymorphicStringTestInvocation(BaseInvocation):
# Importing these must happen after test invocations are defined or they won't register
from invokeai.app.services.events.events_base import EventServiceBase # noqa: E402
from invokeai.app.services.shared.graph import Edge, EdgeConnection # noqa: E402
from invokeai.app.services.shared.graph import Edge, EdgeConnection, GraphExecutionState # noqa: E402
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
@@ -155,3 +156,27 @@ def wait_until(condition: Callable[[], bool], timeout: int = 10, interval: float
return
time.sleep(interval)
raise TimeoutError("Condition not met")
def run_session_with_mock_context(session: GraphExecutionState):
"""Run the session with a mock context to simulate invocation execution.
The graph may only contain invocations that operate on primitive types. Images, models, or any other types that
require a real context cannot be used in this mock execution.
"""
mock_context = MagicMock(spec=InvocationContext)
invocation = session.next()
while invocation is not None:
output = invocation.invoke(mock_context)
session.complete(invocation.id, output)
invocation = session.next()
def get_single_output_from_session(session: GraphExecutionState, node_id: str) -> BaseInvocationOutput:
assert len(session.source_prepared_mapping[node_id]) == 1, (
"Expected exactly one prepared node for the given node_id"
)
prepared_node_id = session.source_prepared_mapping[node_id].pop()
output = session.results[prepared_node_id]
assert isinstance(output, BaseInvocationOutput), "Expected output to be of type BaseInvocationOutput"
return output