fix(rnd): Make Agent Server's pin connections become the mandatory source of input (#7539)

### Background

Input from the input pin is consumed only once, and the default input can always be used. So when you have an input pin overriding the default input, the value will be used only once and the following run will always fall back to the default input. This can mislead the user.

Expected behaviour: the node should NOT RUN, making connected pins only use their connection(s) for sources of data.

### Changes 🏗️

* Make pin connection the mandatory source of input and not falling back to default value.
* Fix the type flakiness on block input & output. Unify the typing for BlockInput & BlockOutput using the right alias to avoid wrong typing.
* Add comment on alias
* automated test on the new behaviour.
This commit is contained in:
Zamil Majdy
2024-07-23 06:06:26 +04:00
committed by GitHub
parent a911f9a5eb
commit d407fd101e
10 changed files with 199 additions and 91 deletions

View File

@@ -9,9 +9,10 @@ from pydantic import BaseModel
from autogpt_server.util import json
BlockInput = dict[str, Any]
BlockData = tuple[str, Any]
BlockOutput = Generator[BlockData, None, None]
BlockData = tuple[str, Any] # Input & Output data should be a tuple of (name, data).
BlockInput = dict[str, Any] # Input: 1 input pin consumes 1 data.
BlockOutput = Generator[BlockData, None, None] # Output: 1 output pin produces n data.
CompletedBlockOutput = dict[str, list[Any]] # Completed stream, collected as a dict.
class BlockCategory(Enum):

View File

