mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-10 07:38:04 -05:00
- Backend: Populate required array in credentials schema based on credentials_optional metadata - Backend: Skip input_default for optional credentials to prevent stale credential validation errors - Backend: Pass nodes_to_skip through execution chain to blocks - Backend: Filter Smart Decision Maker tools based on nodes_to_skip to prevent LLM from selecting unavailable tools - Frontend: Filter incomplete credentials before sending run requests - Frontend: Show (optional) indicator for non-required credentials - Frontend: Only show optional toggle in builder canvas, not run dialogs Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1044 lines
36 KiB
Python
1044 lines
36 KiB
Python
import asyncio
|
|
import logging
|
|
import threading
|
|
import time
|
|
from collections import defaultdict
|
|
from concurrent.futures import Future
|
|
from typing import Mapping, Optional, cast
|
|
|
|
from pydantic import BaseModel, JsonValue, ValidationError
|
|
|
|
from backend.data import execution as execution_db
|
|
from backend.data import graph as graph_db
|
|
from backend.data import user as user_db
|
|
from backend.data.block import (
|
|
Block,
|
|
BlockCostType,
|
|
BlockInput,
|
|
BlockOutputEntry,
|
|
BlockType,
|
|
get_block,
|
|
)
|
|
from backend.data.block_cost_config import BLOCK_COSTS
|
|
from backend.data.db import prisma
|
|
|
|
# Import dynamic field utilities from centralized location
|
|
from backend.data.dynamic_fields import merge_execution_input
|
|
from backend.data.execution import (
|
|
ExecutionContext,
|
|
ExecutionStatus,
|
|
GraphExecutionMeta,
|
|
GraphExecutionStats,
|
|
GraphExecutionWithNodes,
|
|
NodesInputMasks,
|
|
get_graph_execution,
|
|
)
|
|
from backend.data.graph import GraphModel, Node
|
|
from backend.data.model import USER_TIMEZONE_NOT_SET, CredentialsMetaInput
|
|
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
|
from backend.util.clients import (
|
|
get_async_execution_event_bus,
|
|
get_async_execution_queue,
|
|
get_database_manager_async_client,
|
|
get_integration_credentials_store,
|
|
)
|
|
from backend.util.exceptions import (
|
|
GraphNotFoundError,
|
|
GraphValidationError,
|
|
NotFoundError,
|
|
)
|
|
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
|
from backend.util.settings import Config
|
|
from backend.util.type import convert
|
|
|
|
config = Config()
|
|
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil]")
|
|
|
|
# ============ Resource Helpers ============ #
|
|
|
|
|
|
class LogMetadata(TruncatedLogger):
|
|
def __init__(
|
|
self,
|
|
logger: logging.Logger,
|
|
user_id: str,
|
|
graph_eid: str,
|
|
graph_id: str,
|
|
node_eid: str,
|
|
node_id: str,
|
|
block_name: str,
|
|
max_length: int = 1000,
|
|
):
|
|
metadata = {
|
|
"component": "ExecutionManager",
|
|
"user_id": user_id,
|
|
"graph_eid": graph_eid,
|
|
"graph_id": graph_id,
|
|
"node_eid": node_eid,
|
|
"node_id": node_id,
|
|
"block_name": block_name,
|
|
}
|
|
prefix = (
|
|
"[ExecutionManager]"
|
|
if is_structured_logging_enabled()
|
|
else f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|neid:{node_eid}|{block_name}]" # noqa
|
|
)
|
|
super().__init__(
|
|
logger,
|
|
max_length=max_length,
|
|
prefix=prefix,
|
|
metadata=metadata,
|
|
)
|
|
|
|
|
|
# ============ Execution Cost Helpers ============ #
|
|
|
|
|
|
def execution_usage_cost(execution_count: int) -> tuple[int, int]:
|
|
"""
|
|
Calculate the cost of executing a graph based on the current number of node executions.
|
|
|
|
Args:
|
|
execution_count: Number of node executions
|
|
|
|
Returns:
|
|
Tuple of cost amount and the number of execution count that is included in the cost.
|
|
"""
|
|
return (
|
|
(
|
|
config.execution_cost_per_threshold
|
|
if execution_count % config.execution_cost_count_threshold == 0
|
|
else 0
|
|
),
|
|
config.execution_cost_count_threshold,
|
|
)
|
|
|
|
|
|
def block_usage_cost(
|
|
block: Block,
|
|
input_data: BlockInput,
|
|
data_size: float = 0,
|
|
run_time: float = 0,
|
|
) -> tuple[int, BlockInput]:
|
|
"""
|
|
Calculate the cost of using a block based on the input data and the block type.
|
|
|
|
Args:
|
|
block: Block object
|
|
input_data: Input data for the block
|
|
data_size: Size of the input data in bytes
|
|
run_time: Execution time of the block in seconds
|
|
|
|
Returns:
|
|
Tuple of cost amount and cost filter
|
|
"""
|
|
block_costs = BLOCK_COSTS.get(type(block))
|
|
if not block_costs:
|
|
return 0, {}
|
|
|
|
for block_cost in block_costs:
|
|
if not _is_cost_filter_match(block_cost.cost_filter, input_data):
|
|
continue
|
|
|
|
if block_cost.cost_type == BlockCostType.RUN:
|
|
return block_cost.cost_amount, block_cost.cost_filter
|
|
|
|
if block_cost.cost_type == BlockCostType.SECOND:
|
|
return (
|
|
int(run_time * block_cost.cost_amount),
|
|
block_cost.cost_filter,
|
|
)
|
|
|
|
if block_cost.cost_type == BlockCostType.BYTE:
|
|
return (
|
|
int(data_size * block_cost.cost_amount),
|
|
block_cost.cost_filter,
|
|
)
|
|
|
|
return 0, {}
|
|
|
|
|
|
def _is_cost_filter_match(cost_filter: BlockInput, input_data: BlockInput) -> bool:
|
|
"""
|
|
Filter rules:
|
|
- If cost_filter is an object, then check if cost_filter is the subset of input_data
|
|
- Otherwise, check if cost_filter is equal to input_data.
|
|
- Undefined, null, and empty string are considered as equal.
|
|
"""
|
|
if not isinstance(cost_filter, dict) or not isinstance(input_data, dict):
|
|
return cost_filter == input_data
|
|
|
|
return all(
|
|
(not input_data.get(k) and not v)
|
|
or (input_data.get(k) and _is_cost_filter_match(v, input_data[k]))
|
|
for k, v in cost_filter.items()
|
|
)
|
|
|
|
|
|
# ============ Execution Input Helpers ============ #
|
|
|
|
# Dynamic field utilities are now imported from backend.data.dynamic_fields
|
|
|
|
|
|
def validate_exec(
|
|
node: Node,
|
|
data: BlockInput,
|
|
resolve_input: bool = True,
|
|
) -> tuple[BlockInput | None, str]:
|
|
"""
|
|
Validate the input data for a node execution.
|
|
|
|
Args:
|
|
node: The node to execute.
|
|
data: The input data for the node execution.
|
|
resolve_input: Whether to resolve dynamic pins into dict/list/object.
|
|
|
|
Returns:
|
|
A tuple of the validated data and the block name.
|
|
If the data is invalid, the first element will be None, and the second element
|
|
will be an error message.
|
|
If the data is valid, the first element will be the resolved input data, and
|
|
the second element will be the block name.
|
|
"""
|
|
node_block = get_block(node.block_id)
|
|
if not node_block:
|
|
return None, f"Block for {node.block_id} not found."
|
|
schema = node_block.input_schema
|
|
|
|
# Input data (without default values) should contain all required fields.
|
|
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
|
|
if missing_links := schema.get_missing_links(data, node.input_links):
|
|
return None, f"{error_prefix} unpopulated links {missing_links}"
|
|
|
|
# Merge input data with default values and resolve dynamic dict/list/object pins.
|
|
input_default = schema.get_input_defaults(node.input_default)
|
|
data = {**input_default, **data}
|
|
if resolve_input:
|
|
data = merge_execution_input(data)
|
|
|
|
# Convert non-matching data types to the expected input schema.
|
|
for name, data_type in schema.__annotations__.items():
|
|
value = data.get(name)
|
|
if (value is not None) and (type(value) is not data_type):
|
|
data[name] = convert(value, data_type)
|
|
|
|
# Input data post-merge should contain all required fields from the schema.
|
|
if missing_input := schema.get_missing_input(data):
|
|
return None, f"{error_prefix} missing input {missing_input}"
|
|
|
|
# Last validation: Validate the input values against the schema.
|
|
if error := schema.get_mismatch_error(data):
|
|
error_message = f"{error_prefix} {error}"
|
|
logger.warning(error_message)
|
|
return None, error_message
|
|
|
|
return data, node_block.name
|
|
|
|
|
|
async def _validate_node_input_credentials(
|
|
graph: GraphModel,
|
|
user_id: str,
|
|
nodes_input_masks: Optional[NodesInputMasks] = None,
|
|
) -> tuple[dict[str, dict[str, str]], set[str]]:
|
|
"""
|
|
Checks all credentials for all nodes of the graph and returns structured errors
|
|
and a set of nodes that should be skipped due to optional missing credentials.
|
|
|
|
Returns:
|
|
tuple[
|
|
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node,
|
|
set[node_id]: Nodes that should be skipped (optional credentials not configured)
|
|
]
|
|
"""
|
|
credential_errors: dict[str, dict[str, str]] = defaultdict(dict)
|
|
nodes_to_skip: set[str] = set()
|
|
|
|
for node in graph.nodes:
|
|
block = node.block
|
|
|
|
# Find any fields of type CredentialsMetaInput
|
|
credentials_fields = block.input_schema.get_credentials_fields()
|
|
if not credentials_fields:
|
|
continue
|
|
|
|
# Track if any credential field is missing for this node
|
|
has_missing_credentials = False
|
|
|
|
for field_name, credentials_meta_type in credentials_fields.items():
|
|
try:
|
|
# Check nodes_input_masks first, then input_default
|
|
field_value = None
|
|
if (
|
|
nodes_input_masks
|
|
and (node_input_mask := nodes_input_masks.get(node.id))
|
|
and field_name in node_input_mask
|
|
):
|
|
field_value = node_input_mask[field_name]
|
|
elif field_name in node.input_default:
|
|
# For optional credentials, don't use input_default - treat as missing
|
|
# This prevents stale credential IDs from failing validation
|
|
if node.credentials_optional:
|
|
field_value = None
|
|
else:
|
|
field_value = node.input_default[field_name]
|
|
|
|
# Check if credentials are missing (None, empty, or not present)
|
|
if field_value is None or (
|
|
isinstance(field_value, dict) and not field_value.get("id")
|
|
):
|
|
has_missing_credentials = True
|
|
# If node has credentials_optional flag, mark for skipping instead of error
|
|
if node.credentials_optional:
|
|
continue # Don't add error, will be marked for skip after loop
|
|
else:
|
|
credential_errors[node.id][
|
|
field_name
|
|
] = "These credentials are required"
|
|
continue
|
|
|
|
credentials_meta = credentials_meta_type.model_validate(field_value)
|
|
|
|
except ValidationError as e:
|
|
# Validation error means credentials were provided but invalid
|
|
# This should always be an error, even if optional
|
|
credential_errors[node.id][field_name] = f"Invalid credentials: {e}"
|
|
continue
|
|
|
|
try:
|
|
# Fetch the corresponding Credentials and perform sanity checks
|
|
credentials = await get_integration_credentials_store().get_creds_by_id(
|
|
user_id, credentials_meta.id
|
|
)
|
|
except Exception as e:
|
|
# Handle any errors fetching credentials
|
|
# If credentials were explicitly configured but unavailable, it's an error
|
|
credential_errors[node.id][
|
|
field_name
|
|
] = f"Credentials not available: {e}"
|
|
continue
|
|
|
|
if not credentials:
|
|
credential_errors[node.id][
|
|
field_name
|
|
] = f"Unknown credentials #{credentials_meta.id}"
|
|
continue
|
|
|
|
if (
|
|
credentials.provider != credentials_meta.provider
|
|
or credentials.type != credentials_meta.type
|
|
):
|
|
logger.warning(
|
|
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
|
"type/provider mismatch: "
|
|
f"{credentials_meta.type}<>{credentials.type};"
|
|
f"{credentials_meta.provider}<>{credentials.provider}"
|
|
)
|
|
credential_errors[node.id][
|
|
field_name
|
|
] = "Invalid credentials: type/provider mismatch"
|
|
continue
|
|
|
|
# If node has optional credentials and any are missing, mark for skipping
|
|
# But only if there are no other errors for this node
|
|
if (
|
|
has_missing_credentials
|
|
and node.credentials_optional
|
|
and node.id not in credential_errors
|
|
):
|
|
nodes_to_skip.add(node.id)
|
|
logger.info(
|
|
f"Node #{node.id} will be skipped: optional credentials not configured"
|
|
)
|
|
|
|
return credential_errors, nodes_to_skip
|
|
|
|
|
|
def make_node_credentials_input_map(
|
|
graph: GraphModel,
|
|
graph_credentials_input: Mapping[str, CredentialsMetaInput],
|
|
) -> NodesInputMasks:
|
|
"""
|
|
Maps credentials for an execution to the correct nodes.
|
|
|
|
Params:
|
|
graph: The graph to be executed.
|
|
graph_credentials_input: A (graph_input_name, credentials_meta) map.
|
|
|
|
Returns:
|
|
dict[node_id, dict[field_name, CredentialsMetaRaw]]: Node credentials input map.
|
|
"""
|
|
result: dict[str, dict[str, JsonValue]] = {}
|
|
|
|
# Get aggregated credentials fields for the graph
|
|
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
|
|
|
for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items():
|
|
# Best-effort map: skip missing items
|
|
if graph_input_name not in graph_credentials_input:
|
|
continue
|
|
|
|
# Use passed-in credentials for all compatible node input fields
|
|
for node_id, node_field_name in compatible_node_fields:
|
|
if node_id not in result:
|
|
result[node_id] = {}
|
|
result[node_id][node_field_name] = graph_credentials_input[
|
|
graph_input_name
|
|
].model_dump(exclude_none=True)
|
|
|
|
return result
|
|
|
|
|
|
async def validate_graph_with_credentials(
|
|
graph: GraphModel,
|
|
user_id: str,
|
|
nodes_input_masks: Optional[NodesInputMasks] = None,
|
|
) -> tuple[Mapping[str, Mapping[str, str]], set[str]]:
|
|
"""
|
|
Validate graph including credentials and return structured errors per node,
|
|
along with a set of nodes that should be skipped due to optional missing credentials.
|
|
|
|
Returns:
|
|
tuple[
|
|
dict[node_id, dict[field_name, error_message]]: Validation errors per node,
|
|
set[node_id]: Nodes that should be skipped (optional credentials not configured)
|
|
]
|
|
"""
|
|
# Get input validation errors
|
|
node_input_errors = GraphModel.validate_graph_get_errors(
|
|
graph, for_run=True, nodes_input_masks=nodes_input_masks
|
|
)
|
|
|
|
# Get credential input/availability/validation errors and nodes to skip
|
|
node_credential_input_errors, nodes_to_skip = (
|
|
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
|
|
)
|
|
|
|
# Merge credential errors with structural errors
|
|
for node_id, field_errors in node_credential_input_errors.items():
|
|
if node_id not in node_input_errors:
|
|
node_input_errors[node_id] = {}
|
|
node_input_errors[node_id].update(field_errors)
|
|
|
|
return node_input_errors, nodes_to_skip
|
|
|
|
|
|
async def _construct_starting_node_execution_input(
|
|
graph: GraphModel,
|
|
user_id: str,
|
|
graph_inputs: BlockInput,
|
|
nodes_input_masks: Optional[NodesInputMasks] = None,
|
|
) -> tuple[list[tuple[str, BlockInput]], set[str]]:
|
|
"""
|
|
Validates and prepares the input data for executing a graph.
|
|
This function checks the graph for starting nodes, validates the input data
|
|
against the schema, and resolves dynamic input pins into a single list,
|
|
dictionary, or object.
|
|
|
|
Args:
|
|
graph (GraphModel): The graph model to execute.
|
|
user_id (str): The ID of the user executing the graph.
|
|
data (BlockInput): The input data for the graph execution.
|
|
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
|
|
|
|
Returns:
|
|
tuple[
|
|
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID
|
|
and the corresponding input data for that node.
|
|
set[str]: Node IDs that should be skipped (optional credentials not configured)
|
|
]
|
|
"""
|
|
# Use new validation function that includes credentials
|
|
validation_errors, nodes_to_skip = await validate_graph_with_credentials(
|
|
graph, user_id, nodes_input_masks
|
|
)
|
|
n_error_nodes = len(validation_errors)
|
|
n_errors = sum(len(errors) for errors in validation_errors.values())
|
|
if validation_errors:
|
|
raise GraphValidationError(
|
|
f"Graph validation failed: {n_errors} issues on {n_error_nodes} nodes",
|
|
node_errors=validation_errors,
|
|
)
|
|
|
|
nodes_input = []
|
|
for node in graph.starting_nodes:
|
|
input_data = {}
|
|
block = node.block
|
|
|
|
# Note block should never be executed.
|
|
if block.block_type == BlockType.NOTE:
|
|
continue
|
|
|
|
# Extract request input data, and assign it to the input pin.
|
|
if block.block_type == BlockType.INPUT:
|
|
input_name = cast(str | None, node.input_default.get("name"))
|
|
if input_name and input_name in graph_inputs:
|
|
input_data = {"value": graph_inputs[input_name]}
|
|
|
|
# Apply node input overrides
|
|
if nodes_input_masks and (node_input_mask := nodes_input_masks.get(node.id)):
|
|
input_data.update(node_input_mask)
|
|
|
|
input_data, error = validate_exec(node, input_data)
|
|
if input_data is None:
|
|
raise ValueError(error)
|
|
else:
|
|
nodes_input.append((node.id, input_data))
|
|
|
|
if not nodes_input:
|
|
raise ValueError(
|
|
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
|
|
)
|
|
|
|
return nodes_input, nodes_to_skip
|
|
|
|
|
|
async def validate_and_construct_node_execution_input(
|
|
graph_id: str,
|
|
user_id: str,
|
|
graph_inputs: BlockInput,
|
|
graph_version: Optional[int] = None,
|
|
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
|
nodes_input_masks: Optional[NodesInputMasks] = None,
|
|
is_sub_graph: bool = False,
|
|
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks, set[str]]:
|
|
"""
|
|
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
|
|
This centralizes the logic used by both scheduler validation and actual execution.
|
|
|
|
Args:
|
|
graph_id: The ID of the graph to validate/construct.
|
|
user_id: The ID of the user.
|
|
graph_inputs: The input data for the graph execution.
|
|
graph_version: The version of the graph to use.
|
|
graph_credentials_inputs: Credentials inputs to use.
|
|
nodes_input_masks: Node inputs to use.
|
|
|
|
Returns:
|
|
GraphModel: Full graph object for the given `graph_id`.
|
|
list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs.
|
|
dict[str, BlockInput]: Node input masks including all passed-in credentials.
|
|
set[str]: Node IDs that should be skipped (optional credentials not configured).
|
|
|
|
Raises:
|
|
NotFoundError: If the graph is not found.
|
|
GraphValidationError: If the graph has validation issues.
|
|
ValueError: If there are other validation errors.
|
|
"""
|
|
if prisma.is_connected():
|
|
gdb = graph_db
|
|
else:
|
|
gdb = get_database_manager_async_client()
|
|
|
|
graph: GraphModel | None = await gdb.get_graph(
|
|
graph_id=graph_id,
|
|
user_id=user_id,
|
|
version=graph_version,
|
|
include_subgraphs=True,
|
|
# Execution/access permission is checked by validate_graph_execution_permissions
|
|
skip_access_check=True,
|
|
)
|
|
if not graph:
|
|
raise GraphNotFoundError(f"Graph #{graph_id} not found.")
|
|
|
|
# Validate that the user has permission to execute this graph
|
|
# This checks both library membership and execution permissions,
|
|
# raising specific exceptions for appropriate error handling.
|
|
await gdb.validate_graph_execution_permissions(
|
|
user_id=user_id,
|
|
graph_id=graph.id,
|
|
graph_version=graph.version,
|
|
is_sub_graph=is_sub_graph,
|
|
)
|
|
|
|
nodes_input_masks = _merge_nodes_input_masks(
|
|
(
|
|
make_node_credentials_input_map(graph, graph_credentials_inputs)
|
|
if graph_credentials_inputs
|
|
else {}
|
|
),
|
|
nodes_input_masks or {},
|
|
)
|
|
|
|
starting_nodes_input, nodes_to_skip = (
|
|
await _construct_starting_node_execution_input(
|
|
graph=graph,
|
|
user_id=user_id,
|
|
graph_inputs=graph_inputs,
|
|
nodes_input_masks=nodes_input_masks,
|
|
)
|
|
)
|
|
|
|
return graph, starting_nodes_input, nodes_input_masks, nodes_to_skip
|
|
|
|
|
|
def _merge_nodes_input_masks(
|
|
overrides_map_1: NodesInputMasks,
|
|
overrides_map_2: NodesInputMasks,
|
|
) -> NodesInputMasks:
|
|
"""Perform a per-node merge of input overrides"""
|
|
result = dict(overrides_map_1).copy()
|
|
for node_id, overrides2 in overrides_map_2.items():
|
|
if node_id in result:
|
|
result[node_id] = {**result[node_id], **overrides2}
|
|
else:
|
|
result[node_id] = overrides2
|
|
return result
|
|
|
|
|
|
# ============ Execution Queue Helpers ============ #
|
|
|
|
GRAPH_EXECUTION_EXCHANGE = Exchange(
|
|
name="graph_execution",
|
|
type=ExchangeType.DIRECT,
|
|
durable=True,
|
|
auto_delete=False,
|
|
)
|
|
GRAPH_EXECUTION_QUEUE_NAME = "graph_execution_queue"
|
|
GRAPH_EXECUTION_ROUTING_KEY = "graph_execution.run"
|
|
|
|
GRAPH_EXECUTION_CANCEL_EXCHANGE = Exchange(
|
|
name="graph_execution_cancel",
|
|
type=ExchangeType.FANOUT,
|
|
durable=True,
|
|
auto_delete=True,
|
|
)
|
|
GRAPH_EXECUTION_CANCEL_QUEUE_NAME = "graph_execution_cancel_queue"
|
|
|
|
# Graceful shutdown timeout constants
|
|
# Agent executions can run for up to 1 day, so we need a graceful shutdown period
|
|
# that allows long-running executions to complete naturally
|
|
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = 24 * 60 * 60 # 1 day to complete active executions
|
|
|
|
|
|
def create_execution_queue_config() -> RabbitMQConfig:
|
|
"""
|
|
Define two exchanges and queues:
|
|
- 'graph_execution' (DIRECT) for run tasks.
|
|
- 'graph_execution_cancel' (FANOUT) for cancel requests.
|
|
"""
|
|
run_queue = Queue(
|
|
name=GRAPH_EXECUTION_QUEUE_NAME,
|
|
exchange=GRAPH_EXECUTION_EXCHANGE,
|
|
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
|
durable=True,
|
|
auto_delete=False,
|
|
arguments={
|
|
# x-consumer-timeout (1 week)
|
|
# Problem: Default 30-minute consumer timeout kills long-running graph executions
|
|
# Original error: "Consumer acknowledgement timed out after 1800000 ms (30 minutes)"
|
|
# Solution: Disable consumer timeout entirely - let graphs run indefinitely
|
|
# Safety: Heartbeat mechanism now handles dead consumer detection instead
|
|
# Use case: Graph executions that take hours to complete (AI model training, etc.)
|
|
"x-consumer-timeout": GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS
|
|
* 1000,
|
|
},
|
|
)
|
|
cancel_queue = Queue(
|
|
name=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
|
exchange=GRAPH_EXECUTION_CANCEL_EXCHANGE,
|
|
routing_key="", # not used for FANOUT
|
|
durable=True,
|
|
auto_delete=False,
|
|
)
|
|
return RabbitMQConfig(
|
|
vhost="/",
|
|
exchanges=[GRAPH_EXECUTION_EXCHANGE, GRAPH_EXECUTION_CANCEL_EXCHANGE],
|
|
queues=[run_queue, cancel_queue],
|
|
)
|
|
|
|
|
|
class CancelExecutionEvent(BaseModel):
|
|
graph_exec_id: str
|
|
|
|
|
|
async def _get_child_executions(parent_exec_id: str) -> list["GraphExecutionMeta"]:
|
|
"""
|
|
Get all child executions of a parent execution using the execution_db pattern.
|
|
|
|
Args:
|
|
parent_exec_id: Parent graph execution ID
|
|
|
|
Returns:
|
|
List of child graph executions
|
|
"""
|
|
from backend.data.db import prisma
|
|
|
|
if prisma.is_connected():
|
|
edb = execution_db
|
|
else:
|
|
edb = get_database_manager_async_client()
|
|
|
|
return await edb.get_child_graph_executions(parent_exec_id)
|
|
|
|
|
|
async def stop_graph_execution(
|
|
user_id: str,
|
|
graph_exec_id: str,
|
|
wait_timeout: float = 15.0,
|
|
cascade: bool = True,
|
|
):
|
|
"""
|
|
Stop a graph execution and optionally all its child executions.
|
|
|
|
Mechanism:
|
|
1. Set the cancel event for this execution
|
|
2. If cascade=True, recursively stop all child executions
|
|
3. Graph executor's cancel handler thread detects the event, terminates workers,
|
|
reinitializes worker pool, and returns.
|
|
4. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
|
|
|
Args:
|
|
user_id: User ID who owns the execution
|
|
graph_exec_id: Graph execution ID to stop
|
|
wait_timeout: Maximum time to wait for execution to stop (seconds)
|
|
cascade: If True, recursively stop all child executions
|
|
"""
|
|
queue_client = await get_async_execution_queue()
|
|
db = execution_db if prisma.is_connected() else get_database_manager_async_client()
|
|
|
|
# First, find and stop all child executions if cascading
|
|
if cascade:
|
|
children = await _get_child_executions(graph_exec_id)
|
|
logger.info(
|
|
f"Stopping {len(children)} child executions of execution {graph_exec_id}"
|
|
)
|
|
|
|
# Stop all children in parallel (recursively, with cascading enabled)
|
|
if children:
|
|
await asyncio.gather(
|
|
*[
|
|
stop_graph_execution(
|
|
user_id=user_id,
|
|
graph_exec_id=child.id,
|
|
wait_timeout=wait_timeout,
|
|
cascade=True, # Recursively cascade to grandchildren
|
|
)
|
|
for child in children
|
|
],
|
|
return_exceptions=True, # Don't fail parent stop if child stop fails
|
|
)
|
|
|
|
# Now stop this execution
|
|
await queue_client.publish_message(
|
|
routing_key="",
|
|
message=CancelExecutionEvent(graph_exec_id=graph_exec_id).model_dump_json(),
|
|
exchange=GRAPH_EXECUTION_CANCEL_EXCHANGE,
|
|
)
|
|
|
|
if not wait_timeout:
|
|
return
|
|
|
|
start_time = time.time()
|
|
while time.time() - start_time < wait_timeout:
|
|
graph_exec = await db.get_graph_execution_meta(
|
|
execution_id=graph_exec_id, user_id=user_id
|
|
)
|
|
|
|
if not graph_exec:
|
|
raise NotFoundError(f"Graph execution #{graph_exec_id} not found.")
|
|
|
|
if graph_exec.status in [
|
|
ExecutionStatus.TERMINATED,
|
|
ExecutionStatus.COMPLETED,
|
|
ExecutionStatus.FAILED,
|
|
]:
|
|
# If graph execution is terminated/completed/failed, cancellation is complete
|
|
await get_async_execution_event_bus().publish(graph_exec)
|
|
return
|
|
|
|
if graph_exec.status in [
|
|
ExecutionStatus.QUEUED,
|
|
ExecutionStatus.INCOMPLETE,
|
|
]:
|
|
# If the graph is still on the queue, we can prevent them from being executed
|
|
# by setting the status to TERMINATED.
|
|
graph_exec.status = ExecutionStatus.TERMINATED
|
|
|
|
await asyncio.gather(
|
|
# Update graph execution status
|
|
db.update_graph_execution_stats(
|
|
graph_exec_id=graph_exec.id,
|
|
status=ExecutionStatus.TERMINATED,
|
|
),
|
|
# Publish graph execution event
|
|
get_async_execution_event_bus().publish(graph_exec),
|
|
)
|
|
return
|
|
|
|
if graph_exec.status == ExecutionStatus.RUNNING:
|
|
await asyncio.sleep(0.1)
|
|
|
|
raise TimeoutError(
|
|
f"Graph execution #{graph_exec_id} will need to take longer than {wait_timeout} seconds to stop. "
|
|
f"You can check the status of the execution in the UI or try again later."
|
|
)
|
|
|
|
|
|
async def add_graph_execution(
|
|
graph_id: str,
|
|
user_id: str,
|
|
inputs: Optional[BlockInput] = None,
|
|
preset_id: Optional[str] = None,
|
|
graph_version: Optional[int] = None,
|
|
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
|
|
nodes_input_masks: Optional[NodesInputMasks] = None,
|
|
execution_context: Optional[ExecutionContext] = None,
|
|
graph_exec_id: Optional[str] = None,
|
|
) -> GraphExecutionWithNodes:
|
|
"""
|
|
Adds a graph execution to the queue and returns the execution entry.
|
|
|
|
Args:
|
|
graph_id: The ID of the graph to execute.
|
|
user_id: The ID of the user executing the graph.
|
|
inputs: The input data for the graph execution.
|
|
preset_id: The ID of the preset to use.
|
|
graph_version: The version of the graph to execute.
|
|
graph_credentials_inputs: Credentials inputs to use in the execution.
|
|
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
|
|
nodes_input_masks: Node inputs to use in the execution.
|
|
parent_graph_exec_id: The ID of the parent graph execution (for nested executions).
|
|
graph_exec_id: If provided, resume this existing execution instead of creating a new one.
|
|
Returns:
|
|
GraphExecutionEntry: The entry for the graph execution.
|
|
Raises:
|
|
ValueError: If the graph is not found or if there are validation errors.
|
|
NotFoundError: If graph_exec_id is provided but execution is not found.
|
|
"""
|
|
if prisma.is_connected():
|
|
edb = execution_db
|
|
udb = user_db
|
|
gdb = graph_db
|
|
else:
|
|
edb = udb = gdb = get_database_manager_async_client()
|
|
|
|
# Get or create the graph execution
|
|
if graph_exec_id:
|
|
# Resume existing execution
|
|
graph_exec = await get_graph_execution(
|
|
user_id=user_id,
|
|
execution_id=graph_exec_id,
|
|
include_node_executions=True,
|
|
)
|
|
|
|
if not graph_exec:
|
|
raise NotFoundError(f"Graph execution #{graph_exec_id} not found.")
|
|
|
|
# Use existing execution's compiled input masks
|
|
compiled_nodes_input_masks = graph_exec.nodes_input_masks or {}
|
|
# For resumed executions, nodes_to_skip was already determined at creation time
|
|
# TODO: Consider storing nodes_to_skip in DB if we need to preserve it across resumes
|
|
nodes_to_skip: set[str] = set()
|
|
|
|
logger.info(f"Resuming graph execution #{graph_exec.id} for graph #{graph_id}")
|
|
else:
|
|
parent_exec_id = (
|
|
execution_context.parent_execution_id if execution_context else None
|
|
)
|
|
|
|
# Create new execution
|
|
graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip = (
|
|
await validate_and_construct_node_execution_input(
|
|
graph_id=graph_id,
|
|
user_id=user_id,
|
|
graph_inputs=inputs or {},
|
|
graph_version=graph_version,
|
|
graph_credentials_inputs=graph_credentials_inputs,
|
|
nodes_input_masks=nodes_input_masks,
|
|
is_sub_graph=parent_exec_id is not None,
|
|
)
|
|
)
|
|
|
|
graph_exec = await edb.create_graph_execution(
|
|
user_id=user_id,
|
|
graph_id=graph_id,
|
|
graph_version=graph.version,
|
|
inputs=inputs or {},
|
|
credential_inputs=graph_credentials_inputs,
|
|
nodes_input_masks=nodes_input_masks,
|
|
starting_nodes_input=starting_nodes_input,
|
|
preset_id=preset_id,
|
|
parent_graph_exec_id=parent_exec_id,
|
|
)
|
|
|
|
logger.info(
|
|
f"Created graph execution #{graph_exec.id} for graph "
|
|
f"#{graph_id} with {len(starting_nodes_input)} starting nodes"
|
|
)
|
|
|
|
# Generate execution context if it's not provided
|
|
if execution_context is None:
|
|
user = await udb.get_user_by_id(user_id)
|
|
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
|
|
|
|
execution_context = ExecutionContext(
|
|
safe_mode=(
|
|
settings.human_in_the_loop_safe_mode
|
|
if settings.human_in_the_loop_safe_mode is not None
|
|
else True
|
|
),
|
|
user_timezone=(
|
|
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
|
|
),
|
|
root_execution_id=graph_exec.id,
|
|
)
|
|
|
|
try:
|
|
graph_exec_entry = graph_exec.to_graph_execution_entry(
|
|
compiled_nodes_input_masks=compiled_nodes_input_masks,
|
|
nodes_to_skip=nodes_to_skip,
|
|
execution_context=execution_context,
|
|
)
|
|
logger.info(f"Publishing execution {graph_exec.id} to execution queue")
|
|
|
|
exec_queue = await get_async_execution_queue()
|
|
await exec_queue.publish_message(
|
|
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
|
message=graph_exec_entry.model_dump_json(),
|
|
exchange=GRAPH_EXECUTION_EXCHANGE,
|
|
)
|
|
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
|
|
|
|
graph_exec.status = ExecutionStatus.QUEUED
|
|
await edb.update_graph_execution_stats(
|
|
graph_exec_id=graph_exec.id,
|
|
status=graph_exec.status,
|
|
)
|
|
await get_async_execution_event_bus().publish(graph_exec)
|
|
|
|
return graph_exec
|
|
except BaseException as e:
|
|
err = str(e) or type(e).__name__
|
|
if not graph_exec:
|
|
logger.error(f"Unable to execute graph #{graph_id} failed: {err}")
|
|
raise
|
|
|
|
logger.error(
|
|
f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {err}"
|
|
)
|
|
await edb.update_node_execution_status_batch(
|
|
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
|
|
ExecutionStatus.FAILED,
|
|
)
|
|
await edb.update_graph_execution_stats(
|
|
graph_exec_id=graph_exec.id,
|
|
status=ExecutionStatus.FAILED,
|
|
stats=GraphExecutionStats(error=err),
|
|
)
|
|
raise
|
|
|
|
|
|
# ============ Execution Output Helpers ============ #
|
|
|
|
|
|
class ExecutionOutputEntry(BaseModel):
|
|
node: Node
|
|
node_exec_id: str
|
|
data: BlockOutputEntry
|
|
|
|
|
|
class NodeExecutionProgress:
|
|
def __init__(self):
|
|
self.output: dict[str, list[ExecutionOutputEntry]] = defaultdict(list)
|
|
self.tasks: dict[str, Future] = {}
|
|
self._lock = threading.Lock()
|
|
|
|
def add_task(self, node_exec_id: str, task: Future):
|
|
self.tasks[node_exec_id] = task
|
|
|
|
def add_output(self, output: ExecutionOutputEntry):
|
|
with self._lock:
|
|
self.output[output.node_exec_id].append(output)
|
|
|
|
def pop_output(self) -> ExecutionOutputEntry | None:
|
|
exec_id = self._next_exec()
|
|
if not exec_id:
|
|
return None
|
|
|
|
if self._pop_done_task(exec_id):
|
|
return self.pop_output()
|
|
|
|
with self._lock:
|
|
if next_output := self.output[exec_id]:
|
|
return next_output.pop(0)
|
|
|
|
return None
|
|
|
|
def is_done(self, wait_time: float = 0.0) -> bool:
|
|
exec_id = self._next_exec()
|
|
if not exec_id:
|
|
return True
|
|
|
|
if self._pop_done_task(exec_id):
|
|
return self.is_done(wait_time)
|
|
|
|
if wait_time <= 0:
|
|
return False
|
|
|
|
try:
|
|
self.tasks[exec_id].result(wait_time)
|
|
except TimeoutError:
|
|
pass
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Task for exec ID {exec_id} failed with error: {e.__class__.__name__} {str(e)}"
|
|
)
|
|
pass
|
|
return self.is_done(0)
|
|
|
|
def stop(self) -> list[str]:
|
|
"""
|
|
Stops all tasks and clears the output.
|
|
This is useful for cleaning up when the execution is cancelled or terminated.
|
|
Returns a list of execution IDs that were stopped.
|
|
"""
|
|
cancelled_ids = []
|
|
for task_id, task in self.tasks.items():
|
|
if task.done():
|
|
continue
|
|
task.cancel()
|
|
cancelled_ids.append(task_id)
|
|
return cancelled_ids
|
|
|
|
def wait_for_done(self, timeout: float = 5.0):
|
|
"""
|
|
Wait for all cancelled tasks to complete cancellation.
|
|
|
|
Args:
|
|
timeout: Maximum time to wait for cancellation in seconds
|
|
"""
|
|
start_time = time.time()
|
|
|
|
while time.time() - start_time < timeout:
|
|
while self.pop_output():
|
|
pass
|
|
|
|
if self.is_done():
|
|
return
|
|
|
|
time.sleep(0.1) # Small delay to avoid busy waiting
|
|
|
|
raise TimeoutError(
|
|
f"Timeout waiting for cancellation of tasks: {list(self.tasks.keys())}"
|
|
)
|
|
|
|
def _pop_done_task(self, exec_id: str) -> bool:
|
|
task = self.tasks.get(exec_id)
|
|
if not task:
|
|
return True
|
|
|
|
if not task.done():
|
|
return False
|
|
|
|
with self._lock:
|
|
if self.output[exec_id]:
|
|
return False
|
|
|
|
self.tasks.pop(exec_id)
|
|
return True
|
|
|
|
def _next_exec(self) -> str | None:
|
|
if not self.tasks:
|
|
return None
|
|
return next(iter(self.tasks.keys()))
|