mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 23:28:07 -05:00
feat(server, builder): Implement "STOP" button for graph runs (#7892)
- feat(builder): Add "Stop Run" buttons to monitor and builder
- Implement additional state management in `useAgentGraph` hook
- Add "stop" request mechanism
- Implement execution status tracking using WebSockets
- Add `isSaving`, `isRunning`, `isStopping` outputs
- Add `requestStopRun` method
- Rename `requestSaveRun` to `requestSaveAndRun` for clarity
- Add needed functionality for the above to `AutoGPTServerAPI` client
- Add `stopGraphExecution` method
- Add support for multiple handlers per WebSocket method
- Fix parsing of timestamps in `execution_event` WebSocket messages
- Add `IconSquare` from Lucide to `@/components/ui/icons`
- feat(server): Add `POST /graphs/{graph_id}/executions/{graph_exec_id}/stop` route
- Add `stop_graph_run` method to `AgentServer`
- feat(server): Add `cancel_execution` method to `ExecutionManager`
- Replace node executor `ProcessPoolExecutor` by `multiprocessing.Pool` (which has a `terminate()` method)
- Remove now unnecessary `Executor.wait_future(..)` method
- Add `get_graph_execution(..)` in `.data.execution`
- fix(server): Reduce number of node executors to 5 per graph executor
This is necessary because `multiprocessing.Pool` spawns its workers on init, instead of based on demand like `ProcessPoolExecutor` does
- dx(server): Improve debug logging in `ExecutionManager`
- ci(server): Add debug logging mode to CI Pytest step
### Other improvements
Server:
- Improve output type of `ExecutionManager.add_execution(..)`
- Renamed a few things in `.server.rest_api` for consistency
Front end:
- Improved typing in `AutoGPTServerAPI` client
This commit is contained in:
committed by
GitHub
parent
11827835a0
commit
8fd22bcfd7
9
.github/workflows/autogpt-server-ci.yml
vendored
9
.github/workflows/autogpt-server-ci.yml
vendored
@@ -128,9 +128,14 @@ jobs:
|
||||
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
poetry run pytest -vv \
|
||||
test
|
||||
if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
poetry run pytest -vv -o log_cli=true -o log_cli_level=DEBUG test
|
||||
else
|
||||
poetry run pytest -vv test
|
||||
fi
|
||||
if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
env:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
|
||||
@@ -34,7 +34,12 @@ import ConnectionLine from "./ConnectionLine";
|
||||
import { Control, ControlPanel } from "@/components/edit/control/ControlPanel";
|
||||
import { SaveControl } from "@/components/edit/control/SaveControl";
|
||||
import { BlocksControl } from "@/components/edit/control/BlocksControl";
|
||||
import { IconPlay, IconRedo2, IconUndo2 } from "@/components/ui/icons";
|
||||
import {
|
||||
IconPlay,
|
||||
IconRedo2,
|
||||
IconSquare,
|
||||
IconUndo2,
|
||||
} from "@/components/ui/icons";
|
||||
import { startTutorial } from "./tutorial";
|
||||
import useAgentGraph from "@/hooks/useAgentGraph";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
@@ -75,7 +80,9 @@ const FlowEditor: React.FC<{
|
||||
availableNodes,
|
||||
getOutputType,
|
||||
requestSave,
|
||||
requestSaveRun,
|
||||
requestSaveAndRun,
|
||||
requestStopRun,
|
||||
isRunning,
|
||||
nodes,
|
||||
setNodes,
|
||||
edges,
|
||||
@@ -542,9 +549,9 @@ const FlowEditor: React.FC<{
|
||||
onClick: handleRedo,
|
||||
},
|
||||
{
|
||||
label: "Run",
|
||||
icon: <IconPlay />,
|
||||
onClick: requestSaveRun,
|
||||
label: !isRunning ? "Run" : "Stop",
|
||||
icon: !isRunning ? <IconPlay /> : <IconSquare />,
|
||||
onClick: !isRunning ? requestSaveAndRun : requestStopRun,
|
||||
},
|
||||
];
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import React from "react";
|
||||
import { GraphMeta } from "@/lib/autogpt-server-api";
|
||||
import React, { useCallback } from "react";
|
||||
import AutoGPTServerAPI, { GraphMeta } from "@/lib/autogpt-server-api";
|
||||
import { FlowRun } from "@/lib/types";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import Link from "next/link";
|
||||
import { buttonVariants } from "@/components/ui/button";
|
||||
import { Button, buttonVariants } from "@/components/ui/button";
|
||||
import { IconSquare } from "@/components/ui/icons";
|
||||
import { Pencil2Icon } from "@radix-ui/react-icons";
|
||||
import moment from "moment/moment";
|
||||
import { FlowRunStatusBadge } from "@/components/monitor/FlowRunStatusBadge";
|
||||
@@ -20,6 +21,11 @@ export const FlowRunInfo: React.FC<
|
||||
);
|
||||
}
|
||||
|
||||
const handleStopRun = useCallback(() => {
|
||||
const api = new AutoGPTServerAPI();
|
||||
api.stopGraphExecution(flow.id, flowRun.id);
|
||||
}, [flow.id, flowRun.id]);
|
||||
|
||||
return (
|
||||
<Card {...props}>
|
||||
<CardHeader className="flex-row items-center justify-between space-x-3 space-y-0">
|
||||
@@ -34,12 +40,19 @@ export const FlowRunInfo: React.FC<
|
||||
Run ID: <code>{flowRun.id}</code>
|
||||
</p>
|
||||
</div>
|
||||
<Link
|
||||
className={buttonVariants({ variant: "outline" })}
|
||||
href={`/build?flowID=${flow.id}`}
|
||||
>
|
||||
<Pencil2Icon className="mr-2" /> Edit Agent
|
||||
</Link>
|
||||
<div className="flex space-x-2">
|
||||
{flowRun.status === "running" && (
|
||||
<Button onClick={handleStopRun} variant="destructive">
|
||||
<IconSquare className="mr-2" /> Stop Run
|
||||
</Button>
|
||||
)}
|
||||
<Link
|
||||
className={buttonVariants({ variant: "outline" })}
|
||||
href={`/build?flowID=${flow.id}`}
|
||||
>
|
||||
<Pencil2Icon className="mr-2" /> Edit Agent
|
||||
</Link>
|
||||
</div>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<p>
|
||||
|
||||
@@ -405,6 +405,40 @@ export const IconPlay = createIcon((props) => (
|
||||
</svg>
|
||||
));
|
||||
|
||||
/**
|
||||
* Square icon component.
|
||||
*
|
||||
* @component IconSquare
|
||||
* @param {IconProps} props - The props object containing additional attributes and event handlers for the icon.
|
||||
* @returns {JSX.Element} - The square icon.
|
||||
*
|
||||
* @example
|
||||
* // Default usage this is the standard usage
|
||||
* <IconSquare />
|
||||
*
|
||||
* @example
|
||||
* // With custom color and size these should be used sparingly and only when necessary
|
||||
* <IconSquare className="text-primary" size="lg" />
|
||||
*
|
||||
* @example
|
||||
* // With custom size and onClick handler
|
||||
* <IconSquare size="sm" onClick={handleOnClick} />
|
||||
*/
|
||||
export const IconSquare = createIcon((props) => (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
{...props}
|
||||
>
|
||||
<rect width="18" height="18" x="3" y="3" rx="2" />
|
||||
</svg>
|
||||
));
|
||||
|
||||
/**
|
||||
* Package2 icon component.
|
||||
*
|
||||
|
||||
@@ -31,22 +31,27 @@ export default function useAgentGraph(
|
||||
const [updateQueue, setUpdateQueue] = useState<NodeExecutionResult[]>([]);
|
||||
const processedUpdates = useRef<NodeExecutionResult[]>([]);
|
||||
/**
|
||||
* User `request` to save or save&run the agent
|
||||
* User `request` to save or save&run the agent, or to stop the active run.
|
||||
* `state` is used to track the request status:
|
||||
* - none: no request
|
||||
* - saving: request was sent to save the agent
|
||||
* and nodes are pending sync to update their backend ids
|
||||
* - running: request was sent to run the agent
|
||||
* and frontend is enqueueing execution results
|
||||
* - stopping: a request to stop the active run has been sent; response is pending
|
||||
* - error: request failed
|
||||
*
|
||||
* As of now, state will be stuck at 'running' (if run requested)
|
||||
* because there's no way to know when the execution is done
|
||||
*/
|
||||
const [saveRunRequest, setSaveRunRequest] = useState<{
|
||||
request: "none" | "save" | "run";
|
||||
state: "none" | "saving" | "running" | "error";
|
||||
}>({
|
||||
const [saveRunRequest, setSaveRunRequest] = useState<
|
||||
| {
|
||||
request: "none" | "save" | "run";
|
||||
state: "none" | "saving" | "error";
|
||||
}
|
||||
| {
|
||||
request: "run" | "stop";
|
||||
state: "running" | "stopping" | "error";
|
||||
activeExecutionID?: string;
|
||||
}
|
||||
>({
|
||||
request: "none",
|
||||
state: "none",
|
||||
});
|
||||
@@ -128,13 +133,14 @@ export default function useAgentGraph(
|
||||
console.error("Error saving agent");
|
||||
} else if (saveRunRequest.request === "run") {
|
||||
console.error(`Error saving&running agent`);
|
||||
} else if (saveRunRequest.request === "stop") {
|
||||
console.error(`Error stopping agent`);
|
||||
}
|
||||
// Reset request
|
||||
setSaveRunRequest((prev) => ({
|
||||
...prev,
|
||||
setSaveRunRequest({
|
||||
request: "none",
|
||||
state: "none",
|
||||
}));
|
||||
});
|
||||
return;
|
||||
}
|
||||
// When saving request is done
|
||||
@@ -145,11 +151,10 @@ export default function useAgentGraph(
|
||||
) {
|
||||
// Reset request if only save was requested
|
||||
if (saveRunRequest.request === "save") {
|
||||
setSaveRunRequest((prev) => ({
|
||||
...prev,
|
||||
setSaveRunRequest({
|
||||
request: "none",
|
||||
state: "none",
|
||||
}));
|
||||
});
|
||||
// If run was requested, run the agent
|
||||
} else if (saveRunRequest.request === "run") {
|
||||
if (!validateNodes()) {
|
||||
@@ -161,16 +166,64 @@ export default function useAgentGraph(
|
||||
return;
|
||||
}
|
||||
api.subscribeToExecution(savedAgent.id);
|
||||
api.executeGraph(savedAgent.id);
|
||||
processedUpdates.current = processedUpdates.current = [];
|
||||
setSaveRunRequest({ request: "run", state: "running" });
|
||||
api
|
||||
.executeGraph(savedAgent.id)
|
||||
.then((graphExecution) => {
|
||||
setSaveRunRequest({
|
||||
request: "run",
|
||||
state: "running",
|
||||
activeExecutionID: graphExecution.id,
|
||||
});
|
||||
|
||||
setSaveRunRequest((prev) => ({
|
||||
...prev,
|
||||
request: "run",
|
||||
state: "running",
|
||||
}));
|
||||
// Track execution until completed
|
||||
const pendingNodeExecutions: Set<string> = new Set();
|
||||
const cancelExecListener = api.onWebSocketMessage(
|
||||
"execution_event",
|
||||
(nodeResult) => {
|
||||
// We are racing the server here, since we need the ID to filter events
|
||||
if (nodeResult.graph_exec_id != graphExecution.id) {
|
||||
return;
|
||||
}
|
||||
if (
|
||||
nodeResult.status != "COMPLETED" &&
|
||||
nodeResult.status != "FAILED"
|
||||
) {
|
||||
pendingNodeExecutions.add(nodeResult.node_exec_id);
|
||||
} else {
|
||||
pendingNodeExecutions.delete(nodeResult.node_exec_id);
|
||||
}
|
||||
if (pendingNodeExecutions.size == 0) {
|
||||
// Assuming the first event is always a QUEUED node, and
|
||||
// following nodes are QUEUED before all preceding nodes are COMPLETED,
|
||||
// an empty set means the graph has finished running.
|
||||
cancelExecListener();
|
||||
setSaveRunRequest({ request: "none", state: "none" });
|
||||
}
|
||||
},
|
||||
);
|
||||
})
|
||||
.catch(() => setSaveRunRequest({ request: "run", state: "error" }));
|
||||
|
||||
processedUpdates.current = processedUpdates.current = [];
|
||||
}
|
||||
}
|
||||
// Handle stop request
|
||||
if (
|
||||
saveRunRequest.request === "stop" &&
|
||||
saveRunRequest.state != "stopping" &&
|
||||
savedAgent &&
|
||||
saveRunRequest.activeExecutionID
|
||||
) {
|
||||
setSaveRunRequest({
|
||||
request: "stop",
|
||||
state: "stopping",
|
||||
activeExecutionID: saveRunRequest.activeExecutionID,
|
||||
});
|
||||
api
|
||||
.stopGraphExecution(savedAgent.id, saveRunRequest.activeExecutionID)
|
||||
.then(() => setSaveRunRequest({ request: "none", state: "none" }));
|
||||
}
|
||||
}, [saveRunRequest, savedAgent, nodesSyncedWithSavedAgent]);
|
||||
|
||||
// Check if node ids are synced with saved agent
|
||||
@@ -657,7 +710,7 @@ export default function useAgentGraph(
|
||||
[saveAgent],
|
||||
);
|
||||
|
||||
const requestSaveRun = useCallback(() => {
|
||||
const requestSaveAndRun = useCallback(() => {
|
||||
saveAgent();
|
||||
setSaveRunRequest({
|
||||
request: "run",
|
||||
@@ -665,6 +718,23 @@ export default function useAgentGraph(
|
||||
});
|
||||
}, [saveAgent]);
|
||||
|
||||
const requestStopRun = useCallback(() => {
|
||||
if (saveRunRequest.state != "running") {
|
||||
return;
|
||||
}
|
||||
if (!saveRunRequest.activeExecutionID) {
|
||||
console.warn(
|
||||
"Stop requested but execution ID is unknown; state:",
|
||||
saveRunRequest,
|
||||
);
|
||||
}
|
||||
setSaveRunRequest((prev) => ({
|
||||
...prev,
|
||||
request: "stop",
|
||||
state: "running",
|
||||
}));
|
||||
}, [saveRunRequest]);
|
||||
|
||||
return {
|
||||
agentName,
|
||||
setAgentName,
|
||||
@@ -674,7 +744,11 @@ export default function useAgentGraph(
|
||||
availableNodes,
|
||||
getOutputType,
|
||||
requestSave,
|
||||
requestSaveRun,
|
||||
requestSaveAndRun,
|
||||
requestStopRun,
|
||||
isSaving: saveRunRequest.state == "saving",
|
||||
isRunning: saveRunRequest.state == "running",
|
||||
isStopping: saveRunRequest.state == "stopping",
|
||||
nodes,
|
||||
setNodes,
|
||||
edges,
|
||||
|
||||
@@ -15,7 +15,7 @@ export default class AutoGPTServerAPI {
|
||||
private wsUrl: string;
|
||||
private webSocket: WebSocket | null = null;
|
||||
private wsConnecting: Promise<void> | null = null;
|
||||
private wsMessageHandlers: { [key: string]: (data: any) => void } = {};
|
||||
private wsMessageHandlers: Record<string, Set<(data: any) => void>> = {};
|
||||
private supabaseClient = createClient();
|
||||
|
||||
constructor(
|
||||
@@ -128,16 +128,19 @@ export default class AutoGPTServerAPI {
|
||||
runID: string,
|
||||
): Promise<NodeExecutionResult[]> {
|
||||
return (await this._get(`/graphs/${graphID}/executions/${runID}`)).map(
|
||||
(result: any) => ({
|
||||
...result,
|
||||
add_time: new Date(result.add_time),
|
||||
queue_time: result.queue_time ? new Date(result.queue_time) : undefined,
|
||||
start_time: result.start_time ? new Date(result.start_time) : undefined,
|
||||
end_time: result.end_time ? new Date(result.end_time) : undefined,
|
||||
}),
|
||||
parseNodeExecutionResultTimestamps,
|
||||
);
|
||||
}
|
||||
|
||||
async stopGraphExecution(
|
||||
graphID: string,
|
||||
runID: string,
|
||||
): Promise<NodeExecutionResult[]> {
|
||||
return (
|
||||
await this._request("POST", `/graphs/${graphID}/executions/${runID}/stop`)
|
||||
).map(parseNodeExecutionResultTimestamps);
|
||||
}
|
||||
|
||||
private async _get(path: string) {
|
||||
return this._request("GET", path);
|
||||
}
|
||||
@@ -207,10 +210,13 @@ export default class AutoGPTServerAPI {
|
||||
};
|
||||
|
||||
this.webSocket.onmessage = (event) => {
|
||||
const message = JSON.parse(event.data);
|
||||
if (this.wsMessageHandlers[message.method]) {
|
||||
this.wsMessageHandlers[message.method](message.data);
|
||||
const message: WebsocketMessage = JSON.parse(event.data);
|
||||
if (message.method == "execution_event") {
|
||||
message.data = parseNodeExecutionResultTimestamps(message.data);
|
||||
}
|
||||
this.wsMessageHandlers[message.method]?.forEach((handler) =>
|
||||
handler(message.data),
|
||||
);
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("Error connecting to WebSocket:", error);
|
||||
@@ -250,8 +256,12 @@ export default class AutoGPTServerAPI {
|
||||
onWebSocketMessage<M extends keyof WebsocketMessageTypeMap>(
|
||||
method: M,
|
||||
handler: (data: WebsocketMessageTypeMap[M]) => void,
|
||||
) {
|
||||
this.wsMessageHandlers[method] = handler;
|
||||
): () => void {
|
||||
this.wsMessageHandlers[method] ??= new Set();
|
||||
this.wsMessageHandlers[method].add(handler);
|
||||
|
||||
// Return detacher
|
||||
return () => this.wsMessageHandlers[method].delete(handler);
|
||||
}
|
||||
|
||||
subscribeToExecution(graphId: string) {
|
||||
@@ -274,3 +284,22 @@ type WebsocketMessageTypeMap = {
|
||||
subscribe: { graph_id: string };
|
||||
execution_event: NodeExecutionResult;
|
||||
};
|
||||
|
||||
type WebsocketMessage = {
|
||||
[M in keyof WebsocketMessageTypeMap]: {
|
||||
method: M;
|
||||
data: WebsocketMessageTypeMap[M];
|
||||
};
|
||||
}[keyof WebsocketMessageTypeMap];
|
||||
|
||||
/* *** HELPER FUNCTIONS *** */
|
||||
|
||||
function parseNodeExecutionResultTimestamps(result: any): NodeExecutionResult {
|
||||
return {
|
||||
...result,
|
||||
add_time: new Date(result.add_time),
|
||||
queue_time: result.queue_time ? new Date(result.queue_time) : undefined,
|
||||
start_time: result.start_time ? new Date(result.start_time) : undefined,
|
||||
end_time: result.end_time ? new Date(result.end_time) : undefined,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -300,6 +300,26 @@ async def update_execution_status(
|
||||
return ExecutionResult.from_db(res)
|
||||
|
||||
|
||||
async def get_graph_execution(
|
||||
graph_exec_id: str, user_id: str
|
||||
) -> AgentGraphExecution | None:
|
||||
"""
|
||||
Retrieve a specific graph execution by its ID.
|
||||
|
||||
Args:
|
||||
graph_exec_id (str): The ID of the graph execution to retrieve.
|
||||
user_id (str): The ID of the user to whom the graph (execution) belongs.
|
||||
|
||||
Returns:
|
||||
AgentGraphExecution | None: The graph execution if found, None otherwise.
|
||||
"""
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={"id": graph_exec_id, "userId": user_id},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
return execution
|
||||
|
||||
|
||||
async def list_executions(graph_id: str, graph_version: int | None = None) -> list[str]:
|
||||
where: AgentGraphExecutionWhereInput = {"agentGraphId": graph_id}
|
||||
if graph_version is not None:
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import Future, ProcessPoolExecutor, TimeoutError
|
||||
import multiprocessing
|
||||
import threading
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing.pool import AsyncResult, Pool
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -16,6 +19,7 @@ from autogpt_server.data.execution import (
|
||||
GraphExecution,
|
||||
NodeExecution,
|
||||
create_graph_execution,
|
||||
get_execution_results,
|
||||
get_incomplete_executions,
|
||||
get_latest_execution,
|
||||
merge_execution_input,
|
||||
@@ -461,15 +465,19 @@ class Executor:
|
||||
cls.loop = asyncio.new_event_loop()
|
||||
cls.loop.run_until_complete(db.connect())
|
||||
cls.pool_size = Config().num_node_workers
|
||||
cls.executor = ProcessPoolExecutor(
|
||||
max_workers=cls.pool_size,
|
||||
cls._init_node_executor_pool()
|
||||
logger.info(f"Graph executor started with max-{cls.pool_size} node workers.")
|
||||
|
||||
@classmethod
|
||||
def _init_node_executor_pool(cls):
|
||||
cls.executor = Pool(
|
||||
processes=cls.pool_size,
|
||||
initializer=cls.on_node_executor_start,
|
||||
)
|
||||
cls.logger.info(f"Graph executor started with max-{cls.pool_size} node workers")
|
||||
|
||||
@classmethod
|
||||
@error_logged
|
||||
def on_graph_execution(cls, data: GraphExecution):
|
||||
def on_graph_execution(cls, data: GraphExecution, cancel: threading.Event):
|
||||
log_metadata = get_log_metadata(
|
||||
graph_eid=data.graph_exec_id,
|
||||
graph_id=data.graph_id,
|
||||
@@ -477,7 +485,7 @@ class Executor:
|
||||
node_eid="*",
|
||||
block_name="-",
|
||||
)
|
||||
timing_info, node_count = cls._on_graph_execution(data, log_metadata)
|
||||
timing_info, node_count = cls._on_graph_execution(data, cancel, log_metadata)
|
||||
metric_graph_timing("walltime", timing_info.wall_time, tags=log_metadata)
|
||||
metric_graph_timing("cputime", timing_info.cpu_time, tags=log_metadata)
|
||||
metric_graph_count("nodecount", node_count, tags=log_metadata)
|
||||
@@ -495,7 +503,9 @@ class Executor:
|
||||
|
||||
@classmethod
|
||||
@time_measured
|
||||
def _on_graph_execution(cls, graph_data: GraphExecution, log_metadata: dict) -> int:
|
||||
def _on_graph_execution(
|
||||
cls, graph_data: GraphExecution, cancel: threading.Event, log_metadata: dict
|
||||
) -> int:
|
||||
cls.logger.info(
|
||||
"Start graph execution",
|
||||
extra={
|
||||
@@ -504,38 +514,85 @@ class Executor:
|
||||
}
|
||||
},
|
||||
)
|
||||
node_executed = 0
|
||||
n_node_executions = 0
|
||||
finished = False
|
||||
|
||||
def cancel_handler():
|
||||
while not cancel.is_set():
|
||||
cancel.wait(1)
|
||||
if finished:
|
||||
return
|
||||
cls.executor.terminate()
|
||||
logger.info(
|
||||
f"Terminated graph execution {graph_data.graph_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
cls._init_node_executor_pool()
|
||||
|
||||
cancel_thread = threading.Thread(target=cancel_handler)
|
||||
cancel_thread.start()
|
||||
|
||||
try:
|
||||
queue = ExecutionQueue[NodeExecution]()
|
||||
for node_exec in graph_data.start_node_execs:
|
||||
queue.add(node_exec)
|
||||
|
||||
futures: dict[str, Future] = {}
|
||||
running_executions: dict[str, AsyncResult] = {}
|
||||
|
||||
def make_exec_callback(exec_data: NodeExecution):
|
||||
node_id = exec_data.node_id
|
||||
|
||||
def callback(_):
|
||||
running_executions.pop(node_id)
|
||||
nonlocal n_node_executions
|
||||
n_node_executions += 1
|
||||
|
||||
return callback
|
||||
|
||||
while not queue.empty():
|
||||
execution = queue.get()
|
||||
if cancel.is_set():
|
||||
return n_node_executions
|
||||
|
||||
exec_data = queue.get()
|
||||
|
||||
# Avoid parallel execution of the same node.
|
||||
fut = futures.get(execution.node_id)
|
||||
if fut and not fut.done():
|
||||
execution = running_executions.get(exec_data.node_id)
|
||||
if execution and not execution.ready():
|
||||
# TODO (performance improvement):
|
||||
# Wait for the completion of the same node execution is blocking.
|
||||
# To improve this we need a separate queue for each node.
|
||||
# Re-enqueueing the data back to the queue will disrupt the order.
|
||||
cls.wait_future(fut, timeout=None)
|
||||
execution.wait()
|
||||
|
||||
futures[execution.node_id] = cls.executor.submit(
|
||||
cls.on_node_execution, queue, execution
|
||||
logger.debug(f"Dispatching execution of node {exec_data.node_id}")
|
||||
running_executions[exec_data.node_id] = cls.executor.apply_async(
|
||||
cls.on_node_execution,
|
||||
(queue, exec_data),
|
||||
callback=make_exec_callback(exec_data),
|
||||
)
|
||||
|
||||
# Avoid terminating graph execution when some nodes are still running.
|
||||
while queue.empty() and futures:
|
||||
for node_id, future in list(futures.items()):
|
||||
if future.done():
|
||||
node_executed += 1
|
||||
del futures[node_id]
|
||||
elif queue.empty():
|
||||
cls.wait_future(future)
|
||||
while queue.empty() and running_executions:
|
||||
logger.debug(
|
||||
"Queue empty; running nodes: "
|
||||
f"{list(running_executions.keys())}"
|
||||
)
|
||||
for node_id, execution in list(running_executions.items()):
|
||||
if cancel.is_set():
|
||||
return n_node_executions
|
||||
|
||||
if not queue.empty():
|
||||
logger.debug(
|
||||
"Queue no longer empty! Returning to dispatching loop."
|
||||
)
|
||||
break # yield to parent loop to execute new queue items
|
||||
|
||||
logger.debug(f"Waiting on execution of node {node_id}")
|
||||
execution.wait(3)
|
||||
logger.debug(
|
||||
f"State of execution of node {node_id} after waiting: "
|
||||
f"{'DONE' if execution.ready() else 'RUNNING'}"
|
||||
)
|
||||
|
||||
cls.logger.info(
|
||||
"Finished graph execution",
|
||||
@@ -546,7 +603,7 @@ class Executor:
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
cls.logger.exception(
|
||||
logger.exception(
|
||||
f"Failed graph execution: {e}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -554,24 +611,20 @@ class Executor:
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return node_executed
|
||||
|
||||
@classmethod
|
||||
def wait_future(cls, future: Future, timeout: int | None = 3):
|
||||
try:
|
||||
if not future.done():
|
||||
future.result(timeout=timeout)
|
||||
except TimeoutError:
|
||||
# Avoid being blocked by long-running node, by not waiting its completion.
|
||||
pass
|
||||
finally:
|
||||
if not cancel.is_set():
|
||||
finished = True
|
||||
cancel.set()
|
||||
cancel_thread.join()
|
||||
return n_node_executions
|
||||
|
||||
|
||||
class ExecutionManager(AppService):
|
||||
def __init__(self):
|
||||
self.use_redis = False
|
||||
self.pool_size = Config().num_graph_workers
|
||||
self.queue = ExecutionQueue[GraphExecution]()
|
||||
self.use_redis = False
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
|
||||
# def __del__(self):
|
||||
# self.sync_manager.shutdown()
|
||||
@@ -581,11 +634,21 @@ class ExecutionManager(AppService):
|
||||
max_workers=self.pool_size,
|
||||
initializer=Executor.on_graph_executor_start,
|
||||
) as executor:
|
||||
sync_manager = multiprocessing.Manager()
|
||||
logger.info(
|
||||
f"Execution manager started with max-{self.pool_size} graph workers."
|
||||
)
|
||||
while True:
|
||||
executor.submit(Executor.on_graph_execution, self.queue.get())
|
||||
graph_exec_data = self.queue.get()
|
||||
graph_exec_id = graph_exec_data.graph_exec_id
|
||||
cancel_event = sync_manager.Event()
|
||||
future = executor.submit(
|
||||
Executor.on_graph_execution, graph_exec_data, cancel_event
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
future.add_done_callback(
|
||||
lambda _: self.active_graph_runs.pop(graph_exec_id)
|
||||
)
|
||||
|
||||
@property
|
||||
def agent_server_client(self) -> "AgentServer":
|
||||
@@ -594,7 +657,7 @@ class ExecutionManager(AppService):
|
||||
@expose
|
||||
def add_execution(
|
||||
self, graph_id: str, data: BlockInput, user_id: str
|
||||
) -> dict[Any, Any]:
|
||||
) -> dict[str, Any]:
|
||||
graph: Graph | None = self.run_and_wait(get_graph(graph_id, user_id=user_id))
|
||||
if not graph:
|
||||
raise Exception(f"Graph #{graph_id} not found.")
|
||||
@@ -647,4 +710,45 @@ class ExecutionManager(AppService):
|
||||
)
|
||||
self.queue.add(graph_exec)
|
||||
|
||||
return {"id": graph_exec_id}
|
||||
return graph_exec.model_dump()
|
||||
|
||||
@expose
|
||||
def cancel_execution(self, graph_exec_id: str) -> None:
|
||||
"""
|
||||
Mechanism:
|
||||
1. Set the cancel event
|
||||
2. Graph executor's cancel handler thread detects the event, terminates workers,
|
||||
reinitializes worker pool, and returns.
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
if graph_exec_id not in self.active_graph_runs:
|
||||
raise Exception(
|
||||
f"Graph execution #{graph_exec_id} not active/running: "
|
||||
"possibly already completed/cancelled."
|
||||
)
|
||||
|
||||
future, cancel_event = self.active_graph_runs[graph_exec_id]
|
||||
if cancel_event.is_set():
|
||||
return
|
||||
|
||||
cancel_event.set()
|
||||
future.result()
|
||||
|
||||
# Update the status of the unfinished node executions
|
||||
node_execs = self.run_and_wait(get_execution_results(graph_exec_id))
|
||||
for node_exec in node_execs:
|
||||
if node_exec.status not in (
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.FAILED,
|
||||
):
|
||||
self.run_and_wait(
|
||||
upsert_execution_output(
|
||||
node_exec.node_exec_id, "error", "TERMINATED"
|
||||
)
|
||||
)
|
||||
exec_update = self.run_and_wait(
|
||||
update_execution_status(
|
||||
node_exec.node_exec_id, ExecutionStatus.FAILED
|
||||
)
|
||||
)
|
||||
self.agent_server_client.send_execution_update(exec_update.model_dump())
|
||||
|
||||
@@ -11,14 +11,10 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from autogpt_server.data import block, db
|
||||
from autogpt_server.data import execution as execution_db
|
||||
from autogpt_server.data import graph as graph_db
|
||||
from autogpt_server.data import user as user_db
|
||||
from autogpt_server.data.block import BlockInput, CompletedBlockOutput
|
||||
from autogpt_server.data.execution import (
|
||||
ExecutionResult,
|
||||
get_execution_results,
|
||||
list_executions,
|
||||
)
|
||||
from autogpt_server.data.queue import AsyncEventQueue, AsyncRedisEventQueue
|
||||
from autogpt_server.data.user import get_or_create_user
|
||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||
@@ -170,10 +166,15 @@ class AgentServer(AppService):
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
path="/graphs/{graph_id}/executions/{run_id}",
|
||||
endpoint=self.get_run_execution_results,
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}",
|
||||
endpoint=self.get_graph_run_node_execution_results,
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}/stop",
|
||||
endpoint=self.stop_graph_run,
|
||||
methods=["POST"],
|
||||
)
|
||||
router.add_api_route(
|
||||
path="/graphs/{graph_id}/schedules",
|
||||
endpoint=self.create_schedule,
|
||||
@@ -423,15 +424,29 @@ class AgentServer(AppService):
|
||||
graph_id: str,
|
||||
node_input: dict[Any, Any],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[Any, Any]:
|
||||
) -> dict[str, Any]: # FIXME: add proper return type
|
||||
try:
|
||||
return self.execution_manager_client.add_execution(
|
||||
graph_exec = self.execution_manager_client.add_execution(
|
||||
graph_id, node_input, user_id=user_id
|
||||
)
|
||||
return {"id": graph_exec["graph_exec_id"]}
|
||||
except Exception as e:
|
||||
msg = e.__str__().encode().decode("unicode_escape")
|
||||
raise HTTPException(status_code=400, detail=msg)
|
||||
|
||||
async def stop_graph_run(
|
||||
self, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[execution_db.ExecutionResult]:
|
||||
if not await execution_db.get_graph_execution(graph_exec_id, user_id):
|
||||
raise HTTPException(
|
||||
404, detail=f"Agent execution #{graph_exec_id} not found"
|
||||
)
|
||||
|
||||
self.execution_manager_client.cancel_execution(graph_exec_id)
|
||||
|
||||
# Retrieve & return canceled graph execution in its final state
|
||||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
|
||||
@classmethod
|
||||
async def get_graph_input_schema(
|
||||
cls,
|
||||
@@ -458,17 +473,20 @@ class AgentServer(AppService):
|
||||
status_code=404, detail=f"Agent #{graph_id}{rev} not found."
|
||||
)
|
||||
|
||||
return await list_executions(graph_id, graph_version)
|
||||
return await execution_db.list_executions(graph_id, graph_version)
|
||||
|
||||
@classmethod
|
||||
async def get_run_execution_results(
|
||||
cls, graph_id: str, run_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[ExecutionResult]:
|
||||
async def get_graph_run_node_execution_results(
|
||||
cls,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[execution_db.ExecutionResult]:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
return await get_execution_results(run_id)
|
||||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
|
||||
async def create_schedule(
|
||||
self,
|
||||
@@ -506,7 +524,7 @@ class AgentServer(AppService):
|
||||
|
||||
@expose
|
||||
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
|
||||
execution_result = ExecutionResult(**execution_result_dict)
|
||||
execution_result = execution_db.ExecutionResult(**execution_result_dict)
|
||||
self.run_and_wait(self.event_queue.put(execution_result))
|
||||
|
||||
@expose
|
||||
|
||||
@@ -59,7 +59,6 @@ class SpinTestServer:
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
async def __aenter__(self):
|
||||
|
||||
self.name_server.__enter__()
|
||||
self.setup_dependency_overrides()
|
||||
self.agent_server.__enter__()
|
||||
@@ -95,7 +94,7 @@ async def wait_execution(
|
||||
timeout: int = 20,
|
||||
) -> list:
|
||||
async def is_execution_completed():
|
||||
execs = await AgentServer().get_run_execution_results(
|
||||
execs = await AgentServer().get_graph_run_node_execution_results(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
return (
|
||||
@@ -110,7 +109,7 @@ async def wait_execution(
|
||||
# Wait for the executions to complete
|
||||
for i in range(timeout):
|
||||
if await is_execution_completed():
|
||||
return await AgentServer().get_run_execution_results(
|
||||
return await AgentServer().get_graph_run_node_execution_results(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{
|
||||
"num_graph_workers": 10,
|
||||
"num_node_workers": 10
|
||||
"num_node_workers": 5
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ async def assert_sample_graph_executions(
|
||||
test_user: User,
|
||||
graph_exec_id: str,
|
||||
):
|
||||
executions = await agent_server.get_run_execution_results(
|
||||
executions = await agent_server.get_graph_run_node_execution_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
|
||||
@@ -156,7 +156,7 @@ async def test_input_pin_always_waited(server: SpinTestServer):
|
||||
server.agent_server, server.exec_manager, test_graph, test_user, {}, 3
|
||||
)
|
||||
|
||||
executions = await server.agent_server.get_run_execution_results(
|
||||
executions = await server.agent_server.get_graph_run_node_execution_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
assert len(executions) == 3
|
||||
@@ -236,7 +236,7 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
|
||||
graph_exec_id = await execute_graph(
|
||||
server.agent_server, server.exec_manager, test_graph, test_user, {}, 8
|
||||
)
|
||||
executions = await server.agent_server.get_run_execution_results(
|
||||
executions = await server.agent_server.get_graph_run_node_execution_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
assert len(executions) == 8
|
||||
|
||||
@@ -82,6 +82,6 @@ env:
|
||||
APP_ENV: "dev"
|
||||
PYRO_HOST: "0.0.0.0"
|
||||
NUM_GRAPH_WORKERS: 100
|
||||
NUM_NODE_WORKERS: 100
|
||||
NUM_NODE_WORKERS: 5
|
||||
REDIS_HOST: "redis-dev-master.redis-dev.svc.cluster.local"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PORT: "6379"
|
||||
|
||||
Reference in New Issue
Block a user