Compare commits

...

19 Commits

Author SHA1 Message Date
Zamil Majdy
b3d873a4b9 Make all green 2024-08-28 15:50:58 -05:00
Zamil Majdy
c9989c47ec Reformat 2024-08-28 15:37:20 -05:00
Zamil Majdy
cfe85e40d8 fix(rnd): Fix broken test and Input/Output block field renaming 2024-08-28 15:36:53 -05:00
Reinier van der Leer
0d8760931d update note on saveRunRequest / setSaveRunRequest 2024-08-28 18:37:12 +02:00
Reinier van der Leer
0c2c8085bd add note to @expose decorator 2024-08-28 17:59:12 +02:00
Reinier van der Leer
efdd0fb04c feat(builder): Add "Stop Run" button to FlowRunInfo on /monitor page 2024-08-27 18:05:46 +02:00
Reinier van der Leer
bbe82fc9c1 fix(builder): Fix parsing of NodeExecutionResult objects from WebSocket messages 2024-08-27 18:01:28 +02:00
Reinier van der Leer
7e0b5c3235 feat(builder): Add button, icon, and logic to stop an agent run on /build page 2024-08-27 14:16:34 +02:00
Reinier van der Leer
c21bdfeb47 feat(builder): Add support for multiple handlers per websocket method 2024-08-27 14:15:43 +02:00
Reinier van der Leer
c062786f80 fix .data.execution.get_graph_execution(..) -> fix stop execution endpoint 2024-08-27 14:05:21 +02:00
Reinier van der Leer
d63ab9a2f9 fix(server): Fix deadlock and simplify cancel mechanism in Executor.on_graph_execution
This took many hours of bonking my head against the wall, but in the end I found that

      `multiprocess.Event()` or `multiprocess.Event(ctx=executor._mp_context)`

doesn't work with the `ProcessPoolExecutor`, and instead I had to use

      `multiprocess.Manager().Event()`

This adds some overhead, but at least it works. The deadlocking issue occurred for all shared types, e.g. I also tried `Value` and `Array`.
2024-08-27 10:48:54 +02:00
Reinier van der Leer
fce6394a49 fix agent execution endpoint 2024-08-26 17:34:01 +02:00
Reinier van der Leer
13e7716424 feat(builder): Add stopGraphExecution(..) to AutoGPTServerAPI 2024-08-26 17:06:19 +02:00
Reinier van der Leer
2973567010 smol refactor for consistency & readability 2024-08-26 16:42:14 +02:00
Reinier van der Leer
b6c4fc4742 feat(server): Add POST /graphs/{graph_id}/executions/{graph_exec_id}/stop endpoint
- Add `stop_graph_execution` + route in `AgentServer`
- Add `get_graph_execution` function in `.data.execution`
- Fix return type of `ExecutionManager.add_execution(..)`
- Fix type issue with `@expose` decorator
2024-08-26 16:28:08 +02:00
Reinier van der Leer
f9a3170296 Merge branch 'master' into reinier/open-1669-implement-stop-button-for-agent-runs 2024-08-26 15:41:17 +02:00
Reinier van der Leer
a74f76893e smol refactor for readability 2024-08-26 15:40:54 +02:00
Reinier van der Leer
e6aaf71f21 feat(server): Add cancel_execution method to ExecutionManager
- Add `ExecutionManager.cancel_execution(..)`
- Replace graph executor's `ProcessPoolExecutor` by `multiprocessing.pool.Pool`
  - Remove now-unnecessary `Executor.wait_future(..)` method
- Add termination mechanism to `Executor.on_graph_execution`
2024-08-26 15:07:48 +02:00
Reinier van der Leer
31129bd080 fix type issue with AppService.run_and_wait(..) 2024-08-26 14:59:23 +02:00
20 changed files with 518 additions and 181 deletions

View File

@@ -34,7 +34,12 @@ import ConnectionLine from "./ConnectionLine";
import { Control, ControlPanel } from "@/components/edit/control/ControlPanel";
import { SaveControl } from "@/components/edit/control/SaveControl";
import { BlocksControl } from "@/components/edit/control/BlocksControl";
import { IconPlay, IconRedo2, IconUndo2 } from "@/components/ui/icons";
import {
IconPlay,
IconRedo2,
IconSquare,
IconUndo2,
} from "@/components/ui/icons";
import useAgentGraph from "@/hooks/useAgentGraph";
// This is for the history, this is the minimum distance a block must move before it is logged
@@ -71,7 +76,9 @@ const FlowEditor: React.FC<{
availableNodes,
getOutputType,
requestSave,
requestSaveRun,
requestSaveAndRun,
requestStopRun,
isRunning,
nodes,
setNodes,
edges,
@@ -465,9 +472,9 @@ const FlowEditor: React.FC<{
onClick: handleRedo,
},
{
label: "Run",
icon: <IconPlay />,
onClick: requestSaveRun,
label: !isRunning ? "Run" : "Stop",
icon: !isRunning ? <IconPlay /> : <IconSquare />,
onClick: !isRunning ? requestSaveAndRun : requestStopRun,
},
];

View File

