Files
AutoGPT/autogpt_platform/backend/backend/executor/manager.py
Nicholas Tindle 7668c17d9c feat(platform): add User Workspace for persistent CoPilot file storage (#11867)
Implements persistent User Workspace storage for CoPilot, enabling
blocks to save and retrieve files across sessions. Files are stored in
session-scoped virtual paths (`/sessions/{session_id}/`).

Fixes SECRT-1833

### Changes 🏗️

**Database & Storage:**
- Add `UserWorkspace` and `UserWorkspaceFile` Prisma models
- Implement `WorkspaceStorageBackend` abstraction (GCS for cloud, local
filesystem for self-hosted)
- Add `workspace_id` and `session_id` fields to `ExecutionContext`

**Backend API:**
- Add REST endpoints: `GET/POST /api/workspace/files`, `GET/DELETE
/api/workspace/files/{id}`, `GET /api/workspace/files/{id}/download`
- Add CoPilot tools: `list_workspace_files`, `read_workspace_file`,
`write_workspace_file`
- Integrate workspace storage into `store_media_file()` - returns
`workspace://file-id` references

**Block Updates:**
- Refactor all file-handling blocks to use unified `ExecutionContext`
parameter
- Update media-generating blocks to persist outputs to workspace
(AIImageGenerator, AIImageCustomizer, FluxKontext, TalkingHead, FAL
video, Bannerbear, etc.)

**Frontend:**
- Render `workspace://` image references in chat via proxy endpoint
- Add "AI cannot see this image" overlay indicator

**CoPilot Context Mapping:**
- Session = Agent (graph_id) = Run (graph_exec_id)
- Files scoped to `/sessions/{session_id}/`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan:
- [ ] Create CoPilot session, generate image with AIImageGeneratorBlock
  - [ ] Verify image returns `workspace://file-id` (not base64)
  - [ ] Verify image renders in chat with visibility indicator
  - [ ] Verify workspace files persist across sessions
  - [ ] Test list/read/write workspace files via CoPilot tools
  - [ ] Test local storage backend for self-hosted deployments

#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

🤖 Generated with [Claude Code](https://claude.ai/code)

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> Introduces a new persistent file-storage surface area (DB tables,
storage backends, download API, and chat tools) and rewires
`store_media_file()`/block execution context across many blocks, so
regressions could impact file handling, access control, or storage
costs.
> 
> **Overview**
> Adds a **persistent per-user Workspace** (new
`UserWorkspace`/`UserWorkspaceFile` models plus `WorkspaceManager` +
`WorkspaceStorageBackend` with GCS/local implementations) and wires it
into the API via a new `/api/workspace/files/{file_id}/download` route
(including header-sanitized `Content-Disposition`) and shutdown
lifecycle hooks.
> 
> Extends `ExecutionContext` to carry execution identity +
`workspace_id`/`session_id`, updates executor tooling to clone
node-specific contexts, and updates `run_block` (CoPilot) to create a
session-scoped workspace and synthetic graph/run/node IDs.
> 
> Refactors `store_media_file()` to require `execution_context` +
`return_format` and to support `workspace://` references; migrates many
media/file-handling blocks and related tests to the new API and to
persist generated media as `workspace://...` (or fall back to data URIs
outside CoPilot), and adds CoPilot chat tools for
listing/reading/writing/deleting workspace files with safeguards against
context bloat.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
6abc70f793. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2026-01-29 05:49:47 +00:00

2073 lines
81 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import logging
import os
import threading
import time
import uuid
from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
import sentry_sdk
from pika.adapters.blocking_connection import BlockingChannel
from pika.spec import Basic, BasicProperties
from prometheus_client import Gauge, start_http_server
from redis.asyncio.lock import Lock as AsyncRedisLock
from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.io import AgentOutputBlock
from backend.data import redis_client as redis
from backend.data.block import (
BlockInput,
BlockOutput,
BlockOutputEntry,
BlockSchema,
get_block,
)
from backend.data.credit import UsageTransactionMetadata
from backend.data.dynamic_fields import parse_execution_output
from backend.data.execution import (
ExecutionContext,
ExecutionQueue,
ExecutionStatus,
GraphExecution,
GraphExecutionEntry,
NodeExecutionEntry,
NodeExecutionResult,
NodesInputMasks,
)
from backend.data.graph import Link, Node
from backend.data.model import GraphExecutionStats, NodeExecutionStats
from backend.data.notifications import (
AgentRunData,
LowBalanceData,
NotificationEventModel,
NotificationType,
ZeroBalanceData,
)
from backend.data.rabbitmq import SyncRabbitMQ
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.notifications.notifications import queue_notification
from backend.util import json
from backend.util.clients import (
get_async_execution_event_bus,
get_database_manager_async_client,
get_database_manager_client,
get_execution_event_bus,
get_notification_manager_client,
)
from backend.util.decorator import (
async_error_logged,
async_time_measured,
error_logged,
time_measured,
)
from backend.util.exceptions import InsufficientBalanceError, ModerationError
from backend.util.file import clean_exec_files
from backend.util.logging import TruncatedLogger, configure_logging
from backend.util.metrics import DiscordChannel
from backend.util.process import AppProcess, set_service_name
from backend.util.retry import (
continuous_retry,
func_retry,
send_rate_limited_discord_alert,
)
from backend.util.settings import Settings
from .activity_status_generator import generate_activity_status_for_execution
from .automod.manager import automod_manager
from .cluster_lock import ClusterLock
from .utils import (
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
GRAPH_EXECUTION_EXCHANGE,
GRAPH_EXECUTION_QUEUE_NAME,
GRAPH_EXECUTION_ROUTING_KEY,
CancelExecutionEvent,
ExecutionOutputEntry,
LogMetadata,
NodeExecutionProgress,
block_usage_cost,
create_execution_queue_config,
execution_usage_cost,
validate_exec,
)
if TYPE_CHECKING:
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
_logger = logging.getLogger(__name__)
logger = TruncatedLogger(_logger, prefix="[GraphExecutor]")
settings = Settings()
active_runs_gauge = Gauge(
"execution_manager_active_runs", "Number of active graph runs"
)
pool_size_gauge = Gauge(
"execution_manager_pool_size", "Maximum number of graph workers"
)
utilization_gauge = Gauge(
"execution_manager_utilization_ratio",
"Ratio of active graph runs to max graph workers",
)
# Redis key prefix for tracking insufficient funds Discord notifications.
# We only send one notification per user per agent until they top up credits.
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX = "insufficient_funds_discord_notified"
# TTL for the notification flag (30 days) - acts as a fallback cleanup
INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS = 30 * 24 * 60 * 60
async def clear_insufficient_funds_notifications(user_id: str) -> int:
"""
Clear all insufficient funds notification flags for a user.
This should be called when a user tops up their credits, allowing
Discord notifications to be sent again if they run out of funds.
Args:
user_id: The user ID to clear notifications for.
Returns:
The number of keys that were deleted.
"""
try:
redis_client = await redis.get_redis_async()
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
keys = [key async for key in redis_client.scan_iter(match=pattern)]
if keys:
return await redis_client.delete(*keys)
return 0
except Exception as e:
logger.warning(
f"Failed to clear insufficient funds notification flags for user "
f"{user_id}: {e}"
)
return 0
# Thread-local storage for ExecutionProcessor instances
_tls = threading.local()
def init_worker():
"""Initialize ExecutionProcessor instance in thread-local storage"""
_tls.processor = ExecutionProcessor()
_tls.processor.on_graph_executor_start()
def execute_graph(
graph_exec_entry: "GraphExecutionEntry",
cancel_event: threading.Event,
cluster_lock: ClusterLock,
):
"""Execute graph using thread-local ExecutionProcessor instance"""
processor: ExecutionProcessor = _tls.processor
return processor.on_graph_execution(graph_exec_entry, cancel_event, cluster_lock)
T = TypeVar("T")
async def execute_node(
node: Node,
data: NodeExecutionEntry,
execution_processor: "ExecutionProcessor",
execution_stats: NodeExecutionStats | None = None,
nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_to_skip: Optional[set[str]] = None,
) -> BlockOutput:
"""
Execute a node in the graph. This will trigger a block execution on a node,
persist the execution result, and return the subsequent node to be executed.
Args:
db_client: The client to send execution updates to the server.
creds_manager: The manager to acquire and release credentials.
data: The execution data for executing the current node.
execution_stats: The execution statistics to be updated.
Returns:
The subsequent node to be enqueued, or None if there is no subsequent node.
"""
user_id = data.user_id
graph_exec_id = data.graph_exec_id
graph_id = data.graph_id
graph_version = data.graph_version
node_exec_id = data.node_exec_id
node_id = data.node_id
node_block = node.block
execution_context = data.execution_context
creds_manager = execution_processor.creds_manager
log_metadata = LogMetadata(
logger=_logger,
user_id=user_id,
graph_eid=graph_exec_id,
graph_id=graph_id,
node_eid=node_exec_id,
node_id=node_id,
block_name=node_block.name,
)
# Sanity check: validate the execution input.
input_data, error = validate_exec(node, data.inputs, resolve_input=False)
if input_data is None:
log_metadata.error(f"Skip execution, input validation error: {error}")
yield "error", error
return
# Re-shape the input data for agent block.
# AgentExecutorBlock specially separate the node input_data & its input_default.
if isinstance(node_block, AgentExecutorBlock):
_input_data = AgentExecutorBlock.Input(**node.input_default)
_input_data.inputs = input_data
if nodes_input_masks:
_input_data.nodes_input_masks = nodes_input_masks
_input_data.user_id = user_id
input_data = _input_data.model_dump()
data.inputs = input_data
# Execute the node
input_data_str = json.dumps(input_data)
input_size = len(input_data_str)
log_metadata.debug("Executed node with input", input=input_data_str)
# Create node-specific execution context to avoid race conditions
# (multiple nodes can execute concurrently and would otherwise mutate shared state)
execution_context = execution_context.model_copy(
update={"node_id": node_id, "node_exec_id": node_exec_id}
)
# Inject extra execution arguments for the blocks via kwargs
# Keep individual kwargs for backwards compatibility with existing blocks
extra_exec_kwargs: dict = {
"graph_id": graph_id,
"graph_version": graph_version,
"node_id": node_id,
"graph_exec_id": graph_exec_id,
"node_exec_id": node_exec_id,
"user_id": user_id,
"execution_context": execution_context,
"execution_processor": execution_processor,
"nodes_to_skip": nodes_to_skip or set(),
}
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
# changes during execution. ⚠️ This means a set of credentials can only be used by
# one (running) block at a time; simultaneous execution of blocks using same
# credentials is not supported.
creds_locks: list[AsyncRedisLock] = []
input_model = cast(type[BlockSchema], node_block.input_schema)
# Handle regular credentials fields
for field_name, input_type in input_model.get_credentials_fields().items():
credentials_meta = input_type(**input_data[field_name])
credentials, lock = await creds_manager.acquire(user_id, credentials_meta.id)
creds_locks.append(lock)
extra_exec_kwargs[field_name] = credentials
# Handle auto-generated credentials (e.g., from GoogleDriveFileInput)
for kwarg_name, info in input_model.get_auto_credentials_fields().items():
field_name = info["field_name"]
field_data = input_data.get(field_name)
if field_data and isinstance(field_data, dict):
# Check if _credentials_id key exists in the field data
if "_credentials_id" in field_data:
cred_id = field_data["_credentials_id"]
if cred_id:
# Credential ID provided - acquire credentials
provider = info.get("config", {}).get(
"provider", "external service"
)
file_name = field_data.get("name", "selected file")
try:
credentials, lock = await creds_manager.acquire(
user_id, cred_id
)
creds_locks.append(lock)
extra_exec_kwargs[kwarg_name] = credentials
except ValueError:
# Credential was deleted or doesn't exist
raise ValueError(
f"Authentication expired for '{file_name}' in field '{field_name}'. "
f"The saved {provider.capitalize()} credentials no longer exist. "
f"Please re-select the file to re-authenticate."
)
# else: _credentials_id is explicitly None, skip credentials (for chained data)
else:
# _credentials_id key missing entirely - this is an error
provider = info.get("config", {}).get("provider", "external service")
file_name = field_data.get("name", "selected file")
raise ValueError(
f"Authentication missing for '{file_name}' in field '{field_name}'. "
f"Please re-select the file to authenticate with {provider.capitalize()}."
)
output_size = 0
# sentry tracking nonsense to get user counts for blocks because isolation scopes don't work :(
scope = sentry_sdk.get_current_scope()
# save the tags
original_user = scope._user
original_tags = dict(scope._tags) if scope._tags else {}
# Set user ID for error tracking
scope.set_user({"id": user_id})
scope.set_tag("graph_id", graph_id)
scope.set_tag("node_id", node_id)
scope.set_tag("block_name", node_block.name)
scope.set_tag("block_id", node_block.id)
for k, v in execution_context.model_dump().items():
scope.set_tag(f"execution_context.{k}", v)
try:
async for output_name, output_data in node_block.execute(
input_data, **extra_exec_kwargs
):
output_data = json.to_dict(output_data)
output_size += len(json.dumps(output_data))
log_metadata.debug("Node produced output", **{output_name: output_data})
yield output_name, output_data
except Exception as ex:
# Capture exception WITH context still set before restoring scope
sentry_sdk.capture_exception(error=ex, scope=scope)
sentry_sdk.flush() # Ensure it's sent before we restore scope
# Re-raise to maintain normal error flow
raise
finally:
# Ensure all credentials are released even if execution fails
for creds_lock in creds_locks:
if (
creds_lock
and (await creds_lock.locked())
and (await creds_lock.owned())
):
try:
await creds_lock.release()
except Exception as e:
log_metadata.error(f"Failed to release credentials lock: {e}")
# Update execution stats
if execution_stats is not None:
execution_stats += node_block.execution_stats
execution_stats.input_size = input_size
execution_stats.output_size = output_size
# Restore scope AFTER error has been captured
scope._user = original_user
scope._tags = original_tags
async def _enqueue_next_nodes(
db_client: "DatabaseManagerAsyncClient",
node: Node,
output: BlockOutputEntry,
user_id: str,
graph_exec_id: str,
graph_id: str,
graph_version: int,
log_metadata: LogMetadata,
nodes_input_masks: Optional[NodesInputMasks],
execution_context: ExecutionContext,
) -> list[NodeExecutionEntry]:
async def add_enqueued_execution(
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
) -> NodeExecutionEntry:
await async_update_node_execution_status(
db_client=db_client,
exec_id=node_exec_id,
status=ExecutionStatus.QUEUED,
execution_data=data,
)
return NodeExecutionEntry(
user_id=user_id,
graph_exec_id=graph_exec_id,
graph_id=graph_id,
graph_version=graph_version,
node_exec_id=node_exec_id,
node_id=node_id,
block_id=block_id,
inputs=data,
execution_context=execution_context,
)
async def register_next_executions(node_link: Link) -> list[NodeExecutionEntry]:
try:
return await _register_next_executions(node_link)
except Exception as e:
log_metadata.exception(f"Failed to register next executions: {e}")
return []
async def _register_next_executions(node_link: Link) -> list[NodeExecutionEntry]:
enqueued_executions = []
next_output_name = node_link.source_name
next_input_name = node_link.sink_name
next_node_id = node_link.sink_id
output_name, _ = output
next_data = parse_execution_output(
output, next_output_name, next_node_id, next_input_name
)
if next_data is None and output_name != next_output_name:
return enqueued_executions
next_node = await db_client.get_node(next_node_id)
# Multiple node can register the same next node, we need this to be atomic
# To avoid same execution to be enqueued multiple times,
# Or the same input to be consumed multiple times.
async with synchronized(f"upsert_input-{next_node_id}-{graph_exec_id}"):
# Add output data to the earliest incomplete execution, or create a new one.
next_node_exec, next_node_input = await db_client.upsert_execution_input(
node_id=next_node_id,
graph_exec_id=graph_exec_id,
input_name=next_input_name,
input_data=next_data,
)
next_node_exec_id = next_node_exec.node_exec_id
await send_async_execution_update(next_node_exec)
# Complete missing static input pins data using the last execution input.
static_link_names = {
link.sink_name
for link in next_node.input_links
if link.is_static and link.sink_name not in next_node_input
}
if static_link_names and (
latest_execution := await db_client.get_latest_node_execution(
next_node_id, graph_exec_id
)
):
for name in static_link_names:
next_node_input[name] = latest_execution.input_data.get(name)
# Apply node input overrides
node_input_mask = None
if nodes_input_masks and (
node_input_mask := nodes_input_masks.get(next_node.id)
):
next_node_input.update(node_input_mask)
# Validate the input data for the next node.
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
suffix = f"{next_output_name}>{next_input_name}~{next_node_exec_id}:{validation_msg}"
# Incomplete input data, skip queueing the execution.
if not next_node_input:
log_metadata.info(f"Skipped queueing {suffix}")
return enqueued_executions
# Input is complete, enqueue the execution.
log_metadata.info(f"Enqueued {suffix}")
enqueued_executions.append(
await add_enqueued_execution(
node_exec_id=next_node_exec_id,
node_id=next_node_id,
block_id=next_node.block_id,
data=next_node_input,
)
)
# Next execution stops here if the link is not static.
if not node_link.is_static:
return enqueued_executions
# If link is static, there could be some incomplete executions waiting for it.
# Load and complete the input missing input data, and try to re-enqueue them.
for iexec in await db_client.get_node_executions(
node_id=next_node_id,
graph_exec_id=graph_exec_id,
statuses=[ExecutionStatus.INCOMPLETE],
):
idata = iexec.input_data
ineid = iexec.node_exec_id
static_link_names = {
link.sink_name
for link in next_node.input_links
if link.is_static and link.sink_name not in idata
}
for input_name in static_link_names:
idata[input_name] = next_node_input[input_name]
# Apply node input overrides
if node_input_mask:
idata.update(node_input_mask)
idata, msg = validate_exec(next_node, idata)
suffix = f"{next_output_name}>{next_input_name}~{ineid}:{msg}"
if not idata:
log_metadata.info(f"Enqueueing static-link skipped: {suffix}")
continue
log_metadata.info(f"Enqueueing static-link execution {suffix}")
enqueued_executions.append(
await add_enqueued_execution(
node_exec_id=iexec.node_exec_id,
node_id=next_node_id,
block_id=next_node.block_id,
data=idata,
)
)
return enqueued_executions
return [
execution
for link in node.output_links
for execution in await register_next_executions(link)
]
class ExecutionProcessor:
"""
This class contains event handlers for the process pool executor events.
The main events are:
on_graph_executor_start: Initialize the process that executes the graph.
on_graph_execution: Execution logic for a graph.
on_node_execution: Execution logic for a node.
The execution flow:
1. Graph execution request is added to the queue.
2. Graph executor loop picks the request from the queue.
3. Graph executor loop submits the graph execution request to the executor pool.
[on_graph_execution]
4. Graph executor initialize the node execution queue.
5. Graph executor adds the starting nodes to the node execution queue.
6. Graph executor waits for all nodes to be executed.
[on_node_execution]
7. Node executor picks the node execution request from the queue.
8. Node executor executes the node.
9. Node executor enqueues the next executed nodes to the node execution queue.
"""
@async_error_logged(swallow=True)
async def on_node_execution(
self,
node_exec: NodeExecutionEntry,
node_exec_progress: NodeExecutionProgress,
nodes_input_masks: Optional[NodesInputMasks],
graph_stats_pair: tuple[GraphExecutionStats, threading.Lock],
nodes_to_skip: Optional[set[str]] = None,
) -> NodeExecutionStats:
log_metadata = LogMetadata(
logger=_logger,
user_id=node_exec.user_id,
graph_eid=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_eid=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_name=b.name if (b := get_block(node_exec.block_id)) else "-",
)
db_client = get_db_async_client()
node = await db_client.get_node(node_exec.node_id)
execution_stats = NodeExecutionStats()
timing_info, status = await self._on_node_execution(
node=node,
node_exec=node_exec,
node_exec_progress=node_exec_progress,
stats=execution_stats,
db_client=db_client,
log_metadata=log_metadata,
nodes_input_masks=nodes_input_masks,
nodes_to_skip=nodes_to_skip,
)
if isinstance(status, BaseException):
raise status
execution_stats.walltime = timing_info.wall_time
execution_stats.cputime = timing_info.cpu_time
graph_stats, graph_stats_lock = graph_stats_pair
with graph_stats_lock:
graph_stats.node_count += 1 + execution_stats.extra_steps
graph_stats.nodes_cputime += execution_stats.cputime
graph_stats.nodes_walltime += execution_stats.walltime
graph_stats.cost += execution_stats.extra_cost
if isinstance(execution_stats.error, Exception):
graph_stats.node_error_count += 1
node_error = execution_stats.error
node_stats = execution_stats.model_dump()
if node_error and not isinstance(node_error, str):
node_stats["error"] = str(node_error) or node_stats.__class__.__name__
await async_update_node_execution_status(
db_client=db_client,
exec_id=node_exec.node_exec_id,
status=status,
stats=node_stats,
)
await async_update_graph_execution_state(
db_client=db_client,
graph_exec_id=node_exec.graph_exec_id,
stats=graph_stats,
)
return execution_stats
@async_time_measured
async def _on_node_execution(
self,
node: Node,
node_exec: NodeExecutionEntry,
node_exec_progress: NodeExecutionProgress,
stats: NodeExecutionStats,
db_client: "DatabaseManagerAsyncClient",
log_metadata: LogMetadata,
nodes_input_masks: Optional[NodesInputMasks] = None,
nodes_to_skip: Optional[set[str]] = None,
) -> ExecutionStatus:
status = ExecutionStatus.RUNNING
async def persist_output(output_name: str, output_data: Any) -> None:
await db_client.upsert_execution_output(
node_exec_id=node_exec.node_exec_id,
output_name=output_name,
output_data=output_data,
)
if exec_update := await db_client.get_node_execution(
node_exec.node_exec_id
):
await send_async_execution_update(exec_update)
node_exec_progress.add_output(
ExecutionOutputEntry(
node=node,
node_exec_id=node_exec.node_exec_id,
data=(output_name, output_data),
)
)
try:
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
await async_update_node_execution_status(
db_client=db_client,
exec_id=node_exec.node_exec_id,
status=ExecutionStatus.RUNNING,
)
async for output_name, output_data in execute_node(
node=node,
data=node_exec,
execution_processor=self,
execution_stats=stats,
nodes_input_masks=nodes_input_masks,
nodes_to_skip=nodes_to_skip,
):
await persist_output(output_name, output_data)
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
status = ExecutionStatus.COMPLETED
except BaseException as e:
stats.error = e
if isinstance(e, ValueError):
# Avoid user error being marked as an actual error.
log_metadata.info(
f"Expected failure on node execution {node_exec.node_exec_id}: {type(e).__name__} - {e}"
)
status = ExecutionStatus.FAILED
elif isinstance(e, Exception):
# If the exception is not a ValueError, it is unexpected.
log_metadata.exception(
f"Unexpected failure on node execution {node_exec.node_exec_id}: {type(e).__name__} - {e}"
)
status = ExecutionStatus.FAILED
else:
# CancelledError or SystemExit
log_metadata.warning(
f"Interruption error on node execution {node_exec.node_exec_id}: {type(e).__name__}"
)
status = ExecutionStatus.TERMINATED
finally:
if status == ExecutionStatus.FAILED and stats.error is not None:
await persist_output(
"error", str(stats.error) or type(stats.error).__name__
)
return status
@func_retry
def on_graph_executor_start(self):
configure_logging()
set_service_name("GraphExecutor")
self.tid = threading.get_ident()
self.creds_manager = IntegrationCredentialsManager()
self.node_execution_loop = asyncio.new_event_loop()
self.node_evaluation_loop = asyncio.new_event_loop()
self.node_execution_thread = threading.Thread(
target=self.node_execution_loop.run_forever, daemon=True
)
self.node_evaluation_thread = threading.Thread(
target=self.node_evaluation_loop.run_forever, daemon=True
)
self.node_execution_thread.start()
self.node_evaluation_thread.start()
logger.info(f"[GraphExecutor] {self.tid} started")
@error_logged(swallow=False)
def on_graph_execution(
self,
graph_exec: GraphExecutionEntry,
cancel: threading.Event,
cluster_lock: ClusterLock,
):
log_metadata = LogMetadata(
logger=_logger,
user_id=graph_exec.user_id,
graph_eid=graph_exec.graph_exec_id,
graph_id=graph_exec.graph_id,
node_id="*",
node_eid="*",
block_name="-",
)
db_client = get_db_client()
exec_meta = db_client.get_graph_execution_meta(
user_id=graph_exec.user_id,
execution_id=graph_exec.graph_exec_id,
)
if exec_meta is None:
log_metadata.warning(
f"Skipped graph execution #{graph_exec.graph_exec_id}, the graph execution is not found."
)
return
if exec_meta.status in [ExecutionStatus.QUEUED, ExecutionStatus.INCOMPLETE]:
log_metadata.info(f"⚙️ Starting graph execution #{graph_exec.graph_exec_id}")
exec_meta.status = ExecutionStatus.RUNNING
send_execution_update(
db_client.update_graph_execution_start_time(graph_exec.graph_exec_id)
)
elif exec_meta.status == ExecutionStatus.RUNNING:
log_metadata.info(
f"⚙️ Graph execution #{graph_exec.graph_exec_id} is already running, continuing where it left off."
)
elif exec_meta.status == ExecutionStatus.REVIEW:
exec_meta.status = ExecutionStatus.RUNNING
log_metadata.info(
f"⚙️ Graph execution #{graph_exec.graph_exec_id} was waiting for review, resuming execution."
)
update_graph_execution_state(
db_client=db_client,
graph_exec_id=graph_exec.graph_exec_id,
status=ExecutionStatus.RUNNING,
)
elif exec_meta.status == ExecutionStatus.FAILED:
exec_meta.status = ExecutionStatus.RUNNING
log_metadata.info(
f"⚙️ Graph execution #{graph_exec.graph_exec_id} was disturbed, continuing where it left off."
)
update_graph_execution_state(
db_client=db_client,
graph_exec_id=graph_exec.graph_exec_id,
status=ExecutionStatus.RUNNING,
)
else:
log_metadata.warning(
f"Skipped graph execution {graph_exec.graph_exec_id}, the graph execution status is `{exec_meta.status}`."
)
return
if exec_meta.stats is None:
exec_stats = GraphExecutionStats()
else:
exec_stats = exec_meta.stats.to_db()
timing_info, status = self._on_graph_execution(
graph_exec=graph_exec,
cancel=cancel,
log_metadata=log_metadata,
execution_stats=exec_stats,
cluster_lock=cluster_lock,
)
exec_stats.walltime += timing_info.wall_time
exec_stats.cputime += timing_info.cpu_time
try:
# Failure handling
if isinstance(status, BaseException):
raise status
exec_meta.status = status
if status in [ExecutionStatus.COMPLETED, ExecutionStatus.FAILED]:
activity_response = asyncio.run_coroutine_threadsafe(
generate_activity_status_for_execution(
graph_exec_id=graph_exec.graph_exec_id,
graph_id=graph_exec.graph_id,
graph_version=graph_exec.graph_version,
execution_stats=exec_stats,
db_client=get_db_async_client(),
user_id=graph_exec.user_id,
execution_status=status,
),
self.node_execution_loop,
).result(timeout=60.0)
else:
activity_response = None
if activity_response is not None:
exec_stats.activity_status = activity_response["activity_status"]
exec_stats.correctness_score = activity_response["correctness_score"]
log_metadata.info(
f"Generated activity status: {activity_response['activity_status']} "
f"(correctness: {activity_response['correctness_score']:.2f})"
)
else:
log_metadata.debug(
"Activity status generation disabled, not setting fields"
)
finally:
# Communication handling
self._handle_agent_run_notif(db_client, graph_exec, exec_stats)
update_graph_execution_state(
db_client=db_client,
graph_exec_id=graph_exec.graph_exec_id,
status=exec_meta.status,
stats=exec_stats,
)
def _charge_usage(
self,
node_exec: NodeExecutionEntry,
execution_count: int,
) -> tuple[int, int]:
total_cost = 0
remaining_balance = 0
db_client = get_db_client()
block = get_block(node_exec.block_id)
if not block:
logger.error(f"Block {node_exec.block_id} not found.")
return total_cost, 0
cost, matching_filter = block_usage_cost(
block=block, input_data=node_exec.inputs
)
if cost > 0:
remaining_balance = db_client.spend_credits(
user_id=node_exec.user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_exec_id=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_id=node_exec.block_id,
block=block.name,
input=matching_filter,
reason=f"Ran block {node_exec.block_id} {block.name}",
),
)
total_cost += cost
cost, usage_count = execution_usage_cost(execution_count)
if cost > 0:
remaining_balance = db_client.spend_credits(
user_id=node_exec.user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
input={
"execution_count": usage_count,
"charge": "Execution Cost",
},
reason=f"Execution Cost for {usage_count} blocks of ex_id:{node_exec.graph_exec_id} g_id:{node_exec.graph_id}",
),
)
total_cost += cost
return total_cost, remaining_balance
@time_measured
def _on_graph_execution(
self,
graph_exec: GraphExecutionEntry,
cancel: threading.Event,
log_metadata: LogMetadata,
execution_stats: GraphExecutionStats,
cluster_lock: ClusterLock,
) -> ExecutionStatus:
"""
Returns:
dict: The execution statistics of the graph execution.
ExecutionStatus: The final status of the graph execution.
Exception | None: The error that occurred during the execution, if any.
"""
execution_status: ExecutionStatus = ExecutionStatus.RUNNING
error: Exception | None = None
db_client = get_db_client()
execution_stats_lock = threading.Lock()
# State holders ----------------------------------------------------
self.running_node_execution: dict[str, NodeExecutionProgress] = defaultdict(
NodeExecutionProgress
)
self.running_node_evaluation: dict[str, Future] = {}
self.execution_stats = execution_stats
self.execution_stats_lock = execution_stats_lock
execution_queue = ExecutionQueue[NodeExecutionEntry]()
running_node_execution = self.running_node_execution
running_node_evaluation = self.running_node_evaluation
try:
if db_client.get_credits(graph_exec.user_id) <= 0:
raise InsufficientBalanceError(
user_id=graph_exec.user_id,
message="You have no credits left to run an agent.",
balance=0,
amount=1,
)
# Input moderation
try:
if moderation_error := asyncio.run_coroutine_threadsafe(
automod_manager.moderate_graph_execution_inputs(
db_client=get_db_async_client(),
graph_exec=graph_exec,
),
self.node_evaluation_loop,
).result(timeout=30.0):
raise moderation_error
except asyncio.TimeoutError:
log_metadata.warning(
f"Input moderation timed out for graph execution {graph_exec.graph_exec_id}, bypassing moderation and continuing execution"
)
# Continue execution without moderation
# ------------------------------------------------------------
# Prepopulate queue ---------------------------------------
# ------------------------------------------------------------
for node_exec in db_client.get_node_executions(
graph_exec.graph_exec_id,
statuses=[
ExecutionStatus.RUNNING,
ExecutionStatus.QUEUED,
ExecutionStatus.TERMINATED,
ExecutionStatus.REVIEW,
],
):
node_entry = node_exec.to_node_execution_entry(
graph_exec.execution_context
)
execution_queue.add(node_entry)
# ------------------------------------------------------------
# Main dispatch / polling loop -----------------------------
# ------------------------------------------------------------
while not execution_queue.empty():
if cancel.is_set():
break
queued_node_exec = execution_queue.get()
# Check if this node should be skipped due to optional credentials
if queued_node_exec.node_id in graph_exec.nodes_to_skip:
log_metadata.info(
f"Skipping node execution {queued_node_exec.node_exec_id} "
f"for node {queued_node_exec.node_id} - optional credentials not configured"
)
# Mark the node as completed without executing
# No outputs will be produced, so downstream nodes won't trigger
update_node_execution_status(
db_client=db_client,
exec_id=queued_node_exec.node_exec_id,
status=ExecutionStatus.COMPLETED,
)
continue
log_metadata.debug(
f"Dispatching node execution {queued_node_exec.node_exec_id} "
f"for node {queued_node_exec.node_id}",
)
# Charge usage (may raise) ------------------------------
try:
cost, remaining_balance = self._charge_usage(
node_exec=queued_node_exec,
execution_count=increment_execution_count(graph_exec.user_id),
)
with execution_stats_lock:
execution_stats.cost += cost
# Check if we crossed the low balance threshold
self._handle_low_balance(
db_client=db_client,
user_id=graph_exec.user_id,
current_balance=remaining_balance,
transaction_cost=cost,
)
except InsufficientBalanceError as balance_error:
error = balance_error # Set error to trigger FAILED status
node_exec_id = queued_node_exec.node_exec_id
db_client.upsert_execution_output(
node_exec_id=node_exec_id,
output_name="error",
output_data=str(error),
)
update_node_execution_status(
db_client=db_client,
exec_id=node_exec_id,
status=ExecutionStatus.FAILED,
)
self._handle_insufficient_funds_notif(
db_client,
graph_exec.user_id,
graph_exec.graph_id,
error,
)
# Gracefully stop the execution loop
break
# Add input overrides -----------------------------
node_id = queued_node_exec.node_id
if (nodes_input_masks := graph_exec.nodes_input_masks) and (
node_input_mask := nodes_input_masks.get(node_id)
):
queued_node_exec.inputs.update(node_input_mask)
# Kick off async node execution -------------------------
node_execution_task = asyncio.run_coroutine_threadsafe(
self.on_node_execution(
node_exec=queued_node_exec,
node_exec_progress=running_node_execution[node_id],
nodes_input_masks=nodes_input_masks,
graph_stats_pair=(
execution_stats,
execution_stats_lock,
),
nodes_to_skip=graph_exec.nodes_to_skip,
),
self.node_execution_loop,
)
running_node_execution[node_id].add_task(
node_exec_id=queued_node_exec.node_exec_id,
task=node_execution_task,
)
# Poll until queue refills or all inflight work done ----
while execution_queue.empty() and (
running_node_execution or running_node_evaluation
):
if cancel.is_set():
break
# --------------------------------------------------
# Handle inflight evaluations ---------------------
# --------------------------------------------------
node_output_found = False
for node_id, inflight_exec in list(running_node_execution.items()):
if cancel.is_set():
break
# node evaluation future -----------------
if inflight_eval := running_node_evaluation.get(node_id):
if not inflight_eval.done():
continue
try:
inflight_eval.result(timeout=0)
running_node_evaluation.pop(node_id)
except Exception as e:
log_metadata.error(f"Node eval #{node_id} failed: {e}")
# node execution future ---------------------------
if inflight_exec.is_done():
running_node_execution.pop(node_id)
continue
if output := inflight_exec.pop_output():
node_output_found = True
running_node_evaluation[node_id] = (
asyncio.run_coroutine_threadsafe(
self._process_node_output(
output=output,
node_id=node_id,
graph_exec=graph_exec,
log_metadata=log_metadata,
nodes_input_masks=nodes_input_masks,
execution_queue=execution_queue,
),
self.node_evaluation_loop,
)
)
if (
not node_output_found
and execution_queue.empty()
and (running_node_execution or running_node_evaluation)
):
cluster_lock.refresh()
time.sleep(0.1)
# loop done --------------------------------------------------
# Output moderation
try:
if moderation_error := asyncio.run_coroutine_threadsafe(
automod_manager.moderate_graph_execution_outputs(
db_client=get_db_async_client(),
graph_exec_id=graph_exec.graph_exec_id,
user_id=graph_exec.user_id,
graph_id=graph_exec.graph_id,
),
self.node_evaluation_loop,
).result(timeout=30.0):
raise moderation_error
except asyncio.TimeoutError:
log_metadata.warning(
f"Output moderation timed out for graph execution {graph_exec.graph_exec_id}, bypassing moderation and continuing execution"
)
# Continue execution without moderation
# Determine final execution status based on whether there was an error or termination
if cancel.is_set():
execution_status = ExecutionStatus.TERMINATED
elif error is not None:
execution_status = ExecutionStatus.FAILED
else:
if db_client.has_pending_reviews_for_graph_exec(
graph_exec.graph_exec_id
):
execution_status = ExecutionStatus.REVIEW
else:
execution_status = ExecutionStatus.COMPLETED
if error:
execution_stats.error = str(error) or type(error).__name__
return execution_status
except BaseException as e:
error = (
e
if isinstance(e, Exception)
else Exception(f"{e.__class__.__name__}: {e}")
)
if not execution_stats.error:
execution_stats.error = str(error)
known_errors = (InsufficientBalanceError, ModerationError)
if isinstance(error, known_errors):
return ExecutionStatus.FAILED
execution_status = ExecutionStatus.FAILED
log_metadata.exception(
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
)
# Send rate-limited Discord alert for unknown/unexpected errors
send_rate_limited_discord_alert(
"graph_execution",
error,
"unknown_error",
f"🚨 **Unknown Graph Execution Error**\n"
f"User: {graph_exec.user_id}\n"
f"Graph ID: {graph_exec.graph_id}\n"
f"Execution ID: {graph_exec.graph_exec_id}\n"
f"Error Type: {type(error).__name__}\n"
f"Error: {str(error)[:200]}{'...' if len(str(error)) > 200 else ''}\n",
)
raise
finally:
self._cleanup_graph_execution(
execution_queue=execution_queue,
running_node_execution=running_node_execution,
running_node_evaluation=running_node_evaluation,
execution_status=execution_status,
error=error,
graph_exec_id=graph_exec.graph_exec_id,
log_metadata=log_metadata,
db_client=db_client,
)
@error_logged(swallow=True)
def _cleanup_graph_execution(
self,
execution_queue: ExecutionQueue[NodeExecutionEntry],
running_node_execution: dict[str, "NodeExecutionProgress"],
running_node_evaluation: dict[str, Future],
execution_status: ExecutionStatus,
error: Exception | None,
graph_exec_id: str,
log_metadata: LogMetadata,
db_client: "DatabaseManagerClient",
) -> None:
"""
Clean up running node executions and evaluations when graph execution ends.
This method is decorated with @error_logged(swallow=True) to ensure cleanup
never fails in the finally block.
"""
# Cancel and wait for all node executions to complete
for node_id, inflight_exec in running_node_execution.items():
if inflight_exec.is_done():
continue
log_metadata.info(f"Stopping node execution {node_id}")
inflight_exec.stop()
for node_id, inflight_exec in running_node_execution.items():
try:
inflight_exec.wait_for_done(timeout=3600.0)
except TimeoutError:
log_metadata.exception(
f"Node execution #{node_id} did not stop in time, "
"it may be stuck or taking too long."
)
# Wait the remaining inflight evaluations to finish
for node_id, inflight_eval in running_node_evaluation.items():
try:
inflight_eval.result(timeout=3600.0)
except TimeoutError:
log_metadata.exception(
f"Node evaluation #{node_id} did not stop in time, "
"it may be stuck or taking too long."
)
while queued_execution := execution_queue.get_or_none():
update_node_execution_status(
db_client=db_client,
exec_id=queued_execution.node_exec_id,
status=execution_status,
stats={"error": str(error)} if error else None,
)
clean_exec_files(graph_exec_id)
@async_error_logged(swallow=True)
async def _process_node_output(
self,
output: ExecutionOutputEntry,
node_id: str,
graph_exec: GraphExecutionEntry,
log_metadata: LogMetadata,
nodes_input_masks: Optional[NodesInputMasks],
execution_queue: ExecutionQueue[NodeExecutionEntry],
) -> None:
"""Process a node's output, update its status, and enqueue next nodes.
Args:
output: The execution output entry to process
node_id: The ID of the node that produced the output
graph_exec: The graph execution entry
log_metadata: Logger metadata for consistent logging
nodes_input_masks: Optional map of node input overrides
execution_queue: Queue to add next executions to
"""
db_client = get_db_async_client()
log_metadata.debug(f"Enqueue nodes for {node_id}: {output}")
for next_execution in await _enqueue_next_nodes(
db_client=db_client,
node=output.node,
output=output.data,
user_id=graph_exec.user_id,
graph_exec_id=graph_exec.graph_exec_id,
graph_id=graph_exec.graph_id,
graph_version=graph_exec.graph_version,
log_metadata=log_metadata,
nodes_input_masks=nodes_input_masks,
execution_context=graph_exec.execution_context,
):
execution_queue.add(next_execution)
def _handle_agent_run_notif(
self,
db_client: "DatabaseManagerClient",
graph_exec: GraphExecutionEntry,
exec_stats: GraphExecutionStats,
):
metadata = db_client.get_graph_metadata(
graph_exec.graph_id, graph_exec.graph_version
)
outputs = db_client.get_node_executions(
graph_exec.graph_exec_id,
block_ids=[AgentOutputBlock().id],
)
named_outputs = [
{
key: value[0] if key == "name" else value
for key, value in output.output_data.items()
}
for output in outputs
]
queue_notification(
NotificationEventModel(
user_id=graph_exec.user_id,
type=NotificationType.AGENT_RUN,
data=AgentRunData(
outputs=named_outputs,
agent_name=metadata.name if metadata else "Unknown Agent",
credits_used=exec_stats.cost,
execution_time=exec_stats.walltime,
graph_id=graph_exec.graph_id,
node_count=exec_stats.node_count,
),
)
)
def _handle_insufficient_funds_notif(
self,
db_client: "DatabaseManagerClient",
user_id: str,
graph_id: str,
e: InsufficientBalanceError,
):
# Check if we've already sent a notification for this user+agent combo.
# We only send one notification per user per agent until they top up credits.
redis_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
try:
redis_client = redis.get_redis()
# SET NX returns True only if the key was newly set (didn't exist)
is_new_notification = redis_client.set(
redis_key,
"1",
nx=True,
ex=INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS,
)
if not is_new_notification:
# Already notified for this user+agent, skip all notifications
logger.debug(
f"Skipping duplicate insufficient funds notification for "
f"user={user_id}, graph={graph_id}"
)
return
except Exception as redis_error:
# If Redis fails, log and continue to send the notification
# (better to occasionally duplicate than to never notify)
logger.warning(
f"Failed to check/set insufficient funds notification flag in Redis: "
f"{redis_error}"
)
shortfall = abs(e.amount) - e.balance
metadata = db_client.get_graph_metadata(graph_id)
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
# Queue user email notification
queue_notification(
NotificationEventModel(
user_id=user_id,
type=NotificationType.ZERO_BALANCE,
data=ZeroBalanceData(
current_balance=e.balance,
billing_page_link=f"{base_url}/profile/credits",
shortfall=shortfall,
agent_name=metadata.name if metadata else "Unknown Agent",
),
)
)
# Send Discord system alert
try:
user_email = db_client.get_user_email_by_id(user_id)
alert_message = (
f"❌ **Insufficient Funds Alert**\n"
f"User: {user_email or user_id}\n"
f"Agent: {metadata.name if metadata else 'Unknown Agent'}\n"
f"Current balance: ${e.balance / 100:.2f}\n"
f"Attempted cost: ${abs(e.amount) / 100:.2f}\n"
f"Shortfall: ${abs(shortfall) / 100:.2f}\n"
f"[View User Details]({base_url}/admin/spending?search={user_email})"
)
get_notification_manager_client().discord_system_alert(
alert_message, DiscordChannel.PRODUCT
)
except Exception as alert_error:
logger.error(
f"Failed to send insufficient funds Discord alert: {alert_error}"
)
def _handle_low_balance(
self,
db_client: "DatabaseManagerClient",
user_id: str,
current_balance: int,
transaction_cost: int,
):
"""Check and handle low balance scenarios after a transaction"""
LOW_BALANCE_THRESHOLD = settings.config.low_balance_threshold
balance_before = current_balance + transaction_cost
if (
current_balance < LOW_BALANCE_THRESHOLD
and balance_before >= LOW_BALANCE_THRESHOLD
):
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
queue_notification(
NotificationEventModel(
user_id=user_id,
type=NotificationType.LOW_BALANCE,
data=LowBalanceData(
current_balance=current_balance,
billing_page_link=f"{base_url}/profile/credits",
),
)
)
try:
user_email = db_client.get_user_email_by_id(user_id)
alert_message = (
f"⚠️ **Low Balance Alert**\n"
f"User: {user_email or user_id}\n"
f"Balance dropped below ${LOW_BALANCE_THRESHOLD / 100:.2f}\n"
f"Current balance: ${current_balance / 100:.2f}\n"
f"Transaction cost: ${transaction_cost / 100:.2f}\n"
f"[View User Details]({base_url}/admin/spending?search={user_email})"
)
get_notification_manager_client().discord_system_alert(
alert_message, DiscordChannel.PRODUCT
)
except Exception as e:
logger.error(f"Failed to send low balance Discord alert: {e}")
class ExecutionManager(AppProcess):
def __init__(self):
super().__init__()
self.pool_size = settings.config.num_graph_workers
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
self.executor_id = str(uuid.uuid4())
self._executor = None
self._stop_consuming = None
self._cancel_thread = None
self._cancel_client = None
self._run_thread = None
self._run_client = None
self._execution_locks = {}
@property
def cancel_thread(self) -> threading.Thread:
if self._cancel_thread is None:
self._cancel_thread = threading.Thread(
target=lambda: self._consume_execution_cancel(),
daemon=True,
)
return self._cancel_thread
@property
def run_thread(self) -> threading.Thread:
if self._run_thread is None:
self._run_thread = threading.Thread(
target=lambda: self._consume_execution_run(),
daemon=True,
)
return self._run_thread
@property
def stop_consuming(self) -> threading.Event:
if self._stop_consuming is None:
self._stop_consuming = threading.Event()
return self._stop_consuming
@property
def executor(self) -> ThreadPoolExecutor:
if self._executor is None:
self._executor = ThreadPoolExecutor(
max_workers=self.pool_size,
initializer=init_worker,
)
return self._executor
@property
def cancel_client(self) -> SyncRabbitMQ:
if self._cancel_client is None:
self._cancel_client = SyncRabbitMQ(create_execution_queue_config())
return self._cancel_client
@property
def run_client(self) -> SyncRabbitMQ:
if self._run_client is None:
self._run_client = SyncRabbitMQ(create_execution_queue_config())
return self._run_client
def run(self):
logger.info(
f"[{self.service_name}] 🆔 Pod assigned executor_id: {self.executor_id}"
)
logger.info(f"[{self.service_name}] ⏳ Spawn max-{self.pool_size} workers...")
pool_size_gauge.set(self.pool_size)
self._update_prompt_metrics()
start_http_server(settings.config.execution_manager_port)
self.cancel_thread.start()
self.run_thread.start()
while True:
time.sleep(1e5)
@continuous_retry()
def _consume_execution_cancel(self):
if self.stop_consuming.is_set() and not self.active_graph_runs:
logger.info(
f"[{self.service_name}] Stop reconnecting cancel consumer since the service is cleaned up."
)
return
# Check if channel is closed and force reconnection if needed
if not self.cancel_client.is_ready:
self.cancel_client.disconnect()
self.cancel_client.connect()
cancel_channel = self.cancel_client.get_channel()
cancel_channel.basic_consume(
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
on_message_callback=self._handle_cancel_message,
auto_ack=True,
)
logger.info(f"[{self.service_name}] ⏳ Starting cancel message consumer...")
cancel_channel.start_consuming()
if not self.stop_consuming.is_set() or self.active_graph_runs:
raise RuntimeError(
f"[{self.service_name}] ❌ cancel message consumer is stopped: {cancel_channel}"
)
logger.info(
f"[{self.service_name}] ✅ Cancel message consumer stopped gracefully"
)
@continuous_retry()
def _consume_execution_run(self):
# Long-running executions are handled by:
# 1. Long consumer timeout (x-consumer-timeout) allows long running agent
# 2. Enhanced connection settings (5 retries, 1s delay) for quick reconnection
# 3. Process monitoring ensures failed executors release messages back to queue
if self.stop_consuming.is_set():
logger.info(
f"[{self.service_name}] Stop reconnecting execution consumer since the service is cleaned up."
)
return
# Check if channel is closed and force reconnection if needed
if not self.run_client.is_ready:
self.run_client.disconnect()
self.run_client.connect()
run_channel = self.run_client.get_channel()
run_channel.basic_qos(prefetch_count=self.pool_size)
# Configure consumer for long-running graph executions
# auto_ack=False: Don't acknowledge messages until execution completes (prevents data loss)
run_channel.basic_consume(
queue=GRAPH_EXECUTION_QUEUE_NAME,
on_message_callback=self._handle_run_message,
auto_ack=False,
consumer_tag="graph_execution_consumer",
)
run_channel.confirm_delivery()
logger.info(f"[{self.service_name}] ⏳ Starting to consume run messages...")
run_channel.start_consuming()
if not self.stop_consuming.is_set():
raise RuntimeError(
f"[{self.service_name}] ❌ run message consumer is stopped: {run_channel}"
)
logger.info(f"[{self.service_name}] ✅ Run message consumer stopped gracefully")
@error_logged(swallow=True)
def _handle_cancel_message(
self,
_channel: BlockingChannel,
_method: Basic.Deliver,
_properties: BasicProperties,
body: bytes,
):
"""
Called whenever we receive a CANCEL message from the queue.
(With auto_ack=True, message is considered 'acked' automatically.)
"""
request = CancelExecutionEvent.model_validate_json(body)
graph_exec_id = request.graph_exec_id
if not graph_exec_id:
logger.warning(
f"[{self.service_name}] Cancel message missing 'graph_exec_id'"
)
return
if graph_exec_id not in self.active_graph_runs:
logger.debug(
f"[{self.service_name}] Cancel received for {graph_exec_id} but not active."
)
return
_, cancel_event = self.active_graph_runs[graph_exec_id]
logger.info(f"[{self.service_name}] Received cancel for {graph_exec_id}")
if not cancel_event.is_set():
cancel_event.set()
else:
logger.debug(
f"[{self.service_name}] Cancel already set for {graph_exec_id}"
)
def _handle_run_message(
self,
_channel: BlockingChannel,
method: Basic.Deliver,
_properties: BasicProperties,
body: bytes,
):
delivery_tag = method.delivery_tag
@func_retry
def _ack_message(reject: bool, requeue: bool):
"""
Acknowledge or reject the message based on execution status.
Args:
reject: Whether to reject the message
requeue: Whether to requeue the message
"""
# Connection can be lost, so always get a fresh channel
channel = self.run_client.get_channel()
if reject:
if requeue and settings.config.requeue_by_republishing:
# Send rejected message to back of queue using republishing
def _republish_to_back():
try:
# First republish to back of queue
self.run_client.publish_message(
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
message=body.decode(), # publish_message expects string, not bytes
exchange=GRAPH_EXECUTION_EXCHANGE,
)
# Then reject without requeue (message already republished)
channel.basic_nack(delivery_tag, requeue=False)
logger.info("Message requeued to back of queue")
except Exception as e:
logger.error(
f"[{self.service_name}] Failed to requeue message to back: {e}"
)
# Fall back to traditional requeue on failure
channel.basic_nack(delivery_tag, requeue=True)
channel.connection.add_callback_threadsafe(_republish_to_back)
else:
# Traditional requeue (goes to front) or no requeue
channel.connection.add_callback_threadsafe(
lambda: channel.basic_nack(delivery_tag, requeue=requeue)
)
else:
channel.connection.add_callback_threadsafe(
lambda: channel.basic_ack(delivery_tag)
)
# Check if we're shutting down - reject new messages but keep connection alive
if self.stop_consuming.is_set():
logger.info(
f"[{self.service_name}] Rejecting new execution during shutdown"
)
_ack_message(reject=True, requeue=True)
return
# Check if we can accept more runs
self._cleanup_completed_runs()
if len(self.active_graph_runs) >= self.pool_size:
_ack_message(reject=True, requeue=True)
return
try:
graph_exec_entry = GraphExecutionEntry.model_validate_json(body)
except Exception as e:
logger.error(
f"[{self.service_name}] Could not parse run message: {e}, body={body}"
)
_ack_message(reject=True, requeue=False)
return
graph_exec_id = graph_exec_entry.graph_exec_id
user_id = graph_exec_entry.user_id
graph_id = graph_exec_entry.graph_id
root_exec_id = graph_exec_entry.execution_context.root_execution_id
parent_exec_id = graph_exec_entry.execution_context.parent_execution_id
logger.info(
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}, user_id={user_id}, executor_id={self.executor_id}"
+ (f", root={root_exec_id}" if root_exec_id else "")
+ (f", parent={parent_exec_id}" if parent_exec_id else "")
)
# Check if root execution is already terminated (prevents orphaned child executions)
if root_exec_id and root_exec_id != graph_exec_id:
parent_exec = get_db_client().get_graph_execution_meta(
execution_id=root_exec_id,
user_id=user_id,
)
if parent_exec and parent_exec.status == ExecutionStatus.TERMINATED:
logger.info(
f"[{self.service_name}] Skipping execution {graph_exec_id} - parent {root_exec_id} is TERMINATED"
)
# Mark this child as terminated since parent was stopped
get_db_client().update_graph_execution_stats(
graph_exec_id=graph_exec_id,
status=ExecutionStatus.TERMINATED,
)
_ack_message(reject=False, requeue=False)
return
# Check user rate limit before processing
try:
# Only check executions from the last 24 hours for performance
current_running_count = get_db_client().get_graph_executions_count(
user_id=user_id,
graph_id=graph_id,
statuses=[ExecutionStatus.RUNNING],
created_time_gte=datetime.now(timezone.utc) - timedelta(hours=24),
)
if (
current_running_count
>= settings.config.max_concurrent_graph_executions_per_user
):
logger.warning(
f"[{self.service_name}] Rate limit exceeded for user {user_id} on graph {graph_id}: "
f"{current_running_count}/{settings.config.max_concurrent_graph_executions_per_user} running executions"
)
_ack_message(reject=True, requeue=True)
return
except Exception as e:
logger.error(
f"[{self.service_name}] Failed to check rate limit for user {user_id}: {e}, proceeding with execution"
)
# If rate limit check fails, proceed to avoid blocking executions
# Check for local duplicate execution first
if graph_exec_id in self.active_graph_runs:
logger.warning(
f"[{self.service_name}] Graph {graph_exec_id} already running locally; rejecting duplicate."
)
_ack_message(reject=True, requeue=True)
return
# Try to acquire cluster-wide execution lock
cluster_lock = ClusterLock(
redis=redis.get_redis(),
key=f"exec_lock:{graph_exec_id}",
owner_id=self.executor_id,
timeout=settings.config.cluster_lock_timeout,
)
current_owner = cluster_lock.try_acquire()
if current_owner != self.executor_id:
# Either someone else has it or Redis is unavailable
if current_owner is not None:
logger.warning(
f"[{self.service_name}] Graph {graph_exec_id} already running on pod {current_owner}, current executor_id={self.executor_id}"
)
_ack_message(reject=True, requeue=False)
else:
logger.warning(
f"[{self.service_name}] Could not acquire lock for {graph_exec_id} - Redis unavailable"
)
_ack_message(reject=True, requeue=True)
return
# Wrap entire block after successful lock acquisition
try:
self._execution_locks[graph_exec_id] = cluster_lock
logger.info(
f"[{self.service_name}] Successfully acquired cluster lock for {graph_exec_id}, executor_id={self.executor_id}"
)
cancel_event = threading.Event()
future = self.executor.submit(
execute_graph, graph_exec_entry, cancel_event, cluster_lock
)
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
except Exception as e:
logger.warning(
f"[{self.service_name}] Failed to setup execution for {graph_exec_id}: {type(e).__name__}: {e}"
)
# Release cluster lock before requeue
cluster_lock.release()
if graph_exec_id in self._execution_locks:
del self._execution_locks[graph_exec_id]
_ack_message(reject=True, requeue=True)
return
self._update_prompt_metrics()
def _on_run_done(f: Future):
logger.info(f"[{self.service_name}] Run completed for {graph_exec_id}")
try:
if exec_error := f.exception():
logger.error(
f"[{self.service_name}] Execution for {graph_exec_id} failed: {type(exec_error)} {exec_error}"
)
_ack_message(reject=True, requeue=True)
else:
_ack_message(reject=False, requeue=False)
except BaseException as e:
logger.exception(
f"[{self.service_name}] Error in run completion callback: {e}"
)
finally:
# Release the cluster-wide execution lock
if graph_exec_id in self._execution_locks:
logger.info(
f"[{self.service_name}] Releasing cluster lock for {graph_exec_id}, executor_id={self.executor_id}"
)
self._execution_locks[graph_exec_id].release()
del self._execution_locks[graph_exec_id]
self._cleanup_completed_runs()
future.add_done_callback(_on_run_done)
def _cleanup_completed_runs(self) -> list[str]:
"""Remove completed futures from active_graph_runs and update metrics"""
completed_runs = []
for graph_exec_id, (future, _) in self.active_graph_runs.items():
if future.done():
completed_runs.append(graph_exec_id)
for geid in completed_runs:
logger.info(f"[{self.service_name}] ✅ Cleaned up completed run {geid}")
self.active_graph_runs.pop(geid, None)
self._update_prompt_metrics()
return completed_runs
def _update_prompt_metrics(self):
active_count = len(self.active_graph_runs)
active_runs_gauge.set(active_count)
if self.stop_consuming.is_set():
utilization_gauge.set(1.0)
else:
utilization_gauge.set(active_count / self.pool_size)
def _stop_message_consumers(
self, thread: threading.Thread, client: SyncRabbitMQ, prefix: str
):
try:
channel = client.get_channel()
channel.connection.add_callback_threadsafe(lambda: channel.stop_consuming())
try:
thread.join(timeout=300)
except TimeoutError:
logger.error(
f"{prefix} ⚠️ Run thread did not finish in time, forcing disconnect"
)
client.disconnect()
logger.info(f"{prefix} ✅ Run client disconnected")
except Exception as e:
logger.error(f"{prefix} ⚠️ Error disconnecting run client: {type(e)} {e}")
def cleanup(self):
"""Override cleanup to implement graceful shutdown with active execution waiting."""
prefix = f"[{self.service_name}][on_graph_executor_stop {os.getpid()}]"
logger.info(f"{prefix} 🧹 Starting graceful shutdown...")
# Signal the consumer thread to stop (thread-safe)
try:
self.stop_consuming.set()
run_channel = self.run_client.get_channel()
run_channel.connection.add_callback_threadsafe(
lambda: run_channel.stop_consuming()
)
logger.info(f"{prefix} ✅ Exec consumer has been signaled to stop")
except Exception as e:
logger.error(f"{prefix} ⚠️ Error signaling consumer to stop: {type(e)} {e}")
# Wait for active executions to complete
if self.active_graph_runs:
logger.info(
f"{prefix} ⏳ Waiting for {len(self.active_graph_runs)} active executions to complete..."
)
max_wait = GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS
wait_interval = 5
waited = 0
while waited < max_wait:
self._cleanup_completed_runs()
if not self.active_graph_runs:
logger.info(f"{prefix} ✅ All active executions completed")
break
else:
ids = [k.split("-")[0] for k in self.active_graph_runs.keys()]
logger.info(
f"{prefix} ⏳ Still waiting for {len(self.active_graph_runs)} executions: {ids}"
)
for graph_exec_id in self.active_graph_runs:
if lock := self._execution_locks.get(graph_exec_id):
lock.refresh()
time.sleep(wait_interval)
waited += wait_interval
if self.active_graph_runs:
logger.error(
f"{prefix} ⚠️ {len(self.active_graph_runs)} executions still running after {max_wait}s"
)
else:
logger.info(f"{prefix} ✅ All executions completed gracefully")
# Shutdown the executor
try:
self.executor.shutdown(cancel_futures=True, wait=False)
logger.info(f"{prefix} ✅ Executor shutdown completed")
except Exception as e:
logger.error(f"{prefix} ⚠️ Error during executor shutdown: {type(e)} {e}")
# Release remaining execution locks
try:
for lock in self._execution_locks.values():
lock.release()
self._execution_locks.clear()
logger.info(f"{prefix} ✅ Released execution locks")
except Exception as e:
logger.warning(f"{prefix} ⚠️ Failed to release all locks: {e}")
# Disconnect the run execution consumer
self._stop_message_consumers(
self.run_thread,
self.run_client,
prefix + " [run-consumer]",
)
self._stop_message_consumers(
self.cancel_thread,
self.cancel_client,
prefix + " [cancel-consumer]",
)
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
super().cleanup()
# ------- UTILITIES ------- #
def get_db_client() -> "DatabaseManagerClient":
return get_database_manager_client()
def get_db_async_client() -> "DatabaseManagerAsyncClient":
return get_database_manager_async_client()
@func_retry
async def send_async_execution_update(
entry: GraphExecution | NodeExecutionResult | None,
) -> None:
if entry is None:
return
await get_async_execution_event_bus().publish(entry)
@func_retry
def send_execution_update(entry: GraphExecution | NodeExecutionResult | None):
if entry is None:
return
return get_execution_event_bus().publish(entry)
async def async_update_node_execution_status(
db_client: "DatabaseManagerAsyncClient",
exec_id: str,
status: ExecutionStatus,
execution_data: BlockInput | None = None,
stats: dict[str, Any] | None = None,
) -> NodeExecutionResult:
"""Sets status and fetches+broadcasts the latest state of the node execution"""
exec_update = await db_client.update_node_execution_status(
exec_id, status, execution_data, stats
)
await send_async_execution_update(exec_update)
return exec_update
def update_node_execution_status(
db_client: "DatabaseManagerClient",
exec_id: str,
status: ExecutionStatus,
execution_data: BlockInput | None = None,
stats: dict[str, Any] | None = None,
) -> NodeExecutionResult:
"""Sets status and fetches+broadcasts the latest state of the node execution"""
exec_update = db_client.update_node_execution_status(
exec_id, status, execution_data, stats
)
send_execution_update(exec_update)
return exec_update
async def async_update_graph_execution_state(
db_client: "DatabaseManagerAsyncClient",
graph_exec_id: str,
status: ExecutionStatus | None = None,
stats: GraphExecutionStats | None = None,
) -> GraphExecution | None:
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
graph_update = await db_client.update_graph_execution_stats(
graph_exec_id, status, stats
)
if graph_update:
await send_async_execution_update(graph_update)
else:
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
return graph_update
def update_graph_execution_state(
db_client: "DatabaseManagerClient",
graph_exec_id: str,
status: ExecutionStatus | None = None,
stats: GraphExecutionStats | None = None,
) -> GraphExecution | None:
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
graph_update = db_client.update_graph_execution_stats(graph_exec_id, status, stats)
if graph_update:
send_execution_update(graph_update)
else:
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
return graph_update
@asynccontextmanager
async def synchronized(key: str, timeout: int = settings.config.cluster_lock_timeout):
r = await redis.get_redis_async()
lock: AsyncRedisLock = r.lock(f"lock:{key}", timeout=timeout)
try:
await lock.acquire()
yield
finally:
if await lock.locked() and await lock.owned():
try:
await lock.release()
except Exception as e:
logger.warning(f"Failed to release lock for key {key}: {e}")
def increment_execution_count(user_id: str) -> int:
"""
Increment the execution count for a given user,
this will be used to charge the user for the execution cost.
"""
r = redis.get_redis()
k = f"uec:{user_id}" # User Execution Count global key
counter = cast(int, r.incr(k))
if counter == 1:
r.expire(k, settings.config.execution_counter_expiration_time)
return counter