Compare commits

...

1 Commits

Author SHA1 Message Date
SwiftyOS
fe05cae794 added custom node rendering 2024-09-09 11:42:50 +02:00
8 changed files with 222 additions and 106 deletions

View File

@@ -14,6 +14,7 @@ import {
BlockIORootSchema, BlockIORootSchema,
Category, Category,
NodeExecutionResult, NodeExecutionResult,
BlockUIType,
} from "@/lib/autogpt-server-api/types"; } from "@/lib/autogpt-server-api/types";
import { beautifyString, cn, setNestedProperty } from "@/lib/utils"; import { beautifyString, cn, setNestedProperty } from "@/lib/utils";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
@@ -59,6 +60,7 @@ export type CustomNodeData = {
backend_id?: string; backend_id?: string;
errors?: { [key: string]: string }; errors?: { [key: string]: string };
isOutputStatic?: boolean; isOutputStatic?: boolean;
uiType: BlockUIType;
}; };
export type CustomNode = Node<CustomNodeData, "custom">; export type CustomNode = Node<CustomNodeData, "custom">;
@@ -118,8 +120,16 @@ export function CustomNode({ data, id, width, height }: NodeProps<CustomNode>) {
setIsAdvancedOpen(checked); setIsAdvancedOpen(checked);
}; };
const generateOutputHandles = (schema: BlockIORootSchema) => { const generateOutputHandles = (
if (!schema?.properties) return null; schema: BlockIORootSchema,
nodeType: BlockUIType,
) => {
if (
!schema?.properties ||
nodeType === BlockUIType.OUTPUT ||
nodeType === BlockUIType.NOTE
)
return null;
const keys = Object.keys(schema.properties); const keys = Object.keys(schema.properties);
return keys.map((key) => ( return keys.map((key) => (
<div key={key}> <div key={key}>
@@ -133,6 +143,136 @@ export function CustomNode({ data, id, width, height }: NodeProps<CustomNode>) {
)); ));
}; };
const generateInputHandles = (
schema: BlockIORootSchema,
nodeType: BlockUIType,
) => {
if (!schema?.properties) return null;
let keys = Object.entries(schema.properties);
switch (nodeType) {
case BlockUIType.INPUT:
// For INPUT blocks, only show the 'value' property
return keys.map(([propKey, propSchema]) => {
const isRequired = data.inputSchema.required?.includes(propKey);
const isConnected = isHandleConnected(propKey);
const isAdvanced = propSchema.advanced;
return (
<div key={propKey}>
<span className="text-m green -mb-1 text-gray-900">
{propSchema.title || beautifyString(propKey)}
</span>
<div key={propKey} onMouseOver={() => {}}>
{!isConnected && (
<NodeGenericInputField
className="mb-2 mt-1"
propKey={propKey}
propSchema={propSchema}
currentValue={getValue(propKey)}
connections={data.connections}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
errors={data.errors ?? {}}
displayName={propSchema.title || beautifyString(propKey)}
/>
)}
</div>
</div>
);
});
case BlockUIType.NOTE:
// For NOTE blocks, don't render any input handles
return keys.map(([propKey, propSchema]) => {
const isConnected = isHandleConnected(propKey);
return (
<div key={propKey}>
<NodeGenericInputField
className="mb-2 mt-1"
propKey={propKey}
propSchema={propSchema}
currentValue={getValue(propKey)}
connections={data.connections}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
errors={data.errors ?? {}}
displayName={propSchema.title || beautifyString(propKey)}
/>
</div>
);
});
case BlockUIType.OUTPUT:
// For OUTPUT blocks, only show the 'recorded_value' property
return keys.map(([propKey, propSchema]) => {
const isRequired = data.inputSchema.required?.includes(propKey);
const isConnected = isHandleConnected(propKey);
const isAdvanced = propSchema.advanced;
return (
<div key={propKey} onMouseOver={() => {}}>
{propKey !== "value" ? (
<span className="text-m green -mb-1 text-gray-900">
{propSchema.title || beautifyString(propKey)}
</span>
) : (
<NodeHandle
keyName={propKey}
isConnected={isConnected}
isRequired={isRequired}
schema={propSchema}
side="left"
/>
)}
{!isConnected && (
<NodeGenericInputField
className="mb-2 mt-1"
propKey={propKey}
propSchema={propSchema}
currentValue={getValue(propKey)}
connections={data.connections}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
errors={data.errors ?? {}}
displayName={propSchema.title || beautifyString(propKey)}
/>
)}
</div>
);
});
default:
return keys.map(([propKey, propSchema]) => {
const isRequired = data.inputSchema.required?.includes(propKey);
const isConnected = isHandleConnected(propKey);
const isAdvanced = propSchema.advanced;
return (
(isRequired || isAdvancedOpen || isConnected || !isAdvanced) && (
<div key={propKey} onMouseOver={() => {}}>
<NodeHandle
keyName={propKey}
isConnected={isConnected}
isRequired={isRequired}
schema={propSchema}
side="left"
/>
{!isConnected && (
<NodeGenericInputField
className="mb-2 mt-1"
propKey={propKey}
propSchema={propSchema}
currentValue={getValue(propKey)}
connections={data.connections}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
errors={data.errors ?? {}}
displayName={propSchema.title || beautifyString(propKey)}
/>
)}
</div>
)
);
});
}
};
const handleInputChange = (path: string, value: any) => { const handleInputChange = (path: string, value: any) => {
const keys = parseKeys(path); const keys = parseKeys(path);
const newValues = JSON.parse(JSON.stringify(data.hardcodedValues)); const newValues = JSON.parse(JSON.stringify(data.hardcodedValues));
@@ -420,47 +560,11 @@ export function CustomNode({ data, id, width, height }: NodeProps<CustomNode>) {
<div className="flex items-start justify-between gap-2 p-3"> <div className="flex items-start justify-between gap-2 p-3">
<div> <div>
{data.inputSchema && {data.inputSchema &&
Object.entries(data.inputSchema.properties).map( generateInputHandles(data.inputSchema, data.uiType)}
([propKey, propSchema]) => {
const isRequired = data.inputSchema.required?.includes(propKey);
const isConnected = isHandleConnected(propKey);
const isAdvanced = propSchema.advanced;
return (
(isRequired ||
isAdvancedOpen ||
isConnected ||
!isAdvanced) && (
<div key={propKey} onMouseOver={() => {}}>
<NodeHandle
keyName={propKey}
isConnected={isConnected}
isRequired={isRequired}
schema={propSchema}
side="left"
/>
{!isConnected && (
<NodeGenericInputField
className="mb-2 mt-1"
propKey={propKey}
propSchema={propSchema}
currentValue={getValue(propKey)}
connections={data.connections}
handleInputChange={handleInputChange}
handleInputClick={handleInputClick}
errors={data.errors ?? {}}
displayName={
propSchema.title || beautifyString(propKey)
}
/>
)}
</div>
)
);
},
)}
</div> </div>
<div className="flex-none"> <div className="flex-none">
{data.outputSchema && generateOutputHandles(data.outputSchema)} {data.outputSchema &&
generateOutputHandles(data.outputSchema, data.uiType)}
</div> </div>
</div> </div>
{isOutputOpen && ( {isOutputOpen && (
@@ -486,10 +590,11 @@ export function CustomNode({ data, id, width, height }: NodeProps<CustomNode>) {
)} )}
</div> </div>
)} )}
{data.uiType !== BlockUIType.NOTE && (
<div className="mt-2.5 flex items-center pb-4 pl-4"> <div className="mt-2.5 flex items-center pb-4 pl-4">
<Switch checked={isOutputOpen} onCheckedChange={toggleOutput} /> <Switch checked={isOutputOpen} onCheckedChange={toggleOutput} />
<span className="m-1 mr-4">Output</span> <span className="m-1 mr-4">Output</span>
{hasAdvancedFields && ( {hasAdvancedFields && data.uiType === BlockUIType.STANDARD && (
<> <>
<Switch onCheckedChange={toggleAdvancedSettings} /> <Switch onCheckedChange={toggleAdvancedSettings} />
<span className="m-1">Advanced</span> <span className="m-1">Advanced</span>
@@ -505,6 +610,7 @@ export function CustomNode({ data, id, width, height }: NodeProps<CustomNode>) {
</Badge> </Badge>
)} )}
</div> </div>
)}
<InputModalComponent <InputModalComponent
title={activeKey ? `Enter ${beautifyString(activeKey)}` : undefined} title={activeKey ? `Enter ${beautifyString(activeKey)}` : undefined}
isOpen={isModalOpen} isOpen={isModalOpen}

View File

@@ -417,6 +417,7 @@ const FlowEditor: React.FC<{
isOutputOpen: false, isOutputOpen: false,
block_id: blockId, block_id: blockId,
isOutputStatic: nodeSchema.staticOutput, isOutputStatic: nodeSchema.staticOutput,
uiType: nodeSchema.uiType,
}, },
}; };

