fix(rnd): Fix broken test and Input/Output block field renaming

This commit is contained in:
Zamil Majdy
2024-08-28 22:36:53 -05:00
parent c5615aa862
commit 98c909f99f
5 changed files with 76 additions and 50 deletions

View File

@@ -75,33 +75,22 @@ class PrintingBlock(Block):
yield "status", "printed"
T = TypeVar("T")
class ObjectLookupBaseInput(BlockSchema, Generic[T]):
input: T = Field(description="Dictionary to lookup from")
key: str | int = Field(description="Key to lookup in the dictionary")
class ObjectLookupBaseOutput(BlockSchema, Generic[T]):
output: T = Field(description="Value found for the given key")
missing: T = Field(description="Value of the input that missing the key")
class ObjectLookupBase(Block, ABC, Generic[T]):
@abstractmethod
def block_id(self) -> str:
pass
def __init__(self, *args, **kwargs):
input_schema = ObjectLookupBaseInput[T]
output_schema = ObjectLookupBaseOutput[T]
class ObjectLookupBlock(Block):
class Input(BlockSchema):
input: Any = Field(description="Dictionary to lookup from")
key: str | int = Field(description="Key to lookup in the dictionary")
class Output(BlockSchema):
output: Any = Field(description="Value found for the given key")
missing: Any = Field(description="Value of the input that missing the key")
def __init__(self):
super().__init__(
id=self.block_id(),
id="b2g2c3d4-5e6f-7g8h-9i0j-k1l2m3n4o5p6",
description="Lookup the given key in the input dictionary/object/list and return the value.",
input_schema=input_schema,
output_schema=output_schema,
input_schema=ObjectLookupBlock.Input,
output_schema=ObjectLookupBlock.Output,
test_input=[
{"input": {"apple": 1, "banana": 2, "cherry": 3}, "key": "banana"},
{"input": {"x": 10, "y": 20, "z": 30}, "key": "w"},
@@ -118,11 +107,10 @@ class ObjectLookupBase(Block, ABC, Generic[T]):
("output", "key"),
("output", ["v1", "v3"]),
],
*args,
**kwargs,
categories={BlockCategory.BASIC}
)
def run(self, input_data: ObjectLookupBaseInput[T]) -> BlockOutput:
def run(self, input_data: Input) -> BlockOutput:
obj = input_data.input
key = input_data.key
@@ -143,15 +131,50 @@ class ObjectLookupBase(Block, ABC, Generic[T]):
yield "missing", input_data.input
class ObjectLookupBlock(ObjectLookupBase[Any]):
def __init__(self):
super().__init__(categories={BlockCategory.BASIC})
T = TypeVar("T")
class InputOutputBlockInput(BlockSchema, Generic[T]):
value: T = Field(description="The value to be passed as input/output.")
name: str = Field(description="The name of the input/output.")
class InputOutputBlockOutput(BlockSchema, Generic[T]):
value: T = Field(description="The value passed as input/output.")
class InputOutputBlockBase(Block, ABC, Generic[T]):
@abstractmethod
def block_id(self) -> str:
return "b2g2c3d4-5e6f-7g8h-9i0j-k1l2m3n4o5p6"
pass
def __init__(self, *args, **kwargs):
input_schema = InputOutputBlockInput[T]
output_schema = InputOutputBlockOutput[T]
super().__init__(
id=self.block_id(),
description="This block is used to define the input & output of a graph.",
input_schema=input_schema,
output_schema=output_schema,
test_input=[
{"value": {"apple": 1, "banana": 2, "cherry": 3}, "name": "input_1"},
{"value": MockObject(value="!!", key="key"), "name": "input_2"},
],
test_output=[
("value", {"apple": 1, "banana": 2, "cherry": 3}),
("value", MockObject(value="!!", key="key")),
],
static_output=True,
*args,
**kwargs,
)
def run(self, input_data: InputOutputBlockInput[T]) -> BlockOutput:
yield "value", input_data.value
class InputBlock(ObjectLookupBase[Any]):
class InputBlock(InputOutputBlockBase[Any]):
def __init__(self):
super().__init__(categories={BlockCategory.INPUT, BlockCategory.BASIC})
@@ -159,7 +182,7 @@ class InputBlock(ObjectLookupBase[Any]):
return "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
class OutputBlock(ObjectLookupBase[Any]):
class OutputBlock(InputOutputBlockBase[Any]):
def __init__(self):
super().__init__(categories={BlockCategory.OUTPUT, BlockCategory.BASIC})

