diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_output.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_output.py index aa43ac751d..b3bc613fd3 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_output.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_output.py @@ -1,26 +1,20 @@ """Tool for retrieving agent execution outputs from user's library.""" -import asyncio import logging import re from datetime import datetime, timedelta, timezone from typing import Any -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, Field, field_validator from backend.api.features.chat.model import ChatSession from backend.api.features.library import db as library_db from backend.api.features.library.model import LibraryAgent from backend.data import execution as execution_db -from backend.data.execution import ( - AsyncRedisExecutionEventBus, - ExecutionStatus, - GraphExecution, - GraphExecutionEvent, - GraphExecutionMeta, -) +from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta from .base import BaseTool +from .execution_utils import TERMINAL_STATUSES, wait_for_execution from .models import ( AgentOutputResponse, ErrorResponse, @@ -32,15 +26,6 @@ from .utils import fetch_graph_from_store_slug logger = logging.getLogger(__name__) -# Terminal statuses that indicate execution is complete -TERMINAL_STATUSES = frozenset( - { - ExecutionStatus.COMPLETED, - ExecutionStatus.FAILED, - ExecutionStatus.TERMINATED, - } -) - class AgentOutputInput(BaseModel): """Input parameters for the agent_output tool.""" @@ -50,7 +35,7 @@ class AgentOutputInput(BaseModel): store_slug: str = "" execution_id: str = "" run_time: str = "latest" - wait_if_running: int = 0 # Max seconds to wait if execution is still running + wait_if_running: int = Field(default=0, ge=0, le=300) @field_validator( "agent_name", @@ -65,15 +50,6 @@ class AgentOutputInput(BaseModel): """Strip whitespace from string fields.""" return v.strip() if isinstance(v, str) else v - @field_validator("wait_if_running", mode="before") - @classmethod - def validate_wait(cls, v: Any) -> int: - """Ensure wait_if_running is a non-negative integer.""" - if v is None: - return 0 - val = int(v) - return max(0, min(val, 300)) # Cap at 5 minutes - def parse_time_expression( time_expr: str | None, @@ -524,7 +500,7 @@ class AgentOutputTool(BaseTool): f"Execution {execution.id} is {execution.status}, " f"waiting up to {wait_timeout}s for completion" ) - execution = await self._wait_for_execution_completion( + execution = await wait_for_execution( user_id=user_id, graph_id=agent.graph_id, execution_id=execution.id, @@ -532,87 +508,3 @@ class AgentOutputTool(BaseTool): ) return self._build_response(agent, execution, available_executions, session_id) - - async def _wait_for_execution_completion( - self, - user_id: str, - graph_id: str, - execution_id: str, - timeout_seconds: int, - ) -> GraphExecution | None: - """ - Wait for an execution to reach a terminal status using Redis pubsub. - - Args: - user_id: User ID - graph_id: Graph ID - execution_id: Execution ID to wait for - timeout_seconds: Max seconds to wait - - Returns: - The execution with current status, or None on error - """ - # First check current status - maybe it's already done - execution = await execution_db.get_graph_execution( - user_id=user_id, - execution_id=execution_id, - include_node_executions=False, - ) - if not execution: - return None - - # If already in terminal state, return immediately - if execution.status in TERMINAL_STATUSES: - logger.debug( - f"Execution {execution_id} already in terminal state: {execution.status}" - ) - return execution - - logger.info( - f"Waiting up to {timeout_seconds}s for execution {execution_id} " - f"(current status: {execution.status})" - ) - - # Subscribe to execution updates via Redis pubsub - event_bus = AsyncRedisExecutionEventBus() - channel_key = f"{user_id}/{graph_id}/{execution_id}" - - try: - deadline = asyncio.get_event_loop().time() + timeout_seconds - - async for event in event_bus.listen_events(channel_key): - # Check if we've exceeded the timeout - remaining = deadline - asyncio.get_event_loop().time() - if remaining <= 0: - logger.info(f"Timeout waiting for execution {execution_id}") - break - - # Only process GraphExecutionEvents (not NodeExecutionEvents) - if isinstance(event, GraphExecutionEvent): - logger.debug(f"Received execution update: {event.status}") - if event.status in TERMINAL_STATUSES: - # Fetch full execution with outputs - return await execution_db.get_graph_execution( - user_id=user_id, - execution_id=execution_id, - include_node_executions=False, - ) - - except asyncio.TimeoutError: - logger.info(f"Timeout waiting for execution {execution_id}") - except Exception as e: - logger.error(f"Error waiting for execution: {e}", exc_info=True) - finally: - # Clean up pubsub connection - try: - if hasattr(event_bus, "_pubsub") and event_bus._pubsub: - await event_bus._pubsub.close() - except Exception: - pass - - # Return current state on timeout - return await execution_db.get_graph_execution( - user_id=user_id, - execution_id=execution_id, - include_node_executions=False, - ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/execution_utils.py b/autogpt_platform/backend/backend/api/features/chat/tools/execution_utils.py new file mode 100644 index 0000000000..59cb5b34d1 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/execution_utils.py @@ -0,0 +1,124 @@ +"""Shared utilities for execution waiting and status handling.""" + +import asyncio +import logging +from typing import Any + +from backend.data import execution as execution_db +from backend.data.execution import ( + AsyncRedisExecutionEventBus, + ExecutionStatus, + GraphExecution, + GraphExecutionEvent, +) + +logger = logging.getLogger(__name__) + +# Terminal statuses that indicate execution is complete +TERMINAL_STATUSES = frozenset( + { + ExecutionStatus.COMPLETED, + ExecutionStatus.FAILED, + ExecutionStatus.TERMINATED, + } +) + + +async def wait_for_execution( + user_id: str, + graph_id: str, + execution_id: str, + timeout_seconds: int, +) -> GraphExecution | None: + """ + Wait for an execution to reach a terminal status using Redis pubsub. + + Uses asyncio.wait_for to ensure timeout is respected even when no events + are received. + + Args: + user_id: User ID + graph_id: Graph ID + execution_id: Execution ID to wait for + timeout_seconds: Max seconds to wait + + Returns: + The execution with current status, or None if not found + """ + # First check current status - maybe it's already done + execution = await execution_db.get_graph_execution( + user_id=user_id, + execution_id=execution_id, + include_node_executions=False, + ) + if not execution: + return None + + # If already in terminal state, return immediately + if execution.status in TERMINAL_STATUSES: + logger.debug( + f"Execution {execution_id} already in terminal state: {execution.status}" + ) + return execution + + logger.info( + f"Waiting up to {timeout_seconds}s for execution {execution_id} " + f"(current status: {execution.status})" + ) + + # Subscribe to execution updates via Redis pubsub + event_bus = AsyncRedisExecutionEventBus() + channel_key = f"{user_id}/{graph_id}/{execution_id}" + + try: + # Use wait_for to enforce timeout on the entire listen operation + result = await asyncio.wait_for( + _listen_for_terminal_status(event_bus, channel_key, user_id, execution_id), + timeout=timeout_seconds, + ) + return result + except asyncio.TimeoutError: + logger.info(f"Timeout waiting for execution {execution_id}") + except Exception as e: + logger.error(f"Error waiting for execution: {e}", exc_info=True) + + # Return current state on timeout/error + return await execution_db.get_graph_execution( + user_id=user_id, + execution_id=execution_id, + include_node_executions=False, + ) + + +async def _listen_for_terminal_status( + event_bus: AsyncRedisExecutionEventBus, + channel_key: str, + user_id: str, + execution_id: str, +) -> GraphExecution | None: + """ + Listen for execution events until a terminal status is reached. + + This is a helper that gets wrapped in asyncio.wait_for for timeout handling. + """ + async for event in event_bus.listen_events(channel_key): + # Only process GraphExecutionEvents (not NodeExecutionEvents) + if isinstance(event, GraphExecutionEvent): + logger.debug(f"Received execution update: {event.status}") + if event.status in TERMINAL_STATUSES: + # Fetch full execution with outputs + return await execution_db.get_graph_execution( + user_id=user_id, + execution_id=execution_id, + include_node_executions=False, + ) + + # Should not reach here normally (generator should yield indefinitely) + return None + + +def get_execution_outputs(execution: GraphExecution | None) -> dict[str, Any] | None: + """Extract outputs from an execution, or return None.""" + if execution is None: + return None + return execution.outputs diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py index 34a9a00439..5d57f7e006 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py @@ -1,6 +1,5 @@ """Unified tool for agent operations with automatic state detection.""" -import asyncio import logging from typing import Any @@ -13,12 +12,7 @@ from backend.api.features.chat.tracking import ( track_agent_scheduled, ) from backend.api.features.library import db as library_db -from backend.data import execution as execution_db -from backend.data.execution import ( - AsyncRedisExecutionEventBus, - ExecutionStatus, - GraphExecutionEvent, -) +from backend.data.execution import ExecutionStatus from backend.data.graph import GraphModel from backend.data.model import CredentialsMetaInput from backend.data.user import get_user_by_id @@ -31,6 +25,7 @@ from backend.util.timezone_utils import ( ) from .base import BaseTool +from .execution_utils import get_execution_outputs, wait_for_execution from .helpers import get_inputs_from_schema from .models import ( AgentDetails, @@ -55,15 +50,6 @@ from .utils import ( logger = logging.getLogger(__name__) config = ChatConfig() -# Terminal statuses that indicate execution is complete -TERMINAL_STATUSES = frozenset( - { - ExecutionStatus.COMPLETED, - ExecutionStatus.FAILED, - ExecutionStatus.TERMINATED, - } -) - # Constants for response messages MSG_DO_NOT_RUN_AGAIN = "Do not run again unless explicitly requested." MSG_DO_NOT_SCHEDULE_AGAIN = "Do not schedule again unless explicitly requested." @@ -86,7 +72,7 @@ class RunAgentInput(BaseModel): schedule_name: str = "" cron: str = "" timezone: str = "UTC" - wait_for_result: int = 0 # Max seconds to wait for execution to complete + wait_for_result: int = Field(default=0, ge=0, le=300) @field_validator( "username_agent_slug", @@ -101,15 +87,6 @@ class RunAgentInput(BaseModel): """Strip whitespace from string fields.""" return v.strip() if isinstance(v, str) else v - @field_validator("wait_for_result", mode="before") - @classmethod - def validate_wait(cls, v: Any) -> int: - """Ensure wait_for_result is within valid range (0-300 seconds).""" - if v is None: - return 0 - val = int(v) - return max(0, min(val, 300)) # Cap at 5 minutes - class RunAgentTool(BaseTool): """Unified tool for agent operations with automatic state detection. @@ -510,12 +487,14 @@ class RunAgentTool(BaseTool): logger.info( f"Waiting up to {wait_for_result}s for execution {execution.id}" ) - final_status, outputs = await self._wait_for_execution_completion( + result = await wait_for_execution( user_id=user_id, graph_id=library_agent.graph_id, execution_id=execution.id, timeout_seconds=wait_for_result, ) + final_status = result.status if result else ExecutionStatus.FAILED + outputs = get_execution_outputs(result) # Build message based on final status if final_status == ExecutionStatus.COMPLETED: @@ -570,96 +549,6 @@ class RunAgentTool(BaseTool): library_agent_link=library_agent_link, ) - async def _wait_for_execution_completion( - self, - user_id: str, - graph_id: str, - execution_id: str, - timeout_seconds: int, - ) -> tuple[ExecutionStatus, dict[str, Any] | None]: - """ - Wait for an execution to reach a terminal status using Redis pubsub. - - Args: - user_id: User ID - graph_id: Graph ID - execution_id: Execution ID to wait for - timeout_seconds: Max seconds to wait - - Returns: - Tuple of (final_status, outputs_dict_or_None) - """ - # First check current status - maybe it's already done - execution = await execution_db.get_graph_execution( - user_id=user_id, - execution_id=execution_id, - include_node_executions=False, - ) - if not execution: - return ExecutionStatus.FAILED, None - - # If already in terminal state, return immediately - if execution.status in TERMINAL_STATUSES: - logger.debug( - f"Execution {execution_id} already in terminal state: {execution.status}" - ) - return execution.status, execution.outputs - - logger.info( - f"Waiting up to {timeout_seconds}s for execution {execution_id} " - f"(current status: {execution.status})" - ) - - # Subscribe to execution updates via Redis pubsub - event_bus = AsyncRedisExecutionEventBus() - channel_key = f"{user_id}/{graph_id}/{execution_id}" - - try: - deadline = asyncio.get_event_loop().time() + timeout_seconds - - async for event in event_bus.listen_events(channel_key): - # Check if we've exceeded the timeout - remaining = deadline - asyncio.get_event_loop().time() - if remaining <= 0: - logger.info(f"Timeout waiting for execution {execution_id}") - break - - # Only process GraphExecutionEvents (not NodeExecutionEvents) - if isinstance(event, GraphExecutionEvent): - logger.debug(f"Received execution update: {event.status}") - if event.status in TERMINAL_STATUSES: - # Fetch full execution with outputs - final_exec = await execution_db.get_graph_execution( - user_id=user_id, - execution_id=execution_id, - include_node_executions=False, - ) - if final_exec: - return final_exec.status, final_exec.outputs - return event.status, None - - except asyncio.TimeoutError: - logger.info(f"Timeout waiting for execution {execution_id}") - except Exception as e: - logger.error(f"Error waiting for execution: {e}", exc_info=True) - finally: - # Clean up pubsub connection - try: - if hasattr(event_bus, "_pubsub") and event_bus._pubsub: - await event_bus._pubsub.close() - except Exception: - pass - - # Return current state on timeout - execution = await execution_db.get_graph_execution( - user_id=user_id, - execution_id=execution_id, - include_node_executions=False, - ) - if execution: - return execution.status, execution.outputs - return ExecutionStatus.QUEUED, None - async def _schedule_agent( self, user_id: str,