mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 23:28:07 -05:00
fix(rnd): Make Agent Server's pin connections become the mandatory source of input (#7539)
### Background
Input from the input pin is consumed only once, and the default input can always be used. So when you have an input pin overriding the default input, the value will be used only once and the following run will always fall back to the default input. This can mislead the user.
Expected behaviour: the node should NOT RUN, making connected pins only use their connection(s) for sources of data.
### Changes 🏗️
* Make pin connection the mandatory source of input and not falling back to default value.
* Fix the type flakiness on block input & output. Unify the typing for BlockInput & BlockOutput using the right alias to avoid wrong typing.
* Add comment on alias
* automated test on the new behaviour.
This commit is contained in:
@@ -9,9 +9,10 @@ from pydantic import BaseModel
|
||||
|
||||
from autogpt_server.util import json
|
||||
|
||||
BlockInput = dict[str, Any]
|
||||
BlockData = tuple[str, Any]
|
||||
BlockOutput = Generator[BlockData, None, None]
|
||||
BlockData = tuple[str, Any] # Input & Output data should be a tuple of (name, data).
|
||||
BlockInput = dict[str, Any] # Input: 1 input pin consumes 1 data.
|
||||
BlockOutput = Generator[BlockData, None, None] # Output: 1 output pin produces n data.
|
||||
CompletedBlockOutput = dict[str, list[Any]] # Completed stream, collected as a dict.
|
||||
|
||||
|
||||
class BlockCategory(Enum):
|
||||
|
||||
@@ -12,6 +12,7 @@ from prisma.models import (
|
||||
from prisma.types import AgentGraphExecutionWhereInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogpt_server.data.block import BlockData, BlockInput, CompletedBlockOutput
|
||||
from autogpt_server.util import json
|
||||
|
||||
|
||||
@@ -19,7 +20,7 @@ class NodeExecution(BaseModel):
|
||||
graph_exec_id: str
|
||||
node_exec_id: str
|
||||
node_id: str
|
||||
data: dict[str, Any]
|
||||
data: BlockInput
|
||||
|
||||
|
||||
class ExecutionStatus(str, Enum):
|
||||
@@ -57,8 +58,8 @@ class ExecutionResult(BaseModel):
|
||||
node_exec_id: str
|
||||
node_id: str
|
||||
status: ExecutionStatus
|
||||
input_data: dict[str, Any] # 1 input pin should consume exactly 1 data.
|
||||
output_data: dict[str, list[Any]] # but 1 output pin can produce multiple output.
|
||||
input_data: BlockInput
|
||||
output_data: CompletedBlockOutput
|
||||
add_time: datetime
|
||||
queue_time: datetime | None
|
||||
start_time: datetime | None
|
||||
@@ -66,11 +67,11 @@ class ExecutionResult(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def from_db(execution: AgentNodeExecution):
|
||||
input_data: dict[str, Any] = defaultdict()
|
||||
input_data: BlockInput = defaultdict()
|
||||
for data in execution.Input or []:
|
||||
input_data[data.name] = json.loads(data.data)
|
||||
|
||||
output_data: dict[str, Any] = defaultdict(list)
|
||||
output_data: CompletedBlockOutput = defaultdict(list)
|
||||
for data in execution.Output or []:
|
||||
output_data[data.name].append(json.loads(data.data))
|
||||
|
||||
@@ -103,7 +104,7 @@ EXECUTION_RESULT_INCLUDE = {
|
||||
|
||||
|
||||
async def create_graph_execution(
|
||||
graph_id: str, graph_version: int, node_ids: list[str], data: dict[str, Any]
|
||||
graph_id: str, graph_version: int, nodes_input: list[tuple[str, BlockInput]]
|
||||
) -> tuple[str, list[ExecutionResult]]:
|
||||
"""
|
||||
Create a new AgentGraphExecution record.
|
||||
@@ -122,15 +123,17 @@ async def create_graph_execution(
|
||||
"Input": {
|
||||
"create": [
|
||||
{"name": name, "data": json.dumps(data)}
|
||||
for name, data in data.items()
|
||||
for name, data in node_input.items()
|
||||
]
|
||||
},
|
||||
}
|
||||
for node_id in node_ids
|
||||
for node_id, node_input in nodes_input
|
||||
]
|
||||
},
|
||||
},
|
||||
include={"AgentNodeExecutions": True},
|
||||
include={
|
||||
"AgentNodeExecutions": {"include": EXECUTION_RESULT_INCLUDE} # type: ignore
|
||||
},
|
||||
)
|
||||
|
||||
return result.id, [
|
||||
@@ -242,7 +245,7 @@ async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
|
||||
return res
|
||||
|
||||
|
||||
async def get_node_execution_input(node_exec_id: str) -> dict[str, Any]:
|
||||
async def get_node_execution_input(node_exec_id: str) -> BlockInput:
|
||||
"""
|
||||
Get execution node input data from the previous node execution result.
|
||||
|
||||
@@ -256,11 +259,10 @@ async def get_node_execution_input(node_exec_id: str) -> dict[str, Any]:
|
||||
if not execution.AgentNode:
|
||||
raise ValueError(f"Node {execution.agentNodeId} not found.")
|
||||
|
||||
exec_input = json.loads(execution.AgentNode.constantInput)
|
||||
for input_data in execution.Input or []:
|
||||
exec_input[input_data.name] = json.loads(input_data.data)
|
||||
|
||||
return merge_execution_input(exec_input)
|
||||
return {
|
||||
input_data.name: json.loads(input_data.data)
|
||||
for input_data in execution.Input or []
|
||||
}
|
||||
|
||||
|
||||
LIST_SPLIT = "_$_"
|
||||
@@ -268,7 +270,7 @@ DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
|
||||
|
||||
def parse_execution_output(output: tuple[str, Any], name: str) -> Any | None:
|
||||
def parse_execution_output(output: BlockData, name: str) -> Any | None:
|
||||
# Allow extracting partial output data by name.
|
||||
output_name, output_data = output
|
||||
|
||||
@@ -296,7 +298,15 @@ def parse_execution_output(output: tuple[str, Any], name: str) -> Any | None:
|
||||
return None
|
||||
|
||||
|
||||
def merge_execution_input(data: dict[str, Any]) -> dict[str, Any]:
|
||||
def merge_execution_input(data: BlockInput) -> BlockInput:
|
||||
"""
|
||||
Merge all dynamic input pins which described by the following pattern:
|
||||
- <input_name>_$_<index> for list input.
|
||||
- <input_name>_#_<index> for dict input.
|
||||
- <input_name>_@_<index> for object input.
|
||||
This function will construct pins with the same name into a single list/dict/object.
|
||||
"""
|
||||
|
||||
# Merge all input with <input_name>_$_<index> into a single list.
|
||||
items = list(data.items())
|
||||
list_input: list[Any] = []
|
||||
|
||||
@@ -7,6 +7,7 @@ import prisma.types
|
||||
from prisma.models import AgentGraph, AgentNode, AgentNodeLink
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from autogpt_server.data.block import BlockInput
|
||||
from autogpt_server.data.db import BaseDbModel
|
||||
from autogpt_server.util import json
|
||||
|
||||
@@ -33,7 +34,7 @@ class Link(BaseDbModel):
|
||||
|
||||
class Node(BaseDbModel):
|
||||
block_id: str
|
||||
input_default: dict[str, Any] = {} # dict[input_name, default_value]
|
||||
input_default: BlockInput = {} # dict[input_name, default_value]
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
_input_links: list[Link] = PrivateAttr(default=[])
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from prisma.models import AgentGraphExecutionSchedule
|
||||
|
||||
from autogpt_server.data.block import BlockInput
|
||||
from autogpt_server.data.db import BaseDbModel
|
||||
from autogpt_server.util import json
|
||||
|
||||
@@ -12,7 +13,7 @@ class ExecutionSchedule(BaseDbModel):
|
||||
graph_version: int
|
||||
schedule: str
|
||||
is_enabled: bool
|
||||
input_data: dict[str, Any]
|
||||
input_data: BlockInput
|
||||
last_updated: Optional[datetime] = None
|
||||
|
||||
def __init__(self, is_enabled: Optional[bool] = None, **kwargs):
|
||||
|
||||
@@ -7,7 +7,7 @@ if TYPE_CHECKING:
|
||||
from autogpt_server.server.server import AgentServer
|
||||
|
||||
from autogpt_server.data import db
|
||||
from autogpt_server.data.block import Block, get_block
|
||||
from autogpt_server.data.block import Block, BlockData, BlockInput, get_block
|
||||
from autogpt_server.data.execution import ExecutionQueue, ExecutionStatus
|
||||
from autogpt_server.data.execution import NodeExecution as Execution
|
||||
from autogpt_server.data.execution import (
|
||||
@@ -50,7 +50,6 @@ def execute_node(
|
||||
"""
|
||||
graph_exec_id = data.graph_exec_id
|
||||
node_exec_id = data.node_exec_id
|
||||
exec_data = data.data
|
||||
node_id = data.node_id
|
||||
|
||||
asyncio.set_event_loop(loop)
|
||||
@@ -73,8 +72,14 @@ def execute_node(
|
||||
logger.error(f"Block {node.block_id} not found.")
|
||||
return
|
||||
|
||||
# Execute the node
|
||||
# Sanity check: validate the execution input.
|
||||
prefix = get_log_prefix(graph_exec_id, node_exec_id, node_block.name)
|
||||
exec_data, error = validate_exec(node, data.data, resolve_input=False)
|
||||
if not exec_data:
|
||||
logger.error(f"{prefix} Skip execution, input validation error: {error}")
|
||||
return
|
||||
|
||||
# Execute the node
|
||||
logger.warning(f"{prefix} execute with input:\n`{exec_data}`")
|
||||
update_execution(ExecutionStatus.RUNNING)
|
||||
|
||||
@@ -106,7 +111,7 @@ def enqueue_next_nodes(
|
||||
api_client: "AgentServer",
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
node: Node,
|
||||
output: tuple[str, Any],
|
||||
output: BlockData,
|
||||
graph_exec_id: str,
|
||||
prefix: str,
|
||||
) -> list[Execution]:
|
||||
@@ -142,10 +147,10 @@ def enqueue_next_nodes(
|
||||
)
|
||||
|
||||
next_node_input = wait(get_node_execution_input(next_node_exec_id))
|
||||
is_valid, validation_msg = validate_exec(next_node, next_node_input)
|
||||
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
|
||||
suffix = f"{next_output_name}~{next_input_name}#{next_node_id}:{validation_msg}"
|
||||
|
||||
if not is_valid:
|
||||
if not next_node_input:
|
||||
logger.warning(f"{prefix} Skipped queueing {suffix}")
|
||||
return
|
||||
|
||||
@@ -166,38 +171,54 @@ def enqueue_next_nodes(
|
||||
]
|
||||
|
||||
|
||||
def validate_exec(node: Node, data: dict[str, Any]) -> tuple[bool, str]:
|
||||
def validate_exec(
|
||||
node: Node,
|
||||
data: BlockInput,
|
||||
resolve_input: bool = True,
|
||||
) -> tuple[BlockInput | None, str]:
|
||||
"""
|
||||
Validate the input data for a node execution.
|
||||
|
||||
Args:
|
||||
node: The node to execute.
|
||||
data: The input data for the node execution.
|
||||
resolve_input: Whether to resolve dynamic pins into dict/list/object.
|
||||
|
||||
Returns:
|
||||
A tuple of a boolean indicating if the data is valid, and a message if not.
|
||||
Return the executed block name if the data is valid.
|
||||
A tuple of the validated data and the block name.
|
||||
If the data is invalid, the first element will be None, and the second element
|
||||
will be an error message.
|
||||
If the data is valid, the first element will be the resolved input data, and
|
||||
the second element will be the block name.
|
||||
"""
|
||||
node_block: Block | None = get_block(node.block_id) # type: ignore
|
||||
if not node_block:
|
||||
return False, f"Block for {node.block_id} not found."
|
||||
return None, f"Block for {node.block_id} not found."
|
||||
|
||||
error_message = f"Input data missing for {node_block.name}:"
|
||||
|
||||
input_fields_from_schema = node_block.input_schema.get_required_fields()
|
||||
if not input_fields_from_schema.issubset(data):
|
||||
return False, f"{error_message} {input_fields_from_schema - set(data)}"
|
||||
error_prefix = f"Input data missing for {node_block.name}:"
|
||||
|
||||
# Input data (without default values) should contain all required fields.
|
||||
input_fields_from_nodes = {link.sink_name for link in node.input_links}
|
||||
if not input_fields_from_nodes.issubset(data):
|
||||
return False, f"{error_message} {input_fields_from_nodes - set(data)}"
|
||||
return None, f"{error_prefix} {input_fields_from_nodes - set(data)}"
|
||||
|
||||
# Merge input data with default values and resolve dynamic dict/list/object pins.
|
||||
data = {**node.input_default, **data}
|
||||
if resolve_input:
|
||||
data = merge_execution_input(data)
|
||||
|
||||
# Input data post-merge should contain all required fields from the schema.
|
||||
input_fields_from_schema = node_block.input_schema.get_required_fields()
|
||||
if not input_fields_from_schema.issubset(data):
|
||||
return None, f"{error_prefix} {input_fields_from_schema - set(data)}"
|
||||
|
||||
# Last validation: Validate the input values against the schema.
|
||||
if error := node_block.input_schema.validate_data(data): # type: ignore
|
||||
error_message = f"Input data doesn't match {node_block.name}: {error}"
|
||||
logger.error(error_message)
|
||||
return False, error_message
|
||||
return None, error_message
|
||||
|
||||
return True, node_block.name
|
||||
return data, node_block.name
|
||||
|
||||
|
||||
def get_agent_server_client() -> "AgentServer":
|
||||
@@ -251,37 +272,34 @@ class ExecutionManager(AppService):
|
||||
return get_agent_server_client()
|
||||
|
||||
@expose
|
||||
def add_execution(self, graph_id: str, data: dict[str, Any]) -> dict[Any, Any]:
|
||||
def add_execution(self, graph_id: str, data: BlockInput) -> dict[Any, Any]:
|
||||
graph: Graph | None = self.run_and_wait(get_graph(graph_id))
|
||||
if not graph:
|
||||
raise Exception(f"Graph #{graph_id} not found.")
|
||||
|
||||
# Currently, there is no constraint on the number of root nodes in the graph.
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
input_data = merge_execution_input({**node.input_default, **data})
|
||||
valid, error = validate_exec(node, input_data)
|
||||
if not valid:
|
||||
input_data, error = validate_exec(node, data)
|
||||
if not input_data:
|
||||
raise Exception(error)
|
||||
else:
|
||||
nodes_input.append((node.id, input_data))
|
||||
|
||||
graph_exec_id, node_execs = self.run_and_wait(
|
||||
create_graph_execution(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
node_ids=[node.id for node in graph.starting_nodes],
|
||||
data=data,
|
||||
nodes_input=nodes_input,
|
||||
)
|
||||
)
|
||||
executions: list[dict[str, Any]] = []
|
||||
executions: list[BlockInput] = []
|
||||
for node_exec in node_execs:
|
||||
input_data = self.run_and_wait(
|
||||
get_node_execution_input(node_exec.node_exec_id)
|
||||
)
|
||||
self.add_node_execution(
|
||||
Execution(
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
data=input_data,
|
||||
data=node_exec.input_data,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
|
||||
from autogpt_server.data import schedule as model
|
||||
from autogpt_server.data.block import BlockInput
|
||||
from autogpt_server.executor.manager import ExecutionManager
|
||||
from autogpt_server.util.service import AppService, expose, get_service_client
|
||||
|
||||
@@ -68,7 +68,7 @@ class ExecutionScheduler(AppService):
|
||||
|
||||
@expose
|
||||
def add_execution_schedule(
|
||||
self, graph_id: str, graph_version: int, cron: str, input_data: dict[str, Any]
|
||||
self, graph_id: str, graph_version: int, cron: str, input_data: BlockInput
|
||||
) -> str:
|
||||
schedule = model.ExecutionSchedule(
|
||||
graph_id=graph_id,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Annotated, Any, Dict
|
||||
|
||||
@@ -19,6 +20,7 @@ from fastapi.staticfiles import StaticFiles
|
||||
import autogpt_server.server.ws_api
|
||||
from autogpt_server.data import block, db, execution
|
||||
from autogpt_server.data import graph as graph_db
|
||||
from autogpt_server.data.block import BlockInput, CompletedBlockOutput
|
||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||
from autogpt_server.server.conn_manager import ConnectionManager
|
||||
from autogpt_server.server.model import (
|
||||
@@ -386,12 +388,16 @@ class AgentServer(AppService):
|
||||
|
||||
@classmethod
|
||||
def execute_graph_block(
|
||||
cls, block_id: str, data: dict[str, Any]
|
||||
) -> list[dict[str, Any]]:
|
||||
cls, block_id: str, data: BlockInput
|
||||
) -> CompletedBlockOutput:
|
||||
obj = block.get_block(block_id) # type: ignore
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail=f"Block #{block_id} not found.")
|
||||
return [{name: data} for name, data in obj.execute(data)]
|
||||
|
||||
output = defaultdict(list)
|
||||
for name, data in obj.execute(data):
|
||||
output[name].append(data)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
async def get_graphs(cls) -> list[graph_db.GraphMeta]:
|
||||
|
||||
9
rnd/autogpt_server/test/conftest.py
Normal file
9
rnd/autogpt_server/test/conftest.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import pytest
|
||||
|
||||
from autogpt_server.util.test import SpinTestServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def server():
|
||||
async with SpinTestServer() as server:
|
||||
yield server
|
||||
@@ -1,28 +1,30 @@
|
||||
import pytest
|
||||
|
||||
from autogpt_server.blocks.basic import ObjectLookupBlock, ValueBlock
|
||||
from autogpt_server.data import execution, graph
|
||||
from autogpt_server.executor import ExecutionManager
|
||||
from autogpt_server.server import AgentServer
|
||||
from autogpt_server.usecases.sample import create_test_graph
|
||||
from autogpt_server.util.test import SpinTestServer, wait_execution
|
||||
from autogpt_server.util.test import wait_execution
|
||||
|
||||
|
||||
async def execute_graph(test_manager: ExecutionManager, test_graph: graph.Graph) -> str:
|
||||
async def execute_graph(
|
||||
test_manager: ExecutionManager,
|
||||
test_graph: graph.Graph,
|
||||
input_data: dict[str, str],
|
||||
num_execs: int = 4,
|
||||
) -> str:
|
||||
# --- Test adding new executions --- #
|
||||
text = "Hello, World!"
|
||||
input_data = {"input": text}
|
||||
agent_server = AgentServer()
|
||||
response = await agent_server.execute_graph(test_graph.id, input_data)
|
||||
executions = response["executions"]
|
||||
graph_exec_id = response["id"]
|
||||
assert len(executions) == 2
|
||||
|
||||
# Execution queue should be empty
|
||||
assert await wait_execution(test_manager, test_graph.id, graph_exec_id, 4)
|
||||
assert await wait_execution(test_manager, test_graph.id, graph_exec_id, num_execs)
|
||||
return graph_exec_id
|
||||
|
||||
|
||||
async def assert_executions(test_graph: graph.Graph, graph_exec_id: str):
|
||||
async def assert_sample_graph_executions(test_graph: graph.Graph, graph_exec_id: str):
|
||||
text = "Hello, World!"
|
||||
agent_server = AgentServer()
|
||||
executions = await agent_server.get_run_execution_results(
|
||||
@@ -66,9 +68,71 @@ async def assert_executions(test_graph: graph.Graph, graph_exec_id: str):
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_agent_execution():
|
||||
async with SpinTestServer() as server:
|
||||
test_graph = create_test_graph()
|
||||
await graph.create_graph(test_graph)
|
||||
graph_exec_id = await execute_graph(server.exec_manager, test_graph)
|
||||
await assert_executions(test_graph, graph_exec_id)
|
||||
async def test_agent_execution(server):
|
||||
test_graph = create_test_graph()
|
||||
await graph.create_graph(test_graph)
|
||||
data = {"input": "Hello, World!"}
|
||||
graph_exec_id = await execute_graph(server.exec_manager, test_graph, data, 4)
|
||||
await assert_sample_graph_executions(test_graph, graph_exec_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_input_pin_always_waited(server):
|
||||
"""
|
||||
This test is asserting that the input pin should always be waited for the execution,
|
||||
even when default value on that pin is defined, the value has to be ignored.
|
||||
|
||||
Test scenario:
|
||||
ValueBlock1
|
||||
\\ input
|
||||
>------- ObjectLookupBlock | input_default: key: "", input: {}
|
||||
// key
|
||||
ValueBlock2
|
||||
"""
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=ValueBlock().id,
|
||||
input_default={"input": {"key1": "value1", "key2": "value2"}},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=ValueBlock().id,
|
||||
input_default={"input": "key2"},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=ObjectLookupBlock().id,
|
||||
input_default={"key": "", "input": {}},
|
||||
),
|
||||
]
|
||||
links = [
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="output",
|
||||
sink_name="input",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[1].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="output",
|
||||
sink_name="key",
|
||||
),
|
||||
]
|
||||
test_graph = graph.Graph(
|
||||
name="TestGraph",
|
||||
description="Test graph",
|
||||
nodes=nodes,
|
||||
links=links,
|
||||
)
|
||||
|
||||
test_graph = await graph.create_graph(test_graph)
|
||||
graph_exec_id = await execute_graph(server.exec_manager, test_graph, {}, 3)
|
||||
|
||||
agent_server = AgentServer()
|
||||
executions = await agent_server.get_run_execution_results(
|
||||
test_graph.id, graph_exec_id
|
||||
)
|
||||
assert len(executions) == 3
|
||||
# ObjectLookupBlock should wait for the input pin to be provided,
|
||||
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
|
||||
assert executions[2].status == execution.ExecutionStatus.COMPLETED
|
||||
assert executions[2].output_data == {"output": ["value2"]}
|
||||
|
||||
@@ -4,32 +4,30 @@ from autogpt_server.data import db, graph
|
||||
from autogpt_server.executor import ExecutionScheduler
|
||||
from autogpt_server.usecases.sample import create_test_graph
|
||||
from autogpt_server.util.service import get_service_client
|
||||
from autogpt_server.util.test import SpinTestServer
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_agent_schedule():
|
||||
async def test_agent_schedule(server):
|
||||
await db.connect()
|
||||
test_graph = await graph.create_graph(create_test_graph())
|
||||
|
||||
async with SpinTestServer():
|
||||
scheduler = get_service_client(ExecutionScheduler)
|
||||
scheduler = get_service_client(ExecutionScheduler)
|
||||
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id)
|
||||
assert len(schedules) == 0
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id)
|
||||
assert len(schedules) == 0
|
||||
|
||||
schedule_id = scheduler.add_execution_schedule(
|
||||
graph_id=test_graph.id,
|
||||
graph_version=1,
|
||||
cron="0 0 * * *",
|
||||
input_data={"input": "data"},
|
||||
)
|
||||
assert schedule_id
|
||||
schedule_id = scheduler.add_execution_schedule(
|
||||
graph_id=test_graph.id,
|
||||
graph_version=1,
|
||||
cron="0 0 * * *",
|
||||
input_data={"input": "data"},
|
||||
)
|
||||
assert schedule_id
|
||||
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id)
|
||||
assert len(schedules) == 1
|
||||
assert schedules[schedule_id] == "0 0 * * *"
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id)
|
||||
assert len(schedules) == 1
|
||||
assert schedules[schedule_id] == "0 0 * * *"
|
||||
|
||||
scheduler.update_schedule(schedule_id, is_enabled=False)
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id)
|
||||
assert len(schedules) == 0
|
||||
scheduler.update_schedule(schedule_id, is_enabled=False)
|
||||
schedules = scheduler.get_execution_schedules(test_graph.id)
|
||||
assert len(schedules) == 0
|
||||
|
||||
Reference in New Issue
Block a user