mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(rnd): Add InputBlock & OutputBlock (#7654)
### Background
We need an explicit block for providing input & output for the graph.
This will later allow us to build a subgraph with pre-declared input & output schema.
This will also allow us to set input for the node in the middle of the graph, and enable a graph to have output values.
### Changes 🏗️
* Add InputBlock & OutputBlock
* Add graph structure validation on the graph execution step that asserts the following property:
- All mandatory input pin, has to be connected or have a default value, except the `InputBlock` node.
- All links have to connect valid nodes, and the sink & source name using the valid block field.
This commit is contained in:
@@ -20,7 +20,19 @@ for module in modules:
|
||||
|
||||
# Load all Block instances from the available modules
|
||||
AVAILABLE_BLOCKS = {}
|
||||
for cls in Block.__subclasses__():
|
||||
|
||||
|
||||
def all_subclasses(clz):
|
||||
subclasses = clz.__subclasses__()
|
||||
for subclass in subclasses:
|
||||
subclasses += all_subclasses(subclass)
|
||||
return subclasses
|
||||
|
||||
|
||||
for cls in all_subclasses(Block):
|
||||
if not cls.__name__.endswith("Block"):
|
||||
continue
|
||||
|
||||
block = cls()
|
||||
|
||||
if not isinstance(block.id, str) or len(block.id) != 36:
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from typing import Any
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from autogpt_server.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from autogpt_server.util.mock import MockObject
|
||||
|
||||
|
||||
class ValueBlock(Block):
|
||||
@@ -86,28 +88,39 @@ class PrintingBlock(Block):
|
||||
yield "status", "printed"
|
||||
|
||||
|
||||
class ObjectLookupBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
input: Any = Field(description="Dictionary to lookup from")
|
||||
key: str | int = Field(description="Key to lookup in the dictionary")
|
||||
T = TypeVar("T")
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = Field(description="Value found for the given key")
|
||||
missing: Any = Field(description="Value of the input that missing the key")
|
||||
|
||||
def __init__(self):
|
||||
class ObjectLookupBaseInput(BlockSchema, Generic[T]):
|
||||
input: T = Field(description="Dictionary to lookup from")
|
||||
key: str | int = Field(description="Key to lookup in the dictionary")
|
||||
|
||||
|
||||
class ObjectLookupBaseOutput(BlockSchema, Generic[T]):
|
||||
output: T = Field(description="Value found for the given key")
|
||||
missing: T = Field(description="Value of the input that missing the key")
|
||||
|
||||
|
||||
class ObjectLookupBase(Block, ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def block_id(self) -> str:
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
input_schema = ObjectLookupBaseInput[T]
|
||||
output_schema = ObjectLookupBaseOutput[T]
|
||||
|
||||
super().__init__(
|
||||
id="b2g2c3d4-5e6f-7g8h-9i0j-k1l2m3n4o5p6",
|
||||
id=self.block_id(),
|
||||
description="Lookup the given key in the input dictionary/object/list and return the value.",
|
||||
categories={BlockCategory.BASIC},
|
||||
input_schema=ObjectLookupBlock.Input,
|
||||
output_schema=ObjectLookupBlock.Output,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
test_input=[
|
||||
{"input": {"apple": 1, "banana": 2, "cherry": 3}, "key": "banana"},
|
||||
{"input": {"x": 10, "y": 20, "z": 30}, "key": "w"},
|
||||
{"input": [1, 2, 3], "key": 1},
|
||||
{"input": [1, 2, 3], "key": 3},
|
||||
{"input": ObjectLookupBlock.Input(input="!!", key="key"), "key": "key"},
|
||||
{"input": MockObject(value="!!", key="key"), "key": "key"},
|
||||
{"input": [{"k1": "v1"}, {"k2": "v2"}, {"k1": "v3"}], "key": "k1"},
|
||||
],
|
||||
test_output=[
|
||||
@@ -118,9 +131,11 @@ class ObjectLookupBlock(Block):
|
||||
("output", "key"),
|
||||
("output", ["v1", "v3"]),
|
||||
],
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: ObjectLookupBaseInput[T]) -> BlockOutput:
|
||||
obj = input_data.input
|
||||
key = input_data.key
|
||||
|
||||
@@ -139,3 +154,30 @@ class ObjectLookupBlock(Block):
|
||||
yield "output", getattr(obj, key)
|
||||
else:
|
||||
yield "missing", input_data.input
|
||||
|
||||
|
||||
class ObjectLookupBlock(ObjectLookupBase[Any]):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(categories={BlockCategory.BASIC})
|
||||
|
||||
def block_id(self) -> str:
|
||||
return "b2g2c3d4-5e6f-7g8h-9i0j-k1l2m3n4o5p6"
|
||||
|
||||
|
||||
class InputBlock(ObjectLookupBase[Any]):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(categories={BlockCategory.BASIC, BlockCategory.INPUT_OUTPUT})
|
||||
|
||||
def block_id(self) -> str:
|
||||
return "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
|
||||
|
||||
|
||||
class OutputBlock(ObjectLookupBase[Any]):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(categories={BlockCategory.BASIC, BlockCategory.INPUT_OUTPUT})
|
||||
|
||||
def block_id(self) -> str:
|
||||
return "363ae599-353e-4804-937e-b2ee3cef3da4"
|
||||
|
||||
@@ -3,7 +3,6 @@ import re
|
||||
from typing import Type
|
||||
|
||||
from autogpt_server.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from autogpt_server.util.test import execute_block_test
|
||||
|
||||
|
||||
class BlockInstallationBlock(Block):
|
||||
@@ -57,6 +56,9 @@ class BlockInstallationBlock(Block):
|
||||
module = __import__(module_name, fromlist=[class_name])
|
||||
block_class: Type[Block] = getattr(module, class_name)
|
||||
block = block_class()
|
||||
|
||||
from autogpt_server.util.test import execute_block_test
|
||||
|
||||
execute_block_test(block)
|
||||
yield "success", "Block installed successfully."
|
||||
except Exception as e:
|
||||
|
||||
@@ -22,6 +22,7 @@ class BlockCategory(Enum):
|
||||
TEXT = "Block that processes text data."
|
||||
SEARCH = "Block that searches or extracts information from the internet."
|
||||
BASIC = "Block that performs basic operations."
|
||||
INPUT_OUTPUT = "Block that interacts with input/output of the graph."
|
||||
|
||||
def dict(self) -> dict[str, str]:
|
||||
return {"category": self.name, "description": self.value}
|
||||
|
||||
@@ -7,7 +7,8 @@ import prisma.types
|
||||
from prisma.models import AgentGraph, AgentNode, AgentNodeLink
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from autogpt_server.data.block import BlockInput
|
||||
from autogpt_server.blocks.basic import InputBlock, OutputBlock
|
||||
from autogpt_server.data.block import BlockInput, get_block
|
||||
from autogpt_server.data.db import BaseDbModel
|
||||
from autogpt_server.util import json
|
||||
|
||||
@@ -92,7 +93,68 @@ class Graph(GraphMeta):
|
||||
@property
|
||||
def starting_nodes(self) -> list[Node]:
|
||||
outbound_nodes = {link.sink_id for link in self.links}
|
||||
return [node for node in self.nodes if node.id not in outbound_nodes]
|
||||
input_nodes = {
|
||||
v.id for v in self.nodes if isinstance(get_block(v.block_id), InputBlock)
|
||||
}
|
||||
return [
|
||||
node
|
||||
for node in self.nodes
|
||||
if node.id not in outbound_nodes or node.id in input_nodes
|
||||
]
|
||||
|
||||
@property
|
||||
def ending_nodes(self) -> list[Node]:
|
||||
return [v for v in self.nodes if isinstance(get_block(v.block_id), OutputBlock)]
|
||||
|
||||
def validate_graph(self):
|
||||
|
||||
def sanitize(name):
|
||||
return name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
|
||||
|
||||
# Check if all required fields are filled or connected, except for InputBlock.
|
||||
for node in self.nodes:
|
||||
block = get_block(node.block_id)
|
||||
if block is None:
|
||||
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
||||
|
||||
provided_inputs = set(
|
||||
[sanitize(name) for name in node.input_default]
|
||||
+ [sanitize(link.sink_name) for link in node.input_links]
|
||||
)
|
||||
for name in block.input_schema.get_required_fields():
|
||||
if name not in provided_inputs and not isinstance(block, InputBlock):
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} required input missing: `{name}`"
|
||||
)
|
||||
|
||||
# Check if all links are connected compatible pin data type.
|
||||
for link in self.links:
|
||||
source_id = link.source_id
|
||||
sink_id = link.sink_id
|
||||
suffix = f"Link {source_id}<->{sink_id}"
|
||||
|
||||
source_node = next((v for v in self.nodes if v.id == source_id), None)
|
||||
if not source_node:
|
||||
raise ValueError(f"{suffix}, {source_id} is invalid node.")
|
||||
sink_node = next((v for v in self.nodes if v.id == sink_id), None)
|
||||
if not sink_node:
|
||||
raise ValueError(f"{suffix}, {sink_id} is invalid node.")
|
||||
|
||||
source_block = get_block(source_node.block_id)
|
||||
if not source_block:
|
||||
raise ValueError(f"{suffix}, {source_node.block_id} is invalid block.")
|
||||
sink_block = get_block(sink_node.block_id)
|
||||
if not sink_block:
|
||||
raise ValueError(f"{suffix}, {sink_node.block_id} is invalid block.")
|
||||
|
||||
source_name = sanitize(link.source_name)
|
||||
if source_name not in source_block.output_schema.get_fields():
|
||||
raise ValueError(f"{suffix}, `{source_name}` is invalid output pin.")
|
||||
sink_name = sanitize(link.sink_name)
|
||||
if sink_name not in sink_block.input_schema.get_fields():
|
||||
raise ValueError(f"{suffix}, `{sink_name}` is invalid input pin.")
|
||||
|
||||
# TODO: Add type compatibility check here.
|
||||
|
||||
@staticmethod
|
||||
def from_db(graph: AgentGraph):
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar
|
||||
if TYPE_CHECKING:
|
||||
from autogpt_server.server.server import AgentServer
|
||||
|
||||
from autogpt_server.blocks.basic import InputBlock
|
||||
from autogpt_server.data import db
|
||||
from autogpt_server.data.block import Block, BlockData, BlockInput, get_block
|
||||
from autogpt_server.data.execution import (
|
||||
@@ -419,10 +420,16 @@ class ExecutionManager(AppService):
|
||||
graph: Graph | None = self.run_and_wait(get_graph(graph_id))
|
||||
if not graph:
|
||||
raise Exception(f"Graph #{graph_id} not found.")
|
||||
graph.validate_graph()
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
input_data, error = validate_exec(node, data)
|
||||
if isinstance(get_block(node.block_id), InputBlock):
|
||||
input_data = {"input": data}
|
||||
else:
|
||||
input_data = {}
|
||||
|
||||
input_data, error = validate_exec(node, input_data)
|
||||
if not input_data:
|
||||
raise Exception(error)
|
||||
else:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from autogpt_server.blocks.basic import PrintingBlock, ValueBlock
|
||||
from autogpt_server.blocks.basic import InputBlock, PrintingBlock
|
||||
from autogpt_server.blocks.text import TextFormatterBlock
|
||||
from autogpt_server.data import graph
|
||||
from autogpt_server.data.graph import create_graph
|
||||
@@ -14,12 +14,18 @@ def create_test_graph() -> graph.Graph:
|
||||
ValueBlock
|
||||
"""
|
||||
nodes = [
|
||||
graph.Node(block_id=ValueBlock().id),
|
||||
graph.Node(block_id=ValueBlock().id),
|
||||
graph.Node(
|
||||
block_id=InputBlock().id,
|
||||
input_default={"key": "input_1"},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=InputBlock().id,
|
||||
input_default={"key": "input_2"},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=TextFormatterBlock().id,
|
||||
input_default={
|
||||
"format": "{texts[0]},{texts[1]},{texts[2]}",
|
||||
"format": "{texts[0]}, {texts[1]}{texts[2]}",
|
||||
"texts_$_3": "!!!",
|
||||
},
|
||||
),
|
||||
@@ -58,7 +64,7 @@ async def sample_agent():
|
||||
async with SpinTestServer() as server:
|
||||
exec_man = server.exec_manager
|
||||
test_graph = await create_graph(create_test_graph())
|
||||
input_data = {"input": "test!!"}
|
||||
input_data = {"input_1": "Hello", "input_2": "World"}
|
||||
response = await server.agent_server.execute_graph(test_graph.id, input_data)
|
||||
print(response)
|
||||
result = await wait_execution(exec_man, test_graph.id, response["id"], 4, 10)
|
||||
|
||||
@@ -28,37 +28,37 @@ async def execute_graph(
|
||||
async def assert_sample_graph_executions(
|
||||
agent_server: AgentServer, test_graph: graph.Graph, graph_exec_id: str
|
||||
):
|
||||
text = "Hello, World!"
|
||||
input = {"input_1": "Hello", "input_2": "World"}
|
||||
executions = await agent_server.get_run_execution_results(
|
||||
test_graph.id, graph_exec_id
|
||||
)
|
||||
|
||||
# Executing ConstantBlock1
|
||||
# Executing ValueBlock
|
||||
exec = executions[0]
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
assert exec.output_data == {"output": ["Hello, World!"]}
|
||||
assert exec.input_data == {"input": text}
|
||||
assert exec.output_data == {"output": ["Hello"]}
|
||||
assert exec.input_data == {"input": input, "key": "input_1"}
|
||||
assert exec.node_id in [test_graph.nodes[0].id, test_graph.nodes[1].id]
|
||||
|
||||
# Executing ConstantBlock2
|
||||
# Executing ValueBlock
|
||||
exec = executions[1]
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
assert exec.output_data == {"output": ["Hello, World!"]}
|
||||
assert exec.input_data == {"input": text}
|
||||
assert exec.output_data == {"output": ["World"]}
|
||||
assert exec.input_data == {"input": input, "key": "input_2"}
|
||||
assert exec.node_id in [test_graph.nodes[0].id, test_graph.nodes[1].id]
|
||||
|
||||
# Executing TextFormatterBlock
|
||||
exec = executions[2]
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
assert exec.output_data == {"output": ["Hello, World!,Hello, World!,!!!"]}
|
||||
assert exec.output_data == {"output": ["Hello, World!!!"]}
|
||||
assert exec.input_data == {
|
||||
"format": "{texts[0]},{texts[1]},{texts[2]}",
|
||||
"texts": ["Hello, World!", "Hello, World!", "!!!"],
|
||||
"texts_$_1": "Hello, World!",
|
||||
"texts_$_2": "Hello, World!",
|
||||
"format": "{texts[0]}, {texts[1]}{texts[2]}",
|
||||
"texts": ["Hello", "World", "!!!"],
|
||||
"texts_$_1": "Hello",
|
||||
"texts_$_2": "World",
|
||||
"texts_$_3": "!!!",
|
||||
}
|
||||
assert exec.node_id == test_graph.nodes[2].id
|
||||
@@ -68,7 +68,7 @@ async def assert_sample_graph_executions(
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
assert exec.output_data == {"status": ["printed"]}
|
||||
assert exec.input_data == {"text": "Hello, World!,Hello, World!,!!!"}
|
||||
assert exec.input_data == {"text": "Hello, World!!!"}
|
||||
assert exec.node_id == test_graph.nodes[3].id
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ async def assert_sample_graph_executions(
|
||||
async def test_agent_execution(server):
|
||||
test_graph = create_test_graph()
|
||||
await graph.create_graph(test_graph)
|
||||
data = {"input": "Hello, World!"}
|
||||
data = {"input_1": "Hello", "input_2": "World"}
|
||||
graph_exec_id = await execute_graph(
|
||||
server.agent_server, server.exec_manager, test_graph, data, 4
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user