mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Merge branch 'add-iffy-moderation' of https://github.com/Significant-Gravitas/AutoGPT into add-iffy-moderation
This commit is contained in:
@@ -550,3 +550,17 @@ class AgentToggleInputBlock(AgentInputBlock):
|
||||
("result", False),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
IO_BLOCK_IDs = [
|
||||
AgentInputBlock().id,
|
||||
AgentOutputBlock().id,
|
||||
AgentShortTextInputBlock().id,
|
||||
AgentLongTextInputBlock().id,
|
||||
AgentNumberInputBlock().id,
|
||||
AgentDateInputBlock().id,
|
||||
AgentTimeInputBlock().id,
|
||||
AgentFileInputBlock().id,
|
||||
AgentDropdownInputBlock().id,
|
||||
AgentToggleInputBlock().id,
|
||||
]
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import (
|
||||
Literal,
|
||||
Optional,
|
||||
TypeVar,
|
||||
overload,
|
||||
)
|
||||
|
||||
from prisma import Json
|
||||
@@ -36,7 +37,11 @@ from backend.util.settings import Config
|
||||
|
||||
from .block import BlockData, BlockInput, BlockType, CompletedBlockOutput, get_block
|
||||
from .db import BaseDbModel
|
||||
from .includes import EXECUTION_RESULT_INCLUDE, GRAPH_EXECUTION_INCLUDE
|
||||
from .includes import (
|
||||
EXECUTION_RESULT_INCLUDE,
|
||||
GRAPH_EXECUTION_INCLUDE,
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
)
|
||||
from .model import GraphExecutionStats, NodeExecutionStats
|
||||
from .queue import AsyncRedisEventBus, RedisEventBus
|
||||
|
||||
@@ -103,7 +108,6 @@ class GraphExecutionMeta(BaseDbModel):
|
||||
class GraphExecution(GraphExecutionMeta):
|
||||
inputs: BlockInput
|
||||
outputs: CompletedBlockOutput
|
||||
node_executions: list["NodeExecutionResult"]
|
||||
|
||||
@staticmethod
|
||||
def from_db(_graph_exec: AgentGraphExecution):
|
||||
@@ -158,6 +162,32 @@ class GraphExecution(GraphExecutionMeta):
|
||||
},
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
|
||||
class GraphExecutionWithNodes(GraphExecution):
|
||||
node_executions: list["NodeExecutionResult"]
|
||||
|
||||
@staticmethod
|
||||
def from_db(_graph_exec: AgentGraphExecution):
|
||||
if _graph_exec.AgentNodeExecutions is None:
|
||||
raise ValueError("Node executions must be included in query")
|
||||
|
||||
graph_exec_with_io = GraphExecution.from_db(_graph_exec)
|
||||
|
||||
node_executions = sorted(
|
||||
[
|
||||
NodeExecutionResult.from_db(ne, _graph_exec.userId)
|
||||
for ne in _graph_exec.AgentNodeExecutions
|
||||
],
|
||||
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
|
||||
)
|
||||
|
||||
return GraphExecutionWithNodes(
|
||||
**{
|
||||
field_name: getattr(graph_exec_with_io, field_name)
|
||||
for field_name in graph_exec_with_io.model_fields
|
||||
},
|
||||
node_executions=node_executions,
|
||||
)
|
||||
|
||||
@@ -250,12 +280,51 @@ async def get_graph_execution_meta(
|
||||
return GraphExecutionMeta.from_db(execution) if execution else None
|
||||
|
||||
|
||||
async def get_graph_execution(user_id: str, execution_id: str) -> GraphExecution | None:
|
||||
@overload
|
||||
async def get_graph_execution(
|
||||
user_id: str,
|
||||
execution_id: str,
|
||||
include_node_executions: Literal[True],
|
||||
) -> GraphExecutionWithNodes | None: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def get_graph_execution(
|
||||
user_id: str,
|
||||
execution_id: str,
|
||||
include_node_executions: Literal[False] = False,
|
||||
) -> GraphExecution | None: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def get_graph_execution(
|
||||
user_id: str,
|
||||
execution_id: str,
|
||||
include_node_executions: bool = False,
|
||||
) -> GraphExecution | GraphExecutionWithNodes | None: ...
|
||||
|
||||
|
||||
async def get_graph_execution(
|
||||
user_id: str,
|
||||
execution_id: str,
|
||||
include_node_executions: bool = False,
|
||||
) -> GraphExecution | GraphExecutionWithNodes | None:
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={"id": execution_id, "isDeleted": False, "userId": user_id},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
include=(
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES
|
||||
if include_node_executions
|
||||
else GRAPH_EXECUTION_INCLUDE
|
||||
),
|
||||
)
|
||||
if not execution:
|
||||
return None
|
||||
|
||||
return (
|
||||
GraphExecutionWithNodes.from_db(execution)
|
||||
if include_node_executions
|
||||
else GraphExecution.from_db(execution)
|
||||
)
|
||||
return GraphExecution.from_db(execution) if execution else None
|
||||
|
||||
|
||||
async def create_graph_execution(
|
||||
@@ -264,7 +333,7 @@ async def create_graph_execution(
|
||||
nodes_input: list[tuple[str, BlockInput]],
|
||||
user_id: str,
|
||||
preset_id: str | None = None,
|
||||
) -> tuple[str, list[NodeExecutionResult]]:
|
||||
) -> GraphExecutionWithNodes:
|
||||
"""
|
||||
Create a new AgentGraphExecution record.
|
||||
Returns:
|
||||
@@ -294,13 +363,10 @@ async def create_graph_execution(
|
||||
"userId": user_id,
|
||||
"agentPresetId": preset_id,
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
|
||||
)
|
||||
|
||||
return result.id, [
|
||||
NodeExecutionResult.from_db(execution, result.userId)
|
||||
for execution in result.AgentNodeExecutions or []
|
||||
]
|
||||
return GraphExecutionWithNodes.from_db(result)
|
||||
|
||||
|
||||
async def upsert_execution_input(
|
||||
@@ -322,17 +388,20 @@ async def upsert_execution_input(
|
||||
node_exec_id: [Optional] The id of the AgentNodeExecution that has no `input_name` as input. If not provided, it will find the eligible incomplete AgentNodeExecution or create a new one.
|
||||
|
||||
Returns:
|
||||
* The id of the created or existing AgentNodeExecution.
|
||||
* Dict of node input data, key is the input name, value is the input data.
|
||||
str: The id of the created or existing AgentNodeExecution.
|
||||
dict[str, Any]: Node input data; key is the input name, value is the input data.
|
||||
"""
|
||||
existing_exec_query_filter: AgentNodeExecutionWhereInput = {
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_exec_id,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"Input": {"every": {"name": {"not": input_name}}},
|
||||
}
|
||||
if node_exec_id:
|
||||
existing_exec_query_filter["id"] = node_exec_id
|
||||
|
||||
existing_execution = await AgentNodeExecution.prisma().find_first(
|
||||
where={ # type: ignore
|
||||
**({"id": node_exec_id} if node_exec_id else {}),
|
||||
"agentNodeId": node_id,
|
||||
"agentGraphExecutionId": graph_exec_id,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
"Input": {"every": {"name": {"not": input_name}}},
|
||||
},
|
||||
where=existing_exec_query_filter,
|
||||
order={"addedTime": "asc"},
|
||||
include={"Input": True},
|
||||
)
|
||||
@@ -388,25 +457,26 @@ async def upsert_execution_output(
|
||||
)
|
||||
|
||||
|
||||
async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecutionMeta:
|
||||
async def update_graph_execution_start_time(graph_exec_id: str) -> GraphExecution:
|
||||
res = await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
data={
|
||||
"executionStatus": ExecutionStatus.RUNNING,
|
||||
"startedAt": datetime.now(tz=timezone.utc),
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
if not res:
|
||||
raise ValueError(f"Graph execution #{graph_exec_id} not found")
|
||||
|
||||
return GraphExecutionMeta.from_db(res)
|
||||
return GraphExecution.from_db(res)
|
||||
|
||||
|
||||
async def update_graph_execution_stats(
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> GraphExecutionMeta | None:
|
||||
) -> GraphExecution | None:
|
||||
data = stats.model_dump() if stats else {}
|
||||
if isinstance(data.get("error"), Exception):
|
||||
data["error"] = str(data["error"])
|
||||
@@ -422,9 +492,10 @@ async def update_graph_execution_stats(
|
||||
"executionStatus": status,
|
||||
"stats": Json(data),
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
|
||||
return GraphExecutionMeta.from_db(res) if res else None
|
||||
return GraphExecution.from_db(res) if res else None
|
||||
|
||||
|
||||
async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionStats):
|
||||
@@ -774,7 +845,7 @@ class ExecutionEventType(str, Enum):
|
||||
NODE_EXEC_UPDATE = "node_execution_update"
|
||||
|
||||
|
||||
class GraphExecutionEvent(GraphExecutionMeta):
|
||||
class GraphExecutionEvent(GraphExecution):
|
||||
event_type: Literal[ExecutionEventType.GRAPH_EXEC_UPDATE] = (
|
||||
ExecutionEventType.GRAPH_EXEC_UPDATE
|
||||
)
|
||||
@@ -798,8 +869,8 @@ class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
|
||||
def event_bus_name(self) -> str:
|
||||
return config.execution_event_bus_name
|
||||
|
||||
def publish(self, res: GraphExecutionMeta | NodeExecutionResult):
|
||||
if isinstance(res, GraphExecutionMeta):
|
||||
def publish(self, res: GraphExecution | NodeExecutionResult):
|
||||
if isinstance(res, GraphExecution):
|
||||
self.publish_graph_exec_update(res)
|
||||
else:
|
||||
self.publish_node_exec_update(res)
|
||||
@@ -808,7 +879,7 @@ class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
|
||||
event = NodeExecutionEvent.model_validate(res.model_dump())
|
||||
self.publish_event(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
|
||||
|
||||
def publish_graph_exec_update(self, res: GraphExecutionMeta):
|
||||
def publish_graph_exec_update(self, res: GraphExecution):
|
||||
event = GraphExecutionEvent.model_validate(res.model_dump())
|
||||
self.publish_event(event, f"{res.user_id}/{res.graph_id}/{res.id}")
|
||||
|
||||
|
||||
@@ -538,7 +538,6 @@ async def get_graph_metadata(graph_id: str, version: int | None = None) -> Graph
|
||||
|
||||
graph = await AgentGraph.prisma().find_first(
|
||||
where=where_clause,
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
order={"version": "desc"},
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import prisma
|
||||
|
||||
from backend.blocks.io import IO_BLOCK_IDs
|
||||
|
||||
AGENT_NODE_INCLUDE: prisma.types.AgentNodeInclude = {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
@@ -20,7 +22,7 @@ EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
|
||||
|
||||
MAX_NODE_EXECUTIONS_FETCH = 1000
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
|
||||
GRAPH_EXECUTION_INCLUDE_WITH_NODES: prisma.types.AgentGraphExecutionInclude = {
|
||||
"AgentNodeExecutions": {
|
||||
"include": {
|
||||
"Input": True,
|
||||
@@ -37,6 +39,17 @@ GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
|
||||
}
|
||||
}
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
|
||||
"AgentNodeExecutions": {
|
||||
**GRAPH_EXECUTION_INCLUDE_WITH_NODES["AgentNodeExecutions"], # type: ignore
|
||||
"where": {
|
||||
"AgentNode": {
|
||||
"AgentBlock": {"id": {"in": IO_BLOCK_IDs}}, # type: ignore
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
GraphExecutionMeta,
|
||||
GraphExecution,
|
||||
NodeExecutionResult,
|
||||
RedisExecutionEventBus,
|
||||
create_graph_execution,
|
||||
get_graph_execution,
|
||||
get_incomplete_node_executions,
|
||||
get_latest_node_execution,
|
||||
get_node_execution_results,
|
||||
@@ -65,11 +66,12 @@ class DatabaseManager(AppService):
|
||||
|
||||
@expose
|
||||
def send_execution_update(
|
||||
self, execution_result: GraphExecutionMeta | NodeExecutionResult
|
||||
self, execution_result: GraphExecution | NodeExecutionResult
|
||||
):
|
||||
self.execution_event_bus.publish(execution_result)
|
||||
|
||||
# Executions
|
||||
get_graph_execution = exposed_run_and_wait(get_graph_execution)
|
||||
create_graph_execution = exposed_run_and_wait(create_graph_execution)
|
||||
get_node_execution_results = exposed_run_and_wait(get_node_execution_results)
|
||||
get_incomplete_node_executions = exposed_run_and_wait(
|
||||
|
||||
@@ -150,7 +150,7 @@ def execute_node(
|
||||
node_exec_id = data.node_exec_id
|
||||
node_id = data.node_id
|
||||
|
||||
def update_execution(status: ExecutionStatus) -> NodeExecutionResult:
|
||||
def update_execution_status(status: ExecutionStatus) -> NodeExecutionResult:
|
||||
"""Sets status and fetches+broadcasts the latest state of the node execution"""
|
||||
exec_update = db_client.update_node_execution_status(node_exec_id, status)
|
||||
db_client.send_execution_update(exec_update)
|
||||
@@ -163,6 +163,17 @@ def execute_node(
|
||||
logger.error(f"Block {node.block_id} not found.")
|
||||
return
|
||||
|
||||
def push_output(output_name: str, output_data: Any) -> None:
|
||||
_push_node_execution_output(
|
||||
db_client=db_client,
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
node_exec_id=node_exec_id,
|
||||
block_id=node_block.id,
|
||||
output_name=output_name,
|
||||
output_data=output_data,
|
||||
)
|
||||
|
||||
log_metadata = LogMetadata(
|
||||
user_id=user_id,
|
||||
graph_eid=graph_exec_id,
|
||||
@@ -176,8 +187,8 @@ def execute_node(
|
||||
input_data, error = validate_exec(node, data.data, resolve_input=False)
|
||||
if input_data is None:
|
||||
log_metadata.error(f"Skip execution, input validation error: {error}")
|
||||
db_client.upsert_execution_output(node_exec_id, "error", error)
|
||||
update_execution(ExecutionStatus.FAILED)
|
||||
push_output("error", error)
|
||||
update_execution_status(ExecutionStatus.FAILED)
|
||||
return
|
||||
|
||||
# Re-shape the input data for agent block.
|
||||
@@ -190,7 +201,7 @@ def execute_node(
|
||||
input_data_str = json.dumps(input_data)
|
||||
input_size = len(input_data_str)
|
||||
log_metadata.info("Executed node with input", input=input_data_str)
|
||||
update_execution(ExecutionStatus.RUNNING)
|
||||
update_execution_status(ExecutionStatus.RUNNING)
|
||||
|
||||
# Inject extra execution arguments for the blocks via kwargs
|
||||
extra_exec_kwargs: dict = {
|
||||
@@ -221,7 +232,7 @@ def execute_node(
|
||||
output_data = json.convert_pydantic_to_json(output_data)
|
||||
output_size += len(json.dumps(output_data))
|
||||
log_metadata.info("Node produced output", **{output_name: output_data})
|
||||
db_client.upsert_execution_output(node_exec_id, output_name, output_data)
|
||||
push_output(output_name, output_data)
|
||||
outputs[output_name] = output_data
|
||||
for execution in _enqueue_next_nodes(
|
||||
db_client=db_client,
|
||||
@@ -234,13 +245,12 @@ def execute_node(
|
||||
):
|
||||
yield execution
|
||||
|
||||
# Update execution status and spend credits
|
||||
update_execution(ExecutionStatus.COMPLETED)
|
||||
update_execution_status(ExecutionStatus.COMPLETED)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
db_client.upsert_execution_output(node_exec_id, "error", error_msg)
|
||||
update_execution(ExecutionStatus.FAILED)
|
||||
push_output("error", error_msg)
|
||||
update_execution_status(ExecutionStatus.FAILED)
|
||||
|
||||
for execution in _enqueue_next_nodes(
|
||||
db_client=db_client,
|
||||
@@ -271,6 +281,35 @@ def execute_node(
|
||||
execution_stats.output_size = output_size
|
||||
|
||||
|
||||
def _push_node_execution_output(
|
||||
db_client: "DatabaseManager",
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
node_exec_id: str,
|
||||
block_id: str,
|
||||
output_name: str,
|
||||
output_data: Any,
|
||||
):
|
||||
from backend.blocks.io import IO_BLOCK_IDs
|
||||
|
||||
db_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
output_name=output_name,
|
||||
output_data=output_data,
|
||||
)
|
||||
|
||||
# Automatically push execution updates for all agent I/O
|
||||
if block_id in IO_BLOCK_IDs:
|
||||
graph_exec = db_client.get_graph_execution(
|
||||
user_id=user_id, execution_id=graph_exec_id
|
||||
)
|
||||
if not graph_exec:
|
||||
raise ValueError(
|
||||
f"Graph execution #{graph_exec_id} for user #{user_id} not found"
|
||||
)
|
||||
db_client.send_execution_update(graph_exec)
|
||||
|
||||
|
||||
def _enqueue_next_nodes(
|
||||
db_client: "DatabaseManager",
|
||||
node: Node,
|
||||
@@ -628,7 +667,10 @@ class Executor:
|
||||
node_eid="*",
|
||||
block_name="-",
|
||||
)
|
||||
cls.db_client.update_graph_execution_start_time(graph_exec.graph_exec_id)
|
||||
exec_meta = cls.db_client.update_graph_execution_start_time(
|
||||
graph_exec.graph_exec_id
|
||||
)
|
||||
cls.db_client.send_execution_update(exec_meta)
|
||||
timing_info, (exec_stats, status, error) = cls._on_graph_execution(
|
||||
graph_exec, cancel, log_metadata
|
||||
)
|
||||
@@ -636,12 +678,12 @@ class Executor:
|
||||
exec_stats.cputime = timing_info.cpu_time
|
||||
exec_stats.error = str(error)
|
||||
|
||||
if result := cls.db_client.update_graph_execution_stats(
|
||||
if graph_exec_result := cls.db_client.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
status=status,
|
||||
stats=exec_stats,
|
||||
):
|
||||
cls.db_client.send_execution_update(result)
|
||||
cls.db_client.send_execution_update(graph_exec_result)
|
||||
|
||||
cls._handle_agent_run_notif(graph_exec, exec_stats)
|
||||
|
||||
@@ -774,11 +816,19 @@ class Executor:
|
||||
execution_stats=exec_stats,
|
||||
)
|
||||
except InsufficientBalanceError as error:
|
||||
exec_id = exec_data.node_exec_id
|
||||
cls.db_client.upsert_execution_output(exec_id, "error", str(error))
|
||||
node_exec_id = exec_data.node_exec_id
|
||||
_push_node_execution_output(
|
||||
db_client=cls.db_client,
|
||||
user_id=graph_exec.user_id,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
node_exec_id=node_exec_id,
|
||||
block_id=exec_data.block_id,
|
||||
output_name="error",
|
||||
output_data=str(error),
|
||||
)
|
||||
|
||||
exec_update = cls.db_client.update_node_execution_status(
|
||||
exec_id, ExecutionStatus.FAILED
|
||||
node_exec_id, ExecutionStatus.FAILED
|
||||
)
|
||||
cls.db_client.send_execution_update(exec_update)
|
||||
|
||||
@@ -1002,13 +1052,14 @@ class ExecutionManager(AppService):
|
||||
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
|
||||
)
|
||||
|
||||
graph_exec_id, node_execs = self.db_client.create_graph_execution(
|
||||
graph_exec = self.db_client.create_graph_execution(
|
||||
graph_id=graph_id,
|
||||
graph_version=graph.version,
|
||||
nodes_input=nodes_input,
|
||||
user_id=user_id,
|
||||
preset_id=preset_id,
|
||||
)
|
||||
self.db_client.send_execution_update(graph_exec)
|
||||
|
||||
# Right after creating the graph execution, we need to check if the content is safe
|
||||
if settings.config.behave_as != BehaveAs.LOCAL:
|
||||
@@ -1032,18 +1083,12 @@ class ExecutionManager(AppService):
|
||||
block_id=node_exec.block_id,
|
||||
data=node_exec.input_data,
|
||||
)
|
||||
)
|
||||
|
||||
graph_exec = GraphExecutionEntry(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_version=graph_version or 0,
|
||||
graph_exec_id=graph_exec_id,
|
||||
start_node_execs=starting_node_execs,
|
||||
for node_exec in graph_exec.node_executions
|
||||
],
|
||||
)
|
||||
self.queue.add(graph_exec)
|
||||
self.queue.add(graph_exec_entry)
|
||||
|
||||
return graph_exec
|
||||
return graph_exec_entry
|
||||
|
||||
@expose
|
||||
def cancel_execution(self, graph_exec_id: str) -> None:
|
||||
|
||||
@@ -32,22 +32,30 @@ class ConnectionManager:
|
||||
async def subscribe_graph_exec(
|
||||
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
|
||||
) -> str:
|
||||
key = _graph_exec_channel_key(user_id, graph_exec_id)
|
||||
if key not in self.subscriptions:
|
||||
self.subscriptions[key] = set()
|
||||
self.subscriptions[key].add(websocket)
|
||||
return key
|
||||
return await self._subscribe(
|
||||
_graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id), websocket
|
||||
)
|
||||
|
||||
async def unsubscribe(
|
||||
async def subscribe_graph_execs(
|
||||
self, *, user_id: str, graph_id: str, websocket: WebSocket
|
||||
) -> str:
|
||||
return await self._subscribe(
|
||||
_graph_execs_channel_key(user_id, graph_id=graph_id), websocket
|
||||
)
|
||||
|
||||
async def unsubscribe_graph_exec(
|
||||
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
|
||||
) -> str | None:
|
||||
key = _graph_exec_channel_key(user_id, graph_exec_id)
|
||||
if key in self.subscriptions:
|
||||
self.subscriptions[key].discard(websocket)
|
||||
if not self.subscriptions[key]:
|
||||
del self.subscriptions[key]
|
||||
return key
|
||||
return None
|
||||
return await self._unsubscribe(
|
||||
_graph_exec_channel_key(user_id, graph_exec_id=graph_exec_id), websocket
|
||||
)
|
||||
|
||||
async def unsubscribe_graph_execs(
|
||||
self, *, user_id: str, graph_id: str, websocket: WebSocket
|
||||
) -> str | None:
|
||||
return await self._unsubscribe(
|
||||
_graph_execs_channel_key(user_id, graph_id=graph_id), websocket
|
||||
)
|
||||
|
||||
async def send_execution_update(
|
||||
self, exec_event: GraphExecutionEvent | NodeExecutionEvent
|
||||
@@ -57,21 +65,51 @@ class ConnectionManager:
|
||||
if isinstance(exec_event, GraphExecutionEvent)
|
||||
else exec_event.graph_exec_id
|
||||
)
|
||||
key = _graph_exec_channel_key(exec_event.user_id, graph_exec_id)
|
||||
|
||||
n_sent = 0
|
||||
if key in self.subscriptions:
|
||||
|
||||
channels: set[str] = {
|
||||
# Send update to listeners for this graph execution
|
||||
_graph_exec_channel_key(exec_event.user_id, graph_exec_id=graph_exec_id)
|
||||
}
|
||||
if isinstance(exec_event, GraphExecutionEvent):
|
||||
# Send update to listeners for all executions of this graph
|
||||
channels.add(
|
||||
_graph_execs_channel_key(
|
||||
exec_event.user_id, graph_id=exec_event.graph_id
|
||||
)
|
||||
)
|
||||
|
||||
for channel in channels.intersection(self.subscriptions.keys()):
|
||||
message = WSMessage(
|
||||
method=_EVENT_TYPE_TO_METHOD_MAP[exec_event.event_type],
|
||||
channel=key,
|
||||
channel=channel,
|
||||
data=exec_event.model_dump(),
|
||||
).model_dump_json()
|
||||
for connection in self.subscriptions[key]:
|
||||
for connection in self.subscriptions[channel]:
|
||||
await connection.send_text(message)
|
||||
n_sent += 1
|
||||
|
||||
return n_sent
|
||||
|
||||
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
|
||||
if channel_key not in self.subscriptions:
|
||||
self.subscriptions[channel_key] = set()
|
||||
self.subscriptions[channel_key].add(websocket)
|
||||
return channel_key
|
||||
|
||||
def _graph_exec_channel_key(user_id: str, graph_exec_id: str) -> str:
|
||||
async def _unsubscribe(self, channel_key: str, websocket: WebSocket) -> str | None:
|
||||
if channel_key in self.subscriptions:
|
||||
self.subscriptions[channel_key].discard(websocket)
|
||||
if not self.subscriptions[channel_key]:
|
||||
del self.subscriptions[channel_key]
|
||||
return channel_key
|
||||
return None
|
||||
|
||||
|
||||
def _graph_exec_channel_key(user_id: str, *, graph_exec_id: str) -> str:
|
||||
return f"{user_id}|graph_exec#{graph_exec_id}"
|
||||
|
||||
|
||||
def _graph_execs_channel_key(user_id: str, *, graph_id: str) -> str:
|
||||
return f"{user_id}|graph#{graph_id}|executions"
|
||||
|
||||
@@ -9,6 +9,7 @@ from backend.data.graph import Graph
|
||||
|
||||
class WSMethod(enum.Enum):
|
||||
SUBSCRIBE_GRAPH_EXEC = "subscribe_graph_execution"
|
||||
SUBSCRIBE_GRAPH_EXECS = "subscribe_graph_executions"
|
||||
UNSUBSCRIBE = "unsubscribe"
|
||||
GRAPH_EXECUTION_EVENT = "graph_execution_event"
|
||||
NODE_EXECUTION_EVENT = "node_execution_event"
|
||||
@@ -28,6 +29,10 @@ class WSSubscribeGraphExecutionRequest(pydantic.BaseModel):
|
||||
graph_exec_id: str
|
||||
|
||||
|
||||
class WSSubscribeGraphExecutionsRequest(pydantic.BaseModel):
|
||||
graph_id: str
|
||||
|
||||
|
||||
class ExecuteGraphResponse(pydantic.BaseModel):
|
||||
graph_exec_id: str
|
||||
|
||||
|
||||
@@ -193,14 +193,6 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
raise ValueError(f"Execution {graph_exec_id} not found")
|
||||
return execution.status
|
||||
|
||||
@staticmethod
|
||||
async def test_get_graph_run_results(
|
||||
graph_id: str, graph_exec_id: str, user_id: str
|
||||
):
|
||||
return await backend.server.routers.v1.get_graph_execution(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def test_delete_graph(graph_id: str, user_id: str):
|
||||
await backend.server.v2.library.db.delete_library_agent_by_graph_id(
|
||||
|
||||
@@ -10,7 +10,7 @@ from autogpt_libs.auth.middleware import auth_middleware
|
||||
from autogpt_libs.feature_flag.client import feature_flag
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
|
||||
from typing_extensions import Optional, TypedDict
|
||||
|
||||
import backend.data.block
|
||||
@@ -653,9 +653,17 @@ async def get_graph_execution(
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> execution_db.GraphExecution:
|
||||
) -> execution_db.GraphExecution | execution_db.GraphExecutionWithNodes:
|
||||
graph = await graph_db.get_graph(graph_id=graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_404_NOT_FOUND, detail=f"Graph #{graph_id} not found"
|
||||
)
|
||||
|
||||
result = await execution_db.get_graph_execution(
|
||||
execution_id=graph_exec_id, user_id=user_id
|
||||
user_id=user_id,
|
||||
execution_id=graph_exec_id,
|
||||
include_node_executions=graph.user_id == user_id,
|
||||
)
|
||||
if not result or result.graph_id != graph_id:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Protocol
|
||||
|
||||
import uvicorn
|
||||
from autogpt_libs.auth import parse_jwt_token
|
||||
@@ -12,7 +13,12 @@ from backend.data import redis
|
||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.model import WSMessage, WSMethod, WSSubscribeGraphExecutionRequest
|
||||
from backend.server.model import (
|
||||
WSMessage,
|
||||
WSMethod,
|
||||
WSSubscribeGraphExecutionRequest,
|
||||
WSSubscribeGraphExecutionsRequest,
|
||||
)
|
||||
from backend.util.service import AppProcess, get_service_client
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
|
||||
@@ -81,6 +87,19 @@ async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
# ===================== Message Handlers ===================== #
|
||||
|
||||
|
||||
class WSMessageHandler(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
connection_manager: ConnectionManager,
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
message: WSMessage,
|
||||
): ...
|
||||
|
||||
|
||||
async def handle_subscribe(
|
||||
connection_manager: ConnectionManager,
|
||||
websocket: WebSocket,
|
||||
@@ -95,41 +114,53 @@ async def handle_subscribe(
|
||||
error="Subscription data missing",
|
||||
).model_dump_json()
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
# Verify that user has read access to graph
|
||||
# if not get_db_client().get_graph(
|
||||
# graph_id=sub_req.graph_id,
|
||||
# version=sub_req.graph_version,
|
||||
# user_id=user_id,
|
||||
# ):
|
||||
# await websocket.send_text(
|
||||
# WsMessage(
|
||||
# method=Methods.ERROR,
|
||||
# success=False,
|
||||
# error="Access denied",
|
||||
# ).model_dump_json()
|
||||
# )
|
||||
# return
|
||||
|
||||
if message.method == WSMethod.SUBSCRIBE_GRAPH_EXEC:
|
||||
sub_req = WSSubscribeGraphExecutionRequest.model_validate(message.data)
|
||||
|
||||
# Verify that user has read access to graph
|
||||
# if not get_db_client().get_graph(
|
||||
# graph_id=sub_req.graph_id,
|
||||
# version=sub_req.graph_version,
|
||||
# user_id=user_id,
|
||||
# ):
|
||||
# await websocket.send_text(
|
||||
# WsMessage(
|
||||
# method=Methods.ERROR,
|
||||
# success=False,
|
||||
# error="Access denied",
|
||||
# ).model_dump_json()
|
||||
# )
|
||||
# return
|
||||
|
||||
channel_key = await connection_manager.subscribe_graph_exec(
|
||||
user_id=user_id,
|
||||
graph_exec_id=sub_req.graph_exec_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
logger.debug(
|
||||
f"New subscription for user #{user_id}, "
|
||||
f"graph execution #{sub_req.graph_exec_id}"
|
||||
|
||||
elif message.method == WSMethod.SUBSCRIBE_GRAPH_EXECS:
|
||||
sub_req = WSSubscribeGraphExecutionsRequest.model_validate(message.data)
|
||||
channel_key = await connection_manager.subscribe_graph_execs(
|
||||
user_id=user_id,
|
||||
graph_id=sub_req.graph_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
await websocket.send_text(
|
||||
WSMessage(
|
||||
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
|
||||
success=True,
|
||||
channel=channel_key,
|
||||
).model_dump_json()
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{handle_subscribe.__name__} can't handle '{message.method}' messages"
|
||||
)
|
||||
|
||||
logger.debug(f"New subscription on channel {channel_key} for user #{user_id}")
|
||||
await websocket.send_text(
|
||||
WSMessage(
|
||||
method=message.method,
|
||||
success=True,
|
||||
channel=channel_key,
|
||||
).model_dump_json()
|
||||
)
|
||||
|
||||
|
||||
async def handle_unsubscribe(
|
||||
connection_manager: ConnectionManager,
|
||||
@@ -145,29 +176,49 @@ async def handle_unsubscribe(
|
||||
error="Subscription data missing",
|
||||
).model_dump_json()
|
||||
)
|
||||
else:
|
||||
unsub_req = WSSubscribeGraphExecutionRequest.model_validate(message.data)
|
||||
channel_key = await connection_manager.unsubscribe(
|
||||
user_id=user_id,
|
||||
graph_exec_id=unsub_req.graph_exec_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
logger.debug(
|
||||
f"Removed subscription for user #{user_id}, "
|
||||
f"graph execution #{unsub_req.graph_exec_id}"
|
||||
)
|
||||
await websocket.send_text(
|
||||
WSMessage(
|
||||
method=WSMethod.UNSUBSCRIBE,
|
||||
success=True,
|
||||
channel=channel_key,
|
||||
).model_dump_json()
|
||||
)
|
||||
return
|
||||
|
||||
unsub_req = WSSubscribeGraphExecutionRequest.model_validate(message.data)
|
||||
channel_key = await connection_manager.unsubscribe_graph_exec(
|
||||
user_id=user_id,
|
||||
graph_exec_id=unsub_req.graph_exec_id,
|
||||
websocket=websocket,
|
||||
)
|
||||
|
||||
logger.debug(f"Removed subscription on channel {channel_key} for user #{user_id}")
|
||||
await websocket.send_text(
|
||||
WSMessage(
|
||||
method=WSMethod.UNSUBSCRIBE,
|
||||
success=True,
|
||||
channel=channel_key,
|
||||
).model_dump_json()
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def health():
|
||||
return {"status": "healthy"}
|
||||
async def handle_heartbeat(
|
||||
connection_manager: ConnectionManager,
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
message: WSMessage,
|
||||
):
|
||||
await websocket.send_json(
|
||||
{
|
||||
"method": WSMethod.HEARTBEAT.value,
|
||||
"data": "pong",
|
||||
"success": True,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
_MSG_HANDLERS: dict[WSMethod, WSMessageHandler] = {
|
||||
WSMethod.HEARTBEAT: handle_heartbeat,
|
||||
WSMethod.SUBSCRIBE_GRAPH_EXEC: handle_subscribe,
|
||||
WSMethod.SUBSCRIBE_GRAPH_EXECS: handle_subscribe,
|
||||
WSMethod.UNSUBSCRIBE: handle_unsubscribe,
|
||||
}
|
||||
|
||||
|
||||
# ===================== WebSocket Server ===================== #
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
@@ -183,28 +234,9 @@ async def websocket_router(
|
||||
data = await websocket.receive_text()
|
||||
message = WSMessage.model_validate_json(data)
|
||||
|
||||
if message.method == WSMethod.HEARTBEAT:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"method": WSMethod.HEARTBEAT.value,
|
||||
"data": "pong",
|
||||
"success": True,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
if message.method == WSMethod.SUBSCRIBE_GRAPH_EXEC:
|
||||
await handle_subscribe(
|
||||
connection_manager=manager,
|
||||
websocket=websocket,
|
||||
user_id=user_id,
|
||||
message=message,
|
||||
)
|
||||
continue
|
||||
|
||||
elif message.method == WSMethod.UNSUBSCRIBE:
|
||||
await handle_unsubscribe(
|
||||
if message.method in _MSG_HANDLERS:
|
||||
await _MSG_HANDLERS[message.method](
|
||||
connection_manager=manager,
|
||||
websocket=websocket,
|
||||
user_id=user_id,
|
||||
@@ -213,7 +245,7 @@ async def websocket_router(
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error while handling '{message.method}' message "
|
||||
f"Error while handling '{message.method.value}' message "
|
||||
f"for user #{user_id}: {e}"
|
||||
)
|
||||
continue
|
||||
@@ -239,6 +271,11 @@ async def websocket_router(
|
||||
logger.debug("WebSocket client disconnected")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def health():
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
class WebsocketServer(AppProcess):
|
||||
def run(self):
|
||||
logger.info(f"CORS allow origins: {settings.config.backend_cors_allow_origins}")
|
||||
|
||||
@@ -259,7 +259,6 @@ async def block_autogen_agent():
|
||||
)
|
||||
print(response)
|
||||
result = await wait_execution(
|
||||
graph_id=test_graph.id,
|
||||
graph_exec_id=response.graph_exec_id,
|
||||
timeout=1200,
|
||||
user_id=test_user.id,
|
||||
|
||||
@@ -162,9 +162,7 @@ async def reddit_marketing_agent():
|
||||
node_input=input_data,
|
||||
)
|
||||
print(response)
|
||||
result = await wait_execution(
|
||||
test_user.id, test_graph.id, response.graph_exec_id, 120
|
||||
)
|
||||
result = await wait_execution(test_user.id, response.graph_exec_id, 120)
|
||||
print(result)
|
||||
|
||||
|
||||
|
||||
@@ -94,7 +94,7 @@ async def sample_agent():
|
||||
user_id=test_user.id,
|
||||
node_input=input_data,
|
||||
)
|
||||
await wait_execution(test_user.id, test_graph.id, response.graph_exec_id, 10)
|
||||
await wait_execution(test_user.id, response.graph_exec_id, 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -5,7 +5,11 @@ from typing import Sequence, cast
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.block import Block, BlockSchema, initialize_blocks
|
||||
from backend.data.execution import ExecutionStatus, NodeExecutionResult
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
NodeExecutionResult,
|
||||
get_graph_execution,
|
||||
)
|
||||
from backend.data.model import _BaseCredentials
|
||||
from backend.data.user import create_default_user
|
||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
||||
@@ -60,7 +64,6 @@ class SpinTestServer:
|
||||
|
||||
async def wait_execution(
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
timeout: int = 30,
|
||||
) -> Sequence[NodeExecutionResult]:
|
||||
@@ -78,9 +81,12 @@ async def wait_execution(
|
||||
# Wait for the executions to complete
|
||||
for i in range(timeout):
|
||||
if await is_execution_completed():
|
||||
graph_exec = await AgentServer().test_get_graph_run_results(
|
||||
graph_id, graph_exec_id, user_id
|
||||
graph_exec = await get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=graph_exec_id,
|
||||
include_node_executions=True,
|
||||
)
|
||||
assert graph_exec, f"Graph execution #{graph_exec_id} not found"
|
||||
return graph_exec.node_executions
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@@ -46,24 +46,22 @@ async def execute_graph(
|
||||
|
||||
# Execution queue should be empty
|
||||
logger.info("Waiting for execution to complete...")
|
||||
result = await wait_execution(test_user.id, test_graph.id, graph_exec_id, 30)
|
||||
result = await wait_execution(test_user.id, graph_exec_id, 30)
|
||||
logger.info(f"Execution completed with {len(result)} results")
|
||||
assert len(result) == num_execs
|
||||
return graph_exec_id
|
||||
|
||||
|
||||
async def assert_sample_graph_executions(
|
||||
agent_server: AgentServer,
|
||||
test_graph: graph.Graph,
|
||||
test_user: User,
|
||||
graph_exec_id: str,
|
||||
):
|
||||
logger.info(f"Checking execution results for graph {test_graph.id}")
|
||||
graph_run = await agent_server.test_get_graph_run_results(
|
||||
test_graph.id,
|
||||
graph_exec_id,
|
||||
test_user.id,
|
||||
graph_run = await execution.get_graph_execution(
|
||||
test_user.id, graph_exec_id, include_node_executions=True
|
||||
)
|
||||
assert isinstance(graph_run, execution.GraphExecutionWithNodes)
|
||||
|
||||
output_list = [{"result": ["Hello"]}, {"result": ["World"]}]
|
||||
input_list = [
|
||||
@@ -142,9 +140,7 @@ async def test_agent_execution(server: SpinTestServer):
|
||||
data,
|
||||
4,
|
||||
)
|
||||
await assert_sample_graph_executions(
|
||||
server.agent_server, test_graph, test_user, graph_exec_id
|
||||
)
|
||||
await assert_sample_graph_executions(test_graph, test_user, graph_exec_id)
|
||||
logger.info("Completed test_agent_execution")
|
||||
|
||||
|
||||
@@ -203,9 +199,10 @@ async def test_input_pin_always_waited(server: SpinTestServer):
|
||||
)
|
||||
|
||||
logger.info("Checking execution results")
|
||||
graph_exec = await server.agent_server.test_get_graph_run_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
graph_exec = await execution.get_graph_execution(
|
||||
test_user.id, graph_exec_id, include_node_executions=True
|
||||
)
|
||||
assert isinstance(graph_exec, execution.GraphExecutionWithNodes)
|
||||
assert len(graph_exec.node_executions) == 3
|
||||
# FindInDictionaryBlock should wait for the input pin to be provided,
|
||||
# Hence executing extraction of "key" from {"key1": "value1", "key2": "value2"}
|
||||
@@ -286,9 +283,10 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
|
||||
server.agent_server, test_graph, test_user, {}, 8
|
||||
)
|
||||
logger.info("Checking execution results")
|
||||
graph_exec = await server.agent_server.test_get_graph_run_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
graph_exec = await execution.get_graph_execution(
|
||||
test_user.id, graph_exec_id, include_node_executions=True
|
||||
)
|
||||
assert isinstance(graph_exec, execution.GraphExecutionWithNodes)
|
||||
assert len(graph_exec.node_executions) == 8
|
||||
# The last 3 executions will be a+b=4+5=9
|
||||
for i, exec_data in enumerate(graph_exec.node_executions[-3:]):
|
||||
@@ -385,7 +383,7 @@ async def test_execute_preset(server: SpinTestServer):
|
||||
graph_exec_id = result["id"]
|
||||
|
||||
# Wait for execution to complete
|
||||
executions = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
|
||||
executions = await wait_execution(test_user.id, graph_exec_id)
|
||||
assert len(executions) == 4
|
||||
|
||||
# FindInDictionaryBlock should wait for the input pin to be provided,
|
||||
@@ -475,7 +473,7 @@ async def test_execute_preset_with_clash(server: SpinTestServer):
|
||||
graph_exec_id = result["id"]
|
||||
|
||||
# Wait for execution to complete
|
||||
executions = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
|
||||
executions = await wait_execution(test_user.id, graph_exec_id)
|
||||
assert len(executions) == 4
|
||||
|
||||
# FindInDictionaryBlock should wait for the input pin to be provided,
|
||||
@@ -542,7 +540,5 @@ async def test_store_listing_graph(server: SpinTestServer):
|
||||
4,
|
||||
)
|
||||
|
||||
await assert_sample_graph_executions(
|
||||
server.agent_server, test_graph, alt_test_user, graph_exec_id
|
||||
)
|
||||
await assert_sample_graph_executions(test_graph, alt_test_user, graph_exec_id)
|
||||
logger.info("Completed test_agent_execution")
|
||||
|
||||
@@ -55,7 +55,7 @@ async def execute_graph(
|
||||
|
||||
# Execution queue should be empty
|
||||
logger.info("Waiting for execution to complete...")
|
||||
result = await wait_execution(test_user.id, test_graph.id, graph_exec_id, 30)
|
||||
result = await wait_execution(test_user.id, graph_exec_id, 30)
|
||||
logger.info("Execution completed with %d results", len(result))
|
||||
return graph_exec_id
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ async def test_unsubscribe(
|
||||
channel_key = "user-1|graph_exec#graph-exec-1"
|
||||
connection_manager.subscriptions[channel_key] = {mock_websocket}
|
||||
|
||||
await connection_manager.unsubscribe(
|
||||
await connection_manager.unsubscribe_graph_exec(
|
||||
user_id="user-1",
|
||||
graph_exec_id="graph-exec-1",
|
||||
websocket=mock_websocket,
|
||||
@@ -94,6 +94,14 @@ async def test_send_graph_execution_result(
|
||||
total_run_time=0.5,
|
||||
started_at=datetime.now(tz=timezone.utc),
|
||||
ended_at=datetime.now(tz=timezone.utc),
|
||||
inputs={
|
||||
"input_1": "some input value :)",
|
||||
"input_2": "some *other* input value",
|
||||
},
|
||||
outputs={
|
||||
"the_output": ["some output value"],
|
||||
"other_output": ["sike there was another output"],
|
||||
},
|
||||
)
|
||||
|
||||
await connection_manager.send_execution_update(result)
|
||||
|
||||
@@ -70,7 +70,7 @@ async def test_websocket_router_unsubscribe(
|
||||
).model_dump_json(),
|
||||
WebSocketDisconnect(),
|
||||
]
|
||||
mock_manager.unsubscribe.return_value = (
|
||||
mock_manager.unsubscribe_graph_exec.return_value = (
|
||||
f"{DEFAULT_USER_ID}|graph_exec#test-graph-exec-1"
|
||||
)
|
||||
|
||||
@@ -79,7 +79,7 @@ async def test_websocket_router_unsubscribe(
|
||||
)
|
||||
|
||||
mock_manager.connect_socket.assert_called_once_with(mock_websocket)
|
||||
mock_manager.unsubscribe.assert_called_once_with(
|
||||
mock_manager.unsubscribe_graph_exec.assert_called_once_with(
|
||||
user_id=DEFAULT_USER_ID,
|
||||
graph_exec_id="test-graph-exec-1",
|
||||
websocket=mock_websocket,
|
||||
@@ -168,7 +168,9 @@ async def test_handle_unsubscribe_success(
|
||||
message = WSMessage(
|
||||
method=WSMethod.UNSUBSCRIBE, data={"graph_exec_id": "test-graph-exec-id"}
|
||||
)
|
||||
mock_manager.unsubscribe.return_value = "user-1|graph_exec#test-graph-exec-id"
|
||||
mock_manager.unsubscribe_graph_exec.return_value = (
|
||||
"user-1|graph_exec#test-graph-exec-id"
|
||||
)
|
||||
|
||||
await handle_unsubscribe(
|
||||
connection_manager=cast(ConnectionManager, mock_manager),
|
||||
@@ -177,7 +179,7 @@ async def test_handle_unsubscribe_success(
|
||||
message=message,
|
||||
)
|
||||
|
||||
mock_manager.unsubscribe.assert_called_once_with(
|
||||
mock_manager.unsubscribe_graph_exec.assert_called_once_with(
|
||||
user_id="user-1",
|
||||
graph_exec_id="test-graph-exec-id",
|
||||
websocket=mock_websocket,
|
||||
@@ -200,7 +202,7 @@ async def test_handle_unsubscribe_missing_data(
|
||||
message=message,
|
||||
)
|
||||
|
||||
mock_manager.unsubscribe.assert_not_called()
|
||||
mock_manager._unsubscribe.assert_not_called()
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
|
||||
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
|
||||
|
||||
@@ -89,20 +89,23 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
(graph && graph.version == _graph.version) || setGraph(_graph),
|
||||
);
|
||||
api.getGraphExecutions(agent.agent_id).then((agentRuns) => {
|
||||
const sortedRuns = agentRuns.toSorted(
|
||||
(a, b) => Number(b.started_at) - Number(a.started_at),
|
||||
);
|
||||
setAgentRuns(sortedRuns);
|
||||
setAgentRuns(agentRuns);
|
||||
|
||||
// Preload the corresponding graph versions
|
||||
new Set(sortedRuns.map((run) => run.graph_version)).forEach((version) =>
|
||||
new Set(agentRuns.map((run) => run.graph_version)).forEach((version) =>
|
||||
getGraphVersion(agent.agent_id, version),
|
||||
);
|
||||
|
||||
if (!selectedView.id && isFirstLoad && sortedRuns.length > 0) {
|
||||
if (!selectedView.id && isFirstLoad && agentRuns.length > 0) {
|
||||
// only for first load or first execution
|
||||
setIsFirstLoad(false);
|
||||
selectView({ type: "run", id: sortedRuns[0].id });
|
||||
|
||||
const latestRun = agentRuns.reduce((latest, current) => {
|
||||
if (latest.started_at && !current.started_at) return current;
|
||||
else if (!latest.started_at) return latest;
|
||||
return latest.started_at > current.started_at ? latest : current;
|
||||
}, agentRuns[0]);
|
||||
selectView({ type: "run", id: latestRun.id });
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -117,6 +120,39 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
fetchAgents();
|
||||
}, []);
|
||||
|
||||
// Subscribe to websocket updates for agent runs
|
||||
useEffect(() => {
|
||||
if (!agent) return;
|
||||
|
||||
// Subscribe to all executions for this agent
|
||||
api.subscribeToGraphExecutions(agent.agent_id);
|
||||
}, [api, agent]);
|
||||
|
||||
// Handle execution updates
|
||||
useEffect(() => {
|
||||
const detachExecUpdateHandler = api.onWebSocketMessage(
|
||||
"graph_execution_event",
|
||||
(data) => {
|
||||
setAgentRuns((prev) => {
|
||||
const index = prev.findIndex((run) => run.id === data.id);
|
||||
if (index === -1) {
|
||||
return [...prev, data];
|
||||
}
|
||||
const newRuns = [...prev];
|
||||
newRuns[index] = { ...newRuns[index], ...data };
|
||||
return newRuns;
|
||||
});
|
||||
if (data.id === selectedView.id) {
|
||||
setSelectedRun((prev) => ({ ...prev, ...data }));
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
return () => {
|
||||
detachExecUpdateHandler();
|
||||
};
|
||||
}, [api, selectedView.id]);
|
||||
|
||||
// load selectedRun based on selectedView
|
||||
useEffect(() => {
|
||||
if (selectedView.type != "run" || !selectedView.id || !agent) return;
|
||||
@@ -149,12 +185,6 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
fetchSchedules();
|
||||
}, [fetchSchedules]);
|
||||
|
||||
/* TODO: use websockets instead of polling - https://github.com/Significant-Gravitas/AutoGPT/issues/8782 */
|
||||
useEffect(() => {
|
||||
const intervalId = setInterval(() => fetchAgents(), 5000);
|
||||
return () => clearInterval(intervalId);
|
||||
}, [fetchAgents]);
|
||||
|
||||
// =========================== ACTIONS ============================
|
||||
|
||||
const deleteRun = useCallback(
|
||||
@@ -256,6 +286,7 @@ export default function AgentRunsPage(): React.ReactElement {
|
||||
graph={graphVersions[selectedRun.graph_version] ?? graph}
|
||||
run={selectedRun}
|
||||
agentActions={agentActions}
|
||||
onRun={(runID) => selectRun(runID)}
|
||||
deleteRun={() => setConfirmingDeleteAgentRun(selectedRun)}
|
||||
/>
|
||||
)
|
||||
|
||||
@@ -100,13 +100,13 @@ const Monitor = () => {
|
||||
...(selectedFlow
|
||||
? executions.filter((v) => v.graph_id == selectedFlow.agent_id)
|
||||
: executions),
|
||||
].sort((a, b) => Number(b.started_at) - Number(a.started_at))}
|
||||
].sort((a, b) => b.started_at.getTime() - a.started_at.getTime())}
|
||||
selectedRun={selectedRun}
|
||||
onSelectRun={(r) => setSelectedRun(r.id == selectedRun?.id ? null : r)}
|
||||
/>
|
||||
{(selectedRun && (
|
||||
<FlowRunInfo
|
||||
flow={
|
||||
agent={
|
||||
selectedFlow ||
|
||||
flows.find((f) => f.agent_id == selectedRun.graph_id)!
|
||||
}
|
||||
|
||||
@@ -4,8 +4,8 @@ import React, {
|
||||
forwardRef,
|
||||
useImperativeHandle,
|
||||
} from "react";
|
||||
import RunnerOutputUI, { BlockOutput } from "./runner-ui/RunnerOutputUI";
|
||||
import RunnerInputUI from "./runner-ui/RunnerInputUI";
|
||||
import RunnerOutputUI from "./runner-ui/RunnerOutputUI";
|
||||
import { Node } from "@xyflow/react";
|
||||
import { filterBlocksByType } from "@/lib/utils";
|
||||
import { BlockIORootSchema, BlockUIType } from "@/lib/autogpt-server-api/types";
|
||||
@@ -60,7 +60,11 @@ const RunnerUIWrapper = forwardRef<RunnerUIWrapperRef, RunnerUIWrapperProps>(
|
||||
const [isRunnerOutputOpen, setIsRunnerOutputOpen] = useState(false);
|
||||
const [scheduledInput, setScheduledInput] = useState(false);
|
||||
const [cronExpression, setCronExpression] = useState("");
|
||||
const getBlockInputsAndOutputs = useCallback(() => {
|
||||
|
||||
const getBlockInputsAndOutputs = useCallback((): {
|
||||
inputs: InputItem[];
|
||||
outputs: BlockOutput[];
|
||||
} => {
|
||||
const inputBlocks = filterBlocksByType(
|
||||
nodes,
|
||||
(node) => node.data.uiType === BlockUIType.INPUT,
|
||||
@@ -71,34 +75,37 @@ const RunnerUIWrapper = forwardRef<RunnerUIWrapperRef, RunnerUIWrapperProps>(
|
||||
(node) => node.data.uiType === BlockUIType.OUTPUT,
|
||||
);
|
||||
|
||||
const inputs = inputBlocks.map((node) => ({
|
||||
id: node.id,
|
||||
type: "input" as const,
|
||||
inputSchema: node.data.inputSchema as BlockIORootSchema,
|
||||
hardcodedValues: {
|
||||
name: (node.data.hardcodedValues as any).name || "",
|
||||
description: (node.data.hardcodedValues as any).description || "",
|
||||
value: (node.data.hardcodedValues as any).value,
|
||||
placeholder_values:
|
||||
(node.data.hardcodedValues as any).placeholder_values || [],
|
||||
limit_to_placeholder_values:
|
||||
(node.data.hardcodedValues as any).limit_to_placeholder_values ||
|
||||
false,
|
||||
},
|
||||
}));
|
||||
const inputs = inputBlocks.map(
|
||||
(node) =>
|
||||
({
|
||||
id: node.id,
|
||||
type: "input" as const,
|
||||
inputSchema: node.data.inputSchema as BlockIORootSchema,
|
||||
hardcodedValues: {
|
||||
name: (node.data.hardcodedValues as any).name || "",
|
||||
description: (node.data.hardcodedValues as any).description || "",
|
||||
value: (node.data.hardcodedValues as any).value,
|
||||
placeholder_values:
|
||||
(node.data.hardcodedValues as any).placeholder_values || [],
|
||||
limit_to_placeholder_values:
|
||||
(node.data.hardcodedValues as any)
|
||||
.limit_to_placeholder_values || false,
|
||||
},
|
||||
}) satisfies InputItem,
|
||||
);
|
||||
|
||||
const outputs = outputBlocks.map((node) => ({
|
||||
id: node.id,
|
||||
type: "output" as const,
|
||||
hardcodedValues: {
|
||||
name: (node.data.hardcodedValues as any).name || "Output",
|
||||
description:
|
||||
(node.data.hardcodedValues as any).description ||
|
||||
"Output from the agent",
|
||||
value: (node.data.hardcodedValues as any).value,
|
||||
},
|
||||
result: (node.data.executionResults as any)?.at(-1)?.data?.output,
|
||||
}));
|
||||
const outputs = outputBlocks.map(
|
||||
(node) =>
|
||||
({
|
||||
metadata: {
|
||||
name: (node.data.hardcodedValues as any).name || "Output",
|
||||
description:
|
||||
(node.data.hardcodedValues as any).description ||
|
||||
"Output from the agent",
|
||||
},
|
||||
result: (node.data.executionResults as any)?.at(-1)?.data?.output,
|
||||
}) satisfies BlockOutput,
|
||||
);
|
||||
|
||||
return { inputs, outputs };
|
||||
}, [nodes]);
|
||||
|
||||
@@ -5,6 +5,7 @@ import moment from "moment";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import {
|
||||
GraphExecution,
|
||||
GraphExecutionID,
|
||||
GraphExecutionMeta,
|
||||
GraphMeta,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
@@ -25,11 +26,13 @@ export default function AgentRunDetailsView({
|
||||
graph,
|
||||
run,
|
||||
agentActions,
|
||||
onRun,
|
||||
deleteRun,
|
||||
}: {
|
||||
graph: GraphMeta;
|
||||
run: GraphExecution | GraphExecutionMeta;
|
||||
agentActions: ButtonAction[];
|
||||
onRun: (runID: GraphExecutionID) => void;
|
||||
deleteRun: () => void;
|
||||
}): React.ReactNode {
|
||||
const api = useBackendAPI();
|
||||
@@ -90,8 +93,9 @@ export default function AgentRunDetailsView({
|
||||
Object.entries(agentRunInputs).map(([k, v]) => [k, v.value]),
|
||||
),
|
||||
)
|
||||
.then(({ graph_exec_id }) => onRun(graph_exec_id))
|
||||
.catch(toastOnFail("execute agent")),
|
||||
[api, graph, agentRunInputs, toastOnFail],
|
||||
[api, graph, agentRunInputs, onRun, toastOnFail],
|
||||
);
|
||||
|
||||
const stopRun = useCallback(
|
||||
|
||||
@@ -104,24 +104,28 @@ export default function AgentRunsSelectorList({
|
||||
</Button>
|
||||
|
||||
{activeListTab === "runs"
|
||||
? agentRuns.map((run, i) => (
|
||||
<AgentRunSummaryCard
|
||||
className="h-28 w-72 lg:h-32 xl:w-80"
|
||||
key={i}
|
||||
status={agentRunStatusMap[run.status]}
|
||||
title={agent.name}
|
||||
timestamp={run.started_at}
|
||||
selected={selectedView.id === run.id}
|
||||
onClick={() => onSelectRun(run.id)}
|
||||
onDelete={() => onDeleteRun(run)}
|
||||
/>
|
||||
))
|
||||
: schedules
|
||||
.filter((schedule) => schedule.graph_id === agent.agent_id)
|
||||
.map((schedule, i) => (
|
||||
? agentRuns
|
||||
.toSorted(
|
||||
(a, b) => b.started_at.getTime() - a.started_at.getTime(),
|
||||
)
|
||||
.map((run) => (
|
||||
<AgentRunSummaryCard
|
||||
className="h-28 w-72 lg:h-32 xl:w-80"
|
||||
key={i}
|
||||
key={run.id}
|
||||
status={agentRunStatusMap[run.status]}
|
||||
title={agent.name}
|
||||
timestamp={run.started_at}
|
||||
selected={selectedView.id === run.id}
|
||||
onClick={() => onSelectRun(run.id)}
|
||||
onDelete={() => onDeleteRun(run)}
|
||||
/>
|
||||
))
|
||||
: schedules
|
||||
.filter((schedule) => schedule.graph_id === agent.agent_id)
|
||||
.map((schedule) => (
|
||||
<AgentRunSummaryCard
|
||||
className="h-28 w-72 lg:h-32 xl:w-80"
|
||||
key={schedule.id}
|
||||
status="scheduled"
|
||||
title={schedule.name}
|
||||
timestamp={schedule.next_run_time}
|
||||
|
||||
@@ -126,7 +126,8 @@ export const AgentFlowList = ({
|
||||
if (!a.lastRun) return 1;
|
||||
if (!b.lastRun) return -1;
|
||||
return (
|
||||
Number(b.lastRun.started_at) - Number(a.lastRun.started_at)
|
||||
b.lastRun.started_at.getTime() -
|
||||
a.lastRun.started_at.getTime()
|
||||
);
|
||||
})
|
||||
.map(({ flow, runCount, lastRun }) => (
|
||||
|
||||
@@ -17,59 +17,39 @@ import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
|
||||
export const FlowRunInfo: React.FC<
|
||||
React.HTMLAttributes<HTMLDivElement> & {
|
||||
flow: LibraryAgent;
|
||||
agent: LibraryAgent;
|
||||
execution: GraphExecutionMeta;
|
||||
}
|
||||
> = ({ flow, execution, ...props }) => {
|
||||
> = ({ agent, execution, ...props }) => {
|
||||
const [isOutputOpen, setIsOutputOpen] = useState(false);
|
||||
const [blockOutputs, setBlockOutputs] = useState<BlockOutput[]>([]);
|
||||
const api = useBackendAPI();
|
||||
|
||||
const fetchBlockResults = useCallback(async () => {
|
||||
const executionResults = (
|
||||
await api.getGraphExecutionInfo(flow.agent_id, execution.id)
|
||||
).node_executions;
|
||||
|
||||
// Create a map of the latest COMPLETED execution results of output nodes by node_id
|
||||
const latestCompletedResults = executionResults
|
||||
.filter(
|
||||
(result) =>
|
||||
result.status === "COMPLETED" &&
|
||||
result.block_id === SpecialBlockID.OUTPUT,
|
||||
)
|
||||
.reduce((acc, result) => {
|
||||
const existing = acc.get(result.node_id);
|
||||
|
||||
// Compare dates if there's an existing result
|
||||
if (existing) {
|
||||
const existingDate = existing.end_time || existing.add_time;
|
||||
const currentDate = result.end_time || result.add_time;
|
||||
|
||||
if (currentDate > existingDate) {
|
||||
acc.set(result.node_id, result);
|
||||
}
|
||||
} else {
|
||||
acc.set(result.node_id, result);
|
||||
}
|
||||
|
||||
return acc;
|
||||
}, new Map<string, NodeExecutionResult>());
|
||||
const graph = await api.getGraph(agent.agent_id, agent.agent_version);
|
||||
const graphExecution = await api.getGraphExecutionInfo(
|
||||
agent.agent_id,
|
||||
execution.id,
|
||||
);
|
||||
|
||||
// Transform results to BlockOutput format
|
||||
setBlockOutputs(
|
||||
Array.from(latestCompletedResults.values()).map((result) => ({
|
||||
id: result.node_id,
|
||||
type: "output" as const,
|
||||
hardcodedValues: {
|
||||
name: result.input_data.name || "Output",
|
||||
description: result.input_data.description || "Output from the agent",
|
||||
value: result.input_data.value,
|
||||
},
|
||||
// Change this line to extract the array directly
|
||||
result: result.output_data?.output || undefined,
|
||||
})),
|
||||
Object.entries(graphExecution.outputs).flatMap(([key, values]) =>
|
||||
values.map(
|
||||
(value) =>
|
||||
({
|
||||
metadata: {
|
||||
name: graph.output_schema.properties[key].title || "Output",
|
||||
description:
|
||||
graph.output_schema.properties[key].description ||
|
||||
"Output from the agent",
|
||||
},
|
||||
result: value,
|
||||
}) satisfies BlockOutput,
|
||||
),
|
||||
),
|
||||
);
|
||||
}, [api, flow.agent_id, execution.id]);
|
||||
}, [api, agent.agent_id, agent.agent_version, execution.id]);
|
||||
|
||||
// Fetch graph and execution data
|
||||
useEffect(() => {
|
||||
@@ -77,15 +57,15 @@ export const FlowRunInfo: React.FC<
|
||||
fetchBlockResults();
|
||||
}, [isOutputOpen, fetchBlockResults]);
|
||||
|
||||
if (execution.graph_id != flow.agent_id) {
|
||||
if (execution.graph_id != agent.agent_id) {
|
||||
throw new Error(
|
||||
`FlowRunInfo can't be used with non-matching execution.graph_id and flow.id`,
|
||||
);
|
||||
}
|
||||
|
||||
const handleStopRun = useCallback(() => {
|
||||
api.stopGraphExecution(flow.agent_id, execution.id);
|
||||
}, [api, flow.agent_id, execution.id]);
|
||||
api.stopGraphExecution(agent.agent_id, execution.id);
|
||||
}, [api, agent.agent_id, execution.id]);
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -93,7 +73,7 @@ export const FlowRunInfo: React.FC<
|
||||
<CardHeader className="flex-row items-center justify-between space-x-3 space-y-0">
|
||||
<div>
|
||||
<CardTitle>
|
||||
{flow.name}{" "}
|
||||
{agent.name}{" "}
|
||||
<span className="font-light">v{execution.graph_version}</span>
|
||||
</CardTitle>
|
||||
</div>
|
||||
@@ -106,7 +86,7 @@ export const FlowRunInfo: React.FC<
|
||||
<Button onClick={() => setIsOutputOpen(true)} variant="outline">
|
||||
<ExitIcon className="mr-2" /> View Outputs
|
||||
</Button>
|
||||
{flow.can_access_graph && (
|
||||
{agent.can_access_graph && (
|
||||
<Link
|
||||
className={buttonVariants({ variant: "default" })}
|
||||
href={`/build?flowID=${execution.graph_id}&flowVersion=${execution.graph_version}&flowExecutionID=${execution.id}`}
|
||||
@@ -118,7 +98,7 @@ export const FlowRunInfo: React.FC<
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<p className="hidden">
|
||||
<strong>Agent ID:</strong> <code>{flow.agent_id}</code>
|
||||
<strong>Agent ID:</strong> <code>{agent.agent_id}</code>
|
||||
</p>
|
||||
<p className="hidden">
|
||||
<strong>Run ID:</strong> <code>{execution.id}</code>
|
||||
|
||||
@@ -29,7 +29,7 @@ export const FlowRunsStatus: React.FC<{
|
||||
: statsSince;
|
||||
const filteredFlowRuns =
|
||||
statsSinceTimestamp != null
|
||||
? executions.filter((fr) => Number(fr.started_at) > statsSinceTimestamp)
|
||||
? executions.filter((fr) => fr.started_at.getTime() > statsSinceTimestamp)
|
||||
: executions;
|
||||
|
||||
return (
|
||||
|
||||
@@ -99,7 +99,7 @@ export const FlowRunsTimeline = ({
|
||||
.filter((e) => e.graph_id == flow.agent_id)
|
||||
.map((e) => ({
|
||||
...e,
|
||||
time: Number(e.started_at) + e.total_run_time * 1000,
|
||||
time: e.started_at.getTime() + e.total_run_time * 1000,
|
||||
_duration: e.total_run_time,
|
||||
}))}
|
||||
name={flow.name}
|
||||
@@ -112,10 +112,14 @@ export const FlowRunsTimeline = ({
|
||||
type="linear"
|
||||
dataKey="_duration"
|
||||
data={[
|
||||
{ ...execution, time: Number(execution.started_at), _duration: 0 },
|
||||
{
|
||||
...execution,
|
||||
time: Number(execution.ended_at),
|
||||
time: execution.started_at.getTime(),
|
||||
_duration: 0,
|
||||
},
|
||||
{
|
||||
...execution,
|
||||
time: execution.ended_at.getTime(),
|
||||
_duration: execution.total_run_time,
|
||||
},
|
||||
]}
|
||||
|
||||
@@ -7,21 +7,19 @@ import {
|
||||
SheetDescription,
|
||||
} from "@/components/ui/sheet";
|
||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import { BlockIORootSchema } from "@/lib/autogpt-server-api/types";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Clipboard } from "lucide-react";
|
||||
import { useToast } from "@/components/ui/use-toast";
|
||||
|
||||
export interface BlockOutput {
|
||||
id: string;
|
||||
hardcodedValues: {
|
||||
export type BlockOutput = {
|
||||
metadata: {
|
||||
name: string;
|
||||
description: string;
|
||||
};
|
||||
result?: any;
|
||||
}
|
||||
};
|
||||
|
||||
interface OutputModalProps {
|
||||
isOpen: boolean;
|
||||
@@ -87,15 +85,15 @@ export function RunnerOutputUI({
|
||||
<ScrollArea className="h-full overflow-auto pr-4">
|
||||
<div className="space-y-4">
|
||||
{blockOutputs && blockOutputs.length > 0 ? (
|
||||
blockOutputs.map((block) => (
|
||||
<div key={block.id} className="space-y-1">
|
||||
blockOutputs.map((block, i) => (
|
||||
<div key={i} className="space-y-1">
|
||||
<Label className="text-base font-semibold">
|
||||
{block.hardcodedValues.name || "Unnamed Output"}
|
||||
{block.metadata.name || "Unnamed Output"}
|
||||
</Label>
|
||||
|
||||
{block.hardcodedValues.description && (
|
||||
{block.metadata.description && (
|
||||
<Label className="block text-sm text-gray-600">
|
||||
{block.hardcodedValues.description}
|
||||
{block.metadata.description}
|
||||
</Label>
|
||||
)}
|
||||
|
||||
@@ -106,7 +104,7 @@ export function RunnerOutputUI({
|
||||
size="icon"
|
||||
onClick={() =>
|
||||
copyOutput(
|
||||
block.hardcodedValues.name || "Unnamed Output",
|
||||
block.metadata.name || "Unnamed Output",
|
||||
block.result,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -631,7 +631,10 @@ export default function useAgentGraph(
|
||||
activeExecutionID: flowExecutionID,
|
||||
});
|
||||
}
|
||||
setUpdateQueue((prev) => [...prev, ...execution.node_executions]);
|
||||
setUpdateQueue((prev) => {
|
||||
if (!execution.node_executions) return prev;
|
||||
return [...prev, ...execution.node_executions];
|
||||
});
|
||||
|
||||
// Track execution until completed
|
||||
const pendingNodeExecutions: Set<string> = new Set();
|
||||
|
||||
@@ -253,11 +253,15 @@ export default class BackendAPI {
|
||||
}
|
||||
|
||||
getExecutions(): Promise<GraphExecutionMeta[]> {
|
||||
return this._get(`/executions`);
|
||||
return this._get(`/executions`).then((results) =>
|
||||
results.map(parseGraphExecutionTimestamps),
|
||||
);
|
||||
}
|
||||
|
||||
getGraphExecutions(graphID: GraphID): Promise<GraphExecutionMeta[]> {
|
||||
return this._get(`/graphs/${graphID}/executions`);
|
||||
return this._get(`/graphs/${graphID}/executions`).then((results) =>
|
||||
results.map(parseGraphExecutionTimestamps),
|
||||
);
|
||||
}
|
||||
|
||||
async getGraphExecutionInfo(
|
||||
@@ -265,10 +269,7 @@ export default class BackendAPI {
|
||||
runID: GraphExecutionID,
|
||||
): Promise<GraphExecution> {
|
||||
const result = await this._get(`/graphs/${graphID}/executions/${runID}`);
|
||||
result.node_executions = result.node_executions.map(
|
||||
parseNodeExecutionResultTimestamps,
|
||||
);
|
||||
return result;
|
||||
return parseGraphExecutionTimestamps<GraphExecution>(result);
|
||||
}
|
||||
|
||||
async stopGraphExecution(
|
||||
@@ -279,10 +280,7 @@ export default class BackendAPI {
|
||||
"POST",
|
||||
`/graphs/${graphID}/executions/${runID}/stop`,
|
||||
);
|
||||
result.node_executions = result.node_executions.map(
|
||||
parseNodeExecutionResultTimestamps,
|
||||
);
|
||||
return result;
|
||||
return parseGraphExecutionTimestamps<GraphExecution>(result);
|
||||
}
|
||||
|
||||
async deleteGraphExecution(runID: GraphExecutionID): Promise<void> {
|
||||
@@ -828,6 +826,12 @@ export default class BackendAPI {
|
||||
});
|
||||
}
|
||||
|
||||
subscribeToGraphExecutions(graphID: GraphID): Promise<void> {
|
||||
return this.sendWebSocketMessage("subscribe_graph_executions", {
|
||||
graph_id: graphID,
|
||||
});
|
||||
}
|
||||
|
||||
async sendWebSocketMessage<M extends keyof WebsocketMessageTypeMap>(
|
||||
method: M,
|
||||
data: WebsocketMessageTypeMap[M],
|
||||
@@ -904,7 +908,7 @@ export default class BackendAPI {
|
||||
if (message.method === "node_execution_event") {
|
||||
message.data = parseNodeExecutionResultTimestamps(message.data);
|
||||
} else if (message.method == "graph_execution_event") {
|
||||
message.data = parseGraphExecutionMetaTimestamps(message.data);
|
||||
message.data = parseGraphExecutionTimestamps(message.data);
|
||||
}
|
||||
this.wsMessageHandlers[message.method]?.forEach((handler) =>
|
||||
handler(message.data),
|
||||
@@ -973,7 +977,8 @@ type GraphCreateRequestBody = {
|
||||
|
||||
type WebsocketMessageTypeMap = {
|
||||
subscribe_graph_execution: { graph_exec_id: GraphExecutionID };
|
||||
graph_execution_event: GraphExecutionMeta;
|
||||
subscribe_graph_executions: { graph_id: GraphID };
|
||||
graph_execution_event: GraphExecution;
|
||||
node_execution_event: NodeExecutionResult;
|
||||
heartbeat: "ping" | "pong";
|
||||
};
|
||||
@@ -994,11 +999,16 @@ type _PydanticValidationError = {
|
||||
|
||||
/* *** HELPER FUNCTIONS *** */
|
||||
|
||||
function parseGraphExecutionMetaTimestamps(result: any): GraphExecutionMeta {
|
||||
return _parseObjectTimestamps<GraphExecutionMeta>(result, [
|
||||
"started_at",
|
||||
"ended_at",
|
||||
]);
|
||||
function parseGraphExecutionTimestamps<
|
||||
T extends GraphExecutionMeta | GraphExecution,
|
||||
>(result: any): T {
|
||||
const fixed = _parseObjectTimestamps<T>(result, ["started_at", "ended_at"]);
|
||||
if ("node_executions" in fixed && fixed.node_executions) {
|
||||
fixed.node_executions = fixed.node_executions.map(
|
||||
parseNodeExecutionResultTimestamps,
|
||||
);
|
||||
}
|
||||
return fixed;
|
||||
}
|
||||
|
||||
function parseNodeExecutionResultTimestamps(result: any): NodeExecutionResult {
|
||||
|
||||
@@ -228,7 +228,7 @@ export type LinkCreatable = Omit<Link, "id" | "is_static"> & {
|
||||
id?: string;
|
||||
};
|
||||
|
||||
/* Mirror of backend/data/graph.py:GraphExecutionMeta */
|
||||
/* Mirror of backend/data/execution.py:GraphExecutionMeta */
|
||||
export type GraphExecutionMeta = {
|
||||
id: GraphExecutionID;
|
||||
started_at: Date;
|
||||
@@ -244,11 +244,11 @@ export type GraphExecutionMeta = {
|
||||
|
||||
export type GraphExecutionID = Brand<string, "GraphExecutionID">;
|
||||
|
||||
/* Mirror of backend/data/graph.py:GraphExecution */
|
||||
/* Mirror of backend/data/execution.py:GraphExecution */
|
||||
export type GraphExecution = GraphExecutionMeta & {
|
||||
inputs: Record<string, any>;
|
||||
outputs: Record<string, Array<any>>;
|
||||
node_executions: NodeExecutionResult[];
|
||||
node_executions?: NodeExecutionResult[];
|
||||
};
|
||||
|
||||
export type GraphMeta = {
|
||||
@@ -297,7 +297,7 @@ export type GraphUpdateable = Omit<
|
||||
|
||||
export type GraphCreatable = Omit<GraphUpdateable, "id"> & { id?: string };
|
||||
|
||||
/* Mirror of backend/data/execution.py:ExecutionResult */
|
||||
/* Mirror of backend/data/execution.py:NodeExecutionResult */
|
||||
export type NodeExecutionResult = {
|
||||
graph_id: GraphID;
|
||||
graph_version: number;
|
||||
|
||||
Reference in New Issue
Block a user