diff --git a/.github/workflows/autogpt-server-ci.yml b/.github/workflows/autogpt-server-ci.yml
index f71f1a8aee..6c78301fda 100644
--- a/.github/workflows/autogpt-server-ci.yml
+++ b/.github/workflows/autogpt-server-ci.yml
@@ -128,9 +128,14 @@ jobs:
- name: Run pytest with coverage
run: |
- poetry run pytest -vv \
- test
+ if [[ "${{ runner.debug }}" == "1" ]]; then
+ poetry run pytest -vv -o log_cli=true -o log_cli_level=DEBUG test
+ else
+ poetry run pytest -vv test
+ fi
if: success() || (failure() && steps.lint.outcome == 'failure')
+ env:
+ LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
env:
CI: true
PLAIN_OUTPUT: True
diff --git a/rnd/autogpt_builder/src/components/Flow.tsx b/rnd/autogpt_builder/src/components/Flow.tsx
index 5b8bdbbc6d..edc6f696c2 100644
--- a/rnd/autogpt_builder/src/components/Flow.tsx
+++ b/rnd/autogpt_builder/src/components/Flow.tsx
@@ -34,7 +34,12 @@ import ConnectionLine from "./ConnectionLine";
import { Control, ControlPanel } from "@/components/edit/control/ControlPanel";
import { SaveControl } from "@/components/edit/control/SaveControl";
import { BlocksControl } from "@/components/edit/control/BlocksControl";
-import { IconPlay, IconRedo2, IconUndo2 } from "@/components/ui/icons";
+import {
+ IconPlay,
+ IconRedo2,
+ IconSquare,
+ IconUndo2,
+} from "@/components/ui/icons";
import { startTutorial } from "./tutorial";
import useAgentGraph from "@/hooks/useAgentGraph";
import { v4 as uuidv4 } from "uuid";
@@ -75,7 +80,9 @@ const FlowEditor: React.FC<{
availableNodes,
getOutputType,
requestSave,
- requestSaveRun,
+ requestSaveAndRun,
+ requestStopRun,
+ isRunning,
nodes,
setNodes,
edges,
@@ -542,9 +549,9 @@ const FlowEditor: React.FC<{
onClick: handleRedo,
},
{
- label: "Run",
- icon: ,
- onClick: requestSaveRun,
+ label: !isRunning ? "Run" : "Stop",
+ icon: !isRunning ? : ,
+ onClick: !isRunning ? requestSaveAndRun : requestStopRun,
},
];
diff --git a/rnd/autogpt_builder/src/components/monitor/FlowRunInfo.tsx b/rnd/autogpt_builder/src/components/monitor/FlowRunInfo.tsx
index 1247ba2d22..10be74e29b 100644
--- a/rnd/autogpt_builder/src/components/monitor/FlowRunInfo.tsx
+++ b/rnd/autogpt_builder/src/components/monitor/FlowRunInfo.tsx
@@ -1,9 +1,10 @@
-import React from "react";
-import { GraphMeta } from "@/lib/autogpt-server-api";
+import React, { useCallback } from "react";
+import AutoGPTServerAPI, { GraphMeta } from "@/lib/autogpt-server-api";
import { FlowRun } from "@/lib/types";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import Link from "next/link";
-import { buttonVariants } from "@/components/ui/button";
+import { Button, buttonVariants } from "@/components/ui/button";
+import { IconSquare } from "@/components/ui/icons";
import { Pencil2Icon } from "@radix-ui/react-icons";
import moment from "moment/moment";
import { FlowRunStatusBadge } from "@/components/monitor/FlowRunStatusBadge";
@@ -20,6 +21,11 @@ export const FlowRunInfo: React.FC<
);
}
+ const handleStopRun = useCallback(() => {
+ const api = new AutoGPTServerAPI();
+ api.stopGraphExecution(flow.id, flowRun.id);
+ }, [flow.id, flowRun.id]);
+
return (
@@ -34,12 +40,19 @@ export const FlowRunInfo: React.FC<
Run ID: {flowRun.id}
-
diff --git a/rnd/autogpt_builder/src/components/ui/icons.tsx b/rnd/autogpt_builder/src/components/ui/icons.tsx
index 28cc978a6e..09a304d82a 100644
--- a/rnd/autogpt_builder/src/components/ui/icons.tsx
+++ b/rnd/autogpt_builder/src/components/ui/icons.tsx
@@ -405,6 +405,40 @@ export const IconPlay = createIcon((props) => (
));
+/**
+ * Square icon component.
+ *
+ * @component IconSquare
+ * @param {IconProps} props - The props object containing additional attributes and event handlers for the icon.
+ * @returns {JSX.Element} - The square icon.
+ *
+ * @example
+ * // Default usage this is the standard usage
+ *
+ *
+ * @example
+ * // With custom color and size these should be used sparingly and only when necessary
+ *
+ *
+ * @example
+ * // With custom size and onClick handler
+ *
+ */
+export const IconSquare = createIcon((props) => (
+
+));
+
/**
* Package2 icon component.
*
diff --git a/rnd/autogpt_builder/src/hooks/useAgentGraph.ts b/rnd/autogpt_builder/src/hooks/useAgentGraph.ts
index 241dedcb15..454b0d4f63 100644
--- a/rnd/autogpt_builder/src/hooks/useAgentGraph.ts
+++ b/rnd/autogpt_builder/src/hooks/useAgentGraph.ts
@@ -31,22 +31,27 @@ export default function useAgentGraph(
const [updateQueue, setUpdateQueue] = useState([]);
const processedUpdates = useRef([]);
/**
- * User `request` to save or save&run the agent
+ * User `request` to save or save&run the agent, or to stop the active run.
* `state` is used to track the request status:
* - none: no request
* - saving: request was sent to save the agent
* and nodes are pending sync to update their backend ids
* - running: request was sent to run the agent
* and frontend is enqueueing execution results
+ * - stopping: a request to stop the active run has been sent; response is pending
* - error: request failed
- *
- * As of now, state will be stuck at 'running' (if run requested)
- * because there's no way to know when the execution is done
*/
- const [saveRunRequest, setSaveRunRequest] = useState<{
- request: "none" | "save" | "run";
- state: "none" | "saving" | "running" | "error";
- }>({
+ const [saveRunRequest, setSaveRunRequest] = useState<
+ | {
+ request: "none" | "save" | "run";
+ state: "none" | "saving" | "error";
+ }
+ | {
+ request: "run" | "stop";
+ state: "running" | "stopping" | "error";
+ activeExecutionID?: string;
+ }
+ >({
request: "none",
state: "none",
});
@@ -128,13 +133,14 @@ export default function useAgentGraph(
console.error("Error saving agent");
} else if (saveRunRequest.request === "run") {
console.error(`Error saving&running agent`);
+ } else if (saveRunRequest.request === "stop") {
+ console.error(`Error stopping agent`);
}
// Reset request
- setSaveRunRequest((prev) => ({
- ...prev,
+ setSaveRunRequest({
request: "none",
state: "none",
- }));
+ });
return;
}
// When saving request is done
@@ -145,11 +151,10 @@ export default function useAgentGraph(
) {
// Reset request if only save was requested
if (saveRunRequest.request === "save") {
- setSaveRunRequest((prev) => ({
- ...prev,
+ setSaveRunRequest({
request: "none",
state: "none",
- }));
+ });
// If run was requested, run the agent
} else if (saveRunRequest.request === "run") {
if (!validateNodes()) {
@@ -161,16 +166,64 @@ export default function useAgentGraph(
return;
}
api.subscribeToExecution(savedAgent.id);
- api.executeGraph(savedAgent.id);
- processedUpdates.current = processedUpdates.current = [];
+ setSaveRunRequest({ request: "run", state: "running" });
+ api
+ .executeGraph(savedAgent.id)
+ .then((graphExecution) => {
+ setSaveRunRequest({
+ request: "run",
+ state: "running",
+ activeExecutionID: graphExecution.id,
+ });
- setSaveRunRequest((prev) => ({
- ...prev,
- request: "run",
- state: "running",
- }));
+ // Track execution until completed
+ const pendingNodeExecutions: Set = new Set();
+ const cancelExecListener = api.onWebSocketMessage(
+ "execution_event",
+ (nodeResult) => {
+ // We are racing the server here, since we need the ID to filter events
+ if (nodeResult.graph_exec_id != graphExecution.id) {
+ return;
+ }
+ if (
+ nodeResult.status != "COMPLETED" &&
+ nodeResult.status != "FAILED"
+ ) {
+ pendingNodeExecutions.add(nodeResult.node_exec_id);
+ } else {
+ pendingNodeExecutions.delete(nodeResult.node_exec_id);
+ }
+ if (pendingNodeExecutions.size == 0) {
+ // Assuming the first event is always a QUEUED node, and
+ // following nodes are QUEUED before all preceding nodes are COMPLETED,
+ // an empty set means the graph has finished running.
+ cancelExecListener();
+ setSaveRunRequest({ request: "none", state: "none" });
+ }
+ },
+ );
+ })
+ .catch(() => setSaveRunRequest({ request: "run", state: "error" }));
+
+ processedUpdates.current = processedUpdates.current = [];
}
}
+ // Handle stop request
+ if (
+ saveRunRequest.request === "stop" &&
+ saveRunRequest.state != "stopping" &&
+ savedAgent &&
+ saveRunRequest.activeExecutionID
+ ) {
+ setSaveRunRequest({
+ request: "stop",
+ state: "stopping",
+ activeExecutionID: saveRunRequest.activeExecutionID,
+ });
+ api
+ .stopGraphExecution(savedAgent.id, saveRunRequest.activeExecutionID)
+ .then(() => setSaveRunRequest({ request: "none", state: "none" }));
+ }
}, [saveRunRequest, savedAgent, nodesSyncedWithSavedAgent]);
// Check if node ids are synced with saved agent
@@ -657,7 +710,7 @@ export default function useAgentGraph(
[saveAgent],
);
- const requestSaveRun = useCallback(() => {
+ const requestSaveAndRun = useCallback(() => {
saveAgent();
setSaveRunRequest({
request: "run",
@@ -665,6 +718,23 @@ export default function useAgentGraph(
});
}, [saveAgent]);
+ const requestStopRun = useCallback(() => {
+ if (saveRunRequest.state != "running") {
+ return;
+ }
+ if (!saveRunRequest.activeExecutionID) {
+ console.warn(
+ "Stop requested but execution ID is unknown; state:",
+ saveRunRequest,
+ );
+ }
+ setSaveRunRequest((prev) => ({
+ ...prev,
+ request: "stop",
+ state: "running",
+ }));
+ }, [saveRunRequest]);
+
return {
agentName,
setAgentName,
@@ -674,7 +744,11 @@ export default function useAgentGraph(
availableNodes,
getOutputType,
requestSave,
- requestSaveRun,
+ requestSaveAndRun,
+ requestStopRun,
+ isSaving: saveRunRequest.state == "saving",
+ isRunning: saveRunRequest.state == "running",
+ isStopping: saveRunRequest.state == "stopping",
nodes,
setNodes,
edges,
diff --git a/rnd/autogpt_builder/src/lib/autogpt-server-api/client.ts b/rnd/autogpt_builder/src/lib/autogpt-server-api/client.ts
index 0948e842a9..56d950d9cd 100644
--- a/rnd/autogpt_builder/src/lib/autogpt-server-api/client.ts
+++ b/rnd/autogpt_builder/src/lib/autogpt-server-api/client.ts
@@ -15,7 +15,7 @@ export default class AutoGPTServerAPI {
private wsUrl: string;
private webSocket: WebSocket | null = null;
private wsConnecting: Promise | null = null;
- private wsMessageHandlers: { [key: string]: (data: any) => void } = {};
+ private wsMessageHandlers: Record void>> = {};
private supabaseClient = createClient();
constructor(
@@ -128,16 +128,19 @@ export default class AutoGPTServerAPI {
runID: string,
): Promise {
return (await this._get(`/graphs/${graphID}/executions/${runID}`)).map(
- (result: any) => ({
- ...result,
- add_time: new Date(result.add_time),
- queue_time: result.queue_time ? new Date(result.queue_time) : undefined,
- start_time: result.start_time ? new Date(result.start_time) : undefined,
- end_time: result.end_time ? new Date(result.end_time) : undefined,
- }),
+ parseNodeExecutionResultTimestamps,
);
}
+ async stopGraphExecution(
+ graphID: string,
+ runID: string,
+ ): Promise {
+ return (
+ await this._request("POST", `/graphs/${graphID}/executions/${runID}/stop`)
+ ).map(parseNodeExecutionResultTimestamps);
+ }
+
private async _get(path: string) {
return this._request("GET", path);
}
@@ -207,10 +210,13 @@ export default class AutoGPTServerAPI {
};
this.webSocket.onmessage = (event) => {
- const message = JSON.parse(event.data);
- if (this.wsMessageHandlers[message.method]) {
- this.wsMessageHandlers[message.method](message.data);
+ const message: WebsocketMessage = JSON.parse(event.data);
+ if (message.method == "execution_event") {
+ message.data = parseNodeExecutionResultTimestamps(message.data);
}
+ this.wsMessageHandlers[message.method]?.forEach((handler) =>
+ handler(message.data),
+ );
};
} catch (error) {
console.error("Error connecting to WebSocket:", error);
@@ -250,8 +256,12 @@ export default class AutoGPTServerAPI {
onWebSocketMessage(
method: M,
handler: (data: WebsocketMessageTypeMap[M]) => void,
- ) {
- this.wsMessageHandlers[method] = handler;
+ ): () => void {
+ this.wsMessageHandlers[method] ??= new Set();
+ this.wsMessageHandlers[method].add(handler);
+
+ // Return detacher
+ return () => this.wsMessageHandlers[method].delete(handler);
}
subscribeToExecution(graphId: string) {
@@ -274,3 +284,22 @@ type WebsocketMessageTypeMap = {
subscribe: { graph_id: string };
execution_event: NodeExecutionResult;
};
+
+type WebsocketMessage = {
+ [M in keyof WebsocketMessageTypeMap]: {
+ method: M;
+ data: WebsocketMessageTypeMap[M];
+ };
+}[keyof WebsocketMessageTypeMap];
+
+/* *** HELPER FUNCTIONS *** */
+
+function parseNodeExecutionResultTimestamps(result: any): NodeExecutionResult {
+ return {
+ ...result,
+ add_time: new Date(result.add_time),
+ queue_time: result.queue_time ? new Date(result.queue_time) : undefined,
+ start_time: result.start_time ? new Date(result.start_time) : undefined,
+ end_time: result.end_time ? new Date(result.end_time) : undefined,
+ };
+}
diff --git a/rnd/autogpt_server/autogpt_server/data/execution.py b/rnd/autogpt_server/autogpt_server/data/execution.py
index d5abd9f316..0fd54feaa3 100644
--- a/rnd/autogpt_server/autogpt_server/data/execution.py
+++ b/rnd/autogpt_server/autogpt_server/data/execution.py
@@ -300,6 +300,26 @@ async def update_execution_status(
return ExecutionResult.from_db(res)
+async def get_graph_execution(
+ graph_exec_id: str, user_id: str
+) -> AgentGraphExecution | None:
+ """
+ Retrieve a specific graph execution by its ID.
+
+ Args:
+ graph_exec_id (str): The ID of the graph execution to retrieve.
+ user_id (str): The ID of the user to whom the graph (execution) belongs.
+
+ Returns:
+ AgentGraphExecution | None: The graph execution if found, None otherwise.
+ """
+ execution = await AgentGraphExecution.prisma().find_first(
+ where={"id": graph_exec_id, "userId": user_id},
+ include=GRAPH_EXECUTION_INCLUDE,
+ )
+ return execution
+
+
async def list_executions(graph_id: str, graph_version: int | None = None) -> list[str]:
where: AgentGraphExecutionWhereInput = {"agentGraphId": graph_id}
if graph_version is not None:
diff --git a/rnd/autogpt_server/autogpt_server/executor/manager.py b/rnd/autogpt_server/autogpt_server/executor/manager.py
index edde97b156..2f0113a69e 100644
--- a/rnd/autogpt_server/autogpt_server/executor/manager.py
+++ b/rnd/autogpt_server/autogpt_server/executor/manager.py
@@ -1,7 +1,10 @@
import asyncio
import logging
-from concurrent.futures import Future, ProcessPoolExecutor, TimeoutError
+import multiprocessing
+import threading
+from concurrent.futures import Future, ProcessPoolExecutor
from contextlib import contextmanager
+from multiprocessing.pool import AsyncResult, Pool
from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar
if TYPE_CHECKING:
@@ -16,6 +19,7 @@ from autogpt_server.data.execution import (
GraphExecution,
NodeExecution,
create_graph_execution,
+ get_execution_results,
get_incomplete_executions,
get_latest_execution,
merge_execution_input,
@@ -461,15 +465,19 @@ class Executor:
cls.loop = asyncio.new_event_loop()
cls.loop.run_until_complete(db.connect())
cls.pool_size = Config().num_node_workers
- cls.executor = ProcessPoolExecutor(
- max_workers=cls.pool_size,
+ cls._init_node_executor_pool()
+ logger.info(f"Graph executor started with max-{cls.pool_size} node workers.")
+
+ @classmethod
+ def _init_node_executor_pool(cls):
+ cls.executor = Pool(
+ processes=cls.pool_size,
initializer=cls.on_node_executor_start,
)
- cls.logger.info(f"Graph executor started with max-{cls.pool_size} node workers")
@classmethod
@error_logged
- def on_graph_execution(cls, data: GraphExecution):
+ def on_graph_execution(cls, data: GraphExecution, cancel: threading.Event):
log_metadata = get_log_metadata(
graph_eid=data.graph_exec_id,
graph_id=data.graph_id,
@@ -477,7 +485,7 @@ class Executor:
node_eid="*",
block_name="-",
)
- timing_info, node_count = cls._on_graph_execution(data, log_metadata)
+ timing_info, node_count = cls._on_graph_execution(data, cancel, log_metadata)
metric_graph_timing("walltime", timing_info.wall_time, tags=log_metadata)
metric_graph_timing("cputime", timing_info.cpu_time, tags=log_metadata)
metric_graph_count("nodecount", node_count, tags=log_metadata)
@@ -495,7 +503,9 @@ class Executor:
@classmethod
@time_measured
- def _on_graph_execution(cls, graph_data: GraphExecution, log_metadata: dict) -> int:
+ def _on_graph_execution(
+ cls, graph_data: GraphExecution, cancel: threading.Event, log_metadata: dict
+ ) -> int:
cls.logger.info(
"Start graph execution",
extra={
@@ -504,38 +514,85 @@ class Executor:
}
},
)
- node_executed = 0
+ n_node_executions = 0
+ finished = False
+
+ def cancel_handler():
+ while not cancel.is_set():
+ cancel.wait(1)
+ if finished:
+ return
+ cls.executor.terminate()
+ logger.info(
+ f"Terminated graph execution {graph_data.graph_exec_id}",
+ extra={"json_fields": {**log_metadata}},
+ )
+ cls._init_node_executor_pool()
+
+ cancel_thread = threading.Thread(target=cancel_handler)
+ cancel_thread.start()
try:
queue = ExecutionQueue[NodeExecution]()
for node_exec in graph_data.start_node_execs:
queue.add(node_exec)
- futures: dict[str, Future] = {}
+ running_executions: dict[str, AsyncResult] = {}
+
+ def make_exec_callback(exec_data: NodeExecution):
+ node_id = exec_data.node_id
+
+ def callback(_):
+ running_executions.pop(node_id)
+ nonlocal n_node_executions
+ n_node_executions += 1
+
+ return callback
+
while not queue.empty():
- execution = queue.get()
+ if cancel.is_set():
+ return n_node_executions
+
+ exec_data = queue.get()
# Avoid parallel execution of the same node.
- fut = futures.get(execution.node_id)
- if fut and not fut.done():
+ execution = running_executions.get(exec_data.node_id)
+ if execution and not execution.ready():
# TODO (performance improvement):
# Wait for the completion of the same node execution is blocking.
# To improve this we need a separate queue for each node.
# Re-enqueueing the data back to the queue will disrupt the order.
- cls.wait_future(fut, timeout=None)
+ execution.wait()
- futures[execution.node_id] = cls.executor.submit(
- cls.on_node_execution, queue, execution
+ logger.debug(f"Dispatching execution of node {exec_data.node_id}")
+ running_executions[exec_data.node_id] = cls.executor.apply_async(
+ cls.on_node_execution,
+ (queue, exec_data),
+ callback=make_exec_callback(exec_data),
)
# Avoid terminating graph execution when some nodes are still running.
- while queue.empty() and futures:
- for node_id, future in list(futures.items()):
- if future.done():
- node_executed += 1
- del futures[node_id]
- elif queue.empty():
- cls.wait_future(future)
+ while queue.empty() and running_executions:
+ logger.debug(
+ "Queue empty; running nodes: "
+ f"{list(running_executions.keys())}"
+ )
+ for node_id, execution in list(running_executions.items()):
+ if cancel.is_set():
+ return n_node_executions
+
+ if not queue.empty():
+ logger.debug(
+ "Queue no longer empty! Returning to dispatching loop."
+ )
+ break # yield to parent loop to execute new queue items
+
+ logger.debug(f"Waiting on execution of node {node_id}")
+ execution.wait(3)
+ logger.debug(
+ f"State of execution of node {node_id} after waiting: "
+ f"{'DONE' if execution.ready() else 'RUNNING'}"
+ )
cls.logger.info(
"Finished graph execution",
@@ -546,7 +603,7 @@ class Executor:
},
)
except Exception as e:
- cls.logger.exception(
+ logger.exception(
f"Failed graph execution: {e}",
extra={
"json_fields": {
@@ -554,24 +611,20 @@ class Executor:
}
},
)
-
- return node_executed
-
- @classmethod
- def wait_future(cls, future: Future, timeout: int | None = 3):
- try:
- if not future.done():
- future.result(timeout=timeout)
- except TimeoutError:
- # Avoid being blocked by long-running node, by not waiting its completion.
- pass
+ finally:
+ if not cancel.is_set():
+ finished = True
+ cancel.set()
+ cancel_thread.join()
+ return n_node_executions
class ExecutionManager(AppService):
def __init__(self):
+ self.use_redis = False
self.pool_size = Config().num_graph_workers
self.queue = ExecutionQueue[GraphExecution]()
- self.use_redis = False
+ self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
# def __del__(self):
# self.sync_manager.shutdown()
@@ -581,11 +634,21 @@ class ExecutionManager(AppService):
max_workers=self.pool_size,
initializer=Executor.on_graph_executor_start,
) as executor:
+ sync_manager = multiprocessing.Manager()
logger.info(
f"Execution manager started with max-{self.pool_size} graph workers."
)
while True:
- executor.submit(Executor.on_graph_execution, self.queue.get())
+ graph_exec_data = self.queue.get()
+ graph_exec_id = graph_exec_data.graph_exec_id
+ cancel_event = sync_manager.Event()
+ future = executor.submit(
+ Executor.on_graph_execution, graph_exec_data, cancel_event
+ )
+ self.active_graph_runs[graph_exec_id] = (future, cancel_event)
+ future.add_done_callback(
+ lambda _: self.active_graph_runs.pop(graph_exec_id)
+ )
@property
def agent_server_client(self) -> "AgentServer":
@@ -594,7 +657,7 @@ class ExecutionManager(AppService):
@expose
def add_execution(
self, graph_id: str, data: BlockInput, user_id: str
- ) -> dict[Any, Any]:
+ ) -> dict[str, Any]:
graph: Graph | None = self.run_and_wait(get_graph(graph_id, user_id=user_id))
if not graph:
raise Exception(f"Graph #{graph_id} not found.")
@@ -647,4 +710,45 @@ class ExecutionManager(AppService):
)
self.queue.add(graph_exec)
- return {"id": graph_exec_id}
+ return graph_exec.model_dump()
+
+ @expose
+ def cancel_execution(self, graph_exec_id: str) -> None:
+ """
+ Mechanism:
+ 1. Set the cancel event
+ 2. Graph executor's cancel handler thread detects the event, terminates workers,
+ reinitializes worker pool, and returns.
+ 3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
+ """
+ if graph_exec_id not in self.active_graph_runs:
+ raise Exception(
+ f"Graph execution #{graph_exec_id} not active/running: "
+ "possibly already completed/cancelled."
+ )
+
+ future, cancel_event = self.active_graph_runs[graph_exec_id]
+ if cancel_event.is_set():
+ return
+
+ cancel_event.set()
+ future.result()
+
+ # Update the status of the unfinished node executions
+ node_execs = self.run_and_wait(get_execution_results(graph_exec_id))
+ for node_exec in node_execs:
+ if node_exec.status not in (
+ ExecutionStatus.COMPLETED,
+ ExecutionStatus.FAILED,
+ ):
+ self.run_and_wait(
+ upsert_execution_output(
+ node_exec.node_exec_id, "error", "TERMINATED"
+ )
+ )
+ exec_update = self.run_and_wait(
+ update_execution_status(
+ node_exec.node_exec_id, ExecutionStatus.FAILED
+ )
+ )
+ self.agent_server_client.send_execution_update(exec_update.model_dump())
diff --git a/rnd/autogpt_server/autogpt_server/server/rest_api.py b/rnd/autogpt_server/autogpt_server/server/rest_api.py
index a51418f49b..7231ae75ab 100644
--- a/rnd/autogpt_server/autogpt_server/server/rest_api.py
+++ b/rnd/autogpt_server/autogpt_server/server/rest_api.py
@@ -11,14 +11,10 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from autogpt_server.data import block, db
+from autogpt_server.data import execution as execution_db
from autogpt_server.data import graph as graph_db
from autogpt_server.data import user as user_db
from autogpt_server.data.block import BlockInput, CompletedBlockOutput
-from autogpt_server.data.execution import (
- ExecutionResult,
- get_execution_results,
- list_executions,
-)
from autogpt_server.data.queue import AsyncEventQueue, AsyncRedisEventQueue
from autogpt_server.data.user import get_or_create_user
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
@@ -170,10 +166,15 @@ class AgentServer(AppService):
methods=["GET"],
)
router.add_api_route(
- path="/graphs/{graph_id}/executions/{run_id}",
- endpoint=self.get_run_execution_results,
+ path="/graphs/{graph_id}/executions/{graph_exec_id}",
+ endpoint=self.get_graph_run_node_execution_results,
methods=["GET"],
)
+ router.add_api_route(
+ path="/graphs/{graph_id}/executions/{graph_exec_id}/stop",
+ endpoint=self.stop_graph_run,
+ methods=["POST"],
+ )
router.add_api_route(
path="/graphs/{graph_id}/schedules",
endpoint=self.create_schedule,
@@ -423,15 +424,29 @@ class AgentServer(AppService):
graph_id: str,
node_input: dict[Any, Any],
user_id: Annotated[str, Depends(get_user_id)],
- ) -> dict[Any, Any]:
+ ) -> dict[str, Any]: # FIXME: add proper return type
try:
- return self.execution_manager_client.add_execution(
+ graph_exec = self.execution_manager_client.add_execution(
graph_id, node_input, user_id=user_id
)
+ return {"id": graph_exec["graph_exec_id"]}
except Exception as e:
msg = e.__str__().encode().decode("unicode_escape")
raise HTTPException(status_code=400, detail=msg)
+ async def stop_graph_run(
+ self, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
+ ) -> list[execution_db.ExecutionResult]:
+ if not await execution_db.get_graph_execution(graph_exec_id, user_id):
+ raise HTTPException(
+ 404, detail=f"Agent execution #{graph_exec_id} not found"
+ )
+
+ self.execution_manager_client.cancel_execution(graph_exec_id)
+
+ # Retrieve & return canceled graph execution in its final state
+ return await execution_db.get_execution_results(graph_exec_id)
+
@classmethod
async def get_graph_input_schema(
cls,
@@ -458,17 +473,20 @@ class AgentServer(AppService):
status_code=404, detail=f"Agent #{graph_id}{rev} not found."
)
- return await list_executions(graph_id, graph_version)
+ return await execution_db.list_executions(graph_id, graph_version)
@classmethod
- async def get_run_execution_results(
- cls, graph_id: str, run_id: str, user_id: Annotated[str, Depends(get_user_id)]
- ) -> list[ExecutionResult]:
+ async def get_graph_run_node_execution_results(
+ cls,
+ graph_id: str,
+ graph_exec_id: str,
+ user_id: Annotated[str, Depends(get_user_id)],
+ ) -> list[execution_db.ExecutionResult]:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
- return await get_execution_results(run_id)
+ return await execution_db.get_execution_results(graph_exec_id)
async def create_schedule(
self,
@@ -506,7 +524,7 @@ class AgentServer(AppService):
@expose
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
- execution_result = ExecutionResult(**execution_result_dict)
+ execution_result = execution_db.ExecutionResult(**execution_result_dict)
self.run_and_wait(self.event_queue.put(execution_result))
@expose
diff --git a/rnd/autogpt_server/autogpt_server/util/test.py b/rnd/autogpt_server/autogpt_server/util/test.py
index 8349920404..05f9178c93 100644
--- a/rnd/autogpt_server/autogpt_server/util/test.py
+++ b/rnd/autogpt_server/autogpt_server/util/test.py
@@ -59,7 +59,6 @@ class SpinTestServer:
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
async def __aenter__(self):
-
self.name_server.__enter__()
self.setup_dependency_overrides()
self.agent_server.__enter__()
@@ -95,7 +94,7 @@ async def wait_execution(
timeout: int = 20,
) -> list:
async def is_execution_completed():
- execs = await AgentServer().get_run_execution_results(
+ execs = await AgentServer().get_graph_run_node_execution_results(
graph_id, graph_exec_id, user_id
)
return (
@@ -110,7 +109,7 @@ async def wait_execution(
# Wait for the executions to complete
for i in range(timeout):
if await is_execution_completed():
- return await AgentServer().get_run_execution_results(
+ return await AgentServer().get_graph_run_node_execution_results(
graph_id, graph_exec_id, user_id
)
time.sleep(1)
diff --git a/rnd/autogpt_server/config.default.json b/rnd/autogpt_server/config.default.json
index b3afd56b5d..08b2b14bdd 100644
--- a/rnd/autogpt_server/config.default.json
+++ b/rnd/autogpt_server/config.default.json
@@ -1,4 +1,4 @@
{
"num_graph_workers": 10,
- "num_node_workers": 10
+ "num_node_workers": 5
}
diff --git a/rnd/autogpt_server/test/executor/test_manager.py b/rnd/autogpt_server/test/executor/test_manager.py
index 151b641bbc..51debc789a 100644
--- a/rnd/autogpt_server/test/executor/test_manager.py
+++ b/rnd/autogpt_server/test/executor/test_manager.py
@@ -35,7 +35,7 @@ async def assert_sample_graph_executions(
test_user: User,
graph_exec_id: str,
):
- executions = await agent_server.get_run_execution_results(
+ executions = await agent_server.get_graph_run_node_execution_results(
test_graph.id, graph_exec_id, test_user.id
)
@@ -156,7 +156,7 @@ async def test_input_pin_always_waited(server: SpinTestServer):
server.agent_server, server.exec_manager, test_graph, test_user, {}, 3
)
- executions = await server.agent_server.get_run_execution_results(
+ executions = await server.agent_server.get_graph_run_node_execution_results(
test_graph.id, graph_exec_id, test_user.id
)
assert len(executions) == 3
@@ -236,7 +236,7 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
graph_exec_id = await execute_graph(
server.agent_server, server.exec_manager, test_graph, test_user, {}, 8
)
- executions = await server.agent_server.get_run_execution_results(
+ executions = await server.agent_server.get_graph_run_node_execution_results(
test_graph.id, graph_exec_id, test_user.id
)
assert len(executions) == 8
diff --git a/rnd/infra/helm/autogpt-server/values.dev.yaml b/rnd/infra/helm/autogpt-server/values.dev.yaml
index baff7fff15..2d292a8a03 100644
--- a/rnd/infra/helm/autogpt-server/values.dev.yaml
+++ b/rnd/infra/helm/autogpt-server/values.dev.yaml
@@ -82,6 +82,6 @@ env:
APP_ENV: "dev"
PYRO_HOST: "0.0.0.0"
NUM_GRAPH_WORKERS: 100
- NUM_NODE_WORKERS: 100
+ NUM_NODE_WORKERS: 5
REDIS_HOST: "redis-dev-master.redis-dev.svc.cluster.local"
- REDIS_PORT: "6379"
\ No newline at end of file
+ REDIS_PORT: "6379"