Compare commits

...

1 Commits

Author SHA1 Message Date
SwiftyOS
fe05cae794 added custom node rendering 2024-09-09 11:42:50 +02:00
8 changed files with 222 additions and 106 deletions

View File

@@ -14,6 +14,7 @@ import {
BlockIORootSchema,
Category,
NodeExecutionResult,
BlockUIType,
} from "@/lib/autogpt-server-api/types";
import { beautifyString, cn, setNestedProperty } from "@/lib/utils";
import { Button } from "@/components/ui/button";
@@ -59,6 +60,7 @@ export type CustomNodeData = {
backend_id?: string;
errors?: { [key: string]: string };
isOutputStatic?: boolean;
uiType: BlockUIType;
};
export type CustomNode = Node<CustomNodeData, "custom">;
@@ -118,8 +120,16 @@ export function CustomNode({ data, id, width, height }: NodeProps<CustomNode>) {
setIsAdvancedOpen(checked);
};
const generateOutputHandles = (schema: BlockIORootSchema) => {
if (!schema?.properties) return null;
const generateOutputHandles = (
schema: BlockIORootSchema,
nodeType: BlockUIType,
) => {
if (
!schema?.properties ||
nodeType === BlockUIType.OUTPUT ||
nodeType === BlockUIType.NOTE
)
return null;
const keys = Object.keys(schema.properties);
return keys.map((key) => (
<div key={key}>
@@ -133,6 +143,136 @@ export function CustomNode({ data, id, width, height }: NodeProps<CustomNode>) {
));
};
const generateInputHandles = (
schema: BlockIORootSchema,
nodeType: BlockUIType,
) => {
if (!schema?.properties) return null;
let keys = Object.entries(schema.properties);
switch (nodeType) {
case BlockUIType.INPUT:
// For INPUT blocks, only show the 'value' property
return keys.map(([propKey, propSchema]) => {
const isRequired = data.inputSchema.required?.includes(propKey);
const isConnected = isHandleConnected(propKey);
const isAdvanced = propSchema.advanced;
return (
<div key={propKey}>
<span className="text-m green -mb-1 text-gray-900">
{propSchema.title || beautifyString(propKey)}
</span>
<div key={propKey} onMouseOver={() => {}}>
{!isConnected && (
<NodeGenericInputField
className="mb-2 mt-1"
propKey={propKey}
propSchema={propSchema}
currentValue={getValue(propKey)}
connections={data.connections}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
errors={data.errors ?? {}}
displayName={propSchema.title || beautifyString(propKey)}
/>
)}
</div>
</div>
);
});
case BlockUIType.NOTE:
// For NOTE blocks, don't render any input handles
return keys.map(([propKey, propSchema]) => {
const isConnected = isHandleConnected(propKey);
return (
<div key={propKey}>
<NodeGenericInputField
className="mb-2 mt-1"
propKey={propKey}
propSchema={propSchema}
currentValue={getValue(propKey)}
connections={data.connections}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
errors={data.errors ?? {}}
displayName={propSchema.title || beautifyString(propKey)}
/>
</div>
);
});
case BlockUIType.OUTPUT:
// For OUTPUT blocks, only show the 'recorded_value' property
return keys.map(([propKey, propSchema]) => {
const isRequired = data.inputSchema.required?.includes(propKey);
const isConnected = isHandleConnected(propKey);
const isAdvanced = propSchema.advanced;
return (
<div key={propKey} onMouseOver={() => {}}>
{propKey !== "value" ? (
<span className="text-m green -mb-1 text-gray-900">
{propSchema.title || beautifyString(propKey)}
</span>
) : (
<NodeHandle
keyName={propKey}
isConnected={isConnected}
isRequired={isRequired}
schema={propSchema}
side="left"
/>
)}
{!isConnected && (
<NodeGenericInputField
className="mb-2 mt-1"
propKey={propKey}
propSchema={propSchema}
currentValue={getValue(propKey)}
connections={data.connections}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
errors={data.errors ?? {}}
displayName={propSchema.title || beautifyString(propKey)}
/>
)}
</div>
);
});
default:
return keys.map(([propKey, propSchema]) => {
const isRequired = data.inputSchema.required?.includes(propKey);
const isConnected = isHandleConnected(propKey);
const isAdvanced = propSchema.advanced;
return (
(isRequired || isAdvancedOpen || isConnected || !isAdvanced) && (
<div key={propKey} onMouseOver={() => {}}>
<NodeHandle
keyName={propKey}
isConnected={isConnected}
isRequired={isRequired}
schema={propSchema}
side="left"
/>
{!isConnected && (
<NodeGenericInputField
className="mb-2 mt-1"
propKey={propKey}
propSchema={propSchema}
currentValue={getValue(propKey)}
connections={data.connections}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
errors={data.errors ?? {}}
displayName={propSchema.title || beautifyString(propKey)}
/>
)}
</div>
)
);
});
}
};
const handleInputChange = (path: string, value: any) => {
const keys = parseKeys(path);
const newValues = JSON.parse(JSON.stringify(data.hardcodedValues));
@@ -420,47 +560,11 @@ export function CustomNode({ data, id, width, height }: NodeProps<CustomNode>) {
<div className="flex items-start justify-between gap-2 p-3">
<div>
{data.inputSchema &&
Object.entries(data.inputSchema.properties).map(
([propKey, propSchema]) => {
const isRequired = data.inputSchema.required?.includes(propKey);
const isConnected = isHandleConnected(propKey);
const isAdvanced = propSchema.advanced;
return (
(isRequired ||
isAdvancedOpen ||
isConnected ||
!isAdvanced) && (
<div key={propKey} onMouseOver={() => {}}>
<NodeHandle
keyName={propKey}
isConnected={isConnected}
isRequired={isRequired}
schema={propSchema}
side="left"
/>
{!isConnected && (
<NodeGenericInputField
className="mb-2 mt-1"
propKey={propKey}
propSchema={propSchema}
currentValue={getValue(propKey)}
connections={data.connections}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
errors={data.errors ?? {}}
displayName={
propSchema.title || beautifyString(propKey)
}
/>
)}
</div>
)
);
},
)}
generateInputHandles(data.inputSchema, data.uiType)}
</div>
<div className="flex-none">
{data.outputSchema && generateOutputHandles(data.outputSchema)}
{data.outputSchema &&
generateOutputHandles(data.outputSchema, data.uiType)}
</div>
</div>
{isOutputOpen && (
@@ -486,25 +590,27 @@ export function CustomNode({ data, id, width, height }: NodeProps<CustomNode>) {
)}
</div>
)}
<div className="mt-2.5 flex items-center pb-4 pl-4">
<Switch checked={isOutputOpen} onCheckedChange={toggleOutput} />
<span className="m-1 mr-4">Output</span>
{hasAdvancedFields && (
<>
<Switch onCheckedChange={toggleAdvancedSettings} />
<span className="m-1">Advanced</span>
</>
)}
{data.status && (
<Badge
variant="outline"
data-id={`badge-${id}-${data.status}`}
className={cn(data.status.toLowerCase(), "ml-auto mr-5")}
>
{data.status}
</Badge>
)}
</div>
{data.uiType !== BlockUIType.NOTE && (
<div className="mt-2.5 flex items-center pb-4 pl-4">
<Switch checked={isOutputOpen} onCheckedChange={toggleOutput} />
<span className="m-1 mr-4">Output</span>
{hasAdvancedFields && data.uiType === BlockUIType.STANDARD && (
<>
<Switch onCheckedChange={toggleAdvancedSettings} />
<span className="m-1">Advanced</span>
</>
)}
{data.status && (
<Badge
variant="outline"
data-id={`badge-${id}-${data.status}`}
className={cn(data.status.toLowerCase(), "ml-auto mr-5")}
>
{data.status}
</Badge>
)}
</div>
)}
<InputModalComponent
title={activeKey ? `Enter ${beautifyString(activeKey)}` : undefined}
isOpen={isModalOpen}

View File

@@ -417,6 +417,7 @@ const FlowEditor: React.FC<{
isOutputOpen: false,
block_id: blockId,
isOutputStatic: nodeSchema.staticOutput,
uiType: nodeSchema.uiType,
},
};

View File

@@ -13,6 +13,7 @@ export type Block = {
inputSchema: BlockIORootSchema;
outputSchema: BlockIORootSchema;
staticOutput: boolean;
uiType: BlockUIType;
};
export type BlockIORootSchema = {
@@ -182,3 +183,10 @@ export type User = {
id: string;
email: string;
};
export enum BlockUIType {
STANDARD = "Standard",
INPUT = "Input",
OUTPUT = "Output",
NOTE = "Note",
}

View File

@@ -1,5 +1,8 @@
from typing import Any, List
from abc import ABC, abstractmethod
import re
from typing import Any, Generic, List, TypeVar
from jinja2 import BaseLoader, Environment
from pydantic import Field
from autogpt_server.data.block import (
@@ -12,6 +15,8 @@ from autogpt_server.data.block import (
from autogpt_server.data.model import SchemaField
from autogpt_server.util.mock import MockObject
jinja = Environment(loader=BaseLoader())
class StoreValueBlock(Block):
"""
@@ -136,7 +141,7 @@ class FindInDictionaryBlock(Block):
yield "missing", input_data.input
class InputBlock(Block):
class AgentInputBlock(Block):
"""
This block is used to provide input to the graph.
@@ -164,8 +169,8 @@ class InputBlock(Block):
super().__init__(
id="c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
description="This block is used to provide input to the graph.",
input_schema=InputBlock.Input,
output_schema=InputBlock.Output,
input_schema=AgentInputBlock.Input,
output_schema=AgentInputBlock.Output,
test_input=[
{
"value": "Hello, World!",
@@ -194,7 +199,7 @@ class InputBlock(Block):
yield "result", input_data.value
class OutputBlock(Block):
class AgentOutputBlock(Block):
"""
Records the output of the graph for users to see.
@@ -215,12 +220,10 @@ class OutputBlock(Block):
"""
class Input(BlockSchema):
recorded_value: Any = SchemaField(
description="The value to be recorded as output."
)
value: Any = SchemaField(description="The value to be recorded as output.")
name: str = SchemaField(description="The name of the output.")
description: str = SchemaField(description="The description of the output.")
fmt_string: str = SchemaField(
format: str = SchemaField(
description="The format string to be used to format the recorded_value."
)
@@ -238,30 +241,30 @@ class OutputBlock(Block):
"This block is key for capturing and presenting final results or "
"important intermediate outputs of the graph execution."
),
input_schema=OutputBlock.Input,
output_schema=OutputBlock.Output,
input_schema=AgentOutputBlock.Input,
output_schema=AgentOutputBlock.Output,
test_input=[
{
"recorded_value": "Hello, World!",
"value": "Hello, World!",
"name": "output_1",
"description": "This is a test output.",
"fmt_string": "{value}",
"format": "{value}!!",
},
{
"recorded_value": 42,
"value": 42,
"name": "output_2",
"description": "This is another test output.",
"fmt_string": "{value}",
"format": "{value}",
},
{
"recorded_value": MockObject(value="!!", key="key"),
"value": MockObject(value="!!", key="key"),
"name": "output_3",
"description": "This is a test output with a mock object.",
"fmt_string": "{value}",
"format": "{value}",
},
],
test_output=[
("output", "Hello, World!"),
("output", "Hello, World!!!"),
("output", 42),
("output", MockObject(value="!!", key="key")),
],
@@ -274,13 +277,15 @@ class OutputBlock(Block):
Attempts to format the recorded_value using the fmt_string if provided.
If formatting fails or no fmt_string is given, returns the original recorded_value.
"""
if input_data.fmt_string:
if input_data.format:
try:
yield "output", input_data.fmt_string.format(input_data.recorded_value)
fmt = re.sub(r"(?<!{){[ a-zA-Z0-9_]+}", r"{\g<0>}", input_data.format)
template = jinja.from_string(fmt)
yield "output", template.render(input_data.value)
except Exception:
yield "output", input_data.recorded_value
yield "output", f"Error: {input_data.value}"
else:
yield "output", input_data.recorded_value
yield "output", input_data.value
class AddToDictionaryBlock(Block):

View File

@@ -9,7 +9,7 @@ from prisma.models import AgentGraph, AgentNode, AgentNodeLink
from pydantic import BaseModel, PrivateAttr
from pydantic_core import PydanticUndefinedType
from autogpt_server.blocks.basic import InputBlock, OutputBlock
from autogpt_server.blocks.basic import AgentInputBlock, AgentOutputBlock
from autogpt_server.data.block import BlockInput, get_block, get_blocks
from autogpt_server.data.db import BaseDbModel, transaction
from autogpt_server.data.user import DEFAULT_USER_ID
@@ -106,7 +106,9 @@ class Graph(GraphMeta):
def starting_nodes(self) -> list[Node]:
outbound_nodes = {link.sink_id for link in self.links}
input_nodes = {
v.id for v in self.nodes if isinstance(get_block(v.block_id), InputBlock)
v.id
for v in self.nodes
if isinstance(get_block(v.block_id), AgentInputBlock)
}
return [
node
@@ -116,7 +118,9 @@ class Graph(GraphMeta):
@property
def ending_nodes(self) -> list[Node]:
return [v for v in self.nodes if isinstance(get_block(v.block_id), OutputBlock)]
return [
v for v in self.nodes if isinstance(get_block(v.block_id), AgentOutputBlock)
]
@property
def subgraph_map(self) -> dict[str, str]:
@@ -179,7 +183,9 @@ class Graph(GraphMeta):
+ [sanitize(link.sink_name) for link in node.input_links]
)
for name in block.input_schema.get_required_fields():
if name not in provided_inputs and not isinstance(block, InputBlock):
if name not in provided_inputs and not isinstance(
block, AgentInputBlock
):
raise ValueError(
f"Node {block.name} #{node.id} required input missing: `{name}`"
)
@@ -193,7 +199,7 @@ class Graph(GraphMeta):
def is_input_output_block(nid: str) -> bool:
bid = node_map[nid].block_id
b = get_block(bid)
return isinstance(b, InputBlock) or isinstance(b, OutputBlock)
return isinstance(b, AgentInputBlock) or isinstance(b, AgentOutputBlock)
# subgraphs: all nodes in subgraph must be present in the graph.
for subgraph_id, node_ids in self.subgraphs.items():

View File

@@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar
if TYPE_CHECKING:
from autogpt_server.server.rest_api import AgentServer
from autogpt_server.blocks.basic import InputBlock
from autogpt_server.blocks.basic import AgentInputBlock
from autogpt_server.data import db
from autogpt_server.data.block import Block, BlockData, BlockInput, get_block
from autogpt_server.data.execution import (
@@ -698,7 +698,7 @@ class ExecutionManager(AppService):
nodes_input = []
for node in graph.starting_nodes:
input_data = {}
if isinstance(get_block(node.block_id), InputBlock):
if isinstance(get_block(node.block_id), AgentInputBlock):
name = node.input_default.get("name")
if name and name in data:
input_data = {"value": data[name]}

View File

@@ -1,6 +1,6 @@
from prisma.models import User
from autogpt_server.blocks.basic import InputBlock, PrintToConsoleBlock
from autogpt_server.blocks.basic import AgentInputBlock, PrintToConsoleBlock
from autogpt_server.blocks.text import FillTextTemplateBlock
from autogpt_server.data import graph
from autogpt_server.data.graph import create_graph
@@ -28,22 +28,12 @@ def create_test_graph() -> graph.Graph:
"""
nodes = [
graph.Node(
block_id=InputBlock().id,
input_default={
"name": "input_1",
"description": "First input value",
"placeholder_values": [],
"limit_to_placeholder_values": False,
},
block_id=AgentInputBlock().id,
input_default={"name": "input_1"},
),
graph.Node(
block_id=InputBlock().id,
input_default={
"name": "input_2",
"description": "Second input value",
"placeholder_values": [],
"limit_to_placeholder_values": False,
},
block_id=AgentInputBlock().id,
input_default={"name": "input_2"},
),
graph.Node(
block_id=FillTextTemplateBlock().id,

View File

@@ -2,7 +2,7 @@ from uuid import UUID
import pytest
from autogpt_server.blocks.basic import InputBlock, StoreValueBlock
from autogpt_server.blocks.basic import AgentInputBlock, StoreValueBlock
from autogpt_server.data.graph import Graph, Link, Node
from autogpt_server.data.user import DEFAULT_USER_ID, create_default_user
from autogpt_server.server.model import CreateGraph
@@ -25,7 +25,7 @@ async def test_graph_creation(server: SpinTestServer):
await create_default_user("false")
value_block = StoreValueBlock().id
input_block = InputBlock().id
input_block = AgentInputBlock().id
graph = Graph(
id="test_graph",