mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-08 13:55:06 -05:00
Compare commits
19 Commits
fix/execut
...
zamilmajdy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b3d873a4b9 | ||
|
|
c9989c47ec | ||
|
|
cfe85e40d8 | ||
|
|
0d8760931d | ||
|
|
0c2c8085bd | ||
|
|
efdd0fb04c | ||
|
|
bbe82fc9c1 | ||
|
|
7e0b5c3235 | ||
|
|
c21bdfeb47 | ||
|
|
c062786f80 | ||
|
|
d63ab9a2f9 | ||
|
|
fce6394a49 | ||
|
|
13e7716424 | ||
|
|
2973567010 | ||
|
|
b6c4fc4742 | ||
|
|
f9a3170296 | ||
|
|
a74f76893e | ||
|
|
e6aaf71f21 | ||
|
|
31129bd080 |
@@ -34,7 +34,12 @@ import ConnectionLine from "./ConnectionLine";
|
||||
import { Control, ControlPanel } from "@/components/edit/control/ControlPanel";
|
||||
import { SaveControl } from "@/components/edit/control/SaveControl";
|
||||
import { BlocksControl } from "@/components/edit/control/BlocksControl";
|
||||
import { IconPlay, IconRedo2, IconUndo2 } from "@/components/ui/icons";
|
||||
import {
|
||||
IconPlay,
|
||||
IconRedo2,
|
||||
IconSquare,
|
||||
IconUndo2,
|
||||
} from "@/components/ui/icons";
|
||||
import useAgentGraph from "@/hooks/useAgentGraph";
|
||||
|
||||
// This is for the history, this is the minimum distance a block must move before it is logged
|
||||
@@ -71,7 +76,9 @@ const FlowEditor: React.FC<{
|
||||
availableNodes,
|
||||
getOutputType,
|
||||
requestSave,
|
||||
requestSaveRun,
|
||||
requestSaveAndRun,
|
||||
requestStopRun,
|
||||
isRunning,
|
||||
nodes,
|
||||
setNodes,
|
||||
edges,
|
||||
@@ -465,9 +472,9 @@ const FlowEditor: React.FC<{
|
||||
onClick: handleRedo,
|
||||
},
|
||||
{
|
||||
label: "Run",
|
||||
icon: <IconPlay />,
|
||||
onClick: requestSaveRun,
|
||||
label: !isRunning ? "Run" : "Stop",
|
||||
icon: !isRunning ? <IconPlay /> : <IconSquare />,
|
||||
onClick: !isRunning ? requestSaveAndRun : requestStopRun,
|
||||
},
|
||||
];
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import React from "react";
|
||||
import { GraphMeta } from "@/lib/autogpt-server-api";
|
||||
import React, { useCallback } from "react";
|
||||
import AutoGPTServerAPI, { GraphMeta } from "@/lib/autogpt-server-api";
|
||||
import { FlowRun } from "@/lib/types";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import Link from "next/link";
|
||||
import { buttonVariants } from "@/components/ui/button";
|
||||
import { Button, buttonVariants } from "@/components/ui/button";
|
||||
import { IconSquare } from "@/components/ui/icons";
|
||||
import { Pencil2Icon } from "@radix-ui/react-icons";
|
||||
import moment from "moment/moment";
|
||||
import { FlowRunStatusBadge } from "@/components/monitor/FlowRunStatusBadge";
|
||||
@@ -20,6 +21,11 @@ export const FlowRunInfo: React.FC<
|
||||
);
|
||||
}
|
||||
|
||||
const handleStopRun = useCallback(() => {
|
||||
const api = new AutoGPTServerAPI();
|
||||
api.stopGraphExecution(flow.id, flowRun.id);
|
||||
}, [flow.id, flowRun.id]);
|
||||
|
||||
return (
|
||||
<Card {...props}>
|
||||
<CardHeader className="flex-row items-center justify-between space-x-3 space-y-0">
|
||||
@@ -34,12 +40,19 @@ export const FlowRunInfo: React.FC<
|
||||
Run ID: <code>{flowRun.id}</code>
|
||||
</p>
|
||||
</div>
|
||||
<Link
|
||||
className={buttonVariants({ variant: "outline" })}
|
||||
href={`/build?flowID=${flow.id}`}
|
||||
>
|
||||
<Pencil2Icon className="mr-2" /> Edit Agent
|
||||
</Link>
|
||||
<div className="flex space-x-2">
|
||||
{flowRun.status === "running" && (
|
||||
<Button onClick={handleStopRun} variant="destructive">
|
||||
<IconSquare className="mr-2" /> Stop Run
|
||||
</Button>
|
||||
)}
|
||||
<Link
|
||||
className={buttonVariants({ variant: "outline" })}
|
||||
href={`/build?flowID=${flow.id}`}
|
||||
>
|
||||
<Pencil2Icon className="mr-2" /> Edit Agent
|
||||
</Link>
|
||||
</div>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<p>
|
||||
|
||||
@@ -405,6 +405,40 @@ export const IconPlay = createIcon((props) => (
|
||||
</svg>
|
||||
));
|
||||
|
||||
/**
|
||||
* Square icon component.
|
||||
*
|
||||
* @component IconSquare
|
||||
* @param {IconProps} props - The props object containing additional attributes and event handlers for the icon.
|
||||
* @returns {JSX.Element} - The square icon.
|
||||
*
|
||||
* @example
|
||||
* // Default usage this is the standard usage
|
||||
* <IconSquare />
|
||||
*
|
||||
* @example
|
||||
* // With custom color and size these should be used sparingly and only when necessary
|
||||
* <IconSquare className="text-primary" size="lg" />
|
||||
*
|
||||
* @example
|
||||
* // With custom size and onClick handler
|
||||
* <IconSquare size="sm" onClick={handleOnClick} />
|
||||
*/
|
||||
export const IconSquare = createIcon((props) => (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
{...props}
|
||||
>
|
||||
<rect width="18" height="18" x="3" y="3" rx="2" />
|
||||
</svg>
|
||||
));
|
||||
|
||||
/**
|
||||
* Package2 icon component.
|
||||
*
|
||||
|
||||
@@ -31,22 +31,27 @@ export default function useAgentGraph(
|
||||
const [updateQueue, setUpdateQueue] = useState<NodeExecutionResult[]>([]);
|
||||
const processedUpdates = useRef<NodeExecutionResult[]>([]);
|
||||
/**
|
||||
* User `request` to save or save&run the agent
|
||||
* User `request` to save or save&run the agent, or to stop the active run.
|
||||
* `state` is used to track the request status:
|
||||
* - none: no request
|
||||
* - saving: request was sent to save the agent
|
||||
* and nodes are pending sync to update their backend ids
|
||||
* - running: request was sent to run the agent
|
||||
* and frontend is enqueueing execution results
|
||||
* - stopping: a request to stop the active run has been sent; response is pending
|
||||
* - error: request failed
|
||||
*
|
||||
* As of now, state will be stuck at 'running' (if run requested)
|
||||
* because there's no way to know when the execution is done
|
||||
*/
|
||||
const [saveRunRequest, setSaveRunRequest] = useState<{
|
||||
request: "none" | "save" | "run";
|
||||
state: "none" | "saving" | "running" | "error";
|
||||
}>({
|
||||
const [saveRunRequest, setSaveRunRequest] = useState<
|
||||
| {
|
||||
request: "none" | "save" | "run";
|
||||
state: "none" | "saving" | "error";
|
||||
}
|
||||
| {
|
||||
request: "run" | "stop";
|
||||
state: "running" | "stopping" | "error";
|
||||
activeExecutionID?: string;
|
||||
}
|
||||
>({
|
||||
request: "none",
|
||||
state: "none",
|
||||
});
|
||||
@@ -128,13 +133,14 @@ export default function useAgentGraph(
|
||||
console.error("Error saving agent");
|
||||
} else if (saveRunRequest.request === "run") {
|
||||
console.error(`Error saving&running agent`);
|
||||
} else if (saveRunRequest.request === "stop") {
|
||||
console.error(`Error stopping agent`);
|
||||
}
|
||||
// Reset request
|
||||
setSaveRunRequest((prev) => ({
|
||||
...prev,
|
||||
setSaveRunRequest({
|
||||
request: "none",
|
||||
state: "none",
|
||||
}));
|
||||
});
|
||||
return;
|
||||
}
|
||||
// When saving request is done
|
||||
@@ -145,11 +151,10 @@ export default function useAgentGraph(
|
||||
) {
|
||||
// Reset request if only save was requested
|
||||
if (saveRunRequest.request === "save") {
|
||||
setSaveRunRequest((prev) => ({
|
||||
...prev,
|
||||
setSaveRunRequest({
|
||||
request: "none",
|
||||
state: "none",
|
||||
}));
|
||||
});
|
||||
// If run was requested, run the agent
|
||||
} else if (saveRunRequest.request === "run") {
|
||||
if (!validateNodes()) {
|
||||
@@ -161,16 +166,64 @@ export default function useAgentGraph(
|
||||
return;
|
||||
}
|
||||
api.subscribeToExecution(savedAgent.id);
|
||||
api.executeGraph(savedAgent.id);
|
||||
processedUpdates.current = processedUpdates.current = [];
|
||||
setSaveRunRequest({ request: "run", state: "running" });
|
||||
api
|
||||
.executeGraph(savedAgent.id)
|
||||
.then((graphExecution) => {
|
||||
setSaveRunRequest({
|
||||
request: "run",
|
||||
state: "running",
|
||||
activeExecutionID: graphExecution.id,
|
||||
});
|
||||
|
||||
setSaveRunRequest((prev) => ({
|
||||
...prev,
|
||||
request: "run",
|
||||
state: "running",
|
||||
}));
|
||||
// Track execution until completed
|
||||
const pendingNodeExecutions: Set<string> = new Set();
|
||||
const cancelExecListener = api.onWebSocketMessage(
|
||||
"execution_event",
|
||||
(nodeResult) => {
|
||||
// We are racing the server here, since we need the ID to filter events
|
||||
if (nodeResult.graph_exec_id != graphExecution.id) {
|
||||
return;
|
||||
}
|
||||
if (
|
||||
nodeResult.status != "COMPLETED" &&
|
||||
nodeResult.status != "FAILED"
|
||||
) {
|
||||
pendingNodeExecutions.add(nodeResult.node_exec_id);
|
||||
} else {
|
||||
pendingNodeExecutions.delete(nodeResult.node_exec_id);
|
||||
}
|
||||
if (pendingNodeExecutions.size == 0) {
|
||||
// Assuming the first event is always a QUEUED node, and
|
||||
// following nodes are QUEUED before all preceding nodes are COMPLETED,
|
||||
// an empty set means the graph has finished running.
|
||||
cancelExecListener();
|
||||
setSaveRunRequest({ request: "none", state: "none" });
|
||||
}
|
||||
},
|
||||
);
|
||||
})
|
||||
.catch(() => setSaveRunRequest({ request: "run", state: "error" }));
|
||||
|
||||
processedUpdates.current = processedUpdates.current = [];
|
||||
}
|
||||
}
|
||||
// Handle stop request
|
||||
if (
|
||||
saveRunRequest.request === "stop" &&
|
||||
saveRunRequest.state != "stopping" &&
|
||||
savedAgent &&
|
||||
saveRunRequest.activeExecutionID
|
||||
) {
|
||||
setSaveRunRequest({
|
||||
request: "stop",
|
||||
state: "stopping",
|
||||
activeExecutionID: saveRunRequest.activeExecutionID,
|
||||
});
|
||||
api
|
||||
.stopGraphExecution(savedAgent.id, saveRunRequest.activeExecutionID)
|
||||
.then(() => setSaveRunRequest({ request: "none", state: "none" }));
|
||||
}
|
||||
}, [saveRunRequest, savedAgent, nodesSyncedWithSavedAgent]);
|
||||
|
||||
// Check if node ids are synced with saved agent
|
||||
@@ -648,7 +701,7 @@ export default function useAgentGraph(
|
||||
[saveAgent],
|
||||
);
|
||||
|
||||
const requestSaveRun = useCallback(() => {
|
||||
const requestSaveAndRun = useCallback(() => {
|
||||
saveAgent();
|
||||
setSaveRunRequest({
|
||||
request: "run",
|
||||
@@ -656,6 +709,23 @@ export default function useAgentGraph(
|
||||
});
|
||||
}, [saveAgent]);
|
||||
|
||||
const requestStopRun = useCallback(() => {
|
||||
if (saveRunRequest.state != "running") {
|
||||
return;
|
||||
}
|
||||
if (!saveRunRequest.activeExecutionID) {
|
||||
console.warn(
|
||||
"Stop requested but execution ID is unknown; state:",
|
||||
saveRunRequest,
|
||||
);
|
||||
}
|
||||
setSaveRunRequest((prev) => ({
|
||||
...prev,
|
||||
request: "stop",
|
||||
state: "running",
|
||||
}));
|
||||
}, [saveRunRequest]);
|
||||
|
||||
return {
|
||||
agentName,
|
||||
setAgentName,
|
||||
@@ -665,7 +735,11 @@ export default function useAgentGraph(
|
||||
availableNodes,
|
||||
getOutputType,
|
||||
requestSave,
|
||||
requestSaveRun,
|
||||
requestSaveAndRun,
|
||||
requestStopRun,
|
||||
isSaving: saveRunRequest.state == "saving",
|
||||
isRunning: saveRunRequest.state == "running",
|
||||
isStopping: saveRunRequest.state == "stopping",
|
||||
nodes,
|
||||
setNodes,
|
||||
edges,
|
||||
|
||||
@@ -15,7 +15,7 @@ export default class AutoGPTServerAPI {
|
||||
private wsUrl: string;
|
||||
private webSocket: WebSocket | null = null;
|
||||
private wsConnecting: Promise<void> | null = null;
|
||||
private wsMessageHandlers: { [key: string]: (data: any) => void } = {};
|
||||
private wsMessageHandlers: Record<string, Set<(data: any) => void>> = {};
|
||||
private supabaseClient = createClient();
|
||||
|
||||
constructor(
|
||||
@@ -126,16 +126,19 @@ export default class AutoGPTServerAPI {
|
||||
runID: string,
|
||||
): Promise<NodeExecutionResult[]> {
|
||||
return (await this._get(`/graphs/${graphID}/executions/${runID}`)).map(
|
||||
(result: any) => ({
|
||||
...result,
|
||||
add_time: new Date(result.add_time),
|
||||
queue_time: result.queue_time ? new Date(result.queue_time) : undefined,
|
||||
start_time: result.start_time ? new Date(result.start_time) : undefined,
|
||||
end_time: result.end_time ? new Date(result.end_time) : undefined,
|
||||
}),
|
||||
parseNodeExecutionResultTimestamps,
|
||||
);
|
||||
}
|
||||
|
||||
async stopGraphExecution(
|
||||
graphID: string,
|
||||
runID: string,
|
||||
): Promise<NodeExecutionResult[]> {
|
||||
return (
|
||||
await this._request("POST", `/graphs/${graphID}/executions/${runID}/stop`)
|
||||
).map(parseNodeExecutionResultTimestamps);
|
||||
}
|
||||
|
||||
private async _get(path: string) {
|
||||
return this._request("GET", path);
|
||||
}
|
||||
@@ -205,10 +208,13 @@ export default class AutoGPTServerAPI {
|
||||
};
|
||||
|
||||
this.webSocket.onmessage = (event) => {
|
||||
const message = JSON.parse(event.data);
|
||||
if (this.wsMessageHandlers[message.method]) {
|
||||
this.wsMessageHandlers[message.method](message.data);
|
||||
const message: WebsocketMessage = JSON.parse(event.data);
|
||||
if (message.method == "execution_event") {
|
||||
message.data = parseNodeExecutionResultTimestamps(message.data);
|
||||
}
|
||||
this.wsMessageHandlers[message.method]?.forEach((handler) =>
|
||||
handler(message.data),
|
||||
);
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("Error connecting to WebSocket:", error);
|
||||
@@ -248,8 +254,12 @@ export default class AutoGPTServerAPI {
|
||||
onWebSocketMessage<M extends keyof WebsocketMessageTypeMap>(
|
||||
method: M,
|
||||
handler: (data: WebsocketMessageTypeMap[M]) => void,
|
||||
) {
|
||||
this.wsMessageHandlers[method] = handler;
|
||||
): () => void {
|
||||
this.wsMessageHandlers[method] ??= new Set();
|
||||
this.wsMessageHandlers[method].add(handler);
|
||||
|
||||
// Return detacher
|
||||
return () => this.wsMessageHandlers[method].delete(handler);
|
||||
}
|
||||
|
||||
subscribeToExecution(graphId: string) {
|
||||
@@ -272,3 +282,22 @@ type WebsocketMessageTypeMap = {
|
||||
subscribe: { graph_id: string };
|
||||
execution_event: NodeExecutionResult;
|
||||
};
|
||||
|
||||
type WebsocketMessage = {
|
||||
[M in keyof WebsocketMessageTypeMap]: {
|
||||
method: M;
|
||||
data: WebsocketMessageTypeMap[M];
|
||||
};
|
||||
}[keyof WebsocketMessageTypeMap];
|
||||
|
||||
/* *** HELPER FUNCTIONS *** */
|
||||
|
||||
function parseNodeExecutionResultTimestamps(result: any): NodeExecutionResult {
|
||||
return {
|
||||
...result,
|
||||
add_time: new Date(result.add_time),
|
||||
queue_time: result.queue_time ? new Date(result.queue_time) : undefined,
|
||||
start_time: result.start_time ? new Date(result.start_time) : undefined,
|
||||
end_time: result.end_time ? new Date(result.end_time) : undefined,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -75,33 +75,22 @@ class PrintingBlock(Block):
|
||||
yield "status", "printed"
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
class ObjectLookupBlock(Block):
|
||||
|
||||
class Input(BlockSchema):
|
||||
input: Any = Field(description="Dictionary to lookup from")
|
||||
key: str | int = Field(description="Key to lookup in the dictionary")
|
||||
|
||||
class ObjectLookupBaseInput(BlockSchema, Generic[T]):
|
||||
input: T = Field(description="Dictionary to lookup from")
|
||||
key: str | int = Field(description="Key to lookup in the dictionary")
|
||||
|
||||
|
||||
class ObjectLookupBaseOutput(BlockSchema, Generic[T]):
|
||||
output: T = Field(description="Value found for the given key")
|
||||
missing: T = Field(description="Value of the input that missing the key")
|
||||
|
||||
|
||||
class ObjectLookupBase(Block, ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def block_id(self) -> str:
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
input_schema = ObjectLookupBaseInput[T]
|
||||
output_schema = ObjectLookupBaseOutput[T]
|
||||
class Output(BlockSchema):
|
||||
output: Any = Field(description="Value found for the given key")
|
||||
missing: Any = Field(description="Value of the input that missing the key")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id=self.block_id(),
|
||||
id="b2g2c3d4-5e6f-7g8h-9i0j-k1l2m3n4o5p6",
|
||||
description="Lookup the given key in the input dictionary/object/list and return the value.",
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
input_schema=ObjectLookupBlock.Input,
|
||||
output_schema=ObjectLookupBlock.Output,
|
||||
test_input=[
|
||||
{"input": {"apple": 1, "banana": 2, "cherry": 3}, "key": "banana"},
|
||||
{"input": {"x": 10, "y": 20, "z": 30}, "key": "w"},
|
||||
@@ -118,11 +107,10 @@ class ObjectLookupBase(Block, ABC, Generic[T]):
|
||||
("output", "key"),
|
||||
("output", ["v1", "v3"]),
|
||||
],
|
||||
*args,
|
||||
**kwargs,
|
||||
categories={BlockCategory.BASIC},
|
||||
)
|
||||
|
||||
def run(self, input_data: ObjectLookupBaseInput[T]) -> BlockOutput:
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
obj = input_data.input
|
||||
key = input_data.key
|
||||
|
||||
@@ -143,15 +131,50 @@ class ObjectLookupBase(Block, ABC, Generic[T]):
|
||||
yield "missing", input_data.input
|
||||
|
||||
|
||||
class ObjectLookupBlock(ObjectLookupBase[Any]):
|
||||
def __init__(self):
|
||||
super().__init__(categories={BlockCategory.BASIC})
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class InputOutputBlockInput(BlockSchema, Generic[T]):
|
||||
value: T = Field(description="The value to be passed as input/output.")
|
||||
name: str = Field(description="The name of the input/output.")
|
||||
|
||||
|
||||
class InputOutputBlockOutput(BlockSchema, Generic[T]):
|
||||
value: T = Field(description="The value passed as input/output.")
|
||||
|
||||
|
||||
class InputOutputBlockBase(Block, ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def block_id(self) -> str:
|
||||
return "b2g2c3d4-5e6f-7g8h-9i0j-k1l2m3n4o5p6"
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
input_schema = InputOutputBlockInput[T]
|
||||
output_schema = InputOutputBlockOutput[T]
|
||||
|
||||
super().__init__(
|
||||
id=self.block_id(),
|
||||
description="This block is used to define the input & output of a graph.",
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
test_input=[
|
||||
{"value": {"apple": 1, "banana": 2, "cherry": 3}, "name": "input_1"},
|
||||
{"value": MockObject(value="!!", key="key"), "name": "input_2"},
|
||||
],
|
||||
test_output=[
|
||||
("value", {"apple": 1, "banana": 2, "cherry": 3}),
|
||||
("value", MockObject(value="!!", key="key")),
|
||||
],
|
||||
static_output=True,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def run(self, input_data: InputOutputBlockInput[T]) -> BlockOutput:
|
||||
yield "value", input_data.value
|
||||
|
||||
|
||||
class InputBlock(ObjectLookupBase[Any]):
|
||||
class InputBlock(InputOutputBlockBase[Any]):
|
||||
def __init__(self):
|
||||
super().__init__(categories={BlockCategory.INPUT, BlockCategory.BASIC})
|
||||
|
||||
@@ -159,7 +182,7 @@ class InputBlock(ObjectLookupBase[Any]):
|
||||
return "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b"
|
||||
|
||||
|
||||
class OutputBlock(ObjectLookupBase[Any]):
|
||||
class OutputBlock(InputOutputBlockBase[Any]):
|
||||
def __init__(self):
|
||||
super().__init__(categories={BlockCategory.OUTPUT, BlockCategory.BASIC})
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
from autogpt_server.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
|
||||
@@ -103,6 +103,7 @@ class CurrentDateAndTimeBlock(Block):
|
||||
|
||||
class TimerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
message: Any = "timer finished"
|
||||
seconds: Union[int, str] = 0
|
||||
minutes: Union[int, str] = 0
|
||||
hours: Union[int, str] = 0
|
||||
@@ -120,9 +121,11 @@ class TimerBlock(Block):
|
||||
output_schema=TimerBlock.Output,
|
||||
test_input=[
|
||||
{"seconds": 1},
|
||||
{"message": "Custom message"},
|
||||
],
|
||||
test_output=[
|
||||
("message", "timer finished"),
|
||||
("message", "Custom message"),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -136,4 +139,4 @@ class TimerBlock(Block):
|
||||
total_seconds = seconds + minutes * 60 + hours * 3600 + days * 86400
|
||||
|
||||
time.sleep(total_seconds)
|
||||
yield "message", "timer finished"
|
||||
yield "message", input_data.message
|
||||
|
||||
@@ -9,7 +9,7 @@ from prisma.models import (
|
||||
AgentNodeExecution,
|
||||
AgentNodeExecutionInputOutput,
|
||||
)
|
||||
from prisma.types import AgentGraphExecutionWhereInput
|
||||
from prisma.types import AgentGraphExecutionWhereInput, AgentNodeExecutionInclude
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogpt_server.data.block import BlockData, BlockInput, CompletedBlockOutput
|
||||
@@ -108,7 +108,7 @@ class ExecutionResult(BaseModel):
|
||||
|
||||
# --------------------- Model functions --------------------- #
|
||||
|
||||
EXECUTION_RESULT_INCLUDE = {
|
||||
EXECUTION_RESULT_INCLUDE: AgentNodeExecutionInclude = {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"AgentNode": True,
|
||||
@@ -120,7 +120,7 @@ async def create_graph_execution(
|
||||
graph_id: str,
|
||||
graph_version: int,
|
||||
nodes_input: list[tuple[str, BlockInput]],
|
||||
user_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> tuple[str, list[ExecutionResult]]:
|
||||
"""
|
||||
Create a new AgentGraphExecution record.
|
||||
@@ -148,9 +148,7 @@ async def create_graph_execution(
|
||||
},
|
||||
"userId": user_id,
|
||||
},
|
||||
include={
|
||||
"AgentNodeExecutions": {"include": EXECUTION_RESULT_INCLUDE} # type: ignore
|
||||
},
|
||||
include={"AgentNodeExecutions": {"include": EXECUTION_RESULT_INCLUDE}}, # type: ignore
|
||||
)
|
||||
|
||||
return result.id, [
|
||||
@@ -263,7 +261,7 @@ async def update_execution_status(
|
||||
res = await AgentNodeExecution.prisma().update(
|
||||
where={"id": node_exec_id},
|
||||
data=data, # type: ignore
|
||||
include=EXECUTION_RESULT_INCLUDE, # type: ignore
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
if not res:
|
||||
raise ValueError(f"Execution {node_exec_id} not found.")
|
||||
@@ -271,6 +269,26 @@ async def update_execution_status(
|
||||
return ExecutionResult.from_db(res)
|
||||
|
||||
|
||||
async def get_graph_execution(
|
||||
graph_exec_id: str, user_id: str
|
||||
) -> AgentGraphExecution | None:
|
||||
"""
|
||||
Retrieve a specific graph execution by its ID.
|
||||
|
||||
Args:
|
||||
graph_exec_id (str): The ID of the graph execution to retrieve.
|
||||
user_id (str): The ID of the user to whom the graph (execution) belongs.
|
||||
|
||||
Returns:
|
||||
AgentGraphExecution | None: The graph execution if found, None otherwise.
|
||||
"""
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={"id": graph_exec_id, "userId": user_id},
|
||||
include={"AgentNodeExecutions": {"include": EXECUTION_RESULT_INCLUDE}}, # type: ignore
|
||||
)
|
||||
return execution
|
||||
|
||||
|
||||
async def list_executions(graph_id: str, graph_version: int | None = None) -> list[str]:
|
||||
where: AgentGraphExecutionWhereInput = {"agentGraphId": graph_id}
|
||||
if graph_version is not None:
|
||||
@@ -282,7 +300,7 @@ async def list_executions(graph_id: str, graph_version: int | None = None) -> li
|
||||
async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
|
||||
executions = await AgentNodeExecution.prisma().find_many(
|
||||
where={"agentGraphExecutionId": graph_exec_id},
|
||||
include=EXECUTION_RESULT_INCLUDE, # type: ignore
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
order=[
|
||||
{"queuedTime": "asc"},
|
||||
{"addedTime": "asc"}, # Fallback: Incomplete execs has no queuedTime.
|
||||
@@ -379,7 +397,7 @@ async def get_latest_execution(node_id: str, graph_eid: str) -> ExecutionResult
|
||||
"executionData": {"not": None},
|
||||
},
|
||||
order={"queuedTime": "desc"},
|
||||
include=EXECUTION_RESULT_INCLUDE, # type: ignore
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
if not execution:
|
||||
return None
|
||||
@@ -395,6 +413,6 @@ async def get_incomplete_executions(
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
"executionStatus": ExecutionStatus.INCOMPLETE,
|
||||
},
|
||||
include=EXECUTION_RESULT_INCLUDE, # type: ignore
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
return [ExecutionResult.from_db(execution) for execution in executions]
|
||||
|
||||
@@ -294,6 +294,7 @@ async def get_graphs_meta(
|
||||
|
||||
Args:
|
||||
filter_by: An optional filter to either select templates or active graphs.
|
||||
user_id: An optional user ID to filter the graphs by.
|
||||
|
||||
Returns:
|
||||
list[GraphMeta]: A list of objects representing the retrieved graph metadata.
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import Future, ProcessPoolExecutor, TimeoutError
|
||||
import multiprocessing
|
||||
import threading
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing.pool import AsyncResult, Pool
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -16,6 +19,7 @@ from autogpt_server.data.execution import (
|
||||
GraphExecution,
|
||||
NodeExecution,
|
||||
create_graph_execution,
|
||||
get_execution_results,
|
||||
get_incomplete_executions,
|
||||
get_latest_execution,
|
||||
merge_execution_input,
|
||||
@@ -350,76 +354,111 @@ class Executor:
|
||||
@classmethod
|
||||
def on_graph_executor_start(cls):
|
||||
cls.pool_size = Config().num_node_workers
|
||||
cls.executor = ProcessPoolExecutor(
|
||||
max_workers=cls.pool_size,
|
||||
initializer=cls.on_node_executor_start,
|
||||
)
|
||||
cls._init_node_executor_pool()
|
||||
logger.warning(f"Graph executor started with max-{cls.pool_size} node workers.")
|
||||
|
||||
@classmethod
|
||||
def on_graph_execution(cls, graph_data: GraphExecution):
|
||||
def _init_node_executor_pool(cls):
|
||||
cls.executor = Pool(
|
||||
processes=cls.pool_size,
|
||||
initializer=cls.on_node_executor_start,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def on_graph_execution(cls, graph_data: GraphExecution, cancel: threading.Event):
|
||||
prefix = get_log_prefix(graph_data.graph_exec_id, "*")
|
||||
logger.warning(f"{prefix} Start graph execution")
|
||||
|
||||
finished = False
|
||||
|
||||
def cancel_handler():
|
||||
while not cancel.is_set():
|
||||
cancel.wait(1)
|
||||
if finished:
|
||||
return
|
||||
cls.executor.terminate()
|
||||
logger.info(
|
||||
f"{prefix} Terminated graph execution {graph_data.graph_exec_id}"
|
||||
)
|
||||
cls._init_node_executor_pool()
|
||||
|
||||
cancel_thread = threading.Thread(target=cancel_handler)
|
||||
cancel_thread.start()
|
||||
|
||||
try:
|
||||
queue = ExecutionQueue[NodeExecution]()
|
||||
for node_exec in graph_data.start_node_execs:
|
||||
queue.add(node_exec)
|
||||
|
||||
futures: dict[str, Future] = {}
|
||||
running_executions: dict[str, AsyncResult] = {}
|
||||
while not queue.empty():
|
||||
execution = queue.get()
|
||||
if cancel.is_set():
|
||||
return
|
||||
|
||||
exec_data = queue.get()
|
||||
|
||||
# Avoid parallel execution of the same node.
|
||||
fut = futures.get(execution.node_id)
|
||||
if fut and not fut.done():
|
||||
execution = running_executions.get(exec_data.node_id)
|
||||
if execution and not execution.ready():
|
||||
# TODO (performance improvement):
|
||||
# Wait for the completion of the same node execution is blocking.
|
||||
# To improve this we need a separate queue for each node.
|
||||
# Re-enqueueing the data back to the queue will disrupt the order.
|
||||
cls.wait_future(fut, timeout=None)
|
||||
execution.wait()
|
||||
|
||||
futures[execution.node_id] = cls.executor.submit(
|
||||
cls.on_node_execution, queue, execution
|
||||
running_executions[exec_data.node_id] = cls.executor.apply_async(
|
||||
cls.on_node_execution,
|
||||
(queue, exec_data),
|
||||
callback=lambda _: running_executions.pop(exec_data.node_id),
|
||||
)
|
||||
|
||||
# Avoid terminating graph execution when some nodes are still running.
|
||||
while queue.empty() and futures:
|
||||
for node_id, future in list(futures.items()):
|
||||
if future.done():
|
||||
del futures[node_id]
|
||||
elif queue.empty():
|
||||
cls.wait_future(future)
|
||||
while queue.empty() and running_executions:
|
||||
for execution in list(running_executions.values()):
|
||||
if cancel.is_set():
|
||||
return
|
||||
|
||||
if not queue.empty():
|
||||
break # yield to parent loop to execute new queue items
|
||||
|
||||
execution.wait(3)
|
||||
|
||||
logger.warning(f"{prefix} Finished graph execution")
|
||||
except Exception as e:
|
||||
logger.exception(f"{prefix} Failed graph execution: {e}")
|
||||
|
||||
@classmethod
|
||||
def wait_future(cls, future: Future, timeout: int | None = 3):
|
||||
try:
|
||||
if not future.done():
|
||||
future.result(timeout=timeout)
|
||||
except TimeoutError:
|
||||
# Avoid being blocked by long-running node, by not waiting its completion.
|
||||
pass
|
||||
finally:
|
||||
if not cancel.is_set():
|
||||
finished = True
|
||||
cancel.set()
|
||||
cancel_thread.join()
|
||||
|
||||
|
||||
class ExecutionManager(AppService):
|
||||
def __init__(self):
|
||||
self.pool_size = Config().num_graph_workers
|
||||
self.queue = ExecutionQueue[GraphExecution]()
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
|
||||
def run_service(self):
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
initializer=Executor.on_graph_executor_start,
|
||||
) as executor:
|
||||
logger.warning(
|
||||
sync_manager = multiprocessing.Manager()
|
||||
logger.info(
|
||||
f"Execution manager started with max-{self.pool_size} graph workers."
|
||||
)
|
||||
while True:
|
||||
executor.submit(Executor.on_graph_execution, self.queue.get())
|
||||
graph_exec_data = self.queue.get()
|
||||
graph_exec_id = graph_exec_data.graph_exec_id
|
||||
cancel_event = sync_manager.Event()
|
||||
future = executor.submit(
|
||||
Executor.on_graph_execution, graph_exec_data, cancel_event
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
future.add_done_callback(
|
||||
lambda _: self.active_graph_runs.pop(graph_exec_id)
|
||||
)
|
||||
|
||||
@property
|
||||
def agent_server_client(self) -> "AgentServer":
|
||||
@@ -427,8 +466,8 @@ class ExecutionManager(AppService):
|
||||
|
||||
@expose
|
||||
def add_execution(
|
||||
self, graph_id: str, data: BlockInput, user_id: str
|
||||
) -> dict[Any, Any]:
|
||||
self, graph_id: str, data: BlockInput, user_id: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
graph: Graph | None = self.run_and_wait(get_graph(graph_id, user_id=user_id))
|
||||
if not graph:
|
||||
raise Exception(f"Graph #{graph_id} not found.")
|
||||
@@ -436,10 +475,11 @@ class ExecutionManager(AppService):
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
input_data = {}
|
||||
if isinstance(get_block(node.block_id), InputBlock):
|
||||
input_data = {"input": data}
|
||||
else:
|
||||
input_data = {}
|
||||
name = node.input_default.get("name")
|
||||
if name and name in data:
|
||||
input_data = {"value": data[name]}
|
||||
|
||||
input_data, error = validate_exec(node, input_data)
|
||||
if input_data is None:
|
||||
@@ -479,4 +519,45 @@ class ExecutionManager(AppService):
|
||||
)
|
||||
self.queue.add(graph_exec)
|
||||
|
||||
return {"id": graph_exec_id}
|
||||
return graph_exec.model_dump()
|
||||
|
||||
@expose
|
||||
def cancel_execution(self, graph_exec_id: str) -> None:
|
||||
"""
|
||||
Mechanism:
|
||||
1. Set the cancel event
|
||||
2. Graph executor's cancel handler thread detects the event, terminates workers,
|
||||
reinitializes worker pool, and returns.
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
if graph_exec_id not in self.active_graph_runs:
|
||||
raise Exception(
|
||||
f"Graph execution #{graph_exec_id} not active/running: "
|
||||
"possibly already completed/cancelled."
|
||||
)
|
||||
|
||||
future, cancel_event = self.active_graph_runs[graph_exec_id]
|
||||
if cancel_event.is_set():
|
||||
return
|
||||
|
||||
cancel_event.set()
|
||||
future.result()
|
||||
|
||||
# Update the status of the unfinished node executions
|
||||
node_execs = self.run_and_wait(get_execution_results(graph_exec_id))
|
||||
for node_exec in node_execs:
|
||||
if node_exec.status not in (
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.FAILED,
|
||||
):
|
||||
exec_update = self.run_and_wait(
|
||||
update_execution_status(
|
||||
node_exec.node_exec_id, ExecutionStatus.FAILED
|
||||
)
|
||||
)
|
||||
self.run_and_wait(
|
||||
upsert_execution_output(
|
||||
node_exec.node_exec_id, "error", "TERMINATED"
|
||||
)
|
||||
)
|
||||
self.agent_server_client.send_execution_update(exec_update.model_dump())
|
||||
|
||||
@@ -37,7 +37,8 @@ class ExecutionScheduler(AppService):
|
||||
def __refresh_jobs_from_db(self, scheduler: BackgroundScheduler):
|
||||
schedules = self.run_and_wait(model.get_active_schedules(self.last_check))
|
||||
for schedule in schedules:
|
||||
self.last_check = max(self.last_check, schedule.last_updated)
|
||||
if schedule.last_updated:
|
||||
self.last_check = max(self.last_check, schedule.last_updated)
|
||||
|
||||
if not schedule.is_enabled:
|
||||
log(f"Removing recurring job {schedule.id}: {schedule.schedule}")
|
||||
|
||||
@@ -22,14 +22,10 @@ from fastapi.responses import JSONResponse
|
||||
|
||||
import autogpt_server.server.ws_api
|
||||
from autogpt_server.data import block, db
|
||||
from autogpt_server.data import execution as execution_db
|
||||
from autogpt_server.data import graph as graph_db
|
||||
from autogpt_server.data import user as user_db
|
||||
from autogpt_server.data.block import BlockInput, CompletedBlockOutput
|
||||
from autogpt_server.data.execution import (
|
||||
ExecutionResult,
|
||||
get_execution_results,
|
||||
list_executions,
|
||||
)
|
||||
from autogpt_server.data.user import DEFAULT_USER_ID, get_or_create_user
|
||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||
from autogpt_server.server.conn_manager import ConnectionManager
|
||||
@@ -58,7 +54,7 @@ def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
|
||||
|
||||
|
||||
class AgentServer(AppService):
|
||||
event_queue: asyncio.Queue[ExecutionResult] = asyncio.Queue()
|
||||
event_queue: asyncio.Queue[execution_db.ExecutionResult] = asyncio.Queue()
|
||||
manager = ConnectionManager()
|
||||
mutex = KeyedMutex()
|
||||
use_db = False
|
||||
@@ -66,7 +62,7 @@ class AgentServer(AppService):
|
||||
|
||||
async def event_broadcaster(self):
|
||||
while True:
|
||||
event: ExecutionResult = await self.event_queue.get()
|
||||
event: execution_db.ExecutionResult = await self.event_queue.get()
|
||||
await self.manager.send_execution_result(event)
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -193,10 +189,15 @@ class AgentServer(AppService):
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
path="/graphs/{graph_id}/executions/{run_id}",
|
||||
endpoint=self.get_run_execution_results,
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}",
|
||||
endpoint=self.get_graph_run_node_execution_results,
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}/stop",
|
||||
endpoint=self.stop_graph_run,
|
||||
methods=["POST"],
|
||||
)
|
||||
router.add_api_route(
|
||||
path="/graphs/{graph_id}/schedules",
|
||||
endpoint=self.create_schedule, # type: ignore
|
||||
@@ -508,15 +509,29 @@ class AgentServer(AppService):
|
||||
graph_id: str,
|
||||
node_input: dict[Any, Any],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[Any, Any]:
|
||||
) -> dict[str, Any]: # FIXME: add proper return type
|
||||
try:
|
||||
return self.execution_manager_client.add_execution(
|
||||
graph_exec = self.execution_manager_client.add_execution(
|
||||
graph_id, node_input, user_id=user_id
|
||||
)
|
||||
return {"id": graph_exec["graph_exec_id"]}
|
||||
except Exception as e:
|
||||
msg = e.__str__().encode().decode("unicode_escape")
|
||||
raise HTTPException(status_code=400, detail=msg)
|
||||
|
||||
async def stop_graph_run(
|
||||
self, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[execution_db.ExecutionResult]:
|
||||
if not await execution_db.get_graph_execution(graph_exec_id, user_id):
|
||||
raise HTTPException(
|
||||
404, detail=f"Agent execution #{graph_exec_id} not found"
|
||||
)
|
||||
|
||||
self.execution_manager_client.cancel_execution(graph_exec_id)
|
||||
|
||||
# Retrieve & return canceled graph execution in its final state
|
||||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
|
||||
@classmethod
|
||||
async def list_graph_runs(
|
||||
cls,
|
||||
@@ -531,17 +546,20 @@ class AgentServer(AppService):
|
||||
status_code=404, detail=f"Agent #{graph_id}{rev} not found."
|
||||
)
|
||||
|
||||
return await list_executions(graph_id, graph_version)
|
||||
return await execution_db.list_executions(graph_id, graph_version)
|
||||
|
||||
@classmethod
|
||||
async def get_run_execution_results(
|
||||
cls, graph_id: str, run_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[ExecutionResult]:
|
||||
async def get_graph_run_node_execution_results(
|
||||
cls,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[execution_db.ExecutionResult]:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
return await get_execution_results(run_id)
|
||||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
|
||||
async def create_schedule(
|
||||
self,
|
||||
@@ -579,7 +597,7 @@ class AgentServer(AppService):
|
||||
|
||||
@expose
|
||||
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
|
||||
execution_result = ExecutionResult(**execution_result_dict)
|
||||
execution_result = execution_db.ExecutionResult(**execution_result_dict)
|
||||
self.run_and_wait(self.event_queue.put(execution_result))
|
||||
|
||||
@expose
|
||||
|
||||
@@ -29,11 +29,11 @@ def create_test_graph() -> graph.Graph:
|
||||
nodes = [
|
||||
graph.Node(
|
||||
block_id=InputBlock().id,
|
||||
input_default={"key": "input_1"},
|
||||
input_default={"name": "input_1"},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=InputBlock().id,
|
||||
input_default={"key": "input_2"},
|
||||
input_default={"name": "input_2"},
|
||||
),
|
||||
graph.Node(
|
||||
block_id=TextFormatterBlock().id,
|
||||
@@ -48,13 +48,13 @@ def create_test_graph() -> graph.Graph:
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="output",
|
||||
source_name="value",
|
||||
sink_name="values_#_a",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[1].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="output",
|
||||
source_name="value",
|
||||
sink_name="values_#_b",
|
||||
),
|
||||
graph.Link(
|
||||
|
||||
@@ -16,11 +16,21 @@ from autogpt_server.util.settings import Config
|
||||
logger = logging.getLogger(__name__)
|
||||
conn_retry = retry(stop=stop_after_delay(5), wait=wait_exponential(multiplier=0.1))
|
||||
T = TypeVar("T")
|
||||
C = TypeVar("C", bound=Callable)
|
||||
|
||||
pyro_host = Config().pyro_host
|
||||
|
||||
|
||||
def expose(func: Callable) -> Callable:
|
||||
def expose(func: C) -> C:
|
||||
"""
|
||||
Decorator to mark a method or class to be exposed for remote calls.
|
||||
|
||||
## ⚠️ Gotcha
|
||||
The types on the exposed function signature are respected **as long as they are
|
||||
fully picklable**. This is not the case for Pydantic models, so if you really need
|
||||
to pass a model, try dumping the model and passing the resulting dict instead.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
@@ -29,7 +39,7 @@ def expose(func: Callable) -> Callable:
|
||||
logger.exception(msg)
|
||||
raise Exception(msg, e)
|
||||
|
||||
return pyro.expose(wrapper)
|
||||
return pyro.expose(wrapper) # type: ignore
|
||||
|
||||
|
||||
class PyroNameServer(AppProcess):
|
||||
@@ -58,7 +68,7 @@ class AppService(AppProcess):
|
||||
def __run_async(self, coro: Coroutine[T, Any, T]):
|
||||
return asyncio.run_coroutine_threadsafe(coro, self.shared_event_loop)
|
||||
|
||||
def run_and_wait(self, coro: Coroutine[T, Any, T]) -> T:
|
||||
def run_and_wait(self, coro: Coroutine[Any, Any, T]) -> T:
|
||||
future = self.__run_async(coro)
|
||||
return future.result()
|
||||
|
||||
@@ -100,7 +110,6 @@ def get_service_client(service_type: Type[AS]) -> AS:
|
||||
service_name = service_type.service_name
|
||||
|
||||
class DynamicClient:
|
||||
|
||||
@conn_retry
|
||||
def __init__(self):
|
||||
ns = pyro.locate_ns()
|
||||
|
||||
@@ -23,7 +23,6 @@ class SpinTestServer:
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
async def __aenter__(self):
|
||||
|
||||
self.name_server.__enter__()
|
||||
self.setup_dependency_overrides()
|
||||
self.agent_server.__enter__()
|
||||
@@ -59,7 +58,7 @@ async def wait_execution(
|
||||
timeout: int = 20,
|
||||
) -> list:
|
||||
async def is_execution_completed():
|
||||
execs = await AgentServer().get_run_execution_results(
|
||||
execs = await AgentServer().get_graph_run_node_execution_results(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
return (
|
||||
@@ -74,7 +73,7 @@ async def wait_execution(
|
||||
# Wait for the executions to complete
|
||||
for i in range(timeout):
|
||||
if await is_execution_completed():
|
||||
return await AgentServer().get_run_execution_results(
|
||||
return await AgentServer().get_graph_run_node_execution_results(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
-- RedefineTables
|
||||
PRAGMA foreign_keys=OFF;
|
||||
CREATE TABLE "new_AgentGraphExecutionSchedule" (
|
||||
"id" TEXT NOT NULL PRIMARY KEY,
|
||||
"agentGraphId" TEXT NOT NULL,
|
||||
"agentGraphVersion" INTEGER NOT NULL DEFAULT 1,
|
||||
"schedule" TEXT NOT NULL,
|
||||
"isEnabled" BOOLEAN NOT NULL DEFAULT true,
|
||||
"inputData" TEXT NOT NULL,
|
||||
"lastUpdated" DATETIME NOT NULL,
|
||||
"userId" TEXT,
|
||||
CONSTRAINT "AgentGraphExecutionSchedule_agentGraphId_agentGraphVersion_fkey" FOREIGN KEY ("agentGraphId", "agentGraphVersion") REFERENCES "AgentGraph" ("id", "version") ON DELETE RESTRICT ON UPDATE CASCADE,
|
||||
CONSTRAINT "AgentGraphExecutionSchedule_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User" ("id") ON DELETE SET NULL ON UPDATE CASCADE
|
||||
);
|
||||
INSERT INTO "new_AgentGraphExecutionSchedule" ("agentGraphId", "agentGraphVersion", "id", "inputData", "isEnabled", "lastUpdated", "schedule", "userId") SELECT "agentGraphId", "agentGraphVersion", "id", "inputData", "isEnabled", "lastUpdated", "schedule", "userId" FROM "AgentGraphExecutionSchedule";
|
||||
DROP TABLE "AgentGraphExecutionSchedule";
|
||||
ALTER TABLE "new_AgentGraphExecutionSchedule" RENAME TO "AgentGraphExecutionSchedule";
|
||||
CREATE INDEX "AgentGraphExecutionSchedule_isEnabled_idx" ON "AgentGraphExecutionSchedule"("isEnabled");
|
||||
PRAGMA foreign_key_check;
|
||||
PRAGMA foreign_keys=ON;
|
||||
@@ -0,0 +1,8 @@
|
||||
-- DropForeignKey
|
||||
ALTER TABLE "AgentGraphExecutionSchedule" DROP CONSTRAINT "AgentGraphExecutionSchedule_userId_fkey";
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentGraphExecutionSchedule" ALTER COLUMN "userId" DROP NOT NULL;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "AgentGraphExecutionSchedule" ADD CONSTRAINT "AgentGraphExecutionSchedule_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
@@ -12,11 +12,11 @@ generator client {
|
||||
|
||||
// User model to mirror Auth provider users
|
||||
model User {
|
||||
id String @id // This should match the Supabase user ID
|
||||
email String @unique
|
||||
name String?
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
id String @id // This should match the Supabase user ID
|
||||
email String @unique
|
||||
name String?
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
// Relations
|
||||
AgentGraphs AgentGraph[]
|
||||
@@ -184,8 +184,8 @@ model AgentGraphExecutionSchedule {
|
||||
lastUpdated DateTime @updatedAt
|
||||
|
||||
// Link to User model
|
||||
userId String
|
||||
user User @relation(fields: [userId], references: [id])
|
||||
userId String?
|
||||
user User? @relation(fields: [userId], references: [id])
|
||||
|
||||
@@index([isEnabled])
|
||||
}
|
||||
|
||||
@@ -11,11 +11,11 @@ generator client {
|
||||
|
||||
// User model to mirror Auth provider users
|
||||
model User {
|
||||
id String @id // This should match the Supabase user ID
|
||||
email String @unique
|
||||
name String?
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
id String @id // This should match the Supabase user ID
|
||||
email String @unique
|
||||
name String?
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
// Relations
|
||||
AgentGraphs AgentGraph[]
|
||||
@@ -183,8 +183,8 @@ model AgentGraphExecutionSchedule {
|
||||
lastUpdated DateTime @updatedAt
|
||||
|
||||
// Link to User model
|
||||
userId String
|
||||
user User @relation(fields: [userId], references: [id])
|
||||
userId String?
|
||||
user User? @relation(fields: [userId], references: [id])
|
||||
|
||||
@@index([isEnabled])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,8 +35,7 @@ async def assert_sample_graph_executions(
|
||||
test_user: User,
|
||||
graph_exec_id: str,
|
||||
):
|
||||
input = {"input_1": "Hello", "input_2": "World"}
|
||||
executions = await agent_server.get_run_execution_results(
|
||||
executions = await agent_server.get_graph_run_node_execution_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
|
||||
@@ -44,16 +43,16 @@ async def assert_sample_graph_executions(
|
||||
exec = executions[0]
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
assert exec.output_data == {"output": ["Hello"]}
|
||||
assert exec.input_data == {"input": input, "key": "input_1"}
|
||||
assert exec.output_data == {"value": ["Hello"]}
|
||||
assert exec.input_data == {"value": "Hello", "name": "input_1"}
|
||||
assert exec.node_id in [test_graph.nodes[0].id, test_graph.nodes[1].id]
|
||||
|
||||
# Executing ValueBlock
|
||||
exec = executions[1]
|
||||
assert exec.status == execution.ExecutionStatus.COMPLETED
|
||||
assert exec.graph_exec_id == graph_exec_id
|
||||
assert exec.output_data == {"output": ["World"]}
|
||||
assert exec.input_data == {"input": input, "key": "input_2"}
|
||||
assert exec.output_data == {"value": ["World"]}
|
||||
assert exec.input_data == {"value": "World", "name": "input_2"}
|
||||
assert exec.node_id in [test_graph.nodes[0].id, test_graph.nodes[1].id]
|
||||
|
||||
# Executing TextFormatterBlock
|
||||
@@ -151,7 +150,7 @@ async def test_input_pin_always_waited(server):
|
||||
server.agent_server, server.exec_manager, test_graph, test_user, {}, 3
|
||||
)
|
||||
|
||||
executions = await server.agent_server.get_run_execution_results(
|
||||
executions = await server.agent_server.get_graph_run_node_execution_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
assert len(executions) == 3
|
||||
@@ -231,7 +230,7 @@ async def test_static_input_link_on_graph(server):
|
||||
graph_exec_id = await execute_graph(
|
||||
server.agent_server, server.exec_manager, test_graph, test_user, {}, 8
|
||||
)
|
||||
executions = await server.agent_server.get_run_execution_results(
|
||||
executions = await server.agent_server.get_graph_run_node_execution_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
assert len(executions) == 8
|
||||
|
||||
Reference in New Issue
Block a user