mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-20 04:28:09 -05:00
Compare commits
11 Commits
make-old-w
...
reinier/ws
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e2c2b48d77 | ||
|
|
bd36acccb2 | ||
|
|
d63ab9a2f9 | ||
|
|
fce6394a49 | ||
|
|
13e7716424 | ||
|
|
2973567010 | ||
|
|
b6c4fc4742 | ||
|
|
f9a3170296 | ||
|
|
a74f76893e | ||
|
|
e6aaf71f21 | ||
|
|
31129bd080 |
@@ -126,16 +126,19 @@ export default class AutoGPTServerAPI {
|
||||
runID: string,
|
||||
): Promise<NodeExecutionResult[]> {
|
||||
return (await this._get(`/graphs/${graphID}/executions/${runID}`)).map(
|
||||
(result: any) => ({
|
||||
...result,
|
||||
add_time: new Date(result.add_time),
|
||||
queue_time: result.queue_time ? new Date(result.queue_time) : undefined,
|
||||
start_time: result.start_time ? new Date(result.start_time) : undefined,
|
||||
end_time: result.end_time ? new Date(result.end_time) : undefined,
|
||||
}),
|
||||
parseNodeExecutionResultTimestamps,
|
||||
);
|
||||
}
|
||||
|
||||
async stopGraphExecution(
|
||||
graphID: string,
|
||||
runID: string,
|
||||
): Promise<NodeExecutionResult[]> {
|
||||
return (
|
||||
await this._request("POST", `/graphs/${graphID}/executions/${runID}/stop`)
|
||||
).map(parseNodeExecutionResultTimestamps);
|
||||
}
|
||||
|
||||
private async _get(path: string) {
|
||||
return this._request("GET", path);
|
||||
}
|
||||
@@ -272,3 +275,15 @@ type WebsocketMessageTypeMap = {
|
||||
subscribe: { graph_id: string };
|
||||
execution_event: NodeExecutionResult;
|
||||
};
|
||||
|
||||
/* *** HELPER FUNCTIONS *** */
|
||||
|
||||
function parseNodeExecutionResultTimestamps(result: any): NodeExecutionResult {
|
||||
return {
|
||||
...result,
|
||||
add_time: new Date(result.add_time),
|
||||
queue_time: result.queue_time ? new Date(result.queue_time) : undefined,
|
||||
start_time: result.start_time ? new Date(result.start_time) : undefined,
|
||||
end_time: result.end_time ? new Date(result.end_time) : undefined,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -271,6 +271,28 @@ async def update_execution_status(
|
||||
return ExecutionResult.from_db(res)
|
||||
|
||||
|
||||
async def get_graph_execution(
|
||||
graph_exec_id: str, user_id: str
|
||||
) -> AgentGraphExecution | None:
|
||||
"""
|
||||
Retrieve a specific graph execution by its ID.
|
||||
|
||||
Args:
|
||||
graph_exec_id (str): The ID of the graph execution to retrieve.
|
||||
user_id (str): The ID of the user to whom the graph (execution) belongs.
|
||||
|
||||
Returns:
|
||||
AgentGraphExecution | None: The graph execution if found, None otherwise.
|
||||
"""
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={"id": graph_exec_id, "userId": user_id},
|
||||
include={
|
||||
"agentNodeExecutions": {"include": EXECUTION_RESULT_INCLUDE} # type: ignore
|
||||
},
|
||||
)
|
||||
return execution
|
||||
|
||||
|
||||
async def list_executions(graph_id: str, graph_version: int | None = None) -> list[str]:
|
||||
where: AgentGraphExecutionWhereInput = {"agentGraphId": graph_id}
|
||||
if graph_version is not None:
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import Future, ProcessPoolExecutor, TimeoutError
|
||||
import threading
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing.pool import AsyncResult, Pool
|
||||
from multiprocessing.synchronize import Event
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -16,6 +19,7 @@ from autogpt_server.data.execution import (
|
||||
GraphExecution,
|
||||
NodeExecution,
|
||||
create_graph_execution,
|
||||
get_execution_results,
|
||||
get_incomplete_executions,
|
||||
get_latest_execution,
|
||||
merge_execution_input,
|
||||
@@ -350,76 +354,110 @@ class Executor:
|
||||
@classmethod
|
||||
def on_graph_executor_start(cls):
|
||||
cls.pool_size = Config().num_node_workers
|
||||
cls.executor = ProcessPoolExecutor(
|
||||
max_workers=cls.pool_size,
|
||||
initializer=cls.on_node_executor_start,
|
||||
)
|
||||
cls._init_node_executor_pool()
|
||||
logger.warning(f"Graph executor started with max-{cls.pool_size} node workers.")
|
||||
|
||||
@classmethod
|
||||
def on_graph_execution(cls, graph_data: GraphExecution):
|
||||
def _init_node_executor_pool(cls):
|
||||
cls.executor = Pool(
|
||||
processes=cls.pool_size,
|
||||
initializer=cls.on_node_executor_start,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def on_graph_execution(cls, graph_data: GraphExecution, cancel: Event):
|
||||
prefix = get_log_prefix(graph_data.graph_exec_id, "*")
|
||||
logger.warning(f"{prefix} Start graph execution")
|
||||
|
||||
finished = False
|
||||
|
||||
def cancel_handler():
|
||||
while not cancel.is_set():
|
||||
cancel.wait(1)
|
||||
if finished:
|
||||
return
|
||||
cls.executor.terminate()
|
||||
logger.info(
|
||||
f"{prefix} Terminated graph execution {graph_data.graph_exec_id}"
|
||||
)
|
||||
cls._init_node_executor_pool()
|
||||
|
||||
cancel_thread = threading.Thread(target=cancel_handler)
|
||||
cancel_thread.start()
|
||||
|
||||
try:
|
||||
queue = ExecutionQueue[NodeExecution]()
|
||||
for node_exec in graph_data.start_node_execs:
|
||||
queue.add(node_exec)
|
||||
|
||||
futures: dict[str, Future] = {}
|
||||
running_executions: dict[str, AsyncResult] = {}
|
||||
while not queue.empty():
|
||||
execution = queue.get()
|
||||
if cancel.is_set():
|
||||
return
|
||||
|
||||
exec_data = queue.get()
|
||||
|
||||
# Avoid parallel execution of the same node.
|
||||
fut = futures.get(execution.node_id)
|
||||
if fut and not fut.done():
|
||||
execution = running_executions.get(exec_data.node_id)
|
||||
if execution and not execution.ready():
|
||||
# TODO (performance improvement):
|
||||
# Wait for the completion of the same node execution is blocking.
|
||||
# To improve this we need a separate queue for each node.
|
||||
# Re-enqueueing the data back to the queue will disrupt the order.
|
||||
cls.wait_future(fut, timeout=None)
|
||||
execution.wait()
|
||||
|
||||
futures[execution.node_id] = cls.executor.submit(
|
||||
cls.on_node_execution, queue, execution
|
||||
running_executions[exec_data.node_id] = cls.executor.apply_async(
|
||||
cls.on_node_execution,
|
||||
(queue, exec_data),
|
||||
callback=lambda _: running_executions.pop(exec_data.node_id),
|
||||
)
|
||||
|
||||
# Avoid terminating graph execution when some nodes are still running.
|
||||
while queue.empty() and futures:
|
||||
for node_id, future in list(futures.items()):
|
||||
if future.done():
|
||||
del futures[node_id]
|
||||
elif queue.empty():
|
||||
cls.wait_future(future)
|
||||
while queue.empty() and running_executions:
|
||||
for execution in list(running_executions.values()):
|
||||
if cancel.is_set():
|
||||
return
|
||||
|
||||
if not queue.empty():
|
||||
break # yield to parent loop to execute new queue items
|
||||
|
||||
execution.wait(3)
|
||||
|
||||
logger.warning(f"{prefix} Finished graph execution")
|
||||
except Exception as e:
|
||||
logger.exception(f"{prefix} Failed graph execution: {e}")
|
||||
|
||||
@classmethod
|
||||
def wait_future(cls, future: Future, timeout: int | None = 3):
|
||||
try:
|
||||
if not future.done():
|
||||
future.result(timeout=timeout)
|
||||
except TimeoutError:
|
||||
# Avoid being blocked by long-running node, by not waiting its completion.
|
||||
pass
|
||||
finally:
|
||||
if not cancel.is_set():
|
||||
finished = True
|
||||
cancel.set()
|
||||
cancel_thread.join()
|
||||
|
||||
|
||||
class ExecutionManager(AppService):
|
||||
def __init__(self):
|
||||
self.pool_size = Config().num_graph_workers
|
||||
self.queue = ExecutionQueue[GraphExecution]()
|
||||
self.active_graph_runs: dict[str, tuple[Future, Event]] = {}
|
||||
|
||||
def run_service(self):
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
initializer=Executor.on_graph_executor_start,
|
||||
) as executor:
|
||||
logger.warning(
|
||||
logger.info(
|
||||
f"Execution manager started with max-{self.pool_size} graph workers."
|
||||
)
|
||||
while True:
|
||||
executor.submit(Executor.on_graph_execution, self.queue.get())
|
||||
graph_exec_data = self.queue.get()
|
||||
graph_exec_id = graph_exec_data.graph_exec_id
|
||||
cancel_event = executor._mp_context.Event()
|
||||
future = executor.submit(
|
||||
Executor.on_graph_execution, graph_exec_data, cancel_event
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
future.add_done_callback(
|
||||
lambda _: self.active_graph_runs.pop(graph_exec_id)
|
||||
)
|
||||
|
||||
@property
|
||||
def agent_server_client(self) -> "AgentServer":
|
||||
@@ -428,7 +466,7 @@ class ExecutionManager(AppService):
|
||||
@expose
|
||||
def add_execution(
|
||||
self, graph_id: str, data: BlockInput, user_id: str
|
||||
) -> dict[Any, Any]:
|
||||
) -> dict[str, Any]:
|
||||
graph: Graph | None = self.run_and_wait(get_graph(graph_id, user_id=user_id))
|
||||
if not graph:
|
||||
raise Exception(f"Graph #{graph_id} not found.")
|
||||
@@ -479,4 +517,45 @@ class ExecutionManager(AppService):
|
||||
)
|
||||
self.queue.add(graph_exec)
|
||||
|
||||
return {"id": graph_exec_id}
|
||||
return graph_exec.model_dump()
|
||||
|
||||
@expose
|
||||
def cancel_execution(self, graph_exec_id: str) -> None:
|
||||
"""
|
||||
Mechanism:
|
||||
1. Set the cancel event
|
||||
2. Graph executor's cancel handler thread detects the event, terminates workers,
|
||||
reinitializes worker pool, and returns.
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
if graph_exec_id not in self.active_graph_runs:
|
||||
raise Exception(
|
||||
f"Graph execution #{graph_exec_id} not active/running: "
|
||||
"possibly already completed/cancelled."
|
||||
)
|
||||
|
||||
future, cancel_event = self.active_graph_runs[graph_exec_id]
|
||||
if cancel_event.is_set():
|
||||
return
|
||||
|
||||
cancel_event.set()
|
||||
future.result()
|
||||
|
||||
# Update the status of the unfinished node executions
|
||||
node_execs = self.run_and_wait(get_execution_results(graph_exec_id))
|
||||
for node_exec in node_execs:
|
||||
if node_exec.status not in (
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.FAILED,
|
||||
):
|
||||
exec_update = self.run_and_wait(
|
||||
update_execution_status(
|
||||
node_exec.node_exec_id, ExecutionStatus.FAILED
|
||||
)
|
||||
)
|
||||
self.run_and_wait(
|
||||
upsert_execution_output(
|
||||
node_exec.node_exec_id, "error", "TERMINATED"
|
||||
)
|
||||
)
|
||||
self.agent_server_client.send_execution_update(exec_update.model_dump())
|
||||
|
||||
@@ -22,14 +22,10 @@ from fastapi.responses import JSONResponse
|
||||
|
||||
import autogpt_server.server.ws_api
|
||||
from autogpt_server.data import block, db
|
||||
from autogpt_server.data import execution as execution_db
|
||||
from autogpt_server.data import graph as graph_db
|
||||
from autogpt_server.data import user as user_db
|
||||
from autogpt_server.data.block import BlockInput, CompletedBlockOutput
|
||||
from autogpt_server.data.execution import (
|
||||
ExecutionResult,
|
||||
get_execution_results,
|
||||
list_executions,
|
||||
)
|
||||
from autogpt_server.data.user import DEFAULT_USER_ID, get_or_create_user
|
||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||
from autogpt_server.server.conn_manager import ConnectionManager
|
||||
@@ -58,7 +54,7 @@ def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
|
||||
|
||||
|
||||
class AgentServer(AppService):
|
||||
event_queue: asyncio.Queue[ExecutionResult] = asyncio.Queue()
|
||||
event_queue: asyncio.Queue[execution_db.ExecutionResult] = asyncio.Queue()
|
||||
manager = ConnectionManager()
|
||||
mutex = KeyedMutex()
|
||||
use_db = False
|
||||
@@ -66,7 +62,7 @@ class AgentServer(AppService):
|
||||
|
||||
async def event_broadcaster(self):
|
||||
while True:
|
||||
event: ExecutionResult = await self.event_queue.get()
|
||||
event: execution_db.ExecutionResult = await self.event_queue.get()
|
||||
await self.manager.send_execution_result(event)
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -193,10 +189,15 @@ class AgentServer(AppService):
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
path="/graphs/{graph_id}/executions/{run_id}",
|
||||
endpoint=self.get_run_execution_results,
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}",
|
||||
endpoint=self.get_graph_run_node_execution_results,
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}/stop",
|
||||
endpoint=self.stop_graph_run,
|
||||
methods=["POST"],
|
||||
)
|
||||
router.add_api_route(
|
||||
path="/graphs/{graph_id}/schedules",
|
||||
endpoint=self.create_schedule, # type: ignore
|
||||
@@ -508,15 +509,29 @@ class AgentServer(AppService):
|
||||
graph_id: str,
|
||||
node_input: dict[Any, Any],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[Any, Any]:
|
||||
) -> dict[str, Any]: # FIXME: add proper return type
|
||||
try:
|
||||
return self.execution_manager_client.add_execution(
|
||||
graph_exec = self.execution_manager_client.add_execution(
|
||||
graph_id, node_input, user_id=user_id
|
||||
)
|
||||
return {"id": graph_exec["graph_exec_id"]}
|
||||
except Exception as e:
|
||||
msg = e.__str__().encode().decode("unicode_escape")
|
||||
raise HTTPException(status_code=400, detail=msg)
|
||||
|
||||
async def stop_graph_run(
|
||||
self, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[execution_db.ExecutionResult]:
|
||||
if not await execution_db.get_graph_execution(graph_exec_id, user_id):
|
||||
raise HTTPException(
|
||||
404, detail=f"Agent execution #{graph_exec_id} not found"
|
||||
)
|
||||
|
||||
self.execution_manager_client.cancel_execution(graph_exec_id)
|
||||
|
||||
# Retrieve & return canceled graph execution in its final state
|
||||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
|
||||
@classmethod
|
||||
async def list_graph_runs(
|
||||
cls,
|
||||
@@ -531,17 +546,20 @@ class AgentServer(AppService):
|
||||
status_code=404, detail=f"Agent #{graph_id}{rev} not found."
|
||||
)
|
||||
|
||||
return await list_executions(graph_id, graph_version)
|
||||
return await execution_db.list_executions(graph_id, graph_version)
|
||||
|
||||
@classmethod
|
||||
async def get_run_execution_results(
|
||||
cls, graph_id: str, run_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[ExecutionResult]:
|
||||
async def get_graph_run_node_execution_results(
|
||||
cls,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[execution_db.ExecutionResult]:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
return await get_execution_results(run_id)
|
||||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
|
||||
async def create_schedule(
|
||||
self,
|
||||
@@ -579,7 +597,7 @@ class AgentServer(AppService):
|
||||
|
||||
@expose
|
||||
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
|
||||
execution_result = ExecutionResult(**execution_result_dict)
|
||||
execution_result = execution_db.ExecutionResult(**execution_result_dict)
|
||||
self.run_and_wait(self.event_queue.put(execution_result))
|
||||
|
||||
@expose
|
||||
|
||||
@@ -16,11 +16,12 @@ from autogpt_server.util.settings import Config
|
||||
logger = logging.getLogger(__name__)
|
||||
conn_retry = retry(stop=stop_after_delay(5), wait=wait_exponential(multiplier=0.1))
|
||||
T = TypeVar("T")
|
||||
C = TypeVar("C", bound=Callable)
|
||||
|
||||
pyro_host = Config().pyro_host
|
||||
|
||||
|
||||
def expose(func: Callable) -> Callable:
|
||||
def expose(func: C) -> C:
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
@@ -29,7 +30,7 @@ def expose(func: Callable) -> Callable:
|
||||
logger.exception(msg)
|
||||
raise Exception(msg, e)
|
||||
|
||||
return pyro.expose(wrapper)
|
||||
return pyro.expose(wrapper) # type: ignore
|
||||
|
||||
|
||||
class PyroNameServer(AppProcess):
|
||||
@@ -58,7 +59,7 @@ class AppService(AppProcess):
|
||||
def __run_async(self, coro: Coroutine[T, Any, T]):
|
||||
return asyncio.run_coroutine_threadsafe(coro, self.shared_event_loop)
|
||||
|
||||
def run_and_wait(self, coro: Coroutine[T, Any, T]) -> T:
|
||||
def run_and_wait(self, coro: Coroutine[Any, Any, T]) -> T:
|
||||
future = self.__run_async(coro)
|
||||
return future.result()
|
||||
|
||||
@@ -100,7 +101,6 @@ def get_service_client(service_type: Type[AS]) -> AS:
|
||||
service_name = service_type.service_name
|
||||
|
||||
class DynamicClient:
|
||||
|
||||
@conn_retry
|
||||
def __init__(self):
|
||||
ns = pyro.locate_ns()
|
||||
|
||||
@@ -23,7 +23,6 @@ class SpinTestServer:
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
async def __aenter__(self):
|
||||
|
||||
self.name_server.__enter__()
|
||||
self.setup_dependency_overrides()
|
||||
self.agent_server.__enter__()
|
||||
@@ -59,7 +58,7 @@ async def wait_execution(
|
||||
timeout: int = 20,
|
||||
) -> list:
|
||||
async def is_execution_completed():
|
||||
execs = await AgentServer().get_run_execution_results(
|
||||
execs = await AgentServer().get_graph_run_node_execution_results(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
return (
|
||||
@@ -74,7 +73,7 @@ async def wait_execution(
|
||||
# Wait for the executions to complete
|
||||
for i in range(timeout):
|
||||
if await is_execution_completed():
|
||||
return await AgentServer().get_run_execution_results(
|
||||
return await AgentServer().get_graph_run_node_execution_results(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
@@ -36,7 +36,7 @@ async def assert_sample_graph_executions(
|
||||
graph_exec_id: str,
|
||||
):
|
||||
input = {"input_1": "Hello", "input_2": "World"}
|
||||
executions = await agent_server.get_run_execution_results(
|
||||
executions = await agent_server.get_graph_run_node_execution_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user