feat(server): Add endpoint to calculate required graph inputs (#7965)

This commit is contained in:
Swifty
2024-09-04 09:45:45 +02:00
committed by GitHub
parent c2a79d2f10
commit 80df44a978
3 changed files with 204 additions and 1 deletions

View File

@@ -6,7 +6,8 @@ from typing import Any, Literal
import prisma.types
from prisma.models import AgentGraph, AgentNode, AgentNodeLink
from pydantic import PrivateAttr
from pydantic import BaseModel, PrivateAttr
from pydantic_core import PydanticUndefinedType
from autogpt_server.blocks.basic import InputBlock, OutputBlock
from autogpt_server.data.block import BlockInput, get_block, get_blocks
@@ -17,6 +18,12 @@ from autogpt_server.util import json
logger = logging.getLogger(__name__)
class InputSchemaItem(BaseModel):
node_id: str
description: str | None = None
title: str | None = None
class Link(BaseDbModel):
source_id: str
sink_id: str
@@ -235,6 +242,43 @@ class Graph(GraphMeta):
# TODO: Add type compatibility check here.
def get_input_schema(self) -> list[InputSchemaItem]:
"""
Walks the graph and returns all the inputs that are either not:
- static
- provided by parent node
"""
input_schema = []
for node in self.nodes:
block = get_block(node.block_id)
if not block:
continue
for input_name, input_schema_item in (
block.input_schema.jsonschema().get("properties", {}).items()
):
# Check if the input is not static and not provided by a parent node
if (
input_name not in node.input_default
and not any(
link.sink_name == input_name for link in node.input_links
)
and isinstance(
block.input_schema.model_fields.get(input_name).default,
PydanticUndefinedType,
)
):
input_schema.append(
InputSchemaItem(
node_id=node.id,
description=input_schema_item.get("description"),
title=input_schema_item.get("title"),
)
)
return input_schema
@staticmethod
def from_db(graph: AgentGraph):
nodes = [

View File

@@ -154,6 +154,11 @@ class AgentServer(AppService):
endpoint=self.set_graph_active_version,
methods=["PUT"],
)
router.add_api_route(
path="/graphs/{graph_id}/input_schema",
endpoint=self.get_graph_input_schema,
methods=["GET"],
)
router.add_api_route(
path="/graphs/{graph_id}/execute",
endpoint=self.execute_graph,
@@ -427,6 +432,18 @@ class AgentServer(AppService):
msg = e.__str__().encode().decode("unicode_escape")
raise HTTPException(status_code=400, detail=msg)
@classmethod
async def get_graph_input_schema(
cls,
graph_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> list[graph_db.InputSchemaItem]:
try:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
return graph.get_input_schema() if graph else []
except Exception:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
@classmethod
async def list_graph_runs(
cls,

View File

@@ -11,6 +11,17 @@ from autogpt_server.util.test import SpinTestServer
@pytest.mark.asyncio(scope="session")
async def test_graph_creation(server: SpinTestServer):
"""
Test the creation of a graph with nodes and links.
This test ensures that:
1. Nodes from different subgraphs cannot be directly connected.
2. A graph can be successfully created with valid connections.
3. The created graph has the correct structure and properties.
Args:
server (SpinTestServer): The test server instance.
"""
await create_default_user("false")
value_block = StoreValueBlock().id
@@ -66,3 +77,134 @@ async def test_graph_creation(server: SpinTestServer):
assert len(created_graph.subgraphs) == 1
assert len(created_graph.subgraph_map) == len(created_graph.nodes) == 3
@pytest.mark.asyncio(scope="session")
async def test_get_input_schema(server: SpinTestServer):
"""
Test the get_input_schema method of a created graph.
This test ensures that:
1. A graph can be created with a single node.
2. The input schema of the created graph is correctly generated.
3. The input schema contains the expected input name and node id.
Args:
server (SpinTestServer): The test server instance.
"""
value_block = StoreValueBlock().id
graph = Graph(
name="TestInputSchema",
description="Test input schema",
nodes=[
Node(id="node_1", block_id=value_block),
],
links=[],
)
create_graph = CreateGraph(graph=graph)
created_graph = await server.agent_server.create_graph(
create_graph, False, DEFAULT_USER_ID
)
input_schema = created_graph.get_input_schema()
assert len(input_schema) == 1
assert input_schema[0].title == "Input"
assert input_schema[0].node_id == created_graph.nodes[0].id
@pytest.mark.asyncio(scope="session")
async def test_get_input_schema_none_required(server: SpinTestServer):
"""
Test the get_input_schema method when no inputs are required.
This test ensures that:
1. A graph can be created with a node that has a default input value.
2. The input schema of the created graph is empty when all inputs have default values.
Args:
server (SpinTestServer): The test server instance.
"""
value_block = StoreValueBlock().id
graph = Graph(
name="TestInputSchema",
description="Test input schema",
nodes=[
Node(id="node_1", block_id=value_block, input_default={"input": "value"}),
],
links=[],
)
create_graph = CreateGraph(graph=graph)
created_graph = await server.agent_server.create_graph(
create_graph, False, DEFAULT_USER_ID
)
input_schema = created_graph.get_input_schema()
assert input_schema == []
@pytest.mark.asyncio(scope="session")
async def test_get_input_schema_with_linked_blocks(server: SpinTestServer):
"""
Test the get_input_schema method with linked blocks.
This test ensures that:
1. A graph can be created with multiple nodes and links between them.
2. The input schema correctly identifies required inputs for linked blocks.
3. Inputs that are satisfied by links are not included in the input schema.
Args:
server (SpinTestServer): The test server instance.
"""
value_block = StoreValueBlock().id
graph = Graph(
name="TestInputSchemaLinkedBlocks",
description="Test input schema with linked blocks",
nodes=[
Node(id="node_1", block_id=value_block),
Node(id="node_2", block_id=value_block),
],
links=[
Link(
source_id="node_1",
sink_id="node_2",
source_name="output",
sink_name="data",
),
],
)
create_graph = CreateGraph(graph=graph)
created_graph = await server.agent_server.create_graph(
create_graph, False, DEFAULT_USER_ID
)
input_schema = created_graph.get_input_schema()
assert len(input_schema) == 2
node_1_input = next(
(item for item in input_schema if item.node_id == created_graph.nodes[0].id),
None,
)
node_2_input = next(
(item for item in input_schema if item.node_id == created_graph.nodes[1].id),
None,
)
assert node_1_input is not None
assert node_2_input is not None
assert node_1_input.title == "Input"
assert node_2_input.title == "Input"
assert not any(
item.title == "data" and item.node_id == created_graph.nodes[1].id
for item in input_schema
)