From d407fd101e28161edbd3e04ed8911ace5448cfae Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Tue, 23 Jul 2024 06:06:26 +0400 Subject: [PATCH] fix(rnd): Make Agent Server's pin connections become the mandatory source of input (#7539) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Background Input from the input pin is consumed only once, and the default input can always be used. So when you have an input pin overriding the default input, the value will be used only once and the following run will always fall back to the default input. This can mislead the user. Expected behaviour: the node should NOT RUN, making connected pins only use their connection(s) for sources of data. ### Changes 🏗️ * Make pin connection the mandatory source of input and not falling back to default value. * Fix the type flakiness on block input & output. Unify the typing for BlockInput & BlockOutput using the right alias to avoid wrong typing. * Add comment on alias * automated test on the new behaviour. --- .../autogpt_server/data/block.py | 7 +- .../autogpt_server/data/execution.py | 44 +++++---- .../autogpt_server/data/graph.py | 3 +- .../autogpt_server/data/schedule.py | 5 +- .../autogpt_server/executor/manager.py | 78 ++++++++++------ .../autogpt_server/executor/scheduler.py | 4 +- .../autogpt_server/server/server.py | 12 ++- rnd/autogpt_server/test/conftest.py | 9 ++ .../test/executor/test_manager.py | 92 ++++++++++++++++--- .../test/executor/test_scheduler.py | 36 ++++---- 10 files changed, 199 insertions(+), 91 deletions(-) create mode 100644 rnd/autogpt_server/test/conftest.py diff --git a/rnd/autogpt_server/autogpt_server/data/block.py b/rnd/autogpt_server/autogpt_server/data/block.py index 7e1982b600..4171566522 100644 --- a/rnd/autogpt_server/autogpt_server/data/block.py +++ b/rnd/autogpt_server/autogpt_server/data/block.py @@ -9,9 +9,10 @@ from pydantic import BaseModel from autogpt_server.util import json -BlockInput = dict[str, Any] -BlockData = tuple[str, Any] -BlockOutput = Generator[BlockData, None, None] +BlockData = tuple[str, Any] # Input & Output data should be a tuple of (name, data). +BlockInput = dict[str, Any] # Input: 1 input pin consumes 1 data. +BlockOutput = Generator[BlockData, None, None] # Output: 1 output pin produces n data. +CompletedBlockOutput = dict[str, list[Any]] # Completed stream, collected as a dict. class BlockCategory(Enum): diff --git a/rnd/autogpt_server/autogpt_server/data/execution.py b/rnd/autogpt_server/autogpt_server/data/execution.py index d7bbe6a843..09cd513392 100644 --- a/rnd/autogpt_server/autogpt_server/data/execution.py +++ b/rnd/autogpt_server/autogpt_server/data/execution.py @@ -12,6 +12,7 @@ from prisma.models import ( from prisma.types import AgentGraphExecutionWhereInput from pydantic import BaseModel +from autogpt_server.data.block import BlockData, BlockInput, CompletedBlockOutput from autogpt_server.util import json @@ -19,7 +20,7 @@ class NodeExecution(BaseModel): graph_exec_id: str node_exec_id: str node_id: str - data: dict[str, Any] + data: BlockInput class ExecutionStatus(str, Enum): @@ -57,8 +58,8 @@ class ExecutionResult(BaseModel): node_exec_id: str node_id: str status: ExecutionStatus - input_data: dict[str, Any] # 1 input pin should consume exactly 1 data. - output_data: dict[str, list[Any]] # but 1 output pin can produce multiple output. + input_data: BlockInput + output_data: CompletedBlockOutput add_time: datetime queue_time: datetime | None start_time: datetime | None @@ -66,11 +67,11 @@ class ExecutionResult(BaseModel): @staticmethod def from_db(execution: AgentNodeExecution): - input_data: dict[str, Any] = defaultdict() + input_data: BlockInput = defaultdict() for data in execution.Input or []: input_data[data.name] = json.loads(data.data) - output_data: dict[str, Any] = defaultdict(list) + output_data: CompletedBlockOutput = defaultdict(list) for data in execution.Output or []: output_data[data.name].append(json.loads(data.data)) @@ -103,7 +104,7 @@ EXECUTION_RESULT_INCLUDE = { async def create_graph_execution( - graph_id: str, graph_version: int, node_ids: list[str], data: dict[str, Any] + graph_id: str, graph_version: int, nodes_input: list[tuple[str, BlockInput]] ) -> tuple[str, list[ExecutionResult]]: """ Create a new AgentGraphExecution record. @@ -122,15 +123,17 @@ async def create_graph_execution( "Input": { "create": [ {"name": name, "data": json.dumps(data)} - for name, data in data.items() + for name, data in node_input.items() ] }, } - for node_id in node_ids + for node_id, node_input in nodes_input ] }, }, - include={"AgentNodeExecutions": True}, + include={ + "AgentNodeExecutions": {"include": EXECUTION_RESULT_INCLUDE} # type: ignore + }, ) return result.id, [ @@ -242,7 +245,7 @@ async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]: return res -async def get_node_execution_input(node_exec_id: str) -> dict[str, Any]: +async def get_node_execution_input(node_exec_id: str) -> BlockInput: """ Get execution node input data from the previous node execution result. @@ -256,11 +259,10 @@ async def get_node_execution_input(node_exec_id: str) -> dict[str, Any]: if not execution.AgentNode: raise ValueError(f"Node {execution.agentNodeId} not found.") - exec_input = json.loads(execution.AgentNode.constantInput) - for input_data in execution.Input or []: - exec_input[input_data.name] = json.loads(input_data.data) - - return merge_execution_input(exec_input) + return { + input_data.name: json.loads(input_data.data) + for input_data in execution.Input or [] + } LIST_SPLIT = "_$_" @@ -268,7 +270,7 @@ DICT_SPLIT = "_#_" OBJC_SPLIT = "_@_" -def parse_execution_output(output: tuple[str, Any], name: str) -> Any | None: +def parse_execution_output(output: BlockData, name: str) -> Any | None: # Allow extracting partial output data by name. output_name, output_data = output @@ -296,7 +298,15 @@ def parse_execution_output(output: tuple[str, Any], name: str) -> Any | None: return None -def merge_execution_input(data: dict[str, Any]) -> dict[str, Any]: +def merge_execution_input(data: BlockInput) -> BlockInput: + """ + Merge all dynamic input pins which described by the following pattern: + - _$_ for list input. + - _#_ for dict input. + - _@_ for object input. + This function will construct pins with the same name into a single list/dict/object. + """ + # Merge all input with _$_ into a single list. items = list(data.items()) list_input: list[Any] = [] diff --git a/rnd/autogpt_server/autogpt_server/data/graph.py b/rnd/autogpt_server/autogpt_server/data/graph.py index 9ddf5c0586..763a2dec16 100644 --- a/rnd/autogpt_server/autogpt_server/data/graph.py +++ b/rnd/autogpt_server/autogpt_server/data/graph.py @@ -7,6 +7,7 @@ import prisma.types from prisma.models import AgentGraph, AgentNode, AgentNodeLink from pydantic import PrivateAttr +from autogpt_server.data.block import BlockInput from autogpt_server.data.db import BaseDbModel from autogpt_server.util import json @@ -33,7 +34,7 @@ class Link(BaseDbModel): class Node(BaseDbModel): block_id: str - input_default: dict[str, Any] = {} # dict[input_name, default_value] + input_default: BlockInput = {} # dict[input_name, default_value] metadata: dict[str, Any] = {} _input_links: list[Link] = PrivateAttr(default=[]) diff --git a/rnd/autogpt_server/autogpt_server/data/schedule.py b/rnd/autogpt_server/autogpt_server/data/schedule.py index d6049e3359..bbec740dca 100644 --- a/rnd/autogpt_server/autogpt_server/data/schedule.py +++ b/rnd/autogpt_server/autogpt_server/data/schedule.py @@ -1,8 +1,9 @@ from datetime import datetime -from typing import Any, Optional +from typing import Optional from prisma.models import AgentGraphExecutionSchedule +from autogpt_server.data.block import BlockInput from autogpt_server.data.db import BaseDbModel from autogpt_server.util import json @@ -12,7 +13,7 @@ class ExecutionSchedule(BaseDbModel): graph_version: int schedule: str is_enabled: bool - input_data: dict[str, Any] + input_data: BlockInput last_updated: Optional[datetime] = None def __init__(self, is_enabled: Optional[bool] = None, **kwargs): diff --git a/rnd/autogpt_server/autogpt_server/executor/manager.py b/rnd/autogpt_server/autogpt_server/executor/manager.py index fc03e3b802..731297d087 100644 --- a/rnd/autogpt_server/autogpt_server/executor/manager.py +++ b/rnd/autogpt_server/autogpt_server/executor/manager.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: from autogpt_server.server.server import AgentServer from autogpt_server.data import db -from autogpt_server.data.block import Block, get_block +from autogpt_server.data.block import Block, BlockData, BlockInput, get_block from autogpt_server.data.execution import ExecutionQueue, ExecutionStatus from autogpt_server.data.execution import NodeExecution as Execution from autogpt_server.data.execution import ( @@ -50,7 +50,6 @@ def execute_node( """ graph_exec_id = data.graph_exec_id node_exec_id = data.node_exec_id - exec_data = data.data node_id = data.node_id asyncio.set_event_loop(loop) @@ -73,8 +72,14 @@ def execute_node( logger.error(f"Block {node.block_id} not found.") return - # Execute the node + # Sanity check: validate the execution input. prefix = get_log_prefix(graph_exec_id, node_exec_id, node_block.name) + exec_data, error = validate_exec(node, data.data, resolve_input=False) + if not exec_data: + logger.error(f"{prefix} Skip execution, input validation error: {error}") + return + + # Execute the node logger.warning(f"{prefix} execute with input:\n`{exec_data}`") update_execution(ExecutionStatus.RUNNING) @@ -106,7 +111,7 @@ def enqueue_next_nodes( api_client: "AgentServer", loop: asyncio.AbstractEventLoop, node: Node, - output: tuple[str, Any], + output: BlockData, graph_exec_id: str, prefix: str, ) -> list[Execution]: @@ -142,10 +147,10 @@ def enqueue_next_nodes( ) next_node_input = wait(get_node_execution_input(next_node_exec_id)) - is_valid, validation_msg = validate_exec(next_node, next_node_input) + next_node_input, validation_msg = validate_exec(next_node, next_node_input) suffix = f"{next_output_name}~{next_input_name}#{next_node_id}:{validation_msg}" - if not is_valid: + if not next_node_input: logger.warning(f"{prefix} Skipped queueing {suffix}") return @@ -166,38 +171,54 @@ def enqueue_next_nodes( ] -def validate_exec(node: Node, data: dict[str, Any]) -> tuple[bool, str]: +def validate_exec( + node: Node, + data: BlockInput, + resolve_input: bool = True, +) -> tuple[BlockInput | None, str]: """ Validate the input data for a node execution. Args: node: The node to execute. data: The input data for the node execution. + resolve_input: Whether to resolve dynamic pins into dict/list/object. Returns: - A tuple of a boolean indicating if the data is valid, and a message if not. - Return the executed block name if the data is valid. + A tuple of the validated data and the block name. + If the data is invalid, the first element will be None, and the second element + will be an error message. + If the data is valid, the first element will be the resolved input data, and + the second element will be the block name. """ node_block: Block | None = get_block(node.block_id) # type: ignore if not node_block: - return False, f"Block for {node.block_id} not found." + return None, f"Block for {node.block_id} not found." - error_message = f"Input data missing for {node_block.name}:" - - input_fields_from_schema = node_block.input_schema.get_required_fields() - if not input_fields_from_schema.issubset(data): - return False, f"{error_message} {input_fields_from_schema - set(data)}" + error_prefix = f"Input data missing for {node_block.name}:" + # Input data (without default values) should contain all required fields. input_fields_from_nodes = {link.sink_name for link in node.input_links} if not input_fields_from_nodes.issubset(data): - return False, f"{error_message} {input_fields_from_nodes - set(data)}" + return None, f"{error_prefix} {input_fields_from_nodes - set(data)}" + # Merge input data with default values and resolve dynamic dict/list/object pins. + data = {**node.input_default, **data} + if resolve_input: + data = merge_execution_input(data) + + # Input data post-merge should contain all required fields from the schema. + input_fields_from_schema = node_block.input_schema.get_required_fields() + if not input_fields_from_schema.issubset(data): + return None, f"{error_prefix} {input_fields_from_schema - set(data)}" + + # Last validation: Validate the input values against the schema. if error := node_block.input_schema.validate_data(data): # type: ignore error_message = f"Input data doesn't match {node_block.name}: {error}" logger.error(error_message) - return False, error_message + return None, error_message - return True, node_block.name + return data, node_block.name def get_agent_server_client() -> "AgentServer": @@ -251,37 +272,34 @@ class ExecutionManager(AppService): return get_agent_server_client() @expose - def add_execution(self, graph_id: str, data: dict[str, Any]) -> dict[Any, Any]: + def add_execution(self, graph_id: str, data: BlockInput) -> dict[Any, Any]: graph: Graph | None = self.run_and_wait(get_graph(graph_id)) if not graph: raise Exception(f"Graph #{graph_id} not found.") - # Currently, there is no constraint on the number of root nodes in the graph. + nodes_input = [] for node in graph.starting_nodes: - input_data = merge_execution_input({**node.input_default, **data}) - valid, error = validate_exec(node, input_data) - if not valid: + input_data, error = validate_exec(node, data) + if not input_data: raise Exception(error) + else: + nodes_input.append((node.id, input_data)) graph_exec_id, node_execs = self.run_and_wait( create_graph_execution( graph_id=graph_id, graph_version=graph.version, - node_ids=[node.id for node in graph.starting_nodes], - data=data, + nodes_input=nodes_input, ) ) - executions: list[dict[str, Any]] = [] + executions: list[BlockInput] = [] for node_exec in node_execs: - input_data = self.run_and_wait( - get_node_execution_input(node_exec.node_exec_id) - ) self.add_node_execution( Execution( graph_exec_id=node_exec.graph_exec_id, node_exec_id=node_exec.node_exec_id, node_id=node_exec.node_id, - data=input_data, + data=node_exec.input_data, ) ) diff --git a/rnd/autogpt_server/autogpt_server/executor/scheduler.py b/rnd/autogpt_server/autogpt_server/executor/scheduler.py index 02c280559c..360f756d14 100644 --- a/rnd/autogpt_server/autogpt_server/executor/scheduler.py +++ b/rnd/autogpt_server/autogpt_server/executor/scheduler.py @@ -1,12 +1,12 @@ import logging import time from datetime import datetime -from typing import Any from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.triggers.cron import CronTrigger from autogpt_server.data import schedule as model +from autogpt_server.data.block import BlockInput from autogpt_server.executor.manager import ExecutionManager from autogpt_server.util.service import AppService, expose, get_service_client @@ -68,7 +68,7 @@ class ExecutionScheduler(AppService): @expose def add_execution_schedule( - self, graph_id: str, graph_version: int, cron: str, input_data: dict[str, Any] + self, graph_id: str, graph_version: int, cron: str, input_data: BlockInput ) -> str: schedule = model.ExecutionSchedule( graph_id=graph_id, diff --git a/rnd/autogpt_server/autogpt_server/server/server.py b/rnd/autogpt_server/autogpt_server/server/server.py index a5a36dff68..3d35af8812 100644 --- a/rnd/autogpt_server/autogpt_server/server/server.py +++ b/rnd/autogpt_server/autogpt_server/server/server.py @@ -1,5 +1,6 @@ import asyncio import uuid +from collections import defaultdict from contextlib import asynccontextmanager from typing import Annotated, Any, Dict @@ -19,6 +20,7 @@ from fastapi.staticfiles import StaticFiles import autogpt_server.server.ws_api from autogpt_server.data import block, db, execution from autogpt_server.data import graph as graph_db +from autogpt_server.data.block import BlockInput, CompletedBlockOutput from autogpt_server.executor import ExecutionManager, ExecutionScheduler from autogpt_server.server.conn_manager import ConnectionManager from autogpt_server.server.model import ( @@ -386,12 +388,16 @@ class AgentServer(AppService): @classmethod def execute_graph_block( - cls, block_id: str, data: dict[str, Any] - ) -> list[dict[str, Any]]: + cls, block_id: str, data: BlockInput + ) -> CompletedBlockOutput: obj = block.get_block(block_id) # type: ignore if not obj: raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.") - return [{name: data} for name, data in obj.execute(data)] + + output = defaultdict(list) + for name, data in obj.execute(data): + output[name].append(data) + return output @classmethod async def get_graphs(cls) -> list[graph_db.GraphMeta]: diff --git a/rnd/autogpt_server/test/conftest.py b/rnd/autogpt_server/test/conftest.py new file mode 100644 index 0000000000..46e6fce450 --- /dev/null +++ b/rnd/autogpt_server/test/conftest.py @@ -0,0 +1,9 @@ +import pytest + +from autogpt_server.util.test import SpinTestServer + + +@pytest.fixture(scope="session") +async def server(): + async with SpinTestServer() as server: + yield server diff --git a/rnd/autogpt_server/test/executor/test_manager.py b/rnd/autogpt_server/test/executor/test_manager.py index ac59248166..42be39bae8 100644 --- a/rnd/autogpt_server/test/executor/test_manager.py +++ b/rnd/autogpt_server/test/executor/test_manager.py @@ -1,28 +1,30 @@ import pytest +from autogpt_server.blocks.basic import ObjectLookupBlock, ValueBlock from autogpt_server.data import execution, graph from autogpt_server.executor import ExecutionManager from autogpt_server.server import AgentServer from autogpt_server.usecases.sample import create_test_graph -from autogpt_server.util.test import SpinTestServer, wait_execution +from autogpt_server.util.test import wait_execution -async def execute_graph(test_manager: ExecutionManager, test_graph: graph.Graph) -> str: +async def execute_graph( + test_manager: ExecutionManager, + test_graph: graph.Graph, + input_data: dict[str, str], + num_execs: int = 4, +) -> str: # --- Test adding new executions --- # - text = "Hello, World!" - input_data = {"input": text} agent_server = AgentServer() response = await agent_server.execute_graph(test_graph.id, input_data) - executions = response["executions"] graph_exec_id = response["id"] - assert len(executions) == 2 # Execution queue should be empty - assert await wait_execution(test_manager, test_graph.id, graph_exec_id, 4) + assert await wait_execution(test_manager, test_graph.id, graph_exec_id, num_execs) return graph_exec_id -async def assert_executions(test_graph: graph.Graph, graph_exec_id: str): +async def assert_sample_graph_executions(test_graph: graph.Graph, graph_exec_id: str): text = "Hello, World!" agent_server = AgentServer() executions = await agent_server.get_run_execution_results( @@ -66,9 +68,71 @@ async def assert_executions(test_graph: graph.Graph, graph_exec_id: str): @pytest.mark.asyncio(scope="session") -async def test_agent_execution(): - async with SpinTestServer() as server: - test_graph = create_test_graph() - await graph.create_graph(test_graph) - graph_exec_id = await execute_graph(server.exec_manager, test_graph) - await assert_executions(test_graph, graph_exec_id) +async def test_agent_execution(server): + test_graph = create_test_graph() + await graph.create_graph(test_graph) + data = {"input": "Hello, World!"} + graph_exec_id = await execute_graph(server.exec_manager, test_graph, data, 4) + await assert_sample_graph_executions(test_graph, graph_exec_id) + + +@pytest.mark.asyncio(scope="session") +async def test_input_pin_always_waited(server): + """ + This test is asserting that the input pin should always be waited for the execution, + even when default value on that pin is defined, the value has to be ignored. + + Test scenario: + ValueBlock1 + \\ input + >------- ObjectLookupBlock | input_default: key: "", input: {} + // key + ValueBlock2 + """ + nodes = [ + graph.Node( + block_id=ValueBlock().id, + input_default={"input": {"key1": "value1", "key2": "value2"}}, + ), + graph.Node( + block_id=ValueBlock().id, + input_default={"input": "key2"}, + ), + graph.Node( + block_id=ObjectLookupBlock().id, + input_default={"key": "", "input": {}}, + ), + ] + links = [ + graph.Link( + source_id=nodes[0].id, + sink_id=nodes[2].id, + source_name="output", + sink_name="input", + ), + graph.Link( + source_id=nodes[1].id, + sink_id=nodes[2].id, + source_name="output", + sink_name="key", + ), + ] + test_graph = graph.Graph( + name="TestGraph", + description="Test graph", + nodes=nodes, + links=links, + ) + + test_graph = await graph.create_graph(test_graph) + graph_exec_id = await execute_graph(server.exec_manager, test_graph, {}, 3) + + agent_server = AgentServer() + executions = await agent_server.get_run_execution_results( + test_graph.id, graph_exec_id + ) + assert len(executions) == 3 + # ObjectLookupBlock should wait for the input pin to be provided, + # Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"} + assert executions[2].status == execution.ExecutionStatus.COMPLETED + assert executions[2].output_data == {"output": ["value2"]} diff --git a/rnd/autogpt_server/test/executor/test_scheduler.py b/rnd/autogpt_server/test/executor/test_scheduler.py index a2ea925da9..354b73e4de 100644 --- a/rnd/autogpt_server/test/executor/test_scheduler.py +++ b/rnd/autogpt_server/test/executor/test_scheduler.py @@ -4,32 +4,30 @@ from autogpt_server.data import db, graph from autogpt_server.executor import ExecutionScheduler from autogpt_server.usecases.sample import create_test_graph from autogpt_server.util.service import get_service_client -from autogpt_server.util.test import SpinTestServer @pytest.mark.asyncio(scope="session") -async def test_agent_schedule(): +async def test_agent_schedule(server): await db.connect() test_graph = await graph.create_graph(create_test_graph()) - async with SpinTestServer(): - scheduler = get_service_client(ExecutionScheduler) + scheduler = get_service_client(ExecutionScheduler) - schedules = scheduler.get_execution_schedules(test_graph.id) - assert len(schedules) == 0 + schedules = scheduler.get_execution_schedules(test_graph.id) + assert len(schedules) == 0 - schedule_id = scheduler.add_execution_schedule( - graph_id=test_graph.id, - graph_version=1, - cron="0 0 * * *", - input_data={"input": "data"}, - ) - assert schedule_id + schedule_id = scheduler.add_execution_schedule( + graph_id=test_graph.id, + graph_version=1, + cron="0 0 * * *", + input_data={"input": "data"}, + ) + assert schedule_id - schedules = scheduler.get_execution_schedules(test_graph.id) - assert len(schedules) == 1 - assert schedules[schedule_id] == "0 0 * * *" + schedules = scheduler.get_execution_schedules(test_graph.id) + assert len(schedules) == 1 + assert schedules[schedule_id] == "0 0 * * *" - scheduler.update_schedule(schedule_id, is_enabled=False) - schedules = scheduler.get_execution_schedules(test_graph.id) - assert len(schedules) == 0 + scheduler.update_schedule(schedule_id, is_enabled=False) + schedules = scheduler.get_execution_schedules(test_graph.id) + assert len(schedules) == 0