mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
tests(app): add failing test for collector edge case
squash squash
This commit is contained in:
@@ -20,9 +20,11 @@ from invokeai.app.invocations.primitives import (
|
||||
from invokeai.app.invocations.upscale import ESRGANInvocation
|
||||
from invokeai.app.services.shared.graph import (
|
||||
CollectInvocation,
|
||||
CollectInvocationOutput,
|
||||
Edge,
|
||||
EdgeConnection,
|
||||
Graph,
|
||||
GraphExecutionState,
|
||||
InvalidEdgeError,
|
||||
IterateInvocation,
|
||||
NodeAlreadyInGraphError,
|
||||
@@ -35,9 +37,12 @@ from tests.test_nodes import (
|
||||
ListPassThroughInvocation,
|
||||
PolymorphicStringTestInvocation,
|
||||
PromptCollectionTestInvocation,
|
||||
PromptCollectionTestInvocationOutput,
|
||||
PromptTestInvocation,
|
||||
PromptTestInvocationOutput,
|
||||
TextToImageTestInvocation,
|
||||
get_single_output_from_session,
|
||||
run_session_with_mock_context,
|
||||
)
|
||||
|
||||
|
||||
@@ -758,3 +763,26 @@ def test_nodes_must_return_invocation_output():
|
||||
class NoOutputInvocation(BaseInvocation):
|
||||
def invoke(self) -> str:
|
||||
return "foo"
|
||||
|
||||
def test_collector_different_incomers():
|
||||
"""Tests an edge case where a collector has incoming edges from invocations with differently-named output fields."""
|
||||
g = Graph()
|
||||
# This node has a str type output field named "prompt"
|
||||
n1 = PromptTestInvocation(id="1", prompt="Banana")
|
||||
# This node has a str type output field named "value"
|
||||
n2 = StringInvocation(id="2", value="Sushi")
|
||||
n3 = CollectInvocation(id="3")
|
||||
g.add_node(n1)
|
||||
g.add_node(n2)
|
||||
g.add_node(n3)
|
||||
e1 = create_edge(n1.id, "prompt", n3.id, "item")
|
||||
e2 = create_edge(n2.id, "value", n3.id, "item")
|
||||
g.add_edge(e1)
|
||||
g.add_edge(e2)
|
||||
session = GraphExecutionState(graph=g)
|
||||
# The bug resulted in an error like this when calling session.next():
|
||||
# Field types are incompatible (a0f9797b-1179-4200-81ae-6ef981660163.prompt -> ccc6af96-2a65-4bbe-a02f-4189bb4770ac.item)
|
||||
run_session_with_mock_context(session)
|
||||
output= get_single_output_from_session(session, n3.id)
|
||||
assert isinstance(output, CollectInvocationOutput)
|
||||
assert output.collection == ["Banana", "Sushi"] # Both inputs should be collected
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any, Callable, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
@@ -108,7 +109,7 @@ class PolymorphicStringTestInvocation(BaseInvocation):
|
||||
|
||||
# Importing these must happen after test invocations are defined or they won't register
|
||||
from invokeai.app.services.events.events_base import EventServiceBase # noqa: E402
|
||||
from invokeai.app.services.shared.graph import Edge, EdgeConnection # noqa: E402
|
||||
from invokeai.app.services.shared.graph import Edge, EdgeConnection, GraphExecutionState # noqa: E402
|
||||
|
||||
|
||||
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
|
||||
@@ -155,3 +156,27 @@ def wait_until(condition: Callable[[], bool], timeout: int = 10, interval: float
|
||||
return
|
||||
time.sleep(interval)
|
||||
raise TimeoutError("Condition not met")
|
||||
|
||||
|
||||
def run_session_with_mock_context(session: GraphExecutionState):
|
||||
"""Run the session with a mock context to simulate invocation execution.
|
||||
|
||||
The graph may only contain invocations that operate on primitive types. Images, models, or any other types that
|
||||
require a real context cannot be used in this mock execution.
|
||||
"""
|
||||
mock_context = MagicMock(spec=InvocationContext)
|
||||
invocation = session.next()
|
||||
while invocation is not None:
|
||||
output = invocation.invoke(mock_context)
|
||||
session.complete(invocation.id, output)
|
||||
invocation = session.next()
|
||||
|
||||
|
||||
def get_single_output_from_session(session: GraphExecutionState, node_id: str) -> BaseInvocationOutput:
|
||||
assert len(session.source_prepared_mapping[node_id]) == 1, (
|
||||
"Expected exactly one prepared node for the given node_id"
|
||||
)
|
||||
prepared_node_id = session.source_prepared_mapping[node_id].pop()
|
||||
output = session.results[prepared_node_id]
|
||||
assert isinstance(output, BaseInvocationOutput), "Expected output to be of type BaseInvocationOutput"
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user