mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 07:08:09 -05:00
refactor(backend): Introduced Graph Input & Output Schema, Merge GraphMeta & Graph, Remove subgraph functionality (#8526)
This commit is contained in:
@@ -148,9 +148,12 @@ class AgentInputBlock(Block):
|
||||
description="The value to be passed as input.",
|
||||
default=None,
|
||||
)
|
||||
description: str = SchemaField(
|
||||
title: str | None = SchemaField(
|
||||
description="The title of the input.", default=None, advanced=True
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the input.",
|
||||
default="",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
placeholder_values: List[Any] = SchemaField(
|
||||
@@ -163,6 +166,16 @@ class AgentInputBlock(Block):
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to show the input in the advanced section, if the field is not required.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
secret: bool = SchemaField(
|
||||
description="Whether the input should be treated as a secret.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: Any = SchemaField(description="The value passed as input.")
|
||||
@@ -195,6 +208,7 @@ class AgentInputBlock(Block):
|
||||
],
|
||||
categories={BlockCategory.INPUT, BlockCategory.BASIC},
|
||||
block_type=BlockType.INPUT,
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
@@ -205,28 +219,25 @@ class AgentOutputBlock(Block):
|
||||
"""
|
||||
Records the output of the graph for users to see.
|
||||
|
||||
Attributes:
|
||||
recorded_value: The value to be recorded as output.
|
||||
name: The name of the output.
|
||||
description: The description of the output.
|
||||
fmt_string: The format string to be used to format the recorded_value.
|
||||
|
||||
Outputs:
|
||||
output: The formatted recorded_value if fmt_string is provided and the recorded_value
|
||||
can be formatted, otherwise the raw recorded_value.
|
||||
|
||||
Behavior:
|
||||
If fmt_string is provided and the recorded_value is of a type that can be formatted,
|
||||
the block attempts to format the recorded_value using the fmt_string.
|
||||
If formatting fails or no fmt_string is provided, the raw recorded_value is output.
|
||||
If `format` is provided and the `value` is of a type that can be formatted,
|
||||
the block attempts to format the recorded_value using the `format`.
|
||||
If formatting fails or no `format` is provided, the raw `value` is output.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
value: Any = SchemaField(description="The value to be recorded as output.")
|
||||
value: Any = SchemaField(
|
||||
description="The value to be recorded as output.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
name: str = SchemaField(description="The name of the output.")
|
||||
description: str = SchemaField(
|
||||
title: str | None = SchemaField(
|
||||
description="The title of the input.", default=None, advanced=True
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the output.",
|
||||
default="",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
format: str = SchemaField(
|
||||
@@ -234,6 +245,16 @@ class AgentOutputBlock(Block):
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to treat the output as advanced.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
secret: bool = SchemaField(
|
||||
description="Whether the output should be treated as a secret.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="The value recorded as output.")
|
||||
@@ -271,6 +292,7 @@ class AgentOutputBlock(Block):
|
||||
],
|
||||
categories={BlockCategory.OUTPUT, BlockCategory.BASIC},
|
||||
block_type=BlockType.OUTPUT,
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
|
||||
@@ -71,11 +71,18 @@ class ConditionBlock(Block):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
value1 = input_data.value1
|
||||
operator = input_data.operator
|
||||
|
||||
value1 = input_data.value1
|
||||
if isinstance(value1, str):
|
||||
value1 = float(value1.strip())
|
||||
|
||||
value2 = input_data.value2
|
||||
if isinstance(value2, str):
|
||||
value2 = float(value2.strip())
|
||||
|
||||
yes_value = input_data.yes_value if input_data.yes_value is not None else value1
|
||||
no_value = input_data.no_value if input_data.no_value is not None else value1
|
||||
no_value = input_data.no_value if input_data.no_value is not None else value2
|
||||
|
||||
comparison_funcs = {
|
||||
ComparisonOperator.EQUAL: lambda a, b: a == b,
|
||||
@@ -86,17 +93,11 @@ class ConditionBlock(Block):
|
||||
ComparisonOperator.LESS_THAN_OR_EQUAL: lambda a, b: a <= b,
|
||||
}
|
||||
|
||||
try:
|
||||
result = comparison_funcs[operator](value1, value2)
|
||||
result = comparison_funcs[operator](value1, value2)
|
||||
|
||||
yield "result", result
|
||||
yield "result", result
|
||||
|
||||
if result:
|
||||
yield "yes_output", yes_value
|
||||
else:
|
||||
yield "no_output", no_value
|
||||
|
||||
except Exception:
|
||||
yield "result", None
|
||||
yield "yes_output", None
|
||||
yield "no_output", None
|
||||
if result:
|
||||
yield "yes_output", yes_value
|
||||
else:
|
||||
yield "no_output", no_value
|
||||
|
||||
@@ -9,14 +9,11 @@ from prisma.models import (
|
||||
AgentNodeExecution,
|
||||
AgentNodeExecutionInputOutput,
|
||||
)
|
||||
from prisma.types import (
|
||||
AgentGraphExecutionInclude,
|
||||
AgentGraphExecutionWhereInput,
|
||||
AgentNodeExecutionInclude,
|
||||
)
|
||||
from prisma.types import AgentGraphExecutionWhereInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import BlockData, BlockInput, CompletedBlockOutput
|
||||
from backend.data.includes import EXECUTION_RESULT_INCLUDE, GRAPH_EXECUTION_INCLUDE
|
||||
from backend.util import json, mock
|
||||
|
||||
|
||||
@@ -110,24 +107,6 @@ class ExecutionResult(BaseModel):
|
||||
|
||||
# --------------------- Model functions --------------------- #
|
||||
|
||||
EXECUTION_RESULT_INCLUDE: AgentNodeExecutionInclude = {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"AgentNode": True,
|
||||
"AgentGraphExecution": True,
|
||||
}
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE: AgentGraphExecutionInclude = {
|
||||
"AgentNodeExecutions": {
|
||||
"include": {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"AgentNode": True,
|
||||
"AgentGraphExecution": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
async def create_graph_execution(
|
||||
graph_id: str,
|
||||
@@ -268,21 +247,9 @@ async def update_graph_execution_start_time(graph_exec_id: str):
|
||||
|
||||
async def update_graph_execution_stats(
|
||||
graph_exec_id: str,
|
||||
error: Exception | None,
|
||||
wall_time: float,
|
||||
cpu_time: float,
|
||||
node_count: int,
|
||||
stats: dict[str, Any],
|
||||
):
|
||||
status = ExecutionStatus.FAILED if error else ExecutionStatus.COMPLETED
|
||||
stats = (
|
||||
{
|
||||
"walltime": wall_time,
|
||||
"cputime": cpu_time,
|
||||
"nodecount": node_count,
|
||||
"error": str(error) if error else None,
|
||||
},
|
||||
)
|
||||
|
||||
status = ExecutionStatus.FAILED if stats.get("error") else ExecutionStatus.COMPLETED
|
||||
await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
data={
|
||||
|
||||
@@ -3,29 +3,22 @@ import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Type
|
||||
|
||||
import prisma.types
|
||||
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink
|
||||
from prisma.types import AgentGraphInclude
|
||||
from pydantic import BaseModel
|
||||
from pydantic_core import PydanticUndefinedType
|
||||
from prisma.types import AgentGraphWhereInput
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock, BlockType
|
||||
from backend.data.block import BlockInput, get_block, get_blocks
|
||||
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock
|
||||
from backend.data.block import BlockInput, BlockType, get_block, get_blocks
|
||||
from backend.data.db import BaseDbModel, transaction
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
||||
from backend.util import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InputSchemaItem(BaseModel):
|
||||
node_id: str
|
||||
description: str | None = None
|
||||
title: str | None = None
|
||||
|
||||
|
||||
class Link(BaseDbModel):
|
||||
source_id: str
|
||||
sink_id: str
|
||||
@@ -70,7 +63,7 @@ class Node(BaseDbModel):
|
||||
return obj
|
||||
|
||||
|
||||
class ExecutionMeta(BaseDbModel):
|
||||
class GraphExecution(BaseDbModel):
|
||||
execution_id: str
|
||||
started_at: datetime
|
||||
ended_at: datetime
|
||||
@@ -79,20 +72,19 @@ class ExecutionMeta(BaseDbModel):
|
||||
status: ExecutionStatus
|
||||
|
||||
@staticmethod
|
||||
def from_agent_graph_execution(execution: AgentGraphExecution):
|
||||
def from_db(execution: AgentGraphExecution):
|
||||
now = datetime.now(timezone.utc)
|
||||
start_time = execution.startedAt or execution.createdAt
|
||||
end_time = execution.updatedAt or now
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
total_run_time = duration
|
||||
|
||||
total_run_time = 0
|
||||
if execution.AgentNodeExecutions:
|
||||
for node_execution in execution.AgentNodeExecutions:
|
||||
node_start = node_execution.startedTime or now
|
||||
node_end = node_execution.endedTime or now
|
||||
total_run_time += (node_end - node_start).total_seconds()
|
||||
if execution.stats:
|
||||
stats = json.loads(execution.stats)
|
||||
duration = stats.get("walltime", duration)
|
||||
total_run_time = stats.get("nodes_walltime", total_run_time)
|
||||
|
||||
return ExecutionMeta(
|
||||
return GraphExecution(
|
||||
id=execution.id,
|
||||
execution_id=execution.id,
|
||||
started_at=start_time,
|
||||
@@ -103,39 +95,70 @@ class ExecutionMeta(BaseDbModel):
|
||||
)
|
||||
|
||||
|
||||
class GraphMeta(BaseDbModel):
|
||||
class Graph(BaseDbModel):
|
||||
version: int = 1
|
||||
is_active: bool = True
|
||||
is_template: bool = False
|
||||
name: str
|
||||
description: str
|
||||
executions: list[ExecutionMeta] | None = None
|
||||
executions: list[GraphExecution] = []
|
||||
nodes: list[Node] = []
|
||||
links: list[Link] = []
|
||||
|
||||
@staticmethod
|
||||
def from_db(graph: AgentGraph):
|
||||
if graph.AgentGraphExecution:
|
||||
executions = [
|
||||
ExecutionMeta.from_agent_graph_execution(execution)
|
||||
for execution in graph.AgentGraphExecution
|
||||
]
|
||||
else:
|
||||
executions = None
|
||||
def _generate_schema(
|
||||
type_class: Type[AgentInputBlock.Input] | Type[AgentOutputBlock.Input],
|
||||
data: list[dict],
|
||||
) -> dict[str, Any]:
|
||||
props = []
|
||||
for p in data:
|
||||
try:
|
||||
props.append(type_class(**p))
|
||||
except Exception as e:
|
||||
logger.warning(f"Invalid {type_class}: {p}, {e}")
|
||||
|
||||
return GraphMeta(
|
||||
id=graph.id,
|
||||
version=graph.version,
|
||||
is_active=graph.isActive,
|
||||
is_template=graph.isTemplate,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
executions=executions,
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
p.name: {
|
||||
"secret": p.secret,
|
||||
"advanced": p.advanced,
|
||||
"title": p.title or p.name,
|
||||
**({"description": p.description} if p.description else {}),
|
||||
**({"default": p.value} if p.value is not None else {}),
|
||||
}
|
||||
for p in props
|
||||
},
|
||||
"required": [p.name for p in props if p.value is None],
|
||||
}
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
return self._generate_schema(
|
||||
AgentInputBlock.Input,
|
||||
[
|
||||
node.input_default
|
||||
for node in self.nodes
|
||||
if (b := get_block(node.block_id))
|
||||
and b.block_type == BlockType.INPUT
|
||||
and "name" in node.input_default
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class Graph(GraphMeta):
|
||||
nodes: list[Node]
|
||||
links: list[Link]
|
||||
subgraphs: dict[str, list[str]] = {} # subgraph_id -> [node_id]
|
||||
@computed_field
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
return self._generate_schema(
|
||||
AgentOutputBlock.Input,
|
||||
[
|
||||
node.input_default
|
||||
for node in self.nodes
|
||||
if (b := get_block(node.block_id))
|
||||
and b.block_type == BlockType.OUTPUT
|
||||
and "name" in node.input_default
|
||||
],
|
||||
)
|
||||
|
||||
@property
|
||||
def starting_nodes(self) -> list[Node]:
|
||||
@@ -143,7 +166,7 @@ class Graph(GraphMeta):
|
||||
input_nodes = {
|
||||
v.id
|
||||
for v in self.nodes
|
||||
if isinstance(get_block(v.block_id), AgentInputBlock)
|
||||
if (b := get_block(v.block_id)) and b.block_type == BlockType.INPUT
|
||||
}
|
||||
return [
|
||||
node
|
||||
@@ -151,28 +174,6 @@ class Graph(GraphMeta):
|
||||
if node.id not in outbound_nodes or node.id in input_nodes
|
||||
]
|
||||
|
||||
@property
|
||||
def ending_nodes(self) -> list[Node]:
|
||||
return [
|
||||
v for v in self.nodes if isinstance(get_block(v.block_id), AgentOutputBlock)
|
||||
]
|
||||
|
||||
@property
|
||||
def subgraph_map(self) -> dict[str, str]:
|
||||
"""
|
||||
Returns a mapping of node_id to subgraph_id.
|
||||
A node in the main graph will be mapped to the graph's id.
|
||||
"""
|
||||
subgraph_map = {
|
||||
node_id: subgraph_id
|
||||
for subgraph_id, node_ids in self.subgraphs.items()
|
||||
for node_id in node_ids
|
||||
}
|
||||
subgraph_map.update(
|
||||
{node.id: self.id for node in self.nodes if node.id not in subgraph_map}
|
||||
)
|
||||
return subgraph_map
|
||||
|
||||
def reassign_ids(self, reassign_graph_id: bool = False):
|
||||
"""
|
||||
Reassigns all IDs in the graph to new UUIDs.
|
||||
@@ -180,11 +181,7 @@ class Graph(GraphMeta):
|
||||
"""
|
||||
self.validate_graph()
|
||||
|
||||
id_map = {
|
||||
**{node.id: str(uuid.uuid4()) for node in self.nodes},
|
||||
**{subgraph_id: str(uuid.uuid4()) for subgraph_id in self.subgraphs},
|
||||
}
|
||||
|
||||
id_map = {node.id: str(uuid.uuid4()) for node in self.nodes}
|
||||
if reassign_graph_id:
|
||||
self.id = str(uuid.uuid4())
|
||||
|
||||
@@ -195,11 +192,6 @@ class Graph(GraphMeta):
|
||||
link.source_id = id_map[link.source_id]
|
||||
link.sink_id = id_map[link.sink_id]
|
||||
|
||||
self.subgraphs = {
|
||||
id_map[subgraph_id]: [id_map[node_id] for node_id in node_ids]
|
||||
for subgraph_id, node_ids in self.subgraphs.items()
|
||||
}
|
||||
|
||||
def validate_graph(self, for_run: bool = False):
|
||||
def sanitize(name):
|
||||
return name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
|
||||
@@ -227,6 +219,7 @@ class Graph(GraphMeta):
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} required input missing: `{name}`"
|
||||
)
|
||||
|
||||
node_map = {v.id: v for v in self.nodes}
|
||||
|
||||
def is_static_output_block(nid: str) -> bool:
|
||||
@@ -234,18 +227,6 @@ class Graph(GraphMeta):
|
||||
b = get_block(bid)
|
||||
return b.static_output if b else False
|
||||
|
||||
def is_input_output_block(nid: str) -> bool:
|
||||
bid = node_map[nid].block_id
|
||||
b = get_block(bid)
|
||||
return isinstance(b, AgentInputBlock) or isinstance(b, AgentOutputBlock)
|
||||
|
||||
# subgraphs: all nodes in subgraph must be present in the graph.
|
||||
for subgraph_id, node_ids in self.subgraphs.items():
|
||||
for node_id in node_ids:
|
||||
if node_id not in node_map:
|
||||
raise ValueError(f"Subgraph {subgraph_id}'s node {node_id} invalid")
|
||||
subgraph_map = self.subgraph_map
|
||||
|
||||
# Links: links are connected and the connected pin data type are compatible.
|
||||
for link in self.links:
|
||||
source = (link.source_id, link.source_name)
|
||||
@@ -274,66 +255,27 @@ class Graph(GraphMeta):
|
||||
if sanitized_name not in fields:
|
||||
raise ValueError(f"{suffix}, `{name}` invalid, {fields}")
|
||||
|
||||
if (
|
||||
subgraph_map.get(link.source_id) != subgraph_map.get(link.sink_id)
|
||||
and not is_input_output_block(link.source_id)
|
||||
and not is_input_output_block(link.sink_id)
|
||||
):
|
||||
raise ValueError(f"{suffix}, Connecting nodes from different subgraph.")
|
||||
|
||||
if is_static_output_block(link.source_id):
|
||||
link.is_static = True # Each value block output should be static.
|
||||
|
||||
# TODO: Add type compatibility check here.
|
||||
|
||||
def get_input_schema(self) -> list[InputSchemaItem]:
|
||||
"""
|
||||
Walks the graph and returns all the inputs that are either not:
|
||||
- static
|
||||
- provided by parent node
|
||||
"""
|
||||
input_schema = []
|
||||
for node in self.nodes:
|
||||
block = get_block(node.block_id)
|
||||
if not block:
|
||||
continue
|
||||
|
||||
for input_name, input_schema_item in (
|
||||
block.input_schema.jsonschema().get("properties", {}).items()
|
||||
):
|
||||
# Check if the input is not static and not provided by a parent node
|
||||
if (
|
||||
input_name not in node.input_default
|
||||
and not any(
|
||||
link.sink_name == input_name for link in node.input_links
|
||||
)
|
||||
and isinstance(
|
||||
block.input_schema.model_fields.get(input_name).default,
|
||||
PydanticUndefinedType,
|
||||
)
|
||||
):
|
||||
input_schema.append(
|
||||
InputSchemaItem(
|
||||
node_id=node.id,
|
||||
description=input_schema_item.get("description"),
|
||||
title=input_schema_item.get("title"),
|
||||
)
|
||||
)
|
||||
|
||||
return input_schema
|
||||
|
||||
@staticmethod
|
||||
def from_db(graph: AgentGraph, hide_credentials: bool = False):
|
||||
nodes = [
|
||||
*(graph.AgentNodes or []),
|
||||
*(
|
||||
node
|
||||
for subgraph in graph.AgentSubGraphs or []
|
||||
for node in subgraph.AgentNodes or []
|
||||
),
|
||||
executions = [
|
||||
GraphExecution.from_db(execution)
|
||||
for execution in graph.AgentGraphExecution or []
|
||||
]
|
||||
nodes = graph.AgentNodes or []
|
||||
|
||||
return Graph(
|
||||
**GraphMeta.from_db(graph).model_dump(),
|
||||
id=graph.id,
|
||||
version=graph.version,
|
||||
is_active=graph.isActive,
|
||||
is_template=graph.isTemplate,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
executions=executions,
|
||||
nodes=[Graph._process_node(node, hide_credentials) for node in nodes],
|
||||
links=list(
|
||||
{
|
||||
@@ -342,10 +284,6 @@ class Graph(GraphMeta):
|
||||
for link in (node.Input or []) + (node.Output or [])
|
||||
}
|
||||
),
|
||||
subgraphs={
|
||||
subgraph.id: [node.id for node in subgraph.AgentNodes or []]
|
||||
for subgraph in graph.AgentSubGraphs or []
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -374,20 +312,6 @@ class Graph(GraphMeta):
|
||||
return result
|
||||
|
||||
|
||||
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"AgentBlock": True,
|
||||
}
|
||||
|
||||
__SUBGRAPH_INCLUDE = {"AgentNodes": {"include": AGENT_NODE_INCLUDE}}
|
||||
|
||||
AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
|
||||
**__SUBGRAPH_INCLUDE,
|
||||
"AgentSubGraphs": {"include": __SUBGRAPH_INCLUDE}, # type: ignore
|
||||
}
|
||||
|
||||
|
||||
# --------------------- Model functions --------------------- #
|
||||
|
||||
|
||||
@@ -399,11 +323,11 @@ async def get_node(node_id: str) -> Node:
|
||||
return Node.from_db(node)
|
||||
|
||||
|
||||
async def get_graphs_meta(
|
||||
async def get_graphs(
|
||||
user_id: str,
|
||||
include_executions: bool = False,
|
||||
filter_by: Literal["active", "template"] | None = "active",
|
||||
) -> list[GraphMeta]:
|
||||
) -> list[Graph]:
|
||||
"""
|
||||
Retrieves graph metadata objects.
|
||||
Default behaviour is to get all currently active graphs.
|
||||
@@ -414,9 +338,9 @@ async def get_graphs_meta(
|
||||
user_id: The ID of the user that owns the graph.
|
||||
|
||||
Returns:
|
||||
list[GraphMeta]: A list of objects representing the retrieved graph metadata.
|
||||
list[Graph]: A list of objects representing the retrieved graph metadata.
|
||||
"""
|
||||
where_clause: prisma.types.AgentGraphWhereInput = {}
|
||||
where_clause: AgentGraphWhereInput = {}
|
||||
|
||||
if filter_by == "active":
|
||||
where_clause["isActive"] = True
|
||||
@@ -425,23 +349,17 @@ async def get_graphs_meta(
|
||||
|
||||
where_clause["userId"] = user_id
|
||||
|
||||
graph_include = AGENT_GRAPH_INCLUDE
|
||||
graph_include["AgentGraphExecution"] = include_executions
|
||||
|
||||
graphs = await AgentGraph.prisma().find_many(
|
||||
where=where_clause,
|
||||
distinct=["id"],
|
||||
order={"version": "desc"},
|
||||
include=(
|
||||
AgentGraphInclude(
|
||||
AgentGraphExecution={"include": {"AgentNodeExecutions": True}}
|
||||
)
|
||||
if include_executions
|
||||
else None
|
||||
),
|
||||
include=graph_include,
|
||||
)
|
||||
|
||||
if not graphs:
|
||||
return []
|
||||
|
||||
return [GraphMeta.from_db(graph) for graph in graphs]
|
||||
return [Graph.from_db(graph) for graph in graphs]
|
||||
|
||||
|
||||
async def get_graph(
|
||||
@@ -458,7 +376,7 @@ async def get_graph(
|
||||
|
||||
Returns `None` if the record is not found.
|
||||
"""
|
||||
where_clause: prisma.types.AgentGraphWhereInput = {
|
||||
where_clause: AgentGraphWhereInput = {
|
||||
"id": graph_id,
|
||||
"isTemplate": template,
|
||||
}
|
||||
@@ -467,7 +385,7 @@ async def get_graph(
|
||||
elif not template:
|
||||
where_clause["isActive"] = True
|
||||
|
||||
if user_id and not template:
|
||||
if user_id is not None and not template:
|
||||
where_clause["userId"] = user_id
|
||||
|
||||
graph = await AgentGraph.prisma().find_first(
|
||||
@@ -553,33 +471,13 @@ async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
}
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
*[
|
||||
AgentGraph.prisma(tx).create(
|
||||
data={
|
||||
"id": subgraph_id,
|
||||
"agentGraphParentId": graph.id,
|
||||
"version": graph.version,
|
||||
"name": f"SubGraph of {graph.name}",
|
||||
"description": f"Sub-Graph of {graph.id}",
|
||||
"isTemplate": graph.is_template,
|
||||
"isActive": graph.is_active,
|
||||
"userId": user_id,
|
||||
}
|
||||
)
|
||||
for subgraph_id in graph.subgraphs
|
||||
]
|
||||
)
|
||||
|
||||
subgraph_map = graph.subgraph_map
|
||||
|
||||
await asyncio.gather(
|
||||
*[
|
||||
AgentNode.prisma(tx).create(
|
||||
{
|
||||
"id": node.id,
|
||||
"agentBlockId": node.block_id,
|
||||
"agentGraphId": subgraph_map.get(node.id, graph.id),
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"constantInput": json.dumps(node.input_default),
|
||||
"metadata": json.dumps(node.metadata),
|
||||
|
||||
29
autogpt_platform/backend/backend/data/includes.py
Normal file
29
autogpt_platform/backend/backend/data/includes.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import prisma
|
||||
|
||||
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"AgentBlock": True,
|
||||
}
|
||||
|
||||
AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
|
||||
}
|
||||
|
||||
EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"AgentNode": True,
|
||||
"AgentGraphExecution": True,
|
||||
}
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
|
||||
"AgentNodeExecutions": {
|
||||
"include": {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"AgentNode": True,
|
||||
"AgentGraphExecution": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -473,7 +473,7 @@ class Executor:
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecution],
|
||||
node_exec: NodeExecution,
|
||||
):
|
||||
) -> dict[str, Any]:
|
||||
log_metadata = LogMetadata(
|
||||
user_id=node_exec.user_id,
|
||||
graph_eid=node_exec.graph_exec_id,
|
||||
@@ -493,6 +493,7 @@ class Executor:
|
||||
cls.db_client.update_node_execution_stats(
|
||||
node_exec.node_exec_id, execution_stats
|
||||
)
|
||||
return execution_stats
|
||||
|
||||
@classmethod
|
||||
@time_measured
|
||||
@@ -556,16 +557,15 @@ class Executor:
|
||||
node_eid="*",
|
||||
block_name="-",
|
||||
)
|
||||
timing_info, (node_count, error) = cls._on_graph_execution(
|
||||
timing_info, (exec_stats, error) = cls._on_graph_execution(
|
||||
graph_exec, cancel, log_metadata
|
||||
)
|
||||
|
||||
exec_stats["walltime"] = timing_info.wall_time
|
||||
exec_stats["cputime"] = timing_info.cpu_time
|
||||
exec_stats["error"] = str(error) if error else None
|
||||
cls.db_client.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
error=error,
|
||||
wall_time=timing_info.wall_time,
|
||||
cpu_time=timing_info.cpu_time,
|
||||
node_count=node_count,
|
||||
stats=exec_stats,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -575,14 +575,18 @@ class Executor:
|
||||
graph_exec: GraphExecution,
|
||||
cancel: threading.Event,
|
||||
log_metadata: LogMetadata,
|
||||
) -> tuple[int, Exception | None]:
|
||||
) -> tuple[dict[str, Any], Exception | None]:
|
||||
"""
|
||||
Returns:
|
||||
The number of node executions completed.
|
||||
The execution statistics of the graph execution.
|
||||
The error that occurred during the execution.
|
||||
"""
|
||||
log_metadata.info(f"Start graph execution {graph_exec.graph_exec_id}")
|
||||
n_node_executions = 0
|
||||
exec_stats = {
|
||||
"nodes_walltime": 0,
|
||||
"nodes_cputime": 0,
|
||||
"node_count": 0,
|
||||
}
|
||||
error = None
|
||||
finished = False
|
||||
|
||||
@@ -608,17 +612,20 @@ class Executor:
|
||||
def make_exec_callback(exec_data: NodeExecution):
|
||||
node_id = exec_data.node_id
|
||||
|
||||
def callback(_):
|
||||
def callback(result: object):
|
||||
running_executions.pop(node_id)
|
||||
nonlocal n_node_executions
|
||||
n_node_executions += 1
|
||||
nonlocal exec_stats
|
||||
if isinstance(result, dict):
|
||||
exec_stats["node_count"] += 1
|
||||
exec_stats["nodes_cputime"] += result.get("cputime", 0)
|
||||
exec_stats["nodes_walltime"] += result.get("walltime", 0)
|
||||
|
||||
return callback
|
||||
|
||||
while not queue.empty():
|
||||
if cancel.is_set():
|
||||
error = RuntimeError("Execution is cancelled")
|
||||
return n_node_executions, error
|
||||
return exec_stats, error
|
||||
|
||||
exec_data = queue.get()
|
||||
|
||||
@@ -649,7 +656,7 @@ class Executor:
|
||||
for node_id, execution in list(running_executions.items()):
|
||||
if cancel.is_set():
|
||||
error = RuntimeError("Execution is cancelled")
|
||||
return n_node_executions, error
|
||||
return exec_stats, error
|
||||
|
||||
if not queue.empty():
|
||||
break # yield to parent loop to execute new queue items
|
||||
@@ -668,7 +675,7 @@ class Executor:
|
||||
finished = True
|
||||
cancel.set()
|
||||
cancel_thread.join()
|
||||
return n_node_executions, error
|
||||
return exec_stats, error
|
||||
|
||||
|
||||
class ExecutionManager(AppService):
|
||||
|
||||
@@ -121,8 +121,8 @@ class DeleteGraphResponse(TypedDict):
|
||||
async def get_graphs(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
with_runs: bool = False,
|
||||
) -> list[graph_db.GraphMeta]:
|
||||
return await graph_db.get_graphs_meta(
|
||||
) -> list[graph_db.Graph]:
|
||||
return await graph_db.get_graphs(
|
||||
include_executions=with_runs, filter_by="active", user_id=user_id
|
||||
)
|
||||
|
||||
@@ -290,22 +290,6 @@ async def stop_graph_run(
|
||||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/input_schema",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_graph_input_schema(
|
||||
graph_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[graph_db.InputSchemaItem]:
|
||||
try:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
return graph.get_input_schema() if graph else []
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/executions",
|
||||
tags=["graphs"],
|
||||
@@ -374,8 +358,8 @@ async def get_graph_run_status(
|
||||
)
|
||||
async def get_templates(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[graph_db.GraphMeta]:
|
||||
return await graph_db.get_graphs_meta(filter_by="template", user_id=user_id)
|
||||
) -> list[graph_db.Graph]:
|
||||
return await graph_db.get_graphs(filter_by="template", user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- You are about to drop the column `agentGraphParentId` on the `AgentGraph` table. All the data in the column will be lost.
|
||||
|
||||
*/
|
||||
-- DropForeignKey
|
||||
ALTER TABLE "AgentGraph" DROP CONSTRAINT "AgentGraph_agentGraphParentId_version_fkey";
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentGraph" DROP COLUMN "agentGraphParentId";
|
||||
@@ -0,0 +1,4 @@
|
||||
-- This migration converts the stats column from a list to an object.
|
||||
UPDATE "AgentGraphExecution"
|
||||
SET "stats" = (stats::jsonb -> 0)::text
|
||||
WHERE stats IS NOT NULL AND jsonb_typeof(stats::jsonb) = 'array';
|
||||
@@ -52,11 +52,6 @@ model AgentGraph {
|
||||
AgentGraphExecution AgentGraphExecution[]
|
||||
AgentGraphExecutionSchedule AgentGraphExecutionSchedule[]
|
||||
|
||||
// All sub-graphs are defined within this 1-level depth list (even if it's a nested graph).
|
||||
AgentSubGraphs AgentGraph[] @relation("AgentSubGraph")
|
||||
agentGraphParentId String?
|
||||
AgentGraphParent AgentGraph? @relation("AgentSubGraph", fields: [agentGraphParentId, version], references: [id, version], onDelete: Cascade)
|
||||
|
||||
@@id(name: "graphVersionId", [id, version])
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.basic import AgentInputBlock, StoreValueBlock
|
||||
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock, StoreValueBlock
|
||||
from backend.data.block import BlockSchema
|
||||
from backend.data.graph import Graph, Link, Node
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.util.test import SpinTestServer
|
||||
@@ -15,9 +18,8 @@ async def test_graph_creation(server: SpinTestServer):
|
||||
Test the creation of a graph with nodes and links.
|
||||
|
||||
This test ensures that:
|
||||
1. Nodes from different subgraphs cannot be directly connected.
|
||||
2. A graph can be successfully created with valid connections.
|
||||
3. The created graph has the correct structure and properties.
|
||||
1. A graph can be successfully created with valid connections.
|
||||
2. The created graph has the correct structure and properties.
|
||||
|
||||
Args:
|
||||
server (SpinTestServer): The test server instance.
|
||||
@@ -37,23 +39,13 @@ async def test_graph_creation(server: SpinTestServer):
|
||||
links=[
|
||||
Link(
|
||||
source_id="node_1",
|
||||
sink_id="node_3",
|
||||
sink_id="node_2",
|
||||
source_name="output",
|
||||
sink_name="input",
|
||||
sink_name="name",
|
||||
),
|
||||
],
|
||||
subgraphs={"subgraph_1": ["node_2", "node_3"]},
|
||||
)
|
||||
create_graph = CreateGraph(graph=graph)
|
||||
|
||||
try:
|
||||
await server.agent_server.test_create_graph(create_graph, DEFAULT_USER_ID)
|
||||
assert False, "Should not be able to connect nodes from different subgraphs"
|
||||
except ValueError as e:
|
||||
assert "different subgraph" in str(e)
|
||||
|
||||
# Change node_1 <-> node_3 link to node_1 <-> node_2 (input for subgraph_1)
|
||||
graph.links[0].sink_id = "node_2"
|
||||
created_graph = await server.agent_server.test_create_graph(
|
||||
create_graph, DEFAULT_USER_ID
|
||||
)
|
||||
@@ -73,9 +65,6 @@ async def test_graph_creation(server: SpinTestServer):
|
||||
assert links[0].source_id in {nodes[0].id, nodes[1].id, nodes[2].id}
|
||||
assert links[0].sink_id in {nodes[0].id, nodes[1].id, nodes[2].id}
|
||||
|
||||
assert len(created_graph.subgraphs) == 1
|
||||
assert len(created_graph.subgraph_map) == len(created_graph.nodes) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_input_schema(server: SpinTestServer):
|
||||
@@ -91,90 +80,54 @@ async def test_get_input_schema(server: SpinTestServer):
|
||||
server (SpinTestServer): The test server instance.
|
||||
"""
|
||||
value_block = StoreValueBlock().id
|
||||
input_block = AgentInputBlock().id
|
||||
output_block = AgentOutputBlock().id
|
||||
|
||||
graph = Graph(
|
||||
name="TestInputSchema",
|
||||
description="Test input schema",
|
||||
nodes=[
|
||||
Node(id="node_1", block_id=value_block),
|
||||
],
|
||||
links=[],
|
||||
)
|
||||
|
||||
create_graph = CreateGraph(graph=graph)
|
||||
created_graph = await server.agent_server.test_create_graph(
|
||||
create_graph, DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
input_schema = created_graph.get_input_schema()
|
||||
|
||||
assert len(input_schema) == 1
|
||||
|
||||
assert input_schema[0].title == "Input"
|
||||
assert input_schema[0].node_id == created_graph.nodes[0].id
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_input_schema_none_required(server: SpinTestServer):
|
||||
"""
|
||||
Test the get_input_schema method when no inputs are required.
|
||||
|
||||
This test ensures that:
|
||||
1. A graph can be created with a node that has a default input value.
|
||||
2. The input schema of the created graph is empty when all inputs have default values.
|
||||
|
||||
Args:
|
||||
server (SpinTestServer): The test server instance.
|
||||
"""
|
||||
value_block = StoreValueBlock().id
|
||||
|
||||
graph = Graph(
|
||||
name="TestInputSchema",
|
||||
description="Test input schema",
|
||||
nodes=[
|
||||
Node(id="node_1", block_id=value_block, input_default={"input": "value"}),
|
||||
],
|
||||
links=[],
|
||||
)
|
||||
|
||||
create_graph = CreateGraph(graph=graph)
|
||||
created_graph = await server.agent_server.test_create_graph(
|
||||
create_graph, DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
input_schema = created_graph.get_input_schema()
|
||||
|
||||
assert input_schema == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_input_schema_with_linked_blocks(server: SpinTestServer):
|
||||
"""
|
||||
Test the get_input_schema method with linked blocks.
|
||||
|
||||
This test ensures that:
|
||||
1. A graph can be created with multiple nodes and links between them.
|
||||
2. The input schema correctly identifies required inputs for linked blocks.
|
||||
3. Inputs that are satisfied by links are not included in the input schema.
|
||||
|
||||
Args:
|
||||
server (SpinTestServer): The test server instance.
|
||||
"""
|
||||
value_block = StoreValueBlock().id
|
||||
|
||||
graph = Graph(
|
||||
name="TestInputSchemaLinkedBlocks",
|
||||
description="Test input schema with linked blocks",
|
||||
nodes=[
|
||||
Node(id="node_1", block_id=value_block),
|
||||
Node(id="node_2", block_id=value_block),
|
||||
Node(
|
||||
id="node_0_a",
|
||||
block_id=input_block,
|
||||
input_default={"name": "in_key_a", "title": "Key A", "value": "A"},
|
||||
metadata={"id": "node_0_a"},
|
||||
),
|
||||
Node(
|
||||
id="node_0_b",
|
||||
block_id=input_block,
|
||||
input_default={"name": "in_key_b", "advanced": True},
|
||||
metadata={"id": "node_0_b"},
|
||||
),
|
||||
Node(id="node_1", block_id=value_block, metadata={"id": "node_1"}),
|
||||
Node(
|
||||
id="node_2",
|
||||
block_id=output_block,
|
||||
input_default={
|
||||
"name": "out_key",
|
||||
"description": "This is an output key",
|
||||
},
|
||||
metadata={"id": "node_2"},
|
||||
),
|
||||
],
|
||||
links=[
|
||||
Link(
|
||||
source_id="node_0_a",
|
||||
sink_id="node_1",
|
||||
source_name="result",
|
||||
sink_name="input",
|
||||
),
|
||||
Link(
|
||||
source_id="node_0_b",
|
||||
sink_id="node_1",
|
||||
source_name="result",
|
||||
sink_name="input",
|
||||
),
|
||||
Link(
|
||||
source_id="node_1",
|
||||
sink_id="node_2",
|
||||
source_name="output",
|
||||
sink_name="data",
|
||||
sink_name="value",
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -184,25 +137,21 @@ async def test_get_input_schema_with_linked_blocks(server: SpinTestServer):
|
||||
create_graph, DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
input_schema = created_graph.get_input_schema()
|
||||
class ExpectedInputSchema(BlockSchema):
|
||||
in_key_a: Any = SchemaField(title="Key A", default="A", advanced=False)
|
||||
in_key_b: Any = SchemaField(title="in_key_b", advanced=True)
|
||||
|
||||
assert len(input_schema) == 2
|
||||
class ExpectedOutputSchema(BlockSchema):
|
||||
out_key: Any = SchemaField(
|
||||
description="This is an output key",
|
||||
title="out_key",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
node_1_input = next(
|
||||
(item for item in input_schema if item.node_id == created_graph.nodes[0].id),
|
||||
None,
|
||||
)
|
||||
node_2_input = next(
|
||||
(item for item in input_schema if item.node_id == created_graph.nodes[1].id),
|
||||
None,
|
||||
)
|
||||
input_schema = created_graph.input_schema
|
||||
input_schema["title"] = "ExpectedInputSchema"
|
||||
assert input_schema == ExpectedInputSchema.jsonschema()
|
||||
|
||||
assert node_1_input is not None
|
||||
assert node_2_input is not None
|
||||
assert node_1_input.title == "Input"
|
||||
assert node_2_input.title == "Input"
|
||||
|
||||
assert not any(
|
||||
item.title == "data" and item.node_id == created_graph.nodes[1].id
|
||||
for item in input_schema
|
||||
)
|
||||
output_schema = created_graph.output_schema
|
||||
output_schema["title"] = "ExpectedOutputSchema"
|
||||
assert output_schema == ExpectedOutputSchema.jsonschema()
|
||||
|
||||
Reference in New Issue
Block a user