View File

@@ -1,6 +1,6 @@
import time
from datetime import datetime, timedelta
from typing import Union
from typing import Any, Union
from autogpt_server.data.block import Block, BlockCategory, BlockOutput, BlockSchema
@@ -103,6 +103,7 @@ class CurrentDateAndTimeBlock(Block):
class TimerBlock(Block):
class Input(BlockSchema):
message: Any = "timer finished"
seconds: Union[int, str] = 0
minutes: Union[int, str] = 0
hours: Union[int, str] = 0
@@ -120,9 +121,11 @@ class TimerBlock(Block):
output_schema=TimerBlock.Output,
test_input=[
{"seconds": 1},
{"message": "Custom message"},
],
test_output=[
("message", "timer finished"),
("message", "Custom message"),
],
)
@@ -136,4 +139,4 @@ class TimerBlock(Block):
total_seconds = seconds + minutes * 60 + hours * 3600 + days * 86400
time.sleep(total_seconds)
yield "message", "timer finished"
yield "message", input_data.message

View File

@@ -436,11 +436,12 @@ class ExecutionManager(AppService):
nodes_input = []
for node in graph.starting_nodes:
input_data = {}
if isinstance(get_block(node.block_id), InputBlock):
input_data = {"input": data}
else:
input_data = {}
name = node.input_default.get("name")
if name and name in data:
input_data = {"value": data[name]}
input_data, error = validate_exec(node, input_data)
if input_data is None:
raise Exception(error)

View File

@@ -29,11 +29,11 @@ def create_test_graph() -> graph.Graph:
nodes = [
graph.Node(
block_id=InputBlock().id,
input_default={"key": "input_1"},
input_default={"name": "input_1"},
),
graph.Node(
block_id=InputBlock().id,
input_default={"key": "input_2"},
input_default={"name": "input_2"},
),
graph.Node(
block_id=TextFormatterBlock().id,
@@ -48,13 +48,13 @@ def create_test_graph() -> graph.Graph:
graph.Link(
source_id=nodes[0].id,
sink_id=nodes[2].id,
source_name="output",
source_name="value",
sink_name="values_#_a",
),
graph.Link(
source_id=nodes[1].id,
sink_id=nodes[2].id,
source_name="output",
source_name="value",
sink_name="values_#_b",
),
graph.Link(

View File

@@ -35,7 +35,6 @@ async def assert_sample_graph_executions(
test_user: User,
graph_exec_id: str,
):
input = {"input_1": "Hello", "input_2": "World"}
executions = await agent_server.get_run_execution_results(
test_graph.id, graph_exec_id, test_user.id
)
@@ -44,16 +43,16 @@ async def assert_sample_graph_executions(
exec = executions[0]
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.graph_exec_id == graph_exec_id
assert exec.output_data == {"output": ["Hello"]}
assert exec.input_data == {"input": input, "key": "input_1"}
assert exec.output_data == {"value": ["Hello"]}
assert exec.input_data == {"value": "Hello", "name": "input_1"}
assert exec.node_id in [test_graph.nodes[0].id, test_graph.nodes[1].id]
# Executing ValueBlock
exec = executions[1]
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.graph_exec_id == graph_exec_id
assert exec.output_data == {"output": ["World"]}
assert exec.input_data == {"input": input, "key": "input_2"}
assert exec.output_data == {"value": ["World"]}
assert exec.input_data == {"value": "World", "name": "input_2"}
assert exec.node_id in [test_graph.nodes[0].id, test_graph.nodes[1].id]
# Executing TextFormatterBlock
@@ -151,7 +150,7 @@ async def test_input_pin_always_waited(server):
server.agent_server, server.exec_manager, test_graph, test_user, {}, 3
)
executions = await server.agent_server.get_run_execution_results(
executions = await server.agent_server.get_graph_run_node_execution_results(
test_graph.id, graph_exec_id, test_user.id
)
assert len(executions) == 3
@@ -231,7 +230,7 @@ async def test_static_input_link_on_graph(server):
graph_exec_id = await execute_graph(
server.agent_server, server.exec_manager, test_graph, test_user, {}, 8
)
executions = await server.agent_server.get_run_execution_results(
executions = await server.agent_server.get_graph_run_node_execution_results(
test_graph.id, graph_exec_id, test_user.id
)
assert len(executions) == 8