View File

@@ -13,6 +13,7 @@ export type Block = {
inputSchema: BlockIORootSchema; inputSchema: BlockIORootSchema;
outputSchema: BlockIORootSchema; outputSchema: BlockIORootSchema;
staticOutput: boolean; staticOutput: boolean;
uiType: BlockUIType;
}; };
export type BlockIORootSchema = { export type BlockIORootSchema = {
@@ -182,3 +183,10 @@ export type User = {
id: string; id: string;
email: string; email: string;
}; };
export enum BlockUIType {
STANDARD = "Standard",
INPUT = "Input",
OUTPUT = "Output",
NOTE = "Note",
}

View File

@@ -1,5 +1,8 @@
from typing import Any, List from abc import ABC, abstractmethod
import re
from typing import Any, Generic, List, TypeVar
from jinja2 import BaseLoader, Environment
from pydantic import Field from pydantic import Field
from autogpt_server.data.block import ( from autogpt_server.data.block import (
@@ -12,6 +15,8 @@ from autogpt_server.data.block import (
from autogpt_server.data.model import SchemaField from autogpt_server.data.model import SchemaField
from autogpt_server.util.mock import MockObject from autogpt_server.util.mock import MockObject
jinja = Environment(loader=BaseLoader())
class StoreValueBlock(Block): class StoreValueBlock(Block):
""" """
@@ -136,7 +141,7 @@ class FindInDictionaryBlock(Block):
yield "missing", input_data.input yield "missing", input_data.input
class InputBlock(Block): class AgentInputBlock(Block):
""" """
This block is used to provide input to the graph. This block is used to provide input to the graph.
@@ -164,8 +169,8 @@ class InputBlock(Block):
super().__init__( super().__init__(
id="c0a8e994-ebf1-4a9c-a4d8-89d09c86741b", id="c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
description="This block is used to provide input to the graph.", description="This block is used to provide input to the graph.",
input_schema=InputBlock.Input, input_schema=AgentInputBlock.Input,
output_schema=InputBlock.Output, output_schema=AgentInputBlock.Output,
test_input=[ test_input=[
{ {
"value": "Hello, World!", "value": "Hello, World!",
@@ -194,7 +199,7 @@ class InputBlock(Block):
yield "result", input_data.value yield "result", input_data.value
class OutputBlock(Block): class AgentOutputBlock(Block):
""" """
Records the output of the graph for users to see. Records the output of the graph for users to see.
@@ -215,12 +220,10 @@ class OutputBlock(Block):
""" """
class Input(BlockSchema): class Input(BlockSchema):
recorded_value: Any = SchemaField( value: Any = SchemaField(description="The value to be recorded as output.")
description="The value to be recorded as output."
)
name: str = SchemaField(description="The name of the output.") name: str = SchemaField(description="The name of the output.")
description: str = SchemaField(description="The description of the output.") description: str = SchemaField(description="The description of the output.")
fmt_string: str = SchemaField( format: str = SchemaField(
description="The format string to be used to format the recorded_value." description="The format string to be used to format the recorded_value."
) )
@@ -238,30 +241,30 @@ class OutputBlock(Block):
"This block is key for capturing and presenting final results or " "This block is key for capturing and presenting final results or "
"important intermediate outputs of the graph execution." "important intermediate outputs of the graph execution."
), ),
input_schema=OutputBlock.Input, input_schema=AgentOutputBlock.Input,
output_schema=OutputBlock.Output, output_schema=AgentOutputBlock.Output,
test_input=[ test_input=[
{ {
"recorded_value": "Hello, World!", "value": "Hello, World!",
"name": "output_1", "name": "output_1",
"description": "This is a test output.", "description": "This is a test output.",
"fmt_string": "{value}", "format": "{value}!!",
}, },
{ {
"recorded_value": 42, "value": 42,
"name": "output_2", "name": "output_2",
"description": "This is another test output.", "description": "This is another test output.",
"fmt_string": "{value}", "format": "{value}",
}, },
{ {
"recorded_value": MockObject(value="!!", key="key"), "value": MockObject(value="!!", key="key"),
"name": "output_3", "name": "output_3",
"description": "This is a test output with a mock object.", "description": "This is a test output with a mock object.",
"fmt_string": "{value}", "format": "{value}",
}, },
], ],
test_output=[ test_output=[
("output", "Hello, World!"), ("output", "Hello, World!!!"),
("output", 42), ("output", 42),
("output", MockObject(value="!!", key="key")), ("output", MockObject(value="!!", key="key")),
], ],
@@ -274,13 +277,15 @@ class OutputBlock(Block):
Attempts to format the recorded_value using the fmt_string if provided. Attempts to format the recorded_value using the fmt_string if provided.
If formatting fails or no fmt_string is given, returns the original recorded_value. If formatting fails or no fmt_string is given, returns the original recorded_value.
""" """
if input_data.fmt_string: if input_data.format:
try: try:
yield "output", input_data.fmt_string.format(input_data.recorded_value) fmt = re.sub(r"(?<!{){[ a-zA-Z0-9_]+}", r"{\g<0>}", input_data.format)
template = jinja.from_string(fmt)
yield "output", template.render(input_data.value)
except Exception: except Exception:
yield "output", input_data.recorded_value yield "output", f"Error: {input_data.value}"
else: else:
yield "output", input_data.recorded_value yield "output", input_data.value
class AddToDictionaryBlock(Block): class AddToDictionaryBlock(Block):

View File

@@ -9,7 +9,7 @@ from prisma.models import AgentGraph, AgentNode, AgentNodeLink
from pydantic import BaseModel, PrivateAttr from pydantic import BaseModel, PrivateAttr
from pydantic_core import PydanticUndefinedType from pydantic_core import PydanticUndefinedType
from autogpt_server.blocks.basic import InputBlock, OutputBlock from autogpt_server.blocks.basic import AgentInputBlock, AgentOutputBlock
from autogpt_server.data.block import BlockInput, get_block, get_blocks from autogpt_server.data.block import BlockInput, get_block, get_blocks
from autogpt_server.data.db import BaseDbModel, transaction from autogpt_server.data.db import BaseDbModel, transaction
from autogpt_server.data.user import DEFAULT_USER_ID from autogpt_server.data.user import DEFAULT_USER_ID
@@ -106,7 +106,9 @@ class Graph(GraphMeta):
def starting_nodes(self) -> list[Node]: def starting_nodes(self) -> list[Node]:
outbound_nodes = {link.sink_id for link in self.links} outbound_nodes = {link.sink_id for link in self.links}
input_nodes = { input_nodes = {
v.id for v in self.nodes if isinstance(get_block(v.block_id), InputBlock) v.id
for v in self.nodes
if isinstance(get_block(v.block_id), AgentInputBlock)
} }
return [ return [
node node
@@ -116,7 +118,9 @@ class Graph(GraphMeta):
@property @property
def ending_nodes(self) -> list[Node]: def ending_nodes(self) -> list[Node]:
return [v for v in self.nodes if isinstance(get_block(v.block_id), OutputBlock)] return [
v for v in self.nodes if isinstance(get_block(v.block_id), AgentOutputBlock)
]
@property @property
def subgraph_map(self) -> dict[str, str]: def subgraph_map(self) -> dict[str, str]:
@@ -179,7 +183,9 @@ class Graph(GraphMeta):
+ [sanitize(link.sink_name) for link in node.input_links] + [sanitize(link.sink_name) for link in node.input_links]
) )
for name in block.input_schema.get_required_fields(): for name in block.input_schema.get_required_fields():
if name not in provided_inputs and not isinstance(block, InputBlock): if name not in provided_inputs and not isinstance(
block, AgentInputBlock
):
raise ValueError( raise ValueError(
f"Node {block.name} #{node.id} required input missing: `{name}`" f"Node {block.name} #{node.id} required input missing: `{name}`"
) )
@@ -193,7 +199,7 @@ class Graph(GraphMeta):
def is_input_output_block(nid: str) -> bool: def is_input_output_block(nid: str) -> bool:
bid = node_map[nid].block_id bid = node_map[nid].block_id
b = get_block(bid) b = get_block(bid)
return isinstance(b, InputBlock) or isinstance(b, OutputBlock) return isinstance(b, AgentInputBlock) or isinstance(b, AgentOutputBlock)
# subgraphs: all nodes in subgraph must be present in the graph. # subgraphs: all nodes in subgraph must be present in the graph.
for subgraph_id, node_ids in self.subgraphs.items(): for subgraph_id, node_ids in self.subgraphs.items():

View File

@@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
from autogpt_server.server.rest_api import AgentServer from autogpt_server.server.rest_api import AgentServer
from autogpt_server.blocks.basic import InputBlock from autogpt_server.blocks.basic import AgentInputBlock
from autogpt_server.data import db from autogpt_server.data import db
from autogpt_server.data.block import Block, BlockData, BlockInput, get_block from autogpt_server.data.block import Block, BlockData, BlockInput, get_block
from autogpt_server.data.execution import ( from autogpt_server.data.execution import (
@@ -698,7 +698,7 @@ class ExecutionManager(AppService):
nodes_input = [] nodes_input = []
for node in graph.starting_nodes: for node in graph.starting_nodes:
input_data = {} input_data = {}
if isinstance(get_block(node.block_id), InputBlock): if isinstance(get_block(node.block_id), AgentInputBlock):
name = node.input_default.get("name") name = node.input_default.get("name")
if name and name in data: if name and name in data:
input_data = {"value": data[name]} input_data = {"value": data[name]}

View File

@@ -1,6 +1,6 @@
from prisma.models import User from prisma.models import User
from autogpt_server.blocks.basic import InputBlock, PrintToConsoleBlock from autogpt_server.blocks.basic import AgentInputBlock, PrintToConsoleBlock
from autogpt_server.blocks.text import FillTextTemplateBlock from autogpt_server.blocks.text import FillTextTemplateBlock
from autogpt_server.data import graph from autogpt_server.data import graph
from autogpt_server.data.graph import create_graph from autogpt_server.data.graph import create_graph
@@ -28,22 +28,12 @@ def create_test_graph() -> graph.Graph:
""" """
nodes = [ nodes = [
graph.Node( graph.Node(
block_id=InputBlock().id, block_id=AgentInputBlock().id,
input_default={ input_default={"name": "input_1"},
"name": "input_1",
"description": "First input value",
"placeholder_values": [],
"limit_to_placeholder_values": False,
},
), ),
graph.Node( graph.Node(
block_id=InputBlock().id, block_id=AgentInputBlock().id,
input_default={ input_default={"name": "input_2"},
"name": "input_2",
"description": "Second input value",
"placeholder_values": [],
"limit_to_placeholder_values": False,
},
), ),
graph.Node( graph.Node(
block_id=FillTextTemplateBlock().id, block_id=FillTextTemplateBlock().id,

View File

@@ -2,7 +2,7 @@ from uuid import UUID
import pytest import pytest
from autogpt_server.blocks.basic import InputBlock, StoreValueBlock from autogpt_server.blocks.basic import AgentInputBlock, StoreValueBlock
from autogpt_server.data.graph import Graph, Link, Node from autogpt_server.data.graph import Graph, Link, Node
from autogpt_server.data.user import DEFAULT_USER_ID, create_default_user from autogpt_server.data.user import DEFAULT_USER_ID, create_default_user
from autogpt_server.server.model import CreateGraph from autogpt_server.server.model import CreateGraph
@@ -25,7 +25,7 @@ async def test_graph_creation(server: SpinTestServer):
await create_default_user("false") await create_default_user("false")
value_block = StoreValueBlock().id value_block = StoreValueBlock().id
input_block = InputBlock().id input_block = AgentInputBlock().id
graph = Graph( graph = Graph(
id="test_graph", id="test_graph",