mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-10 23:58:06 -05:00
feat(server): Add cancel_execution method to ExecutionManager
- Add `ExecutionManager.cancel_execution(..)` - Replace graph executor's `ProcessPoolExecutor` by `multiprocessing.pool.Pool` - Remove now-unnecessary `Executor.wait_future(..)` method - Add termination mechanism to `Executor.on_graph_execution`
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import Future, ProcessPoolExecutor, TimeoutError
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar
|
||||
from multiprocessing.pool import AsyncResult, Pool
|
||||
from multiprocessing.sharedctypes import Array, SynchronizedString
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt_server.server.server import AgentServer
|
||||
@@ -16,6 +18,7 @@ from autogpt_server.data.execution import (
|
||||
GraphExecution,
|
||||
NodeExecution,
|
||||
create_graph_execution,
|
||||
get_execution_results,
|
||||
get_incomplete_executions,
|
||||
get_latest_execution,
|
||||
merge_execution_input,
|
||||
@@ -330,6 +333,9 @@ class Executor:
|
||||
9. Node executor enqueues the next executed nodes to the node execution queue.
|
||||
"""
|
||||
|
||||
STATUS_RUNNING = b"running"
|
||||
STATUS_CANCELLED = b"cancelled"
|
||||
|
||||
@classmethod
|
||||
def on_node_executor_start(cls):
|
||||
cls.loop = asyncio.new_event_loop()
|
||||
@@ -350,65 +356,76 @@ class Executor:
|
||||
@classmethod
|
||||
def on_graph_executor_start(cls):
|
||||
cls.pool_size = Config().num_node_workers
|
||||
cls.executor = ProcessPoolExecutor(
|
||||
max_workers=cls.pool_size,
|
||||
initializer=cls.on_node_executor_start,
|
||||
)
|
||||
cls._init_node_executor_pool()
|
||||
logger.warning(f"Graph executor started with max-{cls.pool_size} node workers.")
|
||||
|
||||
@classmethod
|
||||
def on_graph_execution(cls, graph_data: GraphExecution):
|
||||
def _init_node_executor_pool(cls):
|
||||
cls.executor = Pool(
|
||||
processes=cls.pool_size,
|
||||
initializer=cls.on_node_executor_start,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def on_graph_execution(cls, graph_data: GraphExecution, status: SynchronizedString):
|
||||
prefix = get_log_prefix(graph_data.graph_exec_id, "*")
|
||||
logger.warning(f"{prefix} Start graph execution")
|
||||
|
||||
def terminate():
|
||||
cls.executor.terminate()
|
||||
logger.info(
|
||||
f"{prefix} Terminated graph execution {exec_data.graph_exec_id}"
|
||||
)
|
||||
cls._init_node_executor_pool()
|
||||
|
||||
try:
|
||||
queue = ExecutionQueue[NodeExecution]()
|
||||
for node_exec in graph_data.start_node_execs:
|
||||
queue.add(node_exec)
|
||||
|
||||
futures: dict[str, Future] = {}
|
||||
running_executions: dict[str, AsyncResult] = {}
|
||||
while not queue.empty():
|
||||
execution = queue.get()
|
||||
if status.value == cls.STATUS_CANCELLED:
|
||||
terminate()
|
||||
return
|
||||
|
||||
exec_data = queue.get()
|
||||
|
||||
# Avoid parallel execution of the same node.
|
||||
fut = futures.get(execution.node_id)
|
||||
if fut and not fut.done():
|
||||
exec = running_executions.get(exec_data.node_id)
|
||||
if exec and not exec.ready():
|
||||
# TODO (performance improvement):
|
||||
# Wait for the completion of the same node execution is blocking.
|
||||
# To improve this we need a separate queue for each node.
|
||||
# Re-enqueueing the data back to the queue will disrupt the order.
|
||||
cls.wait_future(fut, timeout=None)
|
||||
exec.wait()
|
||||
|
||||
futures[execution.node_id] = cls.executor.submit(
|
||||
cls.on_node_execution, queue, execution
|
||||
running_executions[exec_data.node_id] = cls.executor.apply_async(
|
||||
cls.on_node_execution, (queue, exec_data)
|
||||
)
|
||||
|
||||
# Avoid terminating graph execution when some nodes are still running.
|
||||
while queue.empty() and futures:
|
||||
for node_id, future in list(futures.items()):
|
||||
if future.done():
|
||||
del futures[node_id]
|
||||
while queue.empty() and running_executions:
|
||||
for node_id, execution in list(running_executions.items()):
|
||||
if status.value == cls.STATUS_CANCELLED:
|
||||
terminate()
|
||||
return
|
||||
|
||||
if execution.ready():
|
||||
del running_executions[node_id]
|
||||
elif queue.empty():
|
||||
cls.wait_future(future)
|
||||
execution.wait(3)
|
||||
|
||||
logger.warning(f"{prefix} Finished graph execution")
|
||||
except Exception as e:
|
||||
logger.exception(f"{prefix} Failed graph execution: {e}")
|
||||
|
||||
@classmethod
|
||||
def wait_future(cls, future: Future, timeout: int | None = 3):
|
||||
try:
|
||||
if not future.done():
|
||||
future.result(timeout=timeout)
|
||||
except TimeoutError:
|
||||
# Avoid being blocked by long-running node, by not waiting its completion.
|
||||
pass
|
||||
|
||||
|
||||
class ExecutionManager(AppService):
|
||||
def __init__(self):
|
||||
self.pool_size = Config().num_graph_workers
|
||||
self.queue = ExecutionQueue[GraphExecution]()
|
||||
self.active_graph_runs: dict[str, tuple[Future, SynchronizedString]] = {}
|
||||
|
||||
def run_service(self):
|
||||
with ProcessPoolExecutor(
|
||||
@@ -419,7 +436,15 @@ class ExecutionManager(AppService):
|
||||
f"Execution manager started with max-{self.pool_size} graph workers."
|
||||
)
|
||||
while True:
|
||||
executor.submit(Executor.on_graph_execution, self.queue.get())
|
||||
graph_exec_data = self.queue.get()
|
||||
status = cast(SynchronizedString, Array("c", Executor.STATUS_RUNNING))
|
||||
future = executor.submit(
|
||||
Executor.on_graph_execution, graph_exec_data, status
|
||||
)
|
||||
self.active_graph_runs[graph_exec_data.graph_exec_id] = (future, status)
|
||||
future.add_done_callback(
|
||||
lambda _: self.active_graph_runs.pop(graph_exec_data.graph_exec_id)
|
||||
)
|
||||
|
||||
@property
|
||||
def agent_server_client(self) -> "AgentServer":
|
||||
@@ -480,3 +505,48 @@ class ExecutionManager(AppService):
|
||||
self.queue.add(graph_exec)
|
||||
|
||||
return {"id": graph_exec_id}
|
||||
|
||||
@expose
|
||||
def cancel_execution(self, graph_exec_id: str) -> None:
|
||||
"""
|
||||
Mechanism:
|
||||
1. Set shared memory `status` to `"cancelled"`
|
||||
2. Graph executor checks `status`.
|
||||
If `"cancelled"`: terminates its workers, reinitializes its worker pool,
|
||||
and returns.
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
if graph_exec_id not in self.active_graph_runs:
|
||||
raise Exception(
|
||||
f"Graph execution #{graph_exec_id} not active/running: "
|
||||
"possibly already completed/cancelled."
|
||||
)
|
||||
|
||||
future, status = self.active_graph_runs[graph_exec_id]
|
||||
if status.value == Executor.STATUS_CANCELLED:
|
||||
return
|
||||
|
||||
if future.cancel():
|
||||
del self.active_graph_runs[graph_exec_id]
|
||||
else:
|
||||
status.value = Executor.STATUS_CANCELLED
|
||||
future.result()
|
||||
|
||||
# Update the status of the unfinished node executions
|
||||
node_execs = self.run_and_wait(get_execution_results(graph_exec_id))
|
||||
for node_exec in node_execs:
|
||||
if node_exec.status not in (
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.FAILED,
|
||||
):
|
||||
exec_update = self.run_and_wait(
|
||||
update_execution_status(
|
||||
node_exec.node_exec_id, ExecutionStatus.FAILED
|
||||
)
|
||||
)
|
||||
self.run_and_wait(
|
||||
upsert_execution_output(
|
||||
node_exec.node_exec_id, "error", "TERMINATED"
|
||||
)
|
||||
)
|
||||
self.agent_server_client.send_execution_update(exec_update.model_dump())
|
||||
|
||||
Reference in New Issue
Block a user