feat(rnd): Refactor AgentServer Node Input/Output Relation & Block output interface (#7231)

### Background

The current implementation of AgentServer doesn't allow for a single pin to be connected to multiple nodes, this will be problematic when you have a single output node that needs to be propagated into many nodes. Or multiple nodes that possibly feed the data into a single pin (first come first serve).

This infra change is also part of the preparation for changing the `block` interface to return a stream of output instead of a single output.  Treating blocks as streams requires this capability.

### Changes 🏗️

* Update block run interface from returning `(output_name, output_data)` to `Generator[(output_name, output_data)]`
* Removed `agent` term in the API, replace it with `graph` for consistency.
* Reintroduced `AgentNodeExecutionInputOutput`. `AgentNodeExecution` input & output will be a list of `AgentNodeExecutionInputOutput` which describes the input & output data of its execution. Making an execution has 1-many relation to its input output data.
* Propagating the relation and block interface change into the execution engine.
This commit is contained in:
Zamil Majdy
2024-06-26 14:41:55 +04:00
committed by GitHub
parent f04ddceacf
commit 26bcb26bb7
10 changed files with 499 additions and 338 deletions

View File

@@ -1,9 +1,8 @@
import json
import jsonschema
from abc import ABC, abstractmethod
from typing import Any, ClassVar
from typing import Any, Generator, ClassVar
import jsonschema
from prisma.models import AgentBlock
from pydantic import BaseModel
@@ -92,6 +91,12 @@ class BlockSchema(BaseModel):
except jsonschema.ValidationError as e:
return str(e)
def get_fields(self) -> set[str]:
return set(self.jsonschema["properties"].keys())
BlockOutput = Generator[tuple[str, Any], None, None]
class Block(ABC, BaseModel):
@classmethod
@@ -126,13 +131,15 @@ class Block(ABC, BaseModel):
pass
@abstractmethod
def run(self, input_data: BlockData) -> tuple[str, Any]:
def run(self, input_data: BlockData) -> BlockOutput:
"""
Run the block with the given input data.
Args:
input_data: The input data with the structure of input_schema.
Returns:
The (output name, output data), matching the type in output_schema.
A Generator that yields (output_name, output_data).
output_name: One of the output name defined in Block's output_schema.
output_data: The data for the output_name, matching the defined schema.
"""
pass
@@ -149,20 +156,18 @@ class Block(ABC, BaseModel):
"outputSchema": self.output_schema.jsonschema,
}
def execute(self, input_data: BlockData) -> tuple[str, Any]:
def execute(self, input_data: BlockData) -> BlockOutput:
if error := self.input_schema.validate_data(input_data):
raise ValueError(
f"Unable to execute block with invalid input data: {error}"
)
output_name, output_data = self.run(input_data)
if error := self.output_schema.validate_field(output_name, output_data):
raise ValueError(
f"Unable to execute block with invalid output data: {error}"
)
return output_name, output_data
for output_name, output_data in self.run(input_data):
if error := self.output_schema.validate_field(output_name, output_data):
raise ValueError(
f"Unable to execute block with invalid output data: {error}"
)
yield output_name, output_data
# ===================== Inline-Block Implementations ===================== #
@@ -181,8 +186,8 @@ class ParrotBlock(Block):
}
)
def run(self, input_data: BlockData) -> tuple[str, Any]:
return "output", input_data["input"]
def run(self, input_data: BlockData) -> BlockOutput:
yield "output", input_data["input"]
class TextCombinerBlock(Block):
@@ -200,8 +205,8 @@ class TextCombinerBlock(Block):
}
)
def run(self, input_data: BlockData) -> tuple[str, Any]:
return "combined_text", input_data["format"].format(
def run(self, input_data: BlockData) -> BlockOutput:
yield "combined_text", input_data["format"].format(
text1=input_data["text1"],
text2=input_data["text2"],
)
@@ -220,8 +225,8 @@ class PrintingBlock(Block):
}
)
def run(self, input_data: BlockData) -> tuple[str, Any]:
return "status", "printed"
def run(self, input_data: BlockData) -> BlockOutput:
yield "status", "printed"
# ======================= Block Helper Functions ======================= #

View File

