feat(rnd): Introduce Sub-Graph on Agent Server (#7693)

### Background

This change brings the capability to decompose a graph into sub-graphs. The objective of this feature is to allow a user to build a visually modular, and easier-to-understand graph. Also, allowing you to import a graph into your existing graph, without decluttering your existing graph.

This feature will require more implementation on the UI side, to allow the grouping of subgraph to be represented as a node in the builder.

### Changes 🏗️

Introduced a subgraph functionality with the following property:

* Sub-graph is simply a set of nodes that are grouped together, making it representable as a node.
* Sub-graph input & output pins/schema are the `InputBlock` / `OutputBlock` nodes present in the subgraph.
* The previous point implies that connecting two nodes from different sub-graphs, other than input/output nodes, is not allowed.
* Graph can be nested, but defined flatly, e.g.: graph is now only represented by three components: nodes, links, and subgraphs (a set of list of nodes). A nested subgraph is simply connecting a node inside a subgraph into another `InputBlock` node of another subgraph.
This commit is contained in:
Zamil Majdy
2024-08-05 13:48:14 +04:00
committed by GitHub
parent c7fdfa0f77
commit 4cf1dd30f1
9 changed files with 250 additions and 61 deletions

View File

@@ -1,4 +1,5 @@
import os
from contextlib import asynccontextmanager
from uuid import uuid4
from dotenv import load_dotenv
@@ -23,6 +24,12 @@ async def disconnect():
await prisma.disconnect()
@asynccontextmanager
async def transaction():
async with prisma.tx() as tx:
yield tx
class BaseDbModel(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))

View File

