mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-08 22:05:08 -05:00
Implements persistent User Workspace storage for CoPilot, enabling
blocks to save and retrieve files across sessions. Files are stored in
session-scoped virtual paths (`/sessions/{session_id}/`).
Fixes SECRT-1833
### Changes 🏗️
**Database & Storage:**
- Add `UserWorkspace` and `UserWorkspaceFile` Prisma models
- Implement `WorkspaceStorageBackend` abstraction (GCS for cloud, local
filesystem for self-hosted)
- Add `workspace_id` and `session_id` fields to `ExecutionContext`
**Backend API:**
- Add REST endpoints: `GET/POST /api/workspace/files`, `GET/DELETE
/api/workspace/files/{id}`, `GET /api/workspace/files/{id}/download`
- Add CoPilot tools: `list_workspace_files`, `read_workspace_file`,
`write_workspace_file`
- Integrate workspace storage into `store_media_file()` - returns
`workspace://file-id` references
**Block Updates:**
- Refactor all file-handling blocks to use unified `ExecutionContext`
parameter
- Update media-generating blocks to persist outputs to workspace
(AIImageGenerator, AIImageCustomizer, FluxKontext, TalkingHead, FAL
video, Bannerbear, etc.)
**Frontend:**
- Render `workspace://` image references in chat via proxy endpoint
- Add "AI cannot see this image" overlay indicator
**CoPilot Context Mapping:**
- Session = Agent (graph_id) = Run (graph_exec_id)
- Files scoped to `/sessions/{session_id}/`
### Checklist 📋
#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [ ] I have tested my changes according to the test plan:
- [ ] Create CoPilot session, generate image with AIImageGeneratorBlock
- [ ] Verify image returns `workspace://file-id` (not base64)
- [ ] Verify image renders in chat with visibility indicator
- [ ] Verify workspace files persist across sessions
- [ ] Test list/read/write workspace files via CoPilot tools
- [ ] Test local storage backend for self-hosted deployments
#### For configuration changes:
- [x] `.env.default` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)
🤖 Generated with [Claude Code](https://claude.ai/code)
<!-- CURSOR_SUMMARY -->
---
> [!NOTE]
> **Medium Risk**
> Introduces a new persistent file-storage surface area (DB tables,
storage backends, download API, and chat tools) and rewires
`store_media_file()`/block execution context across many blocks, so
regressions could impact file handling, access control, or storage
costs.
>
> **Overview**
> Adds a **persistent per-user Workspace** (new
`UserWorkspace`/`UserWorkspaceFile` models plus `WorkspaceManager` +
`WorkspaceStorageBackend` with GCS/local implementations) and wires it
into the API via a new `/api/workspace/files/{file_id}/download` route
(including header-sanitized `Content-Disposition`) and shutdown
lifecycle hooks.
>
> Extends `ExecutionContext` to carry execution identity +
`workspace_id`/`session_id`, updates executor tooling to clone
node-specific contexts, and updates `run_block` (CoPilot) to create a
session-scoped workspace and synthetic graph/run/node IDs.
>
> Refactors `store_media_file()` to require `execution_context` +
`return_format` and to support `workspace://` references; migrates many
media/file-handling blocks and related tests to the new API and to
persist generated media as `workspace://...` (or fall back to data URIs
outside CoPilot), and adds CoPilot chat tools for
listing/reading/writing/deleting workspace files with safeguards against
context bloat.
>
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
6abc70f793. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->
---------
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
2073 lines
81 KiB
Python
2073 lines
81 KiB
Python
import asyncio
|
||
import logging
|
||
import os
|
||
import threading
|
||
import time
|
||
import uuid
|
||
from collections import defaultdict
|
||
from concurrent.futures import Future, ThreadPoolExecutor
|
||
from contextlib import asynccontextmanager
|
||
from datetime import datetime, timedelta, timezone
|
||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||
|
||
import sentry_sdk
|
||
from pika.adapters.blocking_connection import BlockingChannel
|
||
from pika.spec import Basic, BasicProperties
|
||
from prometheus_client import Gauge, start_http_server
|
||
from redis.asyncio.lock import Lock as AsyncRedisLock
|
||
|
||
from backend.blocks.agent import AgentExecutorBlock
|
||
from backend.blocks.io import AgentOutputBlock
|
||
from backend.data import redis_client as redis
|
||
from backend.data.block import (
|
||
BlockInput,
|
||
BlockOutput,
|
||
BlockOutputEntry,
|
||
BlockSchema,
|
||
get_block,
|
||
)
|
||
from backend.data.credit import UsageTransactionMetadata
|
||
from backend.data.dynamic_fields import parse_execution_output
|
||
from backend.data.execution import (
|
||
ExecutionContext,
|
||
ExecutionQueue,
|
||
ExecutionStatus,
|
||
GraphExecution,
|
||
GraphExecutionEntry,
|
||
NodeExecutionEntry,
|
||
NodeExecutionResult,
|
||
NodesInputMasks,
|
||
)
|
||
from backend.data.graph import Link, Node
|
||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||
from backend.data.notifications import (
|
||
AgentRunData,
|
||
LowBalanceData,
|
||
NotificationEventModel,
|
||
NotificationType,
|
||
ZeroBalanceData,
|
||
)
|
||
from backend.data.rabbitmq import SyncRabbitMQ
|
||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||
from backend.notifications.notifications import queue_notification
|
||
from backend.util import json
|
||
from backend.util.clients import (
|
||
get_async_execution_event_bus,
|
||
get_database_manager_async_client,
|
||
get_database_manager_client,
|
||
get_execution_event_bus,
|
||
get_notification_manager_client,
|
||
)
|
||
from backend.util.decorator import (
|
||
async_error_logged,
|
||
async_time_measured,
|
||
error_logged,
|
||
time_measured,
|
||
)
|
||
from backend.util.exceptions import InsufficientBalanceError, ModerationError
|
||
from backend.util.file import clean_exec_files
|
||
from backend.util.logging import TruncatedLogger, configure_logging
|
||
from backend.util.metrics import DiscordChannel
|
||
from backend.util.process import AppProcess, set_service_name
|
||
from backend.util.retry import (
|
||
continuous_retry,
|
||
func_retry,
|
||
send_rate_limited_discord_alert,
|
||
)
|
||
from backend.util.settings import Settings
|
||
|
||
from .activity_status_generator import generate_activity_status_for_execution
|
||
from .automod.manager import automod_manager
|
||
from .cluster_lock import ClusterLock
|
||
from .utils import (
|
||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||
GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||
GRAPH_EXECUTION_EXCHANGE,
|
||
GRAPH_EXECUTION_QUEUE_NAME,
|
||
GRAPH_EXECUTION_ROUTING_KEY,
|
||
CancelExecutionEvent,
|
||
ExecutionOutputEntry,
|
||
LogMetadata,
|
||
NodeExecutionProgress,
|
||
block_usage_cost,
|
||
create_execution_queue_config,
|
||
execution_usage_cost,
|
||
validate_exec,
|
||
)
|
||
|
||
if TYPE_CHECKING:
|
||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||
|
||
|
||
_logger = logging.getLogger(__name__)
|
||
logger = TruncatedLogger(_logger, prefix="[GraphExecutor]")
|
||
settings = Settings()
|
||
|
||
active_runs_gauge = Gauge(
|
||
"execution_manager_active_runs", "Number of active graph runs"
|
||
)
|
||
pool_size_gauge = Gauge(
|
||
"execution_manager_pool_size", "Maximum number of graph workers"
|
||
)
|
||
utilization_gauge = Gauge(
|
||
"execution_manager_utilization_ratio",
|
||
"Ratio of active graph runs to max graph workers",
|
||
)
|
||
|
||
# Redis key prefix for tracking insufficient funds Discord notifications.
|
||
# We only send one notification per user per agent until they top up credits.
|
||
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX = "insufficient_funds_discord_notified"
|
||
# TTL for the notification flag (30 days) - acts as a fallback cleanup
|
||
INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS = 30 * 24 * 60 * 60
|
||
|
||
|
||
async def clear_insufficient_funds_notifications(user_id: str) -> int:
|
||
"""
|
||
Clear all insufficient funds notification flags for a user.
|
||
|
||
This should be called when a user tops up their credits, allowing
|
||
Discord notifications to be sent again if they run out of funds.
|
||
|
||
Args:
|
||
user_id: The user ID to clear notifications for.
|
||
|
||
Returns:
|
||
The number of keys that were deleted.
|
||
"""
|
||
try:
|
||
redis_client = await redis.get_redis_async()
|
||
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||
keys = [key async for key in redis_client.scan_iter(match=pattern)]
|
||
if keys:
|
||
return await redis_client.delete(*keys)
|
||
return 0
|
||
except Exception as e:
|
||
logger.warning(
|
||
f"Failed to clear insufficient funds notification flags for user "
|
||
f"{user_id}: {e}"
|
||
)
|
||
return 0
|
||
|
||
|
||
# Thread-local storage for ExecutionProcessor instances
|
||
_tls = threading.local()
|
||
|
||
|
||
def init_worker():
|
||
"""Initialize ExecutionProcessor instance in thread-local storage"""
|
||
_tls.processor = ExecutionProcessor()
|
||
_tls.processor.on_graph_executor_start()
|
||
|
||
|
||
def execute_graph(
|
||
graph_exec_entry: "GraphExecutionEntry",
|
||
cancel_event: threading.Event,
|
||
cluster_lock: ClusterLock,
|
||
):
|
||
"""Execute graph using thread-local ExecutionProcessor instance"""
|
||
processor: ExecutionProcessor = _tls.processor
|
||
return processor.on_graph_execution(graph_exec_entry, cancel_event, cluster_lock)
|
||
|
||
|
||
T = TypeVar("T")
|
||
|
||
|
||
async def execute_node(
|
||
node: Node,
|
||
data: NodeExecutionEntry,
|
||
execution_processor: "ExecutionProcessor",
|
||
execution_stats: NodeExecutionStats | None = None,
|
||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||
nodes_to_skip: Optional[set[str]] = None,
|
||
) -> BlockOutput:
|
||
"""
|
||
Execute a node in the graph. This will trigger a block execution on a node,
|
||
persist the execution result, and return the subsequent node to be executed.
|
||
|
||
Args:
|
||
db_client: The client to send execution updates to the server.
|
||
creds_manager: The manager to acquire and release credentials.
|
||
data: The execution data for executing the current node.
|
||
execution_stats: The execution statistics to be updated.
|
||
|
||
Returns:
|
||
The subsequent node to be enqueued, or None if there is no subsequent node.
|
||
"""
|
||
user_id = data.user_id
|
||
graph_exec_id = data.graph_exec_id
|
||
graph_id = data.graph_id
|
||
graph_version = data.graph_version
|
||
node_exec_id = data.node_exec_id
|
||
node_id = data.node_id
|
||
node_block = node.block
|
||
execution_context = data.execution_context
|
||
creds_manager = execution_processor.creds_manager
|
||
|
||
log_metadata = LogMetadata(
|
||
logger=_logger,
|
||
user_id=user_id,
|
||
graph_eid=graph_exec_id,
|
||
graph_id=graph_id,
|
||
node_eid=node_exec_id,
|
||
node_id=node_id,
|
||
block_name=node_block.name,
|
||
)
|
||
|
||
# Sanity check: validate the execution input.
|
||
input_data, error = validate_exec(node, data.inputs, resolve_input=False)
|
||
if input_data is None:
|
||
log_metadata.error(f"Skip execution, input validation error: {error}")
|
||
yield "error", error
|
||
return
|
||
|
||
# Re-shape the input data for agent block.
|
||
# AgentExecutorBlock specially separate the node input_data & its input_default.
|
||
if isinstance(node_block, AgentExecutorBlock):
|
||
_input_data = AgentExecutorBlock.Input(**node.input_default)
|
||
_input_data.inputs = input_data
|
||
if nodes_input_masks:
|
||
_input_data.nodes_input_masks = nodes_input_masks
|
||
_input_data.user_id = user_id
|
||
input_data = _input_data.model_dump()
|
||
data.inputs = input_data
|
||
|
||
# Execute the node
|
||
input_data_str = json.dumps(input_data)
|
||
input_size = len(input_data_str)
|
||
log_metadata.debug("Executed node with input", input=input_data_str)
|
||
|
||
# Create node-specific execution context to avoid race conditions
|
||
# (multiple nodes can execute concurrently and would otherwise mutate shared state)
|
||
execution_context = execution_context.model_copy(
|
||
update={"node_id": node_id, "node_exec_id": node_exec_id}
|
||
)
|
||
|
||
# Inject extra execution arguments for the blocks via kwargs
|
||
# Keep individual kwargs for backwards compatibility with existing blocks
|
||
extra_exec_kwargs: dict = {
|
||
"graph_id": graph_id,
|
||
"graph_version": graph_version,
|
||
"node_id": node_id,
|
||
"graph_exec_id": graph_exec_id,
|
||
"node_exec_id": node_exec_id,
|
||
"user_id": user_id,
|
||
"execution_context": execution_context,
|
||
"execution_processor": execution_processor,
|
||
"nodes_to_skip": nodes_to_skip or set(),
|
||
}
|
||
|
||
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
||
# changes during execution. ⚠️ This means a set of credentials can only be used by
|
||
# one (running) block at a time; simultaneous execution of blocks using same
|
||
# credentials is not supported.
|
||
creds_locks: list[AsyncRedisLock] = []
|
||
input_model = cast(type[BlockSchema], node_block.input_schema)
|
||
|
||
# Handle regular credentials fields
|
||
for field_name, input_type in input_model.get_credentials_fields().items():
|
||
credentials_meta = input_type(**input_data[field_name])
|
||
credentials, lock = await creds_manager.acquire(user_id, credentials_meta.id)
|
||
creds_locks.append(lock)
|
||
extra_exec_kwargs[field_name] = credentials
|
||
|
||
# Handle auto-generated credentials (e.g., from GoogleDriveFileInput)
|
||
for kwarg_name, info in input_model.get_auto_credentials_fields().items():
|
||
field_name = info["field_name"]
|
||
field_data = input_data.get(field_name)
|
||
if field_data and isinstance(field_data, dict):
|
||
# Check if _credentials_id key exists in the field data
|
||
if "_credentials_id" in field_data:
|
||
cred_id = field_data["_credentials_id"]
|
||
if cred_id:
|
||
# Credential ID provided - acquire credentials
|
||
provider = info.get("config", {}).get(
|
||
"provider", "external service"
|
||
)
|
||
file_name = field_data.get("name", "selected file")
|
||
try:
|
||
credentials, lock = await creds_manager.acquire(
|
||
user_id, cred_id
|
||
)
|
||
creds_locks.append(lock)
|
||
extra_exec_kwargs[kwarg_name] = credentials
|
||
except ValueError:
|
||
# Credential was deleted or doesn't exist
|
||
raise ValueError(
|
||
f"Authentication expired for '{file_name}' in field '{field_name}'. "
|
||
f"The saved {provider.capitalize()} credentials no longer exist. "
|
||
f"Please re-select the file to re-authenticate."
|
||
)
|
||
# else: _credentials_id is explicitly None, skip credentials (for chained data)
|
||
else:
|
||
# _credentials_id key missing entirely - this is an error
|
||
provider = info.get("config", {}).get("provider", "external service")
|
||
file_name = field_data.get("name", "selected file")
|
||
raise ValueError(
|
||
f"Authentication missing for '{file_name}' in field '{field_name}'. "
|
||
f"Please re-select the file to authenticate with {provider.capitalize()}."
|
||
)
|
||
|
||
output_size = 0
|
||
|
||
# sentry tracking nonsense to get user counts for blocks because isolation scopes don't work :(
|
||
scope = sentry_sdk.get_current_scope()
|
||
|
||
# save the tags
|
||
original_user = scope._user
|
||
original_tags = dict(scope._tags) if scope._tags else {}
|
||
# Set user ID for error tracking
|
||
scope.set_user({"id": user_id})
|
||
|
||
scope.set_tag("graph_id", graph_id)
|
||
scope.set_tag("node_id", node_id)
|
||
scope.set_tag("block_name", node_block.name)
|
||
scope.set_tag("block_id", node_block.id)
|
||
for k, v in execution_context.model_dump().items():
|
||
scope.set_tag(f"execution_context.{k}", v)
|
||
|
||
try:
|
||
async for output_name, output_data in node_block.execute(
|
||
input_data, **extra_exec_kwargs
|
||
):
|
||
output_data = json.to_dict(output_data)
|
||
output_size += len(json.dumps(output_data))
|
||
log_metadata.debug("Node produced output", **{output_name: output_data})
|
||
yield output_name, output_data
|
||
except Exception as ex:
|
||
# Capture exception WITH context still set before restoring scope
|
||
sentry_sdk.capture_exception(error=ex, scope=scope)
|
||
sentry_sdk.flush() # Ensure it's sent before we restore scope
|
||
# Re-raise to maintain normal error flow
|
||
raise
|
||
finally:
|
||
# Ensure all credentials are released even if execution fails
|
||
for creds_lock in creds_locks:
|
||
if (
|
||
creds_lock
|
||
and (await creds_lock.locked())
|
||
and (await creds_lock.owned())
|
||
):
|
||
try:
|
||
await creds_lock.release()
|
||
except Exception as e:
|
||
log_metadata.error(f"Failed to release credentials lock: {e}")
|
||
|
||
# Update execution stats
|
||
if execution_stats is not None:
|
||
execution_stats += node_block.execution_stats
|
||
execution_stats.input_size = input_size
|
||
execution_stats.output_size = output_size
|
||
|
||
# Restore scope AFTER error has been captured
|
||
scope._user = original_user
|
||
scope._tags = original_tags
|
||
|
||
|
||
async def _enqueue_next_nodes(
|
||
db_client: "DatabaseManagerAsyncClient",
|
||
node: Node,
|
||
output: BlockOutputEntry,
|
||
user_id: str,
|
||
graph_exec_id: str,
|
||
graph_id: str,
|
||
graph_version: int,
|
||
log_metadata: LogMetadata,
|
||
nodes_input_masks: Optional[NodesInputMasks],
|
||
execution_context: ExecutionContext,
|
||
) -> list[NodeExecutionEntry]:
|
||
async def add_enqueued_execution(
|
||
node_exec_id: str, node_id: str, block_id: str, data: BlockInput
|
||
) -> NodeExecutionEntry:
|
||
await async_update_node_execution_status(
|
||
db_client=db_client,
|
||
exec_id=node_exec_id,
|
||
status=ExecutionStatus.QUEUED,
|
||
execution_data=data,
|
||
)
|
||
return NodeExecutionEntry(
|
||
user_id=user_id,
|
||
graph_exec_id=graph_exec_id,
|
||
graph_id=graph_id,
|
||
graph_version=graph_version,
|
||
node_exec_id=node_exec_id,
|
||
node_id=node_id,
|
||
block_id=block_id,
|
||
inputs=data,
|
||
execution_context=execution_context,
|
||
)
|
||
|
||
async def register_next_executions(node_link: Link) -> list[NodeExecutionEntry]:
|
||
try:
|
||
return await _register_next_executions(node_link)
|
||
except Exception as e:
|
||
log_metadata.exception(f"Failed to register next executions: {e}")
|
||
return []
|
||
|
||
async def _register_next_executions(node_link: Link) -> list[NodeExecutionEntry]:
|
||
enqueued_executions = []
|
||
next_output_name = node_link.source_name
|
||
next_input_name = node_link.sink_name
|
||
next_node_id = node_link.sink_id
|
||
|
||
output_name, _ = output
|
||
next_data = parse_execution_output(
|
||
output, next_output_name, next_node_id, next_input_name
|
||
)
|
||
if next_data is None and output_name != next_output_name:
|
||
return enqueued_executions
|
||
next_node = await db_client.get_node(next_node_id)
|
||
|
||
# Multiple node can register the same next node, we need this to be atomic
|
||
# To avoid same execution to be enqueued multiple times,
|
||
# Or the same input to be consumed multiple times.
|
||
async with synchronized(f"upsert_input-{next_node_id}-{graph_exec_id}"):
|
||
# Add output data to the earliest incomplete execution, or create a new one.
|
||
next_node_exec, next_node_input = await db_client.upsert_execution_input(
|
||
node_id=next_node_id,
|
||
graph_exec_id=graph_exec_id,
|
||
input_name=next_input_name,
|
||
input_data=next_data,
|
||
)
|
||
next_node_exec_id = next_node_exec.node_exec_id
|
||
await send_async_execution_update(next_node_exec)
|
||
|
||
# Complete missing static input pins data using the last execution input.
|
||
static_link_names = {
|
||
link.sink_name
|
||
for link in next_node.input_links
|
||
if link.is_static and link.sink_name not in next_node_input
|
||
}
|
||
if static_link_names and (
|
||
latest_execution := await db_client.get_latest_node_execution(
|
||
next_node_id, graph_exec_id
|
||
)
|
||
):
|
||
for name in static_link_names:
|
||
next_node_input[name] = latest_execution.input_data.get(name)
|
||
|
||
# Apply node input overrides
|
||
node_input_mask = None
|
||
if nodes_input_masks and (
|
||
node_input_mask := nodes_input_masks.get(next_node.id)
|
||
):
|
||
next_node_input.update(node_input_mask)
|
||
|
||
# Validate the input data for the next node.
|
||
next_node_input, validation_msg = validate_exec(next_node, next_node_input)
|
||
suffix = f"{next_output_name}>{next_input_name}~{next_node_exec_id}:{validation_msg}"
|
||
|
||
# Incomplete input data, skip queueing the execution.
|
||
if not next_node_input:
|
||
log_metadata.info(f"Skipped queueing {suffix}")
|
||
return enqueued_executions
|
||
|
||
# Input is complete, enqueue the execution.
|
||
log_metadata.info(f"Enqueued {suffix}")
|
||
enqueued_executions.append(
|
||
await add_enqueued_execution(
|
||
node_exec_id=next_node_exec_id,
|
||
node_id=next_node_id,
|
||
block_id=next_node.block_id,
|
||
data=next_node_input,
|
||
)
|
||
)
|
||
|
||
# Next execution stops here if the link is not static.
|
||
if not node_link.is_static:
|
||
return enqueued_executions
|
||
|
||
# If link is static, there could be some incomplete executions waiting for it.
|
||
# Load and complete the input missing input data, and try to re-enqueue them.
|
||
for iexec in await db_client.get_node_executions(
|
||
node_id=next_node_id,
|
||
graph_exec_id=graph_exec_id,
|
||
statuses=[ExecutionStatus.INCOMPLETE],
|
||
):
|
||
idata = iexec.input_data
|
||
ineid = iexec.node_exec_id
|
||
|
||
static_link_names = {
|
||
link.sink_name
|
||
for link in next_node.input_links
|
||
if link.is_static and link.sink_name not in idata
|
||
}
|
||
for input_name in static_link_names:
|
||
idata[input_name] = next_node_input[input_name]
|
||
|
||
# Apply node input overrides
|
||
if node_input_mask:
|
||
idata.update(node_input_mask)
|
||
|
||
idata, msg = validate_exec(next_node, idata)
|
||
suffix = f"{next_output_name}>{next_input_name}~{ineid}:{msg}"
|
||
if not idata:
|
||
log_metadata.info(f"Enqueueing static-link skipped: {suffix}")
|
||
continue
|
||
log_metadata.info(f"Enqueueing static-link execution {suffix}")
|
||
enqueued_executions.append(
|
||
await add_enqueued_execution(
|
||
node_exec_id=iexec.node_exec_id,
|
||
node_id=next_node_id,
|
||
block_id=next_node.block_id,
|
||
data=idata,
|
||
)
|
||
)
|
||
return enqueued_executions
|
||
|
||
return [
|
||
execution
|
||
for link in node.output_links
|
||
for execution in await register_next_executions(link)
|
||
]
|
||
|
||
|
||
class ExecutionProcessor:
|
||
"""
|
||
This class contains event handlers for the process pool executor events.
|
||
|
||
The main events are:
|
||
on_graph_executor_start: Initialize the process that executes the graph.
|
||
on_graph_execution: Execution logic for a graph.
|
||
on_node_execution: Execution logic for a node.
|
||
|
||
The execution flow:
|
||
1. Graph execution request is added to the queue.
|
||
2. Graph executor loop picks the request from the queue.
|
||
3. Graph executor loop submits the graph execution request to the executor pool.
|
||
[on_graph_execution]
|
||
4. Graph executor initialize the node execution queue.
|
||
5. Graph executor adds the starting nodes to the node execution queue.
|
||
6. Graph executor waits for all nodes to be executed.
|
||
[on_node_execution]
|
||
7. Node executor picks the node execution request from the queue.
|
||
8. Node executor executes the node.
|
||
9. Node executor enqueues the next executed nodes to the node execution queue.
|
||
"""
|
||
|
||
@async_error_logged(swallow=True)
|
||
async def on_node_execution(
|
||
self,
|
||
node_exec: NodeExecutionEntry,
|
||
node_exec_progress: NodeExecutionProgress,
|
||
nodes_input_masks: Optional[NodesInputMasks],
|
||
graph_stats_pair: tuple[GraphExecutionStats, threading.Lock],
|
||
nodes_to_skip: Optional[set[str]] = None,
|
||
) -> NodeExecutionStats:
|
||
log_metadata = LogMetadata(
|
||
logger=_logger,
|
||
user_id=node_exec.user_id,
|
||
graph_eid=node_exec.graph_exec_id,
|
||
graph_id=node_exec.graph_id,
|
||
node_eid=node_exec.node_exec_id,
|
||
node_id=node_exec.node_id,
|
||
block_name=b.name if (b := get_block(node_exec.block_id)) else "-",
|
||
)
|
||
db_client = get_db_async_client()
|
||
node = await db_client.get_node(node_exec.node_id)
|
||
execution_stats = NodeExecutionStats()
|
||
|
||
timing_info, status = await self._on_node_execution(
|
||
node=node,
|
||
node_exec=node_exec,
|
||
node_exec_progress=node_exec_progress,
|
||
stats=execution_stats,
|
||
db_client=db_client,
|
||
log_metadata=log_metadata,
|
||
nodes_input_masks=nodes_input_masks,
|
||
nodes_to_skip=nodes_to_skip,
|
||
)
|
||
if isinstance(status, BaseException):
|
||
raise status
|
||
|
||
execution_stats.walltime = timing_info.wall_time
|
||
execution_stats.cputime = timing_info.cpu_time
|
||
|
||
graph_stats, graph_stats_lock = graph_stats_pair
|
||
with graph_stats_lock:
|
||
graph_stats.node_count += 1 + execution_stats.extra_steps
|
||
graph_stats.nodes_cputime += execution_stats.cputime
|
||
graph_stats.nodes_walltime += execution_stats.walltime
|
||
graph_stats.cost += execution_stats.extra_cost
|
||
if isinstance(execution_stats.error, Exception):
|
||
graph_stats.node_error_count += 1
|
||
|
||
node_error = execution_stats.error
|
||
node_stats = execution_stats.model_dump()
|
||
if node_error and not isinstance(node_error, str):
|
||
node_stats["error"] = str(node_error) or node_stats.__class__.__name__
|
||
|
||
await async_update_node_execution_status(
|
||
db_client=db_client,
|
||
exec_id=node_exec.node_exec_id,
|
||
status=status,
|
||
stats=node_stats,
|
||
)
|
||
await async_update_graph_execution_state(
|
||
db_client=db_client,
|
||
graph_exec_id=node_exec.graph_exec_id,
|
||
stats=graph_stats,
|
||
)
|
||
|
||
return execution_stats
|
||
|
||
@async_time_measured
|
||
async def _on_node_execution(
|
||
self,
|
||
node: Node,
|
||
node_exec: NodeExecutionEntry,
|
||
node_exec_progress: NodeExecutionProgress,
|
||
stats: NodeExecutionStats,
|
||
db_client: "DatabaseManagerAsyncClient",
|
||
log_metadata: LogMetadata,
|
||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||
nodes_to_skip: Optional[set[str]] = None,
|
||
) -> ExecutionStatus:
|
||
status = ExecutionStatus.RUNNING
|
||
|
||
async def persist_output(output_name: str, output_data: Any) -> None:
|
||
await db_client.upsert_execution_output(
|
||
node_exec_id=node_exec.node_exec_id,
|
||
output_name=output_name,
|
||
output_data=output_data,
|
||
)
|
||
if exec_update := await db_client.get_node_execution(
|
||
node_exec.node_exec_id
|
||
):
|
||
await send_async_execution_update(exec_update)
|
||
|
||
node_exec_progress.add_output(
|
||
ExecutionOutputEntry(
|
||
node=node,
|
||
node_exec_id=node_exec.node_exec_id,
|
||
data=(output_name, output_data),
|
||
)
|
||
)
|
||
|
||
try:
|
||
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
|
||
await async_update_node_execution_status(
|
||
db_client=db_client,
|
||
exec_id=node_exec.node_exec_id,
|
||
status=ExecutionStatus.RUNNING,
|
||
)
|
||
|
||
async for output_name, output_data in execute_node(
|
||
node=node,
|
||
data=node_exec,
|
||
execution_processor=self,
|
||
execution_stats=stats,
|
||
nodes_input_masks=nodes_input_masks,
|
||
nodes_to_skip=nodes_to_skip,
|
||
):
|
||
await persist_output(output_name, output_data)
|
||
|
||
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
|
||
status = ExecutionStatus.COMPLETED
|
||
|
||
except BaseException as e:
|
||
stats.error = e
|
||
|
||
if isinstance(e, ValueError):
|
||
# Avoid user error being marked as an actual error.
|
||
log_metadata.info(
|
||
f"Expected failure on node execution {node_exec.node_exec_id}: {type(e).__name__} - {e}"
|
||
)
|
||
status = ExecutionStatus.FAILED
|
||
elif isinstance(e, Exception):
|
||
# If the exception is not a ValueError, it is unexpected.
|
||
log_metadata.exception(
|
||
f"Unexpected failure on node execution {node_exec.node_exec_id}: {type(e).__name__} - {e}"
|
||
)
|
||
status = ExecutionStatus.FAILED
|
||
else:
|
||
# CancelledError or SystemExit
|
||
log_metadata.warning(
|
||
f"Interruption error on node execution {node_exec.node_exec_id}: {type(e).__name__}"
|
||
)
|
||
status = ExecutionStatus.TERMINATED
|
||
|
||
finally:
|
||
if status == ExecutionStatus.FAILED and stats.error is not None:
|
||
await persist_output(
|
||
"error", str(stats.error) or type(stats.error).__name__
|
||
)
|
||
return status
|
||
|
||
@func_retry
|
||
def on_graph_executor_start(self):
|
||
configure_logging()
|
||
set_service_name("GraphExecutor")
|
||
self.tid = threading.get_ident()
|
||
self.creds_manager = IntegrationCredentialsManager()
|
||
self.node_execution_loop = asyncio.new_event_loop()
|
||
self.node_evaluation_loop = asyncio.new_event_loop()
|
||
self.node_execution_thread = threading.Thread(
|
||
target=self.node_execution_loop.run_forever, daemon=True
|
||
)
|
||
self.node_evaluation_thread = threading.Thread(
|
||
target=self.node_evaluation_loop.run_forever, daemon=True
|
||
)
|
||
self.node_execution_thread.start()
|
||
self.node_evaluation_thread.start()
|
||
logger.info(f"[GraphExecutor] {self.tid} started")
|
||
|
||
@error_logged(swallow=False)
|
||
def on_graph_execution(
|
||
self,
|
||
graph_exec: GraphExecutionEntry,
|
||
cancel: threading.Event,
|
||
cluster_lock: ClusterLock,
|
||
):
|
||
log_metadata = LogMetadata(
|
||
logger=_logger,
|
||
user_id=graph_exec.user_id,
|
||
graph_eid=graph_exec.graph_exec_id,
|
||
graph_id=graph_exec.graph_id,
|
||
node_id="*",
|
||
node_eid="*",
|
||
block_name="-",
|
||
)
|
||
db_client = get_db_client()
|
||
|
||
exec_meta = db_client.get_graph_execution_meta(
|
||
user_id=graph_exec.user_id,
|
||
execution_id=graph_exec.graph_exec_id,
|
||
)
|
||
if exec_meta is None:
|
||
log_metadata.warning(
|
||
f"Skipped graph execution #{graph_exec.graph_exec_id}, the graph execution is not found."
|
||
)
|
||
return
|
||
|
||
if exec_meta.status in [ExecutionStatus.QUEUED, ExecutionStatus.INCOMPLETE]:
|
||
log_metadata.info(f"⚙️ Starting graph execution #{graph_exec.graph_exec_id}")
|
||
exec_meta.status = ExecutionStatus.RUNNING
|
||
send_execution_update(
|
||
db_client.update_graph_execution_start_time(graph_exec.graph_exec_id)
|
||
)
|
||
elif exec_meta.status == ExecutionStatus.RUNNING:
|
||
log_metadata.info(
|
||
f"⚙️ Graph execution #{graph_exec.graph_exec_id} is already running, continuing where it left off."
|
||
)
|
||
elif exec_meta.status == ExecutionStatus.REVIEW:
|
||
exec_meta.status = ExecutionStatus.RUNNING
|
||
log_metadata.info(
|
||
f"⚙️ Graph execution #{graph_exec.graph_exec_id} was waiting for review, resuming execution."
|
||
)
|
||
update_graph_execution_state(
|
||
db_client=db_client,
|
||
graph_exec_id=graph_exec.graph_exec_id,
|
||
status=ExecutionStatus.RUNNING,
|
||
)
|
||
elif exec_meta.status == ExecutionStatus.FAILED:
|
||
exec_meta.status = ExecutionStatus.RUNNING
|
||
log_metadata.info(
|
||
f"⚙️ Graph execution #{graph_exec.graph_exec_id} was disturbed, continuing where it left off."
|
||
)
|
||
update_graph_execution_state(
|
||
db_client=db_client,
|
||
graph_exec_id=graph_exec.graph_exec_id,
|
||
status=ExecutionStatus.RUNNING,
|
||
)
|
||
else:
|
||
log_metadata.warning(
|
||
f"Skipped graph execution {graph_exec.graph_exec_id}, the graph execution status is `{exec_meta.status}`."
|
||
)
|
||
return
|
||
|
||
if exec_meta.stats is None:
|
||
exec_stats = GraphExecutionStats()
|
||
else:
|
||
exec_stats = exec_meta.stats.to_db()
|
||
|
||
timing_info, status = self._on_graph_execution(
|
||
graph_exec=graph_exec,
|
||
cancel=cancel,
|
||
log_metadata=log_metadata,
|
||
execution_stats=exec_stats,
|
||
cluster_lock=cluster_lock,
|
||
)
|
||
exec_stats.walltime += timing_info.wall_time
|
||
exec_stats.cputime += timing_info.cpu_time
|
||
|
||
try:
|
||
# Failure handling
|
||
if isinstance(status, BaseException):
|
||
raise status
|
||
exec_meta.status = status
|
||
|
||
if status in [ExecutionStatus.COMPLETED, ExecutionStatus.FAILED]:
|
||
activity_response = asyncio.run_coroutine_threadsafe(
|
||
generate_activity_status_for_execution(
|
||
graph_exec_id=graph_exec.graph_exec_id,
|
||
graph_id=graph_exec.graph_id,
|
||
graph_version=graph_exec.graph_version,
|
||
execution_stats=exec_stats,
|
||
db_client=get_db_async_client(),
|
||
user_id=graph_exec.user_id,
|
||
execution_status=status,
|
||
),
|
||
self.node_execution_loop,
|
||
).result(timeout=60.0)
|
||
else:
|
||
activity_response = None
|
||
if activity_response is not None:
|
||
exec_stats.activity_status = activity_response["activity_status"]
|
||
exec_stats.correctness_score = activity_response["correctness_score"]
|
||
log_metadata.info(
|
||
f"Generated activity status: {activity_response['activity_status']} "
|
||
f"(correctness: {activity_response['correctness_score']:.2f})"
|
||
)
|
||
else:
|
||
log_metadata.debug(
|
||
"Activity status generation disabled, not setting fields"
|
||
)
|
||
finally:
|
||
# Communication handling
|
||
self._handle_agent_run_notif(db_client, graph_exec, exec_stats)
|
||
|
||
update_graph_execution_state(
|
||
db_client=db_client,
|
||
graph_exec_id=graph_exec.graph_exec_id,
|
||
status=exec_meta.status,
|
||
stats=exec_stats,
|
||
)
|
||
|
||
def _charge_usage(
|
||
self,
|
||
node_exec: NodeExecutionEntry,
|
||
execution_count: int,
|
||
) -> tuple[int, int]:
|
||
total_cost = 0
|
||
remaining_balance = 0
|
||
db_client = get_db_client()
|
||
block = get_block(node_exec.block_id)
|
||
if not block:
|
||
logger.error(f"Block {node_exec.block_id} not found.")
|
||
return total_cost, 0
|
||
|
||
cost, matching_filter = block_usage_cost(
|
||
block=block, input_data=node_exec.inputs
|
||
)
|
||
if cost > 0:
|
||
remaining_balance = db_client.spend_credits(
|
||
user_id=node_exec.user_id,
|
||
cost=cost,
|
||
metadata=UsageTransactionMetadata(
|
||
graph_exec_id=node_exec.graph_exec_id,
|
||
graph_id=node_exec.graph_id,
|
||
node_exec_id=node_exec.node_exec_id,
|
||
node_id=node_exec.node_id,
|
||
block_id=node_exec.block_id,
|
||
block=block.name,
|
||
input=matching_filter,
|
||
reason=f"Ran block {node_exec.block_id} {block.name}",
|
||
),
|
||
)
|
||
total_cost += cost
|
||
|
||
cost, usage_count = execution_usage_cost(execution_count)
|
||
if cost > 0:
|
||
remaining_balance = db_client.spend_credits(
|
||
user_id=node_exec.user_id,
|
||
cost=cost,
|
||
metadata=UsageTransactionMetadata(
|
||
graph_exec_id=node_exec.graph_exec_id,
|
||
graph_id=node_exec.graph_id,
|
||
input={
|
||
"execution_count": usage_count,
|
||
"charge": "Execution Cost",
|
||
},
|
||
reason=f"Execution Cost for {usage_count} blocks of ex_id:{node_exec.graph_exec_id} g_id:{node_exec.graph_id}",
|
||
),
|
||
)
|
||
total_cost += cost
|
||
|
||
return total_cost, remaining_balance
|
||
|
||
@time_measured
|
||
def _on_graph_execution(
|
||
self,
|
||
graph_exec: GraphExecutionEntry,
|
||
cancel: threading.Event,
|
||
log_metadata: LogMetadata,
|
||
execution_stats: GraphExecutionStats,
|
||
cluster_lock: ClusterLock,
|
||
) -> ExecutionStatus:
|
||
"""
|
||
Returns:
|
||
dict: The execution statistics of the graph execution.
|
||
ExecutionStatus: The final status of the graph execution.
|
||
Exception | None: The error that occurred during the execution, if any.
|
||
"""
|
||
execution_status: ExecutionStatus = ExecutionStatus.RUNNING
|
||
error: Exception | None = None
|
||
db_client = get_db_client()
|
||
execution_stats_lock = threading.Lock()
|
||
|
||
# State holders ----------------------------------------------------
|
||
self.running_node_execution: dict[str, NodeExecutionProgress] = defaultdict(
|
||
NodeExecutionProgress
|
||
)
|
||
self.running_node_evaluation: dict[str, Future] = {}
|
||
self.execution_stats = execution_stats
|
||
self.execution_stats_lock = execution_stats_lock
|
||
execution_queue = ExecutionQueue[NodeExecutionEntry]()
|
||
|
||
running_node_execution = self.running_node_execution
|
||
running_node_evaluation = self.running_node_evaluation
|
||
|
||
try:
|
||
if db_client.get_credits(graph_exec.user_id) <= 0:
|
||
raise InsufficientBalanceError(
|
||
user_id=graph_exec.user_id,
|
||
message="You have no credits left to run an agent.",
|
||
balance=0,
|
||
amount=1,
|
||
)
|
||
|
||
# Input moderation
|
||
try:
|
||
if moderation_error := asyncio.run_coroutine_threadsafe(
|
||
automod_manager.moderate_graph_execution_inputs(
|
||
db_client=get_db_async_client(),
|
||
graph_exec=graph_exec,
|
||
),
|
||
self.node_evaluation_loop,
|
||
).result(timeout=30.0):
|
||
raise moderation_error
|
||
except asyncio.TimeoutError:
|
||
log_metadata.warning(
|
||
f"Input moderation timed out for graph execution {graph_exec.graph_exec_id}, bypassing moderation and continuing execution"
|
||
)
|
||
# Continue execution without moderation
|
||
|
||
# ------------------------------------------------------------
|
||
# Pre‑populate queue ---------------------------------------
|
||
# ------------------------------------------------------------
|
||
for node_exec in db_client.get_node_executions(
|
||
graph_exec.graph_exec_id,
|
||
statuses=[
|
||
ExecutionStatus.RUNNING,
|
||
ExecutionStatus.QUEUED,
|
||
ExecutionStatus.TERMINATED,
|
||
ExecutionStatus.REVIEW,
|
||
],
|
||
):
|
||
node_entry = node_exec.to_node_execution_entry(
|
||
graph_exec.execution_context
|
||
)
|
||
execution_queue.add(node_entry)
|
||
|
||
# ------------------------------------------------------------
|
||
# Main dispatch / polling loop -----------------------------
|
||
# ------------------------------------------------------------
|
||
|
||
while not execution_queue.empty():
|
||
if cancel.is_set():
|
||
break
|
||
|
||
queued_node_exec = execution_queue.get()
|
||
|
||
# Check if this node should be skipped due to optional credentials
|
||
if queued_node_exec.node_id in graph_exec.nodes_to_skip:
|
||
log_metadata.info(
|
||
f"Skipping node execution {queued_node_exec.node_exec_id} "
|
||
f"for node {queued_node_exec.node_id} - optional credentials not configured"
|
||
)
|
||
# Mark the node as completed without executing
|
||
# No outputs will be produced, so downstream nodes won't trigger
|
||
update_node_execution_status(
|
||
db_client=db_client,
|
||
exec_id=queued_node_exec.node_exec_id,
|
||
status=ExecutionStatus.COMPLETED,
|
||
)
|
||
continue
|
||
|
||
log_metadata.debug(
|
||
f"Dispatching node execution {queued_node_exec.node_exec_id} "
|
||
f"for node {queued_node_exec.node_id}",
|
||
)
|
||
|
||
# Charge usage (may raise) ------------------------------
|
||
try:
|
||
cost, remaining_balance = self._charge_usage(
|
||
node_exec=queued_node_exec,
|
||
execution_count=increment_execution_count(graph_exec.user_id),
|
||
)
|
||
with execution_stats_lock:
|
||
execution_stats.cost += cost
|
||
# Check if we crossed the low balance threshold
|
||
self._handle_low_balance(
|
||
db_client=db_client,
|
||
user_id=graph_exec.user_id,
|
||
current_balance=remaining_balance,
|
||
transaction_cost=cost,
|
||
)
|
||
except InsufficientBalanceError as balance_error:
|
||
error = balance_error # Set error to trigger FAILED status
|
||
node_exec_id = queued_node_exec.node_exec_id
|
||
db_client.upsert_execution_output(
|
||
node_exec_id=node_exec_id,
|
||
output_name="error",
|
||
output_data=str(error),
|
||
)
|
||
update_node_execution_status(
|
||
db_client=db_client,
|
||
exec_id=node_exec_id,
|
||
status=ExecutionStatus.FAILED,
|
||
)
|
||
|
||
self._handle_insufficient_funds_notif(
|
||
db_client,
|
||
graph_exec.user_id,
|
||
graph_exec.graph_id,
|
||
error,
|
||
)
|
||
# Gracefully stop the execution loop
|
||
break
|
||
|
||
# Add input overrides -----------------------------
|
||
node_id = queued_node_exec.node_id
|
||
if (nodes_input_masks := graph_exec.nodes_input_masks) and (
|
||
node_input_mask := nodes_input_masks.get(node_id)
|
||
):
|
||
queued_node_exec.inputs.update(node_input_mask)
|
||
|
||
# Kick off async node execution -------------------------
|
||
node_execution_task = asyncio.run_coroutine_threadsafe(
|
||
self.on_node_execution(
|
||
node_exec=queued_node_exec,
|
||
node_exec_progress=running_node_execution[node_id],
|
||
nodes_input_masks=nodes_input_masks,
|
||
graph_stats_pair=(
|
||
execution_stats,
|
||
execution_stats_lock,
|
||
),
|
||
nodes_to_skip=graph_exec.nodes_to_skip,
|
||
),
|
||
self.node_execution_loop,
|
||
)
|
||
running_node_execution[node_id].add_task(
|
||
node_exec_id=queued_node_exec.node_exec_id,
|
||
task=node_execution_task,
|
||
)
|
||
|
||
# Poll until queue refills or all inflight work done ----
|
||
while execution_queue.empty() and (
|
||
running_node_execution or running_node_evaluation
|
||
):
|
||
if cancel.is_set():
|
||
break
|
||
|
||
# --------------------------------------------------
|
||
# Handle inflight evaluations ---------------------
|
||
# --------------------------------------------------
|
||
node_output_found = False
|
||
for node_id, inflight_exec in list(running_node_execution.items()):
|
||
if cancel.is_set():
|
||
break
|
||
|
||
# node evaluation future -----------------
|
||
if inflight_eval := running_node_evaluation.get(node_id):
|
||
if not inflight_eval.done():
|
||
continue
|
||
try:
|
||
inflight_eval.result(timeout=0)
|
||
running_node_evaluation.pop(node_id)
|
||
except Exception as e:
|
||
log_metadata.error(f"Node eval #{node_id} failed: {e}")
|
||
|
||
# node execution future ---------------------------
|
||
if inflight_exec.is_done():
|
||
running_node_execution.pop(node_id)
|
||
continue
|
||
|
||
if output := inflight_exec.pop_output():
|
||
node_output_found = True
|
||
running_node_evaluation[node_id] = (
|
||
asyncio.run_coroutine_threadsafe(
|
||
self._process_node_output(
|
||
output=output,
|
||
node_id=node_id,
|
||
graph_exec=graph_exec,
|
||
log_metadata=log_metadata,
|
||
nodes_input_masks=nodes_input_masks,
|
||
execution_queue=execution_queue,
|
||
),
|
||
self.node_evaluation_loop,
|
||
)
|
||
)
|
||
if (
|
||
not node_output_found
|
||
and execution_queue.empty()
|
||
and (running_node_execution or running_node_evaluation)
|
||
):
|
||
cluster_lock.refresh()
|
||
time.sleep(0.1)
|
||
|
||
# loop done --------------------------------------------------
|
||
|
||
# Output moderation
|
||
try:
|
||
if moderation_error := asyncio.run_coroutine_threadsafe(
|
||
automod_manager.moderate_graph_execution_outputs(
|
||
db_client=get_db_async_client(),
|
||
graph_exec_id=graph_exec.graph_exec_id,
|
||
user_id=graph_exec.user_id,
|
||
graph_id=graph_exec.graph_id,
|
||
),
|
||
self.node_evaluation_loop,
|
||
).result(timeout=30.0):
|
||
raise moderation_error
|
||
except asyncio.TimeoutError:
|
||
log_metadata.warning(
|
||
f"Output moderation timed out for graph execution {graph_exec.graph_exec_id}, bypassing moderation and continuing execution"
|
||
)
|
||
# Continue execution without moderation
|
||
|
||
# Determine final execution status based on whether there was an error or termination
|
||
if cancel.is_set():
|
||
execution_status = ExecutionStatus.TERMINATED
|
||
elif error is not None:
|
||
execution_status = ExecutionStatus.FAILED
|
||
else:
|
||
if db_client.has_pending_reviews_for_graph_exec(
|
||
graph_exec.graph_exec_id
|
||
):
|
||
execution_status = ExecutionStatus.REVIEW
|
||
else:
|
||
execution_status = ExecutionStatus.COMPLETED
|
||
|
||
if error:
|
||
execution_stats.error = str(error) or type(error).__name__
|
||
|
||
return execution_status
|
||
|
||
except BaseException as e:
|
||
error = (
|
||
e
|
||
if isinstance(e, Exception)
|
||
else Exception(f"{e.__class__.__name__}: {e}")
|
||
)
|
||
if not execution_stats.error:
|
||
execution_stats.error = str(error)
|
||
|
||
known_errors = (InsufficientBalanceError, ModerationError)
|
||
if isinstance(error, known_errors):
|
||
return ExecutionStatus.FAILED
|
||
|
||
execution_status = ExecutionStatus.FAILED
|
||
log_metadata.exception(
|
||
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
|
||
)
|
||
|
||
# Send rate-limited Discord alert for unknown/unexpected errors
|
||
send_rate_limited_discord_alert(
|
||
"graph_execution",
|
||
error,
|
||
"unknown_error",
|
||
f"🚨 **Unknown Graph Execution Error**\n"
|
||
f"User: {graph_exec.user_id}\n"
|
||
f"Graph ID: {graph_exec.graph_id}\n"
|
||
f"Execution ID: {graph_exec.graph_exec_id}\n"
|
||
f"Error Type: {type(error).__name__}\n"
|
||
f"Error: {str(error)[:200]}{'...' if len(str(error)) > 200 else ''}\n",
|
||
)
|
||
|
||
raise
|
||
|
||
finally:
|
||
self._cleanup_graph_execution(
|
||
execution_queue=execution_queue,
|
||
running_node_execution=running_node_execution,
|
||
running_node_evaluation=running_node_evaluation,
|
||
execution_status=execution_status,
|
||
error=error,
|
||
graph_exec_id=graph_exec.graph_exec_id,
|
||
log_metadata=log_metadata,
|
||
db_client=db_client,
|
||
)
|
||
|
||
@error_logged(swallow=True)
|
||
def _cleanup_graph_execution(
|
||
self,
|
||
execution_queue: ExecutionQueue[NodeExecutionEntry],
|
||
running_node_execution: dict[str, "NodeExecutionProgress"],
|
||
running_node_evaluation: dict[str, Future],
|
||
execution_status: ExecutionStatus,
|
||
error: Exception | None,
|
||
graph_exec_id: str,
|
||
log_metadata: LogMetadata,
|
||
db_client: "DatabaseManagerClient",
|
||
) -> None:
|
||
"""
|
||
Clean up running node executions and evaluations when graph execution ends.
|
||
This method is decorated with @error_logged(swallow=True) to ensure cleanup
|
||
never fails in the finally block.
|
||
"""
|
||
# Cancel and wait for all node executions to complete
|
||
for node_id, inflight_exec in running_node_execution.items():
|
||
if inflight_exec.is_done():
|
||
continue
|
||
log_metadata.info(f"Stopping node execution {node_id}")
|
||
inflight_exec.stop()
|
||
|
||
for node_id, inflight_exec in running_node_execution.items():
|
||
try:
|
||
inflight_exec.wait_for_done(timeout=3600.0)
|
||
except TimeoutError:
|
||
log_metadata.exception(
|
||
f"Node execution #{node_id} did not stop in time, "
|
||
"it may be stuck or taking too long."
|
||
)
|
||
|
||
# Wait the remaining inflight evaluations to finish
|
||
for node_id, inflight_eval in running_node_evaluation.items():
|
||
try:
|
||
inflight_eval.result(timeout=3600.0)
|
||
except TimeoutError:
|
||
log_metadata.exception(
|
||
f"Node evaluation #{node_id} did not stop in time, "
|
||
"it may be stuck or taking too long."
|
||
)
|
||
|
||
while queued_execution := execution_queue.get_or_none():
|
||
update_node_execution_status(
|
||
db_client=db_client,
|
||
exec_id=queued_execution.node_exec_id,
|
||
status=execution_status,
|
||
stats={"error": str(error)} if error else None,
|
||
)
|
||
|
||
clean_exec_files(graph_exec_id)
|
||
|
||
@async_error_logged(swallow=True)
|
||
async def _process_node_output(
|
||
self,
|
||
output: ExecutionOutputEntry,
|
||
node_id: str,
|
||
graph_exec: GraphExecutionEntry,
|
||
log_metadata: LogMetadata,
|
||
nodes_input_masks: Optional[NodesInputMasks],
|
||
execution_queue: ExecutionQueue[NodeExecutionEntry],
|
||
) -> None:
|
||
"""Process a node's output, update its status, and enqueue next nodes.
|
||
|
||
Args:
|
||
output: The execution output entry to process
|
||
node_id: The ID of the node that produced the output
|
||
graph_exec: The graph execution entry
|
||
log_metadata: Logger metadata for consistent logging
|
||
nodes_input_masks: Optional map of node input overrides
|
||
execution_queue: Queue to add next executions to
|
||
"""
|
||
db_client = get_db_async_client()
|
||
|
||
log_metadata.debug(f"Enqueue nodes for {node_id}: {output}")
|
||
|
||
for next_execution in await _enqueue_next_nodes(
|
||
db_client=db_client,
|
||
node=output.node,
|
||
output=output.data,
|
||
user_id=graph_exec.user_id,
|
||
graph_exec_id=graph_exec.graph_exec_id,
|
||
graph_id=graph_exec.graph_id,
|
||
graph_version=graph_exec.graph_version,
|
||
log_metadata=log_metadata,
|
||
nodes_input_masks=nodes_input_masks,
|
||
execution_context=graph_exec.execution_context,
|
||
):
|
||
execution_queue.add(next_execution)
|
||
|
||
def _handle_agent_run_notif(
|
||
self,
|
||
db_client: "DatabaseManagerClient",
|
||
graph_exec: GraphExecutionEntry,
|
||
exec_stats: GraphExecutionStats,
|
||
):
|
||
metadata = db_client.get_graph_metadata(
|
||
graph_exec.graph_id, graph_exec.graph_version
|
||
)
|
||
outputs = db_client.get_node_executions(
|
||
graph_exec.graph_exec_id,
|
||
block_ids=[AgentOutputBlock().id],
|
||
)
|
||
|
||
named_outputs = [
|
||
{
|
||
key: value[0] if key == "name" else value
|
||
for key, value in output.output_data.items()
|
||
}
|
||
for output in outputs
|
||
]
|
||
|
||
queue_notification(
|
||
NotificationEventModel(
|
||
user_id=graph_exec.user_id,
|
||
type=NotificationType.AGENT_RUN,
|
||
data=AgentRunData(
|
||
outputs=named_outputs,
|
||
agent_name=metadata.name if metadata else "Unknown Agent",
|
||
credits_used=exec_stats.cost,
|
||
execution_time=exec_stats.walltime,
|
||
graph_id=graph_exec.graph_id,
|
||
node_count=exec_stats.node_count,
|
||
),
|
||
)
|
||
)
|
||
|
||
def _handle_insufficient_funds_notif(
|
||
self,
|
||
db_client: "DatabaseManagerClient",
|
||
user_id: str,
|
||
graph_id: str,
|
||
e: InsufficientBalanceError,
|
||
):
|
||
# Check if we've already sent a notification for this user+agent combo.
|
||
# We only send one notification per user per agent until they top up credits.
|
||
redis_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
|
||
try:
|
||
redis_client = redis.get_redis()
|
||
# SET NX returns True only if the key was newly set (didn't exist)
|
||
is_new_notification = redis_client.set(
|
||
redis_key,
|
||
"1",
|
||
nx=True,
|
||
ex=INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS,
|
||
)
|
||
if not is_new_notification:
|
||
# Already notified for this user+agent, skip all notifications
|
||
logger.debug(
|
||
f"Skipping duplicate insufficient funds notification for "
|
||
f"user={user_id}, graph={graph_id}"
|
||
)
|
||
return
|
||
except Exception as redis_error:
|
||
# If Redis fails, log and continue to send the notification
|
||
# (better to occasionally duplicate than to never notify)
|
||
logger.warning(
|
||
f"Failed to check/set insufficient funds notification flag in Redis: "
|
||
f"{redis_error}"
|
||
)
|
||
|
||
shortfall = abs(e.amount) - e.balance
|
||
metadata = db_client.get_graph_metadata(graph_id)
|
||
base_url = (
|
||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||
)
|
||
|
||
# Queue user email notification
|
||
queue_notification(
|
||
NotificationEventModel(
|
||
user_id=user_id,
|
||
type=NotificationType.ZERO_BALANCE,
|
||
data=ZeroBalanceData(
|
||
current_balance=e.balance,
|
||
billing_page_link=f"{base_url}/profile/credits",
|
||
shortfall=shortfall,
|
||
agent_name=metadata.name if metadata else "Unknown Agent",
|
||
),
|
||
)
|
||
)
|
||
|
||
# Send Discord system alert
|
||
try:
|
||
user_email = db_client.get_user_email_by_id(user_id)
|
||
|
||
alert_message = (
|
||
f"❌ **Insufficient Funds Alert**\n"
|
||
f"User: {user_email or user_id}\n"
|
||
f"Agent: {metadata.name if metadata else 'Unknown Agent'}\n"
|
||
f"Current balance: ${e.balance / 100:.2f}\n"
|
||
f"Attempted cost: ${abs(e.amount) / 100:.2f}\n"
|
||
f"Shortfall: ${abs(shortfall) / 100:.2f}\n"
|
||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||
)
|
||
|
||
get_notification_manager_client().discord_system_alert(
|
||
alert_message, DiscordChannel.PRODUCT
|
||
)
|
||
except Exception as alert_error:
|
||
logger.error(
|
||
f"Failed to send insufficient funds Discord alert: {alert_error}"
|
||
)
|
||
|
||
def _handle_low_balance(
|
||
self,
|
||
db_client: "DatabaseManagerClient",
|
||
user_id: str,
|
||
current_balance: int,
|
||
transaction_cost: int,
|
||
):
|
||
"""Check and handle low balance scenarios after a transaction"""
|
||
LOW_BALANCE_THRESHOLD = settings.config.low_balance_threshold
|
||
|
||
balance_before = current_balance + transaction_cost
|
||
|
||
if (
|
||
current_balance < LOW_BALANCE_THRESHOLD
|
||
and balance_before >= LOW_BALANCE_THRESHOLD
|
||
):
|
||
base_url = (
|
||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||
)
|
||
queue_notification(
|
||
NotificationEventModel(
|
||
user_id=user_id,
|
||
type=NotificationType.LOW_BALANCE,
|
||
data=LowBalanceData(
|
||
current_balance=current_balance,
|
||
billing_page_link=f"{base_url}/profile/credits",
|
||
),
|
||
)
|
||
)
|
||
|
||
try:
|
||
user_email = db_client.get_user_email_by_id(user_id)
|
||
alert_message = (
|
||
f"⚠️ **Low Balance Alert**\n"
|
||
f"User: {user_email or user_id}\n"
|
||
f"Balance dropped below ${LOW_BALANCE_THRESHOLD / 100:.2f}\n"
|
||
f"Current balance: ${current_balance / 100:.2f}\n"
|
||
f"Transaction cost: ${transaction_cost / 100:.2f}\n"
|
||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||
)
|
||
get_notification_manager_client().discord_system_alert(
|
||
alert_message, DiscordChannel.PRODUCT
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Failed to send low balance Discord alert: {e}")
|
||
|
||
|
||
class ExecutionManager(AppProcess):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.pool_size = settings.config.num_graph_workers
|
||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||
self.executor_id = str(uuid.uuid4())
|
||
|
||
self._executor = None
|
||
self._stop_consuming = None
|
||
|
||
self._cancel_thread = None
|
||
self._cancel_client = None
|
||
self._run_thread = None
|
||
self._run_client = None
|
||
|
||
self._execution_locks = {}
|
||
|
||
@property
|
||
def cancel_thread(self) -> threading.Thread:
|
||
if self._cancel_thread is None:
|
||
self._cancel_thread = threading.Thread(
|
||
target=lambda: self._consume_execution_cancel(),
|
||
daemon=True,
|
||
)
|
||
return self._cancel_thread
|
||
|
||
@property
|
||
def run_thread(self) -> threading.Thread:
|
||
if self._run_thread is None:
|
||
self._run_thread = threading.Thread(
|
||
target=lambda: self._consume_execution_run(),
|
||
daemon=True,
|
||
)
|
||
return self._run_thread
|
||
|
||
@property
|
||
def stop_consuming(self) -> threading.Event:
|
||
if self._stop_consuming is None:
|
||
self._stop_consuming = threading.Event()
|
||
return self._stop_consuming
|
||
|
||
@property
|
||
def executor(self) -> ThreadPoolExecutor:
|
||
if self._executor is None:
|
||
self._executor = ThreadPoolExecutor(
|
||
max_workers=self.pool_size,
|
||
initializer=init_worker,
|
||
)
|
||
return self._executor
|
||
|
||
@property
|
||
def cancel_client(self) -> SyncRabbitMQ:
|
||
if self._cancel_client is None:
|
||
self._cancel_client = SyncRabbitMQ(create_execution_queue_config())
|
||
return self._cancel_client
|
||
|
||
@property
|
||
def run_client(self) -> SyncRabbitMQ:
|
||
if self._run_client is None:
|
||
self._run_client = SyncRabbitMQ(create_execution_queue_config())
|
||
return self._run_client
|
||
|
||
def run(self):
|
||
logger.info(
|
||
f"[{self.service_name}] 🆔 Pod assigned executor_id: {self.executor_id}"
|
||
)
|
||
logger.info(f"[{self.service_name}] ⏳ Spawn max-{self.pool_size} workers...")
|
||
|
||
pool_size_gauge.set(self.pool_size)
|
||
self._update_prompt_metrics()
|
||
start_http_server(settings.config.execution_manager_port)
|
||
|
||
self.cancel_thread.start()
|
||
self.run_thread.start()
|
||
|
||
while True:
|
||
time.sleep(1e5)
|
||
|
||
@continuous_retry()
|
||
def _consume_execution_cancel(self):
|
||
if self.stop_consuming.is_set() and not self.active_graph_runs:
|
||
logger.info(
|
||
f"[{self.service_name}] Stop reconnecting cancel consumer since the service is cleaned up."
|
||
)
|
||
return
|
||
|
||
# Check if channel is closed and force reconnection if needed
|
||
if not self.cancel_client.is_ready:
|
||
self.cancel_client.disconnect()
|
||
self.cancel_client.connect()
|
||
cancel_channel = self.cancel_client.get_channel()
|
||
cancel_channel.basic_consume(
|
||
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||
on_message_callback=self._handle_cancel_message,
|
||
auto_ack=True,
|
||
)
|
||
logger.info(f"[{self.service_name}] ⏳ Starting cancel message consumer...")
|
||
cancel_channel.start_consuming()
|
||
if not self.stop_consuming.is_set() or self.active_graph_runs:
|
||
raise RuntimeError(
|
||
f"[{self.service_name}] ❌ cancel message consumer is stopped: {cancel_channel}"
|
||
)
|
||
logger.info(
|
||
f"[{self.service_name}] ✅ Cancel message consumer stopped gracefully"
|
||
)
|
||
|
||
@continuous_retry()
|
||
def _consume_execution_run(self):
|
||
# Long-running executions are handled by:
|
||
# 1. Long consumer timeout (x-consumer-timeout) allows long running agent
|
||
# 2. Enhanced connection settings (5 retries, 1s delay) for quick reconnection
|
||
# 3. Process monitoring ensures failed executors release messages back to queue
|
||
if self.stop_consuming.is_set():
|
||
logger.info(
|
||
f"[{self.service_name}] Stop reconnecting execution consumer since the service is cleaned up."
|
||
)
|
||
return
|
||
|
||
# Check if channel is closed and force reconnection if needed
|
||
if not self.run_client.is_ready:
|
||
self.run_client.disconnect()
|
||
self.run_client.connect()
|
||
run_channel = self.run_client.get_channel()
|
||
run_channel.basic_qos(prefetch_count=self.pool_size)
|
||
|
||
# Configure consumer for long-running graph executions
|
||
# auto_ack=False: Don't acknowledge messages until execution completes (prevents data loss)
|
||
run_channel.basic_consume(
|
||
queue=GRAPH_EXECUTION_QUEUE_NAME,
|
||
on_message_callback=self._handle_run_message,
|
||
auto_ack=False,
|
||
consumer_tag="graph_execution_consumer",
|
||
)
|
||
run_channel.confirm_delivery()
|
||
logger.info(f"[{self.service_name}] ⏳ Starting to consume run messages...")
|
||
run_channel.start_consuming()
|
||
if not self.stop_consuming.is_set():
|
||
raise RuntimeError(
|
||
f"[{self.service_name}] ❌ run message consumer is stopped: {run_channel}"
|
||
)
|
||
logger.info(f"[{self.service_name}] ✅ Run message consumer stopped gracefully")
|
||
|
||
@error_logged(swallow=True)
|
||
def _handle_cancel_message(
|
||
self,
|
||
_channel: BlockingChannel,
|
||
_method: Basic.Deliver,
|
||
_properties: BasicProperties,
|
||
body: bytes,
|
||
):
|
||
"""
|
||
Called whenever we receive a CANCEL message from the queue.
|
||
(With auto_ack=True, message is considered 'acked' automatically.)
|
||
"""
|
||
request = CancelExecutionEvent.model_validate_json(body)
|
||
graph_exec_id = request.graph_exec_id
|
||
if not graph_exec_id:
|
||
logger.warning(
|
||
f"[{self.service_name}] Cancel message missing 'graph_exec_id'"
|
||
)
|
||
return
|
||
if graph_exec_id not in self.active_graph_runs:
|
||
logger.debug(
|
||
f"[{self.service_name}] Cancel received for {graph_exec_id} but not active."
|
||
)
|
||
return
|
||
|
||
_, cancel_event = self.active_graph_runs[graph_exec_id]
|
||
logger.info(f"[{self.service_name}] Received cancel for {graph_exec_id}")
|
||
if not cancel_event.is_set():
|
||
cancel_event.set()
|
||
else:
|
||
logger.debug(
|
||
f"[{self.service_name}] Cancel already set for {graph_exec_id}"
|
||
)
|
||
|
||
def _handle_run_message(
|
||
self,
|
||
_channel: BlockingChannel,
|
||
method: Basic.Deliver,
|
||
_properties: BasicProperties,
|
||
body: bytes,
|
||
):
|
||
delivery_tag = method.delivery_tag
|
||
|
||
@func_retry
|
||
def _ack_message(reject: bool, requeue: bool):
|
||
"""
|
||
Acknowledge or reject the message based on execution status.
|
||
|
||
Args:
|
||
reject: Whether to reject the message
|
||
requeue: Whether to requeue the message
|
||
"""
|
||
|
||
# Connection can be lost, so always get a fresh channel
|
||
channel = self.run_client.get_channel()
|
||
if reject:
|
||
if requeue and settings.config.requeue_by_republishing:
|
||
# Send rejected message to back of queue using republishing
|
||
def _republish_to_back():
|
||
try:
|
||
# First republish to back of queue
|
||
self.run_client.publish_message(
|
||
routing_key=GRAPH_EXECUTION_ROUTING_KEY,
|
||
message=body.decode(), # publish_message expects string, not bytes
|
||
exchange=GRAPH_EXECUTION_EXCHANGE,
|
||
)
|
||
# Then reject without requeue (message already republished)
|
||
channel.basic_nack(delivery_tag, requeue=False)
|
||
logger.info("Message requeued to back of queue")
|
||
except Exception as e:
|
||
logger.error(
|
||
f"[{self.service_name}] Failed to requeue message to back: {e}"
|
||
)
|
||
# Fall back to traditional requeue on failure
|
||
channel.basic_nack(delivery_tag, requeue=True)
|
||
|
||
channel.connection.add_callback_threadsafe(_republish_to_back)
|
||
else:
|
||
# Traditional requeue (goes to front) or no requeue
|
||
channel.connection.add_callback_threadsafe(
|
||
lambda: channel.basic_nack(delivery_tag, requeue=requeue)
|
||
)
|
||
else:
|
||
channel.connection.add_callback_threadsafe(
|
||
lambda: channel.basic_ack(delivery_tag)
|
||
)
|
||
|
||
# Check if we're shutting down - reject new messages but keep connection alive
|
||
if self.stop_consuming.is_set():
|
||
logger.info(
|
||
f"[{self.service_name}] Rejecting new execution during shutdown"
|
||
)
|
||
_ack_message(reject=True, requeue=True)
|
||
return
|
||
|
||
# Check if we can accept more runs
|
||
self._cleanup_completed_runs()
|
||
if len(self.active_graph_runs) >= self.pool_size:
|
||
_ack_message(reject=True, requeue=True)
|
||
return
|
||
|
||
try:
|
||
graph_exec_entry = GraphExecutionEntry.model_validate_json(body)
|
||
except Exception as e:
|
||
logger.error(
|
||
f"[{self.service_name}] Could not parse run message: {e}, body={body}"
|
||
)
|
||
_ack_message(reject=True, requeue=False)
|
||
return
|
||
|
||
graph_exec_id = graph_exec_entry.graph_exec_id
|
||
user_id = graph_exec_entry.user_id
|
||
graph_id = graph_exec_entry.graph_id
|
||
root_exec_id = graph_exec_entry.execution_context.root_execution_id
|
||
parent_exec_id = graph_exec_entry.execution_context.parent_execution_id
|
||
|
||
logger.info(
|
||
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}, user_id={user_id}, executor_id={self.executor_id}"
|
||
+ (f", root={root_exec_id}" if root_exec_id else "")
|
||
+ (f", parent={parent_exec_id}" if parent_exec_id else "")
|
||
)
|
||
|
||
# Check if root execution is already terminated (prevents orphaned child executions)
|
||
if root_exec_id and root_exec_id != graph_exec_id:
|
||
parent_exec = get_db_client().get_graph_execution_meta(
|
||
execution_id=root_exec_id,
|
||
user_id=user_id,
|
||
)
|
||
if parent_exec and parent_exec.status == ExecutionStatus.TERMINATED:
|
||
logger.info(
|
||
f"[{self.service_name}] Skipping execution {graph_exec_id} - parent {root_exec_id} is TERMINATED"
|
||
)
|
||
# Mark this child as terminated since parent was stopped
|
||
get_db_client().update_graph_execution_stats(
|
||
graph_exec_id=graph_exec_id,
|
||
status=ExecutionStatus.TERMINATED,
|
||
)
|
||
_ack_message(reject=False, requeue=False)
|
||
return
|
||
|
||
# Check user rate limit before processing
|
||
try:
|
||
# Only check executions from the last 24 hours for performance
|
||
current_running_count = get_db_client().get_graph_executions_count(
|
||
user_id=user_id,
|
||
graph_id=graph_id,
|
||
statuses=[ExecutionStatus.RUNNING],
|
||
created_time_gte=datetime.now(timezone.utc) - timedelta(hours=24),
|
||
)
|
||
|
||
if (
|
||
current_running_count
|
||
>= settings.config.max_concurrent_graph_executions_per_user
|
||
):
|
||
logger.warning(
|
||
f"[{self.service_name}] Rate limit exceeded for user {user_id} on graph {graph_id}: "
|
||
f"{current_running_count}/{settings.config.max_concurrent_graph_executions_per_user} running executions"
|
||
)
|
||
_ack_message(reject=True, requeue=True)
|
||
return
|
||
|
||
except Exception as e:
|
||
logger.error(
|
||
f"[{self.service_name}] Failed to check rate limit for user {user_id}: {e}, proceeding with execution"
|
||
)
|
||
# If rate limit check fails, proceed to avoid blocking executions
|
||
|
||
# Check for local duplicate execution first
|
||
if graph_exec_id in self.active_graph_runs:
|
||
logger.warning(
|
||
f"[{self.service_name}] Graph {graph_exec_id} already running locally; rejecting duplicate."
|
||
)
|
||
_ack_message(reject=True, requeue=True)
|
||
return
|
||
|
||
# Try to acquire cluster-wide execution lock
|
||
cluster_lock = ClusterLock(
|
||
redis=redis.get_redis(),
|
||
key=f"exec_lock:{graph_exec_id}",
|
||
owner_id=self.executor_id,
|
||
timeout=settings.config.cluster_lock_timeout,
|
||
)
|
||
current_owner = cluster_lock.try_acquire()
|
||
if current_owner != self.executor_id:
|
||
# Either someone else has it or Redis is unavailable
|
||
if current_owner is not None:
|
||
logger.warning(
|
||
f"[{self.service_name}] Graph {graph_exec_id} already running on pod {current_owner}, current executor_id={self.executor_id}"
|
||
)
|
||
_ack_message(reject=True, requeue=False)
|
||
else:
|
||
logger.warning(
|
||
f"[{self.service_name}] Could not acquire lock for {graph_exec_id} - Redis unavailable"
|
||
)
|
||
_ack_message(reject=True, requeue=True)
|
||
return
|
||
|
||
# Wrap entire block after successful lock acquisition
|
||
try:
|
||
self._execution_locks[graph_exec_id] = cluster_lock
|
||
|
||
logger.info(
|
||
f"[{self.service_name}] Successfully acquired cluster lock for {graph_exec_id}, executor_id={self.executor_id}"
|
||
)
|
||
|
||
cancel_event = threading.Event()
|
||
future = self.executor.submit(
|
||
execute_graph, graph_exec_entry, cancel_event, cluster_lock
|
||
)
|
||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||
except Exception as e:
|
||
logger.warning(
|
||
f"[{self.service_name}] Failed to setup execution for {graph_exec_id}: {type(e).__name__}: {e}"
|
||
)
|
||
# Release cluster lock before requeue
|
||
cluster_lock.release()
|
||
if graph_exec_id in self._execution_locks:
|
||
del self._execution_locks[graph_exec_id]
|
||
_ack_message(reject=True, requeue=True)
|
||
return
|
||
self._update_prompt_metrics()
|
||
|
||
def _on_run_done(f: Future):
|
||
logger.info(f"[{self.service_name}] Run completed for {graph_exec_id}")
|
||
try:
|
||
if exec_error := f.exception():
|
||
logger.error(
|
||
f"[{self.service_name}] Execution for {graph_exec_id} failed: {type(exec_error)} {exec_error}"
|
||
)
|
||
_ack_message(reject=True, requeue=True)
|
||
else:
|
||
_ack_message(reject=False, requeue=False)
|
||
except BaseException as e:
|
||
logger.exception(
|
||
f"[{self.service_name}] Error in run completion callback: {e}"
|
||
)
|
||
finally:
|
||
# Release the cluster-wide execution lock
|
||
if graph_exec_id in self._execution_locks:
|
||
logger.info(
|
||
f"[{self.service_name}] Releasing cluster lock for {graph_exec_id}, executor_id={self.executor_id}"
|
||
)
|
||
self._execution_locks[graph_exec_id].release()
|
||
del self._execution_locks[graph_exec_id]
|
||
self._cleanup_completed_runs()
|
||
|
||
future.add_done_callback(_on_run_done)
|
||
|
||
def _cleanup_completed_runs(self) -> list[str]:
|
||
"""Remove completed futures from active_graph_runs and update metrics"""
|
||
completed_runs = []
|
||
for graph_exec_id, (future, _) in self.active_graph_runs.items():
|
||
if future.done():
|
||
completed_runs.append(graph_exec_id)
|
||
|
||
for geid in completed_runs:
|
||
logger.info(f"[{self.service_name}] ✅ Cleaned up completed run {geid}")
|
||
self.active_graph_runs.pop(geid, None)
|
||
|
||
self._update_prompt_metrics()
|
||
return completed_runs
|
||
|
||
def _update_prompt_metrics(self):
|
||
active_count = len(self.active_graph_runs)
|
||
active_runs_gauge.set(active_count)
|
||
if self.stop_consuming.is_set():
|
||
utilization_gauge.set(1.0)
|
||
else:
|
||
utilization_gauge.set(active_count / self.pool_size)
|
||
|
||
def _stop_message_consumers(
|
||
self, thread: threading.Thread, client: SyncRabbitMQ, prefix: str
|
||
):
|
||
try:
|
||
channel = client.get_channel()
|
||
channel.connection.add_callback_threadsafe(lambda: channel.stop_consuming())
|
||
|
||
try:
|
||
thread.join(timeout=300)
|
||
except TimeoutError:
|
||
logger.error(
|
||
f"{prefix} ⚠️ Run thread did not finish in time, forcing disconnect"
|
||
)
|
||
|
||
client.disconnect()
|
||
logger.info(f"{prefix} ✅ Run client disconnected")
|
||
except Exception as e:
|
||
logger.error(f"{prefix} ⚠️ Error disconnecting run client: {type(e)} {e}")
|
||
|
||
def cleanup(self):
|
||
"""Override cleanup to implement graceful shutdown with active execution waiting."""
|
||
prefix = f"[{self.service_name}][on_graph_executor_stop {os.getpid()}]"
|
||
logger.info(f"{prefix} 🧹 Starting graceful shutdown...")
|
||
|
||
# Signal the consumer thread to stop (thread-safe)
|
||
try:
|
||
self.stop_consuming.set()
|
||
run_channel = self.run_client.get_channel()
|
||
run_channel.connection.add_callback_threadsafe(
|
||
lambda: run_channel.stop_consuming()
|
||
)
|
||
logger.info(f"{prefix} ✅ Exec consumer has been signaled to stop")
|
||
except Exception as e:
|
||
logger.error(f"{prefix} ⚠️ Error signaling consumer to stop: {type(e)} {e}")
|
||
|
||
# Wait for active executions to complete
|
||
if self.active_graph_runs:
|
||
logger.info(
|
||
f"{prefix} ⏳ Waiting for {len(self.active_graph_runs)} active executions to complete..."
|
||
)
|
||
|
||
max_wait = GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS
|
||
wait_interval = 5
|
||
waited = 0
|
||
|
||
while waited < max_wait:
|
||
self._cleanup_completed_runs()
|
||
if not self.active_graph_runs:
|
||
logger.info(f"{prefix} ✅ All active executions completed")
|
||
break
|
||
else:
|
||
ids = [k.split("-")[0] for k in self.active_graph_runs.keys()]
|
||
logger.info(
|
||
f"{prefix} ⏳ Still waiting for {len(self.active_graph_runs)} executions: {ids}"
|
||
)
|
||
|
||
for graph_exec_id in self.active_graph_runs:
|
||
if lock := self._execution_locks.get(graph_exec_id):
|
||
lock.refresh()
|
||
|
||
time.sleep(wait_interval)
|
||
waited += wait_interval
|
||
|
||
if self.active_graph_runs:
|
||
logger.error(
|
||
f"{prefix} ⚠️ {len(self.active_graph_runs)} executions still running after {max_wait}s"
|
||
)
|
||
else:
|
||
logger.info(f"{prefix} ✅ All executions completed gracefully")
|
||
|
||
# Shutdown the executor
|
||
try:
|
||
self.executor.shutdown(cancel_futures=True, wait=False)
|
||
logger.info(f"{prefix} ✅ Executor shutdown completed")
|
||
except Exception as e:
|
||
logger.error(f"{prefix} ⚠️ Error during executor shutdown: {type(e)} {e}")
|
||
|
||
# Release remaining execution locks
|
||
try:
|
||
for lock in self._execution_locks.values():
|
||
lock.release()
|
||
self._execution_locks.clear()
|
||
logger.info(f"{prefix} ✅ Released execution locks")
|
||
except Exception as e:
|
||
logger.warning(f"{prefix} ⚠️ Failed to release all locks: {e}")
|
||
|
||
# Disconnect the run execution consumer
|
||
self._stop_message_consumers(
|
||
self.run_thread,
|
||
self.run_client,
|
||
prefix + " [run-consumer]",
|
||
)
|
||
self._stop_message_consumers(
|
||
self.cancel_thread,
|
||
self.cancel_client,
|
||
prefix + " [cancel-consumer]",
|
||
)
|
||
|
||
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
|
||
|
||
super().cleanup()
|
||
|
||
|
||
# ------- UTILITIES ------- #
|
||
|
||
|
||
def get_db_client() -> "DatabaseManagerClient":
|
||
return get_database_manager_client()
|
||
|
||
|
||
def get_db_async_client() -> "DatabaseManagerAsyncClient":
|
||
return get_database_manager_async_client()
|
||
|
||
|
||
@func_retry
|
||
async def send_async_execution_update(
|
||
entry: GraphExecution | NodeExecutionResult | None,
|
||
) -> None:
|
||
if entry is None:
|
||
return
|
||
await get_async_execution_event_bus().publish(entry)
|
||
|
||
|
||
@func_retry
|
||
def send_execution_update(entry: GraphExecution | NodeExecutionResult | None):
|
||
if entry is None:
|
||
return
|
||
return get_execution_event_bus().publish(entry)
|
||
|
||
|
||
async def async_update_node_execution_status(
|
||
db_client: "DatabaseManagerAsyncClient",
|
||
exec_id: str,
|
||
status: ExecutionStatus,
|
||
execution_data: BlockInput | None = None,
|
||
stats: dict[str, Any] | None = None,
|
||
) -> NodeExecutionResult:
|
||
"""Sets status and fetches+broadcasts the latest state of the node execution"""
|
||
exec_update = await db_client.update_node_execution_status(
|
||
exec_id, status, execution_data, stats
|
||
)
|
||
await send_async_execution_update(exec_update)
|
||
return exec_update
|
||
|
||
|
||
def update_node_execution_status(
|
||
db_client: "DatabaseManagerClient",
|
||
exec_id: str,
|
||
status: ExecutionStatus,
|
||
execution_data: BlockInput | None = None,
|
||
stats: dict[str, Any] | None = None,
|
||
) -> NodeExecutionResult:
|
||
"""Sets status and fetches+broadcasts the latest state of the node execution"""
|
||
exec_update = db_client.update_node_execution_status(
|
||
exec_id, status, execution_data, stats
|
||
)
|
||
send_execution_update(exec_update)
|
||
return exec_update
|
||
|
||
|
||
async def async_update_graph_execution_state(
|
||
db_client: "DatabaseManagerAsyncClient",
|
||
graph_exec_id: str,
|
||
status: ExecutionStatus | None = None,
|
||
stats: GraphExecutionStats | None = None,
|
||
) -> GraphExecution | None:
|
||
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
|
||
graph_update = await db_client.update_graph_execution_stats(
|
||
graph_exec_id, status, stats
|
||
)
|
||
if graph_update:
|
||
await send_async_execution_update(graph_update)
|
||
else:
|
||
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
|
||
return graph_update
|
||
|
||
|
||
def update_graph_execution_state(
|
||
db_client: "DatabaseManagerClient",
|
||
graph_exec_id: str,
|
||
status: ExecutionStatus | None = None,
|
||
stats: GraphExecutionStats | None = None,
|
||
) -> GraphExecution | None:
|
||
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
|
||
graph_update = db_client.update_graph_execution_stats(graph_exec_id, status, stats)
|
||
if graph_update:
|
||
send_execution_update(graph_update)
|
||
else:
|
||
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
|
||
return graph_update
|
||
|
||
|
||
@asynccontextmanager
|
||
async def synchronized(key: str, timeout: int = settings.config.cluster_lock_timeout):
|
||
r = await redis.get_redis_async()
|
||
lock: AsyncRedisLock = r.lock(f"lock:{key}", timeout=timeout)
|
||
try:
|
||
await lock.acquire()
|
||
yield
|
||
finally:
|
||
if await lock.locked() and await lock.owned():
|
||
try:
|
||
await lock.release()
|
||
except Exception as e:
|
||
logger.warning(f"Failed to release lock for key {key}: {e}")
|
||
|
||
|
||
def increment_execution_count(user_id: str) -> int:
|
||
"""
|
||
Increment the execution count for a given user,
|
||
this will be used to charge the user for the execution cost.
|
||
"""
|
||
r = redis.get_redis()
|
||
k = f"uec:{user_id}" # User Execution Count global key
|
||
counter = cast(int, r.incr(k))
|
||
if counter == 1:
|
||
r.expire(k, settings.config.execution_counter_expiration_time)
|
||
return counter
|