refactor: address review feedback

- Use pydantic Field(ge=0, le=300) instead of custom validators
- Extract shared wait logic to execution_utils.py
- Use asyncio.wait_for for proper timeout handling
- Remove duplicated code in agent_output.py and run_agent.py

Note: Direct DB access will need adjustment for #12057 compatibility
This commit is contained in:
Otto
2026-02-17 13:39:59 +00:00
parent 92ddb57460
commit 901b5e8b75
3 changed files with 135 additions and 230 deletions

View File

@@ -1,26 +1,20 @@
"""Tool for retrieving agent execution outputs from user's library."""
import asyncio
import logging
import re
from datetime import datetime, timedelta, timezone
from typing import Any
from pydantic import BaseModel, field_validator
from pydantic import BaseModel, Field, field_validator
from backend.api.features.chat.model import ChatSession
from backend.api.features.library import db as library_db
from backend.api.features.library.model import LibraryAgent
from backend.data import execution as execution_db
from backend.data.execution import (
AsyncRedisExecutionEventBus,
ExecutionStatus,
GraphExecution,
GraphExecutionEvent,
GraphExecutionMeta,
)
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
from .base import BaseTool
from .execution_utils import TERMINAL_STATUSES, wait_for_execution
from .models import (
AgentOutputResponse,
ErrorResponse,
@@ -32,15 +26,6 @@ from .utils import fetch_graph_from_store_slug
logger = logging.getLogger(__name__)
# Terminal statuses that indicate execution is complete
TERMINAL_STATUSES = frozenset(
{
ExecutionStatus.COMPLETED,
ExecutionStatus.FAILED,
ExecutionStatus.TERMINATED,
}
)
class AgentOutputInput(BaseModel):
"""Input parameters for the agent_output tool."""
@@ -50,7 +35,7 @@ class AgentOutputInput(BaseModel):
store_slug: str = ""
execution_id: str = ""
run_time: str = "latest"
wait_if_running: int = 0 # Max seconds to wait if execution is still running
wait_if_running: int = Field(default=0, ge=0, le=300)
@field_validator(
"agent_name",
@@ -65,15 +50,6 @@ class AgentOutputInput(BaseModel):
"""Strip whitespace from string fields."""
return v.strip() if isinstance(v, str) else v
@field_validator("wait_if_running", mode="before")
@classmethod
def validate_wait(cls, v: Any) -> int:
"""Ensure wait_if_running is a non-negative integer."""
if v is None:
return 0
val = int(v)
return max(0, min(val, 300)) # Cap at 5 minutes
def parse_time_expression(
time_expr: str | None,
@@ -524,7 +500,7 @@ class AgentOutputTool(BaseTool):
f"Execution {execution.id} is {execution.status}, "
f"waiting up to {wait_timeout}s for completion"
)
execution = await self._wait_for_execution_completion(
execution = await wait_for_execution(
user_id=user_id,
graph_id=agent.graph_id,
execution_id=execution.id,
@@ -532,87 +508,3 @@ class AgentOutputTool(BaseTool):
)
return self._build_response(agent, execution, available_executions, session_id)
async def _wait_for_execution_completion(
self,
user_id: str,
graph_id: str,
execution_id: str,
timeout_seconds: int,
) -> GraphExecution | None:
"""
Wait for an execution to reach a terminal status using Redis pubsub.
Args:
user_id: User ID
graph_id: Graph ID
execution_id: Execution ID to wait for
timeout_seconds: Max seconds to wait
Returns:
The execution with current status, or None on error
"""
# First check current status - maybe it's already done
execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
if not execution:
return None
# If already in terminal state, return immediately
if execution.status in TERMINAL_STATUSES:
logger.debug(
f"Execution {execution_id} already in terminal state: {execution.status}"
)
return execution
logger.info(
f"Waiting up to {timeout_seconds}s for execution {execution_id} "
f"(current status: {execution.status})"
)
# Subscribe to execution updates via Redis pubsub
event_bus = AsyncRedisExecutionEventBus()
channel_key = f"{user_id}/{graph_id}/{execution_id}"
try:
deadline = asyncio.get_event_loop().time() + timeout_seconds
async for event in event_bus.listen_events(channel_key):
# Check if we've exceeded the timeout
remaining = deadline - asyncio.get_event_loop().time()
if remaining <= 0:
logger.info(f"Timeout waiting for execution {execution_id}")
break
# Only process GraphExecutionEvents (not NodeExecutionEvents)
if isinstance(event, GraphExecutionEvent):
logger.debug(f"Received execution update: {event.status}")
if event.status in TERMINAL_STATUSES:
# Fetch full execution with outputs
return await execution_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
except asyncio.TimeoutError:
logger.info(f"Timeout waiting for execution {execution_id}")
except Exception as e:
logger.error(f"Error waiting for execution: {e}", exc_info=True)
finally:
# Clean up pubsub connection
try:
if hasattr(event_bus, "_pubsub") and event_bus._pubsub:
await event_bus._pubsub.close()
except Exception:
pass
# Return current state on timeout
return await execution_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)

View File

@@ -0,0 +1,124 @@
"""Shared utilities for execution waiting and status handling."""
import asyncio
import logging
from typing import Any
from backend.data import execution as execution_db
from backend.data.execution import (
AsyncRedisExecutionEventBus,
ExecutionStatus,
GraphExecution,
GraphExecutionEvent,
)
logger = logging.getLogger(__name__)
# Terminal statuses that indicate execution is complete
TERMINAL_STATUSES = frozenset(
{
ExecutionStatus.COMPLETED,
ExecutionStatus.FAILED,
ExecutionStatus.TERMINATED,
}
)
async def wait_for_execution(
user_id: str,
graph_id: str,
execution_id: str,
timeout_seconds: int,
) -> GraphExecution | None:
"""
Wait for an execution to reach a terminal status using Redis pubsub.
Uses asyncio.wait_for to ensure timeout is respected even when no events
are received.
Args:
user_id: User ID
graph_id: Graph ID
execution_id: Execution ID to wait for
timeout_seconds: Max seconds to wait
Returns:
The execution with current status, or None if not found
"""
# First check current status - maybe it's already done
execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
if not execution:
return None
# If already in terminal state, return immediately
if execution.status in TERMINAL_STATUSES:
logger.debug(
f"Execution {execution_id} already in terminal state: {execution.status}"
)
return execution
logger.info(
f"Waiting up to {timeout_seconds}s for execution {execution_id} "
f"(current status: {execution.status})"
)
# Subscribe to execution updates via Redis pubsub
event_bus = AsyncRedisExecutionEventBus()
channel_key = f"{user_id}/{graph_id}/{execution_id}"
try:
# Use wait_for to enforce timeout on the entire listen operation
result = await asyncio.wait_for(
_listen_for_terminal_status(event_bus, channel_key, user_id, execution_id),
timeout=timeout_seconds,
)
return result
except asyncio.TimeoutError:
logger.info(f"Timeout waiting for execution {execution_id}")
except Exception as e:
logger.error(f"Error waiting for execution: {e}", exc_info=True)
# Return current state on timeout/error
return await execution_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
async def _listen_for_terminal_status(
event_bus: AsyncRedisExecutionEventBus,
channel_key: str,
user_id: str,
execution_id: str,
) -> GraphExecution | None:
"""
Listen for execution events until a terminal status is reached.
This is a helper that gets wrapped in asyncio.wait_for for timeout handling.
"""
async for event in event_bus.listen_events(channel_key):
# Only process GraphExecutionEvents (not NodeExecutionEvents)
if isinstance(event, GraphExecutionEvent):
logger.debug(f"Received execution update: {event.status}")
if event.status in TERMINAL_STATUSES:
# Fetch full execution with outputs
return await execution_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
# Should not reach here normally (generator should yield indefinitely)
return None
def get_execution_outputs(execution: GraphExecution | None) -> dict[str, Any] | None:
"""Extract outputs from an execution, or return None."""
if execution is None:
return None
return execution.outputs

View File

@@ -1,6 +1,5 @@
"""Unified tool for agent operations with automatic state detection."""
import asyncio
import logging
from typing import Any
@@ -13,12 +12,7 @@ from backend.api.features.chat.tracking import (
track_agent_scheduled,
)
from backend.api.features.library import db as library_db
from backend.data import execution as execution_db
from backend.data.execution import (
AsyncRedisExecutionEventBus,
ExecutionStatus,
GraphExecutionEvent,
)
from backend.data.execution import ExecutionStatus
from backend.data.graph import GraphModel
from backend.data.model import CredentialsMetaInput
from backend.data.user import get_user_by_id
@@ -31,6 +25,7 @@ from backend.util.timezone_utils import (
)
from .base import BaseTool
from .execution_utils import get_execution_outputs, wait_for_execution
from .helpers import get_inputs_from_schema
from .models import (
AgentDetails,
@@ -55,15 +50,6 @@ from .utils import (
logger = logging.getLogger(__name__)
config = ChatConfig()
# Terminal statuses that indicate execution is complete
TERMINAL_STATUSES = frozenset(
{
ExecutionStatus.COMPLETED,
ExecutionStatus.FAILED,
ExecutionStatus.TERMINATED,
}
)
# Constants for response messages
MSG_DO_NOT_RUN_AGAIN = "Do not run again unless explicitly requested."
MSG_DO_NOT_SCHEDULE_AGAIN = "Do not schedule again unless explicitly requested."
@@ -86,7 +72,7 @@ class RunAgentInput(BaseModel):
schedule_name: str = ""
cron: str = ""
timezone: str = "UTC"
wait_for_result: int = 0 # Max seconds to wait for execution to complete
wait_for_result: int = Field(default=0, ge=0, le=300)
@field_validator(
"username_agent_slug",
@@ -101,15 +87,6 @@ class RunAgentInput(BaseModel):
"""Strip whitespace from string fields."""
return v.strip() if isinstance(v, str) else v
@field_validator("wait_for_result", mode="before")
@classmethod
def validate_wait(cls, v: Any) -> int:
"""Ensure wait_for_result is within valid range (0-300 seconds)."""
if v is None:
return 0
val = int(v)
return max(0, min(val, 300)) # Cap at 5 minutes
class RunAgentTool(BaseTool):
"""Unified tool for agent operations with automatic state detection.
@@ -510,12 +487,14 @@ class RunAgentTool(BaseTool):
logger.info(
f"Waiting up to {wait_for_result}s for execution {execution.id}"
)
final_status, outputs = await self._wait_for_execution_completion(
result = await wait_for_execution(
user_id=user_id,
graph_id=library_agent.graph_id,
execution_id=execution.id,
timeout_seconds=wait_for_result,
)
final_status = result.status if result else ExecutionStatus.FAILED
outputs = get_execution_outputs(result)
# Build message based on final status
if final_status == ExecutionStatus.COMPLETED:
@@ -570,96 +549,6 @@ class RunAgentTool(BaseTool):
library_agent_link=library_agent_link,
)
async def _wait_for_execution_completion(
self,
user_id: str,
graph_id: str,
execution_id: str,
timeout_seconds: int,
) -> tuple[ExecutionStatus, dict[str, Any] | None]:
"""
Wait for an execution to reach a terminal status using Redis pubsub.
Args:
user_id: User ID
graph_id: Graph ID
execution_id: Execution ID to wait for
timeout_seconds: Max seconds to wait
Returns:
Tuple of (final_status, outputs_dict_or_None)
"""
# First check current status - maybe it's already done
execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
if not execution:
return ExecutionStatus.FAILED, None
# If already in terminal state, return immediately
if execution.status in TERMINAL_STATUSES:
logger.debug(
f"Execution {execution_id} already in terminal state: {execution.status}"
)
return execution.status, execution.outputs
logger.info(
f"Waiting up to {timeout_seconds}s for execution {execution_id} "
f"(current status: {execution.status})"
)
# Subscribe to execution updates via Redis pubsub
event_bus = AsyncRedisExecutionEventBus()
channel_key = f"{user_id}/{graph_id}/{execution_id}"
try:
deadline = asyncio.get_event_loop().time() + timeout_seconds
async for event in event_bus.listen_events(channel_key):
# Check if we've exceeded the timeout
remaining = deadline - asyncio.get_event_loop().time()
if remaining <= 0:
logger.info(f"Timeout waiting for execution {execution_id}")
break
# Only process GraphExecutionEvents (not NodeExecutionEvents)
if isinstance(event, GraphExecutionEvent):
logger.debug(f"Received execution update: {event.status}")
if event.status in TERMINAL_STATUSES:
# Fetch full execution with outputs
final_exec = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
if final_exec:
return final_exec.status, final_exec.outputs
return event.status, None
except asyncio.TimeoutError:
logger.info(f"Timeout waiting for execution {execution_id}")
except Exception as e:
logger.error(f"Error waiting for execution: {e}", exc_info=True)
finally:
# Clean up pubsub connection
try:
if hasattr(event_bus, "_pubsub") and event_bus._pubsub:
await event_bus._pubsub.close()
except Exception:
pass
# Return current state on timeout
execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
if execution:
return execution.status, execution.outputs
return ExecutionStatus.QUEUED, None
async def _schedule_agent(
self,
user_id: str,