diff --git a/.github/workflows/autogpt-server-ci.yml b/.github/workflows/autogpt-server-ci.yml index f71f1a8aee..6c78301fda 100644 --- a/.github/workflows/autogpt-server-ci.yml +++ b/.github/workflows/autogpt-server-ci.yml @@ -128,9 +128,14 @@ jobs: - name: Run pytest with coverage run: | - poetry run pytest -vv \ - test + if [[ "${{ runner.debug }}" == "1" ]]; then + poetry run pytest -vv -o log_cli=true -o log_cli_level=DEBUG test + else + poetry run pytest -vv test + fi if: success() || (failure() && steps.lint.outcome == 'failure') + env: + LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }} env: CI: true PLAIN_OUTPUT: True diff --git a/rnd/autogpt_builder/src/components/Flow.tsx b/rnd/autogpt_builder/src/components/Flow.tsx index 5b8bdbbc6d..edc6f696c2 100644 --- a/rnd/autogpt_builder/src/components/Flow.tsx +++ b/rnd/autogpt_builder/src/components/Flow.tsx @@ -34,7 +34,12 @@ import ConnectionLine from "./ConnectionLine"; import { Control, ControlPanel } from "@/components/edit/control/ControlPanel"; import { SaveControl } from "@/components/edit/control/SaveControl"; import { BlocksControl } from "@/components/edit/control/BlocksControl"; -import { IconPlay, IconRedo2, IconUndo2 } from "@/components/ui/icons"; +import { + IconPlay, + IconRedo2, + IconSquare, + IconUndo2, +} from "@/components/ui/icons"; import { startTutorial } from "./tutorial"; import useAgentGraph from "@/hooks/useAgentGraph"; import { v4 as uuidv4 } from "uuid"; @@ -75,7 +80,9 @@ const FlowEditor: React.FC<{ availableNodes, getOutputType, requestSave, - requestSaveRun, + requestSaveAndRun, + requestStopRun, + isRunning, nodes, setNodes, edges, @@ -542,9 +549,9 @@ const FlowEditor: React.FC<{ onClick: handleRedo, }, { - label: "Run", - icon: , - onClick: requestSaveRun, + label: !isRunning ? "Run" : "Stop", + icon: !isRunning ? : , + onClick: !isRunning ? requestSaveAndRun : requestStopRun, }, ]; diff --git a/rnd/autogpt_builder/src/components/monitor/FlowRunInfo.tsx b/rnd/autogpt_builder/src/components/monitor/FlowRunInfo.tsx index 1247ba2d22..10be74e29b 100644 --- a/rnd/autogpt_builder/src/components/monitor/FlowRunInfo.tsx +++ b/rnd/autogpt_builder/src/components/monitor/FlowRunInfo.tsx @@ -1,9 +1,10 @@ -import React from "react"; -import { GraphMeta } from "@/lib/autogpt-server-api"; +import React, { useCallback } from "react"; +import AutoGPTServerAPI, { GraphMeta } from "@/lib/autogpt-server-api"; import { FlowRun } from "@/lib/types"; import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; import Link from "next/link"; -import { buttonVariants } from "@/components/ui/button"; +import { Button, buttonVariants } from "@/components/ui/button"; +import { IconSquare } from "@/components/ui/icons"; import { Pencil2Icon } from "@radix-ui/react-icons"; import moment from "moment/moment"; import { FlowRunStatusBadge } from "@/components/monitor/FlowRunStatusBadge"; @@ -20,6 +21,11 @@ export const FlowRunInfo: React.FC< ); } + const handleStopRun = useCallback(() => { + const api = new AutoGPTServerAPI(); + api.stopGraphExecution(flow.id, flowRun.id); + }, [flow.id, flowRun.id]); + return ( @@ -34,12 +40,19 @@ export const FlowRunInfo: React.FC< Run ID: {flowRun.id}

- - Edit Agent - +
+ {flowRun.status === "running" && ( + + )} + + Edit Agent + +

diff --git a/rnd/autogpt_builder/src/components/ui/icons.tsx b/rnd/autogpt_builder/src/components/ui/icons.tsx index 28cc978a6e..09a304d82a 100644 --- a/rnd/autogpt_builder/src/components/ui/icons.tsx +++ b/rnd/autogpt_builder/src/components/ui/icons.tsx @@ -405,6 +405,40 @@ export const IconPlay = createIcon((props) => ( )); +/** + * Square icon component. + * + * @component IconSquare + * @param {IconProps} props - The props object containing additional attributes and event handlers for the icon. + * @returns {JSX.Element} - The square icon. + * + * @example + * // Default usage this is the standard usage + * + * + * @example + * // With custom color and size these should be used sparingly and only when necessary + * + * + * @example + * // With custom size and onClick handler + * + */ +export const IconSquare = createIcon((props) => ( + + + +)); + /** * Package2 icon component. * diff --git a/rnd/autogpt_builder/src/hooks/useAgentGraph.ts b/rnd/autogpt_builder/src/hooks/useAgentGraph.ts index 241dedcb15..454b0d4f63 100644 --- a/rnd/autogpt_builder/src/hooks/useAgentGraph.ts +++ b/rnd/autogpt_builder/src/hooks/useAgentGraph.ts @@ -31,22 +31,27 @@ export default function useAgentGraph( const [updateQueue, setUpdateQueue] = useState([]); const processedUpdates = useRef([]); /** - * User `request` to save or save&run the agent + * User `request` to save or save&run the agent, or to stop the active run. * `state` is used to track the request status: * - none: no request * - saving: request was sent to save the agent * and nodes are pending sync to update their backend ids * - running: request was sent to run the agent * and frontend is enqueueing execution results + * - stopping: a request to stop the active run has been sent; response is pending * - error: request failed - * - * As of now, state will be stuck at 'running' (if run requested) - * because there's no way to know when the execution is done */ - const [saveRunRequest, setSaveRunRequest] = useState<{ - request: "none" | "save" | "run"; - state: "none" | "saving" | "running" | "error"; - }>({ + const [saveRunRequest, setSaveRunRequest] = useState< + | { + request: "none" | "save" | "run"; + state: "none" | "saving" | "error"; + } + | { + request: "run" | "stop"; + state: "running" | "stopping" | "error"; + activeExecutionID?: string; + } + >({ request: "none", state: "none", }); @@ -128,13 +133,14 @@ export default function useAgentGraph( console.error("Error saving agent"); } else if (saveRunRequest.request === "run") { console.error(`Error saving&running agent`); + } else if (saveRunRequest.request === "stop") { + console.error(`Error stopping agent`); } // Reset request - setSaveRunRequest((prev) => ({ - ...prev, + setSaveRunRequest({ request: "none", state: "none", - })); + }); return; } // When saving request is done @@ -145,11 +151,10 @@ export default function useAgentGraph( ) { // Reset request if only save was requested if (saveRunRequest.request === "save") { - setSaveRunRequest((prev) => ({ - ...prev, + setSaveRunRequest({ request: "none", state: "none", - })); + }); // If run was requested, run the agent } else if (saveRunRequest.request === "run") { if (!validateNodes()) { @@ -161,16 +166,64 @@ export default function useAgentGraph( return; } api.subscribeToExecution(savedAgent.id); - api.executeGraph(savedAgent.id); - processedUpdates.current = processedUpdates.current = []; + setSaveRunRequest({ request: "run", state: "running" }); + api + .executeGraph(savedAgent.id) + .then((graphExecution) => { + setSaveRunRequest({ + request: "run", + state: "running", + activeExecutionID: graphExecution.id, + }); - setSaveRunRequest((prev) => ({ - ...prev, - request: "run", - state: "running", - })); + // Track execution until completed + const pendingNodeExecutions: Set = new Set(); + const cancelExecListener = api.onWebSocketMessage( + "execution_event", + (nodeResult) => { + // We are racing the server here, since we need the ID to filter events + if (nodeResult.graph_exec_id != graphExecution.id) { + return; + } + if ( + nodeResult.status != "COMPLETED" && + nodeResult.status != "FAILED" + ) { + pendingNodeExecutions.add(nodeResult.node_exec_id); + } else { + pendingNodeExecutions.delete(nodeResult.node_exec_id); + } + if (pendingNodeExecutions.size == 0) { + // Assuming the first event is always a QUEUED node, and + // following nodes are QUEUED before all preceding nodes are COMPLETED, + // an empty set means the graph has finished running. + cancelExecListener(); + setSaveRunRequest({ request: "none", state: "none" }); + } + }, + ); + }) + .catch(() => setSaveRunRequest({ request: "run", state: "error" })); + + processedUpdates.current = processedUpdates.current = []; } } + // Handle stop request + if ( + saveRunRequest.request === "stop" && + saveRunRequest.state != "stopping" && + savedAgent && + saveRunRequest.activeExecutionID + ) { + setSaveRunRequest({ + request: "stop", + state: "stopping", + activeExecutionID: saveRunRequest.activeExecutionID, + }); + api + .stopGraphExecution(savedAgent.id, saveRunRequest.activeExecutionID) + .then(() => setSaveRunRequest({ request: "none", state: "none" })); + } }, [saveRunRequest, savedAgent, nodesSyncedWithSavedAgent]); // Check if node ids are synced with saved agent @@ -657,7 +710,7 @@ export default function useAgentGraph( [saveAgent], ); - const requestSaveRun = useCallback(() => { + const requestSaveAndRun = useCallback(() => { saveAgent(); setSaveRunRequest({ request: "run", @@ -665,6 +718,23 @@ export default function useAgentGraph( }); }, [saveAgent]); + const requestStopRun = useCallback(() => { + if (saveRunRequest.state != "running") { + return; + } + if (!saveRunRequest.activeExecutionID) { + console.warn( + "Stop requested but execution ID is unknown; state:", + saveRunRequest, + ); + } + setSaveRunRequest((prev) => ({ + ...prev, + request: "stop", + state: "running", + })); + }, [saveRunRequest]); + return { agentName, setAgentName, @@ -674,7 +744,11 @@ export default function useAgentGraph( availableNodes, getOutputType, requestSave, - requestSaveRun, + requestSaveAndRun, + requestStopRun, + isSaving: saveRunRequest.state == "saving", + isRunning: saveRunRequest.state == "running", + isStopping: saveRunRequest.state == "stopping", nodes, setNodes, edges, diff --git a/rnd/autogpt_builder/src/lib/autogpt-server-api/client.ts b/rnd/autogpt_builder/src/lib/autogpt-server-api/client.ts index 0948e842a9..56d950d9cd 100644 --- a/rnd/autogpt_builder/src/lib/autogpt-server-api/client.ts +++ b/rnd/autogpt_builder/src/lib/autogpt-server-api/client.ts @@ -15,7 +15,7 @@ export default class AutoGPTServerAPI { private wsUrl: string; private webSocket: WebSocket | null = null; private wsConnecting: Promise | null = null; - private wsMessageHandlers: { [key: string]: (data: any) => void } = {}; + private wsMessageHandlers: Record void>> = {}; private supabaseClient = createClient(); constructor( @@ -128,16 +128,19 @@ export default class AutoGPTServerAPI { runID: string, ): Promise { return (await this._get(`/graphs/${graphID}/executions/${runID}`)).map( - (result: any) => ({ - ...result, - add_time: new Date(result.add_time), - queue_time: result.queue_time ? new Date(result.queue_time) : undefined, - start_time: result.start_time ? new Date(result.start_time) : undefined, - end_time: result.end_time ? new Date(result.end_time) : undefined, - }), + parseNodeExecutionResultTimestamps, ); } + async stopGraphExecution( + graphID: string, + runID: string, + ): Promise { + return ( + await this._request("POST", `/graphs/${graphID}/executions/${runID}/stop`) + ).map(parseNodeExecutionResultTimestamps); + } + private async _get(path: string) { return this._request("GET", path); } @@ -207,10 +210,13 @@ export default class AutoGPTServerAPI { }; this.webSocket.onmessage = (event) => { - const message = JSON.parse(event.data); - if (this.wsMessageHandlers[message.method]) { - this.wsMessageHandlers[message.method](message.data); + const message: WebsocketMessage = JSON.parse(event.data); + if (message.method == "execution_event") { + message.data = parseNodeExecutionResultTimestamps(message.data); } + this.wsMessageHandlers[message.method]?.forEach((handler) => + handler(message.data), + ); }; } catch (error) { console.error("Error connecting to WebSocket:", error); @@ -250,8 +256,12 @@ export default class AutoGPTServerAPI { onWebSocketMessage( method: M, handler: (data: WebsocketMessageTypeMap[M]) => void, - ) { - this.wsMessageHandlers[method] = handler; + ): () => void { + this.wsMessageHandlers[method] ??= new Set(); + this.wsMessageHandlers[method].add(handler); + + // Return detacher + return () => this.wsMessageHandlers[method].delete(handler); } subscribeToExecution(graphId: string) { @@ -274,3 +284,22 @@ type WebsocketMessageTypeMap = { subscribe: { graph_id: string }; execution_event: NodeExecutionResult; }; + +type WebsocketMessage = { + [M in keyof WebsocketMessageTypeMap]: { + method: M; + data: WebsocketMessageTypeMap[M]; + }; +}[keyof WebsocketMessageTypeMap]; + +/* *** HELPER FUNCTIONS *** */ + +function parseNodeExecutionResultTimestamps(result: any): NodeExecutionResult { + return { + ...result, + add_time: new Date(result.add_time), + queue_time: result.queue_time ? new Date(result.queue_time) : undefined, + start_time: result.start_time ? new Date(result.start_time) : undefined, + end_time: result.end_time ? new Date(result.end_time) : undefined, + }; +} diff --git a/rnd/autogpt_server/autogpt_server/data/execution.py b/rnd/autogpt_server/autogpt_server/data/execution.py index d5abd9f316..0fd54feaa3 100644 --- a/rnd/autogpt_server/autogpt_server/data/execution.py +++ b/rnd/autogpt_server/autogpt_server/data/execution.py @@ -300,6 +300,26 @@ async def update_execution_status( return ExecutionResult.from_db(res) +async def get_graph_execution( + graph_exec_id: str, user_id: str +) -> AgentGraphExecution | None: + """ + Retrieve a specific graph execution by its ID. + + Args: + graph_exec_id (str): The ID of the graph execution to retrieve. + user_id (str): The ID of the user to whom the graph (execution) belongs. + + Returns: + AgentGraphExecution | None: The graph execution if found, None otherwise. + """ + execution = await AgentGraphExecution.prisma().find_first( + where={"id": graph_exec_id, "userId": user_id}, + include=GRAPH_EXECUTION_INCLUDE, + ) + return execution + + async def list_executions(graph_id: str, graph_version: int | None = None) -> list[str]: where: AgentGraphExecutionWhereInput = {"agentGraphId": graph_id} if graph_version is not None: diff --git a/rnd/autogpt_server/autogpt_server/executor/manager.py b/rnd/autogpt_server/autogpt_server/executor/manager.py index edde97b156..2f0113a69e 100644 --- a/rnd/autogpt_server/autogpt_server/executor/manager.py +++ b/rnd/autogpt_server/autogpt_server/executor/manager.py @@ -1,7 +1,10 @@ import asyncio import logging -from concurrent.futures import Future, ProcessPoolExecutor, TimeoutError +import multiprocessing +import threading +from concurrent.futures import Future, ProcessPoolExecutor from contextlib import contextmanager +from multiprocessing.pool import AsyncResult, Pool from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar if TYPE_CHECKING: @@ -16,6 +19,7 @@ from autogpt_server.data.execution import ( GraphExecution, NodeExecution, create_graph_execution, + get_execution_results, get_incomplete_executions, get_latest_execution, merge_execution_input, @@ -461,15 +465,19 @@ class Executor: cls.loop = asyncio.new_event_loop() cls.loop.run_until_complete(db.connect()) cls.pool_size = Config().num_node_workers - cls.executor = ProcessPoolExecutor( - max_workers=cls.pool_size, + cls._init_node_executor_pool() + logger.info(f"Graph executor started with max-{cls.pool_size} node workers.") + + @classmethod + def _init_node_executor_pool(cls): + cls.executor = Pool( + processes=cls.pool_size, initializer=cls.on_node_executor_start, ) - cls.logger.info(f"Graph executor started with max-{cls.pool_size} node workers") @classmethod @error_logged - def on_graph_execution(cls, data: GraphExecution): + def on_graph_execution(cls, data: GraphExecution, cancel: threading.Event): log_metadata = get_log_metadata( graph_eid=data.graph_exec_id, graph_id=data.graph_id, @@ -477,7 +485,7 @@ class Executor: node_eid="*", block_name="-", ) - timing_info, node_count = cls._on_graph_execution(data, log_metadata) + timing_info, node_count = cls._on_graph_execution(data, cancel, log_metadata) metric_graph_timing("walltime", timing_info.wall_time, tags=log_metadata) metric_graph_timing("cputime", timing_info.cpu_time, tags=log_metadata) metric_graph_count("nodecount", node_count, tags=log_metadata) @@ -495,7 +503,9 @@ class Executor: @classmethod @time_measured - def _on_graph_execution(cls, graph_data: GraphExecution, log_metadata: dict) -> int: + def _on_graph_execution( + cls, graph_data: GraphExecution, cancel: threading.Event, log_metadata: dict + ) -> int: cls.logger.info( "Start graph execution", extra={ @@ -504,38 +514,85 @@ class Executor: } }, ) - node_executed = 0 + n_node_executions = 0 + finished = False + + def cancel_handler(): + while not cancel.is_set(): + cancel.wait(1) + if finished: + return + cls.executor.terminate() + logger.info( + f"Terminated graph execution {graph_data.graph_exec_id}", + extra={"json_fields": {**log_metadata}}, + ) + cls._init_node_executor_pool() + + cancel_thread = threading.Thread(target=cancel_handler) + cancel_thread.start() try: queue = ExecutionQueue[NodeExecution]() for node_exec in graph_data.start_node_execs: queue.add(node_exec) - futures: dict[str, Future] = {} + running_executions: dict[str, AsyncResult] = {} + + def make_exec_callback(exec_data: NodeExecution): + node_id = exec_data.node_id + + def callback(_): + running_executions.pop(node_id) + nonlocal n_node_executions + n_node_executions += 1 + + return callback + while not queue.empty(): - execution = queue.get() + if cancel.is_set(): + return n_node_executions + + exec_data = queue.get() # Avoid parallel execution of the same node. - fut = futures.get(execution.node_id) - if fut and not fut.done(): + execution = running_executions.get(exec_data.node_id) + if execution and not execution.ready(): # TODO (performance improvement): # Wait for the completion of the same node execution is blocking. # To improve this we need a separate queue for each node. # Re-enqueueing the data back to the queue will disrupt the order. - cls.wait_future(fut, timeout=None) + execution.wait() - futures[execution.node_id] = cls.executor.submit( - cls.on_node_execution, queue, execution + logger.debug(f"Dispatching execution of node {exec_data.node_id}") + running_executions[exec_data.node_id] = cls.executor.apply_async( + cls.on_node_execution, + (queue, exec_data), + callback=make_exec_callback(exec_data), ) # Avoid terminating graph execution when some nodes are still running. - while queue.empty() and futures: - for node_id, future in list(futures.items()): - if future.done(): - node_executed += 1 - del futures[node_id] - elif queue.empty(): - cls.wait_future(future) + while queue.empty() and running_executions: + logger.debug( + "Queue empty; running nodes: " + f"{list(running_executions.keys())}" + ) + for node_id, execution in list(running_executions.items()): + if cancel.is_set(): + return n_node_executions + + if not queue.empty(): + logger.debug( + "Queue no longer empty! Returning to dispatching loop." + ) + break # yield to parent loop to execute new queue items + + logger.debug(f"Waiting on execution of node {node_id}") + execution.wait(3) + logger.debug( + f"State of execution of node {node_id} after waiting: " + f"{'DONE' if execution.ready() else 'RUNNING'}" + ) cls.logger.info( "Finished graph execution", @@ -546,7 +603,7 @@ class Executor: }, ) except Exception as e: - cls.logger.exception( + logger.exception( f"Failed graph execution: {e}", extra={ "json_fields": { @@ -554,24 +611,20 @@ class Executor: } }, ) - - return node_executed - - @classmethod - def wait_future(cls, future: Future, timeout: int | None = 3): - try: - if not future.done(): - future.result(timeout=timeout) - except TimeoutError: - # Avoid being blocked by long-running node, by not waiting its completion. - pass + finally: + if not cancel.is_set(): + finished = True + cancel.set() + cancel_thread.join() + return n_node_executions class ExecutionManager(AppService): def __init__(self): + self.use_redis = False self.pool_size = Config().num_graph_workers self.queue = ExecutionQueue[GraphExecution]() - self.use_redis = False + self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {} # def __del__(self): # self.sync_manager.shutdown() @@ -581,11 +634,21 @@ class ExecutionManager(AppService): max_workers=self.pool_size, initializer=Executor.on_graph_executor_start, ) as executor: + sync_manager = multiprocessing.Manager() logger.info( f"Execution manager started with max-{self.pool_size} graph workers." ) while True: - executor.submit(Executor.on_graph_execution, self.queue.get()) + graph_exec_data = self.queue.get() + graph_exec_id = graph_exec_data.graph_exec_id + cancel_event = sync_manager.Event() + future = executor.submit( + Executor.on_graph_execution, graph_exec_data, cancel_event + ) + self.active_graph_runs[graph_exec_id] = (future, cancel_event) + future.add_done_callback( + lambda _: self.active_graph_runs.pop(graph_exec_id) + ) @property def agent_server_client(self) -> "AgentServer": @@ -594,7 +657,7 @@ class ExecutionManager(AppService): @expose def add_execution( self, graph_id: str, data: BlockInput, user_id: str - ) -> dict[Any, Any]: + ) -> dict[str, Any]: graph: Graph | None = self.run_and_wait(get_graph(graph_id, user_id=user_id)) if not graph: raise Exception(f"Graph #{graph_id} not found.") @@ -647,4 +710,45 @@ class ExecutionManager(AppService): ) self.queue.add(graph_exec) - return {"id": graph_exec_id} + return graph_exec.model_dump() + + @expose + def cancel_execution(self, graph_exec_id: str) -> None: + """ + Mechanism: + 1. Set the cancel event + 2. Graph executor's cancel handler thread detects the event, terminates workers, + reinitializes worker pool, and returns. + 3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`. + """ + if graph_exec_id not in self.active_graph_runs: + raise Exception( + f"Graph execution #{graph_exec_id} not active/running: " + "possibly already completed/cancelled." + ) + + future, cancel_event = self.active_graph_runs[graph_exec_id] + if cancel_event.is_set(): + return + + cancel_event.set() + future.result() + + # Update the status of the unfinished node executions + node_execs = self.run_and_wait(get_execution_results(graph_exec_id)) + for node_exec in node_execs: + if node_exec.status not in ( + ExecutionStatus.COMPLETED, + ExecutionStatus.FAILED, + ): + self.run_and_wait( + upsert_execution_output( + node_exec.node_exec_id, "error", "TERMINATED" + ) + ) + exec_update = self.run_and_wait( + update_execution_status( + node_exec.node_exec_id, ExecutionStatus.FAILED + ) + ) + self.agent_server_client.send_execution_update(exec_update.model_dump()) diff --git a/rnd/autogpt_server/autogpt_server/server/rest_api.py b/rnd/autogpt_server/autogpt_server/server/rest_api.py index a51418f49b..7231ae75ab 100644 --- a/rnd/autogpt_server/autogpt_server/server/rest_api.py +++ b/rnd/autogpt_server/autogpt_server/server/rest_api.py @@ -11,14 +11,10 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from autogpt_server.data import block, db +from autogpt_server.data import execution as execution_db from autogpt_server.data import graph as graph_db from autogpt_server.data import user as user_db from autogpt_server.data.block import BlockInput, CompletedBlockOutput -from autogpt_server.data.execution import ( - ExecutionResult, - get_execution_results, - list_executions, -) from autogpt_server.data.queue import AsyncEventQueue, AsyncRedisEventQueue from autogpt_server.data.user import get_or_create_user from autogpt_server.executor import ExecutionManager, ExecutionScheduler @@ -170,10 +166,15 @@ class AgentServer(AppService): methods=["GET"], ) router.add_api_route( - path="/graphs/{graph_id}/executions/{run_id}", - endpoint=self.get_run_execution_results, + path="/graphs/{graph_id}/executions/{graph_exec_id}", + endpoint=self.get_graph_run_node_execution_results, methods=["GET"], ) + router.add_api_route( + path="/graphs/{graph_id}/executions/{graph_exec_id}/stop", + endpoint=self.stop_graph_run, + methods=["POST"], + ) router.add_api_route( path="/graphs/{graph_id}/schedules", endpoint=self.create_schedule, @@ -423,15 +424,29 @@ class AgentServer(AppService): graph_id: str, node_input: dict[Any, Any], user_id: Annotated[str, Depends(get_user_id)], - ) -> dict[Any, Any]: + ) -> dict[str, Any]: # FIXME: add proper return type try: - return self.execution_manager_client.add_execution( + graph_exec = self.execution_manager_client.add_execution( graph_id, node_input, user_id=user_id ) + return {"id": graph_exec["graph_exec_id"]} except Exception as e: msg = e.__str__().encode().decode("unicode_escape") raise HTTPException(status_code=400, detail=msg) + async def stop_graph_run( + self, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)] + ) -> list[execution_db.ExecutionResult]: + if not await execution_db.get_graph_execution(graph_exec_id, user_id): + raise HTTPException( + 404, detail=f"Agent execution #{graph_exec_id} not found" + ) + + self.execution_manager_client.cancel_execution(graph_exec_id) + + # Retrieve & return canceled graph execution in its final state + return await execution_db.get_execution_results(graph_exec_id) + @classmethod async def get_graph_input_schema( cls, @@ -458,17 +473,20 @@ class AgentServer(AppService): status_code=404, detail=f"Agent #{graph_id}{rev} not found." ) - return await list_executions(graph_id, graph_version) + return await execution_db.list_executions(graph_id, graph_version) @classmethod - async def get_run_execution_results( - cls, graph_id: str, run_id: str, user_id: Annotated[str, Depends(get_user_id)] - ) -> list[ExecutionResult]: + async def get_graph_run_node_execution_results( + cls, + graph_id: str, + graph_exec_id: str, + user_id: Annotated[str, Depends(get_user_id)], + ) -> list[execution_db.ExecutionResult]: graph = await graph_db.get_graph(graph_id, user_id=user_id) if not graph: raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.") - return await get_execution_results(run_id) + return await execution_db.get_execution_results(graph_exec_id) async def create_schedule( self, @@ -506,7 +524,7 @@ class AgentServer(AppService): @expose def send_execution_update(self, execution_result_dict: dict[Any, Any]): - execution_result = ExecutionResult(**execution_result_dict) + execution_result = execution_db.ExecutionResult(**execution_result_dict) self.run_and_wait(self.event_queue.put(execution_result)) @expose diff --git a/rnd/autogpt_server/autogpt_server/util/test.py b/rnd/autogpt_server/autogpt_server/util/test.py index 8349920404..05f9178c93 100644 --- a/rnd/autogpt_server/autogpt_server/util/test.py +++ b/rnd/autogpt_server/autogpt_server/util/test.py @@ -59,7 +59,6 @@ class SpinTestServer: return "3e53486c-cf57-477e-ba2a-cb02dc828e1a" async def __aenter__(self): - self.name_server.__enter__() self.setup_dependency_overrides() self.agent_server.__enter__() @@ -95,7 +94,7 @@ async def wait_execution( timeout: int = 20, ) -> list: async def is_execution_completed(): - execs = await AgentServer().get_run_execution_results( + execs = await AgentServer().get_graph_run_node_execution_results( graph_id, graph_exec_id, user_id ) return ( @@ -110,7 +109,7 @@ async def wait_execution( # Wait for the executions to complete for i in range(timeout): if await is_execution_completed(): - return await AgentServer().get_run_execution_results( + return await AgentServer().get_graph_run_node_execution_results( graph_id, graph_exec_id, user_id ) time.sleep(1) diff --git a/rnd/autogpt_server/config.default.json b/rnd/autogpt_server/config.default.json index b3afd56b5d..08b2b14bdd 100644 --- a/rnd/autogpt_server/config.default.json +++ b/rnd/autogpt_server/config.default.json @@ -1,4 +1,4 @@ { "num_graph_workers": 10, - "num_node_workers": 10 + "num_node_workers": 5 } diff --git a/rnd/autogpt_server/test/executor/test_manager.py b/rnd/autogpt_server/test/executor/test_manager.py index 151b641bbc..51debc789a 100644 --- a/rnd/autogpt_server/test/executor/test_manager.py +++ b/rnd/autogpt_server/test/executor/test_manager.py @@ -35,7 +35,7 @@ async def assert_sample_graph_executions( test_user: User, graph_exec_id: str, ): - executions = await agent_server.get_run_execution_results( + executions = await agent_server.get_graph_run_node_execution_results( test_graph.id, graph_exec_id, test_user.id ) @@ -156,7 +156,7 @@ async def test_input_pin_always_waited(server: SpinTestServer): server.agent_server, server.exec_manager, test_graph, test_user, {}, 3 ) - executions = await server.agent_server.get_run_execution_results( + executions = await server.agent_server.get_graph_run_node_execution_results( test_graph.id, graph_exec_id, test_user.id ) assert len(executions) == 3 @@ -236,7 +236,7 @@ async def test_static_input_link_on_graph(server: SpinTestServer): graph_exec_id = await execute_graph( server.agent_server, server.exec_manager, test_graph, test_user, {}, 8 ) - executions = await server.agent_server.get_run_execution_results( + executions = await server.agent_server.get_graph_run_node_execution_results( test_graph.id, graph_exec_id, test_user.id ) assert len(executions) == 8 diff --git a/rnd/infra/helm/autogpt-server/values.dev.yaml b/rnd/infra/helm/autogpt-server/values.dev.yaml index baff7fff15..2d292a8a03 100644 --- a/rnd/infra/helm/autogpt-server/values.dev.yaml +++ b/rnd/infra/helm/autogpt-server/values.dev.yaml @@ -82,6 +82,6 @@ env: APP_ENV: "dev" PYRO_HOST: "0.0.0.0" NUM_GRAPH_WORKERS: 100 - NUM_NODE_WORKERS: 100 + NUM_NODE_WORKERS: 5 REDIS_HOST: "redis-dev-master.redis-dev.svc.cluster.local" - REDIS_PORT: "6379" \ No newline at end of file + REDIS_PORT: "6379"