import asyncio import logging import os import threading import time import uuid from collections import defaultdict from concurrent.futures import Future, ThreadPoolExecutor from contextlib import asynccontextmanager from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast import sentry_sdk from pika.adapters.blocking_connection import BlockingChannel from pika.spec import Basic, BasicProperties from prometheus_client import Gauge, start_http_server from redis.asyncio.lock import Lock as AsyncRedisLock from backend.blocks import get_block from backend.blocks._base import BlockSchema from backend.blocks.agent import AgentExecutorBlock from backend.blocks.io import AgentOutputBlock from backend.blocks.mcp.block import MCPToolBlock from backend.data import redis_client as redis from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry from backend.data.credit import UsageTransactionMetadata from backend.data.dynamic_fields import parse_execution_output from backend.data.execution import ( ExecutionContext, ExecutionQueue, ExecutionStatus, GraphExecution, GraphExecutionEntry, NodeExecutionEntry, NodeExecutionResult, NodesInputMasks, ) from backend.data.graph import Link, Node from backend.data.model import GraphExecutionStats, NodeExecutionStats from backend.data.notifications import ( AgentRunData, LowBalanceData, NotificationEventModel, NotificationType, ZeroBalanceData, ) from backend.data.rabbitmq import SyncRabbitMQ from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.notifications.notifications import queue_notification from backend.util import json from backend.util.clients import ( get_async_execution_event_bus, get_database_manager_async_client, get_database_manager_client, get_execution_event_bus, get_notification_manager_client, ) from backend.util.decorator import ( async_error_logged, async_time_measured, error_logged, time_measured, ) from backend.util.exceptions import InsufficientBalanceError, ModerationError from backend.util.file import clean_exec_files from backend.util.logging import TruncatedLogger, configure_logging from backend.util.metrics import DiscordChannel from backend.util.process import AppProcess, set_service_name from backend.util.retry import ( continuous_retry, func_retry, send_rate_limited_discord_alert, ) from backend.util.settings import Settings from .activity_status_generator import generate_activity_status_for_execution from .automod.manager import automod_manager from .cluster_lock import ClusterLock from .utils import ( GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS, GRAPH_EXECUTION_CANCEL_QUEUE_NAME, GRAPH_EXECUTION_EXCHANGE, GRAPH_EXECUTION_QUEUE_NAME, GRAPH_EXECUTION_ROUTING_KEY, CancelExecutionEvent, ExecutionOutputEntry, LogMetadata, NodeExecutionProgress, block_usage_cost, create_execution_queue_config, execution_usage_cost, validate_exec, ) if TYPE_CHECKING: from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient _logger = logging.getLogger(__name__) logger = TruncatedLogger(_logger, prefix="[GraphExecutor]") settings = Settings() active_runs_gauge = Gauge( "execution_manager_active_runs", "Number of active graph runs" ) pool_size_gauge = Gauge( "execution_manager_pool_size", "Maximum number of graph workers" ) utilization_gauge = Gauge( "execution_manager_utilization_ratio", "Ratio of active graph runs to max graph workers", ) # Redis key prefix for tracking insufficient funds Discord notifications. # We only send one notification per user per agent until they top up credits. INSUFFICIENT_FUNDS_NOTIFIED_PREFIX = "insufficient_funds_discord_notified" # TTL for the notification flag (30 days) - acts as a fallback cleanup INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS = 30 * 24 * 60 * 60 async def clear_insufficient_funds_notifications(user_id: str) -> int: """ Clear all insufficient funds notification flags for a user. This should be called when a user tops up their credits, allowing Discord notifications to be sent again if they run out of funds. Args: user_id: The user ID to clear notifications for. Returns: The number of keys that were deleted. """ try: redis_client = await redis.get_redis_async() pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*" keys = [key async for key in redis_client.scan_iter(match=pattern)] if keys: return await redis_client.delete(*keys) return 0 except Exception as e: logger.warning( f"Failed to clear insufficient funds notification flags for user " f"{user_id}: {e}" ) return 0 # Thread-local storage for ExecutionProcessor instances _tls = threading.local() def init_worker(): """Initialize ExecutionProcessor instance in thread-local storage""" _tls.processor = ExecutionProcessor() _tls.processor.on_graph_executor_start() def execute_graph( graph_exec_entry: "GraphExecutionEntry", cancel_event: threading.Event, cluster_lock: ClusterLock, ): """Execute graph using thread-local ExecutionProcessor instance""" processor: ExecutionProcessor = _tls.processor return processor.on_graph_execution(graph_exec_entry, cancel_event, cluster_lock) T = TypeVar("T") async def execute_node( node: Node, data: NodeExecutionEntry, execution_processor: "ExecutionProcessor", execution_stats: NodeExecutionStats | None = None, nodes_input_masks: Optional[NodesInputMasks] = None, nodes_to_skip: Optional[set[str]] = None, ) -> BlockOutput: """ Execute a node in the graph. This will trigger a block execution on a node, persist the execution result, and return the subsequent node to be executed. Args: db_client: The client to send execution updates to the server. creds_manager: The manager to acquire and release credentials. data: The execution data for executing the current node. execution_stats: The execution statistics to be updated. Returns: The subsequent node to be enqueued, or None if there is no subsequent node. """ user_id = data.user_id graph_exec_id = data.graph_exec_id graph_id = data.graph_id graph_version = data.graph_version node_exec_id = data.node_exec_id node_id = data.node_id node_block = node.block execution_context = data.execution_context creds_manager = execution_processor.creds_manager log_metadata = LogMetadata( logger=_logger, user_id=user_id, graph_eid=graph_exec_id, graph_id=graph_id, node_eid=node_exec_id, node_id=node_id, block_name=node_block.name, ) if node_block.disabled: raise ValueError(f"Block {node_block.id} is disabled and cannot be executed") # Sanity check: validate the execution input. input_data, error = validate_exec(node, data.inputs, resolve_input=False) if input_data is None: log_metadata.error(f"Skip execution, input validation error: {error}") yield "error", error return # Re-shape the input data for agent block. # AgentExecutorBlock specially separate the node input_data & its input_default. if isinstance(node_block, AgentExecutorBlock): _input_data = AgentExecutorBlock.Input(**node.input_default) _input_data.inputs = input_data if nodes_input_masks: _input_data.nodes_input_masks = nodes_input_masks _input_data.user_id = user_id input_data = _input_data.model_dump() elif isinstance(node_block, MCPToolBlock): _mcp_data = MCPToolBlock.Input(**node.input_default) # Dynamic tool fields are flattened to top-level by validate_exec # (via get_input_defaults). Collect them back into tool_arguments. tool_schema = _mcp_data.tool_input_schema tool_props = set(tool_schema.get("properties", {}).keys()) merged_args = {**_mcp_data.tool_arguments} for key in tool_props: if key in input_data: merged_args[key] = input_data[key] _mcp_data.tool_arguments = merged_args input_data = _mcp_data.model_dump() data.inputs = input_data # Execute the node input_data_str = json.dumps(input_data) input_size = len(input_data_str) log_metadata.debug("Executed node with input", input=input_data_str) # Create node-specific execution context to avoid race conditions # (multiple nodes can execute concurrently and would otherwise mutate shared state) execution_context = execution_context.model_copy( update={"node_id": node_id, "node_exec_id": node_exec_id} ) # Inject extra execution arguments for the blocks via kwargs # Keep individual kwargs for backwards compatibility with existing blocks extra_exec_kwargs: dict = { "graph_id": graph_id, "graph_version": graph_version, "node_id": node_id, "graph_exec_id": graph_exec_id, "node_exec_id": node_exec_id, "user_id": user_id, "execution_context": execution_context, "execution_processor": execution_processor, "nodes_to_skip": nodes_to_skip or set(), } # Last-minute fetch credentials + acquire a system-wide read-write lock to prevent # changes during execution. ⚠️ This means a set of credentials can only be used by # one (running) block at a time; simultaneous execution of blocks using same # credentials is not supported. creds_locks: list[AsyncRedisLock] = [] input_model = cast(type[BlockSchema], node_block.input_schema) # Handle regular credentials fields for field_name, input_type in input_model.get_credentials_fields().items(): field_value = input_data.get(field_name) if not field_value or ( isinstance(field_value, dict) and not field_value.get("id") ): # No credentials configured — nullify so JSON schema validation # doesn't choke on the empty default `{}`. input_data[field_name] = None continue # Block runs without credentials credentials_meta = input_type(**field_value) # Write normalized values back so JSON schema validation also passes # (model_validator may have fixed legacy formats like "ProviderName.MCP") input_data[field_name] = credentials_meta.model_dump(mode="json") try: credentials, lock = await creds_manager.acquire( user_id, credentials_meta.id ) except ValueError: # Credential was deleted or doesn't exist. # If the field has a default, run without credentials. if input_model.model_fields[field_name].default is not None: log_metadata.warning( f"Credentials #{credentials_meta.id} not found, " "running without (field has default)" ) input_data[field_name] = None continue raise creds_locks.append(lock) extra_exec_kwargs[field_name] = credentials # Handle auto-generated credentials (e.g., from GoogleDriveFileInput) for kwarg_name, info in input_model.get_auto_credentials_fields().items(): field_name = info["field_name"] field_data = input_data.get(field_name) if field_data and isinstance(field_data, dict): # Check if _credentials_id key exists in the field data if "_credentials_id" in field_data: cred_id = field_data["_credentials_id"] if cred_id: # Credential ID provided - acquire credentials provider = info.get("config", {}).get( "provider", "external service" ) file_name = field_data.get("name", "selected file") try: credentials, lock = await creds_manager.acquire( user_id, cred_id ) creds_locks.append(lock) extra_exec_kwargs[kwarg_name] = credentials except ValueError: # Credential was deleted or doesn't exist raise ValueError( f"Authentication expired for '{file_name}' in field '{field_name}'. " f"The saved {provider.capitalize()} credentials no longer exist. " f"Please re-select the file to re-authenticate." ) # else: _credentials_id is explicitly None, skip credentials (for chained data) else: # _credentials_id key missing entirely - this is an error provider = info.get("config", {}).get("provider", "external service") file_name = field_data.get("name", "selected file") raise ValueError( f"Authentication missing for '{file_name}' in field '{field_name}'. " f"Please re-select the file to authenticate with {provider.capitalize()}." ) output_size = 0 # sentry tracking nonsense to get user counts for blocks because isolation scopes don't work :( scope = sentry_sdk.get_current_scope() # save the tags original_user = scope._user original_tags = dict(scope._tags) if scope._tags else {} # Set user ID for error tracking scope.set_user({"id": user_id}) scope.set_tag("graph_id", graph_id) scope.set_tag("node_id", node_id) scope.set_tag("block_name", node_block.name) scope.set_tag("block_id", node_block.id) for k, v in execution_context.model_dump().items(): scope.set_tag(f"execution_context.{k}", v) try: async for output_name, output_data in node_block.execute( input_data, **extra_exec_kwargs ): output_data = json.to_dict(output_data) output_size += len(json.dumps(output_data)) log_metadata.debug("Node produced output", **{output_name: output_data}) yield output_name, output_data except Exception as ex: # Capture exception WITH context still set before restoring scope sentry_sdk.capture_exception(error=ex, scope=scope) sentry_sdk.flush() # Ensure it's sent before we restore scope # Re-raise to maintain normal error flow raise finally: # Ensure all credentials are released even if execution fails for creds_lock in creds_locks: if ( creds_lock and (await creds_lock.locked()) and (await creds_lock.owned()) ): try: await creds_lock.release() except Exception as e: log_metadata.error(f"Failed to release credentials lock: {e}") # Update execution stats if execution_stats is not None: execution_stats += node_block.execution_stats execution_stats.input_size = input_size execution_stats.output_size = output_size # Restore scope AFTER error has been captured scope._user = original_user scope._tags = original_tags async def _enqueue_next_nodes( db_client: "DatabaseManagerAsyncClient", node: Node, output: BlockOutputEntry, user_id: str, graph_exec_id: str, graph_id: str, graph_version: int, log_metadata: LogMetadata, nodes_input_masks: Optional[NodesInputMasks], execution_context: ExecutionContext, ) -> list[NodeExecutionEntry]: async def add_enqueued_execution( node_exec_id: str, node_id: str, block_id: str, data: BlockInput ) -> NodeExecutionEntry: await async_update_node_execution_status( db_client=db_client, exec_id=node_exec_id, status=ExecutionStatus.QUEUED, execution_data=data, ) return NodeExecutionEntry( user_id=user_id, graph_exec_id=graph_exec_id, graph_id=graph_id, graph_version=graph_version, node_exec_id=node_exec_id, node_id=node_id, block_id=block_id, inputs=data, execution_context=execution_context, ) async def register_next_executions(node_link: Link) -> list[NodeExecutionEntry]: try: return await _register_next_executions(node_link) except Exception as e: log_metadata.exception(f"Failed to register next executions: {e}") return [] async def _register_next_executions(node_link: Link) -> list[NodeExecutionEntry]: enqueued_executions = [] next_output_name = node_link.source_name next_input_name = node_link.sink_name next_node_id = node_link.sink_id output_name, _ = output next_data = parse_execution_output( output, next_output_name, next_node_id, next_input_name ) if next_data is None and output_name != next_output_name: return enqueued_executions next_node = await db_client.get_node(next_node_id) # Multiple node can register the same next node, we need this to be atomic # To avoid same execution to be enqueued multiple times, # Or the same input to be consumed multiple times. async with synchronized(f"upsert_input-{next_node_id}-{graph_exec_id}"): # Add output data to the earliest incomplete execution, or create a new one. next_node_exec, next_node_input = await db_client.upsert_execution_input( node_id=next_node_id, graph_exec_id=graph_exec_id, input_name=next_input_name, input_data=next_data, ) next_node_exec_id = next_node_exec.node_exec_id await send_async_execution_update(next_node_exec) # Complete missing static input pins data using the last execution input. static_link_names = { link.sink_name for link in next_node.input_links if link.is_static and link.sink_name not in next_node_input } if static_link_names and ( latest_execution := await db_client.get_latest_node_execution( next_node_id, graph_exec_id ) ): for name in static_link_names: next_node_input[name] = latest_execution.input_data.get(name) # Apply node input overrides node_input_mask = None if nodes_input_masks and ( node_input_mask := nodes_input_masks.get(next_node.id) ): next_node_input.update(node_input_mask) # Validate the input data for the next node. next_node_input, validation_msg = validate_exec(next_node, next_node_input) suffix = f"{next_output_name}>{next_input_name}~{next_node_exec_id}:{validation_msg}" # Incomplete input data, skip queueing the execution. if not next_node_input: log_metadata.info(f"Skipped queueing {suffix}") return enqueued_executions # Input is complete, enqueue the execution. log_metadata.info(f"Enqueued {suffix}") enqueued_executions.append( await add_enqueued_execution( node_exec_id=next_node_exec_id, node_id=next_node_id, block_id=next_node.block_id, data=next_node_input, ) ) # Next execution stops here if the link is not static. if not node_link.is_static: return enqueued_executions # If link is static, there could be some incomplete executions waiting for it. # Load and complete the input missing input data, and try to re-enqueue them. for iexec in await db_client.get_node_executions( node_id=next_node_id, graph_exec_id=graph_exec_id, statuses=[ExecutionStatus.INCOMPLETE], ): idata = iexec.input_data ineid = iexec.node_exec_id static_link_names = { link.sink_name for link in next_node.input_links if link.is_static and link.sink_name not in idata } for input_name in static_link_names: idata[input_name] = next_node_input[input_name] # Apply node input overrides if node_input_mask: idata.update(node_input_mask) idata, msg = validate_exec(next_node, idata) suffix = f"{next_output_name}>{next_input_name}~{ineid}:{msg}" if not idata: log_metadata.info(f"Enqueueing static-link skipped: {suffix}") continue log_metadata.info(f"Enqueueing static-link execution {suffix}") enqueued_executions.append( await add_enqueued_execution( node_exec_id=iexec.node_exec_id, node_id=next_node_id, block_id=next_node.block_id, data=idata, ) ) return enqueued_executions return [ execution for link in node.output_links for execution in await register_next_executions(link) ] class ExecutionProcessor: """ This class contains event handlers for the process pool executor events. The main events are: on_graph_executor_start: Initialize the process that executes the graph. on_graph_execution: Execution logic for a graph. on_node_execution: Execution logic for a node. The execution flow: 1. Graph execution request is added to the queue. 2. Graph executor loop picks the request from the queue. 3. Graph executor loop submits the graph execution request to the executor pool. [on_graph_execution] 4. Graph executor initialize the node execution queue. 5. Graph executor adds the starting nodes to the node execution queue. 6. Graph executor waits for all nodes to be executed. [on_node_execution] 7. Node executor picks the node execution request from the queue. 8. Node executor executes the node. 9. Node executor enqueues the next executed nodes to the node execution queue. """ @async_error_logged(swallow=True) async def on_node_execution( self, node_exec: NodeExecutionEntry, node_exec_progress: NodeExecutionProgress, nodes_input_masks: Optional[NodesInputMasks], graph_stats_pair: tuple[GraphExecutionStats, threading.Lock], nodes_to_skip: Optional[set[str]] = None, ) -> NodeExecutionStats: log_metadata = LogMetadata( logger=_logger, user_id=node_exec.user_id, graph_eid=node_exec.graph_exec_id, graph_id=node_exec.graph_id, node_eid=node_exec.node_exec_id, node_id=node_exec.node_id, block_name=b.name if (b := get_block(node_exec.block_id)) else "-", ) db_client = get_db_async_client() node = await db_client.get_node(node_exec.node_id) execution_stats = NodeExecutionStats() timing_info, status = await self._on_node_execution( node=node, node_exec=node_exec, node_exec_progress=node_exec_progress, stats=execution_stats, db_client=db_client, log_metadata=log_metadata, nodes_input_masks=nodes_input_masks, nodes_to_skip=nodes_to_skip, ) if isinstance(status, BaseException): raise status execution_stats.walltime = timing_info.wall_time execution_stats.cputime = timing_info.cpu_time graph_stats, graph_stats_lock = graph_stats_pair with graph_stats_lock: graph_stats.node_count += 1 + execution_stats.extra_steps graph_stats.nodes_cputime += execution_stats.cputime graph_stats.nodes_walltime += execution_stats.walltime graph_stats.cost += execution_stats.extra_cost if isinstance(execution_stats.error, Exception): graph_stats.node_error_count += 1 node_error = execution_stats.error node_stats = execution_stats.model_dump() if node_error and not isinstance(node_error, str): node_stats["error"] = str(node_error) or node_stats.__class__.__name__ await async_update_node_execution_status( db_client=db_client, exec_id=node_exec.node_exec_id, status=status, stats=node_stats, ) await async_update_graph_execution_state( db_client=db_client, graph_exec_id=node_exec.graph_exec_id, stats=graph_stats, ) return execution_stats @async_time_measured async def _on_node_execution( self, node: Node, node_exec: NodeExecutionEntry, node_exec_progress: NodeExecutionProgress, stats: NodeExecutionStats, db_client: "DatabaseManagerAsyncClient", log_metadata: LogMetadata, nodes_input_masks: Optional[NodesInputMasks] = None, nodes_to_skip: Optional[set[str]] = None, ) -> ExecutionStatus: status = ExecutionStatus.RUNNING async def persist_output(output_name: str, output_data: Any) -> None: await db_client.upsert_execution_output( node_exec_id=node_exec.node_exec_id, output_name=output_name, output_data=output_data, ) if exec_update := await db_client.get_node_execution( node_exec.node_exec_id ): await send_async_execution_update(exec_update) node_exec_progress.add_output( ExecutionOutputEntry( node=node, node_exec_id=node_exec.node_exec_id, data=(output_name, output_data), ) ) try: log_metadata.info(f"Start node execution {node_exec.node_exec_id}") await async_update_node_execution_status( db_client=db_client, exec_id=node_exec.node_exec_id, status=ExecutionStatus.RUNNING, ) async for output_name, output_data in execute_node( node=node, data=node_exec, execution_processor=self, execution_stats=stats, nodes_input_masks=nodes_input_masks, nodes_to_skip=nodes_to_skip, ): await persist_output(output_name, output_data) log_metadata.info(f"Finished node execution {node_exec.node_exec_id}") status = ExecutionStatus.COMPLETED except BaseException as e: stats.error = e if isinstance(e, ValueError): # Avoid user error being marked as an actual error. log_metadata.info( f"Expected failure on node execution {node_exec.node_exec_id}: {type(e).__name__} - {e}" ) status = ExecutionStatus.FAILED elif isinstance(e, Exception): # If the exception is not a ValueError, it is unexpected. log_metadata.exception( f"Unexpected failure on node execution {node_exec.node_exec_id}: {type(e).__name__} - {e}" ) status = ExecutionStatus.FAILED else: # CancelledError or SystemExit log_metadata.warning( f"Interruption error on node execution {node_exec.node_exec_id}: {type(e).__name__}" ) status = ExecutionStatus.TERMINATED finally: if status == ExecutionStatus.FAILED and stats.error is not None: await persist_output( "error", str(stats.error) or type(stats.error).__name__ ) return status @func_retry def on_graph_executor_start(self): configure_logging() set_service_name("GraphExecutor") self.tid = threading.get_ident() self.creds_manager = IntegrationCredentialsManager() self.node_execution_loop = asyncio.new_event_loop() self.node_evaluation_loop = asyncio.new_event_loop() self.node_execution_thread = threading.Thread( target=self.node_execution_loop.run_forever, daemon=True ) self.node_evaluation_thread = threading.Thread( target=self.node_evaluation_loop.run_forever, daemon=True ) self.node_execution_thread.start() self.node_evaluation_thread.start() logger.info(f"[GraphExecutor] {self.tid} started") @error_logged(swallow=False) def on_graph_execution( self, graph_exec: GraphExecutionEntry, cancel: threading.Event, cluster_lock: ClusterLock, ): log_metadata = LogMetadata( logger=_logger, user_id=graph_exec.user_id, graph_eid=graph_exec.graph_exec_id, graph_id=graph_exec.graph_id, node_id="*", node_eid="*", block_name="-", ) db_client = get_db_client() exec_meta = db_client.get_graph_execution_meta( user_id=graph_exec.user_id, execution_id=graph_exec.graph_exec_id, ) if exec_meta is None: log_metadata.warning( f"Skipped graph execution #{graph_exec.graph_exec_id}, the graph execution is not found." ) return if exec_meta.status in [ExecutionStatus.QUEUED, ExecutionStatus.INCOMPLETE]: log_metadata.info(f"⚙️ Starting graph execution #{graph_exec.graph_exec_id}") exec_meta.status = ExecutionStatus.RUNNING send_execution_update( db_client.update_graph_execution_start_time(graph_exec.graph_exec_id) ) elif exec_meta.status == ExecutionStatus.RUNNING: log_metadata.info( f"⚙️ Graph execution #{graph_exec.graph_exec_id} is already running, continuing where it left off." ) elif exec_meta.status == ExecutionStatus.REVIEW: exec_meta.status = ExecutionStatus.RUNNING log_metadata.info( f"⚙️ Graph execution #{graph_exec.graph_exec_id} was waiting for review, resuming execution." ) update_graph_execution_state( db_client=db_client, graph_exec_id=graph_exec.graph_exec_id, status=ExecutionStatus.RUNNING, ) elif exec_meta.status == ExecutionStatus.FAILED: exec_meta.status = ExecutionStatus.RUNNING log_metadata.info( f"⚙️ Graph execution #{graph_exec.graph_exec_id} was disturbed, continuing where it left off." ) update_graph_execution_state( db_client=db_client, graph_exec_id=graph_exec.graph_exec_id, status=ExecutionStatus.RUNNING, ) else: log_metadata.warning( f"Skipped graph execution {graph_exec.graph_exec_id}, the graph execution status is `{exec_meta.status}`." ) return if exec_meta.stats is None: exec_stats = GraphExecutionStats() else: exec_stats = exec_meta.stats.to_db() timing_info, status = self._on_graph_execution( graph_exec=graph_exec, cancel=cancel, log_metadata=log_metadata, execution_stats=exec_stats, cluster_lock=cluster_lock, ) exec_stats.walltime += timing_info.wall_time exec_stats.cputime += timing_info.cpu_time try: # Failure handling if isinstance(status, BaseException): raise status exec_meta.status = status if status in [ExecutionStatus.COMPLETED, ExecutionStatus.FAILED]: activity_response = asyncio.run_coroutine_threadsafe( generate_activity_status_for_execution( graph_exec_id=graph_exec.graph_exec_id, graph_id=graph_exec.graph_id, graph_version=graph_exec.graph_version, execution_stats=exec_stats, db_client=get_db_async_client(), user_id=graph_exec.user_id, execution_status=status, ), self.node_execution_loop, ).result(timeout=60.0) else: activity_response = None if activity_response is not None: exec_stats.activity_status = activity_response["activity_status"] exec_stats.correctness_score = activity_response["correctness_score"] log_metadata.info( f"Generated activity status: {activity_response['activity_status']} " f"(correctness: {activity_response['correctness_score']:.2f})" ) else: log_metadata.debug( "Activity status generation disabled, not setting fields" ) finally: # Communication handling self._handle_agent_run_notif(db_client, graph_exec, exec_stats) update_graph_execution_state( db_client=db_client, graph_exec_id=graph_exec.graph_exec_id, status=exec_meta.status, stats=exec_stats, ) def _charge_usage( self, node_exec: NodeExecutionEntry, execution_count: int, ) -> tuple[int, int]: total_cost = 0 remaining_balance = 0 db_client = get_db_client() block = get_block(node_exec.block_id) if not block: logger.error(f"Block {node_exec.block_id} not found.") return total_cost, 0 cost, matching_filter = block_usage_cost( block=block, input_data=node_exec.inputs ) if cost > 0: remaining_balance = db_client.spend_credits( user_id=node_exec.user_id, cost=cost, metadata=UsageTransactionMetadata( graph_exec_id=node_exec.graph_exec_id, graph_id=node_exec.graph_id, node_exec_id=node_exec.node_exec_id, node_id=node_exec.node_id, block_id=node_exec.block_id, block=block.name, input=matching_filter, reason=f"Ran block {node_exec.block_id} {block.name}", ), ) total_cost += cost cost, usage_count = execution_usage_cost(execution_count) if cost > 0: remaining_balance = db_client.spend_credits( user_id=node_exec.user_id, cost=cost, metadata=UsageTransactionMetadata( graph_exec_id=node_exec.graph_exec_id, graph_id=node_exec.graph_id, input={ "execution_count": usage_count, "charge": "Execution Cost", }, reason=f"Execution Cost for {usage_count} blocks of ex_id:{node_exec.graph_exec_id} g_id:{node_exec.graph_id}", ), ) total_cost += cost return total_cost, remaining_balance @time_measured def _on_graph_execution( self, graph_exec: GraphExecutionEntry, cancel: threading.Event, log_metadata: LogMetadata, execution_stats: GraphExecutionStats, cluster_lock: ClusterLock, ) -> ExecutionStatus: """ Returns: dict: The execution statistics of the graph execution. ExecutionStatus: The final status of the graph execution. Exception | None: The error that occurred during the execution, if any. """ execution_status: ExecutionStatus = ExecutionStatus.RUNNING error: Exception | None = None db_client = get_db_client() execution_stats_lock = threading.Lock() # State holders ---------------------------------------------------- self.running_node_execution: dict[str, NodeExecutionProgress] = defaultdict( NodeExecutionProgress ) self.running_node_evaluation: dict[str, Future] = {} self.execution_stats = execution_stats self.execution_stats_lock = execution_stats_lock execution_queue = ExecutionQueue[NodeExecutionEntry]() running_node_execution = self.running_node_execution running_node_evaluation = self.running_node_evaluation try: if db_client.get_credits(graph_exec.user_id) <= 0: raise InsufficientBalanceError( user_id=graph_exec.user_id, message="You have no credits left to run an agent.", balance=0, amount=1, ) # Input moderation try: if moderation_error := asyncio.run_coroutine_threadsafe( automod_manager.moderate_graph_execution_inputs( db_client=get_db_async_client(), graph_exec=graph_exec, ), self.node_evaluation_loop, ).result(timeout=30.0): raise moderation_error except asyncio.TimeoutError: log_metadata.warning( f"Input moderation timed out for graph execution {graph_exec.graph_exec_id}, bypassing moderation and continuing execution" ) # Continue execution without moderation # ------------------------------------------------------------ # Pre‑populate queue --------------------------------------- # ------------------------------------------------------------ for node_exec in db_client.get_node_executions( graph_exec.graph_exec_id, statuses=[ ExecutionStatus.RUNNING, ExecutionStatus.QUEUED, ExecutionStatus.TERMINATED, ExecutionStatus.REVIEW, ], ): node_entry = node_exec.to_node_execution_entry( graph_exec.execution_context ) execution_queue.add(node_entry) # ------------------------------------------------------------ # Main dispatch / polling loop ----------------------------- # ------------------------------------------------------------ while not execution_queue.empty(): if cancel.is_set(): break queued_node_exec = execution_queue.get() # Check if this node should be skipped due to optional credentials if queued_node_exec.node_id in graph_exec.nodes_to_skip: log_metadata.info( f"Skipping node execution {queued_node_exec.node_exec_id} " f"for node {queued_node_exec.node_id} - optional credentials not configured" ) # Mark the node as completed without executing # No outputs will be produced, so downstream nodes won't trigger update_node_execution_status( db_client=db_client, exec_id=queued_node_exec.node_exec_id, status=ExecutionStatus.COMPLETED, ) continue log_metadata.debug( f"Dispatching node execution {queued_node_exec.node_exec_id} " f"for node {queued_node_exec.node_id}", ) # Charge usage (may raise) ------------------------------ try: cost, remaining_balance = self._charge_usage( node_exec=queued_node_exec, execution_count=increment_execution_count(graph_exec.user_id), ) with execution_stats_lock: execution_stats.cost += cost # Check if we crossed the low balance threshold self._handle_low_balance( db_client=db_client, user_id=graph_exec.user_id, current_balance=remaining_balance, transaction_cost=cost, ) except InsufficientBalanceError as balance_error: error = balance_error # Set error to trigger FAILED status node_exec_id = queued_node_exec.node_exec_id db_client.upsert_execution_output( node_exec_id=node_exec_id, output_name="error", output_data=str(error), ) update_node_execution_status( db_client=db_client, exec_id=node_exec_id, status=ExecutionStatus.FAILED, ) self._handle_insufficient_funds_notif( db_client, graph_exec.user_id, graph_exec.graph_id, error, ) # Gracefully stop the execution loop break # Add input overrides ----------------------------- node_id = queued_node_exec.node_id if (nodes_input_masks := graph_exec.nodes_input_masks) and ( node_input_mask := nodes_input_masks.get(node_id) ): queued_node_exec.inputs.update(node_input_mask) # Kick off async node execution ------------------------- node_execution_task = asyncio.run_coroutine_threadsafe( self.on_node_execution( node_exec=queued_node_exec, node_exec_progress=running_node_execution[node_id], nodes_input_masks=nodes_input_masks, graph_stats_pair=( execution_stats, execution_stats_lock, ), nodes_to_skip=graph_exec.nodes_to_skip, ), self.node_execution_loop, ) running_node_execution[node_id].add_task( node_exec_id=queued_node_exec.node_exec_id, task=node_execution_task, ) # Poll until queue refills or all inflight work done ---- while execution_queue.empty() and ( running_node_execution or running_node_evaluation ): if cancel.is_set(): break # -------------------------------------------------- # Handle inflight evaluations --------------------- # -------------------------------------------------- node_output_found = False for node_id, inflight_exec in list(running_node_execution.items()): if cancel.is_set(): break # node evaluation future ----------------- if inflight_eval := running_node_evaluation.get(node_id): if not inflight_eval.done(): continue try: inflight_eval.result(timeout=0) running_node_evaluation.pop(node_id) except Exception as e: log_metadata.error(f"Node eval #{node_id} failed: {e}") # node execution future --------------------------- if inflight_exec.is_done(): running_node_execution.pop(node_id) continue if output := inflight_exec.pop_output(): node_output_found = True running_node_evaluation[node_id] = ( asyncio.run_coroutine_threadsafe( self._process_node_output( output=output, node_id=node_id, graph_exec=graph_exec, log_metadata=log_metadata, nodes_input_masks=nodes_input_masks, execution_queue=execution_queue, ), self.node_evaluation_loop, ) ) if ( not node_output_found and execution_queue.empty() and (running_node_execution or running_node_evaluation) ): cluster_lock.refresh() time.sleep(0.1) # loop done -------------------------------------------------- # Output moderation try: if moderation_error := asyncio.run_coroutine_threadsafe( automod_manager.moderate_graph_execution_outputs( db_client=get_db_async_client(), graph_exec_id=graph_exec.graph_exec_id, user_id=graph_exec.user_id, graph_id=graph_exec.graph_id, ), self.node_evaluation_loop, ).result(timeout=30.0): raise moderation_error except asyncio.TimeoutError: log_metadata.warning( f"Output moderation timed out for graph execution {graph_exec.graph_exec_id}, bypassing moderation and continuing execution" ) # Continue execution without moderation # Determine final execution status based on whether there was an error or termination if cancel.is_set(): execution_status = ExecutionStatus.TERMINATED elif error is not None: execution_status = ExecutionStatus.FAILED else: if db_client.has_pending_reviews_for_graph_exec( graph_exec.graph_exec_id ): execution_status = ExecutionStatus.REVIEW else: execution_status = ExecutionStatus.COMPLETED if error: execution_stats.error = str(error) or type(error).__name__ return execution_status except BaseException as e: error = ( e if isinstance(e, Exception) else Exception(f"{e.__class__.__name__}: {e}") ) if not execution_stats.error: execution_stats.error = str(error) known_errors = (InsufficientBalanceError, ModerationError) if isinstance(error, known_errors): return ExecutionStatus.FAILED execution_status = ExecutionStatus.FAILED log_metadata.exception( f"Failed graph execution {graph_exec.graph_exec_id}: {error}" ) # Send rate-limited Discord alert for unknown/unexpected errors send_rate_limited_discord_alert( "graph_execution", error, "unknown_error", f"🚨 **Unknown Graph Execution Error**\n" f"User: {graph_exec.user_id}\n" f"Graph ID: {graph_exec.graph_id}\n" f"Execution ID: {graph_exec.graph_exec_id}\n" f"Error Type: {type(error).__name__}\n" f"Error: {str(error)[:200]}{'...' if len(str(error)) > 200 else ''}\n", ) raise finally: self._cleanup_graph_execution( execution_queue=execution_queue, running_node_execution=running_node_execution, running_node_evaluation=running_node_evaluation, execution_status=execution_status, error=error, graph_exec_id=graph_exec.graph_exec_id, log_metadata=log_metadata, db_client=db_client, ) @error_logged(swallow=True) def _cleanup_graph_execution( self, execution_queue: ExecutionQueue[NodeExecutionEntry], running_node_execution: dict[str, "NodeExecutionProgress"], running_node_evaluation: dict[str, Future], execution_status: ExecutionStatus, error: Exception | None, graph_exec_id: str, log_metadata: LogMetadata, db_client: "DatabaseManagerClient", ) -> None: """ Clean up running node executions and evaluations when graph execution ends. This method is decorated with @error_logged(swallow=True) to ensure cleanup never fails in the finally block. """ # Cancel and wait for all node executions to complete for node_id, inflight_exec in running_node_execution.items(): if inflight_exec.is_done(): continue log_metadata.info(f"Stopping node execution {node_id}") inflight_exec.stop() for node_id, inflight_exec in running_node_execution.items(): try: inflight_exec.wait_for_done(timeout=3600.0) except TimeoutError: log_metadata.exception( f"Node execution #{node_id} did not stop in time, " "it may be stuck or taking too long." ) # Wait the remaining inflight evaluations to finish for node_id, inflight_eval in running_node_evaluation.items(): try: inflight_eval.result(timeout=3600.0) except TimeoutError: log_metadata.exception( f"Node evaluation #{node_id} did not stop in time, " "it may be stuck or taking too long." ) while queued_execution := execution_queue.get_or_none(): update_node_execution_status( db_client=db_client, exec_id=queued_execution.node_exec_id, status=execution_status, stats={"error": str(error)} if error else None, ) clean_exec_files(graph_exec_id) @async_error_logged(swallow=True) async def _process_node_output( self, output: ExecutionOutputEntry, node_id: str, graph_exec: GraphExecutionEntry, log_metadata: LogMetadata, nodes_input_masks: Optional[NodesInputMasks], execution_queue: ExecutionQueue[NodeExecutionEntry], ) -> None: """Process a node's output, update its status, and enqueue next nodes. Args: output: The execution output entry to process node_id: The ID of the node that produced the output graph_exec: The graph execution entry log_metadata: Logger metadata for consistent logging nodes_input_masks: Optional map of node input overrides execution_queue: Queue to add next executions to """ db_client = get_db_async_client() log_metadata.debug(f"Enqueue nodes for {node_id}: {output}") for next_execution in await _enqueue_next_nodes( db_client=db_client, node=output.node, output=output.data, user_id=graph_exec.user_id, graph_exec_id=graph_exec.graph_exec_id, graph_id=graph_exec.graph_id, graph_version=graph_exec.graph_version, log_metadata=log_metadata, nodes_input_masks=nodes_input_masks, execution_context=graph_exec.execution_context, ): execution_queue.add(next_execution) def _handle_agent_run_notif( self, db_client: "DatabaseManagerClient", graph_exec: GraphExecutionEntry, exec_stats: GraphExecutionStats, ): metadata = db_client.get_graph_metadata( graph_exec.graph_id, graph_exec.graph_version ) outputs = db_client.get_node_executions( graph_exec.graph_exec_id, block_ids=[AgentOutputBlock().id], ) named_outputs = [ { key: value[0] if key == "name" else value for key, value in output.output_data.items() } for output in outputs ] queue_notification( NotificationEventModel( user_id=graph_exec.user_id, type=NotificationType.AGENT_RUN, data=AgentRunData( outputs=named_outputs, agent_name=metadata.name if metadata else "Unknown Agent", credits_used=exec_stats.cost, execution_time=exec_stats.walltime, graph_id=graph_exec.graph_id, node_count=exec_stats.node_count, ), ) ) def _handle_insufficient_funds_notif( self, db_client: "DatabaseManagerClient", user_id: str, graph_id: str, e: InsufficientBalanceError, ): # Check if we've already sent a notification for this user+agent combo. # We only send one notification per user per agent until they top up credits. redis_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}" try: redis_client = redis.get_redis() # SET NX returns True only if the key was newly set (didn't exist) is_new_notification = redis_client.set( redis_key, "1", nx=True, ex=INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS, ) if not is_new_notification: # Already notified for this user+agent, skip all notifications logger.debug( f"Skipping duplicate insufficient funds notification for " f"user={user_id}, graph={graph_id}" ) return except Exception as redis_error: # If Redis fails, log and continue to send the notification # (better to occasionally duplicate than to never notify) logger.warning( f"Failed to check/set insufficient funds notification flag in Redis: " f"{redis_error}" ) shortfall = abs(e.amount) - e.balance metadata = db_client.get_graph_metadata(graph_id) base_url = ( settings.config.frontend_base_url or settings.config.platform_base_url ) # Queue user email notification queue_notification( NotificationEventModel( user_id=user_id, type=NotificationType.ZERO_BALANCE, data=ZeroBalanceData( current_balance=e.balance, billing_page_link=f"{base_url}/profile/credits", shortfall=shortfall, agent_name=metadata.name if metadata else "Unknown Agent", ), ) ) # Send Discord system alert try: user_email = db_client.get_user_email_by_id(user_id) alert_message = ( f"❌ **Insufficient Funds Alert**\n" f"User: {user_email or user_id}\n" f"Agent: {metadata.name if metadata else 'Unknown Agent'}\n" f"Current balance: ${e.balance / 100:.2f}\n" f"Attempted cost: ${abs(e.amount) / 100:.2f}\n" f"Shortfall: ${abs(shortfall) / 100:.2f}\n" f"[View User Details]({base_url}/admin/spending?search={user_email})" ) get_notification_manager_client().discord_system_alert( alert_message, DiscordChannel.PRODUCT ) except Exception as alert_error: logger.error( f"Failed to send insufficient funds Discord alert: {alert_error}" ) def _handle_low_balance( self, db_client: "DatabaseManagerClient", user_id: str, current_balance: int, transaction_cost: int, ): """Check and handle low balance scenarios after a transaction""" LOW_BALANCE_THRESHOLD = settings.config.low_balance_threshold balance_before = current_balance + transaction_cost if ( current_balance < LOW_BALANCE_THRESHOLD and balance_before >= LOW_BALANCE_THRESHOLD ): base_url = ( settings.config.frontend_base_url or settings.config.platform_base_url ) queue_notification( NotificationEventModel( user_id=user_id, type=NotificationType.LOW_BALANCE, data=LowBalanceData( current_balance=current_balance, billing_page_link=f"{base_url}/profile/credits", ), ) ) try: user_email = db_client.get_user_email_by_id(user_id) alert_message = ( f"⚠️ **Low Balance Alert**\n" f"User: {user_email or user_id}\n" f"Balance dropped below ${LOW_BALANCE_THRESHOLD / 100:.2f}\n" f"Current balance: ${current_balance / 100:.2f}\n" f"Transaction cost: ${transaction_cost / 100:.2f}\n" f"[View User Details]({base_url}/admin/spending?search={user_email})" ) get_notification_manager_client().discord_system_alert( alert_message, DiscordChannel.PRODUCT ) except Exception as e: logger.error(f"Failed to send low balance Discord alert: {e}") class ExecutionManager(AppProcess): def __init__(self): super().__init__() self.pool_size = settings.config.num_graph_workers self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {} self.executor_id = str(uuid.uuid4()) self._executor = None self._stop_consuming = None self._cancel_thread = None self._cancel_client = None self._run_thread = None self._run_client = None self._execution_locks = {} @property def cancel_thread(self) -> threading.Thread: if self._cancel_thread is None: self._cancel_thread = threading.Thread( target=lambda: self._consume_execution_cancel(), daemon=True, ) return self._cancel_thread @property def run_thread(self) -> threading.Thread: if self._run_thread is None: self._run_thread = threading.Thread( target=lambda: self._consume_execution_run(), daemon=True, ) return self._run_thread @property def stop_consuming(self) -> threading.Event: if self._stop_consuming is None: self._stop_consuming = threading.Event() return self._stop_consuming @property def executor(self) -> ThreadPoolExecutor: if self._executor is None: self._executor = ThreadPoolExecutor( max_workers=self.pool_size, initializer=init_worker, ) return self._executor @property def cancel_client(self) -> SyncRabbitMQ: if self._cancel_client is None: self._cancel_client = SyncRabbitMQ(create_execution_queue_config()) return self._cancel_client @property def run_client(self) -> SyncRabbitMQ: if self._run_client is None: self._run_client = SyncRabbitMQ(create_execution_queue_config()) return self._run_client def run(self): logger.info( f"[{self.service_name}] 🆔 Pod assigned executor_id: {self.executor_id}" ) logger.info(f"[{self.service_name}] ⏳ Spawn max-{self.pool_size} workers...") pool_size_gauge.set(self.pool_size) self._update_prompt_metrics() start_http_server(settings.config.execution_manager_port) self.cancel_thread.start() self.run_thread.start() while True: time.sleep(1e5) @continuous_retry() def _consume_execution_cancel(self): if self.stop_consuming.is_set() and not self.active_graph_runs: logger.info( f"[{self.service_name}] Stop reconnecting cancel consumer since the service is cleaned up." ) return # Check if channel is closed and force reconnection if needed if not self.cancel_client.is_ready: self.cancel_client.disconnect() self.cancel_client.connect() cancel_channel = self.cancel_client.get_channel() cancel_channel.basic_consume( queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME, on_message_callback=self._handle_cancel_message, auto_ack=True, ) logger.info(f"[{self.service_name}] ⏳ Starting cancel message consumer...") cancel_channel.start_consuming() if not self.stop_consuming.is_set() or self.active_graph_runs: raise RuntimeError( f"[{self.service_name}] ❌ cancel message consumer is stopped: {cancel_channel}" ) logger.info( f"[{self.service_name}] ✅ Cancel message consumer stopped gracefully" ) @continuous_retry() def _consume_execution_run(self): # Long-running executions are handled by: # 1. Long consumer timeout (x-consumer-timeout) allows long running agent # 2. Enhanced connection settings (5 retries, 1s delay) for quick reconnection # 3. Process monitoring ensures failed executors release messages back to queue if self.stop_consuming.is_set(): logger.info( f"[{self.service_name}] Stop reconnecting execution consumer since the service is cleaned up." ) return # Check if channel is closed and force reconnection if needed if not self.run_client.is_ready: self.run_client.disconnect() self.run_client.connect() run_channel = self.run_client.get_channel() run_channel.basic_qos(prefetch_count=self.pool_size) # Configure consumer for long-running graph executions # auto_ack=False: Don't acknowledge messages until execution completes (prevents data loss) run_channel.basic_consume( queue=GRAPH_EXECUTION_QUEUE_NAME, on_message_callback=self._handle_run_message, auto_ack=False, consumer_tag="graph_execution_consumer", ) run_channel.confirm_delivery() logger.info(f"[{self.service_name}] ⏳ Starting to consume run messages...") run_channel.start_consuming() if not self.stop_consuming.is_set(): raise RuntimeError( f"[{self.service_name}] ❌ run message consumer is stopped: {run_channel}" ) logger.info(f"[{self.service_name}] ✅ Run message consumer stopped gracefully") @error_logged(swallow=True) def _handle_cancel_message( self, _channel: BlockingChannel, _method: Basic.Deliver, _properties: BasicProperties, body: bytes, ): """ Called whenever we receive a CANCEL message from the queue. (With auto_ack=True, message is considered 'acked' automatically.) """ request = CancelExecutionEvent.model_validate_json(body) graph_exec_id = request.graph_exec_id if not graph_exec_id: logger.warning( f"[{self.service_name}] Cancel message missing 'graph_exec_id'" ) return if graph_exec_id not in self.active_graph_runs: logger.debug( f"[{self.service_name}] Cancel received for {graph_exec_id} but not active." ) return _, cancel_event = self.active_graph_runs[graph_exec_id] logger.info(f"[{self.service_name}] Received cancel for {graph_exec_id}") if not cancel_event.is_set(): cancel_event.set() else: logger.debug( f"[{self.service_name}] Cancel already set for {graph_exec_id}" ) def _handle_run_message( self, _channel: BlockingChannel, method: Basic.Deliver, _properties: BasicProperties, body: bytes, ): delivery_tag = method.delivery_tag @func_retry def _ack_message(reject: bool, requeue: bool): """ Acknowledge or reject the message based on execution status. Args: reject: Whether to reject the message requeue: Whether to requeue the message """ # Connection can be lost, so always get a fresh channel channel = self.run_client.get_channel() if reject: if requeue and settings.config.requeue_by_republishing: # Send rejected message to back of queue using republishing def _republish_to_back(): try: # First republish to back of queue self.run_client.publish_message( routing_key=GRAPH_EXECUTION_ROUTING_KEY, message=body.decode(), # publish_message expects string, not bytes exchange=GRAPH_EXECUTION_EXCHANGE, ) # Then reject without requeue (message already republished) channel.basic_nack(delivery_tag, requeue=False) logger.info("Message requeued to back of queue") except Exception as e: logger.error( f"[{self.service_name}] Failed to requeue message to back: {e}" ) # Fall back to traditional requeue on failure channel.basic_nack(delivery_tag, requeue=True) channel.connection.add_callback_threadsafe(_republish_to_back) else: # Traditional requeue (goes to front) or no requeue channel.connection.add_callback_threadsafe( lambda: channel.basic_nack(delivery_tag, requeue=requeue) ) else: channel.connection.add_callback_threadsafe( lambda: channel.basic_ack(delivery_tag) ) # Check if we're shutting down - reject new messages but keep connection alive if self.stop_consuming.is_set(): logger.info( f"[{self.service_name}] Rejecting new execution during shutdown" ) _ack_message(reject=True, requeue=True) return # Check if we can accept more runs self._cleanup_completed_runs() if len(self.active_graph_runs) >= self.pool_size: _ack_message(reject=True, requeue=True) return try: graph_exec_entry = GraphExecutionEntry.model_validate_json(body) except Exception as e: logger.error( f"[{self.service_name}] Could not parse run message: {e}, body={body}" ) _ack_message(reject=True, requeue=False) return graph_exec_id = graph_exec_entry.graph_exec_id user_id = graph_exec_entry.user_id graph_id = graph_exec_entry.graph_id root_exec_id = graph_exec_entry.execution_context.root_execution_id parent_exec_id = graph_exec_entry.execution_context.parent_execution_id logger.info( f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}, user_id={user_id}, executor_id={self.executor_id}" + (f", root={root_exec_id}" if root_exec_id else "") + (f", parent={parent_exec_id}" if parent_exec_id else "") ) # Check if root execution is already terminated (prevents orphaned child executions) if root_exec_id and root_exec_id != graph_exec_id: parent_exec = get_db_client().get_graph_execution_meta( execution_id=root_exec_id, user_id=user_id, ) if parent_exec and parent_exec.status == ExecutionStatus.TERMINATED: logger.info( f"[{self.service_name}] Skipping execution {graph_exec_id} - parent {root_exec_id} is TERMINATED" ) # Mark this child as terminated since parent was stopped get_db_client().update_graph_execution_stats( graph_exec_id=graph_exec_id, status=ExecutionStatus.TERMINATED, ) _ack_message(reject=False, requeue=False) return # Check user rate limit before processing try: # Only check executions from the last 24 hours for performance current_running_count = get_db_client().get_graph_executions_count( user_id=user_id, graph_id=graph_id, statuses=[ExecutionStatus.RUNNING], created_time_gte=datetime.now(timezone.utc) - timedelta(hours=24), ) if ( current_running_count >= settings.config.max_concurrent_graph_executions_per_user ): logger.warning( f"[{self.service_name}] Rate limit exceeded for user {user_id} on graph {graph_id}: " f"{current_running_count}/{settings.config.max_concurrent_graph_executions_per_user} running executions" ) _ack_message(reject=True, requeue=True) return except Exception as e: logger.error( f"[{self.service_name}] Failed to check rate limit for user {user_id}: {e}, proceeding with execution" ) # If rate limit check fails, proceed to avoid blocking executions # Check for local duplicate execution first if graph_exec_id in self.active_graph_runs: logger.warning( f"[{self.service_name}] Graph {graph_exec_id} already running locally; rejecting duplicate." ) _ack_message(reject=True, requeue=True) return # Try to acquire cluster-wide execution lock cluster_lock = ClusterLock( redis=redis.get_redis(), key=f"exec_lock:{graph_exec_id}", owner_id=self.executor_id, timeout=settings.config.cluster_lock_timeout, ) current_owner = cluster_lock.try_acquire() if current_owner != self.executor_id: # Either someone else has it or Redis is unavailable if current_owner is not None: logger.warning( f"[{self.service_name}] Graph {graph_exec_id} already running on pod {current_owner}, current executor_id={self.executor_id}" ) _ack_message(reject=True, requeue=False) else: logger.warning( f"[{self.service_name}] Could not acquire lock for {graph_exec_id} - Redis unavailable" ) _ack_message(reject=True, requeue=True) return # Wrap entire block after successful lock acquisition try: self._execution_locks[graph_exec_id] = cluster_lock logger.info( f"[{self.service_name}] Successfully acquired cluster lock for {graph_exec_id}, executor_id={self.executor_id}" ) cancel_event = threading.Event() future = self.executor.submit( execute_graph, graph_exec_entry, cancel_event, cluster_lock ) self.active_graph_runs[graph_exec_id] = (future, cancel_event) except Exception as e: logger.warning( f"[{self.service_name}] Failed to setup execution for {graph_exec_id}: {type(e).__name__}: {e}" ) # Release cluster lock before requeue cluster_lock.release() if graph_exec_id in self._execution_locks: del self._execution_locks[graph_exec_id] _ack_message(reject=True, requeue=True) return self._update_prompt_metrics() def _on_run_done(f: Future): logger.info(f"[{self.service_name}] Run completed for {graph_exec_id}") try: if exec_error := f.exception(): logger.error( f"[{self.service_name}] Execution for {graph_exec_id} failed: {type(exec_error)} {exec_error}" ) _ack_message(reject=True, requeue=True) else: _ack_message(reject=False, requeue=False) except BaseException as e: logger.exception( f"[{self.service_name}] Error in run completion callback: {e}" ) finally: # Release the cluster-wide execution lock if graph_exec_id in self._execution_locks: logger.info( f"[{self.service_name}] Releasing cluster lock for {graph_exec_id}, executor_id={self.executor_id}" ) self._execution_locks[graph_exec_id].release() del self._execution_locks[graph_exec_id] self._cleanup_completed_runs() future.add_done_callback(_on_run_done) def _cleanup_completed_runs(self) -> list[str]: """Remove completed futures from active_graph_runs and update metrics""" completed_runs = [] for graph_exec_id, (future, _) in self.active_graph_runs.items(): if future.done(): completed_runs.append(graph_exec_id) for geid in completed_runs: logger.info(f"[{self.service_name}] ✅ Cleaned up completed run {geid}") self.active_graph_runs.pop(geid, None) self._update_prompt_metrics() return completed_runs def _update_prompt_metrics(self): active_count = len(self.active_graph_runs) active_runs_gauge.set(active_count) if self.stop_consuming.is_set(): utilization_gauge.set(1.0) else: utilization_gauge.set(active_count / self.pool_size) def _stop_message_consumers( self, thread: threading.Thread, client: SyncRabbitMQ, prefix: str ): try: channel = client.get_channel() channel.connection.add_callback_threadsafe(lambda: channel.stop_consuming()) try: thread.join(timeout=300) except TimeoutError: logger.error( f"{prefix} ⚠️ Run thread did not finish in time, forcing disconnect" ) client.disconnect() logger.info(f"{prefix} ✅ Run client disconnected") except Exception as e: logger.error(f"{prefix} ⚠️ Error disconnecting run client: {type(e)} {e}") def cleanup(self): """Override cleanup to implement graceful shutdown with active execution waiting.""" prefix = f"[{self.service_name}][on_graph_executor_stop {os.getpid()}]" logger.info(f"{prefix} 🧹 Starting graceful shutdown...") # Signal the consumer thread to stop (thread-safe) try: self.stop_consuming.set() run_channel = self.run_client.get_channel() run_channel.connection.add_callback_threadsafe( lambda: run_channel.stop_consuming() ) logger.info(f"{prefix} ✅ Exec consumer has been signaled to stop") except Exception as e: logger.error(f"{prefix} ⚠️ Error signaling consumer to stop: {type(e)} {e}") # Wait for active executions to complete if self.active_graph_runs: logger.info( f"{prefix} ⏳ Waiting for {len(self.active_graph_runs)} active executions to complete..." ) max_wait = GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS wait_interval = 5 waited = 0 while waited < max_wait: self._cleanup_completed_runs() if not self.active_graph_runs: logger.info(f"{prefix} ✅ All active executions completed") break else: ids = [k.split("-")[0] for k in self.active_graph_runs.keys()] logger.info( f"{prefix} ⏳ Still waiting for {len(self.active_graph_runs)} executions: {ids}" ) for graph_exec_id in self.active_graph_runs: if lock := self._execution_locks.get(graph_exec_id): lock.refresh() time.sleep(wait_interval) waited += wait_interval if self.active_graph_runs: logger.error( f"{prefix} ⚠️ {len(self.active_graph_runs)} executions still running after {max_wait}s" ) else: logger.info(f"{prefix} ✅ All executions completed gracefully") # Shutdown the executor try: self.executor.shutdown(cancel_futures=True, wait=False) logger.info(f"{prefix} ✅ Executor shutdown completed") except Exception as e: logger.error(f"{prefix} ⚠️ Error during executor shutdown: {type(e)} {e}") # Release remaining execution locks try: for lock in self._execution_locks.values(): lock.release() self._execution_locks.clear() logger.info(f"{prefix} ✅ Released execution locks") except Exception as e: logger.warning(f"{prefix} ⚠️ Failed to release all locks: {e}") # Disconnect the run execution consumer self._stop_message_consumers( self.run_thread, self.run_client, prefix + " [run-consumer]", ) self._stop_message_consumers( self.cancel_thread, self.cancel_client, prefix + " [cancel-consumer]", ) logger.info(f"{prefix} ✅ Finished GraphExec cleanup") super().cleanup() # ------- UTILITIES ------- # def get_db_client() -> "DatabaseManagerClient": return get_database_manager_client() def get_db_async_client() -> "DatabaseManagerAsyncClient": return get_database_manager_async_client() @func_retry async def send_async_execution_update( entry: GraphExecution | NodeExecutionResult | None, ) -> None: if entry is None: return await get_async_execution_event_bus().publish(entry) @func_retry def send_execution_update(entry: GraphExecution | NodeExecutionResult | None): if entry is None: return return get_execution_event_bus().publish(entry) async def async_update_node_execution_status( db_client: "DatabaseManagerAsyncClient", exec_id: str, status: ExecutionStatus, execution_data: BlockInput | None = None, stats: dict[str, Any] | None = None, ) -> NodeExecutionResult: """Sets status and fetches+broadcasts the latest state of the node execution""" exec_update = await db_client.update_node_execution_status( exec_id, status, execution_data, stats ) await send_async_execution_update(exec_update) return exec_update def update_node_execution_status( db_client: "DatabaseManagerClient", exec_id: str, status: ExecutionStatus, execution_data: BlockInput | None = None, stats: dict[str, Any] | None = None, ) -> NodeExecutionResult: """Sets status and fetches+broadcasts the latest state of the node execution""" exec_update = db_client.update_node_execution_status( exec_id, status, execution_data, stats ) send_execution_update(exec_update) return exec_update async def async_update_graph_execution_state( db_client: "DatabaseManagerAsyncClient", graph_exec_id: str, status: ExecutionStatus | None = None, stats: GraphExecutionStats | None = None, ) -> GraphExecution | None: """Sets status and fetches+broadcasts the latest state of the graph execution""" graph_update = await db_client.update_graph_execution_stats( graph_exec_id, status, stats ) if graph_update: await send_async_execution_update(graph_update) else: logger.error(f"Failed to update graph execution stats for {graph_exec_id}") return graph_update def update_graph_execution_state( db_client: "DatabaseManagerClient", graph_exec_id: str, status: ExecutionStatus | None = None, stats: GraphExecutionStats | None = None, ) -> GraphExecution | None: """Sets status and fetches+broadcasts the latest state of the graph execution""" graph_update = db_client.update_graph_execution_stats(graph_exec_id, status, stats) if graph_update: send_execution_update(graph_update) else: logger.error(f"Failed to update graph execution stats for {graph_exec_id}") return graph_update @asynccontextmanager async def synchronized(key: str, timeout: int = settings.config.cluster_lock_timeout): r = await redis.get_redis_async() lock: AsyncRedisLock = r.lock(f"lock:{key}", timeout=timeout) try: await lock.acquire() yield finally: if await lock.locked() and await lock.owned(): try: await lock.release() except Exception as e: logger.warning(f"Failed to release lock for key {key}: {e}") def increment_execution_count(user_id: str) -> int: """ Increment the execution count for a given user, this will be used to charge the user for the execution cost. """ r = redis.get_redis() k = f"uec:{user_id}" # User Execution Count global key counter = cast(int, r.incr(k)) if counter == 1: r.expire(k, settings.config.execution_counter_expiration_time) return counter