mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-17 18:21:46 -05:00
feat(copilot): add wait_for_result to run_agent tool
Adds wait_for_result parameter (0-300 seconds) to run_agent that blocks until execution completes or times out. - Uses Redis pubsub subscription via AsyncRedisExecutionEventBus (no polling) - Returns immediately if execution finishes within timeout - Returns current state + partial outputs on timeout - Outputs included in ExecutionStartedResponse when wait is used This allows LLMs to run agents and get results in a single tool call: run_agent(username_agent_slug='user/agent', wait_for_result=60)
This commit is contained in:
@@ -192,6 +192,7 @@ class ExecutionStartedResponse(ToolResponseBase):
|
||||
library_agent_id: str | None = None
|
||||
library_agent_link: str | None = None
|
||||
status: str = "QUEUED"
|
||||
outputs: dict[str, Any] | None = None # Populated when wait_for_result is used
|
||||
|
||||
|
||||
# Auth/error models
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Unified tool for agent operations with automatic state detection."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
@@ -12,6 +13,12 @@ from backend.api.features.chat.tracking import (
|
||||
track_agent_scheduled,
|
||||
)
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data.execution import (
|
||||
AsyncRedisExecutionEventBus,
|
||||
ExecutionStatus,
|
||||
GraphExecutionEvent,
|
||||
)
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.user import get_user_by_id
|
||||
@@ -48,6 +55,15 @@ from .utils import (
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
# Terminal statuses that indicate execution is complete
|
||||
TERMINAL_STATUSES = frozenset(
|
||||
{
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.FAILED,
|
||||
ExecutionStatus.TERMINATED,
|
||||
}
|
||||
)
|
||||
|
||||
# Constants for response messages
|
||||
MSG_DO_NOT_RUN_AGAIN = "Do not run again unless explicitly requested."
|
||||
MSG_DO_NOT_SCHEDULE_AGAIN = "Do not schedule again unless explicitly requested."
|
||||
@@ -70,6 +86,7 @@ class RunAgentInput(BaseModel):
|
||||
schedule_name: str = ""
|
||||
cron: str = ""
|
||||
timezone: str = "UTC"
|
||||
wait_for_result: int = 0 # Max seconds to wait for execution to complete
|
||||
|
||||
@field_validator(
|
||||
"username_agent_slug",
|
||||
@@ -84,6 +101,15 @@ class RunAgentInput(BaseModel):
|
||||
"""Strip whitespace from string fields."""
|
||||
return v.strip() if isinstance(v, str) else v
|
||||
|
||||
@field_validator("wait_for_result", mode="before")
|
||||
@classmethod
|
||||
def validate_wait(cls, v: Any) -> int:
|
||||
"""Ensure wait_for_result is within valid range (0-300 seconds)."""
|
||||
if v is None:
|
||||
return 0
|
||||
val = int(v)
|
||||
return max(0, min(val, 300)) # Cap at 5 minutes
|
||||
|
||||
|
||||
class RunAgentTool(BaseTool):
|
||||
"""Unified tool for agent operations with automatic state detection.
|
||||
@@ -151,6 +177,14 @@ class RunAgentTool(BaseTool):
|
||||
"type": "string",
|
||||
"description": "IANA timezone for schedule (default: UTC)",
|
||||
},
|
||||
"wait_for_result": {
|
||||
"type": "integer",
|
||||
"description": (
|
||||
"Max seconds to wait for execution to complete (0-300). "
|
||||
"If >0, blocks until the execution finishes or times out. "
|
||||
"Returns execution outputs when complete."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
@@ -347,6 +381,7 @@ class RunAgentTool(BaseTool):
|
||||
graph=graph,
|
||||
graph_credentials=graph_credentials,
|
||||
inputs=params.inputs,
|
||||
wait_for_result=params.wait_for_result,
|
||||
)
|
||||
|
||||
except NotFoundError as e:
|
||||
@@ -423,6 +458,96 @@ class RunAgentTool(BaseTool):
|
||||
trigger_info=trigger_info,
|
||||
)
|
||||
|
||||
async def _wait_for_execution_completion(
|
||||
self,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
execution_id: str,
|
||||
timeout_seconds: int,
|
||||
) -> tuple[ExecutionStatus, dict[str, Any] | None]:
|
||||
"""
|
||||
Wait for an execution to reach a terminal status using Redis pubsub.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
graph_id: Graph ID
|
||||
execution_id: Execution ID to wait for
|
||||
timeout_seconds: Max seconds to wait
|
||||
|
||||
Returns:
|
||||
Tuple of (final_status, outputs_dict_or_None)
|
||||
"""
|
||||
# First check current status - maybe it's already done
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
if not execution:
|
||||
return ExecutionStatus.FAILED, None
|
||||
|
||||
# If already in terminal state, return immediately
|
||||
if execution.status in TERMINAL_STATUSES:
|
||||
logger.debug(
|
||||
f"Execution {execution_id} already in terminal state: {execution.status}"
|
||||
)
|
||||
return execution.status, execution.outputs
|
||||
|
||||
logger.info(
|
||||
f"Waiting up to {timeout_seconds}s for execution {execution_id} "
|
||||
f"(current status: {execution.status})"
|
||||
)
|
||||
|
||||
# Subscribe to execution updates via Redis pubsub
|
||||
event_bus = AsyncRedisExecutionEventBus()
|
||||
channel_key = f"{user_id}/{graph_id}/{execution_id}"
|
||||
|
||||
try:
|
||||
deadline = asyncio.get_event_loop().time() + timeout_seconds
|
||||
|
||||
async for event in event_bus.listen_events(channel_key):
|
||||
# Check if we've exceeded the timeout
|
||||
remaining = deadline - asyncio.get_event_loop().time()
|
||||
if remaining <= 0:
|
||||
logger.info(f"Timeout waiting for execution {execution_id}")
|
||||
break
|
||||
|
||||
# Only process GraphExecutionEvents (not NodeExecutionEvents)
|
||||
if isinstance(event, GraphExecutionEvent):
|
||||
logger.debug(f"Received execution update: {event.status}")
|
||||
if event.status in TERMINAL_STATUSES:
|
||||
# Fetch full execution with outputs
|
||||
final_exec = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
if final_exec:
|
||||
return final_exec.status, final_exec.outputs
|
||||
return event.status, None
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.info(f"Timeout waiting for execution {execution_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error waiting for execution: {e}", exc_info=True)
|
||||
finally:
|
||||
# Clean up pubsub connection
|
||||
try:
|
||||
if hasattr(event_bus, "_pubsub") and event_bus._pubsub:
|
||||
await event_bus._pubsub.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Return current state on timeout
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
if execution:
|
||||
return execution.status, execution.outputs
|
||||
return ExecutionStatus.QUEUED, None
|
||||
|
||||
async def _run_agent(
|
||||
self,
|
||||
user_id: str,
|
||||
@@ -430,8 +555,9 @@ class RunAgentTool(BaseTool):
|
||||
graph: GraphModel,
|
||||
graph_credentials: dict[str, CredentialsMetaInput],
|
||||
inputs: dict[str, Any],
|
||||
wait_for_result: int = 0,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute an agent immediately."""
|
||||
"""Execute an agent immediately, optionally waiting for completion."""
|
||||
session_id = session.session_id
|
||||
|
||||
# Check rate limits
|
||||
@@ -468,6 +594,58 @@ class RunAgentTool(BaseTool):
|
||||
)
|
||||
|
||||
library_agent_link = f"/library/agents/{library_agent.id}"
|
||||
|
||||
# If wait_for_result is specified, wait for execution to complete
|
||||
if wait_for_result > 0:
|
||||
logger.info(
|
||||
f"Waiting up to {wait_for_result}s for execution {execution.id}"
|
||||
)
|
||||
final_status, outputs = await self._wait_for_execution_completion(
|
||||
user_id=user_id,
|
||||
graph_id=library_agent.graph_id,
|
||||
execution_id=execution.id,
|
||||
timeout_seconds=wait_for_result,
|
||||
)
|
||||
|
||||
# Build message based on final status
|
||||
if final_status == ExecutionStatus.COMPLETED:
|
||||
message = (
|
||||
f"Agent '{library_agent.name}' execution completed successfully. "
|
||||
f"{MSG_DO_NOT_RUN_AGAIN}"
|
||||
)
|
||||
elif final_status == ExecutionStatus.FAILED:
|
||||
message = (
|
||||
f"Agent '{library_agent.name}' execution failed. "
|
||||
f"View details at {library_agent_link}. "
|
||||
f"{MSG_DO_NOT_RUN_AGAIN}"
|
||||
)
|
||||
elif final_status == ExecutionStatus.TERMINATED:
|
||||
message = (
|
||||
f"Agent '{library_agent.name}' execution was terminated. "
|
||||
f"View details at {library_agent_link}. "
|
||||
f"{MSG_DO_NOT_RUN_AGAIN}"
|
||||
)
|
||||
else:
|
||||
message = (
|
||||
f"Agent '{library_agent.name}' execution is still {final_status.value} "
|
||||
f"(timed out after {wait_for_result}s). "
|
||||
f"View at {library_agent_link}. "
|
||||
f"{MSG_DO_NOT_RUN_AGAIN}"
|
||||
)
|
||||
|
||||
return ExecutionStartedResponse(
|
||||
message=message,
|
||||
session_id=session_id,
|
||||
execution_id=execution.id,
|
||||
graph_id=library_agent.graph_id,
|
||||
graph_name=library_agent.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=library_agent_link,
|
||||
status=final_status.value,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
# Default: return immediately without waiting
|
||||
return ExecutionStartedResponse(
|
||||
message=(
|
||||
f"Agent '{library_agent.name}' execution started successfully. "
|
||||
|
||||
Reference in New Issue
Block a user