refactor(backend): Introduced Graph Input & Output Schema, Merge GraphMeta & Graph, Remove subgraph functionality (#8526)

This commit is contained in:
Zamil Majdy
2024-11-07 09:30:51 +07:00
committed by GitHub
parent af9ea5bc31
commit 86c544177e
11 changed files with 285 additions and 418 deletions

View File

@@ -148,9 +148,12 @@ class AgentInputBlock(Block):
description="The value to be passed as input.",
default=None,
)
description: str = SchemaField(
title: str | None = SchemaField(
description="The title of the input.", default=None, advanced=True
)
description: str | None = SchemaField(
description="The description of the input.",
default="",
default=None,
advanced=True,
)
placeholder_values: List[Any] = SchemaField(
@@ -163,6 +166,16 @@ class AgentInputBlock(Block):
default=False,
advanced=True,
)
advanced: bool = SchemaField(
description="Whether to show the input in the advanced section, if the field is not required.",
default=False,
advanced=True,
)
secret: bool = SchemaField(
description="Whether the input should be treated as a secret.",
default=False,
advanced=True,
)
class Output(BlockSchema):
result: Any = SchemaField(description="The value passed as input.")
@@ -195,6 +208,7 @@ class AgentInputBlock(Block):
],
categories={BlockCategory.INPUT, BlockCategory.BASIC},
block_type=BlockType.INPUT,
static_output=True,
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
@@ -205,28 +219,25 @@ class AgentOutputBlock(Block):
"""
Records the output of the graph for users to see.
Attributes:
recorded_value: The value to be recorded as output.
name: The name of the output.
description: The description of the output.
fmt_string: The format string to be used to format the recorded_value.
Outputs:
output: The formatted recorded_value if fmt_string is provided and the recorded_value
can be formatted, otherwise the raw recorded_value.
Behavior:
If fmt_string is provided and the recorded_value is of a type that can be formatted,
the block attempts to format the recorded_value using the fmt_string.
If formatting fails or no fmt_string is provided, the raw recorded_value is output.
If `format` is provided and the `value` is of a type that can be formatted,
the block attempts to format the recorded_value using the `format`.
If formatting fails or no `format` is provided, the raw `value` is output.
"""
class Input(BlockSchema):
value: Any = SchemaField(description="The value to be recorded as output.")
value: Any = SchemaField(
description="The value to be recorded as output.",
default=None,
advanced=False,
)
name: str = SchemaField(description="The name of the output.")
description: str = SchemaField(
title: str | None = SchemaField(
description="The title of the input.", default=None, advanced=True
)
description: str | None = SchemaField(
description="The description of the output.",
default="",
default=None,
advanced=True,
)
format: str = SchemaField(
@@ -234,6 +245,16 @@ class AgentOutputBlock(Block):
default="",
advanced=True,
)
advanced: bool = SchemaField(
description="Whether to treat the output as advanced.",
default=False,
advanced=True,
)
secret: bool = SchemaField(
description="Whether the output should be treated as a secret.",
default=False,
advanced=True,
)
class Output(BlockSchema):
output: Any = SchemaField(description="The value recorded as output.")
@@ -271,6 +292,7 @@ class AgentOutputBlock(Block):
],
categories={BlockCategory.OUTPUT, BlockCategory.BASIC},
block_type=BlockType.OUTPUT,
static_output=True,
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:

View File

@@ -71,11 +71,18 @@ class ConditionBlock(Block):
)
def run(self, input_data: Input, **kwargs) -> BlockOutput:
value1 = input_data.value1
operator = input_data.operator
value1 = input_data.value1
if isinstance(value1, str):
value1 = float(value1.strip())
value2 = input_data.value2
if isinstance(value2, str):
value2 = float(value2.strip())
yes_value = input_data.yes_value if input_data.yes_value is not None else value1
no_value = input_data.no_value if input_data.no_value is not None else value1
no_value = input_data.no_value if input_data.no_value is not None else value2
comparison_funcs = {
ComparisonOperator.EQUAL: lambda a, b: a == b,
@@ -86,17 +93,11 @@ class ConditionBlock(Block):
ComparisonOperator.LESS_THAN_OR_EQUAL: lambda a, b: a <= b,
}
try:
result = comparison_funcs[operator](value1, value2)
result = comparison_funcs[operator](value1, value2)
yield "result", result
yield "result", result
if result:
yield "yes_output", yes_value
else:
yield "no_output", no_value
except Exception:
yield "result", None
yield "yes_output", None
yield "no_output", None
if result:
yield "yes_output", yes_value
else:
yield "no_output", no_value

