Compare commits

...

11 Commits

Author SHA1 Message Date
Reinier van der Leer
e2c2b48d77 fix 2024-08-27 11:31:35 +02:00
Reinier van der Leer
bd36acccb2 revert to unmanaged multiprocessing.Event() to let others test/debug deadlock issue 2024-08-27 11:16:36 +02:00
Reinier van der Leer
d63ab9a2f9 fix(server): Fix deadlock and simplify cancel mechanism in Executor.on_graph_execution
This took many hours of bonking my head against the wall, but in the end I found that

      `multiprocess.Event()` or `multiprocess.Event(ctx=executor._mp_context)`

doesn't work with the `ProcessPoolExecutor`, and instead I had to use

      `multiprocess.Manager().Event()`

This adds some overhead, but at least it works. The deadlocking issue occurred for all shared types, e.g. I also tried `Value` and `Array`.
2024-08-27 10:48:54 +02:00
Reinier van der Leer
fce6394a49 fix agent execution endpoint 2024-08-26 17:34:01 +02:00
Reinier van der Leer
13e7716424 feat(builder): Add stopGraphExecution(..) to AutoGPTServerAPI 2024-08-26 17:06:19 +02:00
Reinier van der Leer
2973567010 smol refactor for consistency & readability 2024-08-26 16:42:14 +02:00
Reinier van der Leer
b6c4fc4742 feat(server): Add POST /graphs/{graph_id}/executions/{graph_exec_id}/stop endpoint
- Add `stop_graph_execution` + route in `AgentServer`
- Add `get_graph_execution` function in `.data.execution`
- Fix return type of `ExecutionManager.add_execution(..)`
- Fix type issue with `@expose` decorator
2024-08-26 16:28:08 +02:00
Reinier van der Leer
f9a3170296 Merge branch 'master' into reinier/open-1669-implement-stop-button-for-agent-runs 2024-08-26 15:41:17 +02:00
Reinier van der Leer
a74f76893e smol refactor for readability 2024-08-26 15:40:54 +02:00
Reinier van der Leer
e6aaf71f21 feat(server): Add cancel_execution method to ExecutionManager
- Add `ExecutionManager.cancel_execution(..)`
- Replace graph executor's `ProcessPoolExecutor` by `multiprocessing.pool.Pool`
  - Remove now-unnecessary `Executor.wait_future(..)` method
- Add termination mechanism to `Executor.on_graph_execution`
2024-08-26 15:07:48 +02:00
Reinier van der Leer
31129bd080 fix type issue with AppService.run_and_wait(..) 2024-08-26 14:59:23 +02:00
7 changed files with 197 additions and 64 deletions

View File

@@ -126,16 +126,19 @@ export default class AutoGPTServerAPI {
runID: string,
): Promise<NodeExecutionResult[]> {
return (await this._get(`/graphs/${graphID}/executions/${runID}`)).map(
(result: any) => ({
...result,
add_time: new Date(result.add_time),
queue_time: result.queue_time ? new Date(result.queue_time) : undefined,
start_time: result.start_time ? new Date(result.start_time) : undefined,
end_time: result.end_time ? new Date(result.end_time) : undefined,
}),
parseNodeExecutionResultTimestamps,
);
}
async stopGraphExecution(
graphID: string,
runID: string,
): Promise<NodeExecutionResult[]> {
return (
await this._request("POST", `/graphs/${graphID}/executions/${runID}/stop`)
).map(parseNodeExecutionResultTimestamps);
}
private async _get(path: string) {
return this._request("GET", path);
}
@@ -272,3 +275,15 @@ type WebsocketMessageTypeMap = {
subscribe: { graph_id: string };
execution_event: NodeExecutionResult;
};
/* *** HELPER FUNCTIONS *** */
function parseNodeExecutionResultTimestamps(result: any): NodeExecutionResult {
return {
...result,
add_time: new Date(result.add_time),
queue_time: result.queue_time ? new Date(result.queue_time) : undefined,
start_time: result.start_time ? new Date(result.start_time) : undefined,
end_time: result.end_time ? new Date(result.end_time) : undefined,
};
}

View File

