feat(server): Add cancel_execution method to ExecutionManager

- Add `ExecutionManager.cancel_execution(..)`
- Replace graph executor's `ProcessPoolExecutor` by `multiprocessing.pool.Pool`
  - Remove now-unnecessary `Executor.wait_future(..)` method
- Add termination mechanism to `Executor.on_graph_execution`
This commit is contained in:
Reinier van der Leer
2024-08-26 15:07:48 +02:00
parent 31129bd080
commit e6aaf71f21

View File

@@ -1,8 +1,10 @@
import asyncio
import logging
from concurrent.futures import Future, ProcessPoolExecutor, TimeoutError
from concurrent.futures import Future, ProcessPoolExecutor
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar
from multiprocessing.pool import AsyncResult, Pool
from multiprocessing.sharedctypes import Array, SynchronizedString
from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar, cast
if TYPE_CHECKING:
from autogpt_server.server.server import AgentServer
@@ -16,6 +18,7 @@ from autogpt_server.data.execution import (
GraphExecution,
NodeExecution,
create_graph_execution,
get_execution_results,
get_incomplete_executions,
get_latest_execution,
merge_execution_input,
@@ -330,6 +333,9 @@ class Executor:
9. Node executor enqueues the next executed nodes to the node execution queue.
"""
STATUS_RUNNING = b"running"
STATUS_CANCELLED = b"cancelled"
@classmethod
def on_node_executor_start(cls):
cls.loop = asyncio.new_event_loop()
@@ -350,65 +356,76 @@ class Executor:
@classmethod
def on_graph_executor_start(cls):
cls.pool_size = Config().num_node_workers
cls.executor = ProcessPoolExecutor(
max_workers=cls.pool_size,
initializer=cls.on_node_executor_start,
)
cls._init_node_executor_pool()
logger.warning(f"Graph executor started with max-{cls.pool_size} node workers.")
@classmethod
def on_graph_execution(cls, graph_data: GraphExecution):
def _init_node_executor_pool(cls):
cls.executor = Pool(
processes=cls.pool_size,
initializer=cls.on_node_executor_start,
)
@classmethod
def on_graph_execution(cls, graph_data: GraphExecution, status: SynchronizedString):
prefix = get_log_prefix(graph_data.graph_exec_id, "*")
logger.warning(f"{prefix} Start graph execution")
def terminate():
cls.executor.terminate()
logger.info(
f"{prefix} Terminated graph execution {exec_data.graph_exec_id}"
)
cls._init_node_executor_pool()
try:
queue = ExecutionQueue[NodeExecution]()
for node_exec in graph_data.start_node_execs:
queue.add(node_exec)
futures: dict[str, Future] = {}
running_executions: dict[str, AsyncResult] = {}
while not queue.empty():
execution = queue.get()
if status.value == cls.STATUS_CANCELLED:
terminate()
return
exec_data = queue.get()
# Avoid parallel execution of the same node.
fut = futures.get(execution.node_id)
if fut and not fut.done():
exec = running_executions.get(exec_data.node_id)
if exec and not exec.ready():
# TODO (performance improvement):
# Wait for the completion of the same node execution is blocking.
# To improve this we need a separate queue for each node.
# Re-enqueueing the data back to the queue will disrupt the order.
cls.wait_future(fut, timeout=None)
exec.wait()
futures[execution.node_id] = cls.executor.submit(
cls.on_node_execution, queue, execution
running_executions[exec_data.node_id] = cls.executor.apply_async(
cls.on_node_execution, (queue, exec_data)
)
# Avoid terminating graph execution when some nodes are still running.
while queue.empty() and futures:
for node_id, future in list(futures.items()):
if future.done():
del futures[node_id]
while queue.empty() and running_executions:
for node_id, execution in list(running_executions.items()):
if status.value == cls.STATUS_CANCELLED:
terminate()
return
if execution.ready():
del running_executions[node_id]
elif queue.empty():
cls.wait_future(future)
execution.wait(3)
logger.warning(f"{prefix} Finished graph execution")
except Exception as e:
logger.exception(f"{prefix} Failed graph execution: {e}")
@classmethod
def wait_future(cls, future: Future, timeout: int | None = 3):
try:
if not future.done():
future.result(timeout=timeout)
except TimeoutError:
# Avoid being blocked by long-running node, by not waiting its completion.
pass
class ExecutionManager(AppService):
def __init__(self):
self.pool_size = Config().num_graph_workers
self.queue = ExecutionQueue[GraphExecution]()
self.active_graph_runs: dict[str, tuple[Future, SynchronizedString]] = {}
def run_service(self):
with ProcessPoolExecutor(
@@ -419,7 +436,15 @@ class ExecutionManager(AppService):
f"Execution manager started with max-{self.pool_size} graph workers."
)
while True:
executor.submit(Executor.on_graph_execution, self.queue.get())
graph_exec_data = self.queue.get()
status = cast(SynchronizedString, Array("c", Executor.STATUS_RUNNING))
future = executor.submit(
Executor.on_graph_execution, graph_exec_data, status
)
self.active_graph_runs[graph_exec_data.graph_exec_id] = (future, status)
future.add_done_callback(
lambda _: self.active_graph_runs.pop(graph_exec_data.graph_exec_id)
)
@property
def agent_server_client(self) -> "AgentServer":
@@ -480,3 +505,48 @@ class ExecutionManager(AppService):
self.queue.add(graph_exec)
return {"id": graph_exec_id}
@expose
def cancel_execution(self, graph_exec_id: str) -> None:
"""
Mechanism:
1. Set shared memory `status` to `"cancelled"`
2. Graph executor checks `status`.
If `"cancelled"`: terminates its workers, reinitializes its worker pool,
and returns.
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
"""
if graph_exec_id not in self.active_graph_runs:
raise Exception(
f"Graph execution #{graph_exec_id} not active/running: "
"possibly already completed/cancelled."
)
future, status = self.active_graph_runs[graph_exec_id]
if status.value == Executor.STATUS_CANCELLED:
return
if future.cancel():
del self.active_graph_runs[graph_exec_id]
else:
status.value = Executor.STATUS_CANCELLED
future.result()
# Update the status of the unfinished node executions
node_execs = self.run_and_wait(get_execution_results(graph_exec_id))
for node_exec in node_execs:
if node_exec.status not in (
ExecutionStatus.COMPLETED,
ExecutionStatus.FAILED,
):
exec_update = self.run_and_wait(
update_execution_status(
node_exec.node_exec_id, ExecutionStatus.FAILED
)
)
self.run_and_wait(
upsert_execution_output(
node_exec.node_exec_id, "error", "TERMINATED"
)
)
self.agent_server_client.send_execution_update(exec_update.model_dump())