feat(server, builder): Implement "STOP" button for graph runs (#7892)

- feat(builder): Add "Stop Run" buttons to monitor and builder
  - Implement additional state management in `useAgentGraph` hook
    - Add "stop" request mechanism
    - Implement execution status tracking using WebSockets
    - Add `isSaving`, `isRunning`, `isStopping` outputs
    - Add `requestStopRun` method
      - Rename `requestSaveRun` to `requestSaveAndRun` for clarity
  - Add needed functionality for the above to `AutoGPTServerAPI` client
    - Add `stopGraphExecution` method
    - Add support for multiple handlers per WebSocket method
    - Fix parsing of timestamps in `execution_event` WebSocket messages
  - Add `IconSquare` from Lucide to `@/components/ui/icons`

- feat(server): Add `POST /graphs/{graph_id}/executions/{graph_exec_id}/stop` route
  - Add `stop_graph_run` method to `AgentServer`

- feat(server): Add `cancel_execution` method to `ExecutionManager`
  - Replace node executor `ProcessPoolExecutor` by `multiprocessing.Pool` (which has a `terminate()` method)
    - Remove now unnecessary `Executor.wait_future(..)` method
  - Add `get_graph_execution(..)` in `.data.execution`

- fix(server): Reduce number of node executors to 5 per graph executor
  This is necessary because `multiprocessing.Pool` spawns its workers on init, instead of based on demand like `ProcessPoolExecutor` does

- dx(server): Improve debug logging in `ExecutionManager`
- ci(server): Add debug logging mode to CI Pytest step

### Other improvements
Server:
- Improve output type of `ExecutionManager.add_execution(..)`
- Renamed a few things in `.server.rest_api` for consistency

Front end:
- Improved typing in `AutoGPTServerAPI` client
This commit is contained in:
Reinier van der Leer
2024-09-05 14:42:28 +02:00
committed by GitHub
parent 11827835a0
commit 8fd22bcfd7
13 changed files with 417 additions and 114 deletions

View File

@@ -128,9 +128,14 @@ jobs:
- name: Run pytest with coverage
run: |
poetry run pytest -vv \
test
if [[ "${{ runner.debug }}" == "1" ]]; then
poetry run pytest -vv -o log_cli=true -o log_cli_level=DEBUG test
else
poetry run pytest -vv test
fi
if: success() || (failure() && steps.lint.outcome == 'failure')
env:
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
env:
CI: true
PLAIN_OUTPUT: True

View File

@@ -34,7 +34,12 @@ import ConnectionLine from "./ConnectionLine";
import { Control, ControlPanel } from "@/components/edit/control/ControlPanel";
import { SaveControl } from "@/components/edit/control/SaveControl";
import { BlocksControl } from "@/components/edit/control/BlocksControl";
import { IconPlay, IconRedo2, IconUndo2 } from "@/components/ui/icons";
import {
IconPlay,
IconRedo2,
IconSquare,
IconUndo2,
} from "@/components/ui/icons";
import { startTutorial } from "./tutorial";
import useAgentGraph from "@/hooks/useAgentGraph";
import { v4 as uuidv4 } from "uuid";
@@ -75,7 +80,9 @@ const FlowEditor: React.FC<{
availableNodes,
getOutputType,
requestSave,
requestSaveRun,
requestSaveAndRun,
requestStopRun,
isRunning,
nodes,
setNodes,
edges,
@@ -542,9 +549,9 @@ const FlowEditor: React.FC<{
onClick: handleRedo,
},
{
label: "Run",
icon: <IconPlay />,
onClick: requestSaveRun,
label: !isRunning ? "Run" : "Stop",
icon: !isRunning ? <IconPlay /> : <IconSquare />,
onClick: !isRunning ? requestSaveAndRun : requestStopRun,
},
];

View File

@@ -1,9 +1,10 @@
import React from "react";
import { GraphMeta } from "@/lib/autogpt-server-api";
import React, { useCallback } from "react";
import AutoGPTServerAPI, { GraphMeta } from "@/lib/autogpt-server-api";
import { FlowRun } from "@/lib/types";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import Link from "next/link";
import { buttonVariants } from "@/components/ui/button";
import { Button, buttonVariants } from "@/components/ui/button";
import { IconSquare } from "@/components/ui/icons";
import { Pencil2Icon } from "@radix-ui/react-icons";
import moment from "moment/moment";
import { FlowRunStatusBadge } from "@/components/monitor/FlowRunStatusBadge";
@@ -20,6 +21,11 @@ export const FlowRunInfo: React.FC<
);
}
const handleStopRun = useCallback(() => {
const api = new AutoGPTServerAPI();
api.stopGraphExecution(flow.id, flowRun.id);
}, [flow.id, flowRun.id]);
return (
<Card {...props}>
<CardHeader className="flex-row items-center justify-between space-x-3 space-y-0">
@@ -34,12 +40,19 @@ export const FlowRunInfo: React.FC<
Run ID: <code>{flowRun.id}</code>
</p>
</div>
<Link
className={buttonVariants({ variant: "outline" })}
href={`/build?flowID=${flow.id}`}
>
<Pencil2Icon className="mr-2" /> Edit Agent
</Link>
<div className="flex space-x-2">
{flowRun.status === "running" && (
<Button onClick={handleStopRun} variant="destructive">
<IconSquare className="mr-2" /> Stop Run
</Button>
)}
<Link
className={buttonVariants({ variant: "outline" })}
href={`/build?flowID=${flow.id}`}
>
<Pencil2Icon className="mr-2" /> Edit Agent
</Link>
</div>
</CardHeader>
<CardContent>
<p>

View File

@@ -405,6 +405,40 @@ export const IconPlay = createIcon((props) => (
</svg>
));
/**
* Square icon component.
*
* @component IconSquare
* @param {IconProps} props - The props object containing additional attributes and event handlers for the icon.
* @returns {JSX.Element} - The square icon.
*
* @example
* // Default usage this is the standard usage
* <IconSquare />
*
* @example
* // With custom color and size these should be used sparingly and only when necessary
* <IconSquare className="text-primary" size="lg" />
*
* @example
* // With custom size and onClick handler
* <IconSquare size="sm" onClick={handleOnClick} />
*/
export const IconSquare = createIcon((props) => (
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
{...props}
>
<rect width="18" height="18" x="3" y="3" rx="2" />
</svg>
));
/**
* Package2 icon component.
*

View File

@@ -31,22 +31,27 @@ export default function useAgentGraph(
const [updateQueue, setUpdateQueue] = useState<NodeExecutionResult[]>([]);
const processedUpdates = useRef<NodeExecutionResult[]>([]);
/**
* User `request` to save or save&run the agent
* User `request` to save or save&run the agent, or to stop the active run.
* `state` is used to track the request status:
* - none: no request
* - saving: request was sent to save the agent
* and nodes are pending sync to update their backend ids
* - running: request was sent to run the agent
* and frontend is enqueueing execution results
* - stopping: a request to stop the active run has been sent; response is pending
* - error: request failed
*
* As of now, state will be stuck at 'running' (if run requested)
* because there's no way to know when the execution is done
*/
const [saveRunRequest, setSaveRunRequest] = useState<{
request: "none" | "save" | "run";
state: "none" | "saving" | "running" | "error";
}>({
const [saveRunRequest, setSaveRunRequest] = useState<
| {
request: "none" | "save" | "run";
state: "none" | "saving" | "error";
}
| {
request: "run" | "stop";
state: "running" | "stopping" | "error";
activeExecutionID?: string;
}
>({
request: "none",
state: "none",
});
@@ -128,13 +133,14 @@ export default function useAgentGraph(
console.error("Error saving agent");
} else if (saveRunRequest.request === "run") {
console.error(`Error saving&running agent`);
} else if (saveRunRequest.request === "stop") {
console.error(`Error stopping agent`);
}
// Reset request
setSaveRunRequest((prev) => ({
...prev,
setSaveRunRequest({
request: "none",
state: "none",
}));
});
return;
}
// When saving request is done
@@ -145,11 +151,10 @@ export default function useAgentGraph(
) {
// Reset request if only save was requested
if (saveRunRequest.request === "save") {
setSaveRunRequest((prev) => ({
...prev,
setSaveRunRequest({
request: "none",
state: "none",
}));
});
// If run was requested, run the agent
} else if (saveRunRequest.request === "run") {
if (!validateNodes()) {
@@ -161,16 +166,64 @@ export default function useAgentGraph(
return;
}
api.subscribeToExecution(savedAgent.id);
api.executeGraph(savedAgent.id);
processedUpdates.current = processedUpdates.current = [];
setSaveRunRequest({ request: "run", state: "running" });
api
.executeGraph(savedAgent.id)
.then((graphExecution) => {
setSaveRunRequest({
request: "run",
state: "running",
activeExecutionID: graphExecution.id,
});
setSaveRunRequest((prev) => ({
...prev,
request: "run",
state: "running",
}));
// Track execution until completed
const pendingNodeExecutions: Set<string> = new Set();
const cancelExecListener = api.onWebSocketMessage(
"execution_event",
(nodeResult) => {
// We are racing the server here, since we need the ID to filter events
if (nodeResult.graph_exec_id != graphExecution.id) {
return;
}
if (
nodeResult.status != "COMPLETED" &&
nodeResult.status != "FAILED"
) {
pendingNodeExecutions.add(nodeResult.node_exec_id);
} else {
pendingNodeExecutions.delete(nodeResult.node_exec_id);
}
if (pendingNodeExecutions.size == 0) {
// Assuming the first event is always a QUEUED node, and
// following nodes are QUEUED before all preceding nodes are COMPLETED,
// an empty set means the graph has finished running.
cancelExecListener();
setSaveRunRequest({ request: "none", state: "none" });
}
},
);
})
.catch(() => setSaveRunRequest({ request: "run", state: "error" }));
processedUpdates.current = processedUpdates.current = [];
}
}
// Handle stop request
if (
saveRunRequest.request === "stop" &&
saveRunRequest.state != "stopping" &&
savedAgent &&
saveRunRequest.activeExecutionID
) {
setSaveRunRequest({
request: "stop",
state: "stopping",
activeExecutionID: saveRunRequest.activeExecutionID,
});
api
.stopGraphExecution(savedAgent.id, saveRunRequest.activeExecutionID)
.then(() => setSaveRunRequest({ request: "none", state: "none" }));
}
}, [saveRunRequest, savedAgent, nodesSyncedWithSavedAgent]);
// Check if node ids are synced with saved agent
@@ -657,7 +710,7 @@ export default function useAgentGraph(
[saveAgent],
);
const requestSaveRun = useCallback(() => {
const requestSaveAndRun = useCallback(() => {
saveAgent();
setSaveRunRequest({
request: "run",
@@ -665,6 +718,23 @@ export default function useAgentGraph(
});
}, [saveAgent]);
const requestStopRun = useCallback(() => {
if (saveRunRequest.state != "running") {
return;
}
if (!saveRunRequest.activeExecutionID) {
console.warn(
"Stop requested but execution ID is unknown; state:",
saveRunRequest,
);
}
setSaveRunRequest((prev) => ({
...prev,
request: "stop",
state: "running",
}));
}, [saveRunRequest]);
return {
agentName,
setAgentName,
@@ -674,7 +744,11 @@ export default function useAgentGraph(
availableNodes,
getOutputType,
requestSave,
requestSaveRun,
requestSaveAndRun,
requestStopRun,
isSaving: saveRunRequest.state == "saving",
isRunning: saveRunRequest.state == "running",
isStopping: saveRunRequest.state == "stopping",
nodes,
setNodes,
edges,

View File

@@ -15,7 +15,7 @@ export default class AutoGPTServerAPI {
private wsUrl: string;
private webSocket: WebSocket | null = null;
private wsConnecting: Promise<void> | null = null;
private wsMessageHandlers: { [key: string]: (data: any) => void } = {};
private wsMessageHandlers: Record<string, Set<(data: any) => void>> = {};
private supabaseClient = createClient();
constructor(
@@ -128,16 +128,19 @@ export default class AutoGPTServerAPI {
runID: string,
): Promise<NodeExecutionResult[]> {
return (await this._get(`/graphs/${graphID}/executions/${runID}`)).map(
(result: any) => ({
...result,
add_time: new Date(result.add_time),
queue_time: result.queue_time ? new Date(result.queue_time) : undefined,
start_time: result.start_time ? new Date(result.start_time) : undefined,
end_time: result.end_time ? new Date(result.end_time) : undefined,
}),
parseNodeExecutionResultTimestamps,
);
}
async stopGraphExecution(
graphID: string,
runID: string,
): Promise<NodeExecutionResult[]> {
return (
await this._request("POST", `/graphs/${graphID}/executions/${runID}/stop`)
).map(parseNodeExecutionResultTimestamps);
}
private async _get(path: string) {
return this._request("GET", path);
}
@@ -207,10 +210,13 @@ export default class AutoGPTServerAPI {
};
this.webSocket.onmessage = (event) => {
const message = JSON.parse(event.data);
if (this.wsMessageHandlers[message.method]) {
this.wsMessageHandlers[message.method](message.data);
const message: WebsocketMessage = JSON.parse(event.data);
if (message.method == "execution_event") {
message.data = parseNodeExecutionResultTimestamps(message.data);
}
this.wsMessageHandlers[message.method]?.forEach((handler) =>
handler(message.data),
);
};
} catch (error) {
console.error("Error connecting to WebSocket:", error);
@@ -250,8 +256,12 @@ export default class AutoGPTServerAPI {
onWebSocketMessage<M extends keyof WebsocketMessageTypeMap>(
method: M,
handler: (data: WebsocketMessageTypeMap[M]) => void,
) {
this.wsMessageHandlers[method] = handler;
): () => void {
this.wsMessageHandlers[method] ??= new Set();
this.wsMessageHandlers[method].add(handler);
// Return detacher
return () => this.wsMessageHandlers[method].delete(handler);
}
subscribeToExecution(graphId: string) {
@@ -274,3 +284,22 @@ type WebsocketMessageTypeMap = {
subscribe: { graph_id: string };
execution_event: NodeExecutionResult;
};
type WebsocketMessage = {
[M in keyof WebsocketMessageTypeMap]: {
method: M;
data: WebsocketMessageTypeMap[M];
};
}[keyof WebsocketMessageTypeMap];
/* *** HELPER FUNCTIONS *** */
function parseNodeExecutionResultTimestamps(result: any): NodeExecutionResult {
return {
...result,
add_time: new Date(result.add_time),
queue_time: result.queue_time ? new Date(result.queue_time) : undefined,
start_time: result.start_time ? new Date(result.start_time) : undefined,
end_time: result.end_time ? new Date(result.end_time) : undefined,
};
}

View File

@@ -300,6 +300,26 @@ async def update_execution_status(
return ExecutionResult.from_db(res)
async def get_graph_execution(
graph_exec_id: str, user_id: str
) -> AgentGraphExecution | None:
"""
Retrieve a specific graph execution by its ID.
Args:
graph_exec_id (str): The ID of the graph execution to retrieve.
user_id (str): The ID of the user to whom the graph (execution) belongs.
Returns:
AgentGraphExecution | None: The graph execution if found, None otherwise.
"""
execution = await AgentGraphExecution.prisma().find_first(
where={"id": graph_exec_id, "userId": user_id},
include=GRAPH_EXECUTION_INCLUDE,
)
return execution
async def list_executions(graph_id: str, graph_version: int | None = None) -> list[str]:
where: AgentGraphExecutionWhereInput = {"agentGraphId": graph_id}
if graph_version is not None:

View File

@@ -1,7 +1,10 @@
import asyncio
import logging
from concurrent.futures import Future, ProcessPoolExecutor, TimeoutError
import multiprocessing
import threading
from concurrent.futures import Future, ProcessPoolExecutor
from contextlib import contextmanager
from multiprocessing.pool import AsyncResult, Pool
from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar
if TYPE_CHECKING:
@@ -16,6 +19,7 @@ from autogpt_server.data.execution import (
GraphExecution,
NodeExecution,
create_graph_execution,
get_execution_results,
get_incomplete_executions,
get_latest_execution,
merge_execution_input,
@@ -461,15 +465,19 @@ class Executor:
cls.loop = asyncio.new_event_loop()
cls.loop.run_until_complete(db.connect())
cls.pool_size = Config().num_node_workers
cls.executor = ProcessPoolExecutor(
max_workers=cls.pool_size,
cls._init_node_executor_pool()
logger.info(f"Graph executor started with max-{cls.pool_size} node workers.")
@classmethod
def _init_node_executor_pool(cls):
cls.executor = Pool(
processes=cls.pool_size,
initializer=cls.on_node_executor_start,
)
cls.logger.info(f"Graph executor started with max-{cls.pool_size} node workers")
@classmethod
@error_logged
def on_graph_execution(cls, data: GraphExecution):
def on_graph_execution(cls, data: GraphExecution, cancel: threading.Event):
log_metadata = get_log_metadata(
graph_eid=data.graph_exec_id,
graph_id=data.graph_id,
@@ -477,7 +485,7 @@ class Executor:
node_eid="*",
block_name="-",
)
timing_info, node_count = cls._on_graph_execution(data, log_metadata)
timing_info, node_count = cls._on_graph_execution(data, cancel, log_metadata)
metric_graph_timing("walltime", timing_info.wall_time, tags=log_metadata)
metric_graph_timing("cputime", timing_info.cpu_time, tags=log_metadata)
metric_graph_count("nodecount", node_count, tags=log_metadata)
@@ -495,7 +503,9 @@ class Executor:
@classmethod
@time_measured
def _on_graph_execution(cls, graph_data: GraphExecution, log_metadata: dict) -> int:
def _on_graph_execution(
cls, graph_data: GraphExecution, cancel: threading.Event, log_metadata: dict
) -> int:
cls.logger.info(
"Start graph execution",
extra={
@@ -504,38 +514,85 @@ class Executor:
}
},
)
node_executed = 0
n_node_executions = 0
finished = False
def cancel_handler():
while not cancel.is_set():
cancel.wait(1)
if finished:
return
cls.executor.terminate()
logger.info(
f"Terminated graph execution {graph_data.graph_exec_id}",
extra={"json_fields": {**log_metadata}},
)
cls._init_node_executor_pool()
cancel_thread = threading.Thread(target=cancel_handler)
cancel_thread.start()
try:
queue = ExecutionQueue[NodeExecution]()
for node_exec in graph_data.start_node_execs:
queue.add(node_exec)
futures: dict[str, Future] = {}
running_executions: dict[str, AsyncResult] = {}
def make_exec_callback(exec_data: NodeExecution):
node_id = exec_data.node_id
def callback(_):
running_executions.pop(node_id)
nonlocal n_node_executions
n_node_executions += 1
return callback
while not queue.empty():
execution = queue.get()
if cancel.is_set():
return n_node_executions
exec_data = queue.get()
# Avoid parallel execution of the same node.
fut = futures.get(execution.node_id)
if fut and not fut.done():
execution = running_executions.get(exec_data.node_id)
if execution and not execution.ready():
# TODO (performance improvement):
# Wait for the completion of the same node execution is blocking.
# To improve this we need a separate queue for each node.
# Re-enqueueing the data back to the queue will disrupt the order.
cls.wait_future(fut, timeout=None)
execution.wait()
futures[execution.node_id] = cls.executor.submit(
cls.on_node_execution, queue, execution
logger.debug(f"Dispatching execution of node {exec_data.node_id}")
running_executions[exec_data.node_id] = cls.executor.apply_async(
cls.on_node_execution,
(queue, exec_data),
callback=make_exec_callback(exec_data),
)
# Avoid terminating graph execution when some nodes are still running.
while queue.empty() and futures:
for node_id, future in list(futures.items()):
if future.done():
node_executed += 1
del futures[node_id]
elif queue.empty():
cls.wait_future(future)
while queue.empty() and running_executions:
logger.debug(
"Queue empty; running nodes: "
f"{list(running_executions.keys())}"
)
for node_id, execution in list(running_executions.items()):
if cancel.is_set():
return n_node_executions
if not queue.empty():
logger.debug(
"Queue no longer empty! Returning to dispatching loop."
)
break # yield to parent loop to execute new queue items
logger.debug(f"Waiting on execution of node {node_id}")
execution.wait(3)
logger.debug(
f"State of execution of node {node_id} after waiting: "
f"{'DONE' if execution.ready() else 'RUNNING'}"
)
cls.logger.info(
"Finished graph execution",
@@ -546,7 +603,7 @@ class Executor:
},
)
except Exception as e:
cls.logger.exception(
logger.exception(
f"Failed graph execution: {e}",
extra={
"json_fields": {
@@ -554,24 +611,20 @@ class Executor:
}
},
)
return node_executed
@classmethod
def wait_future(cls, future: Future, timeout: int | None = 3):
try:
if not future.done():
future.result(timeout=timeout)
except TimeoutError:
# Avoid being blocked by long-running node, by not waiting its completion.
pass
finally:
if not cancel.is_set():
finished = True
cancel.set()
cancel_thread.join()
return n_node_executions
class ExecutionManager(AppService):
def __init__(self):
self.use_redis = False
self.pool_size = Config().num_graph_workers
self.queue = ExecutionQueue[GraphExecution]()
self.use_redis = False
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
# def __del__(self):
# self.sync_manager.shutdown()
@@ -581,11 +634,21 @@ class ExecutionManager(AppService):
max_workers=self.pool_size,
initializer=Executor.on_graph_executor_start,
) as executor:
sync_manager = multiprocessing.Manager()
logger.info(
f"Execution manager started with max-{self.pool_size} graph workers."
)
while True:
executor.submit(Executor.on_graph_execution, self.queue.get())
graph_exec_data = self.queue.get()
graph_exec_id = graph_exec_data.graph_exec_id
cancel_event = sync_manager.Event()
future = executor.submit(
Executor.on_graph_execution, graph_exec_data, cancel_event
)
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
future.add_done_callback(
lambda _: self.active_graph_runs.pop(graph_exec_id)
)
@property
def agent_server_client(self) -> "AgentServer":
@@ -594,7 +657,7 @@ class ExecutionManager(AppService):
@expose
def add_execution(
self, graph_id: str, data: BlockInput, user_id: str
) -> dict[Any, Any]:
) -> dict[str, Any]:
graph: Graph | None = self.run_and_wait(get_graph(graph_id, user_id=user_id))
if not graph:
raise Exception(f"Graph #{graph_id} not found.")
@@ -647,4 +710,45 @@ class ExecutionManager(AppService):
)
self.queue.add(graph_exec)
return {"id": graph_exec_id}
return graph_exec.model_dump()
@expose
def cancel_execution(self, graph_exec_id: str) -> None:
"""
Mechanism:
1. Set the cancel event
2. Graph executor's cancel handler thread detects the event, terminates workers,
reinitializes worker pool, and returns.
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
"""
if graph_exec_id not in self.active_graph_runs:
raise Exception(
f"Graph execution #{graph_exec_id} not active/running: "
"possibly already completed/cancelled."
)
future, cancel_event = self.active_graph_runs[graph_exec_id]
if cancel_event.is_set():
return
cancel_event.set()
future.result()
# Update the status of the unfinished node executions
node_execs = self.run_and_wait(get_execution_results(graph_exec_id))
for node_exec in node_execs:
if node_exec.status not in (
ExecutionStatus.COMPLETED,
ExecutionStatus.FAILED,
):
self.run_and_wait(
upsert_execution_output(
node_exec.node_exec_id, "error", "TERMINATED"
)
)
exec_update = self.run_and_wait(
update_execution_status(
node_exec.node_exec_id, ExecutionStatus.FAILED
)
)
self.agent_server_client.send_execution_update(exec_update.model_dump())

View File

@@ -11,14 +11,10 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from autogpt_server.data import block, db
from autogpt_server.data import execution as execution_db
from autogpt_server.data import graph as graph_db
from autogpt_server.data import user as user_db
from autogpt_server.data.block import BlockInput, CompletedBlockOutput
from autogpt_server.data.execution import (
ExecutionResult,
get_execution_results,
list_executions,
)
from autogpt_server.data.queue import AsyncEventQueue, AsyncRedisEventQueue
from autogpt_server.data.user import get_or_create_user
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
@@ -170,10 +166,15 @@ class AgentServer(AppService):
methods=["GET"],
)
router.add_api_route(
path="/graphs/{graph_id}/executions/{run_id}",
endpoint=self.get_run_execution_results,
path="/graphs/{graph_id}/executions/{graph_exec_id}",
endpoint=self.get_graph_run_node_execution_results,
methods=["GET"],
)
router.add_api_route(
path="/graphs/{graph_id}/executions/{graph_exec_id}/stop",
endpoint=self.stop_graph_run,
methods=["POST"],
)
router.add_api_route(
path="/graphs/{graph_id}/schedules",
endpoint=self.create_schedule,
@@ -423,15 +424,29 @@ class AgentServer(AppService):
graph_id: str,
node_input: dict[Any, Any],
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[Any, Any]:
) -> dict[str, Any]: # FIXME: add proper return type
try:
return self.execution_manager_client.add_execution(
graph_exec = self.execution_manager_client.add_execution(
graph_id, node_input, user_id=user_id
)
return {"id": graph_exec["graph_exec_id"]}
except Exception as e:
msg = e.__str__().encode().decode("unicode_escape")
raise HTTPException(status_code=400, detail=msg)
async def stop_graph_run(
self, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> list[execution_db.ExecutionResult]:
if not await execution_db.get_graph_execution(graph_exec_id, user_id):
raise HTTPException(
404, detail=f"Agent execution #{graph_exec_id} not found"
)
self.execution_manager_client.cancel_execution(graph_exec_id)
# Retrieve & return canceled graph execution in its final state
return await execution_db.get_execution_results(graph_exec_id)
@classmethod
async def get_graph_input_schema(
cls,
@@ -458,17 +473,20 @@ class AgentServer(AppService):
status_code=404, detail=f"Agent #{graph_id}{rev} not found."
)
return await list_executions(graph_id, graph_version)
return await execution_db.list_executions(graph_id, graph_version)
@classmethod
async def get_run_execution_results(
cls, graph_id: str, run_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> list[ExecutionResult]:
async def get_graph_run_node_execution_results(
cls,
graph_id: str,
graph_exec_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> list[execution_db.ExecutionResult]:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return await get_execution_results(run_id)
return await execution_db.get_execution_results(graph_exec_id)
async def create_schedule(
self,
@@ -506,7 +524,7 @@ class AgentServer(AppService):
@expose
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
execution_result = ExecutionResult(**execution_result_dict)
execution_result = execution_db.ExecutionResult(**execution_result_dict)
self.run_and_wait(self.event_queue.put(execution_result))
@expose

View File

@@ -59,7 +59,6 @@ class SpinTestServer:
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
async def __aenter__(self):
self.name_server.__enter__()
self.setup_dependency_overrides()
self.agent_server.__enter__()
@@ -95,7 +94,7 @@ async def wait_execution(
timeout: int = 20,
) -> list:
async def is_execution_completed():
execs = await AgentServer().get_run_execution_results(
execs = await AgentServer().get_graph_run_node_execution_results(
graph_id, graph_exec_id, user_id
)
return (
@@ -110,7 +109,7 @@ async def wait_execution(
# Wait for the executions to complete
for i in range(timeout):
if await is_execution_completed():
return await AgentServer().get_run_execution_results(
return await AgentServer().get_graph_run_node_execution_results(
graph_id, graph_exec_id, user_id
)
time.sleep(1)

View File

@@ -1,4 +1,4 @@
{
"num_graph_workers": 10,
"num_node_workers": 10
"num_node_workers": 5
}

View File

@@ -35,7 +35,7 @@ async def assert_sample_graph_executions(
test_user: User,
graph_exec_id: str,
):
executions = await agent_server.get_run_execution_results(
executions = await agent_server.get_graph_run_node_execution_results(
test_graph.id, graph_exec_id, test_user.id
)
@@ -156,7 +156,7 @@ async def test_input_pin_always_waited(server: SpinTestServer):
server.agent_server, server.exec_manager, test_graph, test_user, {}, 3
)
executions = await server.agent_server.get_run_execution_results(
executions = await server.agent_server.get_graph_run_node_execution_results(
test_graph.id, graph_exec_id, test_user.id
)
assert len(executions) == 3
@@ -236,7 +236,7 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
graph_exec_id = await execute_graph(
server.agent_server, server.exec_manager, test_graph, test_user, {}, 8
)
executions = await server.agent_server.get_run_execution_results(
executions = await server.agent_server.get_graph_run_node_execution_results(
test_graph.id, graph_exec_id, test_user.id
)
assert len(executions) == 8

View File

@@ -82,6 +82,6 @@ env:
APP_ENV: "dev"
PYRO_HOST: "0.0.0.0"
NUM_GRAPH_WORKERS: 100
NUM_NODE_WORKERS: 100
NUM_NODE_WORKERS: 5
REDIS_HOST: "redis-dev-master.redis-dev.svc.cluster.local"
REDIS_PORT: "6379"
REDIS_PORT: "6379"