mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Merge branch 'dev' into master
This commit is contained in:
4
.github/workflows/codeql.yml
vendored
4
.github/workflows/codeql.yml
vendored
@@ -13,9 +13,9 @@ name: "CodeQL"
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "master", "release-*" ]
|
||||
branches: [ "master", "release-*", "dev" ]
|
||||
pull_request:
|
||||
branches: [ "master", "release-*" ]
|
||||
branches: [ "master", "release-*", "dev" ]
|
||||
schedule:
|
||||
- cron: '15 4 * * 0'
|
||||
|
||||
|
||||
@@ -148,9 +148,12 @@ class AgentInputBlock(Block):
|
||||
description="The value to be passed as input.",
|
||||
default=None,
|
||||
)
|
||||
description: str = SchemaField(
|
||||
title: str | None = SchemaField(
|
||||
description="The title of the input.", default=None, advanced=True
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the input.",
|
||||
default="",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
placeholder_values: List[Any] = SchemaField(
|
||||
@@ -163,6 +166,16 @@ class AgentInputBlock(Block):
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to show the input in the advanced section, if the field is not required.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
secret: bool = SchemaField(
|
||||
description="Whether the input should be treated as a secret.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: Any = SchemaField(description="The value passed as input.")
|
||||
@@ -195,6 +208,7 @@ class AgentInputBlock(Block):
|
||||
],
|
||||
categories={BlockCategory.INPUT, BlockCategory.BASIC},
|
||||
block_type=BlockType.INPUT,
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
@@ -205,28 +219,25 @@ class AgentOutputBlock(Block):
|
||||
"""
|
||||
Records the output of the graph for users to see.
|
||||
|
||||
Attributes:
|
||||
recorded_value: The value to be recorded as output.
|
||||
name: The name of the output.
|
||||
description: The description of the output.
|
||||
fmt_string: The format string to be used to format the recorded_value.
|
||||
|
||||
Outputs:
|
||||
output: The formatted recorded_value if fmt_string is provided and the recorded_value
|
||||
can be formatted, otherwise the raw recorded_value.
|
||||
|
||||
Behavior:
|
||||
If fmt_string is provided and the recorded_value is of a type that can be formatted,
|
||||
the block attempts to format the recorded_value using the fmt_string.
|
||||
If formatting fails or no fmt_string is provided, the raw recorded_value is output.
|
||||
If `format` is provided and the `value` is of a type that can be formatted,
|
||||
the block attempts to format the recorded_value using the `format`.
|
||||
If formatting fails or no `format` is provided, the raw `value` is output.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
value: Any = SchemaField(description="The value to be recorded as output.")
|
||||
value: Any = SchemaField(
|
||||
description="The value to be recorded as output.",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
name: str = SchemaField(description="The name of the output.")
|
||||
description: str = SchemaField(
|
||||
title: str | None = SchemaField(
|
||||
description="The title of the input.", default=None, advanced=True
|
||||
)
|
||||
description: str | None = SchemaField(
|
||||
description="The description of the output.",
|
||||
default="",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
format: str = SchemaField(
|
||||
@@ -234,6 +245,16 @@ class AgentOutputBlock(Block):
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
advanced: bool = SchemaField(
|
||||
description="Whether to treat the output as advanced.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
secret: bool = SchemaField(
|
||||
description="Whether the output should be treated as a secret.",
|
||||
default=False,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
output: Any = SchemaField(description="The value recorded as output.")
|
||||
@@ -271,6 +292,7 @@ class AgentOutputBlock(Block):
|
||||
],
|
||||
categories={BlockCategory.OUTPUT, BlockCategory.BASIC},
|
||||
block_type=BlockType.OUTPUT,
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
|
||||
@@ -71,11 +71,18 @@ class ConditionBlock(Block):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
value1 = input_data.value1
|
||||
operator = input_data.operator
|
||||
|
||||
value1 = input_data.value1
|
||||
if isinstance(value1, str):
|
||||
value1 = float(value1.strip())
|
||||
|
||||
value2 = input_data.value2
|
||||
if isinstance(value2, str):
|
||||
value2 = float(value2.strip())
|
||||
|
||||
yes_value = input_data.yes_value if input_data.yes_value is not None else value1
|
||||
no_value = input_data.no_value if input_data.no_value is not None else value1
|
||||
no_value = input_data.no_value if input_data.no_value is not None else value2
|
||||
|
||||
comparison_funcs = {
|
||||
ComparisonOperator.EQUAL: lambda a, b: a == b,
|
||||
@@ -86,17 +93,11 @@ class ConditionBlock(Block):
|
||||
ComparisonOperator.LESS_THAN_OR_EQUAL: lambda a, b: a <= b,
|
||||
}
|
||||
|
||||
try:
|
||||
result = comparison_funcs[operator](value1, value2)
|
||||
result = comparison_funcs[operator](value1, value2)
|
||||
|
||||
yield "result", result
|
||||
yield "result", result
|
||||
|
||||
if result:
|
||||
yield "yes_output", yes_value
|
||||
else:
|
||||
yield "no_output", no_value
|
||||
|
||||
except Exception:
|
||||
yield "result", None
|
||||
yield "yes_output", None
|
||||
yield "no_output", None
|
||||
if result:
|
||||
yield "yes_output", yes_value
|
||||
else:
|
||||
yield "no_output", no_value
|
||||
|
||||
@@ -9,14 +9,11 @@ from prisma.models import (
|
||||
AgentNodeExecution,
|
||||
AgentNodeExecutionInputOutput,
|
||||
)
|
||||
from prisma.types import (
|
||||
AgentGraphExecutionInclude,
|
||||
AgentGraphExecutionWhereInput,
|
||||
AgentNodeExecutionInclude,
|
||||
)
|
||||
from prisma.types import AgentGraphExecutionWhereInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import BlockData, BlockInput, CompletedBlockOutput
|
||||
from backend.data.includes import EXECUTION_RESULT_INCLUDE, GRAPH_EXECUTION_INCLUDE
|
||||
from backend.util import json, mock
|
||||
|
||||
|
||||
@@ -110,24 +107,6 @@ class ExecutionResult(BaseModel):
|
||||
|
||||
# --------------------- Model functions --------------------- #
|
||||
|
||||
EXECUTION_RESULT_INCLUDE: AgentNodeExecutionInclude = {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"AgentNode": True,
|
||||
"AgentGraphExecution": True,
|
||||
}
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE: AgentGraphExecutionInclude = {
|
||||
"AgentNodeExecutions": {
|
||||
"include": {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"AgentNode": True,
|
||||
"AgentGraphExecution": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
async def create_graph_execution(
|
||||
graph_id: str,
|
||||
@@ -268,21 +247,9 @@ async def update_graph_execution_start_time(graph_exec_id: str):
|
||||
|
||||
async def update_graph_execution_stats(
|
||||
graph_exec_id: str,
|
||||
error: Exception | None,
|
||||
wall_time: float,
|
||||
cpu_time: float,
|
||||
node_count: int,
|
||||
stats: dict[str, Any],
|
||||
):
|
||||
status = ExecutionStatus.FAILED if error else ExecutionStatus.COMPLETED
|
||||
stats = (
|
||||
{
|
||||
"walltime": wall_time,
|
||||
"cputime": cpu_time,
|
||||
"nodecount": node_count,
|
||||
"error": str(error) if error else None,
|
||||
},
|
||||
)
|
||||
|
||||
status = ExecutionStatus.FAILED if stats.get("error") else ExecutionStatus.COMPLETED
|
||||
await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
data={
|
||||
|
||||
@@ -1,30 +1,24 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Type
|
||||
|
||||
import prisma.types
|
||||
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink
|
||||
from prisma.types import AgentGraphInclude
|
||||
from pydantic import BaseModel
|
||||
from pydantic_core import PydanticUndefinedType
|
||||
from prisma.types import AgentGraphWhereInput
|
||||
from pydantic.fields import computed_field
|
||||
|
||||
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock, BlockType
|
||||
from backend.data.block import BlockInput, get_block, get_blocks
|
||||
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock
|
||||
from backend.data.block import BlockInput, BlockType, get_block, get_blocks
|
||||
from backend.data.db import BaseDbModel, transaction
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
||||
from backend.util import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InputSchemaItem(BaseModel):
|
||||
node_id: str
|
||||
description: str | None = None
|
||||
title: str | None = None
|
||||
|
||||
|
||||
class Link(BaseDbModel):
|
||||
source_id: str
|
||||
sink_id: str
|
||||
@@ -69,7 +63,7 @@ class Node(BaseDbModel):
|
||||
return obj
|
||||
|
||||
|
||||
class ExecutionMeta(BaseDbModel):
|
||||
class GraphExecution(BaseDbModel):
|
||||
execution_id: str
|
||||
started_at: datetime
|
||||
ended_at: datetime
|
||||
@@ -78,20 +72,19 @@ class ExecutionMeta(BaseDbModel):
|
||||
status: ExecutionStatus
|
||||
|
||||
@staticmethod
|
||||
def from_agent_graph_execution(execution: AgentGraphExecution):
|
||||
def from_db(execution: AgentGraphExecution):
|
||||
now = datetime.now(timezone.utc)
|
||||
start_time = execution.startedAt or execution.createdAt
|
||||
end_time = execution.updatedAt or now
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
total_run_time = duration
|
||||
|
||||
total_run_time = 0
|
||||
if execution.AgentNodeExecutions:
|
||||
for node_execution in execution.AgentNodeExecutions:
|
||||
node_start = node_execution.startedTime or now
|
||||
node_end = node_execution.endedTime or now
|
||||
total_run_time += (node_end - node_start).total_seconds()
|
||||
if execution.stats:
|
||||
stats = json.loads(execution.stats)
|
||||
duration = stats.get("walltime", duration)
|
||||
total_run_time = stats.get("nodes_walltime", total_run_time)
|
||||
|
||||
return ExecutionMeta(
|
||||
return GraphExecution(
|
||||
id=execution.id,
|
||||
execution_id=execution.id,
|
||||
started_at=start_time,
|
||||
@@ -102,39 +95,70 @@ class ExecutionMeta(BaseDbModel):
|
||||
)
|
||||
|
||||
|
||||
class GraphMeta(BaseDbModel):
|
||||
class Graph(BaseDbModel):
|
||||
version: int = 1
|
||||
is_active: bool = True
|
||||
is_template: bool = False
|
||||
name: str
|
||||
description: str
|
||||
executions: list[ExecutionMeta] | None = None
|
||||
executions: list[GraphExecution] = []
|
||||
nodes: list[Node] = []
|
||||
links: list[Link] = []
|
||||
|
||||
@staticmethod
|
||||
def from_db(graph: AgentGraph):
|
||||
if graph.AgentGraphExecution:
|
||||
executions = [
|
||||
ExecutionMeta.from_agent_graph_execution(execution)
|
||||
for execution in graph.AgentGraphExecution
|
||||
]
|
||||
else:
|
||||
executions = None
|
||||
def _generate_schema(
|
||||
type_class: Type[AgentInputBlock.Input] | Type[AgentOutputBlock.Input],
|
||||
data: list[dict],
|
||||
) -> dict[str, Any]:
|
||||
props = []
|
||||
for p in data:
|
||||
try:
|
||||
props.append(type_class(**p))
|
||||
except Exception as e:
|
||||
logger.warning(f"Invalid {type_class}: {p}, {e}")
|
||||
|
||||
return GraphMeta(
|
||||
id=graph.id,
|
||||
version=graph.version,
|
||||
is_active=graph.isActive,
|
||||
is_template=graph.isTemplate,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
executions=executions,
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
p.name: {
|
||||
"secret": p.secret,
|
||||
"advanced": p.advanced,
|
||||
"title": p.title or p.name,
|
||||
**({"description": p.description} if p.description else {}),
|
||||
**({"default": p.value} if p.value is not None else {}),
|
||||
}
|
||||
for p in props
|
||||
},
|
||||
"required": [p.name for p in props if p.value is None],
|
||||
}
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def input_schema(self) -> dict[str, Any]:
|
||||
return self._generate_schema(
|
||||
AgentInputBlock.Input,
|
||||
[
|
||||
node.input_default
|
||||
for node in self.nodes
|
||||
if (b := get_block(node.block_id))
|
||||
and b.block_type == BlockType.INPUT
|
||||
and "name" in node.input_default
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class Graph(GraphMeta):
|
||||
nodes: list[Node]
|
||||
links: list[Link]
|
||||
subgraphs: dict[str, list[str]] = {} # subgraph_id -> [node_id]
|
||||
@computed_field
|
||||
@property
|
||||
def output_schema(self) -> dict[str, Any]:
|
||||
return self._generate_schema(
|
||||
AgentOutputBlock.Input,
|
||||
[
|
||||
node.input_default
|
||||
for node in self.nodes
|
||||
if (b := get_block(node.block_id))
|
||||
and b.block_type == BlockType.OUTPUT
|
||||
and "name" in node.input_default
|
||||
],
|
||||
)
|
||||
|
||||
@property
|
||||
def starting_nodes(self) -> list[Node]:
|
||||
@@ -142,7 +166,7 @@ class Graph(GraphMeta):
|
||||
input_nodes = {
|
||||
v.id
|
||||
for v in self.nodes
|
||||
if isinstance(get_block(v.block_id), AgentInputBlock)
|
||||
if (b := get_block(v.block_id)) and b.block_type == BlockType.INPUT
|
||||
}
|
||||
return [
|
||||
node
|
||||
@@ -150,28 +174,6 @@ class Graph(GraphMeta):
|
||||
if node.id not in outbound_nodes or node.id in input_nodes
|
||||
]
|
||||
|
||||
@property
|
||||
def ending_nodes(self) -> list[Node]:
|
||||
return [
|
||||
v for v in self.nodes if isinstance(get_block(v.block_id), AgentOutputBlock)
|
||||
]
|
||||
|
||||
@property
|
||||
def subgraph_map(self) -> dict[str, str]:
|
||||
"""
|
||||
Returns a mapping of node_id to subgraph_id.
|
||||
A node in the main graph will be mapped to the graph's id.
|
||||
"""
|
||||
subgraph_map = {
|
||||
node_id: subgraph_id
|
||||
for subgraph_id, node_ids in self.subgraphs.items()
|
||||
for node_id in node_ids
|
||||
}
|
||||
subgraph_map.update(
|
||||
{node.id: self.id for node in self.nodes if node.id not in subgraph_map}
|
||||
)
|
||||
return subgraph_map
|
||||
|
||||
def reassign_ids(self, reassign_graph_id: bool = False):
|
||||
"""
|
||||
Reassigns all IDs in the graph to new UUIDs.
|
||||
@@ -179,11 +181,7 @@ class Graph(GraphMeta):
|
||||
"""
|
||||
self.validate_graph()
|
||||
|
||||
id_map = {
|
||||
**{node.id: str(uuid.uuid4()) for node in self.nodes},
|
||||
**{subgraph_id: str(uuid.uuid4()) for subgraph_id in self.subgraphs},
|
||||
}
|
||||
|
||||
id_map = {node.id: str(uuid.uuid4()) for node in self.nodes}
|
||||
if reassign_graph_id:
|
||||
self.id = str(uuid.uuid4())
|
||||
|
||||
@@ -194,16 +192,15 @@ class Graph(GraphMeta):
|
||||
link.source_id = id_map[link.source_id]
|
||||
link.sink_id = id_map[link.sink_id]
|
||||
|
||||
self.subgraphs = {
|
||||
id_map[subgraph_id]: [id_map[node_id] for node_id in node_ids]
|
||||
for subgraph_id, node_ids in self.subgraphs.items()
|
||||
}
|
||||
|
||||
def validate_graph(self, for_run: bool = False):
|
||||
def sanitize(name):
|
||||
return name.split("_#_")[0].split("_@_")[0].split("_$_")[0]
|
||||
|
||||
# Nodes: required fields are filled or connected, except for InputBlock.
|
||||
input_links = defaultdict(list)
|
||||
for link in self.links:
|
||||
input_links[link.sink_id].append(link)
|
||||
|
||||
# Nodes: required fields are filled or connected
|
||||
for node in self.nodes:
|
||||
block = get_block(node.block_id)
|
||||
if block is None:
|
||||
@@ -211,7 +208,7 @@ class Graph(GraphMeta):
|
||||
|
||||
provided_inputs = set(
|
||||
[sanitize(name) for name in node.input_default]
|
||||
+ [sanitize(link.sink_name) for link in node.input_links]
|
||||
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
|
||||
)
|
||||
for name in block.input_schema.get_required_fields():
|
||||
if name not in provided_inputs and (
|
||||
@@ -222,6 +219,7 @@ class Graph(GraphMeta):
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} required input missing: `{name}`"
|
||||
)
|
||||
|
||||
node_map = {v.id: v for v in self.nodes}
|
||||
|
||||
def is_static_output_block(nid: str) -> bool:
|
||||
@@ -229,18 +227,6 @@ class Graph(GraphMeta):
|
||||
b = get_block(bid)
|
||||
return b.static_output if b else False
|
||||
|
||||
def is_input_output_block(nid: str) -> bool:
|
||||
bid = node_map[nid].block_id
|
||||
b = get_block(bid)
|
||||
return isinstance(b, AgentInputBlock) or isinstance(b, AgentOutputBlock)
|
||||
|
||||
# subgraphs: all nodes in subgraph must be present in the graph.
|
||||
for subgraph_id, node_ids in self.subgraphs.items():
|
||||
for node_id in node_ids:
|
||||
if node_id not in node_map:
|
||||
raise ValueError(f"Subgraph {subgraph_id}'s node {node_id} invalid")
|
||||
subgraph_map = self.subgraph_map
|
||||
|
||||
# Links: links are connected and the connected pin data type are compatible.
|
||||
for link in self.links:
|
||||
source = (link.source_id, link.source_name)
|
||||
@@ -269,66 +255,27 @@ class Graph(GraphMeta):
|
||||
if sanitized_name not in fields:
|
||||
raise ValueError(f"{suffix}, `{name}` invalid, {fields}")
|
||||
|
||||
if (
|
||||
subgraph_map.get(link.source_id) != subgraph_map.get(link.sink_id)
|
||||
and not is_input_output_block(link.source_id)
|
||||
and not is_input_output_block(link.sink_id)
|
||||
):
|
||||
raise ValueError(f"{suffix}, Connecting nodes from different subgraph.")
|
||||
|
||||
if is_static_output_block(link.source_id):
|
||||
link.is_static = True # Each value block output should be static.
|
||||
|
||||
# TODO: Add type compatibility check here.
|
||||
|
||||
def get_input_schema(self) -> list[InputSchemaItem]:
|
||||
"""
|
||||
Walks the graph and returns all the inputs that are either not:
|
||||
- static
|
||||
- provided by parent node
|
||||
"""
|
||||
input_schema = []
|
||||
for node in self.nodes:
|
||||
block = get_block(node.block_id)
|
||||
if not block:
|
||||
continue
|
||||
|
||||
for input_name, input_schema_item in (
|
||||
block.input_schema.jsonschema().get("properties", {}).items()
|
||||
):
|
||||
# Check if the input is not static and not provided by a parent node
|
||||
if (
|
||||
input_name not in node.input_default
|
||||
and not any(
|
||||
link.sink_name == input_name for link in node.input_links
|
||||
)
|
||||
and isinstance(
|
||||
block.input_schema.model_fields.get(input_name).default,
|
||||
PydanticUndefinedType,
|
||||
)
|
||||
):
|
||||
input_schema.append(
|
||||
InputSchemaItem(
|
||||
node_id=node.id,
|
||||
description=input_schema_item.get("description"),
|
||||
title=input_schema_item.get("title"),
|
||||
)
|
||||
)
|
||||
|
||||
return input_schema
|
||||
|
||||
@staticmethod
|
||||
def from_db(graph: AgentGraph, hide_credentials: bool = False):
|
||||
nodes = [
|
||||
*(graph.AgentNodes or []),
|
||||
*(
|
||||
node
|
||||
for subgraph in graph.AgentSubGraphs or []
|
||||
for node in subgraph.AgentNodes or []
|
||||
),
|
||||
executions = [
|
||||
GraphExecution.from_db(execution)
|
||||
for execution in graph.AgentGraphExecution or []
|
||||
]
|
||||
nodes = graph.AgentNodes or []
|
||||
|
||||
return Graph(
|
||||
**GraphMeta.from_db(graph).model_dump(),
|
||||
id=graph.id,
|
||||
version=graph.version,
|
||||
is_active=graph.isActive,
|
||||
is_template=graph.isTemplate,
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
executions=executions,
|
||||
nodes=[Graph._process_node(node, hide_credentials) for node in nodes],
|
||||
links=list(
|
||||
{
|
||||
@@ -337,10 +284,6 @@ class Graph(GraphMeta):
|
||||
for link in (node.Input or []) + (node.Output or [])
|
||||
}
|
||||
),
|
||||
subgraphs={
|
||||
subgraph.id: [node.id for node in subgraph.AgentNodes or []]
|
||||
for subgraph in graph.AgentSubGraphs or []
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -369,20 +312,6 @@ class Graph(GraphMeta):
|
||||
return result
|
||||
|
||||
|
||||
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"AgentBlock": True,
|
||||
}
|
||||
|
||||
__SUBGRAPH_INCLUDE = {"AgentNodes": {"include": AGENT_NODE_INCLUDE}}
|
||||
|
||||
AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
|
||||
**__SUBGRAPH_INCLUDE,
|
||||
"AgentSubGraphs": {"include": __SUBGRAPH_INCLUDE}, # type: ignore
|
||||
}
|
||||
|
||||
|
||||
# --------------------- Model functions --------------------- #
|
||||
|
||||
|
||||
@@ -394,11 +323,11 @@ async def get_node(node_id: str) -> Node:
|
||||
return Node.from_db(node)
|
||||
|
||||
|
||||
async def get_graphs_meta(
|
||||
async def get_graphs(
|
||||
user_id: str,
|
||||
include_executions: bool = False,
|
||||
filter_by: Literal["active", "template"] | None = "active",
|
||||
) -> list[GraphMeta]:
|
||||
) -> list[Graph]:
|
||||
"""
|
||||
Retrieves graph metadata objects.
|
||||
Default behaviour is to get all currently active graphs.
|
||||
@@ -409,9 +338,9 @@ async def get_graphs_meta(
|
||||
user_id: The ID of the user that owns the graph.
|
||||
|
||||
Returns:
|
||||
list[GraphMeta]: A list of objects representing the retrieved graph metadata.
|
||||
list[Graph]: A list of objects representing the retrieved graph metadata.
|
||||
"""
|
||||
where_clause: prisma.types.AgentGraphWhereInput = {}
|
||||
where_clause: AgentGraphWhereInput = {}
|
||||
|
||||
if filter_by == "active":
|
||||
where_clause["isActive"] = True
|
||||
@@ -420,23 +349,17 @@ async def get_graphs_meta(
|
||||
|
||||
where_clause["userId"] = user_id
|
||||
|
||||
graph_include = AGENT_GRAPH_INCLUDE
|
||||
graph_include["AgentGraphExecution"] = include_executions
|
||||
|
||||
graphs = await AgentGraph.prisma().find_many(
|
||||
where=where_clause,
|
||||
distinct=["id"],
|
||||
order={"version": "desc"},
|
||||
include=(
|
||||
AgentGraphInclude(
|
||||
AgentGraphExecution={"include": {"AgentNodeExecutions": True}}
|
||||
)
|
||||
if include_executions
|
||||
else None
|
||||
),
|
||||
include=graph_include,
|
||||
)
|
||||
|
||||
if not graphs:
|
||||
return []
|
||||
|
||||
return [GraphMeta.from_db(graph) for graph in graphs]
|
||||
return [Graph.from_db(graph) for graph in graphs]
|
||||
|
||||
|
||||
async def get_graph(
|
||||
@@ -453,7 +376,7 @@ async def get_graph(
|
||||
|
||||
Returns `None` if the record is not found.
|
||||
"""
|
||||
where_clause: prisma.types.AgentGraphWhereInput = {
|
||||
where_clause: AgentGraphWhereInput = {
|
||||
"id": graph_id,
|
||||
"isTemplate": template,
|
||||
}
|
||||
@@ -462,7 +385,7 @@ async def get_graph(
|
||||
elif not template:
|
||||
where_clause["isActive"] = True
|
||||
|
||||
if user_id and not template:
|
||||
if user_id is not None and not template:
|
||||
where_clause["userId"] = user_id
|
||||
|
||||
graph = await AgentGraph.prisma().find_first(
|
||||
@@ -548,33 +471,13 @@ async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
}
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
*[
|
||||
AgentGraph.prisma(tx).create(
|
||||
data={
|
||||
"id": subgraph_id,
|
||||
"agentGraphParentId": graph.id,
|
||||
"version": graph.version,
|
||||
"name": f"SubGraph of {graph.name}",
|
||||
"description": f"Sub-Graph of {graph.id}",
|
||||
"isTemplate": graph.is_template,
|
||||
"isActive": graph.is_active,
|
||||
"userId": user_id,
|
||||
}
|
||||
)
|
||||
for subgraph_id in graph.subgraphs
|
||||
]
|
||||
)
|
||||
|
||||
subgraph_map = graph.subgraph_map
|
||||
|
||||
await asyncio.gather(
|
||||
*[
|
||||
AgentNode.prisma(tx).create(
|
||||
{
|
||||
"id": node.id,
|
||||
"agentBlockId": node.block_id,
|
||||
"agentGraphId": subgraph_map.get(node.id, graph.id),
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"constantInput": json.dumps(node.input_default),
|
||||
"metadata": json.dumps(node.metadata),
|
||||
|
||||
29
autogpt_platform/backend/backend/data/includes.py
Normal file
29
autogpt_platform/backend/backend/data/includes.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import prisma
|
||||
|
||||
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"AgentBlock": True,
|
||||
}
|
||||
|
||||
AGENT_GRAPH_INCLUDE: prisma.types.AgentGraphInclude = {
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
|
||||
}
|
||||
|
||||
EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"AgentNode": True,
|
||||
"AgentGraphExecution": True,
|
||||
}
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
|
||||
"AgentNodeExecutions": {
|
||||
"include": {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"AgentNode": True,
|
||||
"AgentGraphExecution": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -36,7 +36,12 @@ from backend.util import json
|
||||
from backend.util.decorator import error_logged, time_measured
|
||||
from backend.util.logging import configure_logging
|
||||
from backend.util.process import set_service_name
|
||||
from backend.util.service import AppService, expose, get_service_client
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
close_service_client,
|
||||
expose,
|
||||
get_service_client,
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
from backend.util.type import convert
|
||||
|
||||
@@ -452,6 +457,8 @@ class Executor:
|
||||
cls.creds_manager.release_all_locks()
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB manager...")
|
||||
close_service_client(cls.db_client)
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
|
||||
|
||||
@classmethod
|
||||
@@ -473,7 +480,7 @@ class Executor:
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecution],
|
||||
node_exec: NodeExecution,
|
||||
):
|
||||
) -> dict[str, Any]:
|
||||
log_metadata = LogMetadata(
|
||||
user_id=node_exec.user_id,
|
||||
graph_eid=node_exec.graph_exec_id,
|
||||
@@ -493,6 +500,7 @@ class Executor:
|
||||
cls.db_client.update_node_execution_stats(
|
||||
node_exec.node_exec_id, execution_stats
|
||||
)
|
||||
return execution_stats
|
||||
|
||||
@classmethod
|
||||
@time_measured
|
||||
@@ -536,6 +544,8 @@ class Executor:
|
||||
prefix = f"[on_graph_executor_stop {cls.pid}]"
|
||||
logger.info(f"{prefix} ⏳ Terminating node executor pool...")
|
||||
cls.executor.terminate()
|
||||
logger.info(f"{prefix} ⏳ Disconnecting DB manager...")
|
||||
close_service_client(cls.db_client)
|
||||
logger.info(f"{prefix} ✅ Finished cleanup")
|
||||
|
||||
@classmethod
|
||||
@@ -556,16 +566,15 @@ class Executor:
|
||||
node_eid="*",
|
||||
block_name="-",
|
||||
)
|
||||
timing_info, (node_count, error) = cls._on_graph_execution(
|
||||
timing_info, (exec_stats, error) = cls._on_graph_execution(
|
||||
graph_exec, cancel, log_metadata
|
||||
)
|
||||
|
||||
exec_stats["walltime"] = timing_info.wall_time
|
||||
exec_stats["cputime"] = timing_info.cpu_time
|
||||
exec_stats["error"] = str(error) if error else None
|
||||
cls.db_client.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
error=error,
|
||||
wall_time=timing_info.wall_time,
|
||||
cpu_time=timing_info.cpu_time,
|
||||
node_count=node_count,
|
||||
stats=exec_stats,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -575,14 +584,18 @@ class Executor:
|
||||
graph_exec: GraphExecution,
|
||||
cancel: threading.Event,
|
||||
log_metadata: LogMetadata,
|
||||
) -> tuple[int, Exception | None]:
|
||||
) -> tuple[dict[str, Any], Exception | None]:
|
||||
"""
|
||||
Returns:
|
||||
The number of node executions completed.
|
||||
The execution statistics of the graph execution.
|
||||
The error that occurred during the execution.
|
||||
"""
|
||||
log_metadata.info(f"Start graph execution {graph_exec.graph_exec_id}")
|
||||
n_node_executions = 0
|
||||
exec_stats = {
|
||||
"nodes_walltime": 0,
|
||||
"nodes_cputime": 0,
|
||||
"node_count": 0,
|
||||
}
|
||||
error = None
|
||||
finished = False
|
||||
|
||||
@@ -608,17 +621,20 @@ class Executor:
|
||||
def make_exec_callback(exec_data: NodeExecution):
|
||||
node_id = exec_data.node_id
|
||||
|
||||
def callback(_):
|
||||
def callback(result: object):
|
||||
running_executions.pop(node_id)
|
||||
nonlocal n_node_executions
|
||||
n_node_executions += 1
|
||||
nonlocal exec_stats
|
||||
if isinstance(result, dict):
|
||||
exec_stats["node_count"] += 1
|
||||
exec_stats["nodes_cputime"] += result.get("cputime", 0)
|
||||
exec_stats["nodes_walltime"] += result.get("walltime", 0)
|
||||
|
||||
return callback
|
||||
|
||||
while not queue.empty():
|
||||
if cancel.is_set():
|
||||
error = RuntimeError("Execution is cancelled")
|
||||
return n_node_executions, error
|
||||
return exec_stats, error
|
||||
|
||||
exec_data = queue.get()
|
||||
|
||||
@@ -649,7 +665,7 @@ class Executor:
|
||||
for node_id, execution in list(running_executions.items()):
|
||||
if cancel.is_set():
|
||||
error = RuntimeError("Execution is cancelled")
|
||||
return n_node_executions, error
|
||||
return exec_stats, error
|
||||
|
||||
if not queue.empty():
|
||||
break # yield to parent loop to execute new queue items
|
||||
@@ -668,7 +684,7 @@ class Executor:
|
||||
finished = True
|
||||
cancel.set()
|
||||
cancel_thread.join()
|
||||
return n_node_executions, error
|
||||
return exec_stats, error
|
||||
|
||||
|
||||
class ExecutionManager(AppService):
|
||||
|
||||
@@ -121,8 +121,8 @@ class DeleteGraphResponse(TypedDict):
|
||||
async def get_graphs(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
with_runs: bool = False,
|
||||
) -> list[graph_db.GraphMeta]:
|
||||
return await graph_db.get_graphs_meta(
|
||||
) -> list[graph_db.Graph]:
|
||||
return await graph_db.get_graphs(
|
||||
include_executions=with_runs, filter_by="active", user_id=user_id
|
||||
)
|
||||
|
||||
@@ -290,22 +290,6 @@ async def stop_graph_run(
|
||||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/input_schema",
|
||||
tags=["graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_graph_input_schema(
|
||||
graph_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[graph_db.InputSchemaItem]:
|
||||
try:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
return graph.get_input_schema() if graph else []
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/graphs/{graph_id}/executions",
|
||||
tags=["graphs"],
|
||||
@@ -374,8 +358,8 @@ async def get_graph_run_status(
|
||||
)
|
||||
async def get_templates(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[graph_db.GraphMeta]:
|
||||
return await graph_db.get_graphs_meta(filter_by="template", user_id=user_id)
|
||||
) -> list[graph_db.Graph]:
|
||||
return await graph_db.get_graphs(filter_by="template", user_id=user_id)
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
|
||||
@@ -30,6 +30,7 @@ from typing import (
|
||||
import Pyro5.api
|
||||
from pydantic import BaseModel
|
||||
from Pyro5 import api as pyro
|
||||
from Pyro5 import config as pyro_config
|
||||
|
||||
from backend.data import db, redis
|
||||
from backend.util.process import AppProcess
|
||||
@@ -40,7 +41,10 @@ logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T")
|
||||
C = TypeVar("C", bound=Callable)
|
||||
|
||||
pyro_host = Config().pyro_host
|
||||
config = Config()
|
||||
pyro_host = config.pyro_host
|
||||
pyro_config.MAX_RETRIES = config.pyro_client_comm_retry # type: ignore
|
||||
pyro_config.COMMTIMEOUT = config.pyro_client_comm_timeout # type: ignore
|
||||
|
||||
|
||||
def expose(func: C) -> C:
|
||||
@@ -166,8 +170,14 @@ class AppService(AppProcess, ABC):
|
||||
|
||||
@conn_retry("Pyro", "Starting Pyro Service")
|
||||
def __start_pyro(self):
|
||||
host = Config().pyro_host
|
||||
daemon = Pyro5.api.Daemon(host=host, port=self.get_port())
|
||||
conf = Config()
|
||||
maximum_connection_thread_count = max(
|
||||
Pyro5.config.THREADPOOL_SIZE,
|
||||
conf.num_node_workers * conf.num_graph_workers,
|
||||
)
|
||||
|
||||
Pyro5.config.THREADPOOL_SIZE = maximum_connection_thread_count # type: ignore
|
||||
daemon = Pyro5.api.Daemon(host=conf.pyro_host, port=self.get_port())
|
||||
self.uri = daemon.register(self, objectId=self.service_name)
|
||||
logger.info(f"[{self.service_name}] Connected to Pyro; URI = {self.uri}")
|
||||
daemon.requestLoop()
|
||||
@@ -182,10 +192,21 @@ class AppService(AppProcess, ABC):
|
||||
AS = TypeVar("AS", bound=AppService)
|
||||
|
||||
|
||||
class PyroClient:
|
||||
proxy: Pyro5.api.Proxy
|
||||
|
||||
|
||||
def close_service_client(client: AppService) -> None:
|
||||
if isinstance(client, PyroClient):
|
||||
client.proxy._pyroRelease()
|
||||
else:
|
||||
raise RuntimeError(f"Client {client.__class__} is not a Pyro client.")
|
||||
|
||||
|
||||
def get_service_client(service_type: Type[AS]) -> AS:
|
||||
service_name = service_type.service_name
|
||||
|
||||
class DynamicClient:
|
||||
class DynamicClient(PyroClient):
|
||||
@conn_retry("Pyro", f"Connecting to [{service_name}]")
|
||||
def __init__(self):
|
||||
host = os.environ.get(f"{service_name.upper()}_HOST", "localhost")
|
||||
|
||||
@@ -69,6 +69,14 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default="localhost",
|
||||
description="The default hostname of the Pyro server.",
|
||||
)
|
||||
pyro_client_comm_timeout: float = Field(
|
||||
default=15,
|
||||
description="The default timeout in seconds, for Pyro client connections.",
|
||||
)
|
||||
pyro_client_comm_retry: int = Field(
|
||||
default=3,
|
||||
description="The default number of retries for Pyro client connections.",
|
||||
)
|
||||
enable_auth: bool = Field(
|
||||
default=True,
|
||||
description="If authentication is enabled or not",
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- You are about to drop the column `agentGraphParentId` on the `AgentGraph` table. All the data in the column will be lost.
|
||||
|
||||
*/
|
||||
-- DropForeignKey
|
||||
ALTER TABLE "AgentGraph" DROP CONSTRAINT "AgentGraph_agentGraphParentId_version_fkey";
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentGraph" DROP COLUMN "agentGraphParentId";
|
||||
@@ -0,0 +1,4 @@
|
||||
-- This migration converts the stats column from a list to an object.
|
||||
UPDATE "AgentGraphExecution"
|
||||
SET "stats" = (stats::jsonb -> 0)::text
|
||||
WHERE stats IS NOT NULL AND jsonb_typeof(stats::jsonb) = 'array';
|
||||
@@ -52,11 +52,6 @@ model AgentGraph {
|
||||
AgentGraphExecution AgentGraphExecution[]
|
||||
AgentGraphExecutionSchedule AgentGraphExecutionSchedule[]
|
||||
|
||||
// All sub-graphs are defined within this 1-level depth list (even if it's a nested graph).
|
||||
AgentSubGraphs AgentGraph[] @relation("AgentSubGraph")
|
||||
agentGraphParentId String?
|
||||
AgentGraphParent AgentGraph? @relation("AgentSubGraph", fields: [agentGraphParentId, version], references: [id, version], onDelete: Cascade)
|
||||
|
||||
@@id(name: "graphVersionId", [id, version])
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.basic import AgentInputBlock, StoreValueBlock
|
||||
from backend.blocks.basic import AgentInputBlock, AgentOutputBlock, StoreValueBlock
|
||||
from backend.data.block import BlockSchema
|
||||
from backend.data.graph import Graph, Link, Node
|
||||
from backend.data.model import SchemaField
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.util.test import SpinTestServer
|
||||
@@ -15,9 +18,8 @@ async def test_graph_creation(server: SpinTestServer):
|
||||
Test the creation of a graph with nodes and links.
|
||||
|
||||
This test ensures that:
|
||||
1. Nodes from different subgraphs cannot be directly connected.
|
||||
2. A graph can be successfully created with valid connections.
|
||||
3. The created graph has the correct structure and properties.
|
||||
1. A graph can be successfully created with valid connections.
|
||||
2. The created graph has the correct structure and properties.
|
||||
|
||||
Args:
|
||||
server (SpinTestServer): The test server instance.
|
||||
@@ -37,23 +39,13 @@ async def test_graph_creation(server: SpinTestServer):
|
||||
links=[
|
||||
Link(
|
||||
source_id="node_1",
|
||||
sink_id="node_3",
|
||||
sink_id="node_2",
|
||||
source_name="output",
|
||||
sink_name="input",
|
||||
sink_name="name",
|
||||
),
|
||||
],
|
||||
subgraphs={"subgraph_1": ["node_2", "node_3"]},
|
||||
)
|
||||
create_graph = CreateGraph(graph=graph)
|
||||
|
||||
try:
|
||||
await server.agent_server.test_create_graph(create_graph, DEFAULT_USER_ID)
|
||||
assert False, "Should not be able to connect nodes from different subgraphs"
|
||||
except ValueError as e:
|
||||
assert "different subgraph" in str(e)
|
||||
|
||||
# Change node_1 <-> node_3 link to node_1 <-> node_2 (input for subgraph_1)
|
||||
graph.links[0].sink_id = "node_2"
|
||||
created_graph = await server.agent_server.test_create_graph(
|
||||
create_graph, DEFAULT_USER_ID
|
||||
)
|
||||
@@ -73,9 +65,6 @@ async def test_graph_creation(server: SpinTestServer):
|
||||
assert links[0].source_id in {nodes[0].id, nodes[1].id, nodes[2].id}
|
||||
assert links[0].sink_id in {nodes[0].id, nodes[1].id, nodes[2].id}
|
||||
|
||||
assert len(created_graph.subgraphs) == 1
|
||||
assert len(created_graph.subgraph_map) == len(created_graph.nodes) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_input_schema(server: SpinTestServer):
|
||||
@@ -91,90 +80,54 @@ async def test_get_input_schema(server: SpinTestServer):
|
||||
server (SpinTestServer): The test server instance.
|
||||
"""
|
||||
value_block = StoreValueBlock().id
|
||||
input_block = AgentInputBlock().id
|
||||
output_block = AgentOutputBlock().id
|
||||
|
||||
graph = Graph(
|
||||
name="TestInputSchema",
|
||||
description="Test input schema",
|
||||
nodes=[
|
||||
Node(id="node_1", block_id=value_block),
|
||||
],
|
||||
links=[],
|
||||
)
|
||||
|
||||
create_graph = CreateGraph(graph=graph)
|
||||
created_graph = await server.agent_server.test_create_graph(
|
||||
create_graph, DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
input_schema = created_graph.get_input_schema()
|
||||
|
||||
assert len(input_schema) == 1
|
||||
|
||||
assert input_schema[0].title == "Input"
|
||||
assert input_schema[0].node_id == created_graph.nodes[0].id
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_input_schema_none_required(server: SpinTestServer):
|
||||
"""
|
||||
Test the get_input_schema method when no inputs are required.
|
||||
|
||||
This test ensures that:
|
||||
1. A graph can be created with a node that has a default input value.
|
||||
2. The input schema of the created graph is empty when all inputs have default values.
|
||||
|
||||
Args:
|
||||
server (SpinTestServer): The test server instance.
|
||||
"""
|
||||
value_block = StoreValueBlock().id
|
||||
|
||||
graph = Graph(
|
||||
name="TestInputSchema",
|
||||
description="Test input schema",
|
||||
nodes=[
|
||||
Node(id="node_1", block_id=value_block, input_default={"input": "value"}),
|
||||
],
|
||||
links=[],
|
||||
)
|
||||
|
||||
create_graph = CreateGraph(graph=graph)
|
||||
created_graph = await server.agent_server.test_create_graph(
|
||||
create_graph, DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
input_schema = created_graph.get_input_schema()
|
||||
|
||||
assert input_schema == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_input_schema_with_linked_blocks(server: SpinTestServer):
|
||||
"""
|
||||
Test the get_input_schema method with linked blocks.
|
||||
|
||||
This test ensures that:
|
||||
1. A graph can be created with multiple nodes and links between them.
|
||||
2. The input schema correctly identifies required inputs for linked blocks.
|
||||
3. Inputs that are satisfied by links are not included in the input schema.
|
||||
|
||||
Args:
|
||||
server (SpinTestServer): The test server instance.
|
||||
"""
|
||||
value_block = StoreValueBlock().id
|
||||
|
||||
graph = Graph(
|
||||
name="TestInputSchemaLinkedBlocks",
|
||||
description="Test input schema with linked blocks",
|
||||
nodes=[
|
||||
Node(id="node_1", block_id=value_block),
|
||||
Node(id="node_2", block_id=value_block),
|
||||
Node(
|
||||
id="node_0_a",
|
||||
block_id=input_block,
|
||||
input_default={"name": "in_key_a", "title": "Key A", "value": "A"},
|
||||
metadata={"id": "node_0_a"},
|
||||
),
|
||||
Node(
|
||||
id="node_0_b",
|
||||
block_id=input_block,
|
||||
input_default={"name": "in_key_b", "advanced": True},
|
||||
metadata={"id": "node_0_b"},
|
||||
),
|
||||
Node(id="node_1", block_id=value_block, metadata={"id": "node_1"}),
|
||||
Node(
|
||||
id="node_2",
|
||||
block_id=output_block,
|
||||
input_default={
|
||||
"name": "out_key",
|
||||
"description": "This is an output key",
|
||||
},
|
||||
metadata={"id": "node_2"},
|
||||
),
|
||||
],
|
||||
links=[
|
||||
Link(
|
||||
source_id="node_0_a",
|
||||
sink_id="node_1",
|
||||
source_name="result",
|
||||
sink_name="input",
|
||||
),
|
||||
Link(
|
||||
source_id="node_0_b",
|
||||
sink_id="node_1",
|
||||
source_name="result",
|
||||
sink_name="input",
|
||||
),
|
||||
Link(
|
||||
source_id="node_1",
|
||||
sink_id="node_2",
|
||||
source_name="output",
|
||||
sink_name="data",
|
||||
sink_name="value",
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -184,25 +137,21 @@ async def test_get_input_schema_with_linked_blocks(server: SpinTestServer):
|
||||
create_graph, DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
input_schema = created_graph.get_input_schema()
|
||||
class ExpectedInputSchema(BlockSchema):
|
||||
in_key_a: Any = SchemaField(title="Key A", default="A", advanced=False)
|
||||
in_key_b: Any = SchemaField(title="in_key_b", advanced=True)
|
||||
|
||||
assert len(input_schema) == 2
|
||||
class ExpectedOutputSchema(BlockSchema):
|
||||
out_key: Any = SchemaField(
|
||||
description="This is an output key",
|
||||
title="out_key",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
node_1_input = next(
|
||||
(item for item in input_schema if item.node_id == created_graph.nodes[0].id),
|
||||
None,
|
||||
)
|
||||
node_2_input = next(
|
||||
(item for item in input_schema if item.node_id == created_graph.nodes[1].id),
|
||||
None,
|
||||
)
|
||||
input_schema = created_graph.input_schema
|
||||
input_schema["title"] = "ExpectedInputSchema"
|
||||
assert input_schema == ExpectedInputSchema.jsonschema()
|
||||
|
||||
assert node_1_input is not None
|
||||
assert node_2_input is not None
|
||||
assert node_1_input.title == "Input"
|
||||
assert node_2_input.title == "Input"
|
||||
|
||||
assert not any(
|
||||
item.title == "data" and item.node_id == created_graph.nodes[1].id
|
||||
for item in input_schema
|
||||
)
|
||||
output_schema = created_graph.output_schema
|
||||
output_schema["title"] = "ExpectedOutputSchema"
|
||||
assert output_schema == ExpectedOutputSchema.jsonschema()
|
||||
|
||||
@@ -626,10 +626,13 @@ export function CustomNode({
|
||||
|
||||
<div className="flex w-full flex-col">
|
||||
<div className="flex flex-row items-center justify-between">
|
||||
<div className="font-roboto px-3 text-lg font-semibold">
|
||||
<div className="font-roboto flex items-center px-3 text-lg font-semibold">
|
||||
{beautifyString(
|
||||
data.blockType?.replace(/Block$/, "") || data.title,
|
||||
)}
|
||||
<div className="px-2 text-xs text-gray-500">
|
||||
#{id.split("-")[0]}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{blockCost && (
|
||||
|
||||
@@ -18,7 +18,7 @@ export async function NavBar() {
|
||||
const { user } = await getServerUser();
|
||||
|
||||
return (
|
||||
<header className="sticky top-0 z-50 mx-4 flex h-16 items-center gap-4 border border-gray-300 bg-background p-3 md:rounded-b-2xl md:px-6 md:shadow">
|
||||
<header className="sticky top-0 z-50 mx-4 flex h-16 select-none items-center gap-4 border border-gray-300 bg-background p-3 md:rounded-b-2xl md:px-6 md:shadow">
|
||||
<div className="flex flex-1 items-center gap-4">
|
||||
<Sheet>
|
||||
<SheetTrigger asChild>
|
||||
|
||||
@@ -32,7 +32,7 @@ const PrimaryActionBar: React.FC<PrimaryActionBarProps> = ({
|
||||
const runButtonOnClick = !isRunning ? onClickRunAgent : requestStopRun;
|
||||
|
||||
return (
|
||||
<div className="absolute bottom-0 left-1/2 z-50 flex w-fit -translate-x-1/2 transform items-center justify-center p-4">
|
||||
<div className="absolute bottom-0 left-1/2 z-50 flex w-fit -translate-x-1/2 transform select-none items-center justify-center p-4">
|
||||
<div className={`flex gap-4`}>
|
||||
<Tooltip key="ViewOutputs" delayDuration={500}>
|
||||
<TooltipTrigger asChild>
|
||||
|
||||
@@ -48,7 +48,7 @@ const TallyPopupSimple = () => {
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="fixed bottom-1 right-6 z-50 hidden items-center gap-4 p-3 transition-all duration-300 ease-in-out md:flex">
|
||||
<div className="fixed bottom-1 right-6 z-50 hidden select-none items-center gap-4 p-3 transition-all duration-300 ease-in-out md:flex">
|
||||
<Button
|
||||
variant="default"
|
||||
onClick={resetTutorial}
|
||||
|
||||
@@ -352,9 +352,14 @@ const NodeKeyValueInput: FC<{
|
||||
const defaultEntries = new Map(
|
||||
Object.entries(entries ?? schema.default ?? {}),
|
||||
);
|
||||
const prefix = `${selfKey}_#_`;
|
||||
connections
|
||||
.filter((c) => c.targetHandle.startsWith(prefix) && c.target === nodeId)
|
||||
.map((c) => c.targetHandle.slice(prefix.length))
|
||||
.forEach((k) => !defaultEntries.has(k) && defaultEntries.set(k, ""));
|
||||
|
||||
return Array.from(defaultEntries, ([key, value]) => ({ key, value }));
|
||||
}, [entries, schema.default]);
|
||||
}, [entries, schema.default, connections, nodeId, selfKey]);
|
||||
|
||||
const [keyValuePairs, setKeyValuePairs] = useState<
|
||||
{ key: string; value: string | number | null }[]
|
||||
@@ -497,6 +502,18 @@ const NodeArrayInput: FC<{
|
||||
displayName,
|
||||
}) => {
|
||||
entries ??= schema.default ?? [];
|
||||
|
||||
const prefix = `${selfKey}_$_`;
|
||||
connections
|
||||
.filter((c) => c.targetHandle.startsWith(prefix) && c.target === nodeId)
|
||||
.map((c) => parseInt(c.targetHandle.slice(prefix.length)))
|
||||
.filter((c) => !isNaN(c))
|
||||
.forEach(
|
||||
(c) =>
|
||||
entries.length <= c &&
|
||||
entries.push(...Array(c - entries.length + 1).fill("")),
|
||||
);
|
||||
|
||||
const isItemObject = "items" in schema && "properties" in schema.items!;
|
||||
const error =
|
||||
typeof errors[selfKey] === "string" ? errors[selfKey] : undefined;
|
||||
|
||||
@@ -487,7 +487,16 @@ export default function useAgentGraph(
|
||||
},
|
||||
);
|
||||
})
|
||||
.catch(() => setSaveRunRequest({ request: "run", state: "error" }));
|
||||
.catch((error) => {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : String(error);
|
||||
toast({
|
||||
variant: "destructive",
|
||||
title: "Error saving agent",
|
||||
description: errorMessage,
|
||||
});
|
||||
setSaveRunRequest({ request: "run", state: "error" });
|
||||
});
|
||||
|
||||
processedUpdates.current = processedUpdates.current = [];
|
||||
}
|
||||
@@ -510,6 +519,7 @@ export default function useAgentGraph(
|
||||
}
|
||||
}, [
|
||||
api,
|
||||
toast,
|
||||
saveRunRequest,
|
||||
savedAgent,
|
||||
nodesSyncedWithSavedAgent,
|
||||
@@ -590,8 +600,7 @@ export default function useAgentGraph(
|
||||
[availableNodes],
|
||||
);
|
||||
|
||||
const _saveAgent = (
|
||||
() =>
|
||||
const _saveAgent = useCallback(
|
||||
async (asTemplate: boolean = false) => {
|
||||
//FIXME frontend ids should be resolved better (e.g. returned from the server)
|
||||
// currently this relays on block_id and position
|
||||
@@ -745,8 +754,20 @@ export default function useAgentGraph(
|
||||
},
|
||||
}));
|
||||
});
|
||||
}
|
||||
)();
|
||||
},
|
||||
[
|
||||
api,
|
||||
nodes,
|
||||
edges,
|
||||
pathname,
|
||||
router,
|
||||
searchParams,
|
||||
savedAgent,
|
||||
agentName,
|
||||
agentDescription,
|
||||
prepareNodeInputData,
|
||||
],
|
||||
);
|
||||
|
||||
const saveAgent = useCallback(
|
||||
async (asTemplate: boolean = false) => {
|
||||
@@ -763,18 +784,7 @@ export default function useAgentGraph(
|
||||
});
|
||||
}
|
||||
},
|
||||
[
|
||||
api,
|
||||
nodes,
|
||||
edges,
|
||||
pathname,
|
||||
router,
|
||||
searchParams,
|
||||
savedAgent,
|
||||
agentName,
|
||||
agentDescription,
|
||||
prepareNodeInputData,
|
||||
],
|
||||
[_saveAgent, toast],
|
||||
);
|
||||
|
||||
const requestSave = useCallback(
|
||||
|
||||
@@ -111,6 +111,8 @@ env:
|
||||
AGENTSERVER_HOST: "autogpt-server.prod-agpt.svc.cluster.local"
|
||||
EXECUTIONMANAGER_HOST: "autogpt-server-executor.prod-agpt.svc.cluster.local"
|
||||
DATABASEMANAGER_HOST: "autogpt-server-executor.prod-agpt.svc.cluster.local"
|
||||
PYRO_CLIENT_COMM_TIMEOUT: 15
|
||||
PYRO_CLIENT_COMM_RETRY: 3
|
||||
|
||||
secrets:
|
||||
REPLICATE_API_KEY: "AgCPCgcYb+tE8/k45Y7/my4G2jWPCuEMTXJIn1fG1q4x4ZJPFzb43m7Uqtwn23NkmUZ5Qvh8BXedrtHwxapuYzw/P6c7xK66xfLKRbTWtYk4twS3sxPb+pt1FXY4USEjj5yeIFduybkqhE2QfnGoyrbDZ4Bz3AIgnrRD0Ee5m9u5yNZTPmJqZZqg4MRdUBCxCWIJBkW6DCE9nCPAQeNPD6e+lZ1j+/LocT2HX/ZlcsPXCxbn6wkxoyLqA0vUKSG9azS6oLvn0/3Cb01ozG8S2OEAqWIImFqhKGMfGqL6jSZWln43cmQdMTzSzM+HiprA9JHjZqGK7wOV9HZvSR+58IXoJGPBEIM7jIg5KqPjpZY4KFZBp5OiiRRYu+nCbuD+KsY/7ogjPHjbi1rpR8TrtXdzWNmwsTTmjytB/KEqeUpLWOEPgArFPyrNTS5/nmREH7r9jNEhfIRdTlS3IVGGXp/VN8napbNND1GDyzowvF771neq7/zTmfCRCJ4J0gwPNKM5rzOuRW+caEf2qOFBKIldVa/J0PFg5bAgpGL6jhpXHj0Q/+j1s3FA/D2ZebZTPIpKe40It3sWsS/0Qjhbj1GMbL4yUWvGpBSUTk7kZazkaVND1LbhjC+4AolTQdIU4MgW0bkmDn5ZI4a9/dHyLS3lFeYNSQ6vnbz+Id7zB3O0D6/FH8nfAUGL8V+J3eFKMp+G67z+XYH6WGABaNicz41zFBDF5hRax+k/ZziPPlFY0kDc3cAB6pLc"
|
||||
|
||||
Reference in New Issue
Block a user