View File

@@ -9,14 +9,11 @@ from prisma.models import (
AgentNodeExecution,
AgentNodeExecutionInputOutput,
)
from prisma.types import (
AgentGraphExecutionInclude,
AgentGraphExecutionWhereInput,
AgentNodeExecutionInclude,
)
from prisma.types import AgentGraphExecutionWhereInput
from pydantic import BaseModel
from backend.data.block import BlockData, BlockInput, CompletedBlockOutput
from backend.data.includes import EXECUTION_RESULT_INCLUDE, GRAPH_EXECUTION_INCLUDE
from backend.util import json, mock
@@ -110,24 +107,6 @@ class ExecutionResult(BaseModel):
# --------------------- Model functions --------------------- #
EXECUTION_RESULT_INCLUDE: AgentNodeExecutionInclude = {
"Input": True,
"Output": True,
"AgentNode": True,
"AgentGraphExecution": True,
}
GRAPH_EXECUTION_INCLUDE: AgentGraphExecutionInclude = {
"AgentNodeExecutions": {
"include": {
"Input": True,
"Output": True,
"AgentNode": True,
"AgentGraphExecution": True,
}
}
}
async def create_graph_execution(
graph_id: str,
@@ -268,21 +247,9 @@ async def update_graph_execution_start_time(graph_exec_id: str):
async def update_graph_execution_stats(
graph_exec_id: str,
error: Exception | None,
wall_time: float,
cpu_time: float,
node_count: int,
stats: dict[str, Any],
):
status = ExecutionStatus.FAILED if error else ExecutionStatus.COMPLETED
stats = (
{
"walltime": wall_time,
"cputime": cpu_time,
"nodecount": node_count,
"error": str(error) if error else None,
},
)
status = ExecutionStatus.FAILED if stats.get("error") else ExecutionStatus.COMPLETED
await AgentGraphExecution.prisma().update(
where={"id": graph_exec_id},
data={

View File

@@ -3,29 +3,22 @@ import logging
import uuid
from collections import defaultdict
from datetime import datetime, timezone
from typing import Any, Literal
from typing import Any, Literal, Type
import prisma.types
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink
from prisma.types import AgentGraphInclude
from pydantic import BaseModel
from pydantic_core import PydanticUndefinedType
from prisma.types import AgentGraphWhereInput
from pydantic.fields import computed_field
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock, BlockType
from backend.data.block import BlockInput, get_block, get_blocks
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock
from backend.data.block import BlockInput, BlockType, get_block, get_blocks
from backend.data.db import BaseDbModel, transaction
from backend.data.execution import ExecutionStatus
from backend.data.includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
from backend.util import json
logger = logging.getLogger(__name__)
class InputSchemaItem(BaseModel):
node_id: str
description: str | None = None
title: str | None = None
class Link(BaseDbModel):
source_id: str
sink_id: str
@@ -70,7 +63,7 @@ class Node(BaseDbModel):
return obj
class ExecutionMeta(BaseDbModel):
class GraphExecution(BaseDbModel):
execution_id: str
started_at: datetime
ended_at: datetime
@@ -79,20 +72,19 @@ class ExecutionMeta(BaseDbModel):
status: ExecutionStatus
@staticmethod
def from_agent_graph_execution(execution: AgentGraphExecution):
def from_db(execution: AgentGraphExecution):
now = datetime.now(timezone.utc)
start_time = execution.startedAt or execution.createdAt
end_time = execution.updatedAt or now
duration = (end_time - start_time).total_seconds()
total_run_time = duration
total_run_time = 0
if execution.AgentNodeExecutions:
for node_execution in execution.AgentNodeExecutions:
node_start = node_execution.startedTime or now
node_end = node_execution.endedTime or now
total_run_time += (node_end - node_start).total_seconds()
if execution.stats:
stats = json.loads(execution.stats)
duration = stats.get("walltime", duration)
total_run_time = stats.get("nodes_walltime", total_run_time)
return ExecutionMeta(
return GraphExecution(
id=execution.id,
execution_id=execution.id,
started_at=start_time,
@@ -103,39 +95,70 @@ class ExecutionMeta(BaseDbModel):
)
class GraphMeta(BaseDbModel):
class Graph(BaseDbModel):
version: int = 1
is_active: bool = True
is_template: bool = False
name: str
description: str
executions: list[ExecutionMeta] | None = None
executions: list[GraphExecution] = []
nodes: list[Node] = []
links: list[Link] = []
@staticmethod
def from_db(graph: AgentGraph):
if graph.AgentGraphExecution:
executions = [
ExecutionMeta.from_agent_graph_execution(execution)
for execution in graph.AgentGraphExecution
]
else:
executions = None
def _generate_schema(
type_class: Type[AgentInputBlock.Input] | Type[AgentOutputBlock.Input],
data: list[dict],
) -> dict[str, Any]:
props = []
for p in data:
try:
props.append(type_class(**p))
except Exception as e:
logger.warning(f"Invalid {type_class}: {p}, {e}")
return GraphMeta(
id=graph.id,
version=graph.version,
is_active=graph.isActive,
is_template=graph.isTemplate,
name=graph.name or "",
description=graph.description or "",
executions=executions,
return {
"type": "object",
"properties": {
p.name: {
"secret": p.secret,
"advanced": p.advanced,
"title": p.title or p.name,
**({"description": p.description} if p.description else {}),
**({"default": p.value} if p.value is not None else {}),
}
for p in props
},
"required": [p.name for p in props if p.value is None],
}
@computed_field
@property
def input_schema(self) -> dict[str, Any]:
return self._generate_schema(
AgentInputBlock.Input,
[
node.input_default
for node in self.nodes
if (b := get_block(node.block_id))
and b.block_type == BlockType.INPUT
and "name" in node.input_default
],
)
class Graph(GraphMeta):
nodes: list[Node]
links: list[Link]
subgraphs: dict[str, list[str]] = {} # subgraph_id -> [node_id]
@computed_field
@property
def output_schema(self) -> dict[str, Any]:
return self._generate_schema(
AgentOutputBlock.Input,
[
node.input_default
for node in self.nodes
if (b := get_block(node.block_id))
and b.block_type == BlockType.OUTPUT
and "name" in node.input_default
],
)
@property
def starting_nodes(self) -> list[Node]:
@@ -143,7 +166,7 @@ class Graph(GraphMeta):
input_nodes = {
v.id
for v in self.nodes
if isinstance(get_block(v.block_id), AgentInputBlock)
if (b := get_block(v.block_id)) and b.block_type == BlockType.INPUT
}
return [
node
@@ -151,28 +174,6 @@ class Graph(GraphMeta):
if node.id not in outbound_nodes or node.id in input_nodes
]
@property
def ending_nodes(self) -> list[Node]:
return [
v for v in self.nodes if isinstance(get_block(v.block_id), AgentOutputBlock)
]
@property
def subgraph_map(self) -> dict[str, str]:
"""
Returns a mapping of node_id to subgraph_id.
A node in the main graph will be mapped to the graph's id.
"""
subgraph_map = {
node_id: subgraph_id
for subgraph_id, node_ids in self.subgraphs.items()
for node_id in node_ids
}
subgraph_map.update(
{node.id: self.id for node in self.nodes if node.id not in subgraph_map}
)
return subgraph_map
def reassign_ids(self, reassign_graph_id: bool = False):
"""
Reassigns all IDs in the graph to new UUIDs.
@@ -180,11 +181,7 @@ class Graph(GraphMeta):
"""
self.validate_graph()
id_map = {
**{node.id: str(uuid.uuid4()) for node in self.nodes},
**{subgraph_id: str(uuid.uuid4()) for subgraph_id in self.subgraphs},
}
id_map = {node.id: str(uuid.uuid4()) for node in self.nodes}
if reassign_graph_id:
self.id = str(uuid.uuid4())
@@ -195,11 +192,6 @@ class Graph(GraphMeta):
link.source_id = id_map[link.source_id]
link.sink_id = id_map[link.sink_id]
self.subgraphs = {
id_map[subgraph_id]: [id_map[node_id] for node_id in node_ids]
for subgraph_id, node_ids in self.subgraphs.items()
}
def validate_graph(self, for_run: bool = False):
def sanitize(name):
return name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
@@ -227,6 +219,7 @@ class Graph(GraphMeta):
raise ValueError(
f"Node {block.name} #{node.id} required input missing: `{name}`"
)
node_map = {v.id: v for v in self.nodes}
def is_static_output_block(nid: str) -> bool:
@@ -234,18 +227,6 @@ class Graph(GraphMeta):
b = get_block(bid)
return b.static_output if b else False
def is_input_output_block(nid: str) -> bool:
bid = node_map[nid].block_id
b = get_block(bid)
return isinstance(b, AgentInputBlock) or isinstance(b, AgentOutputBlock)
# subgraphs: all nodes in subgraph must be present in the graph.
for subgraph_id, node_ids in self.subgraphs.items():
for node_id in node_ids:
if node_id not in node_map:
raise ValueError(f"Subgraph {subgraph_id}'s node {node_id} invalid")
subgraph_map = self.subgraph_map
# Links: links are connected and the connected pin data type are compatible.
for link in self.links:
source = (link.source_id, link.source_name)
@@ -274,66 +255,27 @@ class Graph(GraphMeta):
if sanitized_name not in fields:
raise ValueError(f"{suffix}, `{name}` invalid, {fields}")
if (
subgraph_map.get(link.source_id) != subgraph_map.get(link.sink_id)
and not is_input_output_block(link.source_id)
and not is_input_output_block(link.sink_id)
):
raise ValueError(f"{suffix}, Connecting nodes from different subgraph.")
if is_static_output_block(link.source_id):
link.is_static = True # Each value block output should be static.
# TODO: Add type compatibility check here.
def get_input_schema(self) -> list[InputSchemaItem]:
"""
Walks the graph and returns all the inputs that are either not:
- static
- provided by parent node
"""
input_schema = []
for node in self.nodes:
block = get_block(node.block_id)
if not block:
continue
for input_name, input_schema_item in (
block.input_schema.jsonschema().get("properties", {}).items()
):
# Check if the input is not static and not provided by a parent node
if (
input_name not in node.input_default
and not any(
link.sink_name == input_name for link in node.input_links
)
and isinstance(
block.input_schema.model_fields.get(input_name).default,
PydanticUndefinedType,
)
):
input_schema.append(
InputSchemaItem(
node_id=node.id,
description=input_schema_item.get("description"),
title=input_schema_item.get("title"),
)
)
return input_schema
@staticmethod
def from_db(graph: AgentGraph, hide_credentials: bool = False):
nodes = [
*(graph.AgentNodes or []),
*(
node
for subgraph in graph.AgentSubGraphs or []
for node in subgraph.AgentNodes or []
),
executions = [
GraphExecution.from_db(execution)
for execution in graph.AgentGraphExecution or []
]
nodes = graph.AgentNodes or []
return Graph(
**GraphMeta.from_db(graph).model_dump(),
id=graph.id,
version=graph.version,
is_active=graph.isActive,
is_template=graph.isTemplate,
name=graph.name or "",
description=graph.description or "",
executions=executions,
nodes=[Graph._process_node(node, hide_credentials) for node in nodes],
links=list(
{
@@ -342,10 +284,6 @@ class Graph(GraphMeta):
for link in (node.Input or []) + (node.Output or [])
}
),
subgraphs={
subgraph.id: [node.id for node in subgraph.AgentNodes or []]
for subgraph in graph.AgentSubGraphs or []
},
)
@staticmethod
@@ -374,20 +312,6 @@ class Graph(GraphMeta):
return result
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
"Input": True,
"Output": True,
"AgentBlock": True,
}
__SUBGRAPH_INCLUDE = {"AgentNodes": {"include": AGENT_NODE_INCLUDE}}
AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
**__SUBGRAPH_INCLUDE,
"AgentSubGraphs": {"include": __SUBGRAPH_INCLUDE}, # type: ignore
}
# --------------------- Model functions --------------------- #
@@ -399,11 +323,11 @@ async def get_node(node_id: str) -> Node:
return Node.from_db(node)
async def get_graphs_meta(
async def get_graphs(
user_id: str,
include_executions: bool = False,
filter_by: Literal["active", "template"] | None = "active",
) -> list[GraphMeta]:
) -> list[Graph]:
"""
Retrieves graph metadata objects.
Default behaviour is to get all currently active graphs.
@@ -414,9 +338,9 @@ async def get_graphs_meta(
user_id: The ID of the user that owns the graph.
Returns:
list[GraphMeta]: A list of objects representing the retrieved graph metadata.
list[Graph]: A list of objects representing the retrieved graph metadata.
"""
where_clause: prisma.types.AgentGraphWhereInput = {}
where_clause: AgentGraphWhereInput = {}
if filter_by == "active":
where_clause["isActive"] = True
@@ -425,23 +349,17 @@ async def get_graphs_meta(
where_clause["userId"] = user_id
graph_include = AGENT_GRAPH_INCLUDE
graph_include["AgentGraphExecution"] = include_executions
graphs = await AgentGraph.prisma().find_many(
where=where_clause,
distinct=["id"],
order={"version": "desc"},
include=(
AgentGraphInclude(
AgentGraphExecution={"include": {"AgentNodeExecutions": True}}
)
if include_executions
else None
),
include=graph_include,
)
if not graphs:
return []
return [GraphMeta.from_db(graph) for graph in graphs]
return [Graph.from_db(graph) for graph in graphs]
async def get_graph(
@@ -458,7 +376,7 @@ async def get_graph(
Returns `None` if the record is not found.
"""
where_clause: prisma.types.AgentGraphWhereInput = {
where_clause: AgentGraphWhereInput = {
"id": graph_id,
"isTemplate": template,
}
@@ -467,7 +385,7 @@ async def get_graph(
elif not template:
where_clause["isActive"] = True
if user_id and not template:
if user_id is not None and not template:
where_clause["userId"] = user_id
graph = await AgentGraph.prisma().find_first(
@@ -553,33 +471,13 @@ async def __create_graph(tx, graph: Graph, user_id: str):
}
)
await asyncio.gather(
*[
AgentGraph.prisma(tx).create(
data={
"id": subgraph_id,
"agentGraphParentId": graph.id,
"version": graph.version,
"name": f"SubGraph of {graph.name}",
"description": f"Sub-Graph of {graph.id}",
"isTemplate": graph.is_template,
"isActive": graph.is_active,
"userId": user_id,
}
)
for subgraph_id in graph.subgraphs
]
)
subgraph_map = graph.subgraph_map
await asyncio.gather(
*[
AgentNode.prisma(tx).create(
{
"id": node.id,
"agentBlockId": node.block_id,
"agentGraphId": subgraph_map.get(node.id, graph.id),
"agentGraphId": graph.id,
"agentGraphVersion": graph.version,
"constantInput": json.dumps(node.input_default),
"metadata": json.dumps(node.metadata),

View File

@@ -0,0 +1,29 @@
import prisma
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
"Input": True,
"Output": True,
"AgentBlock": True,
}
AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
}
EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
"Input": True,
"Output": True,
"AgentNode": True,
"AgentGraphExecution": True,
}
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
"AgentNodeExecutions": {
"include": {
"Input": True,
"Output": True,
"AgentNode": True,
"AgentGraphExecution": True,
}
}
}

View File

@@ -473,7 +473,7 @@ class Executor:
cls,
q: ExecutionQueue[NodeExecution],
node_exec: NodeExecution,
):
) -> dict[str, Any]:
log_metadata = LogMetadata(
user_id=node_exec.user_id,
graph_eid=node_exec.graph_exec_id,
@@ -493,6 +493,7 @@ class Executor:
cls.db_client.update_node_execution_stats(
node_exec.node_exec_id, execution_stats
)
return execution_stats
@classmethod
@time_measured
@@ -556,16 +557,15 @@ class Executor:
node_eid="*",
block_name="-",
)
timing_info, (node_count, error) = cls._on_graph_execution(
timing_info, (exec_stats, error) = cls._on_graph_execution(
graph_exec, cancel, log_metadata
)
exec_stats["walltime"] = timing_info.wall_time
exec_stats["cputime"] = timing_info.cpu_time
exec_stats["error"] = str(error) if error else None
cls.db_client.update_graph_execution_stats(
graph_exec_id=graph_exec.graph_exec_id,
error=error,
wall_time=timing_info.wall_time,
cpu_time=timing_info.cpu_time,
node_count=node_count,
stats=exec_stats,
)
@classmethod
@@ -575,14 +575,18 @@ class Executor:
graph_exec: GraphExecution,
cancel: threading.Event,
log_metadata: LogMetadata,
) -> tuple[int, Exception | None]:
) -> tuple[dict[str, Any], Exception | None]:
"""
Returns:
The number of node executions completed.
The execution statistics of the graph execution.
The error that occurred during the execution.
"""
log_metadata.info(f"Start graph execution {graph_exec.graph_exec_id}")
n_node_executions = 0
exec_stats = {
"nodes_walltime": 0,
"nodes_cputime": 0,
"node_count": 0,
}
error = None
finished = False
@@ -608,17 +612,20 @@ class Executor:
def make_exec_callback(exec_data: NodeExecution):
node_id = exec_data.node_id
def callback(_):
def callback(result: object):
running_executions.pop(node_id)
nonlocal n_node_executions
n_node_executions += 1
nonlocal exec_stats
if isinstance(result, dict):
exec_stats["node_count"] += 1
exec_stats["nodes_cputime"] += result.get("cputime", 0)
exec_stats["nodes_walltime"] += result.get("walltime", 0)
return callback
while not queue.empty():
if cancel.is_set():
error = RuntimeError("Execution is cancelled")
return n_node_executions, error
return exec_stats, error
exec_data = queue.get()
@@ -649,7 +656,7 @@ class Executor:
for node_id, execution in list(running_executions.items()):
if cancel.is_set():
error = RuntimeError("Execution is cancelled")
return n_node_executions, error
return exec_stats, error
if not queue.empty():
break # yield to parent loop to execute new queue items
@@ -668,7 +675,7 @@ class Executor:
finished = True
cancel.set()
cancel_thread.join()
return n_node_executions, error
return exec_stats, error
class ExecutionManager(AppService):

View File

@@ -121,8 +121,8 @@ class DeleteGraphResponse(TypedDict):
async def get_graphs(
user_id: Annotated[str, Depends(get_user_id)],
with_runs: bool = False,
) -> list[graph_db.GraphMeta]:
return await graph_db.get_graphs_meta(
) -> list[graph_db.Graph]:
return await graph_db.get_graphs(
include_executions=with_runs, filter_by="active", user_id=user_id
)
@@ -290,22 +290,6 @@ async def stop_graph_run(
return await execution_db.get_execution_results(graph_exec_id)
@v1_router.get(
path="/graphs/{graph_id}/input_schema",
tags=["graphs"],
dependencies=[Depends(auth_middleware)],
)
async def get_graph_input_schema(
graph_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> list[graph_db.InputSchemaItem]:
try:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
return graph.get_input_schema() if graph else []
except Exception:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
@v1_router.get(
path="/graphs/{graph_id}/executions",
tags=["graphs"],
@@ -374,8 +358,8 @@ async def get_graph_run_status(
)
async def get_templates(
user_id: Annotated[str, Depends(get_user_id)]
) -> list[graph_db.GraphMeta]:
return await graph_db.get_graphs_meta(filter_by="template", user_id=user_id)
) -> list[graph_db.Graph]:
return await graph_db.get_graphs(filter_by="template", user_id=user_id)
@v1_router.get(

View File

@@ -0,0 +1,11 @@
/*
Warnings:
- You are about to drop the column `agentGraphParentId` on the `AgentGraph` table. All the data in the column will be lost.
*/
-- DropForeignKey
ALTER TABLE "AgentGraph" DROP CONSTRAINT "AgentGraph_agentGraphParentId_version_fkey";
-- AlterTable
ALTER TABLE "AgentGraph" DROP COLUMN "agentGraphParentId";

View File

@@ -0,0 +1,4 @@
-- This migration converts the stats column from a list to an object.
UPDATE "AgentGraphExecution"
SET "stats" = (stats::jsonb -> 0)::text
WHERE stats IS NOT NULL AND jsonb_typeof(stats::jsonb) = 'array';

View File

@@ -52,11 +52,6 @@ model AgentGraph {
AgentGraphExecution AgentGraphExecution[]
AgentGraphExecutionSchedule AgentGraphExecutionSchedule[]
// All sub-graphs are defined within this 1-level depth list (even if it's a nested graph).
AgentSubGraphs AgentGraph[] @relation("AgentSubGraph")
agentGraphParentId String?
AgentGraphParent AgentGraph? @relation("AgentSubGraph", fields: [agentGraphParentId, version], references: [id, version], onDelete: Cascade)
@@id(name: "graphVersionId", [id, version])
}

View File

@@ -1,9 +1,12 @@
from typing import Any
from uuid import UUID
import pytest
from backend.blocks.basic import AgentInputBlock, StoreValueBlock
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock, StoreValueBlock
from backend.data.block import BlockSchema
from backend.data.graph import Graph, Link, Node
from backend.data.model import SchemaField
from backend.data.user import DEFAULT_USER_ID
from backend.server.model import CreateGraph
from backend.util.test import SpinTestServer
@@ -15,9 +18,8 @@ async def test_graph_creation(server: SpinTestServer):
Test the creation of a graph with nodes and links.
This test ensures that:
1. Nodes from different subgraphs cannot be directly connected.
2. A graph can be successfully created with valid connections.
3. The created graph has the correct structure and properties.
1. A graph can be successfully created with valid connections.
2. The created graph has the correct structure and properties.
Args:
server (SpinTestServer): The test server instance.
@@ -37,23 +39,13 @@ async def test_graph_creation(server: SpinTestServer):
links=[
Link(
source_id="node_1",
sink_id="node_3",
sink_id="node_2",
source_name="output",
sink_name="input",
sink_name="name",
),
],
subgraphs={"subgraph_1": ["node_2", "node_3"]},
)
create_graph = CreateGraph(graph=graph)
try:
await server.agent_server.test_create_graph(create_graph, DEFAULT_USER_ID)
assert False, "Should not be able to connect nodes from different subgraphs"
except ValueError as e:
assert "different subgraph" in str(e)
# Change node_1 <-> node_3 link to node_1 <-> node_2 (input for subgraph_1)
graph.links[0].sink_id = "node_2"
created_graph = await server.agent_server.test_create_graph(
create_graph, DEFAULT_USER_ID
)
@@ -73,9 +65,6 @@ async def test_graph_creation(server: SpinTestServer):
assert links[0].source_id in {nodes[0].id, nodes[1].id, nodes[2].id}
assert links[0].sink_id in {nodes[0].id, nodes[1].id, nodes[2].id}
assert len(created_graph.subgraphs) == 1
assert len(created_graph.subgraph_map) == len(created_graph.nodes) == 3
@pytest.mark.asyncio(scope="session")
async def test_get_input_schema(server: SpinTestServer):
@@ -91,90 +80,54 @@ async def test_get_input_schema(server: SpinTestServer):
server (SpinTestServer): The test server instance.
"""
value_block = StoreValueBlock().id
input_block = AgentInputBlock().id
output_block = AgentOutputBlock().id
graph = Graph(
name="TestInputSchema",
description="Test input schema",
nodes=[
Node(id="node_1", block_id=value_block),
],
links=[],
)
create_graph = CreateGraph(graph=graph)
created_graph = await server.agent_server.test_create_graph(
create_graph, DEFAULT_USER_ID
)
input_schema = created_graph.get_input_schema()
assert len(input_schema) == 1
assert input_schema[0].title == "Input"
assert input_schema[0].node_id == created_graph.nodes[0].id
@pytest.mark.asyncio(scope="session")
async def test_get_input_schema_none_required(server: SpinTestServer):
"""
Test the get_input_schema method when no inputs are required.
This test ensures that:
1. A graph can be created with a node that has a default input value.
2. The input schema of the created graph is empty when all inputs have default values.
Args:
server (SpinTestServer): The test server instance.
"""
value_block = StoreValueBlock().id
graph = Graph(
name="TestInputSchema",
description="Test input schema",
nodes=[
Node(id="node_1", block_id=value_block, input_default={"input": "value"}),
],
links=[],
)
create_graph = CreateGraph(graph=graph)
created_graph = await server.agent_server.test_create_graph(
create_graph, DEFAULT_USER_ID
)
input_schema = created_graph.get_input_schema()
assert input_schema == []
@pytest.mark.asyncio(scope="session")
async def test_get_input_schema_with_linked_blocks(server: SpinTestServer):
"""
Test the get_input_schema method with linked blocks.
This test ensures that:
1. A graph can be created with multiple nodes and links between them.
2. The input schema correctly identifies required inputs for linked blocks.
3. Inputs that are satisfied by links are not included in the input schema.
Args:
server (SpinTestServer): The test server instance.
"""
value_block = StoreValueBlock().id
graph = Graph(
name="TestInputSchemaLinkedBlocks",
description="Test input schema with linked blocks",
nodes=[
Node(id="node_1", block_id=value_block),
Node(id="node_2", block_id=value_block),
Node(
id="node_0_a",
block_id=input_block,
input_default={"name": "in_key_a", "title": "Key A", "value": "A"},
metadata={"id": "node_0_a"},
),
Node(
id="node_0_b",
block_id=input_block,
input_default={"name": "in_key_b", "advanced": True},
metadata={"id": "node_0_b"},
),
Node(id="node_1", block_id=value_block, metadata={"id": "node_1"}),
Node(
id="node_2",
block_id=output_block,
input_default={
"name": "out_key",
"description": "This is an output key",
},
metadata={"id": "node_2"},
),
],
links=[
Link(
source_id="node_0_a",
sink_id="node_1",
source_name="result",
sink_name="input",
),
Link(
source_id="node_0_b",
sink_id="node_1",
source_name="result",
sink_name="input",
),
Link(
source_id="node_1",
sink_id="node_2",
source_name="output",
sink_name="data",
sink_name="value",
),
],
)
@@ -184,25 +137,21 @@ async def test_get_input_schema_with_linked_blocks(server: SpinTestServer):
create_graph, DEFAULT_USER_ID
)
input_schema = created_graph.get_input_schema()
class ExpectedInputSchema(BlockSchema):
in_key_a: Any = SchemaField(title="Key A", default="A", advanced=False)
in_key_b: Any = SchemaField(title="in_key_b", advanced=True)
assert len(input_schema) == 2
class ExpectedOutputSchema(BlockSchema):
out_key: Any = SchemaField(
description="This is an output key",
title="out_key",
advanced=False,
)
node_1_input = next(
(item for item in input_schema if item.node_id == created_graph.nodes[0].id),
None,
)
node_2_input = next(
(item for item in input_schema if item.node_id == created_graph.nodes[1].id),
None,
)
input_schema = created_graph.input_schema
input_schema["title"] = "ExpectedInputSchema"
assert input_schema == ExpectedInputSchema.jsonschema()
assert node_1_input is not None
assert node_2_input is not None
assert node_1_input.title == "Input"
assert node_2_input.title == "Input"
assert not any(
item.title == "data" and item.node_id == created_graph.nodes[1].id
for item in input_schema
)
output_schema = created_graph.output_schema
output_schema["title"] = "ExpectedOutputSchema"
assert output_schema == ExpectedOutputSchema.jsonschema()