@@ -1,9 +1,10 @@
import React from "react";
import { GraphMeta } from "@/lib/autogpt-server-api";
import React, { useCallback } from "react";
import AutoGPTServerAPI, { GraphMeta } from "@/lib/autogpt-server-api";
import { FlowRun } from "@/lib/types";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import Link from "next/link";
import { buttonVariants } from "@/components/ui/button";
import { Button, buttonVariants } from "@/components/ui/button";
import { IconSquare } from "@/components/ui/icons";
import { Pencil2Icon } from "@radix-ui/react-icons";
import moment from "moment/moment";
import { FlowRunStatusBadge } from "@/components/monitor/FlowRunStatusBadge";
@@ -20,6 +21,11 @@ export const FlowRunInfo: React.FC<
);
}
const handleStopRun = useCallback(() => {
const api = new AutoGPTServerAPI();
api.stopGraphExecution(flow.id, flowRun.id);
}, [flow.id, flowRun.id]);
return (
<Card {...props}>
<CardHeader className="flex-row items-center justify-between space-x-3 space-y-0">
@@ -34,12 +40,19 @@ export const FlowRunInfo: React.FC<
Run ID: <code>{flowRun.id}</code>
</p>
</div>
<Link
className={buttonVariants({ variant: "outline" })}
href={`/build?flowID=${flow.id}`}
>
<Pencil2Icon className="mr-2" /> Edit Agent
</Link>
<div className="flex space-x-2">
{flowRun.status === "running" && (
<Button onClick={handleStopRun} variant="destructive">
<IconSquare className="mr-2" /> Stop Run
</Button>
)}
<Link
className={buttonVariants({ variant: "outline" })}
href={`/build?flowID=${flow.id}`}
>
<Pencil2Icon className="mr-2" /> Edit Agent
</Link>
</div>
</CardHeader>
<CardContent>
<p>

View File

@@ -405,6 +405,40 @@ export const IconPlay = createIcon((props) => (
</svg>
));
/**
* Square icon component.
*
* @component IconSquare
* @param {IconProps} props - The props object containing additional attributes and event handlers for the icon.
* @returns {JSX.Element} - The square icon.
*
* @example
* // Default usage this is the standard usage
* <IconSquare />
*
* @example
* // With custom color and size these should be used sparingly and only when necessary
* <IconSquare className="text-primary" size="lg" />
*
* @example
* // With custom size and onClick handler
* <IconSquare size="sm" onClick={handleOnClick} />
*/
export const IconSquare = createIcon((props) => (
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
{...props}
>
<rect width="18" height="18" x="3" y="3" rx="2" />
</svg>
));
/**
* Package2 icon component.
*

View File

@@ -31,22 +31,27 @@ export default function useAgentGraph(
const [updateQueue, setUpdateQueue] = useState<NodeExecutionResult[]>([]);
const processedUpdates = useRef<NodeExecutionResult[]>([]);
/**
* User `request` to save or save&run the agent
* User `request` to save or save&run the agent, or to stop the active run.
* `state` is used to track the request status:
* - none: no request
* - saving: request was sent to save the agent
* and nodes are pending sync to update their backend ids
* - running: request was sent to run the agent
* and frontend is enqueueing execution results
* - stopping: a request to stop the active run has been sent; response is pending
* - error: request failed
*
* As of now, state will be stuck at 'running' (if run requested)
* because there's no way to know when the execution is done
*/
const [saveRunRequest, setSaveRunRequest] = useState<{
request: "none" | "save" | "run";
state: "none" | "saving" | "running" | "error";
}>({
const [saveRunRequest, setSaveRunRequest] = useState<
| {
request: "none" | "save" | "run";
state: "none" | "saving" | "error";
}
| {
request: "run" | "stop";
state: "running" | "stopping" | "error";
activeExecutionID?: string;
}
>({
request: "none",
state: "none",
});
@@ -128,13 +133,14 @@ export default function useAgentGraph(
console.error("Error saving agent");
} else if (saveRunRequest.request === "run") {
console.error(`Error saving&running agent`);
} else if (saveRunRequest.request === "stop") {
console.error(`Error stopping agent`);
}
// Reset request
setSaveRunRequest((prev) => ({
...prev,
setSaveRunRequest({
request: "none",
state: "none",
}));
});
return;
}
// When saving request is done
@@ -145,11 +151,10 @@ export default function useAgentGraph(
) {
// Reset request if only save was requested
if (saveRunRequest.request === "save") {
setSaveRunRequest((prev) => ({
...prev,
setSaveRunRequest({
request: "none",
state: "none",
}));
});
// If run was requested, run the agent
} else if (saveRunRequest.request === "run") {
if (!validateNodes()) {
@@ -161,16 +166,64 @@ export default function useAgentGraph(
return;
}
api.subscribeToExecution(savedAgent.id);
api.executeGraph(savedAgent.id);
processedUpdates.current = processedUpdates.current = [];
setSaveRunRequest({ request: "run", state: "running" });
api
.executeGraph(savedAgent.id)
.then((graphExecution) => {
setSaveRunRequest({
request: "run",
state: "running",
activeExecutionID: graphExecution.id,
});
setSaveRunRequest((prev) => ({
...prev,
request: "run",
state: "running",
}));
// Track execution until completed
const pendingNodeExecutions: Set<string> = new Set();
const cancelExecListener = api.onWebSocketMessage(
"execution_event",
(nodeResult) => {
// We are racing the server here, since we need the ID to filter events
if (nodeResult.graph_exec_id != graphExecution.id) {
return;
}
if (
nodeResult.status != "COMPLETED" &&
nodeResult.status != "FAILED"
) {
pendingNodeExecutions.add(nodeResult.node_exec_id);
} else {
pendingNodeExecutions.delete(nodeResult.node_exec_id);
}
if (pendingNodeExecutions.size == 0) {
// Assuming the first event is always a QUEUED node, and
// following nodes are QUEUED before all preceding nodes are COMPLETED,
// an empty set means the graph has finished running.
cancelExecListener();
setSaveRunRequest({ request: "none", state: "none" });
}
},
);
})
.catch(() => setSaveRunRequest({ request: "run", state: "error" }));
processedUpdates.current = processedUpdates.current = [];
}
}
// Handle stop request
if (
saveRunRequest.request === "stop" &&
saveRunRequest.state != "stopping" &&
savedAgent &&
saveRunRequest.activeExecutionID
) {
setSaveRunRequest({
request: "stop",
state: "stopping",
activeExecutionID: saveRunRequest.activeExecutionID,
});
api
.stopGraphExecution(savedAgent.id, saveRunRequest.activeExecutionID)
.then(() => setSaveRunRequest({ request: "none", state: "none" }));
}
}, [saveRunRequest, savedAgent, nodesSyncedWithSavedAgent]);
// Check if node ids are synced with saved agent
@@ -648,7 +701,7 @@ export default function useAgentGraph(
[saveAgent],
);
const requestSaveRun = useCallback(() => {
const requestSaveAndRun = useCallback(() => {
saveAgent();
setSaveRunRequest({
request: "run",
@@ -656,6 +709,23 @@ export default function useAgentGraph(
});
}, [saveAgent]);
const requestStopRun = useCallback(() => {
if (saveRunRequest.state != "running") {
return;
}
if (!saveRunRequest.activeExecutionID) {
console.warn(
"Stop requested but execution ID is unknown; state:",
saveRunRequest,
);
}
setSaveRunRequest((prev) => ({
...prev,
request: "stop",
state: "running",
}));
}, [saveRunRequest]);
return {
agentName,
setAgentName,
@@ -665,7 +735,11 @@ export default function useAgentGraph(
availableNodes,
getOutputType,
requestSave,
requestSaveRun,
requestSaveAndRun,
requestStopRun,
isSaving: saveRunRequest.state == "saving",
isRunning: saveRunRequest.state == "running",
isStopping: saveRunRequest.state == "stopping",
nodes,
setNodes,
edges,

View File

@@ -15,7 +15,7 @@ export default class AutoGPTServerAPI {
private wsUrl: string;
private webSocket: WebSocket | null = null;
private wsConnecting: Promise<void> | null = null;
private wsMessageHandlers: { [key: string]: (data: any) => void } = {};
private wsMessageHandlers: Record<string, Set<(data: any) => void>> = {};
private supabaseClient = createClient();
constructor(
@@ -126,16 +126,19 @@ export default class AutoGPTServerAPI {
runID: string,
): Promise<NodeExecutionResult[]> {
return (await this._get(`/graphs/${graphID}/executions/${runID}`)).map(
(result: any) => ({
...result,
add_time: new Date(result.add_time),
queue_time: result.queue_time ? new Date(result.queue_time) : undefined,
start_time: result.start_time ? new Date(result.start_time) : undefined,
end_time: result.end_time ? new Date(result.end_time) : undefined,
}),
parseNodeExecutionResultTimestamps,
);
}
async stopGraphExecution(
graphID: string,
runID: string,
): Promise<NodeExecutionResult[]> {
return (
await this._request("POST", `/graphs/${graphID}/executions/${runID}/stop`)
).map(parseNodeExecutionResultTimestamps);
}
private async _get(path: string) {
return this._request("GET", path);
}
@@ -205,10 +208,13 @@ export default class AutoGPTServerAPI {
};
this.webSocket.onmessage = (event) => {
const message = JSON.parse(event.data);
if (this.wsMessageHandlers[message.method]) {
this.wsMessageHandlers[message.method](message.data);
const message: WebsocketMessage = JSON.parse(event.data);
if (message.method == "execution_event") {
message.data = parseNodeExecutionResultTimestamps(message.data);
}
this.wsMessageHandlers[message.method]?.forEach((handler) =>
handler(message.data),
);
};
} catch (error) {
console.error("Error connecting to WebSocket:", error);
@@ -248,8 +254,12 @@ export default class AutoGPTServerAPI {
onWebSocketMessage<M extends keyof WebsocketMessageTypeMap>(
method: M,
handler: (data: WebsocketMessageTypeMap[M]) => void,
) {
this.wsMessageHandlers[method] = handler;
): () => void {
this.wsMessageHandlers[method] ??= new Set();
this.wsMessageHandlers[method].add(handler);
// Return detacher
return () => this.wsMessageHandlers[method].delete(handler);
}
subscribeToExecution(graphId: string) {
@@ -272,3 +282,22 @@ type WebsocketMessageTypeMap = {
subscribe: { graph_id: string };
execution_event: NodeExecutionResult;
};
type WebsocketMessage = {
[M in keyof WebsocketMessageTypeMap]: {
method: M;
data: WebsocketMessageTypeMap[M];
};
}[keyof WebsocketMessageTypeMap];
/* *** HELPER FUNCTIONS *** */
function parseNodeExecutionResultTimestamps(result: any): NodeExecutionResult {
return {
...result,
add_time: new Date(result.add_time),
queue_time: result.queue_time ? new Date(result.queue_time) : undefined,
start_time: result.start_time ? new Date(result.start_time) : undefined,
end_time: result.end_time ? new Date(result.end_time) : undefined,
};
}

View File

@@ -75,33 +75,22 @@ class PrintingBlock(Block):
yield "status", "printed"
T = TypeVar("T")
class ObjectLookupBlock(Block):
class Input(BlockSchema):
input: Any = Field(description="Dictionary to lookup from")
key: str | int = Field(description="Key to lookup in the dictionary")
class ObjectLookupBaseInput(BlockSchema, Generic[T]):
input: T = Field(description="Dictionary to lookup from")
key: str | int = Field(description="Key to lookup in the dictionary")
class ObjectLookupBaseOutput(BlockSchema, Generic[T]):
output: T = Field(description="Value found for the given key")
missing: T = Field(description="Value of the input that missing the key")
class ObjectLookupBase(Block, ABC, Generic[T]):
@abstractmethod
def block_id(self) -> str:
pass
def __init__(self, *args, **kwargs):
input_schema = ObjectLookupBaseInput[T]
output_schema = ObjectLookupBaseOutput[T]
class Output(BlockSchema):
output: Any = Field(description="Value found for the given key")
missing: Any = Field(description="Value of the input that missing the key")
def __init__(self):
super().__init__(
id=self.block_id(),
id="b2g2c3d4-5e6f-7g8h-9i0j-k1l2m3n4o5p6",
description="Lookup the given key in the input dictionary/object/list and return the value.",
input_schema=input_schema,
output_schema=output_schema,
input_schema=ObjectLookupBlock.Input,
output_schema=ObjectLookupBlock.Output,
test_input=[
{"input": {"apple": 1, "banana": 2, "cherry": 3}, "key": "banana"},
{"input": {"x": 10, "y": 20, "z": 30}, "key": "w"},
@@ -118,11 +107,10 @@ class ObjectLookupBase(Block, ABC, Generic[T]):
("output", "key"),
("output", ["v1", "v3"]),
],
*args,
**kwargs,
categories={BlockCategory.BASIC},
)
def run(self, input_data: ObjectLookupBaseInput[T]) -> BlockOutput:
def run(self, input_data: Input) -> BlockOutput:
obj = input_data.input
key = input_data.key
@@ -143,15 +131,50 @@ class ObjectLookupBase(Block, ABC, Generic[T]):
yield "missing", input_data.input
class ObjectLookupBlock(ObjectLookupBase[Any]):
def __init__(self):
super().__init__(categories={BlockCategory.BASIC})
T = TypeVar("T")
class InputOutputBlockInput(BlockSchema, Generic[T]):
value: T = Field(description="The value to be passed as input/output.")
name: str = Field(description="The name of the input/output.")
class InputOutputBlockOutput(BlockSchema, Generic[T]):
value: T = Field(description="The value passed as input/output.")
class InputOutputBlockBase(Block, ABC, Generic[T]):
@abstractmethod
def block_id(self) -> str:
return "b2g2c3d4-5e6f-7g8h-9i0j-k1l2m3n4o5p6"
pass
def __init__(self, *args, **kwargs):
input_schema = InputOutputBlockInput[T]
output_schema = InputOutputBlockOutput[T]
super().__init__(
id=self.block_id(),
description="This block is used to define the input & output of a graph.",
input_schema=input_schema,
output_schema=output_schema,
test_input=[
{"value": {"apple": 1, "banana": 2, "cherry": 3}, "name": "input_1"},
{"value": MockObject(value="!!", key="key"), "name": "input_2"},
],
test_output=[
("value", {"apple": 1, "banana": 2, "cherry": 3}),
("value", MockObject(value="!!", key="key")),
],
static_output=True,
*args,
**kwargs,
)
def run(self, input_data: InputOutputBlockInput[T]) -> BlockOutput:
yield "value", input_data.value
class InputBlock(ObjectLookupBase[Any]):
class InputBlock(InputOutputBlockBase[Any]):
def __init__(self):
super().__init__(categories={BlockCategory.INPUT, BlockCategory.BASIC})
@@ -159,7 +182,7 @@ class InputBlock(ObjectLookupBase[Any]):
return "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
class OutputBlock(ObjectLookupBase[Any]):
class OutputBlock(InputOutputBlockBase[Any]):
def __init__(self):
super().__init__(categories={BlockCategory.OUTPUT, BlockCategory.BASIC})

View File

@@ -1,6 +1,6 @@
import time
from datetime import datetime, timedelta
from typing import Union
from typing import Any, Union
from autogpt_server.data.block import Block, BlockCategory, BlockOutput, BlockSchema
@@ -103,6 +103,7 @@ class CurrentDateAndTimeBlock(Block):
class TimerBlock(Block):
class Input(BlockSchema):
message: Any = "timer finished"
seconds: Union[int, str] = 0
minutes: Union[int, str] = 0
hours: Union[int, str] = 0
@@ -120,9 +121,11 @@ class TimerBlock(Block):
output_schema=TimerBlock.Output,
test_input=[
{"seconds": 1},
{"message": "Custom message"},
],
test_output=[
("message", "timer finished"),
("message", "Custom message"),
],
)
@@ -136,4 +139,4 @@ class TimerBlock(Block):
total_seconds = seconds + minutes * 60 + hours * 3600 + days * 86400
time.sleep(total_seconds)
yield "message", "timer finished"
yield "message", input_data.message

View File

@@ -9,7 +9,7 @@ from prisma.models import (
AgentNodeExecution,
AgentNodeExecutionInputOutput,
)
from prisma.types import AgentGraphExecutionWhereInput
from prisma.types import AgentGraphExecutionWhereInput, AgentNodeExecutionInclude
from pydantic import BaseModel
from autogpt_server.data.block import BlockData, BlockInput, CompletedBlockOutput
@@ -108,7 +108,7 @@ class ExecutionResult(BaseModel):
# --------------------- Model functions --------------------- #
EXECUTION_RESULT_INCLUDE = {
EXECUTION_RESULT_INCLUDE: AgentNodeExecutionInclude = {
"Input": True,
"Output": True,
"AgentNode": True,
@@ -120,7 +120,7 @@ async def create_graph_execution(
graph_id: str,
graph_version: int,
nodes_input: list[tuple[str, BlockInput]],
user_id: str,
user_id: str | None = None,
) -> tuple[str, list[ExecutionResult]]:
"""
Create a new AgentGraphExecution record.
@@ -148,9 +148,7 @@ async def create_graph_execution(
},
"userId": user_id,
},
include={
"AgentNodeExecutions": {"include": EXECUTION_RESULT_INCLUDE} # type: ignore
},
include={"AgentNodeExecutions": {"include": EXECUTION_RESULT_INCLUDE}}, # type: ignore
)
return result.id, [
@@ -263,7 +261,7 @@ async def update_execution_status(
res = await AgentNodeExecution.prisma().update(
where={"id": node_exec_id},
data=data, # type: ignore
include=EXECUTION_RESULT_INCLUDE, # type: ignore
include=EXECUTION_RESULT_INCLUDE,
)
if not res:
raise ValueError(f"Execution {node_exec_id} not found.")
@@ -271,6 +269,26 @@ async def update_execution_status(
return ExecutionResult.from_db(res)
async def get_graph_execution(
graph_exec_id: str, user_id: str
) -> AgentGraphExecution | None:
"""
Retrieve a specific graph execution by its ID.
Args:
graph_exec_id (str): The ID of the graph execution to retrieve.
user_id (str): The ID of the user to whom the graph (execution) belongs.
Returns:
AgentGraphExecution | None: The graph execution if found, None otherwise.
"""
execution = await AgentGraphExecution.prisma().find_first(
where={"id": graph_exec_id, "userId": user_id},
include={"AgentNodeExecutions": {"include": EXECUTION_RESULT_INCLUDE}}, # type: ignore
)
return execution
async def list_executions(graph_id: str, graph_version: int | None = None) -> list[str]:
where: AgentGraphExecutionWhereInput = {"agentGraphId": graph_id}
if graph_version is not None:
@@ -282,7 +300,7 @@ async def list_executions(graph_id: str, graph_version: int | None = None) -> li
async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
executions = await AgentNodeExecution.prisma().find_many(
where={"agentGraphExecutionId": graph_exec_id},
include=EXECUTION_RESULT_INCLUDE, # type: ignore
include=EXECUTION_RESULT_INCLUDE,
order=[
{"queuedTime": "asc"},
{"addedTime": "asc"}, # Fallback: Incomplete execs has no queuedTime.
@@ -379,7 +397,7 @@ async def get_latest_execution(node_id: str, graph_eid: str) -> ExecutionResult
"executionData": {"not": None},
},
order={"queuedTime": "desc"},
include=EXECUTION_RESULT_INCLUDE, # type: ignore
include=EXECUTION_RESULT_INCLUDE,
)
if not execution:
return None
@@ -395,6 +413,6 @@ async def get_incomplete_executions(
"agentGraphExecutionId": graph_eid,
"executionStatus": ExecutionStatus.INCOMPLETE,
},
include=EXECUTION_RESULT_INCLUDE, # type: ignore
include=EXECUTION_RESULT_INCLUDE,
)
return [ExecutionResult.from_db(execution) for execution in executions]

View File

@@ -294,6 +294,7 @@ async def get_graphs_meta(
Args:
filter_by: An optional filter to either select templates or active graphs.
user_id: An optional user ID to filter the graphs by.
Returns:
list[GraphMeta]: A list of objects representing the retrieved graph metadata.

View File

@@ -1,7 +1,10 @@
import asyncio
import logging
from concurrent.futures import Future, ProcessPoolExecutor, TimeoutError
import multiprocessing
import threading
from concurrent.futures import Future, ProcessPoolExecutor
from contextlib import contextmanager
from multiprocessing.pool import AsyncResult, Pool
from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar
if TYPE_CHECKING:
@@ -16,6 +19,7 @@ from autogpt_server.data.execution import (
GraphExecution,
NodeExecution,
create_graph_execution,
get_execution_results,
get_incomplete_executions,
get_latest_execution,
merge_execution_input,
@@ -350,76 +354,111 @@ class Executor:
@classmethod
def on_graph_executor_start(cls):
cls.pool_size = Config().num_node_workers
cls.executor = ProcessPoolExecutor(
max_workers=cls.pool_size,
initializer=cls.on_node_executor_start,
)
cls._init_node_executor_pool()
logger.warning(f"Graph executor started with max-{cls.pool_size} node workers.")
@classmethod
def on_graph_execution(cls, graph_data: GraphExecution):
def _init_node_executor_pool(cls):
cls.executor = Pool(
processes=cls.pool_size,
initializer=cls.on_node_executor_start,
)
@classmethod
def on_graph_execution(cls, graph_data: GraphExecution, cancel: threading.Event):
prefix = get_log_prefix(graph_data.graph_exec_id, "*")
logger.warning(f"{prefix} Start graph execution")
finished = False
def cancel_handler():
while not cancel.is_set():
cancel.wait(1)
if finished:
return
cls.executor.terminate()
logger.info(
f"{prefix} Terminated graph execution {graph_data.graph_exec_id}"
)
cls._init_node_executor_pool()
cancel_thread = threading.Thread(target=cancel_handler)
cancel_thread.start()
try:
queue = ExecutionQueue[NodeExecution]()
for node_exec in graph_data.start_node_execs:
queue.add(node_exec)
futures: dict[str, Future] = {}
running_executions: dict[str, AsyncResult] = {}
while not queue.empty():
execution = queue.get()
if cancel.is_set():
return
exec_data = queue.get()
# Avoid parallel execution of the same node.
fut = futures.get(execution.node_id)
if fut and not fut.done():
execution = running_executions.get(exec_data.node_id)
if execution and not execution.ready():
# TODO (performance improvement):
# Wait for the completion of the same node execution is blocking.
# To improve this we need a separate queue for each node.
# Re-enqueueing the data back to the queue will disrupt the order.
cls.wait_future(fut, timeout=None)
execution.wait()
futures[execution.node_id] = cls.executor.submit(
cls.on_node_execution, queue, execution
running_executions[exec_data.node_id] = cls.executor.apply_async(
cls.on_node_execution,
(queue, exec_data),
callback=lambda _: running_executions.pop(exec_data.node_id),
)
# Avoid terminating graph execution when some nodes are still running.
while queue.empty() and futures:
for node_id, future in list(futures.items()):
if future.done():
del futures[node_id]
elif queue.empty():
cls.wait_future(future)
while queue.empty() and running_executions:
for execution in list(running_executions.values()):
if cancel.is_set():
return
if not queue.empty():
break # yield to parent loop to execute new queue items
execution.wait(3)
logger.warning(f"{prefix} Finished graph execution")
except Exception as e:
logger.exception(f"{prefix} Failed graph execution: {e}")
@classmethod
def wait_future(cls, future: Future, timeout: int | None = 3):
try:
if not future.done():
future.result(timeout=timeout)
except TimeoutError:
# Avoid being blocked by long-running node, by not waiting its completion.
pass
finally:
if not cancel.is_set():
finished = True
cancel.set()
cancel_thread.join()
class ExecutionManager(AppService):
def __init__(self):
self.pool_size = Config().num_graph_workers
self.queue = ExecutionQueue[GraphExecution]()
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
def run_service(self):
with ProcessPoolExecutor(
max_workers=self.pool_size,
initializer=Executor.on_graph_executor_start,
) as executor:
logger.warning(
sync_manager = multiprocessing.Manager()
logger.info(
f"Execution manager started with max-{self.pool_size} graph workers."
)
while True:
executor.submit(Executor.on_graph_execution, self.queue.get())
graph_exec_data = self.queue.get()
graph_exec_id = graph_exec_data.graph_exec_id
cancel_event = sync_manager.Event()
future = executor.submit(
Executor.on_graph_execution, graph_exec_data, cancel_event
)
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
future.add_done_callback(
lambda _: self.active_graph_runs.pop(graph_exec_id)
)
@property
def agent_server_client(self) -> "AgentServer":
@@ -427,8 +466,8 @@ class ExecutionManager(AppService):
@expose
def add_execution(
self, graph_id: str, data: BlockInput, user_id: str
) -> dict[Any, Any]:
self, graph_id: str, data: BlockInput, user_id: str | None = None
) -> dict[str, Any]:
graph: Graph | None = self.run_and_wait(get_graph(graph_id, user_id=user_id))
if not graph:
raise Exception(f"Graph #{graph_id} not found.")
@@ -436,10 +475,11 @@ class ExecutionManager(AppService):
nodes_input = []
for node in graph.starting_nodes:
input_data = {}
if isinstance(get_block(node.block_id), InputBlock):
input_data = {"input": data}
else:
input_data = {}
name = node.input_default.get("name")
if name and name in data:
input_data = {"value": data[name]}
input_data, error = validate_exec(node, input_data)
if input_data is None:
@@ -479,4 +519,45 @@ class ExecutionManager(AppService):
)
self.queue.add(graph_exec)
return {"id": graph_exec_id}
return graph_exec.model_dump()
@expose
def cancel_execution(self, graph_exec_id: str) -> None:
"""
Mechanism:
1. Set the cancel event
2. Graph executor's cancel handler thread detects the event, terminates workers,
reinitializes worker pool, and returns.
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
"""
if graph_exec_id not in self.active_graph_runs:
raise Exception(
f"Graph execution #{graph_exec_id} not active/running: "
"possibly already completed/cancelled."
)
future, cancel_event = self.active_graph_runs[graph_exec_id]
if cancel_event.is_set():
return
cancel_event.set()
future.result()
# Update the status of the unfinished node executions
node_execs = self.run_and_wait(get_execution_results(graph_exec_id))
for node_exec in node_execs:
if node_exec.status not in (
ExecutionStatus.COMPLETED,
ExecutionStatus.FAILED,
):
exec_update = self.run_and_wait(
update_execution_status(
node_exec.node_exec_id, ExecutionStatus.FAILED
)
)
self.run_and_wait(
upsert_execution_output(
node_exec.node_exec_id, "error", "TERMINATED"
)
)
self.agent_server_client.send_execution_update(exec_update.model_dump())

View File

@@ -37,7 +37,8 @@ class ExecutionScheduler(AppService):
def __refresh_jobs_from_db(self, scheduler: BackgroundScheduler):
schedules = self.run_and_wait(model.get_active_schedules(self.last_check))
for schedule in schedules:
self.last_check = max(self.last_check, schedule.last_updated)
if schedule.last_updated:
self.last_check = max(self.last_check, schedule.last_updated)
if not schedule.is_enabled:
log(f"Removing recurring job {schedule.id}: {schedule.schedule}")

View File

@@ -22,14 +22,10 @@ from fastapi.responses import JSONResponse
import autogpt_server.server.ws_api
from autogpt_server.data import block, db
from autogpt_server.data import execution as execution_db
from autogpt_server.data import graph as graph_db
from autogpt_server.data import user as user_db
from autogpt_server.data.block import BlockInput, CompletedBlockOutput
from autogpt_server.data.execution import (
ExecutionResult,
get_execution_results,
list_executions,
)
from autogpt_server.data.user import DEFAULT_USER_ID, get_or_create_user
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
from autogpt_server.server.conn_manager import ConnectionManager
@@ -58,7 +54,7 @@ def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
class AgentServer(AppService):
event_queue: asyncio.Queue[ExecutionResult] = asyncio.Queue()
event_queue: asyncio.Queue[execution_db.ExecutionResult] = asyncio.Queue()
manager = ConnectionManager()
mutex = KeyedMutex()
use_db = False
@@ -66,7 +62,7 @@ class AgentServer(AppService):
async def event_broadcaster(self):
while True:
event: ExecutionResult = await self.event_queue.get()
event: execution_db.ExecutionResult = await self.event_queue.get()
await self.manager.send_execution_result(event)
@asynccontextmanager
@@ -193,10 +189,15 @@ class AgentServer(AppService):
methods=["GET"],
)
router.add_api_route(
path="/graphs/{graph_id}/executions/{run_id}",
endpoint=self.get_run_execution_results,
path="/graphs/{graph_id}/executions/{graph_exec_id}",
endpoint=self.get_graph_run_node_execution_results,
methods=["GET"],
)
router.add_api_route(
path="/graphs/{graph_id}/executions/{graph_exec_id}/stop",
endpoint=self.stop_graph_run,
methods=["POST"],
)
router.add_api_route(
path="/graphs/{graph_id}/schedules",
endpoint=self.create_schedule, # type: ignore
@@ -508,15 +509,29 @@ class AgentServer(AppService):
graph_id: str,
node_input: dict[Any, Any],
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[Any, Any]:
) -> dict[str, Any]: # FIXME: add proper return type
try:
return self.execution_manager_client.add_execution(
graph_exec = self.execution_manager_client.add_execution(
graph_id, node_input, user_id=user_id
)
return {"id": graph_exec["graph_exec_id"]}
except Exception as e:
msg = e.__str__().encode().decode("unicode_escape")
raise HTTPException(status_code=400, detail=msg)
async def stop_graph_run(
self, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> list[execution_db.ExecutionResult]:
if not await execution_db.get_graph_execution(graph_exec_id, user_id):
raise HTTPException(
404, detail=f"Agent execution #{graph_exec_id} not found"
)
self.execution_manager_client.cancel_execution(graph_exec_id)
# Retrieve & return canceled graph execution in its final state
return await execution_db.get_execution_results(graph_exec_id)
@classmethod
async def list_graph_runs(
cls,
@@ -531,17 +546,20 @@ class AgentServer(AppService):
status_code=404, detail=f"Agent #{graph_id}{rev} not found."
)
return await list_executions(graph_id, graph_version)
return await execution_db.list_executions(graph_id, graph_version)
@classmethod
async def get_run_execution_results(
cls, graph_id: str, run_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> list[ExecutionResult]:
async def get_graph_run_node_execution_results(
cls,
graph_id: str,
graph_exec_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> list[execution_db.ExecutionResult]:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return await get_execution_results(run_id)
return await execution_db.get_execution_results(graph_exec_id)
async def create_schedule(
self,
@@ -579,7 +597,7 @@ class AgentServer(AppService):
@expose
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
execution_result = ExecutionResult(**execution_result_dict)
execution_result = execution_db.ExecutionResult(**execution_result_dict)
self.run_and_wait(self.event_queue.put(execution_result))
@expose

View File

@@ -29,11 +29,11 @@ def create_test_graph() -> graph.Graph:
nodes = [
graph.Node(
block_id=InputBlock().id,
input_default={"key": "input_1"},
input_default={"name": "input_1"},
),
graph.Node(
block_id=InputBlock().id,
input_default={"key": "input_2"},
input_default={"name": "input_2"},
),
graph.Node(
block_id=TextFormatterBlock().id,
@@ -48,13 +48,13 @@ def create_test_graph() -> graph.Graph:
graph.Link(
source_id=nodes[0].id,
sink_id=nodes[2].id,
source_name="output",
source_name="value",
sink_name="values_#_a",
),
graph.Link(
source_id=nodes[1].id,
sink_id=nodes[2].id,
source_name="output",
source_name="value",
sink_name="values_#_b",
),
graph.Link(

View File

@@ -16,11 +16,21 @@ from autogpt_server.util.settings import Config
logger = logging.getLogger(__name__)
conn_retry = retry(stop=stop_after_delay(5), wait=wait_exponential(multiplier=0.1))
T = TypeVar("T")
C = TypeVar("C", bound=Callable)
pyro_host = Config().pyro_host
def expose(func: Callable) -> Callable:
def expose(func: C) -> C:
"""
Decorator to mark a method or class to be exposed for remote calls.
## ⚠️ Gotcha
The types on the exposed function signature are respected **as long as they are
fully picklable**. This is not the case for Pydantic models, so if you really need
to pass a model, try dumping the model and passing the resulting dict instead.
"""
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
@@ -29,7 +39,7 @@ def expose(func: Callable) -> Callable:
logger.exception(msg)
raise Exception(msg, e)
return pyro.expose(wrapper)
return pyro.expose(wrapper) # type: ignore
class PyroNameServer(AppProcess):
@@ -58,7 +68,7 @@ class AppService(AppProcess):
def __run_async(self, coro: Coroutine[T, Any, T]):
return asyncio.run_coroutine_threadsafe(coro, self.shared_event_loop)
def run_and_wait(self, coro: Coroutine[T, Any, T]) -> T:
def run_and_wait(self, coro: Coroutine[Any, Any, T]) -> T:
future = self.__run_async(coro)
return future.result()
@@ -100,7 +110,6 @@ def get_service_client(service_type: Type[AS]) -> AS:
service_name = service_type.service_name
class DynamicClient:
@conn_retry
def __init__(self):
ns = pyro.locate_ns()

View File

@@ -23,7 +23,6 @@ class SpinTestServer:
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
async def __aenter__(self):
self.name_server.__enter__()
self.setup_dependency_overrides()
self.agent_server.__enter__()
@@ -59,7 +58,7 @@ async def wait_execution(
timeout: int = 20,
) -> list:
async def is_execution_completed():
execs = await AgentServer().get_run_execution_results(
execs = await AgentServer().get_graph_run_node_execution_results(
graph_id, graph_exec_id, user_id
)
return (
@@ -74,7 +73,7 @@ async def wait_execution(
# Wait for the executions to complete
for i in range(timeout):
if await is_execution_completed():
return await AgentServer().get_run_execution_results(
return await AgentServer().get_graph_run_node_execution_results(
graph_id, graph_exec_id, user_id
)
time.sleep(1)

View File

@@ -0,0 +1,20 @@
-- RedefineTables
PRAGMA foreign_keys=OFF;
CREATE TABLE "new_AgentGraphExecutionSchedule" (
"id" TEXT NOT NULL PRIMARY KEY,
"agentGraphId" TEXT NOT NULL,
"agentGraphVersion" INTEGER NOT NULL DEFAULT 1,
"schedule" TEXT NOT NULL,
"isEnabled" BOOLEAN NOT NULL DEFAULT true,
"inputData" TEXT NOT NULL,
"lastUpdated" DATETIME NOT NULL,
"userId" TEXT,
CONSTRAINT "AgentGraphExecutionSchedule_agentGraphId_agentGraphVersion_fkey" FOREIGN KEY ("agentGraphId", "agentGraphVersion") REFERENCES "AgentGraph" ("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE,
CONSTRAINT "AgentGraphExecutionSchedule_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User" ("id") ON DELETE SET NULL ON UPDATE CASCADE
);
INSERT INTO "new_AgentGraphExecutionSchedule" ("agentGraphId", "agentGraphVersion", "id", "inputData", "isEnabled", "lastUpdated", "schedule", "userId") SELECT "agentGraphId", "agentGraphVersion", "id", "inputData", "isEnabled", "lastUpdated", "schedule", "userId" FROM "AgentGraphExecutionSchedule";
DROP TABLE "AgentGraphExecutionSchedule";
ALTER TABLE "new_AgentGraphExecutionSchedule" RENAME TO "AgentGraphExecutionSchedule";
CREATE INDEX "AgentGraphExecutionSchedule_isEnabled_idx" ON "AgentGraphExecutionSchedule"("isEnabled");
PRAGMA foreign_key_check;
PRAGMA foreign_keys=ON;

View File

@@ -0,0 +1,8 @@
-- DropForeignKey
ALTER TABLE "AgentGraphExecutionSchedule" DROP CONSTRAINT "AgentGraphExecutionSchedule_userId_fkey";
-- AlterTable
ALTER TABLE "AgentGraphExecutionSchedule" ALTER COLUMN "userId" DROP NOT NULL;
-- AddForeignKey
ALTER TABLE "AgentGraphExecutionSchedule" ADD CONSTRAINT "AgentGraphExecutionSchedule_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE SET NULL ON UPDATE CASCADE;

View File

@@ -12,11 +12,11 @@ generator client {
// User model to mirror Auth provider users
model User {
id String @id // This should match the Supabase user ID
email String @unique
name String?
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
id String @id // This should match the Supabase user ID
email String @unique
name String?
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
// Relations
AgentGraphs AgentGraph[]
@@ -184,8 +184,8 @@ model AgentGraphExecutionSchedule {
lastUpdated DateTime @updatedAt
// Link to User model
userId String
user User @relation(fields: [userId], references: [id])
userId String?
user User? @relation(fields: [userId], references: [id])
@@index([isEnabled])
}

View File

@@ -11,11 +11,11 @@ generator client {
// User model to mirror Auth provider users
model User {
id String @id // This should match the Supabase user ID
email String @unique
name String?
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
id String @id // This should match the Supabase user ID
email String @unique
name String?
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
// Relations
AgentGraphs AgentGraph[]
@@ -183,8 +183,8 @@ model AgentGraphExecutionSchedule {
lastUpdated DateTime @updatedAt
// Link to User model
userId String
user User @relation(fields: [userId], references: [id])
userId String?
user User? @relation(fields: [userId], references: [id])
@@index([isEnabled])
}
}

View File

@@ -35,8 +35,7 @@ async def assert_sample_graph_executions(
test_user: User,
graph_exec_id: str,
):
input = {"input_1": "Hello", "input_2": "World"}
executions = await agent_server.get_run_execution_results(
executions = await agent_server.get_graph_run_node_execution_results(
test_graph.id, graph_exec_id, test_user.id
)
@@ -44,16 +43,16 @@ async def assert_sample_graph_executions(
exec = executions[0]
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.graph_exec_id == graph_exec_id
assert exec.output_data == {"output": ["Hello"]}
assert exec.input_data == {"input": input, "key": "input_1"}
assert exec.output_data == {"value": ["Hello"]}
assert exec.input_data == {"value": "Hello", "name": "input_1"}
assert exec.node_id in [test_graph.nodes[0].id, test_graph.nodes[1].id]
# Executing ValueBlock
exec = executions[1]
assert exec.status == execution.ExecutionStatus.COMPLETED
assert exec.graph_exec_id == graph_exec_id
assert exec.output_data == {"output": ["World"]}
assert exec.input_data == {"input": input, "key": "input_2"}
assert exec.output_data == {"value": ["World"]}
assert exec.input_data == {"value": "World", "name": "input_2"}
assert exec.node_id in [test_graph.nodes[0].id, test_graph.nodes[1].id]
# Executing TextFormatterBlock
@@ -151,7 +150,7 @@ async def test_input_pin_always_waited(server):
server.agent_server, server.exec_manager, test_graph, test_user, {}, 3
)
executions = await server.agent_server.get_run_execution_results(
executions = await server.agent_server.get_graph_run_node_execution_results(
test_graph.id, graph_exec_id, test_user.id
)
assert len(executions) == 3
@@ -231,7 +230,7 @@ async def test_static_input_link_on_graph(server):
graph_exec_id = await execute_graph(
server.agent_server, server.exec_manager, test_graph, test_user, {}, 8
)
executions = await server.agent_server.get_run_execution_results(
executions = await server.agent_server.get_graph_run_node_execution_results(
test_graph.id, graph_exec_id, test_user.id
)
assert len(executions) == 8