Merge branch 'add-iffy-moderation' of https://github.com/Significant-Gravitas/AutoGPT into add-iffy-moderation

This commit is contained in:
Bentlybro
2025-03-28 17:05:50 +00:00
32 changed files with 613 additions and 338 deletions

View File

@@ -550,3 +550,17 @@ class AgentToggleInputBlock(AgentInputBlock):
("result", False),
],
)
IO_BLOCK_IDs = [
AgentInputBlock().id,
AgentOutputBlock().id,
AgentShortTextInputBlock().id,
AgentLongTextInputBlock().id,
AgentNumberInputBlock().id,
AgentDateInputBlock().id,
AgentTimeInputBlock().id,
AgentFileInputBlock().id,
AgentDropdownInputBlock().id,
AgentToggleInputBlock().id,
]

View File

@@ -12,6 +12,7 @@ from typing import (
Literal,
Optional,
TypeVar,
overload,
)
from prisma import Json
@@ -36,7 +37,11 @@ from backend.util.settings import Config
from .block import BlockData, BlockInput, BlockType, CompletedBlockOutput, get_block
from .db import BaseDbModel
from .includes import EXECUTION_RESULT_INCLUDE, GRAPH_EXECUTION_INCLUDE
from .includes import (
EXECUTION_RESULT_INCLUDE,
GRAPH_EXECUTION_INCLUDE,
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
)
from .model import GraphExecutionStats, NodeExecutionStats
from .queue import AsyncRedisEventBus, RedisEventBus
@@ -103,7 +108,6 @@ class GraphExecutionMeta(BaseDbModel):
class GraphExecution(GraphExecutionMeta):
inputs: BlockInput
outputs: CompletedBlockOutput
node_executions: list["NodeExecutionResult"]
@staticmethod
def from_db(_graph_exec: AgentGraphExecution):
@@ -158,6 +162,32 @@ class GraphExecution(GraphExecutionMeta):
},
inputs=inputs,
outputs=outputs,
)
class GraphExecutionWithNodes(GraphExecution):
node_executions: list["NodeExecutionResult"]
@staticmethod
def from_db(_graph_exec: AgentGraphExecution):
if _graph_exec.AgentNodeExecutions is None:
raise ValueError("Node executions must be included in query")
graph_exec_with_io = GraphExecution.from_db(_graph_exec)
node_executions = sorted(
[
NodeExecutionResult.from_db(ne, _graph_exec.userId)
for ne in _graph_exec.AgentNodeExecutions
],
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
)
return GraphExecutionWithNodes(
**{
field_name: getattr(graph_exec_with_io, field_name)
for field_name in graph_exec_with_io.model_fields
},
node_executions=node_executions,
)
@@ -250,12 +280,51 @@ async def get_graph_execution_meta(
return GraphExecutionMeta.from_db(execution) if execution else None
async def get_graph_execution(user_id: str, execution_id: str) -> GraphExecution | None:
@overload
async def get_graph_execution(
user_id: str,
execution_id: str,
include_node_executions: Literal[True],
) -> GraphExecutionWithNodes | None: ...
@overload
async def get_graph_execution(
user_id: str,
execution_id: str,
include_node_executions: Literal[False] = False,
) -> GraphExecution | None: ...
@overload
async def get_graph_execution(
user_id: str,
execution_id: str,
include_node_executions: bool = False,
) -> GraphExecution | GraphExecutionWithNodes | None: ...
async def get_graph_execution(
user_id: str,
execution_id: str,
include_node_executions: bool = False,
) -> GraphExecution | GraphExecutionWithNodes | None:
execution = await AgentGraphExecution.prisma().find_first(
where={"id": execution_id, "isDeleted": False, "userId": user_id},
include=GRAPH_EXECUTION_INCLUDE,
include=(
GRAPH_EXECUTION_INCLUDE_WITH_NODES
if include_node_executions
else GRAPH_EXECUTION_INCLUDE
),
)
if not execution:
return None
return (
GraphExecutionWithNodes.from_db(execution)
if include_node_executions
else GraphExecution.from_db(execution)
)
return GraphExecution.from_db(execution) if execution else None
async def create_graph_execution(
@@ -264,7 +333,7 @@ async def create_graph_execution(
nodes_input: list[tuple[str, BlockInput]],
user_id: str,
preset_id: str | None = None,
) -> tuple[str, list[NodeExecutionResult]]:
) -> GraphExecutionWithNodes:
"""
Create a new AgentGraphExecution record.
Returns:
@@ -294,13 +363,10 @@ async def create_graph_execution(
"userId": user_id,
"agentPresetId": preset_id,
},
include=GRAPH_EXECUTION_INCLUDE,
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
)
return result.id, [
NodeExecutionResult.from_db(execution, result.userId)
for execution in result.AgentNodeExecutions or []
]
return GraphExecutionWithNodes.from_db(result)
async def upsert_execution_input(
@@ -322,17 +388,20 @@ async def upsert_execution_input(
node_exec_id: [Optional] The id of the AgentNodeExecution that has no `input_name` as input. If not provided, it will find the eligible incomplete AgentNodeExecution or create a new one.
Returns:
* The id of the created or existing AgentNodeExecution.
* Dict of node input data, key is the input name, value is the input data.
str: The id of the created or existing AgentNodeExecution.
dict[str, Any]: Node input data; key is the input name, value is the input data.
"""
existing_exec_query_filter: AgentNodeExecutionWhereInput = {
"agentNodeId": node_id,
"agentGraphExecutionId": graph_exec_id,
"executionStatus": ExecutionStatus.INCOMPLETE,
"Input": {"every": {"name": {"not": input_name}}},
}
if node_exec_id:
existing_exec_query_filter["id"] = node_exec_id
existing_execution = await AgentNodeExecution.prisma().find_first(
where={ # type: ignore
**({"id": node_exec_id} if node_exec_id else {}),
"agentNodeId": node_id,
"agentGraphExecutionId": graph_exec_id,
"executionStatus": ExecutionStatus.INCOMPLETE,
"Input": {"every": {"name": {"not": input_name}}},
},
where=existing_exec_query_filter,
order={"addedTime": "asc"},
include={"Input": True},
)
@@ -388,25 +457,26 @@ async def upsert_execution_output(
)
async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecutionMeta:
async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecution:
res = await AgentGraphExecution.prisma().update(
where={"id": graph_exec_id},
data={
"executionStatus": ExecutionStatus.RUNNING,
"startedAt": datetime.now(tz=timezone.utc),
},
include=GRAPH_EXECUTION_INCLUDE,
)
if not res:
raise ValueError(f"Graph execution #{graph_exec_id} not found")
return GraphExecutionMeta.from_db(res)
return GraphExecution.from_db(res)
async def update_graph_execution_stats(
graph_exec_id: str,
status: ExecutionStatus,
stats: GraphExecutionStats | None = None,
) -> GraphExecutionMeta | None:
) -> GraphExecution | None:
data = stats.model_dump() if stats else {}
if isinstance(data.get("error"), Exception):
data["error"] = str(data["error"])
@@ -422,9 +492,10 @@ async def update_graph_execution_stats(
"executionStatus": status,
"stats": Json(data),
},
include=GRAPH_EXECUTION_INCLUDE,
)
return GraphExecutionMeta.from_db(res) if res else None
return GraphExecution.from_db(res) if res else None
async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionStats):
@@ -774,7 +845,7 @@ class ExecutionEventType(str, Enum):
NODE_EXEC_UPDATE = "node_execution_update"
class GraphExecutionEvent(GraphExecutionMeta):
class GraphExecutionEvent(GraphExecution):
event_type: Literal[ExecutionEventType.GRAPH_EXEC_UPDATE] = (
ExecutionEventType.GRAPH_EXEC_UPDATE
)
@@ -798,8 +869,8 @@ class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
def event_bus_name(self) -> str:
return config.execution_event_bus_name
def publish(self, res: GraphExecutionMeta | NodeExecutionResult):
if isinstance(res, GraphExecutionMeta):
def publish(self, res: GraphExecution | NodeExecutionResult):
if isinstance(res, GraphExecution):
self.publish_graph_exec_update(res)
else:
self.publish_node_exec_update(res)
@@ -808,7 +879,7 @@ class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
event = NodeExecutionEvent.model_validate(res.model_dump())
self.publish_event(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
def publish_graph_exec_update(self, res: GraphExecutionMeta):
def publish_graph_exec_update(self, res: GraphExecution):
event = GraphExecutionEvent.model_validate(res.model_dump())
self.publish_event(event, f"{res.user_id}/{res.graph_id}/{res.id}")

View File

@@ -538,7 +538,6 @@ async def get_graph_metadata(graph_id: str, version: int | None = None) -> Graph
graph = await AgentGraph.prisma().find_first(
where=where_clause,
include=AGENT_GRAPH_INCLUDE,
order={"version": "desc"},
)

View File

@@ -1,5 +1,7 @@
import prisma
from backend.blocks.io import IO_BLOCK_IDs
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
"Input": True,
"Output": True,
@@ -20,7 +22,7 @@ EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
MAX_NODE_EXECUTIONS_FETCH = 1000
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
"AgentNodeExecutions": {
"include": {
"Input": True,
@@ -37,6 +39,17 @@ GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
}
}
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
"AgentNodeExecutions": {
**GRAPH_EXECUTION_INCLUDE_WITH_NODES["AgentNodeExecutions"], # type: ignore
"where": {
"AgentNode": {
"AgentBlock": {"id": {"in": IO_BLOCK_IDs}}, # type: ignore
},
},
}
}
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore

