mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-17 18:21:46 -05:00
feat(copilot): add wait_if_running to view_agent_output tool
Adds ability to wait for execution completion instead of just returning current state. Uses Redis pubsub subscription for real-time updates. Changes: - Add wait_if_running parameter (0-300 seconds) to AgentOutputInput - Add _wait_for_execution_completion method using AsyncRedisExecutionEventBus - Subscribe to execution updates via Redis pubsub channel - Return immediately if execution already in terminal state - Return current state on timeout - Update response messages to indicate running/incomplete status SECRT-2003
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""Tool for retrieving agent execution outputs from user's library."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
@@ -11,7 +12,13 @@ from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
||||
from backend.data.execution import (
|
||||
AsyncRedisExecutionEventBus,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
GraphExecutionEvent,
|
||||
GraphExecutionMeta,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
@@ -25,6 +32,15 @@ from .utils import fetch_graph_from_store_slug
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Terminal statuses that indicate execution is complete
|
||||
TERMINAL_STATUSES = frozenset(
|
||||
{
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.FAILED,
|
||||
ExecutionStatus.TERMINATED,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class AgentOutputInput(BaseModel):
|
||||
"""Input parameters for the agent_output tool."""
|
||||
@@ -34,6 +50,7 @@ class AgentOutputInput(BaseModel):
|
||||
store_slug: str = ""
|
||||
execution_id: str = ""
|
||||
run_time: str = "latest"
|
||||
wait_if_running: int = 0 # Max seconds to wait if execution is still running
|
||||
|
||||
@field_validator(
|
||||
"agent_name",
|
||||
@@ -48,6 +65,15 @@ class AgentOutputInput(BaseModel):
|
||||
"""Strip whitespace from string fields."""
|
||||
return v.strip() if isinstance(v, str) else v
|
||||
|
||||
@field_validator("wait_if_running", mode="before")
|
||||
@classmethod
|
||||
def validate_wait(cls, v: Any) -> int:
|
||||
"""Ensure wait_if_running is a non-negative integer."""
|
||||
if v is None:
|
||||
return 0
|
||||
val = int(v)
|
||||
return max(0, min(val, 300)) # Cap at 5 minutes
|
||||
|
||||
|
||||
def parse_time_expression(
|
||||
time_expr: str | None,
|
||||
@@ -117,6 +143,11 @@ class AgentOutputTool(BaseTool):
|
||||
Select which run to retrieve using:
|
||||
- execution_id: Specific execution ID
|
||||
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
|
||||
|
||||
Wait for completion (optional):
|
||||
- wait_if_running: Max seconds to wait if execution is still running (0-300).
|
||||
If the execution is running/queued, waits up to this many seconds for it to complete.
|
||||
Returns current status on timeout. If already finished, returns immediately.
|
||||
"""
|
||||
|
||||
@property
|
||||
@@ -146,6 +177,13 @@ class AgentOutputTool(BaseTool):
|
||||
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
|
||||
),
|
||||
},
|
||||
"wait_if_running": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Max seconds to wait if execution is still running (0-300). "
|
||||
"If running, waits for completion. Returns current state on timeout."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
@@ -154,6 +192,90 @@ class AgentOutputTool(BaseTool):
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _wait_for_execution_completion(
|
||||
self,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
execution_id: str,
|
||||
timeout_seconds: int,
|
||||
) -> GraphExecution | None:
|
||||
"""
|
||||
Wait for an execution to reach a terminal status using Redis pubsub.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
graph_id: Graph ID
|
||||
execution_id: Execution ID to wait for
|
||||
timeout_seconds: Max seconds to wait
|
||||
|
||||
Returns:
|
||||
The execution with current status, or None on error
|
||||
"""
|
||||
# First check current status - maybe it's already done
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
if not execution:
|
||||
return None
|
||||
|
||||
# If already in terminal state, return immediately
|
||||
if execution.status in TERMINAL_STATUSES:
|
||||
logger.debug(
|
||||
f"Execution {execution_id} already in terminal state: {execution.status}"
|
||||
)
|
||||
return execution
|
||||
|
||||
logger.info(
|
||||
f"Waiting up to {timeout_seconds}s for execution {execution_id} "
|
||||
f"(current status: {execution.status})"
|
||||
)
|
||||
|
||||
# Subscribe to execution updates via Redis pubsub
|
||||
event_bus = AsyncRedisExecutionEventBus()
|
||||
channel_key = f"{user_id}/{graph_id}/{execution_id}"
|
||||
|
||||
try:
|
||||
deadline = asyncio.get_event_loop().time() + timeout_seconds
|
||||
|
||||
async for event in event_bus.listen_events(channel_key):
|
||||
# Check if we've exceeded the timeout
|
||||
remaining = deadline - asyncio.get_event_loop().time()
|
||||
if remaining <= 0:
|
||||
logger.info(f"Timeout waiting for execution {execution_id}")
|
||||
break
|
||||
|
||||
# Only process GraphExecutionEvents (not NodeExecutionEvents)
|
||||
if isinstance(event, GraphExecutionEvent):
|
||||
logger.debug(f"Received execution update: {event.status}")
|
||||
if event.status in TERMINAL_STATUSES:
|
||||
# Fetch full execution with outputs
|
||||
return await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.info(f"Timeout waiting for execution {execution_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error waiting for execution: {e}", exc_info=True)
|
||||
finally:
|
||||
# Clean up pubsub connection
|
||||
try:
|
||||
if hasattr(event_bus, "_pubsub") and event_bus._pubsub:
|
||||
await event_bus._pubsub.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Return current state on timeout
|
||||
return await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
|
||||
async def _resolve_agent(
|
||||
self,
|
||||
user_id: str,
|
||||
@@ -223,10 +345,14 @@ class AgentOutputTool(BaseTool):
|
||||
execution_id: str | None,
|
||||
time_start: datetime | None,
|
||||
time_end: datetime | None,
|
||||
include_running: bool = False,
|
||||
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
|
||||
"""
|
||||
Fetch execution(s) based on filters.
|
||||
Returns (single_execution, available_executions_meta, error_message).
|
||||
|
||||
Args:
|
||||
include_running: If True, also look for running/queued executions (for waiting)
|
||||
"""
|
||||
# If specific execution_id provided, fetch it directly
|
||||
if execution_id:
|
||||
@@ -239,11 +365,22 @@ class AgentOutputTool(BaseTool):
|
||||
return None, [], f"Execution '{execution_id}' not found"
|
||||
return execution, [], None
|
||||
|
||||
# Get completed executions with time filters
|
||||
# Determine which statuses to query
|
||||
statuses = [ExecutionStatus.COMPLETED]
|
||||
if include_running:
|
||||
statuses.extend(
|
||||
[
|
||||
ExecutionStatus.RUNNING,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
]
|
||||
)
|
||||
|
||||
# Get executions with time filters
|
||||
executions = await execution_db.get_graph_executions(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
statuses=[ExecutionStatus.COMPLETED],
|
||||
statuses=statuses,
|
||||
created_time_gte=time_start,
|
||||
created_time_lte=time_end,
|
||||
limit=10,
|
||||
@@ -310,10 +447,28 @@ class AgentOutputTool(BaseTool):
|
||||
for e in available_executions[:5]
|
||||
]
|
||||
|
||||
message = f"Found execution outputs for agent '{agent.name}'"
|
||||
# Build appropriate message based on execution status
|
||||
if execution.status == ExecutionStatus.COMPLETED:
|
||||
message = f"Found execution outputs for agent '{agent.name}'"
|
||||
elif execution.status == ExecutionStatus.FAILED:
|
||||
message = f"Execution for agent '{agent.name}' failed"
|
||||
elif execution.status == ExecutionStatus.TERMINATED:
|
||||
message = f"Execution for agent '{agent.name}' was terminated"
|
||||
elif execution.status in (
|
||||
ExecutionStatus.RUNNING,
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
):
|
||||
message = (
|
||||
f"Execution for agent '{agent.name}' is still {execution.status.value}. "
|
||||
"Results may be incomplete. Use wait_if_running to wait for completion."
|
||||
)
|
||||
else:
|
||||
message = f"Found execution for agent '{agent.name}' (status: {execution.status.value})"
|
||||
|
||||
if len(available_executions) > 1:
|
||||
message += (
|
||||
f". Showing latest of {len(available_executions)} matching executions."
|
||||
f" Showing latest of {len(available_executions)} matching executions."
|
||||
)
|
||||
|
||||
return AgentOutputResponse(
|
||||
@@ -428,13 +583,17 @@ class AgentOutputTool(BaseTool):
|
||||
# Parse time expression
|
||||
time_start, time_end = parse_time_expression(input_data.run_time)
|
||||
|
||||
# Fetch execution(s)
|
||||
# Check if we should wait for running executions
|
||||
wait_timeout = input_data.wait_if_running
|
||||
|
||||
# Fetch execution(s) - include running if we're going to wait
|
||||
execution, available_executions, exec_error = await self._get_execution(
|
||||
user_id=user_id,
|
||||
graph_id=agent.graph_id,
|
||||
execution_id=input_data.execution_id or None,
|
||||
time_start=time_start,
|
||||
time_end=time_end,
|
||||
include_running=wait_timeout > 0,
|
||||
)
|
||||
|
||||
if exec_error:
|
||||
@@ -443,4 +602,17 @@ class AgentOutputTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# If we have an execution that's still running and we should wait
|
||||
if execution and wait_timeout > 0 and execution.status not in TERMINAL_STATUSES:
|
||||
logger.info(
|
||||
f"Execution {execution.id} is {execution.status}, "
|
||||
f"waiting up to {wait_timeout}s for completion"
|
||||
)
|
||||
execution = await self._wait_for_execution_completion(
|
||||
user_id=user_id,
|
||||
graph_id=agent.graph_id,
|
||||
execution_id=execution.id,
|
||||
timeout_seconds=wait_timeout,
|
||||
)
|
||||
|
||||
return self._build_response(agent, execution, available_executions, session_id)
|
||||
|
||||
Reference in New Issue
Block a user