@@ -1,23 +1,27 @@
import json
from collections import defaultdict
from datetime import datetime
from enum import Enum
from multiprocessing import Queue
from multiprocessing import Manager
from typing import Any
from prisma.models import AgentNodeExecution
from autogpt_server.data.db import BaseDbModel
from prisma.models import (
AgentGraphExecution,
AgentNodeExecution,
AgentNodeExecutionInputOutput,
)
from pydantic import BaseModel
class Execution(BaseDbModel):
"""Data model for an execution of an Agent"""
run_id: str
class NodeExecution(BaseModel):
graph_exec_id: str
node_exec_id: str
node_id: str
data: dict[str, Any]
class ExecutionStatus(str, Enum):
INCOMPLETE = "INCOMPLETE"
QUEUED = "QUEUED"
RUNNING = "RUNNING"
COMPLETED = "COMPLETED"
@@ -31,103 +35,205 @@ class ExecutionQueue:
"""
def __init__(self):
self.queue: Queue[Execution] = Queue()
self.queue = Manager().Queue()
def add(self, execution: Execution) -> Execution:
def add(self, execution: NodeExecution) -> NodeExecution:
self.queue.put(execution)
return execution
def get(self) -> Execution:
def get(self) -> NodeExecution:
return self.queue.get()
def empty(self) -> bool:
return self.queue.empty()
class ExecutionResult(BaseDbModel):
run_id: str
execution_id: str
class ExecutionResult(BaseModel):
graph_exec_id: str
node_exec_id: str
node_id: str
status: ExecutionStatus
input_data: dict[str, Any]
output_name: str
output_data: Any
creation_time: datetime
input_data: dict[str, Any] # 1 input pin should consume exactly 1 data.
output_data: dict[str, list[Any]] # but 1 output pin can produce multiple output.
add_time: datetime
queue_time: datetime | None
start_time: datetime | None
end_time: datetime | None
@staticmethod
def from_db(execution: AgentNodeExecution):
input_data = defaultdict()
for data in execution.Input or []:
input_data[data.name] = json.loads(data.data)
output_data = defaultdict(list)
for data in execution.Output or []:
output_data[data.name].append(json.loads(data.data))
return ExecutionResult(
run_id=execution.executionId,
graph_exec_id=execution.agentGraphExecutionId,
node_exec_id=execution.id,
node_id=execution.agentNodeId,
execution_id=execution.id,
status=ExecutionStatus(execution.executionStatus),
input_data=json.loads(execution.inputData or "{}"),
output_name=execution.outputName or "",
output_data=json.loads(execution.outputData or "{}"),
creation_time=execution.creationTime,
start_time=execution.startTime,
end_time=execution.endTime,
input_data=input_data,
output_data=output_data,
add_time=execution.addedTime,
queue_time=execution.queuedTime,
start_time=execution.startedTime,
end_time=execution.endedTime,
)
# --------------------- Model functions --------------------- #
async def enqueue_execution(execution: Execution) -> None:
await AgentNodeExecution.prisma().create(
async def create_graph_execution(
graph_id: str,
node_ids: list[str],
data: dict[str, Any]
) -> tuple[str, list[ExecutionResult]]:
"""
Create a new AgentGraphExecution record.
Returns:
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
"""
result = await AgentGraphExecution.prisma().create(
data={
"id": execution.id,
"executionId": execution.run_id,
"agentNodeId": execution.node_id,
"executionStatus": ExecutionStatus.QUEUED,
"inputData": json.dumps(execution.data),
"creationTime": datetime.now(),
"agentGraphId": graph_id,
"AgentNodeExecutions": {
"create": [ # type: ignore
{
"agentNodeId": node_id,
"executionStatus": ExecutionStatus.INCOMPLETE,
"Input": {
"create": [
{"name": name, "data": json.dumps(data)}
for name, data in data.items()
]
},
}
for node_id in node_ids
]
},
},
include={"AgentNodeExecutions": True}
)
return result.id, [
ExecutionResult.from_db(execution)
for execution in result.AgentNodeExecutions or []
]
async def upsert_execution_input(
node_id: str,
graph_exec_id: str,
input_name: str,
data: Any,
) -> str:
"""
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Input.
If there is no AgentNodeExecution that has no `input_name` as input, create new one.
Returns:
The id of the created or existing AgentNodeExecution.
"""
existing_execution = await AgentNodeExecution.prisma().find_first(
where={ # type: ignore
"agentNodeId": node_id,
"agentGraphExecutionId": graph_exec_id,
"Input": {"every": {"name": {"not": input_name}}},
},
order={"addedTime": "asc"},
)
json_data = json.dumps(data)
if existing_execution:
print(f"Adding input {input_name}={data} to execution #{existing_execution.id}")
await AgentNodeExecutionInputOutput.prisma().create(
data={
"name": input_name,
"data": json_data,
"referencedByInputExecId": existing_execution.id,
}
)
return existing_execution.id
else:
print(f"Creating new execution for input {input_name}={data}")
result = await AgentNodeExecution.prisma().create(
data={
"agentNodeId": node_id,
"agentGraphExecutionId": graph_exec_id,
"executionStatus": ExecutionStatus.INCOMPLETE,
"Input": {"create": {"name": input_name, "data": json_data}},
}
)
return result.id
async def upsert_execution_output(
node_exec_id: str,
output_name: str,
output_data: Any,
) -> None:
"""
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
"""
await AgentNodeExecutionInputOutput.prisma().create(
data={
"name": output_name,
"data": json.dumps(output_data),
"referencedByOutputExecId": node_exec_id,
}
)
async def start_execution(exec_id: str) -> None:
await AgentNodeExecution.prisma().update(
where={"id": exec_id},
data={
"executionStatus": ExecutionStatus.RUNNING,
"startTime": datetime.now(),
},
async def update_execution_status(node_exec_id: str, status: ExecutionStatus) -> None:
now = datetime.now()
data = {
**({"executionStatus": status}),
**({"queuedTime": now} if status == ExecutionStatus.QUEUED else {}),
**({"startedTime": now} if status == ExecutionStatus.RUNNING else {}),
**({"endedTime": now} if status == ExecutionStatus.FAILED else {}),
**({"endedTime": now} if status == ExecutionStatus.COMPLETED else {}),
}
count = await AgentNodeExecution.prisma().update(
where={"id": node_exec_id},
data=data # type: ignore
)
if count == 0:
raise ValueError(f"Execution {node_exec_id} not found.")
async def complete_execution(exec_id: str, output: tuple[str, Any]) -> None:
output_name, output_data = output
await AgentNodeExecution.prisma().update(
where={"id": exec_id},
data={
"executionStatus": ExecutionStatus.COMPLETED,
"outputName": output_name,
"outputData": json.dumps(output_data),
"endTime": datetime.now(),
},
)
async def fail_execution(exec_id: str, error: Exception) -> None:
await AgentNodeExecution.prisma().update(
where={"id": exec_id},
data={
"executionStatus": ExecutionStatus.FAILED,
"outputName": "error",
"outputData": str(error),
"endTime": datetime.now(),
},
)
async def get_executions(run_id: str) -> list[ExecutionResult]:
async def get_executions(graph_exec_id: str) -> list[ExecutionResult]:
executions = await AgentNodeExecution.prisma().find_many(
where={"executionId": run_id},
order={"startTime": "asc"},
where={"agentGraphExecutionId": graph_exec_id},
include={"Input": True, "Output": True},
order={"addedTime": "asc"},
)
res = [ExecutionResult.from_db(execution) for execution in executions]
return res
async def get_node_execution_input(node_exec_id: str) -> dict[str, Any]:
"""
Get execution node input data from the previous node execution result.
Returns:
dictionary of input data, key is the input name, value is the input data.
"""
execution = await AgentNodeExecution.prisma().find_unique_or_raise(
where={"id": node_exec_id},
include={
"Input": True,
"AgentNode": True,
},
)
if not execution.AgentNode:
raise ValueError(f"Node {execution.agentNodeId} not found.")
exec_input = json.loads(execution.AgentNode.constantInput)
for input_data in execution.Input or []:
exec_input[input_data.name] = json.loads(input_data.data)
return exec_input

View File

@@ -1,20 +1,30 @@
import asyncio
import json
import uuid
from typing import Any
from prisma.models import AgentGraph, AgentNode, AgentNodeExecution, AgentNodeLink
from prisma.models import AgentGraph, AgentNode, AgentNodeLink
from pydantic import BaseModel
from autogpt_server.data.db import BaseDbModel
class Link(BaseModel):
name: str
node_id: str
def __init__(self, name: str, node_id: str):
super().__init__(name=name, node_id=node_id)
def __iter__(self):
return iter((self.name, self.node_id))
class Node(BaseDbModel):
block_id: str
input_default: dict[str, Any] = {} # dict[input_name, default_value]
input_nodes: dict[str, str] = {} # dict[input_name, node_id]
# TODO: Make it `dict[str, list[str]]`, output can be connected to multiple blocks.
# Other option is to use an edge-list, but it will complicate the rest code.
output_nodes: dict[str, str] = {} # dict[output_name, node_id]
input_nodes: list[Link] = [] # dict[input_name, node_id]
output_nodes: list[Link] = [] # dict[output_name, node_id]
metadata: dict[str, Any] = {}
@staticmethod
@@ -26,14 +36,20 @@ class Node(BaseDbModel):
id=node.id,
block_id=node.AgentBlock.id,
input_default=json.loads(node.constantInput),
input_nodes={v.sinkName: v.agentNodeSourceId for v in node.Input or []},
output_nodes={v.sourceName: v.agentNodeSinkId for v in node.Output or []},
input_nodes=[
Link(v.sinkName, v.agentNodeSourceId)
for v in node.Input or []
],
output_nodes=[
Link(v.sourceName, v.agentNodeSinkId)
for v in node.Output or []
],
metadata=json.loads(node.metadata),
)
def connect(self, node: "Node", source_name: str, sink_name: str):
self.output_nodes[source_name] = node.id
node.input_nodes[sink_name] = self.id
self.output_nodes.append(Link(source_name, node.id))
node.input_nodes.append(Link(sink_name, self.id))
class Graph(BaseDbModel):
@@ -85,41 +101,7 @@ async def get_graph(graph_id: str) -> Graph | None:
return Graph.from_db(graph) if graph else None
async def get_node_input(node: Node, exec_id: str) -> dict[str, Any]:
"""
Get execution node input data from the previous node execution result.
Args:
node: The execution node.
exec_id: The execution ID.
Returns:
dictionary of input data, key is the input name, value is the input data.
"""
query = await AgentNodeExecution.prisma().find_many(
where={ # type: ignore
"executionId": exec_id,
"agentNodeId": {"in": list(node.input_nodes.values())},
"executionStatus": "COMPLETED",
},
distinct=["agentNodeId"], # type: ignore
order={"creationTime": "desc"},
)
latest_executions: dict[str, AgentNodeExecution] = {
execution.agentNodeId: execution for execution in query
}
return {
**node.input_default,
**{
name: json.loads(latest_executions[node_id].outputData or "{}")
for name, node_id in node.input_nodes.items()
if node_id in latest_executions and latest_executions[node_id].outputData
},
}
async def create_graph(graph: Graph) -> Graph:
await AgentGraph.prisma().create(
data={
"id": graph.id,
@@ -142,12 +124,12 @@ async def create_graph(graph: Graph) -> Graph:
edge_source_names = {
(source_node.id, sink_node_id): output_name
for source_node in graph.nodes
for output_name, sink_node_id in source_node.output_nodes.items()
for output_name, sink_node_id in source_node.output_nodes
}
edge_sink_names = {
(source_node_id, sink_node.id): input_name
for sink_node in graph.nodes
for input_name, source_node_id in sink_node.input_nodes.items()
for input_name, source_node_id in sink_node.input_nodes
}
# TODO: replace bulk creation using create_many

View File

@@ -2,14 +2,13 @@ import json
from datetime import datetime
from typing import Optional, Any
from prisma.models import AgentExecutionSchedule
from prisma.models import AgentGraphExecutionSchedule
from autogpt_server.data.db import BaseDbModel
class ExecutionSchedule(BaseDbModel):
id: str
agent_id: str
graph_id: str
schedule: str
is_enabled: bool
input_data: dict[str, Any]
@@ -25,10 +24,10 @@ class ExecutionSchedule(BaseDbModel):
super().__init__(is_enabled=is_enabled, **kwargs)
@staticmethod
def from_db(schedule: AgentExecutionSchedule):
def from_db(schedule: AgentGraphExecutionSchedule):
return ExecutionSchedule(
id=schedule.id,
agent_id=schedule.agentGraphId,
graph_id=schedule.agentGraphId,
schedule=schedule.schedule,
is_enabled=schedule.isEnabled,
last_updated=schedule.lastUpdated.replace(tzinfo=None),
@@ -37,7 +36,7 @@ class ExecutionSchedule(BaseDbModel):
async def get_active_schedules(last_fetch_time: datetime) -> list[ExecutionSchedule]:
query = AgentExecutionSchedule.prisma().find_many(
query = AgentGraphExecutionSchedule.prisma().find_many(
where={
"isEnabled": True,
"lastUpdated": {"gt": last_fetch_time}
@@ -51,17 +50,17 @@ async def get_active_schedules(last_fetch_time: datetime) -> list[ExecutionSched
async def disable_schedule(schedule_id: str):
await AgentExecutionSchedule.prisma().update(
await AgentGraphExecutionSchedule.prisma().update(
where={"id": schedule_id},
data={"isEnabled": False}
)
async def get_schedules(agent_id: str) -> list[ExecutionSchedule]:
query = AgentExecutionSchedule.prisma().find_many(
async def get_schedules(graph_id: str) -> list[ExecutionSchedule]:
query = AgentGraphExecutionSchedule.prisma().find_many(
where={
"isEnabled": True,
"agentGraphId": agent_id,
"agentGraphId": graph_id,
},
)
return [
@@ -70,20 +69,21 @@ async def get_schedules(agent_id: str) -> list[ExecutionSchedule]:
]
async def add_schedule(schedule: ExecutionSchedule):
await AgentExecutionSchedule.prisma().create(
async def add_schedule(schedule: ExecutionSchedule) -> ExecutionSchedule:
obj = await AgentGraphExecutionSchedule.prisma().create(
data={
"id": schedule.id,
"agentGraphId": schedule.agent_id,
"agentGraphId": schedule.graph_id,
"schedule": schedule.schedule,
"isEnabled": schedule.is_enabled,
"inputData": json.dumps(schedule.input_data),
}
)
return ExecutionSchedule.from_db(obj)
async def update_schedule(schedule_id: str, is_enabled: bool):
await AgentExecutionSchedule.prisma().update(
await AgentGraphExecutionSchedule.prisma().update(
where={"id": schedule_id},
data={"isEnabled": is_enabled}
)

View File

@@ -1,30 +1,35 @@
import asyncio
import logging
import uuid
from concurrent.futures import ProcessPoolExecutor
from typing import Optional, Any
from typing import Any, Coroutine, Generator, TypeVar
from autogpt_server.data import db
from autogpt_server.data.block import Block, get_block
from autogpt_server.data.graph import Node, get_node, get_node_input, get_graph
from autogpt_server.data.execution import (
Execution,
get_node_execution_input,
create_graph_execution,
update_execution_status as execution_update,
upsert_execution_output,
upsert_execution_input,
NodeExecution as Execution,
ExecutionStatus,
ExecutionQueue,
enqueue_execution,
complete_execution,
fail_execution,
start_execution,
)
from autogpt_server.data.graph import Node, get_node, get_graph
from autogpt_server.util.service import AppService, expose
logger = logging.getLogger(__name__)
def get_log_prefix(run_id: str, exec_id: str, block_name: str = "-"):
return f"[ExecutionManager] [graph-{run_id}|node-{exec_id}|{block_name}]"
def get_log_prefix(graph_eid: str, node_eid: str, block_name: str = "-"):
return f"[ExecutionManager] [graph-{graph_eid}|node-{node_eid}|{block_name}]"
def execute_node(loop: asyncio.AbstractEventLoop, data: Execution) -> Execution | None:
T = TypeVar("T")
ExecutionStream = Generator[Execution, None, None]
def execute_node(loop: asyncio.AbstractEventLoop, data: Execution) -> ExecutionStream:
"""
Execute a node in the graph. This will trigger a block execution on a node,
persist the execution result, and return the subsequent node to be executed.
@@ -36,57 +41,102 @@ def execute_node(loop: asyncio.AbstractEventLoop, data: Execution) -> Execution
Returns:
The subsequent node to be enqueued, or None if there is no subsequent node.
"""
run_id = data.run_id
exec_id = data.id
graph_exec_id = data.graph_exec_id
node_exec_id = data.node_exec_id
exec_data = data.data
node_id = data.node_id
asyncio.set_event_loop(loop)
wait = lambda f: loop.run_until_complete(f)
node: Optional[Node] = wait(get_node(node_id))
def wait(f: Coroutine[T, Any, T]) -> T:
return loop.run_until_complete(f)
node = wait(get_node(node_id))
if not node:
logger.error(f"Node {node_id} not found.")
return None
return
node_block: Optional[Block] = wait(get_block(node.block_id))
node_block = wait(get_block(node.block_id))
if not node_block:
logger.error(f"Block {node.block_id} not found.")
return None
return
# Execute the node
prefix = get_log_prefix(run_id, exec_id, node_block.name)
prefix = get_log_prefix(graph_exec_id, node_exec_id, node_block.name)
logger.warning(f"{prefix} execute with input:\n`{exec_data}`")
wait(start_execution(exec_id))
wait(execution_update(node_exec_id, ExecutionStatus.RUNNING))
try:
output_name, output_data = node_block.execute(exec_data)
logger.warning(f"{prefix} executed with output [{output_name}]:`{output_data}`")
wait(complete_execution(exec_id, (output_name, output_data)))
for output_name, output_data in node_block.execute(exec_data):
logger.warning(f"{prefix} Executed, output [{output_name}]:`{output_data}`")
wait(execution_update(node_exec_id, ExecutionStatus.COMPLETED))
wait(upsert_execution_output(node_exec_id, output_name, output_data))
for execution in enqueue_next_nodes(
loop, node, output_name, output_data, graph_exec_id
):
yield execution
except Exception as e:
logger.exception(f"{prefix} failed with error: %s", e)
wait(fail_execution(exec_id, e))
wait(execution_update(node_exec_id, ExecutionStatus.FAILED))
wait(upsert_execution_output(node_exec_id, "error", str(e)))
raise e
def enqueue_next_nodes(
loop: asyncio.AbstractEventLoop,
node: Node,
output_name: str,
output_data: Any,
graph_exec_id: str,
) -> list[Execution]:
def wait(f: Coroutine[T, Any, T]) -> T:
return loop.run_until_complete(f)
prefix = get_log_prefix(graph_exec_id, node.id)
node_id = node.id
# Try to enqueue next eligible nodes
if output_name not in node.output_nodes:
next_node_ids = [nid for name, nid in node.output_nodes if name == output_name]
if not next_node_ids:
logger.error(f"{prefix} Output [{output_name}] has no subsequent node.")
return None
return []
next_node_id = node.output_nodes[output_name]
next_node: Optional[Node] = wait(get_node(next_node_id))
if not next_node:
logger.error(f"{prefix} Error, next node {next_node_id} not found.")
return None
def validate_node_execution(next_node_id: str):
next_node = wait(get_node(next_node_id))
if not next_node:
logger.error(f"{prefix} Error, next node {next_node_id} not found.")
return
next_node_input: dict[str, Any] = wait(get_node_input(next_node, run_id))
is_valid, validation_resp = wait(validate_exec(next_node, next_node_input))
if not is_valid:
logger.warning(f"{prefix} Skipped {next_node_id}: {validation_resp}")
return None
next_node_input_name = next(
name for name, nid in next_node.input_nodes if nid == node_id
)
next_node_exec_id = wait(upsert_execution_input(
node_id=next_node_id,
graph_exec_id=graph_exec_id,
input_name=next_node_input_name,
data=output_data
))
logger.warning(f"{prefix} Enqueue next node {next_node_id}-{validation_resp}")
return Execution(run_id=run_id, node_id=next_node_id, data=next_node_input)
next_node_input = wait(get_node_execution_input(next_node_exec_id))
is_valid, validation_resp = wait(validate_exec(next_node, next_node_input))
if not is_valid:
logger.warning(f"{prefix} Skipped {next_node_id}: {validation_resp}")
return
logger.warning(f"{prefix} Enqueue next node {next_node_id}-{validation_resp}")
return Execution(
graph_exec_id=graph_exec_id,
node_exec_id=next_node_exec_id,
node_id=next_node_id,
data=next_node_input
)
executions = []
for nid in next_node_ids:
if execution := validate_node_execution(nid):
executions.append(execution)
return executions
async def validate_exec(node: Node, data: dict[str, Any]) -> tuple[bool, str]:
@@ -105,10 +155,12 @@ async def validate_exec(node: Node, data: dict[str, Any]) -> tuple[bool, str]:
if not node_block:
return False, f"Block for {node.block_id} not found."
if not set(node.input_nodes).issubset(data):
return False, f"Input data missing: {set(node.input_nodes) - set(data)}"
input_fields = node_block.input_schema.get_fields()
if not input_fields.issubset(data):
return False, f"Input data missing: {input_fields - set(data)}"
if error := node_block.input_schema.validate_data(data):
logger.error("Input value doesn't match schema: %s", error)
return False, f"Input data doesn't match {node_block.name}: {error}"
return True, node_block.name
@@ -123,16 +175,16 @@ class Executor:
cls.loop.run_until_complete(db.connect())
@classmethod
def on_start_execution(cls, data: Execution) -> Optional[Execution | None]:
"""
A synchronous version of `execute_node`, to be used in the ProcessPoolExecutor.
"""
prefix = get_log_prefix(data.run_id, data.id)
def on_start_execution(cls, q: ExecutionQueue, data: Execution) -> bool:
prefix = get_log_prefix(data.graph_exec_id, data.node_exec_id)
try:
logger.warning(f"{prefix} Start execution")
return execute_node(cls.loop, data)
for execution in execute_node(cls.loop, data):
q.add(execution)
return True
except Exception as e:
logger.error(f"{prefix} Error: {e}")
logger.exception(f"{prefix} Error: {e}")
return False
class ExecutionManager(AppService):
@@ -142,59 +194,63 @@ class ExecutionManager(AppService):
self.queue = ExecutionQueue()
def run_service(self):
def on_complete_execution(f: asyncio.Future[Execution | None]):
exception = f.exception()
if exception:
logger.exception("Error during execution!! %s", exception)
return exception
execution = f.result()
if execution:
return self.add_node_execution(execution)
return None
with ProcessPoolExecutor(
max_workers=self.pool_size,
initializer=Executor.on_executor_start,
) as executor:
logger.warning(f"Execution manager started with {self.pool_size} workers.")
while True:
future = executor.submit(
executor.submit(
Executor.on_start_execution,
self.queue.get()
self.queue,
self.queue.get(),
)
future.add_done_callback(on_complete_execution) # type: ignore
@expose
def add_execution(self, graph_id: str, data: dict[str, Any]) -> dict:
run_id = str(uuid.uuid4())
agent = self.run_and_wait(get_graph(graph_id))
if not agent:
raise Exception(f"Agent #{graph_id} not found.")
graph = self.run_and_wait(get_graph(graph_id))
if not graph:
raise Exception(f"Graph #{graph_id} not found.")
# Currently, there is no constraint on the number of root nodes in the graph.
for node in agent.starting_nodes:
valid, error = self.run_and_wait(validate_exec(node, data))
for node in graph.starting_nodes:
input_data = {**node.input_default, **data}
valid, error = self.run_and_wait(validate_exec(node, input_data))
if not valid:
raise Exception(error)
graph_exec_id, node_execs = self.run_and_wait(create_graph_execution(
graph_id=graph_id,
node_ids=[node.id for node in graph.starting_nodes],
data=data
))
executions = []
for node in agent.starting_nodes:
exec_id = self.add_node_execution(
Execution(run_id=run_id, node_id=node.id, data=data)
for node_exec in node_execs:
input_data = self.run_and_wait(
get_node_execution_input(node_exec.node_exec_id)
)
self.add_node_execution(
Execution(
graph_exec_id=node_exec.graph_exec_id,
node_exec_id=node_exec.node_exec_id,
node_id=node_exec.node_id,
data=input_data,
)
)
executions.append({
"exec_id": exec_id,
"node_id": node.id,
"id": node_exec.node_exec_id,
"node_id": node_exec.node_id,
})
return {
"run_id": run_id,
"id": graph_exec_id,
"executions": executions,
}
def add_node_execution(self, execution: Execution) -> Execution:
self.run_and_wait(enqueue_execution(execution))
self.run_and_wait(execution_update(
execution.node_exec_id,
ExecutionStatus.QUEUED
))
return self.queue.add(execution)

View File

@@ -45,20 +45,20 @@ class ExecutionScheduler(AppService):
log(f"Adding recurring job {schedule.id}: {schedule.schedule}")
scheduler.add_job(
self.__execute_agent,
self.__execute_graph,
CronTrigger.from_crontab(schedule.schedule),
id=schedule.id,
args=[schedule.agent_id, schedule.input_data],
args=[schedule.graph_id, schedule.input_data],
replace_existing=True,
)
def __execute_agent(self, agent_id: str, input_data: dict):
def __execute_graph(self, graph_id: str, input_data: dict):
try:
log(f"Executing recurring job for agent #{agent_id}")
log(f"Executing recurring job for graph #{graph_id}")
execution_manager = self.execution_manager_client
execution_manager.add_execution(agent_id, input_data)
execution_manager.add_execution(graph_id, input_data)
except Exception as e:
logger.error(f"Error executing agent {agent_id}: {e}")
logger.exception(f"Error executing graph {graph_id}: {e}")
@expose
def update_schedule(self, schedule_id: str, is_enabled: bool) -> str:
@@ -66,17 +66,16 @@ class ExecutionScheduler(AppService):
return schedule_id
@expose
def add_execution_schedule(self, agent_id: str, cron: str, input_data: dict) -> str:
def add_execution_schedule(self, graph_id: str, cron: str, input_data: dict) -> str:
schedule = model.ExecutionSchedule(
agent_id=agent_id,
graph_id=graph_id,
schedule=cron,
input_data=input_data,
)
self.run_and_wait(model.add_schedule(schedule))
return schedule.id
return self.run_and_wait(model.add_schedule(schedule)).id
@expose
def get_execution_schedules(self, agent_id: str) -> dict[str, str]:
query = model.get_schedules(agent_id)
def get_execution_schedules(self, graph_id: str) -> dict[str, str]:
query = model.get_schedules(graph_id)
schedules: list[model.ExecutionSchedule] = self.run_and_wait(query)
return {v.id: v.schedule for v in schedules}

View File

@@ -4,7 +4,14 @@ import uvicorn
from contextlib import asynccontextmanager
from fastapi import APIRouter, FastAPI, HTTPException
from autogpt_server.data import db, execution, graph, block
from autogpt_server.data import db, execution, block
from autogpt_server.data.graph import (
create_graph,
get_graph,
get_graph_ids,
Graph,
Link,
)
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
from autogpt_server.util.process import AppProcess
from autogpt_server.util.service import get_service_client
@@ -34,46 +41,46 @@ class AgentServer(AppProcess):
router = APIRouter()
router.add_api_route(
path="/blocks",
endpoint=self.get_agent_blocks,
endpoint=self.get_graph_blocks,
methods=["GET"],
)
router.add_api_route(
path="/agents",
endpoint=self.get_agents,
path="/graphs",
endpoint=self.get_graphs,
methods=["GET"],
)
router.add_api_route(
path="/agents/{agent_id}",
endpoint=self.get_agent,
path="/graphs/{graph_id}",
endpoint=self.get_graph,
methods=["GET"],
)
router.add_api_route(
path="/agents",
endpoint=self.create_agent,
path="/graphs",
endpoint=self.create_new_graph,
methods=["POST"],
)
router.add_api_route(
path="/agents/{agent_id}/execute",
endpoint=self.execute_agent,
path="/graphs/{graph_id}/execute",
endpoint=self.execute_graph,
methods=["POST"],
)
router.add_api_route(
path="/agents/{agent_id}/executions/{run_id}",
path="/graphs/{graph_id}/executions/{run_id}",
endpoint=self.get_executions,
methods=["GET"],
)
router.add_api_route(
path="/agents/{agent_id}/schedules",
endpoint=self.schedule_agent,
path="/graphs/{graph_id}/schedules",
endpoint=self.create_schedule,
methods=["POST"],
)
router.add_api_route(
path="/agents/{agent_id}/schedules",
path="/graphs/{graph_id}/schedules",
endpoint=self.get_execution_schedules,
methods=["GET"],
)
router.add_api_route(
path="/agents/schedules/{schedule_id}",
path="/graphs/schedules/{schedule_id}",
endpoint=self.update_schedule,
methods=["PUT"],
)
@@ -89,52 +96,51 @@ class AgentServer(AppProcess):
def execution_scheduler_client(self) -> ExecutionScheduler:
return get_service_client(ExecutionScheduler)
async def get_agent_blocks(self) -> list[dict]:
async def get_graph_blocks(self) -> list[dict]:
return [v.to_dict() for v in await block.get_blocks()]
async def get_agents(self) -> list[str]:
return await graph.get_graph_ids()
async def get_graphs(self) -> list[str]:
return await get_graph_ids()
async def get_agent(self, agent_id: str) -> graph.Graph:
agent = await graph.get_graph(agent_id)
if not agent:
raise HTTPException(status_code=404, detail=f"Agent #{agent_id} not found.")
async def get_graph(self, graph_id: str) -> Graph:
graph = await get_graph(graph_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return graph
return agent
async def create_agent(self, agent: graph.Graph) -> graph.Graph:
agent.id = str(uuid.uuid4())
id_map = {node.id: str(uuid.uuid4()) for node in agent.nodes}
for node in agent.nodes:
async def create_new_graph(self, graph: Graph) -> Graph:
# TODO: replace uuid generation here to DB generated uuids.
graph.id = str(uuid.uuid4())
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
for node in graph.nodes:
node.id = id_map[node.id]
node.input_nodes = {k: id_map[v] for k, v in node.input_nodes.items()}
node.output_nodes = {k: id_map[v] for k, v in node.output_nodes.items()}
node.input_nodes = [Link(k, id_map[v]) for k, v in node.input_nodes]
node.output_nodes = [Link(k, id_map[v]) for k, v in node.output_nodes]
return await graph.create_graph(agent)
return await create_graph(graph)
async def execute_agent(self, agent_id: str, node_input: dict) -> dict:
async def execute_graph(self, graph_id: str, node_input: dict) -> dict:
try:
return self.execution_manager_client.add_execution(agent_id, node_input)
return self.execution_manager_client.add_execution(graph_id, node_input)
except Exception as e:
msg = e.__str__().encode().decode('unicode_escape')
raise HTTPException(status_code=400, detail=msg)
async def get_executions(
self, agent_id: str, run_id: str) -> list[execution.ExecutionResult]:
agent = await graph.get_graph(agent_id)
if not agent:
raise HTTPException(status_code=404, detail=f"Agent #{agent_id} not found.")
self, graph_id: str, run_id: str) -> list[execution.ExecutionResult]:
graph = await get_graph(graph_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Agent #{graph_id} not found.")
return await execution.get_executions(run_id)
async def schedule_agent(self, agent_id: str, cron: str, input_data: dict) -> dict:
agent = await graph.get_graph(agent_id)
if not agent:
raise HTTPException(status_code=404, detail=f"Agent #{agent_id} not found.")
async def create_schedule(self, graph_id: str, cron: str, input_data: dict) -> dict:
graph = await get_graph(graph_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
execution_scheduler = self.execution_scheduler_client
return {
"id": execution_scheduler.add_execution_schedule(agent_id, cron, input_data)
"id": execution_scheduler.add_execution_schedule(graph_id, cron, input_data)
}
def update_schedule(self, schedule_id: str, input_data: dict) -> dict:
@@ -143,6 +149,6 @@ class AgentServer(AppProcess):
execution_scheduler.update_schedule(schedule_id, is_enabled)
return {"id": schedule_id}
def get_execution_schedules(self, agent_id: str) -> dict[str, str]:
def get_execution_schedules(self, graph_id: str) -> dict[str, str]:
execution_scheduler = self.execution_scheduler_client
return execution_scheduler.get_execution_schedules(agent_id)
return execution_scheduler.get_execution_schedules(graph_id)

View File

@@ -15,6 +15,7 @@ from autogpt_server.util.process import AppProcess
logger = logging.getLogger(__name__)
conn_retry = retry(stop=stop_after_delay(5), wait=wait_exponential(multiplier=0.1))
T = TypeVar("T")
def expose(func: Callable) -> Callable:
@@ -23,7 +24,7 @@ def expose(func: Callable) -> Callable:
return func(*args, **kwargs)
except Exception as e:
msg = f"Error in {func.__name__}: {e.__str__()}"
logger.error(msg)
logger.exception(msg)
raise Exception(msg, e)
return pyro.expose(wrapper)
@@ -51,10 +52,10 @@ class AppService(AppProcess):
while True:
time.sleep(10)
def run_async(self, coro: Coroutine):
def run_async(self, coro: Coroutine[T, Any, T]):
return asyncio.run_coroutine_threadsafe(coro, self.shared_event_loop)
def run_and_wait(self, coro: Coroutine):
def run_and_wait(self, coro: Coroutine[T, Any, T]) -> T:
future = self.run_async(coro)
return future.result()

View File

@@ -11,23 +11,24 @@ generator client {
// This model describes the Agent Graph/Flow (Multi Agent System).
model AgentGraph {
id String @id
id String @id @default(uuid())
name String?
description String?
AgentNodes AgentNode[] @relation("AgentGraphNodes")
AgentExecutionSchedule AgentExecutionSchedule[]
AgentNodes AgentNode[]
AgentGraphExecution AgentGraphExecution[]
AgentGraphExecutionSchedule AgentGraphExecutionSchedule[]
}
// This model describes a single node in the Agent Graph/Flow (Multi Agent System).
model AgentNode {
id String @id
id String @id @default(uuid())
agentBlockId String
AgentBlock AgentBlock @relation(fields: [agentBlockId], references: [id])
agentGraphId String
AgentGraph AgentGraph @relation("AgentGraphNodes", fields: [agentGraphId], references: [id])
AgentGraph AgentGraph @relation(fields: [agentGraphId], references: [id])
// List of consumed input, that the parent node should provide.
Input AgentNodeLink[] @relation("AgentNodeSink")
@@ -46,7 +47,7 @@ model AgentNode {
// This model describes the link between two AgentNodes.
model AgentNodeLink {
id String @id
id String @id @default(uuid())
// Output of a node is connected to the source of the link.
agentNodeSourceId String
@@ -61,7 +62,7 @@ model AgentNodeLink {
// This model describes a component that will be executed by the AgentNode.
model AgentBlock {
id String @id
id String @id @default(uuid())
name String @unique
// We allow a block to have multiple types of input & output.
@@ -73,45 +74,55 @@ model AgentBlock {
ReferencedByAgentNode AgentNode[]
}
// This model describes the execution of an AgentGraph.
model AgentGraphExecution {
id String @id @default(uuid())
agentGraphId String
AgentGraph AgentGraph @relation(fields: [agentGraphId], references: [id])
AgentNodeExecutions AgentNodeExecution[]
}
// This model describes the execution of an AgentNode.
model AgentNodeExecution {
id String @id
executionId String
id String @id @default(uuid())
agentGraphExecutionId String
AgentGraphExecution AgentGraphExecution @relation(fields: [agentGraphExecutionId], references: [id])
agentNodeId String
AgentNode AgentNode @relation(fields: [agentNodeId], references: [id])
inputData String?
inputFiles FileDefinition[] @relation("InputFiles")
outputName String?
outputData String?
outputFiles FileDefinition[] @relation("OutputFiles")
Input AgentNodeExecutionInputOutput[] @relation("AgentNodeExecutionInput")
Output AgentNodeExecutionInputOutput[] @relation("AgentNodeExecutionOutput")
// sqlite does not support enum
// enum Status { QUEUED, RUNNING, SUCCESS, FAILED }
// enum Status { INCOMPLETE, QUEUED, RUNNING, SUCCESS, FAILED }
executionStatus String
creationTime DateTime
startTime DateTime?
endTime DateTime?
addedTime DateTime @default(now())
queuedTime DateTime?
startedTime DateTime?
endedTime DateTime?
}
// This model describes a file that can be used as input/output of an AgentNodeExecution.
model FileDefinition {
id String @id
path String
metadata String? // JSON serialized object
mimeType String?
size Int?
hash String?
encoding String?
// This model describes the output of an AgentNodeExecution.
model AgentNodeExecutionInputOutput {
id String @id @default(uuid())
name String
data String
time DateTime @default(now())
// Prisma requires explicit back-references.
ReferencedByInputFiles AgentNodeExecution[] @relation("InputFiles")
ReferencedByOutputFiles AgentNodeExecution[] @relation("OutputFiles")
referencedByInputExecId String?
ReferencedByInputExec AgentNodeExecution? @relation("AgentNodeExecutionInput", fields: [referencedByInputExecId], references: [id])
referencedByOutputExecId String?
ReferencedByOutputExec AgentNodeExecution? @relation("AgentNodeExecutionOutput", fields: [referencedByOutputExecId], references: [id])
}
// This model describes the recurring execution schedule of an Agent.
model AgentExecutionSchedule {
model AgentGraphExecutionSchedule {
id String @id
agentGraphId String

View File

@@ -45,18 +45,18 @@ async def create_test_graph() -> graph.Graph:
return test_graph
async def execute_agent(test_manager: ExecutionManager, test_graph: graph.Graph):
async def execute_graph(test_manager: ExecutionManager, test_graph: graph.Graph):
# --- Test adding new executions --- #
text = "Hello, World!"
input_data = {"input": text}
agent_server = AgentServer()
response = await agent_server.execute_agent(test_graph.id, input_data)
response = await agent_server.execute_graph(test_graph.id, input_data)
executions = response["executions"]
run_id = response["run_id"]
graph_exec_id = response["id"]
assert len(executions) == 2
async def is_execution_completed():
execs = await agent_server.get_executions(test_graph.id, run_id)
execs = await agent_server.get_executions(test_graph.id, graph_exec_id)
return test_manager.queue.empty() and len(execs) == 4
# Wait for the executions to complete
@@ -67,34 +67,30 @@ async def execute_agent(test_manager: ExecutionManager, test_graph: graph.Graph)
# Execution queue should be empty
assert await is_execution_completed()
executions = await agent_server.get_executions(test_graph.id, run_id)
executions = await agent_server.get_executions(test_graph.id, graph_exec_id)
# Executing ParrotBlock1
exec = executions[0]
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.run_id == run_id
assert exec.output_name == "output"
assert exec.output_data == "Hello, World!"
assert exec.input_data == input_data
assert exec.graph_exec_id == graph_exec_id
assert exec.output_data == {"output": ["Hello, World!"]}
assert exec.input_data == {"input": text}
assert exec.node_id == test_graph.nodes[0].id
# Executing ParrotBlock2
exec = executions[1]
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.run_id == run_id
assert exec.output_name == "output"
assert exec.output_data == "Hello, World!"
assert exec.input_data == input_data
assert exec.graph_exec_id == graph_exec_id
assert exec.output_data == {"output": ["Hello, World!"]}
assert exec.input_data == {"input": text}
assert exec.node_id == test_graph.nodes[1].id
# Executing TextCombinerBlock
exec = executions[2]
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.run_id == run_id
assert exec.output_name == "combined_text"
assert exec.output_data == "Hello, World!,Hello, World!"
assert exec.graph_exec_id == graph_exec_id
assert exec.output_data == {"combined_text": ["Hello, World!,Hello, World!"]}
assert exec.input_data == {
"format": "{text1},{text2}",
"text1": "Hello, World!",
"text2": "Hello, World!",
}
@@ -103,9 +99,8 @@ async def execute_agent(test_manager: ExecutionManager, test_graph: graph.Graph)
# Executing PrintingBlock
exec = executions[3]
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.run_id == run_id
assert exec.output_name == "status"
assert exec.output_data == "printed"
assert exec.graph_exec_id == graph_exec_id
assert exec.output_data == {"status": ["printed"]}
assert exec.input_data == {"text": "Hello, World!,Hello, World!"}
assert exec.node_id == test_graph.nodes[3].id
@@ -116,4 +111,4 @@ async def test_agent_execution():
with ExecutionManager(1) as test_manager:
await db.connect()
test_graph = await create_test_graph()
await execute_agent(test_manager, test_graph)
await execute_graph(test_manager, test_graph)