@@ -9,7 +9,7 @@ from pydantic import PrivateAttr
from autogpt_server.blocks.basic import InputBlock, OutputBlock
from autogpt_server.data.block import BlockInput, get_block
from autogpt_server.data.db import BaseDbModel
from autogpt_server.data.db import BaseDbModel, transaction
from autogpt_server.util import json
@@ -89,6 +89,7 @@ class GraphMeta(BaseDbModel):
class Graph(GraphMeta):
nodes: list[Node]
links: list[Link]
subgraphs: dict[str, list[str]] = {} # subgraph_id -> [node_id]
@property
def starting_nodes(self) -> list[Node]:
@@ -106,17 +107,63 @@ class Graph(GraphMeta):
def ending_nodes(self) -> list[Node]:
return [v for v in self.nodes if isinstance(get_block(v.block_id), OutputBlock)]
def validate_graph(self):
@property
def subgraph_map(self) -> dict[str, str]:
"""
Returns a mapping of node_id to subgraph_id.
A node in the main graph will be mapped to the graph's id.
"""
subgraph_map = {
node_id: subgraph_id
for subgraph_id, node_ids in self.subgraphs.items()
for node_id in node_ids
}
subgraph_map.update(
{node.id: self.id for node in self.nodes if node.id not in subgraph_map}
)
return subgraph_map
def reassign_ids(self):
"""
Reassigns all IDs in the graph to new UUIDs.
This method can be used before storing a new graph to the database.
"""
self.validate_graph()
id_map = {
self.id: str(uuid.uuid4()),
**{node.id: str(uuid.uuid4()) for node in self.nodes},
**{subgraph_id: str(uuid.uuid4()) for subgraph_id in self.subgraphs},
}
self.id = id_map[self.id]
for node in self.nodes:
node.id = id_map[node.id]
for link in self.links:
link.source_id = id_map[link.source_id]
link.sink_id = id_map[link.sink_id]
self.subgraphs = {
id_map[subgraph_id]: [id_map[node_id] for node_id in node_ids]
for subgraph_id, node_ids in self.subgraphs.items()
}
def validate_graph(self, for_run: bool = False):
def sanitize(name):
return name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
# Check if all required fields are filled or connected, except for InputBlock.
# Nodes: required fields are filled or connected, except for InputBlock.
for node in self.nodes:
block = get_block(node.block_id)
if block is None:
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
if not for_run:
continue # Skip input completion validation, unless when executing.
provided_inputs = set(
[sanitize(name) for name in node.input_default]
+ [sanitize(link.sink_name) for link in node.input_links]
@@ -126,57 +173,92 @@ class Graph(GraphMeta):
raise ValueError(
f"Node {block.name} #{node.id} required input missing: `{name}`"
)
node_map = {v.id: v for v in self.nodes}
# Check if all links are connected compatible pin data type.
def is_input_output_block(nid: str) -> bool:
bid = node_map[nid].block_id
b = get_block(bid)
return isinstance(b, InputBlock) or isinstance(b, OutputBlock)
# subgraphs: all nodes in subgraph must be present in the graph.
for subgraph_id, node_ids in self.subgraphs.items():
for node_id in node_ids:
if node_id not in node_map:
raise ValueError(f"Subgraph {subgraph_id}'s node {node_id} invalid")
subgraph_map = self.subgraph_map
# Links: links are connected and the connected pin data type are compatible.
for link in self.links:
source_id = link.source_id
sink_id = link.sink_id
suffix = f"Link {source_id}<->{sink_id}"
source = (link.source_id, link.source_name)
sink = (link.sink_id, link.sink_name)
suffix = f"Link {source} <-> {sink}"
source_node = next((v for v in self.nodes if v.id == source_id), None)
if not source_node:
raise ValueError(f"{suffix}, {source_id} is invalid node.")
sink_node = next((v for v in self.nodes if v.id == sink_id), None)
if not sink_node:
raise ValueError(f"{suffix}, {sink_id} is invalid node.")
for i, (node_id, name) in enumerate([source, sink]):
node = node_map.get(node_id)
if not node:
raise ValueError(f"{suffix}, {node_id} is invalid node.")
source_block = get_block(source_node.block_id)
if not source_block:
raise ValueError(f"{suffix}, {source_node.block_id} is invalid block.")
sink_block = get_block(sink_node.block_id)
if not sink_block:
raise ValueError(f"{suffix}, {sink_node.block_id} is invalid block.")
block = get_block(node.block_id)
if not block:
raise ValueError(f"{suffix}, {node.block_id} is invalid block.")
source_name = sanitize(link.source_name)
if source_name not in source_block.output_schema.get_fields():
raise ValueError(f"{suffix}, `{source_name}` is invalid output pin.")
sink_name = sanitize(link.sink_name)
if sink_name not in sink_block.input_schema.get_fields():
raise ValueError(f"{suffix}, `{sink_name}` is invalid input pin.")
sanitized_name = sanitize(name)
if i == 0:
fields = block.output_schema.get_fields()
else:
fields = block.input_schema.get_fields()
if sanitized_name not in fields:
raise ValueError(f"{suffix}, `{name}` invalid, fields: {fields}")
if (
subgraph_map.get(link.source_id) != subgraph_map.get(link.sink_id)
and not is_input_output_block(link.source_id)
and not is_input_output_block(link.sink_id)
):
raise ValueError(f"{suffix}, Connecting nodes from different subgraph.")
# TODO: Add type compatibility check here.
@staticmethod
def from_db(graph: AgentGraph):
nodes = [
*(graph.AgentNodes or []),
*(
node
for subgraph in graph.AgentSubGraphs or []
for node in subgraph.AgentNodes or []
),
]
return Graph(
**GraphMeta.from_db(graph).model_dump(),
nodes=[Node.from_db(node) for node in graph.AgentNodes or []],
nodes=[Node.from_db(node) for node in nodes],
links=list(
{
Link.from_db(link)
for node in graph.AgentNodes or []
for node in nodes
for link in (node.Input or []) + (node.Output or [])
}
),
subgraphs={
subgraph.id: [node.id for node in subgraph.AgentNodes or []]
for subgraph in graph.AgentSubGraphs or []
},
)
EXECUTION_NODE_INCLUDE = {
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
"Input": True,
"Output": True,
"AgentBlock": True,
}
__SUBGRAPH_INCLUDE = {"AgentNodes": {"include": AGENT_NODE_INCLUDE}}
AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
**__SUBGRAPH_INCLUDE,
"AgentSubGraphs": {"include": __SUBGRAPH_INCLUDE}, # type: ignore
}
# --------------------- Model functions --------------------- #
@@ -184,7 +266,7 @@ EXECUTION_NODE_INCLUDE = {
async def get_node(node_id: str) -> Node | None:
node = await AgentNode.prisma().find_unique_or_raise(
where={"id": node_id},
include=EXECUTION_NODE_INCLUDE, # type: ignore
include=AGENT_NODE_INCLUDE,
)
return Node.from_db(node) if node else None
@@ -242,7 +324,7 @@ async def get_graph(
graph = await AgentGraph.prisma().find_first(
where=where_clause,
include={"AgentNodes": {"include": EXECUTION_NODE_INCLUDE}}, # type: ignore
include=AGENT_GRAPH_INCLUDE,
order={"version": "desc"},
)
return Graph.from_db(graph) if graph else None
@@ -267,7 +349,7 @@ async def get_graph_all_versions(graph_id: str) -> list[Graph]:
graph_versions = await AgentGraph.prisma().find_many(
where={"id": graph_id},
order={"version": "desc"},
include={"AgentNodes": {"include": EXECUTION_NODE_INCLUDE}}, # type: ignore
include=AGENT_GRAPH_INCLUDE,
)
if not graph_versions:
@@ -277,7 +359,17 @@ async def get_graph_all_versions(graph_id: str) -> list[Graph]:
async def create_graph(graph: Graph) -> Graph:
await AgentGraph.prisma().create(
async with transaction() as tx:
await __create_graph(tx, graph)
if created_graph := await get_graph(graph.id, graph.version, graph.is_template):
return created_graph
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
async def __create_graph(tx, graph: Graph):
await AgentGraph.prisma(tx).create(
data={
"id": graph.id,
"version": graph.version,
@@ -290,11 +382,30 @@ async def create_graph(graph: Graph) -> Graph:
await asyncio.gather(
*[
AgentNode.prisma().create(
AgentGraph.prisma(tx).create(
data={
"id": subgraph_id,
"agentGraphParentId": graph.id,
"version": graph.version,
"name": f"SubGraph of {graph.name}",
"description": f"Sub-Graph of {graph.id}",
"isTemplate": graph.is_template,
"isActive": graph.is_active,
}
)
for subgraph_id in graph.subgraphs
]
)
subgraph_map = graph.subgraph_map
await asyncio.gather(
*[
AgentNode.prisma(tx).create(
{
"id": node.id,
"agentBlockId": node.block_id,
"agentGraphId": graph.id,
"agentGraphId": subgraph_map.get(node.id, graph.id),
"agentGraphVersion": graph.version,
"constantInput": json.dumps(node.input_default),
"metadata": json.dumps(node.metadata),
@@ -306,7 +417,7 @@ async def create_graph(graph: Graph) -> Graph:
await asyncio.gather(
*[
AgentNodeLink.prisma().create(
AgentNodeLink.prisma(tx).create(
{
"id": str(uuid.uuid4()),
"sourceName": link.source_name,
@@ -320,13 +431,6 @@ async def create_graph(graph: Graph) -> Graph:
]
)
if created_graph := await get_graph(
graph.id, graph.version, template=graph.is_template
):
return created_graph
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
# --------------------- Helper functions --------------------- #

View File

@@ -420,7 +420,7 @@ class ExecutionManager(AppService):
graph: Graph | None = self.run_and_wait(get_graph(graph_id))
if not graph:
raise Exception(f"Graph #{graph_id} not found.")
graph.validate_graph()
graph.validate_graph(for_run=True)
nodes_input = []
for node in graph.starting_nodes:

View File

@@ -1,5 +1,4 @@
import asyncio
import uuid
from collections import defaultdict
from contextlib import asynccontextmanager
from typing import Annotated, Any, Dict
@@ -468,15 +467,7 @@ class AgentServer(AppService):
graph.is_template = is_template
graph.is_active = not is_template
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
for node in graph.nodes:
node.id = id_map[node.id]
for link in graph.links:
link.source_id = id_map[link.source_id]
link.sink_id = id_map[link.sink_id]
graph.reassign_ids()
return await graph_db.create_graph(graph)
@@ -501,14 +492,7 @@ class AgentServer(AppService):
400, detail="Changing is_template on an existing graph is forbidden"
)
graph.is_active = not graph.is_template
# Assign new UUIDs to all nodes and links
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
for node in graph.nodes:
node.id = id_map[node.id]
for link in graph.links:
link.source_id = id_map[link.source_id]
link.sink_id = id_map[link.sink_id]
graph.reassign_ids()
new_graph_version = await graph_db.create_graph(graph)

View File

@@ -0,0 +1,19 @@
-- RedefineTables
PRAGMA foreign_keys=OFF;
CREATE TABLE "new_AgentGraph" (
"id" TEXT NOT NULL,
"version" INTEGER NOT NULL DEFAULT 1,
"name" TEXT,
"description" TEXT,
"isActive" BOOLEAN NOT NULL DEFAULT true,
"isTemplate" BOOLEAN NOT NULL DEFAULT false,
"agentGraphParentId" TEXT,
PRIMARY KEY ("id", "version"),
CONSTRAINT "AgentGraph_agentGraphParentId_version_fkey" FOREIGN KEY ("agentGraphParentId", "version") REFERENCES "AgentGraph" ("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE
);
INSERT INTO "new_AgentGraph" ("description", "id", "isActive", "isTemplate", "name", "version") SELECT "description", "id", "isActive", "isTemplate", "name", "version" FROM "AgentGraph";
DROP TABLE "AgentGraph";
ALTER TABLE "new_AgentGraph" RENAME TO "AgentGraph";
PRAGMA foreign_key_check;
PRAGMA foreign_keys=ON;

View File

@@ -0,0 +1,5 @@
-- AlterTable
ALTER TABLE "AgentGraph" ADD COLUMN "agentGraphParentId" TEXT;
-- AddForeignKey
ALTER TABLE "AgentGraph" ADD CONSTRAINT "AgentGraph_agentGraphParentId_version_fkey" FOREIGN KEY ("agentGraphParentId", "version") REFERENCES "AgentGraph"("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE;

View File

@@ -24,6 +24,11 @@ model AgentGraph {
AgentGraphExecution AgentGraphExecution[]
AgentGraphExecutionSchedule AgentGraphExecutionSchedule[]
// All sub-graphs are defined within this 1-level depth list (even if it's a nested graph).
AgentSubGraphs AgentGraph[] @relation("AgentSubGraph")
agentGraphParentId String?
AgentGraphParent AgentGraph? @relation("AgentSubGraph", fields: [agentGraphParentId, version], references: [id, version])
@@id(name: "graphVersionId", [id, version])
}

View File

@@ -23,6 +23,11 @@ model AgentGraph {
AgentGraphExecution AgentGraphExecution[]
AgentGraphExecutionSchedule AgentGraphExecutionSchedule[]
// All sub-graphs are defined within this 1-level depth list (even if it's a nested graph).
AgentSubGraphs AgentGraph[] @relation("AgentSubGraph")
agentGraphParentId String?
AgentGraphParent AgentGraph? @relation("AgentSubGraph", fields: [agentGraphParentId, version], references: [id, version])
@@id(name: "graphVersionId", [id, version])
}

View File

@@ -0,0 +1,60 @@
from uuid import UUID
import pytest
from autogpt_server.blocks.basic import InputBlock, ValueBlock
from autogpt_server.data.graph import Graph, Link, Node
from autogpt_server.server.model import CreateGraph
@pytest.mark.asyncio(scope="session")
async def test_graph_creation(server):
value_block = ValueBlock().id
input_block = InputBlock().id
graph = Graph(
id="test_graph",
name="TestGraph",
description="Test graph",
nodes=[
Node(id="node_1", block_id=value_block),
Node(id="node_2", block_id=input_block),
Node(id="node_3", block_id=value_block),
],
links=[
Link(
source_id="node_1",
sink_id="node_3",
source_name="output",
sink_name="input",
),
],
subgraphs={"subgraph_1": ["node_2", "node_3"]},
)
create_graph = CreateGraph(graph=graph)
try:
await server.agent_server.create_graph(create_graph, False)
assert False, "Should not be able to connect nodes from different subgraphs"
except ValueError as e:
assert "different subgraph" in str(e)
# Change node_1 <-> node_3 link to node_1 <-> node_2 (input for subgraph_1)
graph.links[0].sink_id = "node_2"
created_graph = await server.agent_server.create_graph(create_graph, False)
assert UUID(created_graph.id)
assert created_graph.name == "TestGraph"
assert len(created_graph.nodes) == 3
assert UUID(created_graph.nodes[0].id)
assert UUID(created_graph.nodes[1].id)
assert UUID(created_graph.nodes[2].id)
nodes = created_graph.nodes
links = created_graph.links
assert len(links) == 1
assert {nodes[0].id, nodes[1].id} == {links[0].source_id, links[0].sink_id}
assert len(created_graph.subgraphs) == 1
assert len(created_graph.subgraph_map) == len(created_graph.nodes) == 3