diff --git a/rnd/autogpt_server/autogpt_server/data/execution.py b/rnd/autogpt_server/autogpt_server/data/execution.py index 78eb120333..76ee133f61 100644 --- a/rnd/autogpt_server/autogpt_server/data/execution.py +++ b/rnd/autogpt_server/autogpt_server/data/execution.py @@ -271,6 +271,28 @@ async def update_execution_status( return ExecutionResult.from_db(res) +async def get_graph_execution( + graph_exec_id: str, user_id: str +) -> AgentGraphExecution | None: + """ + Retrieve a specific graph execution by its ID. + + Args: + graph_exec_id (str): The ID of the graph execution to retrieve. + user_id (str): The ID of the user to whom the graph (execution) belongs. + + Returns: + AgentGraphExecution | None: The graph execution if found, None otherwise. + """ + execution = await AgentGraphExecution.prisma().find_first( + where={"id": graph_exec_id, "userId": user_id}, + include={ + "agentNodeExecutions": {"include": EXECUTION_RESULT_INCLUDE} # type: ignore + }, + ) + return execution + + async def list_executions(graph_id: str, graph_version: int | None = None) -> list[str]: where: AgentGraphExecutionWhereInput = {"agentGraphId": graph_id} if graph_version is not None: diff --git a/rnd/autogpt_server/autogpt_server/executor/manager.py b/rnd/autogpt_server/autogpt_server/executor/manager.py index 8f59a3f72b..e2177bbe10 100644 --- a/rnd/autogpt_server/autogpt_server/executor/manager.py +++ b/rnd/autogpt_server/autogpt_server/executor/manager.py @@ -457,7 +457,7 @@ class ExecutionManager(AppService): @expose def add_execution( self, graph_id: str, data: BlockInput, user_id: str - ) -> dict[Any, Any]: + ) -> GraphExecution: graph: Graph | None = self.run_and_wait(get_graph(graph_id, user_id=user_id)) if not graph: raise Exception(f"Graph #{graph_id} not found.") @@ -508,7 +508,7 @@ class ExecutionManager(AppService): ) self.queue.add(graph_exec) - return {"id": graph_exec_id} + return graph_exec @expose def cancel_execution(self, graph_exec_id: str) -> None: diff --git a/rnd/autogpt_server/autogpt_server/server/server.py b/rnd/autogpt_server/autogpt_server/server/server.py index 7a5fdafcc7..581e3c2329 100644 --- a/rnd/autogpt_server/autogpt_server/server/server.py +++ b/rnd/autogpt_server/autogpt_server/server/server.py @@ -22,6 +22,7 @@ from fastapi.responses import JSONResponse import autogpt_server.server.ws_api from autogpt_server.data import block, db +from autogpt_server.data import execution as execution_db from autogpt_server.data import graph as graph_db from autogpt_server.data import user as user_db from autogpt_server.data.block import BlockInput, CompletedBlockOutput @@ -197,6 +198,11 @@ class AgentServer(AppService): endpoint=self.get_run_execution_results, methods=["GET"], ) + router.add_api_route( + path="/graphs/{graph_id}/executions/{graph_exec_id}/stop", + endpoint=self.stop_graph_execution, + methods=["POST"], + ) router.add_api_route( path="/graphs/{graph_id}/schedules", endpoint=self.create_schedule, # type: ignore @@ -508,15 +514,32 @@ class AgentServer(AppService): graph_id: str, node_input: dict[Any, Any], user_id: Annotated[str, Depends(get_user_id)], - ) -> dict[Any, Any]: + ) -> dict[Any, Any]: # FIXME: add proper return type try: - return self.execution_manager_client.add_execution( + graph_exec = self.execution_manager_client.add_execution( graph_id, node_input, user_id=user_id ) + return {"id": graph_exec.graph_exec_id} except Exception as e: msg = e.__str__().encode().decode("unicode_escape") raise HTTPException(status_code=400, detail=msg) + async def stop_graph_execution( + self, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)] + ) -> execution_db.AgentGraphExecution: + graph_exec = await execution_db.get_graph_execution(graph_exec_id, user_id) + if not graph_exec: + raise HTTPException( + 404, detail=f"Agent execution #{graph_exec_id} not found" + ) + + self.execution_manager_client.cancel_execution(graph_exec_id) + + # Retrieve & return canceled graph execution in its final state + graph_exec = await execution_db.get_graph_execution(graph_exec_id, user_id) + assert graph_exec + return graph_exec + @classmethod async def list_graph_runs( cls, diff --git a/rnd/autogpt_server/autogpt_server/util/service.py b/rnd/autogpt_server/autogpt_server/util/service.py index e57d918ad7..bb55d1e5ce 100644 --- a/rnd/autogpt_server/autogpt_server/util/service.py +++ b/rnd/autogpt_server/autogpt_server/util/service.py @@ -16,11 +16,12 @@ from autogpt_server.util.settings import Config logger = logging.getLogger(__name__) conn_retry = retry(stop=stop_after_delay(5), wait=wait_exponential(multiplier=0.1)) T = TypeVar("T") +C = TypeVar("C", bound=Callable) pyro_host = Config().pyro_host -def expose(func: Callable) -> Callable: +def expose(func: C) -> C: def wrapper(*args, **kwargs): try: return func(*args, **kwargs) @@ -29,7 +30,7 @@ def expose(func: Callable) -> Callable: logger.exception(msg) raise Exception(msg, e) - return pyro.expose(wrapper) + return pyro.expose(wrapper) # type: ignore class PyroNameServer(AppProcess):