feat(rnd): Add InputBlock & OutputBlock (#7654)

### Background

We need an explicit block for providing input & output for the graph.
This will later allow us to build a subgraph with pre-declared input & output schema.
This will also allow us to set input for the node in the middle of the graph, and enable a graph to have output values.

### Changes 🏗️

* Add InputBlock & OutputBlock
* Add graph structure validation  on the graph execution step that asserts the following property:
    - All mandatory input pin, has to be connected or have a default value, except the `InputBlock` node.
    - All links have to connect valid nodes, and the sink & source name using the valid block field.
This commit is contained in:
Zamil Majdy
2024-08-02 10:33:34 +04:00
committed by GitHub
parent e773329391
commit 973822d973
8 changed files with 171 additions and 39 deletions

View File

@@ -20,7 +20,19 @@ for module in modules:
# Load all Block instances from the available modules
AVAILABLE_BLOCKS = {}
for cls in Block.__subclasses__():
def all_subclasses(clz):
subclasses = clz.__subclasses__()
for subclass in subclasses:
subclasses += all_subclasses(subclass)
return subclasses
for cls in all_subclasses(Block):
if not cls.__name__.endswith("Block"):
continue
block = cls()
if not isinstance(block.id, str) or len(block.id) != 36:

View File

@@ -1,8 +1,10 @@
from typing import Any
from abc import ABC, abstractmethod
from typing import Any, Generic, TypeVar
from pydantic import Field
from autogpt_server.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from autogpt_server.util.mock import MockObject
class ValueBlock(Block):
@@ -86,28 +88,39 @@ class PrintingBlock(Block):
yield "status", "printed"
class ObjectLookupBlock(Block):
class Input(BlockSchema):
input: Any = Field(description="Dictionary to lookup from")
key: str | int = Field(description="Key to lookup in the dictionary")
T = TypeVar("T")
class Output(BlockSchema):
output: Any = Field(description="Value found for the given key")
missing: Any = Field(description="Value of the input that missing the key")
def __init__(self):
class ObjectLookupBaseInput(BlockSchema, Generic[T]):
input: T = Field(description="Dictionary to lookup from")
key: str | int = Field(description="Key to lookup in the dictionary")
class ObjectLookupBaseOutput(BlockSchema, Generic[T]):
output: T = Field(description="Value found for the given key")
missing: T = Field(description="Value of the input that missing the key")
class ObjectLookupBase(Block, ABC, Generic[T]):
@abstractmethod
def block_id(self) -> str:
pass
def __init__(self, *args, **kwargs):
input_schema = ObjectLookupBaseInput[T]
output_schema = ObjectLookupBaseOutput[T]
super().__init__(
id="b2g2c3d4-5e6f-7g8h-9i0j-k1l2m3n4o5p6",
id=self.block_id(),
description="Lookup the given key in the input dictionary/object/list and return the value.",
categories={BlockCategory.BASIC},
input_schema=ObjectLookupBlock.Input,
output_schema=ObjectLookupBlock.Output,
input_schema=input_schema,
output_schema=output_schema,
test_input=[
{"input": {"apple": 1, "banana": 2, "cherry": 3}, "key": "banana"},
{"input": {"x": 10, "y": 20, "z": 30}, "key": "w"},
{"input": [1, 2, 3], "key": 1},
{"input": [1, 2, 3], "key": 3},
{"input": ObjectLookupBlock.Input(input="!!", key="key"), "key": "key"},
{"input": MockObject(value="!!", key="key"), "key": "key"},
{"input": [{"k1": "v1"}, {"k2": "v2"}, {"k1": "v3"}], "key": "k1"},
],
test_output=[
@@ -118,9 +131,11 @@ class ObjectLookupBlock(Block):
("output", "key"),
("output", ["v1", "v3"]),
],
*args,
**kwargs,
)
def run(self, input_data: Input) -> BlockOutput:
def run(self, input_data: ObjectLookupBaseInput[T]) -> BlockOutput:
obj = input_data.input
key = input_data.key
@@ -139,3 +154,30 @@ class ObjectLookupBlock(Block):
yield "output", getattr(obj, key)
else:
yield "missing", input_data.input
class ObjectLookupBlock(ObjectLookupBase[Any]):
def __init__(self):
super().__init__(categories={BlockCategory.BASIC})
def block_id(self) -> str:
return "b2g2c3d4-5e6f-7g8h-9i0j-k1l2m3n4o5p6"
class InputBlock(ObjectLookupBase[Any]):
def __init__(self):
super().__init__(categories={BlockCategory.BASIC, BlockCategory.INPUT_OUTPUT})
def block_id(self) -> str:
return "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
class OutputBlock(ObjectLookupBase[Any]):
def __init__(self):
super().__init__(categories={BlockCategory.BASIC, BlockCategory.INPUT_OUTPUT})
def block_id(self) -> str:
return "363ae599-353e-4804-937e-b2ee3cef3da4"

View File

@@ -3,7 +3,6 @@ import re
from typing import Type
from autogpt_server.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from autogpt_server.util.test import execute_block_test
class BlockInstallationBlock(Block):
@@ -57,6 +56,9 @@ class BlockInstallationBlock(Block):
module = __import__(module_name, fromlist=[class_name])
block_class: Type[Block] = getattr(module, class_name)
block = block_class()
from autogpt_server.util.test import execute_block_test
execute_block_test(block)
yield "success", "Block installed successfully."
except Exception as e:

View File

@@ -22,6 +22,7 @@ class BlockCategory(Enum):
TEXT = "Block that processes text data."
SEARCH = "Block that searches or extracts information from the internet."
BASIC = "Block that performs basic operations."
INPUT_OUTPUT = "Block that interacts with input/output of the graph."
def dict(self) -> dict[str, str]:
return {"category": self.name, "description": self.value}

View File

@@ -7,7 +7,8 @@ import prisma.types
from prisma.models import AgentGraph, AgentNode, AgentNodeLink
from pydantic import PrivateAttr
from autogpt_server.data.block import BlockInput
from autogpt_server.blocks.basic import InputBlock, OutputBlock
from autogpt_server.data.block import BlockInput, get_block
from autogpt_server.data.db import BaseDbModel
from autogpt_server.util import json
@@ -92,7 +93,68 @@ class Graph(GraphMeta):
@property
def starting_nodes(self) -> list[Node]:
outbound_nodes = {link.sink_id for link in self.links}
return [node for node in self.nodes if node.id not in outbound_nodes]
input_nodes = {
v.id for v in self.nodes if isinstance(get_block(v.block_id), InputBlock)
}
return [
node
for node in self.nodes
if node.id not in outbound_nodes or node.id in input_nodes
]
@property
def ending_nodes(self) -> list[Node]:
return [v for v in self.nodes if isinstance(get_block(v.block_id), OutputBlock)]
def validate_graph(self):
def sanitize(name):
return name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
# Check if all required fields are filled or connected, except for InputBlock.
for node in self.nodes:
block = get_block(node.block_id)
if block is None:
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
provided_inputs = set(
[sanitize(name) for name in node.input_default]
+ [sanitize(link.sink_name) for link in node.input_links]
)
for name in block.input_schema.get_required_fields():
if name not in provided_inputs and not isinstance(block, InputBlock):
raise ValueError(
f"Node {block.name} #{node.id} required input missing: `{name}`"
)
# Check if all links are connected compatible pin data type.
for link in self.links:
source_id = link.source_id
sink_id = link.sink_id
suffix = f"Link {source_id}<->{sink_id}"
source_node = next((v for v in self.nodes if v.id == source_id), None)
if not source_node:
raise ValueError(f"{suffix}, {source_id} is invalid node.")
sink_node = next((v for v in self.nodes if v.id == sink_id), None)
if not sink_node:
raise ValueError(f"{suffix}, {sink_id} is invalid node.")
source_block = get_block(source_node.block_id)
if not source_block:
raise ValueError(f"{suffix}, {source_node.block_id} is invalid block.")
sink_block = get_block(sink_node.block_id)
if not sink_block:
raise ValueError(f"{suffix}, {sink_node.block_id} is invalid block.")
source_name = sanitize(link.source_name)
if source_name not in source_block.output_schema.get_fields():
raise ValueError(f"{suffix}, `{source_name}` is invalid output pin.")
sink_name = sanitize(link.sink_name)
if sink_name not in sink_block.input_schema.get_fields():
raise ValueError(f"{suffix}, `{sink_name}` is invalid input pin.")
# TODO: Add type compatibility check here.
@staticmethod
def from_db(graph: AgentGraph):

View File

@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar
if TYPE_CHECKING:
from autogpt_server.server.server import AgentServer
from autogpt_server.blocks.basic import InputBlock
from autogpt_server.data import db
from autogpt_server.data.block import Block, BlockData, BlockInput, get_block
from autogpt_server.data.execution import (
@@ -419,10 +420,16 @@ class ExecutionManager(AppService):
graph: Graph | None = self.run_and_wait(get_graph(graph_id))
if not graph:
raise Exception(f"Graph #{graph_id} not found.")
graph.validate_graph()
nodes_input = []
for node in graph.starting_nodes:
input_data, error = validate_exec(node, data)
if isinstance(get_block(node.block_id), InputBlock):
input_data = {"input": data}
else:
input_data = {}
input_data, error = validate_exec(node, input_data)
if not input_data:
raise Exception(error)
else:

View File

@@ -1,4 +1,4 @@
from autogpt_server.blocks.basic import PrintingBlock, ValueBlock
from autogpt_server.blocks.basic import InputBlock, PrintingBlock
from autogpt_server.blocks.text import TextFormatterBlock
from autogpt_server.data import graph
from autogpt_server.data.graph import create_graph
@@ -14,12 +14,18 @@ def create_test_graph() -> graph.Graph:
ValueBlock
"""
nodes = [
graph.Node(block_id=ValueBlock().id),
graph.Node(block_id=ValueBlock().id),
graph.Node(
block_id=InputBlock().id,
input_default={"key": "input_1"},
),
graph.Node(
block_id=InputBlock().id,
input_default={"key": "input_2"},
),
graph.Node(
block_id=TextFormatterBlock().id,
input_default={
"format": "{texts[0]},{texts[1]},{texts[2]}",
"format": "{texts[0]}, {texts[1]}{texts[2]}",
"texts_$_3": "!!!",
},
),
@@ -58,7 +64,7 @@ async def sample_agent():
async with SpinTestServer() as server:
exec_man = server.exec_manager
test_graph = await create_graph(create_test_graph())
input_data = {"input": "test!!"}
input_data = {"input_1": "Hello", "input_2": "World"}
response = await server.agent_server.execute_graph(test_graph.id, input_data)
print(response)
result = await wait_execution(exec_man, test_graph.id, response["id"], 4, 10)

View File

@@ -28,37 +28,37 @@ async def execute_graph(
async def assert_sample_graph_executions(
agent_server: AgentServer, test_graph: graph.Graph, graph_exec_id: str
):
text = "Hello, World!"
input = {"input_1": "Hello", "input_2": "World"}
executions = await agent_server.get_run_execution_results(
test_graph.id, graph_exec_id
)
# Executing ConstantBlock1
# Executing ValueBlock
exec = executions[0]
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.graph_exec_id == graph_exec_id
assert exec.output_data == {"output": ["Hello, World!"]}
assert exec.input_data == {"input": text}
assert exec.output_data == {"output": ["Hello"]}
assert exec.input_data == {"input": input, "key": "input_1"}
assert exec.node_id in [test_graph.nodes[0].id, test_graph.nodes[1].id]
# Executing ConstantBlock2
# Executing ValueBlock
exec = executions[1]
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.graph_exec_id == graph_exec_id
assert exec.output_data == {"output": ["Hello, World!"]}
assert exec.input_data == {"input": text}
assert exec.output_data == {"output": ["World"]}
assert exec.input_data == {"input": input, "key": "input_2"}
assert exec.node_id in [test_graph.nodes[0].id, test_graph.nodes[1].id]
# Executing TextFormatterBlock
exec = executions[2]
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.graph_exec_id == graph_exec_id
assert exec.output_data == {"output": ["Hello, World!,Hello, World!,!!!"]}
assert exec.output_data == {"output": ["Hello, World!!!"]}
assert exec.input_data == {
"format": "{texts[0]},{texts[1]},{texts[2]}",
"texts": ["Hello, World!", "Hello, World!", "!!!"],
"texts_$_1": "Hello, World!",
"texts_$_2": "Hello, World!",
"format": "{texts[0]}, {texts[1]}{texts[2]}",
"texts": ["Hello", "World", "!!!"],
"texts_$_1": "Hello",
"texts_$_2": "World",
"texts_$_3": "!!!",
}
assert exec.node_id == test_graph.nodes[2].id
@@ -68,7 +68,7 @@ async def assert_sample_graph_executions(
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.graph_exec_id == graph_exec_id
assert exec.output_data == {"status": ["printed"]}
assert exec.input_data == {"text": "Hello, World!,Hello, World!,!!!"}
assert exec.input_data == {"text": "Hello, World!!!"}
assert exec.node_id == test_graph.nodes[3].id
@@ -76,7 +76,7 @@ async def assert_sample_graph_executions(
async def test_agent_execution(server):
test_graph = create_test_graph()
await graph.create_graph(test_graph)
data = {"input": "Hello, World!"}
data = {"input_1": "Hello", "input_2": "World"}
graph_exec_id = await execute_graph(
server.agent_server, server.exec_manager, test_graph, data, 4
)