View File

@@ -1,9 +1,10 @@
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import (
GraphExecutionMeta,
GraphExecution,
NodeExecutionResult,
RedisExecutionEventBus,
create_graph_execution,
get_graph_execution,
get_incomplete_node_executions,
get_latest_node_execution,
get_node_execution_results,
@@ -65,11 +66,12 @@ class DatabaseManager(AppService):
@expose
def send_execution_update(
self, execution_result: GraphExecutionMeta | NodeExecutionResult
self, execution_result: GraphExecution | NodeExecutionResult
):
self.execution_event_bus.publish(execution_result)
# Executions
get_graph_execution = exposed_run_and_wait(get_graph_execution)
create_graph_execution = exposed_run_and_wait(create_graph_execution)
get_node_execution_results = exposed_run_and_wait(get_node_execution_results)
get_incomplete_node_executions = exposed_run_and_wait(

View File

@@ -150,7 +150,7 @@ def execute_node(
node_exec_id = data.node_exec_id
node_id = data.node_id
def update_execution(status: ExecutionStatus) -> NodeExecutionResult:
def update_execution_status(status: ExecutionStatus) -> NodeExecutionResult:
"""Sets status and fetches+broadcasts the latest state of the node execution"""
exec_update = db_client.update_node_execution_status(node_exec_id, status)
db_client.send_execution_update(exec_update)
@@ -163,6 +163,17 @@ def execute_node(
logger.error(f"Block {node.block_id} not found.")
return
def push_output(output_name: str, output_data: Any) -> None:
_push_node_execution_output(
db_client=db_client,
user_id=user_id,
graph_exec_id=graph_exec_id,
node_exec_id=node_exec_id,
block_id=node_block.id,
output_name=output_name,
output_data=output_data,
)
log_metadata = LogMetadata(
user_id=user_id,
graph_eid=graph_exec_id,
@@ -176,8 +187,8 @@ def execute_node(
input_data, error = validate_exec(node, data.data, resolve_input=False)
if input_data is None:
log_metadata.error(f"Skip execution, input validation error: {error}")
db_client.upsert_execution_output(node_exec_id, "error", error)
update_execution(ExecutionStatus.FAILED)
push_output("error", error)
update_execution_status(ExecutionStatus.FAILED)
return
# Re-shape the input data for agent block.
@@ -190,7 +201,7 @@ def execute_node(
input_data_str = json.dumps(input_data)
input_size = len(input_data_str)
log_metadata.info("Executed node with input", input=input_data_str)
update_execution(ExecutionStatus.RUNNING)
update_execution_status(ExecutionStatus.RUNNING)
# Inject extra execution arguments for the blocks via kwargs
extra_exec_kwargs: dict = {
@@ -221,7 +232,7 @@ def execute_node(
output_data = json.convert_pydantic_to_json(output_data)
output_size += len(json.dumps(output_data))
log_metadata.info("Node produced output", **{output_name: output_data})
db_client.upsert_execution_output(node_exec_id, output_name, output_data)
push_output(output_name, output_data)
outputs[output_name] = output_data
for execution in _enqueue_next_nodes(
db_client=db_client,
@@ -234,13 +245,12 @@ def execute_node(
):
yield execution
# Update execution status and spend credits
update_execution(ExecutionStatus.COMPLETED)
update_execution_status(ExecutionStatus.COMPLETED)
except Exception as e:
error_msg = str(e)
db_client.upsert_execution_output(node_exec_id, "error", error_msg)
update_execution(ExecutionStatus.FAILED)
push_output("error", error_msg)
update_execution_status(ExecutionStatus.FAILED)
for execution in _enqueue_next_nodes(
db_client=db_client,
@@ -271,6 +281,35 @@ def execute_node(
execution_stats.output_size = output_size
def _push_node_execution_output(
db_client: "DatabaseManager",
user_id: str,
graph_exec_id: str,
node_exec_id: str,
block_id: str,
output_name: str,
output_data: Any,
):
from backend.blocks.io import IO_BLOCK_IDs
db_client.upsert_execution_output(
node_exec_id=node_exec_id,
output_name=output_name,
output_data=output_data,
)
# Automatically push execution updates for all agent I/O
if block_id in IO_BLOCK_IDs:
graph_exec = db_client.get_graph_execution(
user_id=user_id, execution_id=graph_exec_id
)
if not graph_exec:
raise ValueError(
f"Graph execution #{graph_exec_id} for user #{user_id} not found"
)
db_client.send_execution_update(graph_exec)
def _enqueue_next_nodes(
db_client: "DatabaseManager",
node: Node,
@@ -628,7 +667,10 @@ class Executor:
node_eid="*",
block_name="-",
)
cls.db_client.update_graph_execution_start_time(graph_exec.graph_exec_id)
exec_meta = cls.db_client.update_graph_execution_start_time(
graph_exec.graph_exec_id
)
cls.db_client.send_execution_update(exec_meta)
timing_info, (exec_stats, status, error) = cls._on_graph_execution(
graph_exec, cancel, log_metadata
)
@@ -636,12 +678,12 @@ class Executor:
exec_stats.cputime = timing_info.cpu_time
exec_stats.error = str(error)
if result := cls.db_client.update_graph_execution_stats(
if graph_exec_result := cls.db_client.update_graph_execution_stats(
graph_exec_id=graph_exec.graph_exec_id,
status=status,
stats=exec_stats,
):
cls.db_client.send_execution_update(result)
cls.db_client.send_execution_update(graph_exec_result)
cls._handle_agent_run_notif(graph_exec, exec_stats)
@@ -774,11 +816,19 @@ class Executor:
execution_stats=exec_stats,
)
except InsufficientBalanceError as error:
exec_id = exec_data.node_exec_id
cls.db_client.upsert_execution_output(exec_id, "error", str(error))
node_exec_id = exec_data.node_exec_id
_push_node_execution_output(
db_client=cls.db_client,
user_id=graph_exec.user_id,
graph_exec_id=graph_exec.graph_exec_id,
node_exec_id=node_exec_id,
block_id=exec_data.block_id,
output_name="error",
output_data=str(error),
)
exec_update = cls.db_client.update_node_execution_status(
exec_id, ExecutionStatus.FAILED
node_exec_id, ExecutionStatus.FAILED
)
cls.db_client.send_execution_update(exec_update)
@@ -1002,13 +1052,14 @@ class ExecutionManager(AppService):
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
)
graph_exec_id, node_execs = self.db_client.create_graph_execution(
graph_exec = self.db_client.create_graph_execution(
graph_id=graph_id,
graph_version=graph.version,
nodes_input=nodes_input,
user_id=user_id,
preset_id=preset_id,
)
self.db_client.send_execution_update(graph_exec)
# Right after creating the graph execution, we need to check if the content is safe
if settings.config.behave_as != BehaveAs.LOCAL:
@@ -1032,18 +1083,12 @@ class ExecutionManager(AppService):
block_id=node_exec.block_id,
data=node_exec.input_data,
)
)
graph_exec = GraphExecutionEntry(
user_id=user_id,
graph_id=graph_id,
graph_version=graph_version or 0,
graph_exec_id=graph_exec_id,
start_node_execs=starting_node_execs,
for node_exec in graph_exec.node_executions
],
)
self.queue.add(graph_exec)
self.queue.add(graph_exec_entry)
return graph_exec
return graph_exec_entry
@expose
def cancel_execution(self, graph_exec_id: str) -> None:

View File

@@ -32,22 +32,30 @@ class ConnectionManager:
async def subscribe_graph_exec(
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
) -> str:
key = _graph_exec_channel_key(user_id, graph_exec_id)
if key not in self.subscriptions:
self.subscriptions[key] = set()
self.subscriptions[key].add(websocket)
return key
return await self._subscribe(
_graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id), websocket
)
async def unsubscribe(
async def subscribe_graph_execs(
self, *, user_id: str, graph_id: str, websocket: WebSocket
) -> str:
return await self._subscribe(
_graph_execs_channel_key(user_id, graph_id=graph_id), websocket
)
async def unsubscribe_graph_exec(
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
) -> str | None:
key = _graph_exec_channel_key(user_id, graph_exec_id)
if key in self.subscriptions:
self.subscriptions[key].discard(websocket)
if not self.subscriptions[key]:
del self.subscriptions[key]
return key
return None
return await self._unsubscribe(
_graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id), websocket
)
async def unsubscribe_graph_execs(
self, *, user_id: str, graph_id: str, websocket: WebSocket
) -> str | None:
return await self._unsubscribe(
_graph_execs_channel_key(user_id, graph_id=graph_id), websocket
)
async def send_execution_update(
self, exec_event: GraphExecutionEvent | NodeExecutionEvent
@@ -57,21 +65,51 @@ class ConnectionManager:
if isinstance(exec_event, GraphExecutionEvent)
else exec_event.graph_exec_id
)
key = _graph_exec_channel_key(exec_event.user_id, graph_exec_id)
n_sent = 0
if key in self.subscriptions:
channels: set[str] = {
# Send update to listeners for this graph execution
_graph_exec_channel_key(exec_event.user_id, graph_exec_id=graph_exec_id)
}
if isinstance(exec_event, GraphExecutionEvent):
# Send update to listeners for all executions of this graph
channels.add(
_graph_execs_channel_key(
exec_event.user_id, graph_id=exec_event.graph_id
)
)
for channel in channels.intersection(self.subscriptions.keys()):
message = WSMessage(
method=_EVENT_TYPE_TO_METHOD_MAP[exec_event.event_type],
channel=key,
channel=channel,
data=exec_event.model_dump(),
).model_dump_json()
for connection in self.subscriptions[key]:
for connection in self.subscriptions[channel]:
await connection.send_text(message)
n_sent += 1
return n_sent
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
if channel_key not in self.subscriptions:
self.subscriptions[channel_key] = set()
self.subscriptions[channel_key].add(websocket)
return channel_key
def _graph_exec_channel_key(user_id: str, graph_exec_id: str) -> str:
async def _unsubscribe(self, channel_key: str, websocket: WebSocket) -> str | None:
if channel_key in self.subscriptions:
self.subscriptions[channel_key].discard(websocket)
if not self.subscriptions[channel_key]:
del self.subscriptions[channel_key]
return channel_key
return None
def _graph_exec_channel_key(user_id: str, *, graph_exec_id: str) -> str:
return f"{user_id}|graph_exec#{graph_exec_id}"
def _graph_execs_channel_key(user_id: str, *, graph_id: str) -> str:
return f"{user_id}|graph#{graph_id}|executions"

View File

@@ -9,6 +9,7 @@ from backend.data.graph import Graph
class WSMethod(enum.Enum):
SUBSCRIBE_GRAPH_EXEC = "subscribe_graph_execution"
SUBSCRIBE_GRAPH_EXECS = "subscribe_graph_executions"
UNSUBSCRIBE = "unsubscribe"
GRAPH_EXECUTION_EVENT = "graph_execution_event"
NODE_EXECUTION_EVENT = "node_execution_event"
@@ -28,6 +29,10 @@ class WSSubscribeGraphExecutionRequest(pydantic.BaseModel):
graph_exec_id: str
class WSSubscribeGraphExecutionsRequest(pydantic.BaseModel):
graph_id: str
class ExecuteGraphResponse(pydantic.BaseModel):
graph_exec_id: str

View File

@@ -193,14 +193,6 @@ class AgentServer(backend.util.service.AppProcess):
raise ValueError(f"Execution {graph_exec_id} not found")
return execution.status
@staticmethod
async def test_get_graph_run_results(
graph_id: str, graph_exec_id: str, user_id: str
):
return await backend.server.routers.v1.get_graph_execution(
graph_id, graph_exec_id, user_id
)
@staticmethod
async def test_delete_graph(graph_id: str, user_id: str):
await backend.server.v2.library.db.delete_library_agent_by_graph_id(

View File

@@ -10,7 +10,7 @@ from autogpt_libs.auth.middleware import auth_middleware
from autogpt_libs.feature_flag.client import feature_flag
from autogpt_libs.utils.cache import thread_cached
from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response
from starlette.status import HTTP_204_NO_CONTENT
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
from typing_extensions import Optional, TypedDict
import backend.data.block
@@ -653,9 +653,17 @@ async def get_graph_execution(
graph_id: str,
graph_exec_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> execution_db.GraphExecution:
) -> execution_db.GraphExecution | execution_db.GraphExecutionWithNodes:
graph = await graph_db.get_graph(graph_id=graph_id, user_id=user_id)
if not graph:
raise HTTPException(
status_code=HTTP_404_NOT_FOUND, detail=f"Graph #{graph_id} not found"
)
result = await execution_db.get_graph_execution(
execution_id=graph_exec_id, user_id=user_id
user_id=user_id,
execution_id=graph_exec_id,
include_node_executions=graph.user_id == user_id,
)
if not result or result.graph_id != graph_id:
raise HTTPException(

View File

@@ -1,6 +1,7 @@
import asyncio
import logging
from contextlib import asynccontextmanager
from typing import Protocol
import uvicorn
from autogpt_libs.auth import parse_jwt_token
@@ -12,7 +13,12 @@ from backend.data import redis
from backend.data.execution import AsyncRedisExecutionEventBus
from backend.data.user import DEFAULT_USER_ID
from backend.server.conn_manager import ConnectionManager
from backend.server.model import WSMessage, WSMethod, WSSubscribeGraphExecutionRequest
from backend.server.model import (
WSMessage,
WSMethod,
WSSubscribeGraphExecutionRequest,
WSSubscribeGraphExecutionsRequest,
)
from backend.util.service import AppProcess, get_service_client
from backend.util.settings import AppEnvironment, Config, Settings
@@ -81,6 +87,19 @@ async def authenticate_websocket(websocket: WebSocket) -> str:
return ""
# ===================== Message Handlers ===================== #
class WSMessageHandler(Protocol):
async def __call__(
self,
connection_manager: ConnectionManager,
websocket: WebSocket,
user_id: str,
message: WSMessage,
): ...
async def handle_subscribe(
connection_manager: ConnectionManager,
websocket: WebSocket,
@@ -95,41 +114,53 @@ async def handle_subscribe(
error="Subscription data missing",
).model_dump_json()
)
else:
return
# Verify that user has read access to graph
# if not get_db_client().get_graph(
# graph_id=sub_req.graph_id,
# version=sub_req.graph_version,
# user_id=user_id,
# ):
# await websocket.send_text(
# WsMessage(
# method=Methods.ERROR,
# success=False,
# error="Access denied",
# ).model_dump_json()
# )
# return
if message.method == WSMethod.SUBSCRIBE_GRAPH_EXEC:
sub_req = WSSubscribeGraphExecutionRequest.model_validate(message.data)
# Verify that user has read access to graph
# if not get_db_client().get_graph(
# graph_id=sub_req.graph_id,
# version=sub_req.graph_version,
# user_id=user_id,
# ):
# await websocket.send_text(
# WsMessage(
# method=Methods.ERROR,
# success=False,
# error="Access denied",
# ).model_dump_json()
# )
# return
channel_key = await connection_manager.subscribe_graph_exec(
user_id=user_id,
graph_exec_id=sub_req.graph_exec_id,
websocket=websocket,
)
logger.debug(
f"New subscription for user #{user_id}, "
f"graph execution #{sub_req.graph_exec_id}"
elif message.method == WSMethod.SUBSCRIBE_GRAPH_EXECS:
sub_req = WSSubscribeGraphExecutionsRequest.model_validate(message.data)
channel_key = await connection_manager.subscribe_graph_execs(
user_id=user_id,
graph_id=sub_req.graph_id,
websocket=websocket,
)
await websocket.send_text(
WSMessage(
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
success=True,
channel=channel_key,
).model_dump_json()
else:
raise ValueError(
f"{handle_subscribe.__name__} can't handle '{message.method}' messages"
)
logger.debug(f"New subscription on channel {channel_key} for user #{user_id}")
await websocket.send_text(
WSMessage(
method=message.method,
success=True,
channel=channel_key,
).model_dump_json()
)
async def handle_unsubscribe(
connection_manager: ConnectionManager,
@@ -145,29 +176,49 @@ async def handle_unsubscribe(
error="Subscription data missing",
).model_dump_json()
)
else:
unsub_req = WSSubscribeGraphExecutionRequest.model_validate(message.data)
channel_key = await connection_manager.unsubscribe(
user_id=user_id,
graph_exec_id=unsub_req.graph_exec_id,
websocket=websocket,
)
logger.debug(
f"Removed subscription for user #{user_id}, "
f"graph execution #{unsub_req.graph_exec_id}"
)
await websocket.send_text(
WSMessage(
method=WSMethod.UNSUBSCRIBE,
success=True,
channel=channel_key,
).model_dump_json()
)
return
unsub_req = WSSubscribeGraphExecutionRequest.model_validate(message.data)
channel_key = await connection_manager.unsubscribe_graph_exec(
user_id=user_id,
graph_exec_id=unsub_req.graph_exec_id,
websocket=websocket,
)
logger.debug(f"Removed subscription on channel {channel_key} for user #{user_id}")
await websocket.send_text(
WSMessage(
method=WSMethod.UNSUBSCRIBE,
success=True,
channel=channel_key,
).model_dump_json()
)
@app.get("/")
async def health():
return {"status": "healthy"}
async def handle_heartbeat(
connection_manager: ConnectionManager,
websocket: WebSocket,
user_id: str,
message: WSMessage,
):
await websocket.send_json(
{
"method": WSMethod.HEARTBEAT.value,
"data": "pong",
"success": True,
}
)
_MSG_HANDLERS: dict[WSMethod, WSMessageHandler] = {
WSMethod.HEARTBEAT: handle_heartbeat,
WSMethod.SUBSCRIBE_GRAPH_EXEC: handle_subscribe,
WSMethod.SUBSCRIBE_GRAPH_EXECS: handle_subscribe,
WSMethod.UNSUBSCRIBE: handle_unsubscribe,
}
# ===================== WebSocket Server ===================== #
@app.websocket("/ws")
@@ -183,28 +234,9 @@ async def websocket_router(
data = await websocket.receive_text()
message = WSMessage.model_validate_json(data)
if message.method == WSMethod.HEARTBEAT:
await websocket.send_json(
{
"method": WSMethod.HEARTBEAT.value,
"data": "pong",
"success": True,
}
)
continue
try:
if message.method == WSMethod.SUBSCRIBE_GRAPH_EXEC:
await handle_subscribe(
connection_manager=manager,
websocket=websocket,
user_id=user_id,
message=message,
)
continue
elif message.method == WSMethod.UNSUBSCRIBE:
await handle_unsubscribe(
if message.method in _MSG_HANDLERS:
await _MSG_HANDLERS[message.method](
connection_manager=manager,
websocket=websocket,
user_id=user_id,
@@ -213,7 +245,7 @@ async def websocket_router(
continue
except Exception as e:
logger.error(
f"Error while handling '{message.method}' message "
f"Error while handling '{message.method.value}' message "
f"for user #{user_id}: {e}"
)
continue
@@ -239,6 +271,11 @@ async def websocket_router(
logger.debug("WebSocket client disconnected")
@app.get("/")
async def health():
return {"status": "healthy"}
class WebsocketServer(AppProcess):
def run(self):
logger.info(f"CORS allow origins: {settings.config.backend_cors_allow_origins}")

View File

@@ -259,7 +259,6 @@ async def block_autogen_agent():
)
print(response)
result = await wait_execution(
graph_id=test_graph.id,
graph_exec_id=response.graph_exec_id,
timeout=1200,
user_id=test_user.id,

View File

@@ -162,9 +162,7 @@ async def reddit_marketing_agent():
node_input=input_data,
)
print(response)
result = await wait_execution(
test_user.id, test_graph.id, response.graph_exec_id, 120
)
result = await wait_execution(test_user.id, response.graph_exec_id, 120)
print(result)

View File

@@ -94,7 +94,7 @@ async def sample_agent():
user_id=test_user.id,
node_input=input_data,
)
await wait_execution(test_user.id, test_graph.id, response.graph_exec_id, 10)
await wait_execution(test_user.id, response.graph_exec_id, 10)
if __name__ == "__main__":

View File

@@ -5,7 +5,11 @@ from typing import Sequence, cast
from backend.data import db
from backend.data.block import Block, BlockSchema, initialize_blocks
from backend.data.execution import ExecutionStatus, NodeExecutionResult
from backend.data.execution import (
ExecutionStatus,
NodeExecutionResult,
get_graph_execution,
)
from backend.data.model import _BaseCredentials
from backend.data.user import create_default_user
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
@@ -60,7 +64,6 @@ class SpinTestServer:
async def wait_execution(
user_id: str,
graph_id: str,
graph_exec_id: str,
timeout: int = 30,
) -> Sequence[NodeExecutionResult]:
@@ -78,9 +81,12 @@ async def wait_execution(
# Wait for the executions to complete
for i in range(timeout):
if await is_execution_completed():
graph_exec = await AgentServer().test_get_graph_run_results(
graph_id, graph_exec_id, user_id
graph_exec = await get_graph_execution(
user_id=user_id,
execution_id=graph_exec_id,
include_node_executions=True,
)
assert graph_exec, f"Graph execution #{graph_exec_id} not found"
return graph_exec.node_executions
time.sleep(1)

View File

@@ -46,24 +46,22 @@ async def execute_graph(
# Execution queue should be empty
logger.info("Waiting for execution to complete...")
result = await wait_execution(test_user.id, test_graph.id, graph_exec_id, 30)
result = await wait_execution(test_user.id, graph_exec_id, 30)
logger.info(f"Execution completed with {len(result)} results")
assert len(result) == num_execs
return graph_exec_id
async def assert_sample_graph_executions(
agent_server: AgentServer,
test_graph: graph.Graph,
test_user: User,
graph_exec_id: str,
):
logger.info(f"Checking execution results for graph {test_graph.id}")
graph_run = await agent_server.test_get_graph_run_results(
test_graph.id,
graph_exec_id,
test_user.id,
graph_run = await execution.get_graph_execution(
test_user.id, graph_exec_id, include_node_executions=True
)
assert isinstance(graph_run, execution.GraphExecutionWithNodes)
output_list = [{"result": ["Hello"]}, {"result": ["World"]}]
input_list = [
@@ -142,9 +140,7 @@ async def test_agent_execution(server: SpinTestServer):
data,
4,
)
await assert_sample_graph_executions(
server.agent_server, test_graph, test_user, graph_exec_id
)
await assert_sample_graph_executions(test_graph, test_user, graph_exec_id)
logger.info("Completed test_agent_execution")
@@ -203,9 +199,10 @@ async def test_input_pin_always_waited(server: SpinTestServer):
)
logger.info("Checking execution results")
graph_exec = await server.agent_server.test_get_graph_run_results(
test_graph.id, graph_exec_id, test_user.id
graph_exec = await execution.get_graph_execution(
test_user.id, graph_exec_id, include_node_executions=True
)
assert isinstance(graph_exec, execution.GraphExecutionWithNodes)
assert len(graph_exec.node_executions) == 3
# FindInDictionaryBlock should wait for the input pin to be provided,
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
@@ -286,9 +283,10 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
server.agent_server, test_graph, test_user, {}, 8
)
logger.info("Checking execution results")
graph_exec = await server.agent_server.test_get_graph_run_results(
test_graph.id, graph_exec_id, test_user.id
graph_exec = await execution.get_graph_execution(
test_user.id, graph_exec_id, include_node_executions=True
)
assert isinstance(graph_exec, execution.GraphExecutionWithNodes)
assert len(graph_exec.node_executions) == 8
# The last 3 executions will be a+b=4+5=9
for i, exec_data in enumerate(graph_exec.node_executions[-3:]):
@@ -385,7 +383,7 @@ async def test_execute_preset(server: SpinTestServer):
graph_exec_id = result["id"]
# Wait for execution to complete
executions = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
executions = await wait_execution(test_user.id, graph_exec_id)
assert len(executions) == 4
# FindInDictionaryBlock should wait for the input pin to be provided,
@@ -475,7 +473,7 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
graph_exec_id = result["id"]
# Wait for execution to complete
executions = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
executions = await wait_execution(test_user.id, graph_exec_id)
assert len(executions) == 4
# FindInDictionaryBlock should wait for the input pin to be provided,
@@ -542,7 +540,5 @@ async def test_store_listing_graph(server: SpinTestServer):
4,
)
await assert_sample_graph_executions(
server.agent_server, test_graph, alt_test_user, graph_exec_id
)
await assert_sample_graph_executions(test_graph, alt_test_user, graph_exec_id)
logger.info("Completed test_agent_execution")

View File

@@ -55,7 +55,7 @@ async def execute_graph(
# Execution queue should be empty
logger.info("Waiting for execution to complete...")
result = await wait_execution(test_user.id, test_graph.id, graph_exec_id, 30)
result = await wait_execution(test_user.id, graph_exec_id, 30)
logger.info("Execution completed with %d results", len(result))
return graph_exec_id

View File

@@ -68,7 +68,7 @@ async def test_unsubscribe(
channel_key = "user-1|graph_exec#graph-exec-1"
connection_manager.subscriptions[channel_key] = {mock_websocket}
await connection_manager.unsubscribe(
await connection_manager.unsubscribe_graph_exec(
user_id="user-1",
graph_exec_id="graph-exec-1",
websocket=mock_websocket,
@@ -94,6 +94,14 @@ async def test_send_graph_execution_result(
total_run_time=0.5,
started_at=datetime.now(tz=timezone.utc),
ended_at=datetime.now(tz=timezone.utc),
inputs={
"input_1": "some input value :)",
"input_2": "some *other* input value",
},
outputs={
"the_output": ["some output value"],
"other_output": ["sike there was another output"],
},
)
await connection_manager.send_execution_update(result)

View File

@@ -70,7 +70,7 @@ async def test_websocket_router_unsubscribe(
).model_dump_json(),
WebSocketDisconnect(),
]
mock_manager.unsubscribe.return_value = (
mock_manager.unsubscribe_graph_exec.return_value = (
f"{DEFAULT_USER_ID}|graph_exec#test-graph-exec-1"
)
@@ -79,7 +79,7 @@ async def test_websocket_router_unsubscribe(
)
mock_manager.connect_socket.assert_called_once_with(mock_websocket)
mock_manager.unsubscribe.assert_called_once_with(
mock_manager.unsubscribe_graph_exec.assert_called_once_with(
user_id=DEFAULT_USER_ID,
graph_exec_id="test-graph-exec-1",
websocket=mock_websocket,
@@ -168,7 +168,9 @@ async def test_handle_unsubscribe_success(
message = WSMessage(
method=WSMethod.UNSUBSCRIBE, data={"graph_exec_id": "test-graph-exec-id"}
)
mock_manager.unsubscribe.return_value = "user-1|graph_exec#test-graph-exec-id"
mock_manager.unsubscribe_graph_exec.return_value = (
"user-1|graph_exec#test-graph-exec-id"
)
await handle_unsubscribe(
connection_manager=cast(ConnectionManager, mock_manager),
@@ -177,7 +179,7 @@ async def test_handle_unsubscribe_success(
message=message,
)
mock_manager.unsubscribe.assert_called_once_with(
mock_manager.unsubscribe_graph_exec.assert_called_once_with(
user_id="user-1",
graph_exec_id="test-graph-exec-id",
websocket=mock_websocket,
@@ -200,7 +202,7 @@ async def test_handle_unsubscribe_missing_data(
message=message,
)
mock_manager.unsubscribe.assert_not_called()
mock_manager._unsubscribe.assert_not_called()
mock_websocket.send_text.assert_called_once()
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
assert '"success":false' in mock_websocket.send_text.call_args[0][0]

View File

@@ -89,20 +89,23 @@ export default function AgentRunsPage(): React.ReactElement {
(graph && graph.version == _graph.version) || setGraph(_graph),
);
api.getGraphExecutions(agent.agent_id).then((agentRuns) => {
const sortedRuns = agentRuns.toSorted(
(a, b) => Number(b.started_at) - Number(a.started_at),
);
setAgentRuns(sortedRuns);
setAgentRuns(agentRuns);
// Preload the corresponding graph versions
new Set(sortedRuns.map((run) => run.graph_version)).forEach((version) =>
new Set(agentRuns.map((run) => run.graph_version)).forEach((version) =>
getGraphVersion(agent.agent_id, version),
);
if (!selectedView.id && isFirstLoad && sortedRuns.length > 0) {
if (!selectedView.id && isFirstLoad && agentRuns.length > 0) {
// only for first load or first execution
setIsFirstLoad(false);
selectView({ type: "run", id: sortedRuns[0].id });
const latestRun = agentRuns.reduce((latest, current) => {
if (latest.started_at && !current.started_at) return current;
else if (!latest.started_at) return latest;
return latest.started_at > current.started_at ? latest : current;
}, agentRuns[0]);
selectView({ type: "run", id: latestRun.id });
}
});
});
@@ -117,6 +120,39 @@ export default function AgentRunsPage(): React.ReactElement {
fetchAgents();
}, []);
// Subscribe to websocket updates for agent runs
useEffect(() => {
if (!agent) return;
// Subscribe to all executions for this agent
api.subscribeToGraphExecutions(agent.agent_id);
}, [api, agent]);
// Handle execution updates
useEffect(() => {
const detachExecUpdateHandler = api.onWebSocketMessage(
"graph_execution_event",
(data) => {
setAgentRuns((prev) => {
const index = prev.findIndex((run) => run.id === data.id);
if (index === -1) {
return [...prev, data];
}
const newRuns = [...prev];
newRuns[index] = { ...newRuns[index], ...data };
return newRuns;
});
if (data.id === selectedView.id) {
setSelectedRun((prev) => ({ ...prev, ...data }));
}
},
);
return () => {
detachExecUpdateHandler();
};
}, [api, selectedView.id]);
// load selectedRun based on selectedView
useEffect(() => {
if (selectedView.type != "run" || !selectedView.id || !agent) return;
@@ -149,12 +185,6 @@ export default function AgentRunsPage(): React.ReactElement {
fetchSchedules();
}, [fetchSchedules]);
/* TODO: use websockets instead of polling - https://github.com/Significant-Gravitas/AutoGPT/issues/8782 */
useEffect(() => {
const intervalId = setInterval(() => fetchAgents(), 5000);
return () => clearInterval(intervalId);
}, [fetchAgents]);
// =========================== ACTIONS ============================
const deleteRun = useCallback(
@@ -256,6 +286,7 @@ export default function AgentRunsPage(): React.ReactElement {
graph={graphVersions[selectedRun.graph_version] ?? graph}
run={selectedRun}
agentActions={agentActions}
onRun={(runID) => selectRun(runID)}
deleteRun={() => setConfirmingDeleteAgentRun(selectedRun)}
/>
)

View File

@@ -100,13 +100,13 @@ const Monitor = () => {
...(selectedFlow
? executions.filter((v) => v.graph_id == selectedFlow.agent_id)
: executions),
].sort((a, b) => Number(b.started_at) - Number(a.started_at))}
].sort((a, b) => b.started_at.getTime() - a.started_at.getTime())}
selectedRun={selectedRun}
onSelectRun={(r) => setSelectedRun(r.id == selectedRun?.id ? null : r)}
/>
{(selectedRun && (
<FlowRunInfo
flow={
agent={
selectedFlow ||
flows.find((f) => f.agent_id == selectedRun.graph_id)!
}

View File

@@ -4,8 +4,8 @@ import React, {
forwardRef,
useImperativeHandle,
} from "react";
import RunnerOutputUI, { BlockOutput } from "./runner-ui/RunnerOutputUI";
import RunnerInputUI from "./runner-ui/RunnerInputUI";
import RunnerOutputUI from "./runner-ui/RunnerOutputUI";
import { Node } from "@xyflow/react";
import { filterBlocksByType } from "@/lib/utils";
import { BlockIORootSchema, BlockUIType } from "@/lib/autogpt-server-api/types";
@@ -60,7 +60,11 @@ const RunnerUIWrapper = forwardRef<RunnerUIWrapperRef, RunnerUIWrapperProps>(
const [isRunnerOutputOpen, setIsRunnerOutputOpen] = useState(false);
const [scheduledInput, setScheduledInput] = useState(false);
const [cronExpression, setCronExpression] = useState("");
const getBlockInputsAndOutputs = useCallback(() => {
const getBlockInputsAndOutputs = useCallback((): {
inputs: InputItem[];
outputs: BlockOutput[];
} => {
const inputBlocks = filterBlocksByType(
nodes,
(node) => node.data.uiType === BlockUIType.INPUT,
@@ -71,34 +75,37 @@ const RunnerUIWrapper = forwardRef<RunnerUIWrapperRef, RunnerUIWrapperProps>(
(node) => node.data.uiType === BlockUIType.OUTPUT,
);
const inputs = inputBlocks.map((node) => ({
id: node.id,
type: "input" as const,
inputSchema: node.data.inputSchema as BlockIORootSchema,
hardcodedValues: {
name: (node.data.hardcodedValues as any).name || "",
description: (node.data.hardcodedValues as any).description || "",
value: (node.data.hardcodedValues as any).value,
placeholder_values:
(node.data.hardcodedValues as any).placeholder_values || [],
limit_to_placeholder_values:
(node.data.hardcodedValues as any).limit_to_placeholder_values ||
false,
},
}));
const inputs = inputBlocks.map(
(node) =>
({
id: node.id,
type: "input" as const,
inputSchema: node.data.inputSchema as BlockIORootSchema,
hardcodedValues: {
name: (node.data.hardcodedValues as any).name || "",
description: (node.data.hardcodedValues as any).description || "",
value: (node.data.hardcodedValues as any).value,
placeholder_values:
(node.data.hardcodedValues as any).placeholder_values || [],
limit_to_placeholder_values:
(node.data.hardcodedValues as any)
.limit_to_placeholder_values || false,
},
}) satisfies InputItem,
);
const outputs = outputBlocks.map((node) => ({
id: node.id,
type: "output" as const,
hardcodedValues: {
name: (node.data.hardcodedValues as any).name || "Output",
description:
(node.data.hardcodedValues as any).description ||
"Output from the agent",
value: (node.data.hardcodedValues as any).value,
},
result: (node.data.executionResults as any)?.at(-1)?.data?.output,
}));
const outputs = outputBlocks.map(
(node) =>
({
metadata: {
name: (node.data.hardcodedValues as any).name || "Output",
description:
(node.data.hardcodedValues as any).description ||
"Output from the agent",
},
result: (node.data.executionResults as any)?.at(-1)?.data?.output,
}) satisfies BlockOutput,
);
return { inputs, outputs };
}, [nodes]);

View File

@@ -5,6 +5,7 @@ import moment from "moment";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import {
GraphExecution,
GraphExecutionID,
GraphExecutionMeta,
GraphMeta,
} from "@/lib/autogpt-server-api";
@@ -25,11 +26,13 @@ export default function AgentRunDetailsView({
graph,
run,
agentActions,
onRun,
deleteRun,
}: {
graph: GraphMeta;
run: GraphExecution | GraphExecutionMeta;
agentActions: ButtonAction[];
onRun: (runID: GraphExecutionID) => void;
deleteRun: () => void;
}): React.ReactNode {
const api = useBackendAPI();
@@ -90,8 +93,9 @@ export default function AgentRunDetailsView({
Object.entries(agentRunInputs).map(([k, v]) => [k, v.value]),
),
)
.then(({ graph_exec_id }) => onRun(graph_exec_id))
.catch(toastOnFail("execute agent")),
[api, graph, agentRunInputs, toastOnFail],
[api, graph, agentRunInputs, onRun, toastOnFail],
);
const stopRun = useCallback(

View File

@@ -104,24 +104,28 @@ export default function AgentRunsSelectorList({
</Button>
{activeListTab === "runs"
? agentRuns.map((run, i) => (
<AgentRunSummaryCard
className="h-28 w-72 lg:h-32 xl:w-80"
key={i}
status={agentRunStatusMap[run.status]}
title={agent.name}
timestamp={run.started_at}
selected={selectedView.id === run.id}
onClick={() => onSelectRun(run.id)}
onDelete={() => onDeleteRun(run)}
/>
))
: schedules
.filter((schedule) => schedule.graph_id === agent.agent_id)
.map((schedule, i) => (
? agentRuns
.toSorted(
(a, b) => b.started_at.getTime() - a.started_at.getTime(),
)
.map((run) => (
<AgentRunSummaryCard
className="h-28 w-72 lg:h-32 xl:w-80"
key={i}
key={run.id}
status={agentRunStatusMap[run.status]}
title={agent.name}
timestamp={run.started_at}
selected={selectedView.id === run.id}
onClick={() => onSelectRun(run.id)}
onDelete={() => onDeleteRun(run)}
/>
))
: schedules
.filter((schedule) => schedule.graph_id === agent.agent_id)
.map((schedule) => (
<AgentRunSummaryCard
className="h-28 w-72 lg:h-32 xl:w-80"
key={schedule.id}
status="scheduled"
title={schedule.name}
timestamp={schedule.next_run_time}

View File

@@ -126,7 +126,8 @@ export const AgentFlowList = ({
if (!a.lastRun) return 1;
if (!b.lastRun) return -1;
return (
Number(b.lastRun.started_at) - Number(a.lastRun.started_at)
b.lastRun.started_at.getTime() -
a.lastRun.started_at.getTime()
);
})
.map(({ flow, runCount, lastRun }) => (

View File

@@ -17,59 +17,39 @@ import { useBackendAPI } from "@/lib/autogpt-server-api/context";
export const FlowRunInfo: React.FC<
React.HTMLAttributes<HTMLDivElement> & {
flow: LibraryAgent;
agent: LibraryAgent;
execution: GraphExecutionMeta;
}
> = ({ flow, execution, ...props }) => {
> = ({ agent, execution, ...props }) => {
const [isOutputOpen, setIsOutputOpen] = useState(false);
const [blockOutputs, setBlockOutputs] = useState<BlockOutput[]>([]);
const api = useBackendAPI();
const fetchBlockResults = useCallback(async () => {
const executionResults = (
await api.getGraphExecutionInfo(flow.agent_id, execution.id)
).node_executions;
// Create a map of the latest COMPLETED execution results of output nodes by node_id
const latestCompletedResults = executionResults
.filter(
(result) =>
result.status === "COMPLETED" &&
result.block_id === SpecialBlockID.OUTPUT,
)
.reduce((acc, result) => {
const existing = acc.get(result.node_id);
// Compare dates if there's an existing result
if (existing) {
const existingDate = existing.end_time || existing.add_time;
const currentDate = result.end_time || result.add_time;
if (currentDate > existingDate) {
acc.set(result.node_id, result);
}
} else {
acc.set(result.node_id, result);
}
return acc;
}, new Map<string, NodeExecutionResult>());
const graph = await api.getGraph(agent.agent_id, agent.agent_version);
const graphExecution = await api.getGraphExecutionInfo(
agent.agent_id,
execution.id,
);
// Transform results to BlockOutput format
setBlockOutputs(
Array.from(latestCompletedResults.values()).map((result) => ({
id: result.node_id,
type: "output" as const,
hardcodedValues: {
name: result.input_data.name || "Output",
description: result.input_data.description || "Output from the agent",
value: result.input_data.value,
},
// Change this line to extract the array directly
result: result.output_data?.output || undefined,
})),
Object.entries(graphExecution.outputs).flatMap(([key, values]) =>
values.map(
(value) =>
({
metadata: {
name: graph.output_schema.properties[key].title || "Output",
description:
graph.output_schema.properties[key].description ||
"Output from the agent",
},
result: value,
}) satisfies BlockOutput,
),
),
);
}, [api, flow.agent_id, execution.id]);
}, [api, agent.agent_id, agent.agent_version, execution.id]);
// Fetch graph and execution data
useEffect(() => {
@@ -77,15 +57,15 @@ export const FlowRunInfo: React.FC<
fetchBlockResults();
}, [isOutputOpen, fetchBlockResults]);
if (execution.graph_id != flow.agent_id) {
if (execution.graph_id != agent.agent_id) {
throw new Error(
`FlowRunInfo can't be used with non-matching execution.graph_id and flow.id`,
);
}
const handleStopRun = useCallback(() => {
api.stopGraphExecution(flow.agent_id, execution.id);
}, [api, flow.agent_id, execution.id]);
api.stopGraphExecution(agent.agent_id, execution.id);
}, [api, agent.agent_id, execution.id]);
return (
<>
@@ -93,7 +73,7 @@ export const FlowRunInfo: React.FC<
<CardHeader className="flex-row items-center justify-between space-x-3 space-y-0">
<div>
<CardTitle>
{flow.name}{" "}
{agent.name}{" "}
<span className="font-light">v{execution.graph_version}</span>
</CardTitle>
</div>
@@ -106,7 +86,7 @@ export const FlowRunInfo: React.FC<
<Button onClick={() => setIsOutputOpen(true)} variant="outline">
<ExitIcon className="mr-2" /> View Outputs
</Button>
{flow.can_access_graph && (
{agent.can_access_graph && (
<Link
className={buttonVariants({ variant: "default" })}
href={`/build?flowID=${execution.graph_id}&flowVersion=${execution.graph_version}&flowExecutionID=${execution.id}`}
@@ -118,7 +98,7 @@ export const FlowRunInfo: React.FC<
</CardHeader>
<CardContent>
<p className="hidden">
<strong>Agent ID:</strong> <code>{flow.agent_id}</code>
<strong>Agent ID:</strong> <code>{agent.agent_id}</code>
</p>
<p className="hidden">
<strong>Run ID:</strong> <code>{execution.id}</code>

View File

@@ -29,7 +29,7 @@ export const FlowRunsStatus: React.FC<{
: statsSince;
const filteredFlowRuns =
statsSinceTimestamp != null
? executions.filter((fr) => Number(fr.started_at) > statsSinceTimestamp)
? executions.filter((fr) => fr.started_at.getTime() > statsSinceTimestamp)
: executions;
return (

View File

@@ -99,7 +99,7 @@ export const FlowRunsTimeline = ({
.filter((e) => e.graph_id == flow.agent_id)
.map((e) => ({
...e,
time: Number(e.started_at) + e.total_run_time * 1000,
time: e.started_at.getTime() + e.total_run_time * 1000,
_duration: e.total_run_time,
}))}
name={flow.name}
@@ -112,10 +112,14 @@ export const FlowRunsTimeline = ({
type="linear"
dataKey="_duration"
data={[
{ ...execution, time: Number(execution.started_at), _duration: 0 },
{
...execution,
time: Number(execution.ended_at),
time: execution.started_at.getTime(),
_duration: 0,
},
{
...execution,
time: execution.ended_at.getTime(),
_duration: execution.total_run_time,
},
]}

View File

@@ -7,21 +7,19 @@ import {
SheetDescription,
} from "@/components/ui/sheet";
import { ScrollArea } from "@/components/ui/scroll-area";
import { BlockIORootSchema } from "@/lib/autogpt-server-api/types";
import { Label } from "@/components/ui/label";
import { Textarea } from "@/components/ui/textarea";
import { Button } from "@/components/ui/button";
import { Clipboard } from "lucide-react";
import { useToast } from "@/components/ui/use-toast";
export interface BlockOutput {
id: string;
hardcodedValues: {
export type BlockOutput = {
metadata: {
name: string;
description: string;
};
result?: any;
}
};
interface OutputModalProps {
isOpen: boolean;
@@ -87,15 +85,15 @@ export function RunnerOutputUI({
<ScrollArea className="h-full overflow-auto pr-4">
<div className="space-y-4">
{blockOutputs && blockOutputs.length > 0 ? (
blockOutputs.map((block) => (
<div key={block.id} className="space-y-1">
blockOutputs.map((block, i) => (
<div key={i} className="space-y-1">
<Label className="text-base font-semibold">
{block.hardcodedValues.name || "Unnamed Output"}
{block.metadata.name || "Unnamed Output"}
</Label>
{block.hardcodedValues.description && (
{block.metadata.description && (
<Label className="block text-sm text-gray-600">
{block.hardcodedValues.description}
{block.metadata.description}
</Label>
)}
@@ -106,7 +104,7 @@ export function RunnerOutputUI({
size="icon"
onClick={() =>
copyOutput(
block.hardcodedValues.name || "Unnamed Output",
block.metadata.name || "Unnamed Output",
block.result,
)
}

View File

@@ -631,7 +631,10 @@ export default function useAgentGraph(
activeExecutionID: flowExecutionID,
});
}
setUpdateQueue((prev) => [...prev, ...execution.node_executions]);
setUpdateQueue((prev) => {
if (!execution.node_executions) return prev;
return [...prev, ...execution.node_executions];
});
// Track execution until completed
const pendingNodeExecutions: Set<string> = new Set();

View File

@@ -253,11 +253,15 @@ export default class BackendAPI {
}
getExecutions(): Promise<GraphExecutionMeta[]> {
return this._get(`/executions`);
return this._get(`/executions`).then((results) =>
results.map(parseGraphExecutionTimestamps),
);
}
getGraphExecutions(graphID: GraphID): Promise<GraphExecutionMeta[]> {
return this._get(`/graphs/${graphID}/executions`);
return this._get(`/graphs/${graphID}/executions`).then((results) =>
results.map(parseGraphExecutionTimestamps),
);
}
async getGraphExecutionInfo(
@@ -265,10 +269,7 @@ export default class BackendAPI {
runID: GraphExecutionID,
): Promise<GraphExecution> {
const result = await this._get(`/graphs/${graphID}/executions/${runID}`);
result.node_executions = result.node_executions.map(
parseNodeExecutionResultTimestamps,
);
return result;
return parseGraphExecutionTimestamps<GraphExecution>(result);
}
async stopGraphExecution(
@@ -279,10 +280,7 @@ export default class BackendAPI {
"POST",
`/graphs/${graphID}/executions/${runID}/stop`,
);
result.node_executions = result.node_executions.map(
parseNodeExecutionResultTimestamps,
);
return result;
return parseGraphExecutionTimestamps<GraphExecution>(result);
}
async deleteGraphExecution(runID: GraphExecutionID): Promise<void> {
@@ -828,6 +826,12 @@ export default class BackendAPI {
});
}
subscribeToGraphExecutions(graphID: GraphID): Promise<void> {
return this.sendWebSocketMessage("subscribe_graph_executions", {
graph_id: graphID,
});
}
async sendWebSocketMessage<M extends keyof WebsocketMessageTypeMap>(
method: M,
data: WebsocketMessageTypeMap[M],
@@ -904,7 +908,7 @@ export default class BackendAPI {
if (message.method === "node_execution_event") {
message.data = parseNodeExecutionResultTimestamps(message.data);
} else if (message.method == "graph_execution_event") {
message.data = parseGraphExecutionMetaTimestamps(message.data);
message.data = parseGraphExecutionTimestamps(message.data);
}
this.wsMessageHandlers[message.method]?.forEach((handler) =>
handler(message.data),
@@ -973,7 +977,8 @@ type GraphCreateRequestBody = {
type WebsocketMessageTypeMap = {
subscribe_graph_execution: { graph_exec_id: GraphExecutionID };
graph_execution_event: GraphExecutionMeta;
subscribe_graph_executions: { graph_id: GraphID };
graph_execution_event: GraphExecution;
node_execution_event: NodeExecutionResult;
heartbeat: "ping" | "pong";
};
@@ -994,11 +999,16 @@ type _PydanticValidationError = {
/* *** HELPER FUNCTIONS *** */
function parseGraphExecutionMetaTimestamps(result: any): GraphExecutionMeta {
return _parseObjectTimestamps<GraphExecutionMeta>(result, [
"started_at",
"ended_at",
]);
function parseGraphExecutionTimestamps<
T extends GraphExecutionMeta | GraphExecution,
>(result: any): T {
const fixed = _parseObjectTimestamps<T>(result, ["started_at", "ended_at"]);
if ("node_executions" in fixed && fixed.node_executions) {
fixed.node_executions = fixed.node_executions.map(
parseNodeExecutionResultTimestamps,
);
}
return fixed;
}
function parseNodeExecutionResultTimestamps(result: any): NodeExecutionResult {

View File

@@ -228,7 +228,7 @@ export type LinkCreatable = Omit<Link, "id" | "is_static"> & {
id?: string;
};
/* Mirror of backend/data/graph.py:GraphExecutionMeta */
/* Mirror of backend/data/execution.py:GraphExecutionMeta */
export type GraphExecutionMeta = {
id: GraphExecutionID;
started_at: Date;
@@ -244,11 +244,11 @@ export type GraphExecutionMeta = {
export type GraphExecutionID = Brand<string, "GraphExecutionID">;
/* Mirror of backend/data/graph.py:GraphExecution */
/* Mirror of backend/data/execution.py:GraphExecution */
export type GraphExecution = GraphExecutionMeta & {
inputs: Record<string, any>;
outputs: Record<string, Array<any>>;
node_executions: NodeExecutionResult[];
node_executions?: NodeExecutionResult[];
};
export type GraphMeta = {
@@ -297,7 +297,7 @@ export type GraphUpdateable = Omit<
export type GraphCreatable = Omit<GraphUpdateable, "id"> & { id?: string };
/* Mirror of backend/data/execution.py:ExecutionResult */
/* Mirror of backend/data/execution.py:NodeExecutionResult */
export type NodeExecutionResult = {
graph_id: GraphID;
graph_version: number;