Files
AutoGPT/autogpt_platform/backend/backend/executor/utils.py
Nicholas Tindle 7668c17d9c feat(platform): add User Workspace for persistent CoPilot file storage (#11867)
Implements persistent User Workspace storage for CoPilot, enabling
blocks to save and retrieve files across sessions. Files are stored in
session-scoped virtual paths (`/sessions/{session_id}/`).

Fixes SECRT-1833

### Changes 🏗️

**Database & Storage:**
- Add `UserWorkspace` and `UserWorkspaceFile` Prisma models
- Implement `WorkspaceStorageBackend` abstraction (GCS for cloud, local
filesystem for self-hosted)
- Add `workspace_id` and `session_id` fields to `ExecutionContext`

**Backend API:**
- Add REST endpoints: `GET/POST /api/workspace/files`, `GET/DELETE
/api/workspace/files/{id}`, `GET /api/workspace/files/{id}/download`
- Add CoPilot tools: `list_workspace_files`, `read_workspace_file`,
`write_workspace_file`
- Integrate workspace storage into `store_media_file()` - returns
`workspace://file-id` references

**Block Updates:**
- Refactor all file-handling blocks to use unified `ExecutionContext`
parameter
- Update media-generating blocks to persist outputs to workspace
(AIImageGenerator, AIImageCustomizer, FluxKontext, TalkingHead, FAL
video, Bannerbear, etc.)

**Frontend:**
- Render `workspace://` image references in chat via proxy endpoint
- Add "AI cannot see this image" overlay indicator

**CoPilot Context Mapping:**
- Session = Agent (graph_id) = Run (graph_exec_id)
- Files scoped to `/sessions/{session_id}/`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan:
- [ ] Create CoPilot session, generate image with AIImageGeneratorBlock
  - [ ] Verify image returns `workspace://file-id` (not base64)
  - [ ] Verify image renders in chat with visibility indicator
  - [ ] Verify workspace files persist across sessions
  - [ ] Test list/read/write workspace files via CoPilot tools
  - [ ] Test local storage backend for self-hosted deployments

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

🤖 Generated with [Claude Code](https://claude.ai/code)

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> Introduces a new persistent file-storage surface area (DB tables,
storage backends, download API, and chat tools) and rewires
`store_media_file()`/block execution context across many blocks, so
regressions could impact file handling, access control, or storage
costs.
> 
> **Overview**
> Adds a **persistent per-user Workspace** (new
`UserWorkspace`/`UserWorkspaceFile` models plus `WorkspaceManager` +
`WorkspaceStorageBackend` with GCS/local implementations) and wires it
into the API via a new `/api/workspace/files/{file_id}/download` route
(including header-sanitized `Content-Disposition`) and shutdown
lifecycle hooks.
> 
> Extends `ExecutionContext` to carry execution identity +
`workspace_id`/`session_id`, updates executor tooling to clone
node-specific contexts, and updates `run_block` (CoPilot) to create a
session-scoped workspace and synthetic graph/run/node IDs.
> 
> Refactors `store_media_file()` to require `execution_context` +
`return_format` and to support `workspace://` references; migrates many
media/file-handling blocks and related tests to the new API and to
persist generated media as `workspace://...` (or fall back to data URIs
outside CoPilot), and adds CoPilot chat tools for
listing/reading/writing/deleting workspace files with safeguards against
context bloat.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
6abc70f793. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-01-29 05:49:47 +00:00

1098 lines
38 KiB
Python

import asyncio
import logging
import threading
import time
from collections import defaultdict
from concurrent.futures import Future
from typing import Mapping, Optional, cast
from pydantic import BaseModel, JsonValue, ValidationError
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data import human_review as human_review_db
from backend.data import onboarding as onboarding_db
from backend.data import user as user_db
from backend.data.block import (
Block,
BlockCostType,
BlockInput,
BlockOutputEntry,
BlockType,
get_block,
)
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.db import prisma
# Import dynamic field utilities from centralized location
from backend.data.dynamic_fields import merge_execution_input
from backend.data.execution import (
ExecutionContext,
ExecutionStatus,
GraphExecutionMeta,
GraphExecutionStats,
GraphExecutionWithNodes,
NodesInputMasks,
)
from backend.data.graph import GraphModel, Node
from backend.data.model import USER_TIMEZONE_NOT_SET, CredentialsMetaInput
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.util.clients import (
get_async_execution_event_bus,
get_async_execution_queue,
get_database_manager_async_client,
get_integration_credentials_store,
)
from backend.util.exceptions import (
GraphNotFoundError,
GraphValidationError,
NotFoundError,
)
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
from backend.util.settings import Config
from backend.util.type import convert
config = Config()
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[GraphExecutorUtil]")
# ============ Resource Helpers ============ #
class LogMetadata(TruncatedLogger):
def __init__(
self,
logger: logging.Logger,
user_id: str,
graph_eid: str,
graph_id: str,
node_eid: str,
node_id: str,
block_name: str,
max_length: int = 1000,
):
metadata = {
"component": "ExecutionManager",
"user_id": user_id,
"graph_eid": graph_eid,
"graph_id": graph_id,
"node_eid": node_eid,
"node_id": node_id,
"block_name": block_name,
}
prefix = (
"[ExecutionManager]"
if is_structured_logging_enabled()
else f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|neid:{node_eid}|{block_name}]" # noqa
)
super().__init__(
logger,
max_length=max_length,
prefix=prefix,
metadata=metadata,
)
# ============ Execution Cost Helpers ============ #
def execution_usage_cost(execution_count: int) -> tuple[int, int]:
"""
Calculate the cost of executing a graph based on the current number of node executions.
Args:
execution_count: Number of node executions
Returns:
Tuple of cost amount and the number of execution count that is included in the cost.
"""
return (
(
config.execution_cost_per_threshold
if execution_count % config.execution_cost_count_threshold == 0
else 0
),
config.execution_cost_count_threshold,
)
def block_usage_cost(
block: Block,
input_data: BlockInput,
data_size: float = 0,
run_time: float = 0,
) -> tuple[int, BlockInput]:
"""
Calculate the cost of using a block based on the input data and the block type.
Args:
block: Block object
input_data: Input data for the block
data_size: Size of the input data in bytes
run_time: Execution time of the block in seconds
Returns:
Tuple of cost amount and cost filter
"""
block_costs = BLOCK_COSTS.get(type(block))
if not block_costs:
return 0, {}
for block_cost in block_costs:
if not _is_cost_filter_match(block_cost.cost_filter, input_data):
continue
if block_cost.cost_type == BlockCostType.RUN:
return block_cost.cost_amount, block_cost.cost_filter
if block_cost.cost_type == BlockCostType.SECOND:
return (
int(run_time * block_cost.cost_amount),
block_cost.cost_filter,
)
if block_cost.cost_type == BlockCostType.BYTE:
return (
int(data_size * block_cost.cost_amount),
block_cost.cost_filter,
)
return 0, {}
def _is_cost_filter_match(cost_filter: BlockInput, input_data: BlockInput) -> bool:
"""
Filter rules:
- If cost_filter is an object, then check if cost_filter is the subset of input_data
- Otherwise, check if cost_filter is equal to input_data.
- Undefined, null, and empty string are considered as equal.
"""
if not isinstance(cost_filter, dict) or not isinstance(input_data, dict):
return cost_filter == input_data
return all(
(not input_data.get(k) and not v)
or (input_data.get(k) and _is_cost_filter_match(v, input_data[k]))
for k, v in cost_filter.items()
)
# ============ Execution Input Helpers ============ #
# Dynamic field utilities are now imported from backend.data.dynamic_fields
def validate_exec(
node: Node,
data: BlockInput,
resolve_input: bool = True,
) -> tuple[BlockInput | None, str]:
"""
Validate the input data for a node execution.
Args:
node: The node to execute.
data: The input data for the node execution.
resolve_input: Whether to resolve dynamic pins into dict/list/object.
Returns:
A tuple of the validated data and the block name.
If the data is invalid, the first element will be None, and the second element
will be an error message.
If the data is valid, the first element will be the resolved input data, and
the second element will be the block name.
"""
node_block = get_block(node.block_id)
if not node_block:
return None, f"Block for {node.block_id} not found."
schema = node_block.input_schema
# Input data (without default values) should contain all required fields.
error_prefix = f"Input data missing or mismatch for `{node_block.name}`:"
if missing_links := schema.get_missing_links(data, node.input_links):
return None, f"{error_prefix} unpopulated links {missing_links}"
# Merge input data with default values and resolve dynamic dict/list/object pins.
input_default = schema.get_input_defaults(node.input_default)
data = {**input_default, **data}
if resolve_input:
data = merge_execution_input(data)
# Convert non-matching data types to the expected input schema.
for name, data_type in schema.__annotations__.items():
value = data.get(name)
if (value is not None) and (type(value) is not data_type):
data[name] = convert(value, data_type)
# Input data post-merge should contain all required fields from the schema.
if missing_input := schema.get_missing_input(data):
return None, f"{error_prefix} missing input {missing_input}"
# Last validation: Validate the input values against the schema.
if error := schema.get_mismatch_error(data):
error_message = f"{error_prefix} {error}"
logger.warning(error_message)
return None, error_message
return data, node_block.name
async def _validate_node_input_credentials(
graph: GraphModel,
user_id: str,
nodes_input_masks: Optional[NodesInputMasks] = None,
) -> tuple[dict[str, dict[str, str]], set[str]]:
"""
Checks all credentials for all nodes of the graph and returns structured errors
and a set of nodes that should be skipped due to optional missing credentials.
Returns:
tuple[
dict[node_id, dict[field_name, error_message]]: Credential validation errors per node,
set[node_id]: Nodes that should be skipped (optional credentials not configured)
]
"""
credential_errors: dict[str, dict[str, str]] = defaultdict(dict)
nodes_to_skip: set[str] = set()
for node in graph.nodes:
block = node.block
# Find any fields of type CredentialsMetaInput
credentials_fields = block.input_schema.get_credentials_fields()
if not credentials_fields:
continue
# Track if any credential field is missing for this node
has_missing_credentials = False
for field_name, credentials_meta_type in credentials_fields.items():
try:
# Check nodes_input_masks first, then input_default
field_value = None
if (
nodes_input_masks
and (node_input_mask := nodes_input_masks.get(node.id))
and field_name in node_input_mask
):
field_value = node_input_mask[field_name]
elif field_name in node.input_default:
# For optional credentials, don't use input_default - treat as missing
# This prevents stale credential IDs from failing validation
if node.credentials_optional:
field_value = None
else:
field_value = node.input_default[field_name]
# Check if credentials are missing (None, empty, or not present)
if field_value is None or (
isinstance(field_value, dict) and not field_value.get("id")
):
has_missing_credentials = True
# If node has credentials_optional flag, mark for skipping instead of error
if node.credentials_optional:
continue # Don't add error, will be marked for skip after loop
else:
credential_errors[node.id][
field_name
] = "These credentials are required"
continue
credentials_meta = credentials_meta_type.model_validate(field_value)
except ValidationError as e:
# Validation error means credentials were provided but invalid
# This should always be an error, even if optional
credential_errors[node.id][field_name] = f"Invalid credentials: {e}"
continue
try:
# Fetch the corresponding Credentials and perform sanity checks
credentials = await get_integration_credentials_store().get_creds_by_id(
user_id, credentials_meta.id
)
except Exception as e:
# Handle any errors fetching credentials
# If credentials were explicitly configured but unavailable, it's an error
credential_errors[node.id][
field_name
] = f"Credentials not available: {e}"
continue
if not credentials:
credential_errors[node.id][
field_name
] = f"Unknown credentials #{credentials_meta.id}"
continue
if (
credentials.provider != credentials_meta.provider
or credentials.type != credentials_meta.type
):
logger.warning(
f"Invalid credentials #{credentials.id} for node #{node.id}: "
"type/provider mismatch: "
f"{credentials_meta.type}<>{credentials.type};"
f"{credentials_meta.provider}<>{credentials.provider}"
)
credential_errors[node.id][
field_name
] = "Invalid credentials: type/provider mismatch"
continue
# If node has optional credentials and any are missing, mark for skipping
# But only if there are no other errors for this node
if (
has_missing_credentials
and node.credentials_optional
and node.id not in credential_errors
):
nodes_to_skip.add(node.id)
logger.info(
f"Node #{node.id} will be skipped: optional credentials not configured"
)
return credential_errors, nodes_to_skip
def make_node_credentials_input_map(
graph: GraphModel,
graph_credentials_input: Mapping[str, CredentialsMetaInput],
) -> NodesInputMasks:
"""
Maps credentials for an execution to the correct nodes.
Params:
graph: The graph to be executed.
graph_credentials_input: A (graph_input_name, credentials_meta) map.
Returns:
dict[node_id, dict[field_name, CredentialsMetaRaw]]: Node credentials input map.
"""
result: dict[str, dict[str, JsonValue]] = {}
# Get aggregated credentials fields for the graph
graph_cred_inputs = graph.aggregate_credentials_inputs()
for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items():
# Best-effort map: skip missing items
if graph_input_name not in graph_credentials_input:
continue
# Use passed-in credentials for all compatible node input fields
for node_id, node_field_name in compatible_node_fields:
if node_id not in result:
result[node_id] = {}
result[node_id][node_field_name] = graph_credentials_input[
graph_input_name
].model_dump(exclude_none=True)
return result
async def validate_graph_with_credentials(
graph: GraphModel,
user_id: str,
nodes_input_masks: Optional[NodesInputMasks] = None,
) -> tuple[Mapping[str, Mapping[str, str]], set[str]]:
"""
Validate graph including credentials and return structured errors per node,
along with a set of nodes that should be skipped due to optional missing credentials.
Returns:
tuple[
dict[node_id, dict[field_name, error_message]]: Validation errors per node,
set[node_id]: Nodes that should be skipped (optional credentials not configured)
]
"""
# Get input validation errors
node_input_errors = GraphModel.validate_graph_get_errors(
graph, for_run=True, nodes_input_masks=nodes_input_masks
)
# Get credential input/availability/validation errors and nodes to skip
node_credential_input_errors, nodes_to_skip = (
await _validate_node_input_credentials(graph, user_id, nodes_input_masks)
)
# Merge credential errors with structural errors
for node_id, field_errors in node_credential_input_errors.items():
if node_id not in node_input_errors:
node_input_errors[node_id] = {}
node_input_errors[node_id].update(field_errors)
return node_input_errors, nodes_to_skip
async def _construct_starting_node_execution_input(
graph: GraphModel,
user_id: str,
graph_inputs: BlockInput,
nodes_input_masks: Optional[NodesInputMasks] = None,
) -> tuple[list[tuple[str, BlockInput]], set[str]]:
"""
Validates and prepares the input data for executing a graph.
This function checks the graph for starting nodes, validates the input data
against the schema, and resolves dynamic input pins into a single list,
dictionary, or object.
Args:
graph (GraphModel): The graph model to execute.
user_id (str): The ID of the user executing the graph.
data (BlockInput): The input data for the graph execution.
node_credentials_map: `dict[node_id, dict[input_name, CredentialsMetaInput]]`
Returns:
tuple[
list[tuple[str, BlockInput]]: A list of tuples, each containing the node ID
and the corresponding input data for that node.
set[str]: Node IDs that should be skipped (optional credentials not configured)
]
"""
# Use new validation function that includes credentials
validation_errors, nodes_to_skip = await validate_graph_with_credentials(
graph, user_id, nodes_input_masks
)
n_error_nodes = len(validation_errors)
n_errors = sum(len(errors) for errors in validation_errors.values())
if validation_errors:
raise GraphValidationError(
f"Graph validation failed: {n_errors} issues on {n_error_nodes} nodes",
node_errors=validation_errors,
)
nodes_input = []
for node in graph.starting_nodes:
input_data = {}
block = node.block
# Note block should never be executed.
if block.block_type == BlockType.NOTE:
continue
# Extract request input data, and assign it to the input pin.
if block.block_type == BlockType.INPUT:
input_name = cast(str | None, node.input_default.get("name"))
if input_name and input_name in graph_inputs:
input_data = {"value": graph_inputs[input_name]}
# Apply node input overrides
if nodes_input_masks and (node_input_mask := nodes_input_masks.get(node.id)):
input_data.update(node_input_mask)
input_data, error = validate_exec(node, input_data)
if input_data is None:
raise ValueError(error)
else:
nodes_input.append((node.id, input_data))
if not nodes_input:
raise ValueError(
"No starting nodes found for the graph, make sure an AgentInput or blocks with no inbound links are present as starting nodes."
)
return nodes_input, nodes_to_skip
async def validate_and_construct_node_execution_input(
graph_id: str,
user_id: str,
graph_inputs: BlockInput,
graph_version: Optional[int] = None,
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
nodes_input_masks: Optional[NodesInputMasks] = None,
is_sub_graph: bool = False,
) -> tuple[GraphModel, list[tuple[str, BlockInput]], NodesInputMasks, set[str]]:
"""
Public wrapper that handles graph fetching, credential mapping, and validation+construction.
This centralizes the logic used by both scheduler validation and actual execution.
Args:
graph_id: The ID of the graph to validate/construct.
user_id: The ID of the user.
graph_inputs: The input data for the graph execution.
graph_version: The version of the graph to use.
graph_credentials_inputs: Credentials inputs to use.
nodes_input_masks: Node inputs to use.
Returns:
GraphModel: Full graph object for the given `graph_id`.
list[tuple[node_id, BlockInput]]: Starting node IDs with corresponding inputs.
dict[str, BlockInput]: Node input masks including all passed-in credentials.
set[str]: Node IDs that should be skipped (optional credentials not configured).
Raises:
NotFoundError: If the graph is not found.
GraphValidationError: If the graph has validation issues.
ValueError: If there are other validation errors.
"""
if prisma.is_connected():
gdb = graph_db
else:
gdb = get_database_manager_async_client()
graph: GraphModel | None = await gdb.get_graph(
graph_id=graph_id,
user_id=user_id,
version=graph_version,
include_subgraphs=True,
# Execution/access permission is checked by validate_graph_execution_permissions
skip_access_check=True,
)
if not graph:
raise GraphNotFoundError(f"Graph #{graph_id} not found.")
# Validate that the user has permission to execute this graph
# This checks both library membership and execution permissions,
# raising specific exceptions for appropriate error handling.
await gdb.validate_graph_execution_permissions(
user_id=user_id,
graph_id=graph.id,
graph_version=graph.version,
is_sub_graph=is_sub_graph,
)
nodes_input_masks = _merge_nodes_input_masks(
(
make_node_credentials_input_map(graph, graph_credentials_inputs)
if graph_credentials_inputs
else {}
),
nodes_input_masks or {},
)
starting_nodes_input, nodes_to_skip = (
await _construct_starting_node_execution_input(
graph=graph,
user_id=user_id,
graph_inputs=graph_inputs,
nodes_input_masks=nodes_input_masks,
)
)
return graph, starting_nodes_input, nodes_input_masks, nodes_to_skip
def _merge_nodes_input_masks(
overrides_map_1: NodesInputMasks,
overrides_map_2: NodesInputMasks,
) -> NodesInputMasks:
"""Perform a per-node merge of input overrides"""
result = dict(overrides_map_1).copy()
for node_id, overrides2 in overrides_map_2.items():
if node_id in result:
result[node_id] = {**result[node_id], **overrides2}
else:
result[node_id] = overrides2
return result
# ============ Execution Queue Helpers ============ #
GRAPH_EXECUTION_EXCHANGE = Exchange(
name="graph_execution",
type=ExchangeType.DIRECT,
durable=True,
auto_delete=False,
)
GRAPH_EXECUTION_QUEUE_NAME = "graph_execution_queue"
GRAPH_EXECUTION_ROUTING_KEY = "graph_execution.run"
GRAPH_EXECUTION_CANCEL_EXCHANGE = Exchange(
name="graph_execution_cancel",
type=ExchangeType.FANOUT,
durable=True,
auto_delete=True,
)
GRAPH_EXECUTION_CANCEL_QUEUE_NAME = "graph_execution_cancel_queue"
# Graceful shutdown timeout constants
# Agent executions can run for up to 1 day, so we need a graceful shutdown period
# that allows long-running executions to complete naturally
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = 24 * 60 * 60 # 1 day to complete active executions
def create_execution_queue_config() -> RabbitMQConfig:
"""
Define two exchanges and queues:
- 'graph_execution' (DIRECT) for run tasks.
- 'graph_execution_cancel' (FANOUT) for cancel requests.
"""
run_queue = Queue(
name=GRAPH_EXECUTION_QUEUE_NAME,
exchange=GRAPH_EXECUTION_EXCHANGE,
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
durable=True,
auto_delete=False,
arguments={
# x-consumer-timeout (1 week)
# Problem: Default 30-minute consumer timeout kills long-running graph executions
# Original error: "Consumer acknowledgement timed out after 1800000 ms (30 minutes)"
# Solution: Disable consumer timeout entirely - let graphs run indefinitely
# Safety: Heartbeat mechanism now handles dead consumer detection instead
# Use case: Graph executions that take hours to complete (AI model training, etc.)
"x-consumer-timeout": GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS
* 1000,
},
)
cancel_queue = Queue(
name=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
exchange=GRAPH_EXECUTION_CANCEL_EXCHANGE,
routing_key="", # not used for FANOUT
durable=True,
auto_delete=False,
)
return RabbitMQConfig(
vhost="/",
exchanges=[GRAPH_EXECUTION_EXCHANGE, GRAPH_EXECUTION_CANCEL_EXCHANGE],
queues=[run_queue, cancel_queue],
)
class CancelExecutionEvent(BaseModel):
graph_exec_id: str
async def _get_child_executions(parent_exec_id: str) -> list["GraphExecutionMeta"]:
"""
Get all child executions of a parent execution using the execution_db pattern.
Args:
parent_exec_id: Parent graph execution ID
Returns:
List of child graph executions
"""
from backend.data.db import prisma
if prisma.is_connected():
edb = execution_db
else:
edb = get_database_manager_async_client()
return await edb.get_child_graph_executions(parent_exec_id)
async def stop_graph_execution(
user_id: str,
graph_exec_id: str,
wait_timeout: float = 15.0,
cascade: bool = True,
):
"""
Stop a graph execution and optionally all its child executions.
Mechanism:
1. Set the cancel event for this execution
2. If cascade=True, recursively stop all child executions
3. Graph executor's cancel handler thread detects the event, terminates workers,
reinitializes worker pool, and returns.
4. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
Args:
user_id: User ID who owns the execution
graph_exec_id: Graph execution ID to stop
wait_timeout: Maximum time to wait for execution to stop (seconds)
cascade: If True, recursively stop all child executions
"""
queue_client = await get_async_execution_queue()
db = execution_db if prisma.is_connected() else get_database_manager_async_client()
# First, find and stop all child executions if cascading
if cascade:
children = await _get_child_executions(graph_exec_id)
logger.info(
f"Stopping {len(children)} child executions of execution {graph_exec_id}"
)
# Stop all children in parallel (recursively, with cascading enabled)
if children:
await asyncio.gather(
*[
stop_graph_execution(
user_id=user_id,
graph_exec_id=child.id,
wait_timeout=wait_timeout,
cascade=True, # Recursively cascade to grandchildren
)
for child in children
],
return_exceptions=True, # Don't fail parent stop if child stop fails
)
# Now stop this execution
await queue_client.publish_message(
routing_key="",
message=CancelExecutionEvent(graph_exec_id=graph_exec_id).model_dump_json(),
exchange=GRAPH_EXECUTION_CANCEL_EXCHANGE,
)
if not wait_timeout:
return
start_time = time.time()
while time.time() - start_time < wait_timeout:
graph_exec = await db.get_graph_execution_meta(
execution_id=graph_exec_id, user_id=user_id
)
if not graph_exec:
raise NotFoundError(f"Graph execution #{graph_exec_id} not found.")
if graph_exec.status in [
ExecutionStatus.TERMINATED,
ExecutionStatus.COMPLETED,
ExecutionStatus.FAILED,
]:
# If graph execution is terminated/completed/failed, cancellation is complete
await get_async_execution_event_bus().publish(graph_exec)
return
if graph_exec.status in [
ExecutionStatus.QUEUED,
ExecutionStatus.INCOMPLETE,
ExecutionStatus.REVIEW,
]:
# If the graph is queued/incomplete/paused for review, terminate immediately
# No need to wait for executor since it's not actively running
# If graph is in REVIEW status, clean up pending reviews before terminating
if graph_exec.status == ExecutionStatus.REVIEW:
# Use human_review_db if Prisma connected, else database manager
review_db = (
human_review_db
if prisma.is_connected()
else get_database_manager_async_client()
)
# Mark all pending reviews as rejected/cancelled
cancelled_count = await review_db.cancel_pending_reviews_for_execution(
graph_exec_id, user_id
)
logger.info(
f"Cancelled {cancelled_count} pending review(s) for stopped execution {graph_exec_id}"
)
graph_exec.status = ExecutionStatus.TERMINATED
await asyncio.gather(
# Update graph execution status
db.update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=ExecutionStatus.TERMINATED,
),
# Publish graph execution event
get_async_execution_event_bus().publish(graph_exec),
)
return
if graph_exec.status == ExecutionStatus.RUNNING:
await asyncio.sleep(0.1)
raise TimeoutError(
f"Graph execution #{graph_exec_id} will need to take longer than {wait_timeout} seconds to stop. "
f"You can check the status of the execution in the UI or try again later."
)
async def add_graph_execution(
graph_id: str,
user_id: str,
inputs: Optional[BlockInput] = None,
preset_id: Optional[str] = None,
graph_version: Optional[int] = None,
graph_credentials_inputs: Optional[Mapping[str, CredentialsMetaInput]] = None,
nodes_input_masks: Optional[NodesInputMasks] = None,
execution_context: Optional[ExecutionContext] = None,
graph_exec_id: Optional[str] = None,
) -> GraphExecutionWithNodes:
"""
Adds a graph execution to the queue and returns the execution entry.
Args:
graph_id: The ID of the graph to execute.
user_id: The ID of the user executing the graph.
inputs: The input data for the graph execution.
preset_id: The ID of the preset to use.
graph_version: The version of the graph to execute.
graph_credentials_inputs: Credentials inputs to use in the execution.
Keys should map to the keys generated by `GraphModel.aggregate_credentials_inputs`.
nodes_input_masks: Node inputs to use in the execution.
parent_graph_exec_id: The ID of the parent graph execution (for nested executions).
graph_exec_id: If provided, resume this existing execution instead of creating a new one.
Returns:
GraphExecutionEntry: The entry for the graph execution.
Raises:
ValueError: If the graph is not found or if there are validation errors.
NotFoundError: If graph_exec_id is provided but execution is not found.
"""
if prisma.is_connected():
edb = execution_db
udb = user_db
gdb = graph_db
odb = onboarding_db
else:
edb = udb = gdb = odb = get_database_manager_async_client()
# Get or create the graph execution
if graph_exec_id:
# Resume existing execution
graph_exec = await edb.get_graph_execution(
user_id=user_id,
execution_id=graph_exec_id,
include_node_executions=True,
)
if not graph_exec:
raise NotFoundError(f"Graph execution #{graph_exec_id} not found.")
# Use existing execution's compiled input masks
compiled_nodes_input_masks = graph_exec.nodes_input_masks or {}
# For resumed executions, nodes_to_skip was already determined at creation time
# TODO: Consider storing nodes_to_skip in DB if we need to preserve it across resumes
nodes_to_skip: set[str] = set()
logger.info(f"Resuming graph execution #{graph_exec.id} for graph #{graph_id}")
else:
parent_exec_id = (
execution_context.parent_execution_id if execution_context else None
)
# Create new execution
graph, starting_nodes_input, compiled_nodes_input_masks, nodes_to_skip = (
await validate_and_construct_node_execution_input(
graph_id=graph_id,
user_id=user_id,
graph_inputs=inputs or {},
graph_version=graph_version,
graph_credentials_inputs=graph_credentials_inputs,
nodes_input_masks=nodes_input_masks,
is_sub_graph=parent_exec_id is not None,
)
)
graph_exec = await edb.create_graph_execution(
user_id=user_id,
graph_id=graph_id,
graph_version=graph.version,
inputs=inputs or {},
credential_inputs=graph_credentials_inputs,
nodes_input_masks=nodes_input_masks,
starting_nodes_input=starting_nodes_input,
preset_id=preset_id,
parent_graph_exec_id=parent_exec_id,
)
logger.info(
f"Created graph execution #{graph_exec.id} for graph "
f"#{graph_id} with {len(starting_nodes_input)} starting nodes"
)
# Generate execution context if it's not provided
if execution_context is None:
user = await udb.get_user_by_id(user_id)
settings = await gdb.get_graph_settings(user_id=user_id, graph_id=graph_id)
execution_context = ExecutionContext(
# Execution identity
user_id=user_id,
graph_id=graph_id,
graph_exec_id=graph_exec.id,
graph_version=graph_exec.graph_version,
# Safety settings
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
# User settings
user_timezone=(
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
),
# Execution hierarchy
root_execution_id=graph_exec.id,
)
try:
graph_exec_entry = graph_exec.to_graph_execution_entry(
compiled_nodes_input_masks=compiled_nodes_input_masks,
nodes_to_skip=nodes_to_skip,
execution_context=execution_context,
)
logger.info(f"Queueing execution {graph_exec.id}")
# Update execution status to QUEUED BEFORE publishing to prevent race condition
# where two concurrent requests could both publish the same execution
updated_exec = await edb.update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=ExecutionStatus.QUEUED,
)
# Verify the status update succeeded (prevents duplicate queueing in race conditions)
# If another request already updated the status, this execution will not be QUEUED
if not updated_exec or updated_exec.status != ExecutionStatus.QUEUED:
logger.warning(
f"Skipping queue publish for execution {graph_exec.id} - "
f"status update failed or execution already queued by another request"
)
return graph_exec
graph_exec.status = ExecutionStatus.QUEUED
# Publish to execution queue for executor to pick up
# This happens AFTER status update to ensure only one request publishes
exec_queue = await get_async_execution_queue()
await exec_queue.publish_message(
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
message=graph_exec_entry.model_dump_json(),
exchange=GRAPH_EXECUTION_EXCHANGE,
)
logger.info(f"Published execution {graph_exec.id} to RabbitMQ queue")
except BaseException as e:
err = str(e) or type(e).__name__
if not graph_exec:
logger.error(f"Unable to execute graph #{graph_id} failed: {err}")
raise
logger.error(
f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {err}"
)
await edb.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
ExecutionStatus.FAILED,
)
await edb.update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=ExecutionStatus.FAILED,
stats=GraphExecutionStats(error=err),
)
raise
try:
await get_async_execution_event_bus().publish(graph_exec)
logger.info(f"Published update for execution #{graph_exec.id} to event bus")
except Exception as e:
logger.error(
f"Failed to publish execution event for graph exec #{graph_exec.id}: {e}"
)
try:
await odb.increment_onboarding_runs(user_id)
logger.info(
f"Incremented user #{user_id} onboarding runs for exec #{graph_exec.id}"
)
except Exception as e:
logger.error(f"Failed to increment onboarding runs for user #{user_id}: {e}")
return graph_exec
# ============ Execution Output Helpers ============ #
class ExecutionOutputEntry(BaseModel):
node: Node
node_exec_id: str
data: BlockOutputEntry
class NodeExecutionProgress:
def __init__(self):
self.output: dict[str, list[ExecutionOutputEntry]] = defaultdict(list)
self.tasks: dict[str, Future] = {}
self._lock = threading.Lock()
def add_task(self, node_exec_id: str, task: Future):
self.tasks[node_exec_id] = task
def add_output(self, output: ExecutionOutputEntry):
with self._lock:
self.output[output.node_exec_id].append(output)
def pop_output(self) -> ExecutionOutputEntry | None:
exec_id = self._next_exec()
if not exec_id:
return None
if self._pop_done_task(exec_id):
return self.pop_output()
with self._lock:
if next_output := self.output[exec_id]:
return next_output.pop(0)
return None
def is_done(self, wait_time: float = 0.0) -> bool:
exec_id = self._next_exec()
if not exec_id:
return True
if self._pop_done_task(exec_id):
return self.is_done(wait_time)
if wait_time <= 0:
return False
try:
self.tasks[exec_id].result(wait_time)
except TimeoutError:
pass
except Exception as e:
logger.error(
f"Task for exec ID {exec_id} failed with error: {e.__class__.__name__} {str(e)}"
)
pass
return self.is_done(0)
def stop(self) -> list[str]:
"""
Stops all tasks and clears the output.
This is useful for cleaning up when the execution is cancelled or terminated.
Returns a list of execution IDs that were stopped.
"""
cancelled_ids = []
for task_id, task in self.tasks.items():
if task.done():
continue
task.cancel()
cancelled_ids.append(task_id)
return cancelled_ids
def wait_for_done(self, timeout: float = 5.0):
"""
Wait for all cancelled tasks to complete cancellation.
Args:
timeout: Maximum time to wait for cancellation in seconds
"""
start_time = time.time()
while time.time() - start_time < timeout:
while self.pop_output():
pass
if self.is_done():
return
time.sleep(0.1) # Small delay to avoid busy waiting
raise TimeoutError(
f"Timeout waiting for cancellation of tasks: {list(self.tasks.keys())}"
)
def _pop_done_task(self, exec_id: str) -> bool:
task = self.tasks.get(exec_id)
if not task:
return True
if not task.done():
return False
with self._lock:
if self.output[exec_id]:
return False
self.tasks.pop(exec_id)
return True
def _next_exec(self) -> str | None:
if not self.tasks:
return None
return next(iter(self.tasks.keys()))