mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(rnd): Introduce Sub-Graph on Agent Server (#7693)
### Background
This change brings the capability to decompose a graph into sub-graphs. The objective of this feature is to allow a user to build a visually modular, and easier-to-understand graph. Also, allowing you to import a graph into your existing graph, without decluttering your existing graph.
This feature will require more implementation on the UI side, to allow the grouping of subgraph to be represented as a node in the builder.
### Changes 🏗️
Introduced a subgraph functionality with the following property:
* Sub-graph is simply a set of nodes that are grouped together, making it representable as a node.
* Sub-graph input & output pins/schema are the `InputBlock` / `OutputBlock` nodes present in the subgraph.
* The previous point implies that connecting two nodes from different sub-graphs, other than input/output nodes, is not allowed.
* Graph can be nested, but defined flatly, e.g.: graph is now only represented by three components: nodes, links, and subgraphs (a set of list of nodes). A nested subgraph is simply connecting a node inside a subgraph into another `InputBlock` node of another subgraph.
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from uuid import uuid4
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -23,6 +24,12 @@ async def disconnect():
|
||||
await prisma.disconnect()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction():
|
||||
async with prisma.tx() as tx:
|
||||
yield tx
|
||||
|
||||
|
||||
class BaseDbModel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from pydantic import PrivateAttr
|
||||
|
||||
from autogpt_server.blocks.basic import InputBlock, OutputBlock
|
||||
from autogpt_server.data.block import BlockInput, get_block
|
||||
from autogpt_server.data.db import BaseDbModel
|
||||
from autogpt_server.data.db import BaseDbModel, transaction
|
||||
from autogpt_server.util import json
|
||||
|
||||
|
||||
@@ -89,6 +89,7 @@ class GraphMeta(BaseDbModel):
|
||||
class Graph(GraphMeta):
|
||||
nodes: list[Node]
|
||||
links: list[Link]
|
||||
subgraphs: dict[str, list[str]] = {} # subgraph_id -> [node_id]
|
||||
|
||||
@property
|
||||
def starting_nodes(self) -> list[Node]:
|
||||
@@ -106,17 +107,63 @@ class Graph(GraphMeta):
|
||||
def ending_nodes(self) -> list[Node]:
|
||||
return [v for v in self.nodes if isinstance(get_block(v.block_id), OutputBlock)]
|
||||
|
||||
def validate_graph(self):
|
||||
@property
|
||||
def subgraph_map(self) -> dict[str, str]:
|
||||
"""
|
||||
Returns a mapping of node_id to subgraph_id.
|
||||
A node in the main graph will be mapped to the graph's id.
|
||||
"""
|
||||
subgraph_map = {
|
||||
node_id: subgraph_id
|
||||
for subgraph_id, node_ids in self.subgraphs.items()
|
||||
for node_id in node_ids
|
||||
}
|
||||
subgraph_map.update(
|
||||
{node.id: self.id for node in self.nodes if node.id not in subgraph_map}
|
||||
)
|
||||
return subgraph_map
|
||||
|
||||
def reassign_ids(self):
|
||||
"""
|
||||
Reassigns all IDs in the graph to new UUIDs.
|
||||
This method can be used before storing a new graph to the database.
|
||||
"""
|
||||
self.validate_graph()
|
||||
|
||||
id_map = {
|
||||
self.id: str(uuid.uuid4()),
|
||||
**{node.id: str(uuid.uuid4()) for node in self.nodes},
|
||||
**{subgraph_id: str(uuid.uuid4()) for subgraph_id in self.subgraphs},
|
||||
}
|
||||
|
||||
self.id = id_map[self.id]
|
||||
|
||||
for node in self.nodes:
|
||||
node.id = id_map[node.id]
|
||||
|
||||
for link in self.links:
|
||||
link.source_id = id_map[link.source_id]
|
||||
link.sink_id = id_map[link.sink_id]
|
||||
|
||||
self.subgraphs = {
|
||||
id_map[subgraph_id]: [id_map[node_id] for node_id in node_ids]
|
||||
for subgraph_id, node_ids in self.subgraphs.items()
|
||||
}
|
||||
|
||||
def validate_graph(self, for_run: bool = False):
|
||||
|
||||
def sanitize(name):
|
||||
return name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
|
||||
|
||||
# Check if all required fields are filled or connected, except for InputBlock.
|
||||
# Nodes: required fields are filled or connected, except for InputBlock.
|
||||
for node in self.nodes:
|
||||
block = get_block(node.block_id)
|
||||
if block is None:
|
||||
raise ValueError(f"Invalid block {node.block_id} for node #{node.id}")
|
||||
|
||||
if not for_run:
|
||||
continue # Skip input completion validation, unless when executing.
|
||||
|
||||
provided_inputs = set(
|
||||
[sanitize(name) for name in node.input_default]
|
||||
+ [sanitize(link.sink_name) for link in node.input_links]
|
||||
@@ -126,57 +173,92 @@ class Graph(GraphMeta):
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} required input missing: `{name}`"
|
||||
)
|
||||
node_map = {v.id: v for v in self.nodes}
|
||||
|
||||
# Check if all links are connected compatible pin data type.
|
||||
def is_input_output_block(nid: str) -> bool:
|
||||
bid = node_map[nid].block_id
|
||||
b = get_block(bid)
|
||||
return isinstance(b, InputBlock) or isinstance(b, OutputBlock)
|
||||
|
||||
# subgraphs: all nodes in subgraph must be present in the graph.
|
||||
for subgraph_id, node_ids in self.subgraphs.items():
|
||||
for node_id in node_ids:
|
||||
if node_id not in node_map:
|
||||
raise ValueError(f"Subgraph {subgraph_id}'s node {node_id} invalid")
|
||||
subgraph_map = self.subgraph_map
|
||||
|
||||
# Links: links are connected and the connected pin data type are compatible.
|
||||
for link in self.links:
|
||||
source_id = link.source_id
|
||||
sink_id = link.sink_id
|
||||
suffix = f"Link {source_id}<->{sink_id}"
|
||||
source = (link.source_id, link.source_name)
|
||||
sink = (link.sink_id, link.sink_name)
|
||||
suffix = f"Link {source} <-> {sink}"
|
||||
|
||||
source_node = next((v for v in self.nodes if v.id == source_id), None)
|
||||
if not source_node:
|
||||
raise ValueError(f"{suffix}, {source_id} is invalid node.")
|
||||
sink_node = next((v for v in self.nodes if v.id == sink_id), None)
|
||||
if not sink_node:
|
||||
raise ValueError(f"{suffix}, {sink_id} is invalid node.")
|
||||
for i, (node_id, name) in enumerate([source, sink]):
|
||||
node = node_map.get(node_id)
|
||||
if not node:
|
||||
raise ValueError(f"{suffix}, {node_id} is invalid node.")
|
||||
|
||||
source_block = get_block(source_node.block_id)
|
||||
if not source_block:
|
||||
raise ValueError(f"{suffix}, {source_node.block_id} is invalid block.")
|
||||
sink_block = get_block(sink_node.block_id)
|
||||
if not sink_block:
|
||||
raise ValueError(f"{suffix}, {sink_node.block_id} is invalid block.")
|
||||
block = get_block(node.block_id)
|
||||
if not block:
|
||||
raise ValueError(f"{suffix}, {node.block_id} is invalid block.")
|
||||
|
||||
source_name = sanitize(link.source_name)
|
||||
if source_name not in source_block.output_schema.get_fields():
|
||||
raise ValueError(f"{suffix}, `{source_name}` is invalid output pin.")
|
||||
sink_name = sanitize(link.sink_name)
|
||||
if sink_name not in sink_block.input_schema.get_fields():
|
||||
raise ValueError(f"{suffix}, `{sink_name}` is invalid input pin.")
|
||||
sanitized_name = sanitize(name)
|
||||
if i == 0:
|
||||
fields = block.output_schema.get_fields()
|
||||
else:
|
||||
fields = block.input_schema.get_fields()
|
||||
if sanitized_name not in fields:
|
||||
raise ValueError(f"{suffix}, `{name}` invalid, fields: {fields}")
|
||||
|
||||
if (
|
||||
subgraph_map.get(link.source_id) != subgraph_map.get(link.sink_id)
|
||||
and not is_input_output_block(link.source_id)
|
||||
and not is_input_output_block(link.sink_id)
|
||||
):
|
||||
raise ValueError(f"{suffix}, Connecting nodes from different subgraph.")
|
||||
|
||||
# TODO: Add type compatibility check here.
|
||||
|
||||
@staticmethod
|
||||
def from_db(graph: AgentGraph):
|
||||
nodes = [
|
||||
*(graph.AgentNodes or []),
|
||||
*(
|
||||
node
|
||||
for subgraph in graph.AgentSubGraphs or []
|
||||
for node in subgraph.AgentNodes or []
|
||||
),
|
||||
]
|
||||
return Graph(
|
||||
**GraphMeta.from_db(graph).model_dump(),
|
||||
nodes=[Node.from_db(node) for node in graph.AgentNodes or []],
|
||||
nodes=[Node.from_db(node) for node in nodes],
|
||||
links=list(
|
||||
{
|
||||
Link.from_db(link)
|
||||
for node in graph.AgentNodes or []
|
||||
for node in nodes
|
||||
for link in (node.Input or []) + (node.Output or [])
|
||||
}
|
||||
),
|
||||
subgraphs={
|
||||
subgraph.id: [node.id for node in subgraph.AgentNodes or []]
|
||||
for subgraph in graph.AgentSubGraphs or []
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
EXECUTION_NODE_INCLUDE = {
|
||||
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"AgentBlock": True,
|
||||
}
|
||||
|
||||
__SUBGRAPH_INCLUDE = {"AgentNodes": {"include": AGENT_NODE_INCLUDE}}
|
||||
|
||||
AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
|
||||
**__SUBGRAPH_INCLUDE,
|
||||
"AgentSubGraphs": {"include": __SUBGRAPH_INCLUDE}, # type: ignore
|
||||
}
|
||||
|
||||
|
||||
# --------------------- Model functions --------------------- #
|
||||
|
||||
@@ -184,7 +266,7 @@ EXECUTION_NODE_INCLUDE = {
|
||||
async def get_node(node_id: str) -> Node | None:
|
||||
node = await AgentNode.prisma().find_unique_or_raise(
|
||||
where={"id": node_id},
|
||||
include=EXECUTION_NODE_INCLUDE, # type: ignore
|
||||
include=AGENT_NODE_INCLUDE,
|
||||
)
|
||||
return Node.from_db(node) if node else None
|
||||
|
||||
@@ -242,7 +324,7 @@ async def get_graph(
|
||||
|
||||
graph = await AgentGraph.prisma().find_first(
|
||||
where=where_clause,
|
||||
include={"AgentNodes": {"include": EXECUTION_NODE_INCLUDE}}, # type: ignore
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
order={"version": "desc"},
|
||||
)
|
||||
return Graph.from_db(graph) if graph else None
|
||||
@@ -267,7 +349,7 @@ async def get_graph_all_versions(graph_id: str) -> list[Graph]:
|
||||
graph_versions = await AgentGraph.prisma().find_many(
|
||||
where={"id": graph_id},
|
||||
order={"version": "desc"},
|
||||
include={"AgentNodes": {"include": EXECUTION_NODE_INCLUDE}}, # type: ignore
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
|
||||
if not graph_versions:
|
||||
@@ -277,7 +359,17 @@ async def get_graph_all_versions(graph_id: str) -> list[Graph]:
|
||||
|
||||
|
||||
async def create_graph(graph: Graph) -> Graph:
|
||||
await AgentGraph.prisma().create(
|
||||
async with transaction() as tx:
|
||||
await __create_graph(tx, graph)
|
||||
|
||||
if created_graph := await get_graph(graph.id, graph.version, graph.is_template):
|
||||
return created_graph
|
||||
|
||||
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
|
||||
|
||||
|
||||
async def __create_graph(tx, graph: Graph):
|
||||
await AgentGraph.prisma(tx).create(
|
||||
data={
|
||||
"id": graph.id,
|
||||
"version": graph.version,
|
||||
@@ -290,11 +382,30 @@ async def create_graph(graph: Graph) -> Graph:
|
||||
|
||||
await asyncio.gather(
|
||||
*[
|
||||
AgentNode.prisma().create(
|
||||
AgentGraph.prisma(tx).create(
|
||||
data={
|
||||
"id": subgraph_id,
|
||||
"agentGraphParentId": graph.id,
|
||||
"version": graph.version,
|
||||
"name": f"SubGraph of {graph.name}",
|
||||
"description": f"Sub-Graph of {graph.id}",
|
||||
"isTemplate": graph.is_template,
|
||||
"isActive": graph.is_active,
|
||||
}
|
||||
)
|
||||
for subgraph_id in graph.subgraphs
|
||||
]
|
||||
)
|
||||
|
||||
subgraph_map = graph.subgraph_map
|
||||
|
||||
await asyncio.gather(
|
||||
*[
|
||||
AgentNode.prisma(tx).create(
|
||||
{
|
||||
"id": node.id,
|
||||
"agentBlockId": node.block_id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphId": subgraph_map.get(node.id, graph.id),
|
||||
"agentGraphVersion": graph.version,
|
||||
"constantInput": json.dumps(node.input_default),
|
||||
"metadata": json.dumps(node.metadata),
|
||||
@@ -306,7 +417,7 @@ async def create_graph(graph: Graph) -> Graph:
|
||||
|
||||
await asyncio.gather(
|
||||
*[
|
||||
AgentNodeLink.prisma().create(
|
||||
AgentNodeLink.prisma(tx).create(
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"sourceName": link.source_name,
|
||||
@@ -320,13 +431,6 @@ async def create_graph(graph: Graph) -> Graph:
|
||||
]
|
||||
)
|
||||
|
||||
if created_graph := await get_graph(
|
||||
graph.id, graph.version, template=graph.is_template
|
||||
):
|
||||
return created_graph
|
||||
|
||||
raise ValueError(f"Created graph {graph.id} v{graph.version} is not in DB")
|
||||
|
||||
|
||||
# --------------------- Helper functions --------------------- #
|
||||
|
||||
|
||||
@@ -420,7 +420,7 @@ class ExecutionManager(AppService):
|
||||
graph: Graph | None = self.run_and_wait(get_graph(graph_id))
|
||||
if not graph:
|
||||
raise Exception(f"Graph #{graph_id} not found.")
|
||||
graph.validate_graph()
|
||||
graph.validate_graph(for_run=True)
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Annotated, Any, Dict
|
||||
@@ -468,15 +467,7 @@ class AgentServer(AppService):
|
||||
|
||||
graph.is_template = is_template
|
||||
graph.is_active = not is_template
|
||||
|
||||
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
||||
|
||||
for node in graph.nodes:
|
||||
node.id = id_map[node.id]
|
||||
|
||||
for link in graph.links:
|
||||
link.source_id = id_map[link.source_id]
|
||||
link.sink_id = id_map[link.sink_id]
|
||||
graph.reassign_ids()
|
||||
|
||||
return await graph_db.create_graph(graph)
|
||||
|
||||
@@ -501,14 +492,7 @@ class AgentServer(AppService):
|
||||
400, detail="Changing is_template on an existing graph is forbidden"
|
||||
)
|
||||
graph.is_active = not graph.is_template
|
||||
|
||||
# Assign new UUIDs to all nodes and links
|
||||
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
||||
for node in graph.nodes:
|
||||
node.id = id_map[node.id]
|
||||
for link in graph.links:
|
||||
link.source_id = id_map[link.source_id]
|
||||
link.sink_id = id_map[link.sink_id]
|
||||
graph.reassign_ids()
|
||||
|
||||
new_graph_version = await graph_db.create_graph(graph)
|
||||
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
-- RedefineTables
|
||||
PRAGMA foreign_keys=OFF;
|
||||
CREATE TABLE "new_AgentGraph" (
|
||||
"id" TEXT NOT NULL,
|
||||
"version" INTEGER NOT NULL DEFAULT 1,
|
||||
"name" TEXT,
|
||||
"description" TEXT,
|
||||
"isActive" BOOLEAN NOT NULL DEFAULT true,
|
||||
"isTemplate" BOOLEAN NOT NULL DEFAULT false,
|
||||
"agentGraphParentId" TEXT,
|
||||
|
||||
PRIMARY KEY ("id", "version"),
|
||||
CONSTRAINT "AgentGraph_agentGraphParentId_version_fkey" FOREIGN KEY ("agentGraphParentId", "version") REFERENCES "AgentGraph" ("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
);
|
||||
INSERT INTO "new_AgentGraph" ("description", "id", "isActive", "isTemplate", "name", "version") SELECT "description", "id", "isActive", "isTemplate", "name", "version" FROM "AgentGraph";
|
||||
DROP TABLE "AgentGraph";
|
||||
ALTER TABLE "new_AgentGraph" RENAME TO "AgentGraph";
|
||||
PRAGMA foreign_key_check;
|
||||
PRAGMA foreign_keys=ON;
|
||||
@@ -0,0 +1,5 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentGraph" ADD COLUMN "agentGraphParentId" TEXT;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "AgentGraph" ADD CONSTRAINT "AgentGraph_agentGraphParentId_version_fkey" FOREIGN KEY ("agentGraphParentId", "version") REFERENCES "AgentGraph"("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
@@ -24,6 +24,11 @@ model AgentGraph {
|
||||
AgentGraphExecution AgentGraphExecution[]
|
||||
AgentGraphExecutionSchedule AgentGraphExecutionSchedule[]
|
||||
|
||||
// All sub-graphs are defined within this 1-level depth list (even if it's a nested graph).
|
||||
AgentSubGraphs AgentGraph[] @relation("AgentSubGraph")
|
||||
agentGraphParentId String?
|
||||
AgentGraphParent AgentGraph? @relation("AgentSubGraph", fields: [agentGraphParentId, version], references: [id, version])
|
||||
|
||||
@@id(name: "graphVersionId", [id, version])
|
||||
}
|
||||
|
||||
|
||||
@@ -23,6 +23,11 @@ model AgentGraph {
|
||||
AgentGraphExecution AgentGraphExecution[]
|
||||
AgentGraphExecutionSchedule AgentGraphExecutionSchedule[]
|
||||
|
||||
// All sub-graphs are defined within this 1-level depth list (even if it's a nested graph).
|
||||
AgentSubGraphs AgentGraph[] @relation("AgentSubGraph")
|
||||
agentGraphParentId String?
|
||||
AgentGraphParent AgentGraph? @relation("AgentSubGraph", fields: [agentGraphParentId, version], references: [id, version])
|
||||
|
||||
@@id(name: "graphVersionId", [id, version])
|
||||
}
|
||||
|
||||
|
||||
60
rnd/autogpt_server/test/data/graph.py
Normal file
60
rnd/autogpt_server/test/data/graph.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt_server.blocks.basic import InputBlock, ValueBlock
|
||||
from autogpt_server.data.graph import Graph, Link, Node
|
||||
from autogpt_server.server.model import CreateGraph
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_graph_creation(server):
|
||||
value_block = ValueBlock().id
|
||||
input_block = InputBlock().id
|
||||
|
||||
graph = Graph(
|
||||
id="test_graph",
|
||||
name="TestGraph",
|
||||
description="Test graph",
|
||||
nodes=[
|
||||
Node(id="node_1", block_id=value_block),
|
||||
Node(id="node_2", block_id=input_block),
|
||||
Node(id="node_3", block_id=value_block),
|
||||
],
|
||||
links=[
|
||||
Link(
|
||||
source_id="node_1",
|
||||
sink_id="node_3",
|
||||
source_name="output",
|
||||
sink_name="input",
|
||||
),
|
||||
],
|
||||
subgraphs={"subgraph_1": ["node_2", "node_3"]},
|
||||
)
|
||||
create_graph = CreateGraph(graph=graph)
|
||||
|
||||
try:
|
||||
await server.agent_server.create_graph(create_graph, False)
|
||||
assert False, "Should not be able to connect nodes from different subgraphs"
|
||||
except ValueError as e:
|
||||
assert "different subgraph" in str(e)
|
||||
|
||||
# Change node_1 <-> node_3 link to node_1 <-> node_2 (input for subgraph_1)
|
||||
graph.links[0].sink_id = "node_2"
|
||||
created_graph = await server.agent_server.create_graph(create_graph, False)
|
||||
|
||||
assert UUID(created_graph.id)
|
||||
assert created_graph.name == "TestGraph"
|
||||
|
||||
assert len(created_graph.nodes) == 3
|
||||
assert UUID(created_graph.nodes[0].id)
|
||||
assert UUID(created_graph.nodes[1].id)
|
||||
assert UUID(created_graph.nodes[2].id)
|
||||
|
||||
nodes = created_graph.nodes
|
||||
links = created_graph.links
|
||||
assert len(links) == 1
|
||||
assert {nodes[0].id, nodes[1].id} == {links[0].source_id, links[0].sink_id}
|
||||
|
||||
assert len(created_graph.subgraphs) == 1
|
||||
assert len(created_graph.subgraph_map) == len(created_graph.nodes) == 3
|
||||
Reference in New Issue
Block a user