mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(rnd): Fix broken test and Input/Output block field renaming
This commit is contained in:
@@ -75,33 +75,22 @@ class PrintingBlock(Block):
|
||||
yield "status", "printed"
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ObjectLookupBaseInput(BlockSchema, Generic[T]):
|
||||
input: T = Field(description="Dictionary to lookup from")
|
||||
key: str | int = Field(description="Key to lookup in the dictionary")
|
||||
|
||||
|
||||
class ObjectLookupBaseOutput(BlockSchema, Generic[T]):
|
||||
output: T = Field(description="Value found for the given key")
|
||||
missing: T = Field(description="Value of the input that missing the key")
|
||||
|
||||
|
||||
class ObjectLookupBase(Block, ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def block_id(self) -> str:
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
input_schema = ObjectLookupBaseInput[T]
|
||||
output_schema = ObjectLookupBaseOutput[T]
|
||||
class ObjectLookupBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
input: Any = Field(description="Dictionary to lookup from")
|
||||
key: str | int = Field(description="Key to lookup in the dictionary")
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = Field(description="Value found for the given key")
|
||||
missing: Any = Field(description="Value of the input that missing the key")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id=self.block_id(),
|
||||
id="b2g2c3d4-5e6f-7g8h-9i0j-k1l2m3n4o5p6",
|
||||
description="Lookup the given key in the input dictionary/object/list and return the value.",
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
input_schema=ObjectLookupBlock.Input,
|
||||
output_schema=ObjectLookupBlock.Output,
|
||||
test_input=[
|
||||
{"input": {"apple": 1, "banana": 2, "cherry": 3}, "key": "banana"},
|
||||
{"input": {"x": 10, "y": 20, "z": 30}, "key": "w"},
|
||||
@@ -118,11 +107,10 @@ class ObjectLookupBase(Block, ABC, Generic[T]):
|
||||
("output", "key"),
|
||||
("output", ["v1", "v3"]),
|
||||
],
|
||||
*args,
|
||||
**kwargs,
|
||||
categories={BlockCategory.BASIC}
|
||||
)
|
||||
|
||||
def run(self, input_data: ObjectLookupBaseInput[T]) -> BlockOutput:
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
obj = input_data.input
|
||||
key = input_data.key
|
||||
|
||||
@@ -143,15 +131,50 @@ class ObjectLookupBase(Block, ABC, Generic[T]):
|
||||
yield "missing", input_data.input
|
||||
|
||||
|
||||
class ObjectLookupBlock(ObjectLookupBase[Any]):
|
||||
def __init__(self):
|
||||
super().__init__(categories={BlockCategory.BASIC})
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class InputOutputBlockInput(BlockSchema, Generic[T]):
|
||||
value: T = Field(description="The value to be passed as input/output.")
|
||||
name: str = Field(description="The name of the input/output.")
|
||||
|
||||
|
||||
class InputOutputBlockOutput(BlockSchema, Generic[T]):
|
||||
value: T = Field(description="The value passed as input/output.")
|
||||
|
||||
|
||||
class InputOutputBlockBase(Block, ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def block_id(self) -> str:
|
||||
return "b2g2c3d4-5e6f-7g8h-9i0j-k1l2m3n4o5p6"
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
input_schema = InputOutputBlockInput[T]
|
||||
output_schema = InputOutputBlockOutput[T]
|
||||
|
||||
super().__init__(
|
||||
id=self.block_id(),
|
||||
description="This block is used to define the input & output of a graph.",
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
test_input=[
|
||||
{"value": {"apple": 1, "banana": 2, "cherry": 3}, "name": "input_1"},
|
||||
{"value": MockObject(value="!!", key="key"), "name": "input_2"},
|
||||
],
|
||||
test_output=[
|
||||
("value", {"apple": 1, "banana": 2, "cherry": 3}),
|
||||
("value", MockObject(value="!!", key="key")),
|
||||
],
|
||||
static_output=True,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def run(self, input_data: InputOutputBlockInput[T]) -> BlockOutput:
|
||||
yield "value", input_data.value
|
||||
|
||||
|
||||
class InputBlock(ObjectLookupBase[Any]):
|
||||
class InputBlock(InputOutputBlockBase[Any]):
|
||||
def __init__(self):
|
||||
super().__init__(categories={BlockCategory.INPUT, BlockCategory.BASIC})
|
||||
|
||||
@@ -159,7 +182,7 @@ class InputBlock(ObjectLookupBase[Any]):
|
||||
return "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
|
||||
|
||||
|
||||
class OutputBlock(ObjectLookupBase[Any]):
|
||||
class OutputBlock(InputOutputBlockBase[Any]):
|
||||
def __init__(self):
|
||||
super().__init__(categories={BlockCategory.OUTPUT, BlockCategory.BASIC})
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
from autogpt_server.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
|
||||
@@ -103,6 +103,7 @@ class CurrentDateAndTimeBlock(Block):
|
||||
|
||||
class TimerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
message: Any = "timer finished"
|
||||
seconds: Union[int, str] = 0
|
||||
minutes: Union[int, str] = 0
|
||||
hours: Union[int, str] = 0
|
||||
@@ -120,9 +121,11 @@ class TimerBlock(Block):
|
||||
output_schema=TimerBlock.Output,
|
||||
test_input=[
|
||||
{"seconds": 1},
|
||||
{"message": "Custom message"},
|
||||
],
|
||||
test_output=[
|
||||
("message", "timer finished"),
|
||||
("message", "Custom message"),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -136,4 +139,4 @@ class TimerBlock(Block):
|
||||
total_seconds = seconds + minutes * 60 + hours * 3600 + days * 86400
|
||||
|
||||
time.sleep(total_seconds)
|
||||
yield "message", "timer finished"
|
||||
yield "message", input_data.message
|
||||
|
||||
@@ -436,11 +436,12 @@ class ExecutionManager(AppService):
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
input_data = {}
|
||||
if isinstance(get_block(node.block_id), InputBlock):
|
||||
input_data = {"input": data}
|
||||
else:
|
||||
input_data = {}
|
||||
|
||||
name = node.input_default.get("name")
|
||||
if name and name in data:
|
||||
input_data = {"value": data[name]}
|
||||
|
||||
input_data, error = validate_exec(node, input_data)
|
||||
if input_data is None:
|
||||
raise Exception(error)
|
||||
|
||||
@@ -29,11 +29,11 @@ def create_test_graph() -> graph.Graph:
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=InputBlock().id,
|
||||
input_default={"key": "input_1"},
|
||||
input_default={"name": "input_1"},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=InputBlock().id,
|
||||
input_default={"key": "input_2"},
|
||||
input_default={"name": "input_2"},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=TextFormatterBlock().id,
|
||||
@@ -48,13 +48,13 @@ def create_test_graph() -> graph.Graph:
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="output",
|
||||
source_name="value",
|
||||
sink_name="values_#_a",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[1].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="output",
|
||||
source_name="value",
|
||||
sink_name="values_#_b",
|
||||
),
|
||||
graph.Link(
|
||||
|
||||
@@ -35,7 +35,6 @@ async def assert_sample_graph_executions(
|
||||
test_user: User,
|
||||
graph_exec_id: str,
|
||||
):
|
||||
input = {"input_1": "Hello", "input_2": "World"}
|
||||
executions = await agent_server.get_run_execution_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
@@ -44,16 +43,16 @@ async def assert_sample_graph_executions(
|
||||
exec = executions[0]
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
assert exec.output_data == {"output": ["Hello"]}
|
||||
assert exec.input_data == {"input": input, "key": "input_1"}
|
||||
assert exec.output_data == {"value": ["Hello"]}
|
||||
assert exec.input_data == {"value": "Hello", "name": "input_1"}
|
||||
assert exec.node_id in [test_graph.nodes[0].id, test_graph.nodes[1].id]
|
||||
|
||||
# Executing ValueBlock
|
||||
exec = executions[1]
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
assert exec.output_data == {"output": ["World"]}
|
||||
assert exec.input_data == {"input": input, "key": "input_2"}
|
||||
assert exec.output_data == {"value": ["World"]}
|
||||
assert exec.input_data == {"value": "World", "name": "input_2"}
|
||||
assert exec.node_id in [test_graph.nodes[0].id, test_graph.nodes[1].id]
|
||||
|
||||
# Executing TextFormatterBlock
|
||||
@@ -151,7 +150,7 @@ async def test_input_pin_always_waited(server):
|
||||
server.agent_server, server.exec_manager, test_graph, test_user, {}, 3
|
||||
)
|
||||
|
||||
executions = await server.agent_server.get_run_execution_results(
|
||||
executions = await server.agent_server.get_graph_run_node_execution_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
assert len(executions) == 3
|
||||
@@ -231,7 +230,7 @@ async def test_static_input_link_on_graph(server):
|
||||
graph_exec_id = await execute_graph(
|
||||
server.agent_server, server.exec_manager, test_graph, test_user, {}, 8
|
||||
)
|
||||
executions = await server.agent_server.get_run_execution_results(
|
||||
executions = await server.agent_server.get_graph_run_node_execution_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
assert len(executions) == 8
|
||||
|
||||
Reference in New Issue
Block a user