@@ -12,6 +12,7 @@ from prisma.models import (
from prisma.types import AgentGraphExecutionWhereInput
from pydantic import BaseModel
from autogpt_server.data.block import BlockData, BlockInput, CompletedBlockOutput
from autogpt_server.util import json
@@ -19,7 +20,7 @@ class NodeExecution(BaseModel):
graph_exec_id: str
node_exec_id: str
node_id: str
data: dict[str, Any]
data: BlockInput
class ExecutionStatus(str, Enum):
@@ -57,8 +58,8 @@ class ExecutionResult(BaseModel):
node_exec_id: str
node_id: str
status: ExecutionStatus
input_data: dict[str, Any] # 1 input pin should consume exactly 1 data.
output_data: dict[str, list[Any]] # but 1 output pin can produce multiple output.
input_data: BlockInput
output_data: CompletedBlockOutput
add_time: datetime
queue_time: datetime | None
start_time: datetime | None
@@ -66,11 +67,11 @@ class ExecutionResult(BaseModel):
@staticmethod
def from_db(execution: AgentNodeExecution):
input_data: dict[str, Any] = defaultdict()
input_data: BlockInput = defaultdict()
for data in execution.Input or []:
input_data[data.name] = json.loads(data.data)
output_data: dict[str, Any] = defaultdict(list)
output_data: CompletedBlockOutput = defaultdict(list)
for data in execution.Output or []:
output_data[data.name].append(json.loads(data.data))
@@ -103,7 +104,7 @@ EXECUTION_RESULT_INCLUDE = {
async def create_graph_execution(
graph_id: str, graph_version: int, node_ids: list[str], data: dict[str, Any]
graph_id: str, graph_version: int, nodes_input: list[tuple[str, BlockInput]]
) -> tuple[str, list[ExecutionResult]]:
"""
Create a new AgentGraphExecution record.
@@ -122,15 +123,17 @@ async def create_graph_execution(
"Input": {
"create": [
{"name": name, "data": json.dumps(data)}
for name, data in data.items()
for name, data in node_input.items()
]
},
}
for node_id in node_ids
for node_id, node_input in nodes_input
]
},
},
include={"AgentNodeExecutions": True},
include={
"AgentNodeExecutions": {"include": EXECUTION_RESULT_INCLUDE} # type: ignore
},
)
return result.id, [
@@ -242,7 +245,7 @@ async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
return res
async def get_node_execution_input(node_exec_id: str) -> dict[str, Any]:
async def get_node_execution_input(node_exec_id: str) -> BlockInput:
"""
Get execution node input data from the previous node execution result.
@@ -256,11 +259,10 @@ async def get_node_execution_input(node_exec_id: str) -> dict[str, Any]:
if not execution.AgentNode:
raise ValueError(f"Node {execution.agentNodeId} not found.")
exec_input = json.loads(execution.AgentNode.constantInput)
for input_data in execution.Input or []:
exec_input[input_data.name] = json.loads(input_data.data)
return merge_execution_input(exec_input)
return {
input_data.name: json.loads(input_data.data)
for input_data in execution.Input or []
}
LIST_SPLIT = "_$_"
@@ -268,7 +270,7 @@ DICT_SPLIT = "_#_"
OBJC_SPLIT = "_@_"
def parse_execution_output(output: tuple[str, Any], name: str) -> Any | None:
def parse_execution_output(output: BlockData, name: str) -> Any | None:
# Allow extracting partial output data by name.
output_name, output_data = output
@@ -296,7 +298,15 @@ def parse_execution_output(output: tuple[str, Any], name: str) -> Any | None:
return None
def merge_execution_input(data: dict[str, Any]) -> dict[str, Any]:
def merge_execution_input(data: BlockInput) -> BlockInput:
"""
Merge all dynamic input pins which described by the following pattern:
- <input_name>_$_<index> for list input.
- <input_name>_#_<index> for dict input.
- <input_name>_@_<index> for object input.
This function will construct pins with the same name into a single list/dict/object.
"""
# Merge all input with <input_name>_$_<index> into a single list.
items = list(data.items())
list_input: list[Any] = []

View File

@@ -7,6 +7,7 @@ import prisma.types
from prisma.models import AgentGraph, AgentNode, AgentNodeLink
from pydantic import PrivateAttr
from autogpt_server.data.block import BlockInput
from autogpt_server.data.db import BaseDbModel
from autogpt_server.util import json
@@ -33,7 +34,7 @@ class Link(BaseDbModel):
class Node(BaseDbModel):
block_id: str
input_default: dict[str, Any] = {} # dict[input_name, default_value]
input_default: BlockInput = {} # dict[input_name, default_value]
metadata: dict[str, Any] = {}
_input_links: list[Link] = PrivateAttr(default=[])

View File

@@ -1,8 +1,9 @@
from datetime import datetime
from typing import Any, Optional
from typing import Optional
from prisma.models import AgentGraphExecutionSchedule
from autogpt_server.data.block import BlockInput
from autogpt_server.data.db import BaseDbModel
from autogpt_server.util import json
@@ -12,7 +13,7 @@ class ExecutionSchedule(BaseDbModel):
graph_version: int
schedule: str
is_enabled: bool
input_data: dict[str, Any]
input_data: BlockInput
last_updated: Optional[datetime] = None
def __init__(self, is_enabled: Optional[bool] = None, **kwargs):

View File

@@ -7,7 +7,7 @@ if TYPE_CHECKING:
from autogpt_server.server.server import AgentServer
from autogpt_server.data import db
from autogpt_server.data.block import Block, get_block
from autogpt_server.data.block import Block, BlockData, BlockInput, get_block
from autogpt_server.data.execution import ExecutionQueue, ExecutionStatus
from autogpt_server.data.execution import NodeExecution as Execution
from autogpt_server.data.execution import (
@@ -50,7 +50,6 @@ def execute_node(
"""
graph_exec_id = data.graph_exec_id
node_exec_id = data.node_exec_id
exec_data = data.data
node_id = data.node_id
asyncio.set_event_loop(loop)
@@ -73,8 +72,14 @@ def execute_node(
logger.error(f"Block {node.block_id} not found.")
return
# Execute the node
# Sanity check: validate the execution input.
prefix = get_log_prefix(graph_exec_id, node_exec_id, node_block.name)
exec_data, error = validate_exec(node, data.data, resolve_input=False)
if not exec_data:
logger.error(f"{prefix} Skip execution, input validation error: {error}")
return
# Execute the node
logger.warning(f"{prefix} execute with input:\n`{exec_data}`")
update_execution(ExecutionStatus.RUNNING)
@@ -106,7 +111,7 @@ def enqueue_next_nodes(
api_client: "AgentServer",
loop: asyncio.AbstractEventLoop,
node: Node,
output: tuple[str, Any],
output: BlockData,
graph_exec_id: str,
prefix: str,
) -> list[Execution]:
@@ -142,10 +147,10 @@ def enqueue_next_nodes(
)
next_node_input = wait(get_node_execution_input(next_node_exec_id))
is_valid, validation_msg = validate_exec(next_node, next_node_input)
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
suffix = f"{next_output_name}~{next_input_name}#{next_node_id}:{validation_msg}"
if not is_valid:
if not next_node_input:
logger.warning(f"{prefix} Skipped queueing {suffix}")
return
@@ -166,38 +171,54 @@ def enqueue_next_nodes(
]
def validate_exec(node: Node, data: dict[str, Any]) -> tuple[bool, str]:
def validate_exec(
node: Node,
data: BlockInput,
resolve_input: bool = True,
) -> tuple[BlockInput | None, str]:
"""
Validate the input data for a node execution.
Args:
node: The node to execute.
data: The input data for the node execution.
resolve_input: Whether to resolve dynamic pins into dict/list/object.
Returns:
A tuple of a boolean indicating if the data is valid, and a message if not.
Return the executed block name if the data is valid.
A tuple of the validated data and the block name.
If the data is invalid, the first element will be None, and the second element
will be an error message.
If the data is valid, the first element will be the resolved input data, and
the second element will be the block name.
"""
node_block: Block | None = get_block(node.block_id) # type: ignore
if not node_block:
return False, f"Block for {node.block_id} not found."
return None, f"Block for {node.block_id} not found."
error_message = f"Input data missing for {node_block.name}:"
input_fields_from_schema = node_block.input_schema.get_required_fields()
if not input_fields_from_schema.issubset(data):
return False, f"{error_message} {input_fields_from_schema - set(data)}"
error_prefix = f"Input data missing for {node_block.name}:"
# Input data (without default values) should contain all required fields.
input_fields_from_nodes = {link.sink_name for link in node.input_links}
if not input_fields_from_nodes.issubset(data):
return False, f"{error_message} {input_fields_from_nodes - set(data)}"
return None, f"{error_prefix} {input_fields_from_nodes - set(data)}"
# Merge input data with default values and resolve dynamic dict/list/object pins.
data = {**node.input_default, **data}
if resolve_input:
data = merge_execution_input(data)
# Input data post-merge should contain all required fields from the schema.
input_fields_from_schema = node_block.input_schema.get_required_fields()
if not input_fields_from_schema.issubset(data):
return None, f"{error_prefix} {input_fields_from_schema - set(data)}"
# Last validation: Validate the input values against the schema.
if error := node_block.input_schema.validate_data(data): # type: ignore
error_message = f"Input data doesn't match {node_block.name}: {error}"
logger.error(error_message)
return False, error_message
return None, error_message
return True, node_block.name
return data, node_block.name
def get_agent_server_client() -> "AgentServer":
@@ -251,37 +272,34 @@ class ExecutionManager(AppService):
return get_agent_server_client()
@expose
def add_execution(self, graph_id: str, data: dict[str, Any]) -> dict[Any, Any]:
def add_execution(self, graph_id: str, data: BlockInput) -> dict[Any, Any]:
graph: Graph | None = self.run_and_wait(get_graph(graph_id))
if not graph:
raise Exception(f"Graph #{graph_id} not found.")
# Currently, there is no constraint on the number of root nodes in the graph.
nodes_input = []
for node in graph.starting_nodes:
input_data = merge_execution_input({**node.input_default, **data})
valid, error = validate_exec(node, input_data)
if not valid:
input_data, error = validate_exec(node, data)
if not input_data:
raise Exception(error)
else:
nodes_input.append((node.id, input_data))
graph_exec_id, node_execs = self.run_and_wait(
create_graph_execution(
graph_id=graph_id,
graph_version=graph.version,
node_ids=[node.id for node in graph.starting_nodes],
data=data,
nodes_input=nodes_input,
)
)
executions: list[dict[str, Any]] = []
executions: list[BlockInput] = []
for node_exec in node_execs:
input_data = self.run_and_wait(
get_node_execution_input(node_exec.node_exec_id)
)
self.add_node_execution(
Execution(
graph_exec_id=node_exec.graph_exec_id,
node_exec_id=node_exec.node_exec_id,
node_id=node_exec.node_id,
data=input_data,
data=node_exec.input_data,
)
)

View File

@@ -1,12 +1,12 @@
import logging
import time
from datetime import datetime
from typing import Any
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from autogpt_server.data import schedule as model
from autogpt_server.data.block import BlockInput
from autogpt_server.executor.manager import ExecutionManager
from autogpt_server.util.service import AppService, expose, get_service_client
@@ -68,7 +68,7 @@ class ExecutionScheduler(AppService):
@expose
def add_execution_schedule(
self, graph_id: str, graph_version: int, cron: str, input_data: dict[str, Any]
self, graph_id: str, graph_version: int, cron: str, input_data: BlockInput
) -> str:
schedule = model.ExecutionSchedule(
graph_id=graph_id,

View File

@@ -1,5 +1,6 @@
import asyncio
import uuid
from collections import defaultdict
from contextlib import asynccontextmanager
from typing import Annotated, Any, Dict
@@ -19,6 +20,7 @@ from fastapi.staticfiles import StaticFiles
import autogpt_server.server.ws_api
from autogpt_server.data import block, db, execution
from autogpt_server.data import graph as graph_db
from autogpt_server.data.block import BlockInput, CompletedBlockOutput
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
from autogpt_server.server.conn_manager import ConnectionManager
from autogpt_server.server.model import (
@@ -386,12 +388,16 @@ class AgentServer(AppService):
@classmethod
def execute_graph_block(
cls, block_id: str, data: dict[str, Any]
) -> list[dict[str, Any]]:
cls, block_id: str, data: BlockInput
) -> CompletedBlockOutput:
obj = block.get_block(block_id) # type: ignore
if not obj:
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
return [{name: data} for name, data in obj.execute(data)]
output = defaultdict(list)
for name, data in obj.execute(data):
output[name].append(data)
return output
@classmethod
async def get_graphs(cls) -> list[graph_db.GraphMeta]:

View File

@@ -0,0 +1,9 @@
import pytest
from autogpt_server.util.test import SpinTestServer
@pytest.fixture(scope="session")
async def server():
async with SpinTestServer() as server:
yield server

View File

@@ -1,28 +1,30 @@
import pytest
from autogpt_server.blocks.basic import ObjectLookupBlock, ValueBlock
from autogpt_server.data import execution, graph
from autogpt_server.executor import ExecutionManager
from autogpt_server.server import AgentServer
from autogpt_server.usecases.sample import create_test_graph
from autogpt_server.util.test import SpinTestServer, wait_execution
from autogpt_server.util.test import wait_execution
async def execute_graph(test_manager: ExecutionManager, test_graph: graph.Graph) -> str:
async def execute_graph(
test_manager: ExecutionManager,
test_graph: graph.Graph,
input_data: dict[str, str],
num_execs: int = 4,
) -> str:
# --- Test adding new executions --- #
text = "Hello, World!"
input_data = {"input": text}
agent_server = AgentServer()
response = await agent_server.execute_graph(test_graph.id, input_data)
executions = response["executions"]
graph_exec_id = response["id"]
assert len(executions) == 2
# Execution queue should be empty
assert await wait_execution(test_manager, test_graph.id, graph_exec_id, 4)
assert await wait_execution(test_manager, test_graph.id, graph_exec_id, num_execs)
return graph_exec_id
async def assert_executions(test_graph: graph.Graph, graph_exec_id: str):
async def assert_sample_graph_executions(test_graph: graph.Graph, graph_exec_id: str):
text = "Hello, World!"
agent_server = AgentServer()
executions = await agent_server.get_run_execution_results(
@@ -66,9 +68,71 @@ async def assert_executions(test_graph: graph.Graph, graph_exec_id: str):
@pytest.mark.asyncio(scope="session")
async def test_agent_execution():
async with SpinTestServer() as server:
test_graph = create_test_graph()
await graph.create_graph(test_graph)
graph_exec_id = await execute_graph(server.exec_manager, test_graph)
await assert_executions(test_graph, graph_exec_id)
async def test_agent_execution(server):
test_graph = create_test_graph()
await graph.create_graph(test_graph)
data = {"input": "Hello, World!"}
graph_exec_id = await execute_graph(server.exec_manager, test_graph, data, 4)
await assert_sample_graph_executions(test_graph, graph_exec_id)
@pytest.mark.asyncio(scope="session")
async def test_input_pin_always_waited(server):
"""
This test is asserting that the input pin should always be waited for the execution,
even when default value on that pin is defined, the value has to be ignored.
Test scenario:
ValueBlock1
\\ input
>------- ObjectLookupBlock | input_default: key: "", input: {}
// key
ValueBlock2
"""
nodes = [
graph.Node(
block_id=ValueBlock().id,
input_default={"input": {"key1": "value1", "key2": "value2"}},
),
graph.Node(
block_id=ValueBlock().id,
input_default={"input": "key2"},
),
graph.Node(
block_id=ObjectLookupBlock().id,
input_default={"key": "", "input": {}},
),
]
links = [
graph.Link(
source_id=nodes[0].id,
sink_id=nodes[2].id,
source_name="output",
sink_name="input",
),
graph.Link(
source_id=nodes[1].id,
sink_id=nodes[2].id,
source_name="output",
sink_name="key",
),
]
test_graph = graph.Graph(
name="TestGraph",
description="Test graph",
nodes=nodes,
links=links,
)
test_graph = await graph.create_graph(test_graph)
graph_exec_id = await execute_graph(server.exec_manager, test_graph, {}, 3)
agent_server = AgentServer()
executions = await agent_server.get_run_execution_results(
test_graph.id, graph_exec_id
)
assert len(executions) == 3
# ObjectLookupBlock should wait for the input pin to be provided,
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
assert executions[2].status == execution.ExecutionStatus.COMPLETED
assert executions[2].output_data == {"output": ["value2"]}

View File

@@ -4,32 +4,30 @@ from autogpt_server.data import db, graph
from autogpt_server.executor import ExecutionScheduler
from autogpt_server.usecases.sample import create_test_graph
from autogpt_server.util.service import get_service_client
from autogpt_server.util.test import SpinTestServer
@pytest.mark.asyncio(scope="session")
async def test_agent_schedule():
async def test_agent_schedule(server):
await db.connect()
test_graph = await graph.create_graph(create_test_graph())
async with SpinTestServer():
scheduler = get_service_client(ExecutionScheduler)
scheduler = get_service_client(ExecutionScheduler)
schedules = scheduler.get_execution_schedules(test_graph.id)
assert len(schedules) == 0
schedules = scheduler.get_execution_schedules(test_graph.id)
assert len(schedules) == 0
schedule_id = scheduler.add_execution_schedule(
graph_id=test_graph.id,
graph_version=1,
cron="0 0 * * *",
input_data={"input": "data"},
)
assert schedule_id
schedule_id = scheduler.add_execution_schedule(
graph_id=test_graph.id,
graph_version=1,
cron="0 0 * * *",
input_data={"input": "data"},
)
assert schedule_id
schedules = scheduler.get_execution_schedules(test_graph.id)
assert len(schedules) == 1
assert schedules[schedule_id] == "0 0 * * *"
schedules = scheduler.get_execution_schedules(test_graph.id)
assert len(schedules) == 1
assert schedules[schedule_id] == "0 0 * * *"
scheduler.update_schedule(schedule_id, is_enabled=False)
schedules = scheduler.get_execution_schedules(test_graph.id)
assert len(schedules) == 0
scheduler.update_schedule(schedule_id, is_enabled=False)
schedules = scheduler.get_execution_schedules(test_graph.id)
assert len(schedules) == 0