@@ -271,6 +271,28 @@ async def update_execution_status(
return ExecutionResult.from_db(res)
async def get_graph_execution(
graph_exec_id: str, user_id: str
) -> AgentGraphExecution | None:
"""
Retrieve a specific graph execution by its ID.
Args:
graph_exec_id (str): The ID of the graph execution to retrieve.
user_id (str): The ID of the user to whom the graph (execution) belongs.
Returns:
AgentGraphExecution | None: The graph execution if found, None otherwise.
"""
execution = await AgentGraphExecution.prisma().find_first(
where={"id": graph_exec_id, "userId": user_id},
include={
"agentNodeExecutions": {"include": EXECUTION_RESULT_INCLUDE} # type: ignore
},
)
return execution
async def list_executions(graph_id: str, graph_version: int | None = None) -> list[str]:
where: AgentGraphExecutionWhereInput = {"agentGraphId": graph_id}
if graph_version is not None:

View File

@@ -1,7 +1,10 @@
import asyncio
import logging
from concurrent.futures import Future, ProcessPoolExecutor, TimeoutError
import threading
from concurrent.futures import Future, ProcessPoolExecutor
from contextlib import contextmanager
from multiprocessing.pool import AsyncResult, Pool
from multiprocessing.synchronize import Event
from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar
if TYPE_CHECKING:
@@ -16,6 +19,7 @@ from autogpt_server.data.execution import (
GraphExecution,
NodeExecution,
create_graph_execution,
get_execution_results,
get_incomplete_executions,
get_latest_execution,
merge_execution_input,
@@ -350,76 +354,110 @@ class Executor:
@classmethod
def on_graph_executor_start(cls):
cls.pool_size = Config().num_node_workers
cls.executor = ProcessPoolExecutor(
max_workers=cls.pool_size,
initializer=cls.on_node_executor_start,
)
cls._init_node_executor_pool()
logger.warning(f"Graph executor started with max-{cls.pool_size} node workers.")
@classmethod
def on_graph_execution(cls, graph_data: GraphExecution):
def _init_node_executor_pool(cls):
cls.executor = Pool(
processes=cls.pool_size,
initializer=cls.on_node_executor_start,
)
@classmethod
def on_graph_execution(cls, graph_data: GraphExecution, cancel: Event):
prefix = get_log_prefix(graph_data.graph_exec_id, "*")
logger.warning(f"{prefix} Start graph execution")
finished = False
def cancel_handler():
while not cancel.is_set():
cancel.wait(1)
if finished:
return
cls.executor.terminate()
logger.info(
f"{prefix} Terminated graph execution {graph_data.graph_exec_id}"
)
cls._init_node_executor_pool()
cancel_thread = threading.Thread(target=cancel_handler)
cancel_thread.start()
try:
queue = ExecutionQueue[NodeExecution]()
for node_exec in graph_data.start_node_execs:
queue.add(node_exec)
futures: dict[str, Future] = {}
running_executions: dict[str, AsyncResult] = {}
while not queue.empty():
execution = queue.get()
if cancel.is_set():
return
exec_data = queue.get()
# Avoid parallel execution of the same node.
fut = futures.get(execution.node_id)
if fut and not fut.done():
execution = running_executions.get(exec_data.node_id)
if execution and not execution.ready():
# TODO (performance improvement):
# Wait for the completion of the same node execution is blocking.
# To improve this we need a separate queue for each node.
# Re-enqueueing the data back to the queue will disrupt the order.
cls.wait_future(fut, timeout=None)
execution.wait()
futures[execution.node_id] = cls.executor.submit(
cls.on_node_execution, queue, execution
running_executions[exec_data.node_id] = cls.executor.apply_async(
cls.on_node_execution,
(queue, exec_data),
callback=lambda _: running_executions.pop(exec_data.node_id),
)
# Avoid terminating graph execution when some nodes are still running.
while queue.empty() and futures:
for node_id, future in list(futures.items()):
if future.done():
del futures[node_id]
elif queue.empty():
cls.wait_future(future)
while queue.empty() and running_executions:
for execution in list(running_executions.values()):
if cancel.is_set():
return
if not queue.empty():
break # yield to parent loop to execute new queue items
execution.wait(3)
logger.warning(f"{prefix} Finished graph execution")
except Exception as e:
logger.exception(f"{prefix} Failed graph execution: {e}")
@classmethod
def wait_future(cls, future: Future, timeout: int | None = 3):
try:
if not future.done():
future.result(timeout=timeout)
except TimeoutError:
# Avoid being blocked by long-running node, by not waiting its completion.
pass
finally:
if not cancel.is_set():
finished = True
cancel.set()
cancel_thread.join()
class ExecutionManager(AppService):
def __init__(self):
self.pool_size = Config().num_graph_workers
self.queue = ExecutionQueue[GraphExecution]()
self.active_graph_runs: dict[str, tuple[Future, Event]] = {}
def run_service(self):
with ProcessPoolExecutor(
max_workers=self.pool_size,
initializer=Executor.on_graph_executor_start,
) as executor:
logger.warning(
logger.info(
f"Execution manager started with max-{self.pool_size} graph workers."
)
while True:
executor.submit(Executor.on_graph_execution, self.queue.get())
graph_exec_data = self.queue.get()
graph_exec_id = graph_exec_data.graph_exec_id
cancel_event = executor._mp_context.Event()
future = executor.submit(
Executor.on_graph_execution, graph_exec_data, cancel_event
)
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
future.add_done_callback(
lambda _: self.active_graph_runs.pop(graph_exec_id)
)
@property
def agent_server_client(self) -> "AgentServer":
@@ -428,7 +466,7 @@ class ExecutionManager(AppService):
@expose
def add_execution(
self, graph_id: str, data: BlockInput, user_id: str
) -> dict[Any, Any]:
) -> dict[str, Any]:
graph: Graph | None = self.run_and_wait(get_graph(graph_id, user_id=user_id))
if not graph:
raise Exception(f"Graph #{graph_id} not found.")
@@ -479,4 +517,45 @@ class ExecutionManager(AppService):
)
self.queue.add(graph_exec)
return {"id": graph_exec_id}
return graph_exec.model_dump()
@expose
def cancel_execution(self, graph_exec_id: str) -> None:
"""
Mechanism:
1. Set the cancel event
2. Graph executor's cancel handler thread detects the event, terminates workers,
reinitializes worker pool, and returns.
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
"""
if graph_exec_id not in self.active_graph_runs:
raise Exception(
f"Graph execution #{graph_exec_id} not active/running: "
"possibly already completed/cancelled."
)
future, cancel_event = self.active_graph_runs[graph_exec_id]
if cancel_event.is_set():
return
cancel_event.set()
future.result()
# Update the status of the unfinished node executions
node_execs = self.run_and_wait(get_execution_results(graph_exec_id))
for node_exec in node_execs:
if node_exec.status not in (
ExecutionStatus.COMPLETED,
ExecutionStatus.FAILED,
):
exec_update = self.run_and_wait(
update_execution_status(
node_exec.node_exec_id, ExecutionStatus.FAILED
)
)
self.run_and_wait(
upsert_execution_output(
node_exec.node_exec_id, "error", "TERMINATED"
)
)
self.agent_server_client.send_execution_update(exec_update.model_dump())

View File

@@ -22,14 +22,10 @@ from fastapi.responses import JSONResponse
import autogpt_server.server.ws_api
from autogpt_server.data import block, db
from autogpt_server.data import execution as execution_db
from autogpt_server.data import graph as graph_db
from autogpt_server.data import user as user_db
from autogpt_server.data.block import BlockInput, CompletedBlockOutput
from autogpt_server.data.execution import (
ExecutionResult,
get_execution_results,
list_executions,
)
from autogpt_server.data.user import DEFAULT_USER_ID, get_or_create_user
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
from autogpt_server.server.conn_manager import ConnectionManager
@@ -58,7 +54,7 @@ def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
class AgentServer(AppService):
event_queue: asyncio.Queue[ExecutionResult] = asyncio.Queue()
event_queue: asyncio.Queue[execution_db.ExecutionResult] = asyncio.Queue()
manager = ConnectionManager()
mutex = KeyedMutex()
use_db = False
@@ -66,7 +62,7 @@ class AgentServer(AppService):
async def event_broadcaster(self):
while True:
event: ExecutionResult = await self.event_queue.get()
event: execution_db.ExecutionResult = await self.event_queue.get()
await self.manager.send_execution_result(event)
@asynccontextmanager
@@ -193,10 +189,15 @@ class AgentServer(AppService):
methods=["GET"],
)
router.add_api_route(
path="/graphs/{graph_id}/executions/{run_id}",
endpoint=self.get_run_execution_results,
path="/graphs/{graph_id}/executions/{graph_exec_id}",
endpoint=self.get_graph_run_node_execution_results,
methods=["GET"],
)
router.add_api_route(
path="/graphs/{graph_id}/executions/{graph_exec_id}/stop",
endpoint=self.stop_graph_run,
methods=["POST"],
)
router.add_api_route(
path="/graphs/{graph_id}/schedules",
endpoint=self.create_schedule, # type: ignore
@@ -508,15 +509,29 @@ class AgentServer(AppService):
graph_id: str,
node_input: dict[Any, Any],
user_id: Annotated[str, Depends(get_user_id)],
) -> dict[Any, Any]:
) -> dict[str, Any]: # FIXME: add proper return type
try:
return self.execution_manager_client.add_execution(
graph_exec = self.execution_manager_client.add_execution(
graph_id, node_input, user_id=user_id
)
return {"id": graph_exec["graph_exec_id"]}
except Exception as e:
msg = e.__str__().encode().decode("unicode_escape")
raise HTTPException(status_code=400, detail=msg)
async def stop_graph_run(
self, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> list[execution_db.ExecutionResult]:
if not await execution_db.get_graph_execution(graph_exec_id, user_id):
raise HTTPException(
404, detail=f"Agent execution #{graph_exec_id} not found"
)
self.execution_manager_client.cancel_execution(graph_exec_id)
# Retrieve & return canceled graph execution in its final state
return await execution_db.get_execution_results(graph_exec_id)
@classmethod
async def list_graph_runs(
cls,
@@ -531,17 +546,20 @@ class AgentServer(AppService):
status_code=404, detail=f"Agent #{graph_id}{rev} not found."
)
return await list_executions(graph_id, graph_version)
return await execution_db.list_executions(graph_id, graph_version)
@classmethod
async def get_run_execution_results(
cls, graph_id: str, run_id: str, user_id: Annotated[str, Depends(get_user_id)]
) -> list[ExecutionResult]:
async def get_graph_run_node_execution_results(
cls,
graph_id: str,
graph_exec_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> list[execution_db.ExecutionResult]:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return await get_execution_results(run_id)
return await execution_db.get_execution_results(graph_exec_id)
async def create_schedule(
self,
@@ -579,7 +597,7 @@ class AgentServer(AppService):
@expose
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
execution_result = ExecutionResult(**execution_result_dict)
execution_result = execution_db.ExecutionResult(**execution_result_dict)
self.run_and_wait(self.event_queue.put(execution_result))
@expose

View File

@@ -16,11 +16,12 @@ from autogpt_server.util.settings import Config
logger = logging.getLogger(__name__)
conn_retry = retry(stop=stop_after_delay(5), wait=wait_exponential(multiplier=0.1))
T = TypeVar("T")
C = TypeVar("C", bound=Callable)
pyro_host = Config().pyro_host
def expose(func: Callable) -> Callable:
def expose(func: C) -> C:
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
@@ -29,7 +30,7 @@ def expose(func: Callable) -> Callable:
logger.exception(msg)
raise Exception(msg, e)
return pyro.expose(wrapper)
return pyro.expose(wrapper) # type: ignore
class PyroNameServer(AppProcess):
@@ -58,7 +59,7 @@ class AppService(AppProcess):
def __run_async(self, coro: Coroutine[T, Any, T]):
return asyncio.run_coroutine_threadsafe(coro, self.shared_event_loop)
def run_and_wait(self, coro: Coroutine[T, Any, T]) -> T:
def run_and_wait(self, coro: Coroutine[Any, Any, T]) -> T:
future = self.__run_async(coro)
return future.result()
@@ -100,7 +101,6 @@ def get_service_client(service_type: Type[AS]) -> AS:
service_name = service_type.service_name
class DynamicClient:
@conn_retry
def __init__(self):
ns = pyro.locate_ns()

View File

@@ -23,7 +23,6 @@ class SpinTestServer:
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
async def __aenter__(self):
self.name_server.__enter__()
self.setup_dependency_overrides()
self.agent_server.__enter__()
@@ -59,7 +58,7 @@ async def wait_execution(
timeout: int = 20,
) -> list:
async def is_execution_completed():
execs = await AgentServer().get_run_execution_results(
execs = await AgentServer().get_graph_run_node_execution_results(
graph_id, graph_exec_id, user_id
)
return (
@@ -74,7 +73,7 @@ async def wait_execution(
# Wait for the executions to complete
for i in range(timeout):
if await is_execution_completed():
return await AgentServer().get_run_execution_results(
return await AgentServer().get_graph_run_node_execution_results(
graph_id, graph_exec_id, user_id
)
time.sleep(1)

View File

@@ -36,7 +36,7 @@ async def assert_sample_graph_executions(
graph_exec_id: str,
):
input = {"input_1": "Hello", "input_2": "World"}
executions = await agent_server.get_run_execution_results(
executions = await agent_server.get_graph_run_node_execution_results(
test_graph.id, graph_exec_id, test_user.id
)