fix(backend): improve executor reliability and error handling (#10526)

This PR improves the reliability of the executor system by addressing
several race conditions and improving error handling throughout the
execution pipeline.

### Changes 🏗️

- **Consolidated exception handling**: Now using `BaseException` to
properly catch all types of interruptions including `CancelledError` and
`SystemExit`
- **Atomic stats updates**: Moved node execution stats updates to be
atomic with graph stats updates to prevent race conditions
- **Improved cleanup handling**: Added proper timeout handling (3600s)
for stuck executions during cleanup
- **Fixed concurrent update race conditions**: Node execution updates
are now properly synchronized with graph execution updates
- **Better error propagation**: Improved error type preservation and
status management throughout the execution chain
- **Graph resumption support**: Added proper handling for resuming
terminated and failed graph executions
- **Removed deprecated methods**: Removed `update_node_execution_stats`
in favor of atomic updates

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Execute a graph with multiple nodes and verify stats are updated
correctly
  - [x] Cancel a running graph execution and verify proper cleanup
  - [x] Simulate node failures and verify error propagation
  - [x] Test graph resumption after termination/failure
  - [x] Verify no race conditions in concurrent node execution updates

#### For configuration changes:
- [x] `.env.example` is updated or already compatible with my changes
- [x] `docker-compose.yml` is updated or already compatible with my
changes
- [x] I have included a list of my configuration changes in the PR
description (under **Changes**)

🤖 Generated with [Claude Code](https://claude.ai/code)

---------

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Zamil Majdy
2025-08-02 18:41:59 +08:00
committed by GitHub
parent 4283798dc2
commit 69d873debc
10 changed files with 342 additions and 277 deletions

View File

@@ -1,4 +1,3 @@
import asyncio
import logging
from typing import Any, Optional
@@ -95,23 +94,14 @@ class AgentExecutorBlock(Block):
logger=logger,
):
yield name, data
except asyncio.CancelledError:
except BaseException as e:
await self._stop(
graph_exec_id=graph_exec.id,
user_id=input_data.user_id,
logger=logger,
)
logger.warning(
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} was cancelled."
)
except Exception as e:
await self._stop(
graph_exec_id=graph_exec.id,
user_id=input_data.user_id,
logger=logger,
)
logger.error(
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} failed: {e}, execution is stopped."
f"Execution of graph {input_data.graph_id}v{input_data.graph_version} failed: {e.__class__.__name__} {str(e)}; execution is stopped."
)
raise
@@ -197,6 +187,7 @@ class AgentExecutorBlock(Block):
await execution_utils.stop_graph_execution(
graph_exec_id=graph_exec_id,
user_id=user_id,
wait_timeout=3600,
)
logger.info(f"Execution {log_id} stopped successfully.")
except Exception as e:

View File

@@ -58,7 +58,7 @@ from .includes import (
GRAPH_EXECUTION_INCLUDE_WITH_NODES,
graph_execution_include,
)
from .model import GraphExecutionStats, NodeExecutionStats
from .model import GraphExecutionStats
T = TypeVar("T")
@@ -636,6 +636,8 @@ async def update_graph_execution_stats(
"OR": [
{"executionStatus": ExecutionStatus.RUNNING},
{"executionStatus": ExecutionStatus.QUEUED},
# Terminated graph can be resumed.
{"executionStatus": ExecutionStatus.TERMINATED},
],
},
data=update_data,
@@ -652,27 +654,6 @@ async def update_graph_execution_stats(
return GraphExecution.from_db(graph_exec)
async def update_node_execution_stats(
node_exec_id: str, stats: NodeExecutionStats
) -> NodeExecutionResult:
data = stats.model_dump()
if isinstance(data["error"], Exception):
data["error"] = str(data["error"])
res = await AgentNodeExecution.prisma().update(
where={"id": node_exec_id},
data={
"stats": SafeJson(data),
"endedTime": datetime.now(tz=timezone.utc),
},
include=EXECUTION_RESULT_INCLUDE,
)
if not res:
raise ValueError(f"Node execution {node_exec_id} not found.")
return NodeExecutionResult.from_db(res)
async def update_node_execution_status_batch(
node_exec_ids: list[str],
status: ExecutionStatus,

View File

@@ -644,7 +644,7 @@ class NodeExecutionStats(BaseModel):
arbitrary_types_allowed=True,
)
error: Optional[Exception | str] = None
error: Optional[BaseException | str] = None
walltime: float = 0
cputime: float = 0
input_size: int = 0

View File

@@ -16,7 +16,6 @@ from backend.data.execution import (
set_execution_kv_data,
update_graph_execution_start_time,
update_graph_execution_stats,
update_node_execution_stats,
update_node_execution_status,
update_node_execution_status_batch,
upsert_execution_input,
@@ -106,7 +105,6 @@ class DatabaseManager(AppService):
update_node_execution_status_batch = _(update_node_execution_status_batch)
update_graph_execution_start_time = _(update_graph_execution_start_time)
update_graph_execution_stats = _(update_graph_execution_stats)
update_node_execution_stats = _(update_node_execution_stats)
upsert_execution_input = _(upsert_execution_input)
upsert_execution_output = _(upsert_execution_output)
get_execution_kv_data = _(get_execution_kv_data)
@@ -167,7 +165,6 @@ class DatabaseManagerClient(AppServiceClient):
update_node_execution_status_batch = _(d.update_node_execution_status_batch)
update_graph_execution_start_time = _(d.update_graph_execution_start_time)
update_graph_execution_stats = _(d.update_graph_execution_stats)
update_node_execution_stats = _(d.update_node_execution_stats)
upsert_execution_input = _(d.upsert_execution_input)
upsert_execution_output = _(d.upsert_execution_output)
get_execution_kv_data = _(d.get_execution_kv_data)
@@ -230,7 +227,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
upsert_execution_input = d.upsert_execution_input
upsert_execution_output = d.upsert_execution_output
update_graph_execution_stats = d.update_graph_execution_stats
update_node_execution_stats = d.update_node_execution_stats
update_node_execution_status = d.update_node_execution_status
update_node_execution_status_batch = d.update_node_execution_status_batch
update_user_integrations = d.update_user_integrations

View File

@@ -6,7 +6,7 @@ import sys
import threading
import time
from collections import defaultdict
from concurrent.futures import CancelledError, Future, ProcessPoolExecutor
from concurrent.futures import Future, ProcessPoolExecutor
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
@@ -410,7 +410,8 @@ class Executor:
cls,
node_exec: NodeExecutionEntry,
node_exec_progress: NodeExecutionProgress,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]],
graph_stats_pair: tuple[GraphExecutionStats, threading.Lock],
) -> NodeExecutionStats:
log_metadata = LogMetadata(
logger=_logger,
@@ -424,25 +425,52 @@ class Executor:
db_client = get_db_async_client()
node = await db_client.get_node(node_exec.node_id)
execution_stats = NodeExecutionStats()
timing_info, _ = await cls._on_node_execution(
timing_info, execution_stats = await cls._on_node_execution(
node=node,
node_exec=node_exec,
node_exec_progress=node_exec_progress,
db_client=db_client,
log_metadata=log_metadata,
stats=execution_stats,
nodes_input_masks=nodes_input_masks,
)
execution_stats.walltime = timing_info.wall_time
execution_stats.cputime = timing_info.cpu_time
if isinstance(execution_stats.error, Exception):
execution_stats.error = str(execution_stats.error)
exec_update = await db_client.update_node_execution_stats(
node_exec.node_exec_id, execution_stats
graph_stats, graph_stats_lock = graph_stats_pair
with graph_stats_lock:
graph_stats.node_count += 1 + execution_stats.extra_steps
graph_stats.nodes_cputime += execution_stats.cputime
graph_stats.nodes_walltime += execution_stats.walltime
graph_stats.cost += execution_stats.extra_cost
if isinstance(execution_stats.error, Exception):
graph_stats.node_error_count += 1
node_error = execution_stats.error
node_stats = execution_stats.model_dump()
if node_error and not isinstance(node_error, str):
node_stats["error"] = str(node_error) or node_stats.__class__.__name__
if isinstance(node_error, Exception):
status = ExecutionStatus.FAILED
elif isinstance(node_error, BaseException):
status = ExecutionStatus.TERMINATED
else:
status = ExecutionStatus.COMPLETED
await async_update_node_execution_status(
db_client=db_client,
exec_id=node_exec.node_exec_id,
status=status,
stats=node_stats,
)
await send_async_execution_update(exec_update)
await async_update_graph_execution_state(
db_client=db_client,
graph_exec_id=node_exec.graph_exec_id,
stats=graph_stats,
)
return execution_stats
@classmethod
@@ -454,9 +482,10 @@ class Executor:
node_exec_progress: NodeExecutionProgress,
db_client: "DatabaseManagerAsyncClient",
log_metadata: LogMetadata,
stats: NodeExecutionStats | None = None,
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
):
) -> NodeExecutionStats:
stats = NodeExecutionStats()
try:
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
await async_update_node_execution_status(
@@ -480,19 +509,26 @@ class Executor:
)
)
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
except Exception as e:
# Avoid user error being marked as an actual error.
except BaseException as e:
stats.error = e
if isinstance(e, ValueError):
# Avoid user error being marked as an actual error.
log_metadata.info(
f"Failed node execution {node_exec.node_exec_id}: {e}"
f"Expected failure on node execution {node_exec.node_exec_id}: {e}"
)
elif isinstance(e, Exception):
# If the exception is not a ValueError, it is unexpected.
log_metadata.exception(
f"Unexpected failure on node execution {node_exec.node_exec_id}: {type(e).__name__} - {e}"
)
else:
log_metadata.exception(
f"Failed node execution {node_exec.node_exec_id}: {e}"
# CancelledError or SystemExit
log_metadata.warning(
f"Interuption error on node execution {node_exec.node_exec_id}: {type(e).__name__}"
)
if stats is not None:
stats.error = e
return stats
@classmethod
@func_retry
@@ -551,6 +587,16 @@ class Executor:
log_metadata.info(
f"⚙️ Graph execution #{graph_exec.graph_exec_id} is already running, continuing where it left off."
)
elif exec_meta.status == ExecutionStatus.FAILED:
exec_meta.status = ExecutionStatus.RUNNING
log_metadata.info(
f"⚙️ Graph execution #{graph_exec.graph_exec_id} was disturbed, continuing where it left off."
)
update_graph_execution_state(
db_client=db_client,
graph_exec_id=graph_exec.graph_exec_id,
status=ExecutionStatus.RUNNING,
)
else:
log_metadata.warning(
f"Skipped graph execution {graph_exec.graph_exec_id}, the graph execution status is `{exec_meta.status}`."
@@ -602,14 +648,13 @@ class Executor:
except Exception as e:
log_metadata.error(f"Failed to generate activity status: {str(e)}")
# Don't set activity_status on exception - let it remain None/unset
if graph_exec_result := db_client.update_graph_execution_stats(
update_graph_execution_state(
db_client=db_client,
graph_exec_id=graph_exec.graph_exec_id,
status=status,
stats=exec_stats,
):
send_execution_update(graph_exec_result)
)
cls._handle_agent_run_notif(db_client, graph_exec, exec_stats)
@@ -664,7 +709,6 @@ class Executor:
execution_stats.cost += cost
@classmethod
@func_retry
@time_measured
def _on_graph_execution(
cls,
@@ -682,47 +726,11 @@ class Executor:
execution_status: ExecutionStatus = ExecutionStatus.RUNNING
error: Exception | None = None
db_client = get_db_client()
def on_done_task(node_exec_id: str, result: object):
if not isinstance(result, NodeExecutionStats):
log_metadata.error(f"Unexpected result #{node_exec_id}: {type(result)}")
return
nonlocal execution_stats
execution_stats.node_count += 1 + result.extra_steps
execution_stats.nodes_cputime += result.cputime
execution_stats.nodes_walltime += result.walltime
execution_stats.cost += result.extra_cost
if (err := result.error) and isinstance(err, Exception):
execution_stats.node_error_count += 1
update_node_execution_status(
db_client=db_client,
exec_id=node_exec_id,
status=ExecutionStatus.FAILED,
)
else:
update_node_execution_status(
db_client=db_client,
exec_id=node_exec_id,
status=ExecutionStatus.COMPLETED,
)
if _graph_exec := db_client.update_graph_execution_stats(
graph_exec_id=graph_exec.graph_exec_id,
stats=execution_stats,
):
send_execution_update(_graph_exec)
else:
log_metadata.error(
"Callback for finished node execution "
f"#{node_exec_id} could not update execution stats "
f"for graph execution #{graph_exec.graph_exec_id}; "
f"triggered while graph exec status = {execution_status}"
)
execution_stats_lock = threading.Lock()
# State holders ----------------------------------------------------
running_node_execution: dict[str, NodeExecutionProgress] = defaultdict(
lambda: NodeExecutionProgress(on_done_task=on_done_task)
NodeExecutionProgress
)
running_node_evaluation: dict[str, Future] = {}
execution_queue = ExecutionQueue[NodeExecutionEntry]()
@@ -741,7 +749,11 @@ class Executor:
# ------------------------------------------------------------
for node_exec in db_client.get_node_executions(
graph_exec.graph_exec_id,
statuses=[ExecutionStatus.RUNNING, ExecutionStatus.QUEUED],
statuses=[
ExecutionStatus.RUNNING,
ExecutionStatus.QUEUED,
ExecutionStatus.TERMINATED,
],
):
execution_queue.add(node_exec.to_node_execution_entry())
@@ -750,8 +762,7 @@ class Executor:
# ------------------------------------------------------------
while not execution_queue.empty():
if cancel.is_set():
execution_status = ExecutionStatus.TERMINATED
return execution_stats, execution_status, error
break
queued_node_exec = execution_queue.get()
@@ -804,6 +815,10 @@ class Executor:
node_exec=queued_node_exec,
node_exec_progress=running_node_execution[node_id],
nodes_input_masks=nodes_input_masks,
graph_stats_pair=(
execution_stats,
execution_stats_lock,
),
),
cls.node_execution_loop,
)
@@ -816,14 +831,16 @@ class Executor:
while execution_queue.empty() and (
running_node_execution or running_node_evaluation
):
if cancel.is_set():
break
# --------------------------------------------------
# Handle inflight evaluations ---------------------
# --------------------------------------------------
node_output_found = False
for node_id, inflight_exec in list(running_node_execution.items()):
if cancel.is_set():
execution_status = ExecutionStatus.TERMINATED
return execution_stats, execution_status, error
break
# node evaluation future -----------------
if inflight_eval := running_node_evaluation.get(node_id):
@@ -864,26 +881,36 @@ class Executor:
time.sleep(0.1)
# loop done --------------------------------------------------
# Determine final execution status based on whether there was an error
execution_status = (
ExecutionStatus.FAILED if error else ExecutionStatus.COMPLETED
# Determine final execution status based on whether there was an error or termination
if cancel.is_set():
execution_status = ExecutionStatus.TERMINATED
elif error is not None:
execution_status = ExecutionStatus.FAILED
else:
execution_status = ExecutionStatus.COMPLETED
return execution_stats, execution_status, error
except BaseException as exc:
execution_status = ExecutionStatus.FAILED
error = (
exc
if isinstance(exc, Exception)
else Exception(f"{exc.__class__.__name__}: {exc}")
)
except CancelledError as exc:
execution_status = ExecutionStatus.TERMINATED
error = exc
log_metadata.exception(
f"Cancelled graph execution {graph_exec.graph_exec_id}: {error}"
)
except Exception as exc:
execution_status = ExecutionStatus.FAILED
error = exc
known_errors = (InsufficientBalanceError,)
if isinstance(error, known_errors):
return execution_stats, execution_status, error
log_metadata.exception(
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
)
raise
finally:
# Use helper method with error handling to ensure cleanup never fails
cls._cleanup_graph_execution(
execution_queue=execution_queue,
running_node_execution=running_node_execution,
running_node_evaluation=running_node_evaluation,
execution_status=execution_status,
@@ -892,12 +919,12 @@ class Executor:
log_metadata=log_metadata,
db_client=db_client,
)
return execution_stats, execution_status, error
@classmethod
@error_logged(swallow=True)
def _cleanup_graph_execution(
cls,
execution_queue: ExecutionQueue[NodeExecutionEntry],
running_node_execution: dict[str, "NodeExecutionProgress"],
running_node_evaluation: dict[str, Future],
execution_status: ExecutionStatus,
@@ -911,6 +938,22 @@ class Executor:
This method is decorated with @error_logged(swallow=True) to ensure cleanup
never fails in the finally block.
"""
# Cancel and wait for all node evaluations to complete
for node_id, inflight_eval in running_node_evaluation.items():
if inflight_eval.done():
continue
log_metadata.info(f"Stopping node evaluation {node_id}")
inflight_eval.cancel()
for node_id, inflight_eval in running_node_evaluation.items():
try:
inflight_eval.result(timeout=3600.0)
except TimeoutError:
log_metadata.exception(
f"Node evaluation #{node_id} did not stop in time, "
"it may be stuck or taking too long."
)
# Cancel and wait for all node executions to complete
for node_id, inflight_exec in running_node_execution.items():
if inflight_exec.is_done():
@@ -918,51 +961,22 @@ class Executor:
log_metadata.info(f"Stopping node execution {node_id}")
inflight_exec.stop()
for node_id, inflight_eval in running_node_evaluation.items():
if inflight_eval.done():
continue
log_metadata.info(f"Stopping node evaluation {node_id}")
inflight_eval.cancel()
for node_id, inflight_exec in running_node_execution.items():
if inflight_exec.is_done():
continue
try:
inflight_exec.wait_for_cancellation(timeout=60.0)
inflight_exec.wait_for_done(timeout=3600.0)
except TimeoutError:
log_metadata.exception(
f"Node execution #{node_id} did not stop in time, "
"it may be stuck or taking too long."
)
for node_id, inflight_eval in running_node_evaluation.items():
if inflight_eval.done():
continue
try:
inflight_eval.result(timeout=60.0)
except TimeoutError:
log_metadata.exception(
f"Node evaluation #{node_id} did not stop in time, "
"it may be stuck or taking too long."
)
if execution_status in [ExecutionStatus.TERMINATED, ExecutionStatus.FAILED]:
inflight_executions = db_client.get_node_executions(
graph_exec_id,
statuses=[
ExecutionStatus.QUEUED,
ExecutionStatus.RUNNING,
],
include_exec_data=False,
)
db_client.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in inflight_executions],
while queued_execution := execution_queue.get_or_none():
update_node_execution_status(
db_client=db_client,
exec_id=queued_execution.node_exec_id,
status=execution_status,
stats={"error": str(error)} if error else None,
)
for node_exec in inflight_executions:
node_exec.status = execution_status
send_execution_update(node_exec)
clean_exec_files(graph_exec_id)
@@ -1010,6 +1024,15 @@ class Executor:
nodes_input_masks=nodes_input_masks,
):
execution_queue.add(next_execution)
except asyncio.CancelledError as e:
log_metadata.warning(
f"Node execution {output.node_exec_id} was cancelled: {e}"
)
await async_update_node_execution_status(
db_client=db_client,
exec_id=output.node_exec_id,
status=ExecutionStatus.TERMINATED,
)
except Exception as e:
log_metadata.exception(f"Failed to process node output: {e}")
await db_client.upsert_execution_output(
@@ -1142,7 +1165,7 @@ class ExecutionManager(AppProcess):
def _consume_execution_run(self):
# Long-running executions are handled by:
# 1. Disabled consumer timeout (x-consumer-timeout: 0) allows unlimited execution time
# 1. Long consumer timeout (x-consumer-timeout) allows long running agent
# 2. Enhanced connection settings (5 retries, 1s delay) for quick reconnection
# 3. Process monitoring ensures failed executors release messages back to queue
@@ -1165,6 +1188,7 @@ class ExecutionManager(AppProcess):
run_channel.start_consuming()
raise RuntimeError(f"❌ run message consumer is stopped: {run_channel}")
@error_logged(swallow=True)
def _handle_cancel_message(
self,
channel: BlockingChannel,
@@ -1176,31 +1200,27 @@ class ExecutionManager(AppProcess):
Called whenever we receive a CANCEL message from the queue.
(With auto_ack=True, message is considered 'acked' automatically.)
"""
try:
request = CancelExecutionEvent.model_validate_json(body)
graph_exec_id = request.graph_exec_id
if not graph_exec_id:
logger.warning(
f"[{self.service_name}] Cancel message missing 'graph_exec_id'"
)
return
if graph_exec_id not in self.active_graph_runs:
logger.debug(
f"[{self.service_name}] Cancel received for {graph_exec_id} but not active."
)
return
request = CancelExecutionEvent.model_validate_json(body)
graph_exec_id = request.graph_exec_id
if not graph_exec_id:
logger.warning(
f"[{self.service_name}] Cancel message missing 'graph_exec_id'"
)
return
if graph_exec_id not in self.active_graph_runs:
logger.debug(
f"[{self.service_name}] Cancel received for {graph_exec_id} but not active."
)
return
_, cancel_event = self.active_graph_runs[graph_exec_id]
logger.info(f"[{self.service_name}] Received cancel for {graph_exec_id}")
if not cancel_event.is_set():
cancel_event.set()
else:
logger.debug(
f"[{self.service_name}] Cancel already set for {graph_exec_id}"
)
except Exception as e:
logger.exception(f"Error handling cancel message: {e}")
_, cancel_event = self.active_graph_runs[graph_exec_id]
logger.info(f"[{self.service_name}] Received cancel for {graph_exec_id}")
if not cancel_event.is_set():
cancel_event.set()
else:
logger.debug(
f"[{self.service_name}] Cancel already set for {graph_exec_id}"
)
def _handle_run_message(
self,
@@ -1229,7 +1249,8 @@ class ExecutionManager(AppProcess):
f"[{self.service_name}] Received RUN for graph_exec_id={graph_exec_id}"
)
if graph_exec_id in self.active_graph_runs:
logger.warning(
# TODO: Make this check cluster-wide, prevent duplicate runs across executor pods.
logger.error(
f"[{self.service_name}] Graph {graph_exec_id} already running; rejecting duplicate run."
)
channel.basic_nack(delivery_tag, requeue=False)
@@ -1249,7 +1270,7 @@ class ExecutionManager(AppProcess):
try:
if exec_error := f.exception():
logger.error(
f"[{self.service_name}] Execution for {graph_exec_id} failed: {exec_error}"
f"[{self.service_name}] Execution for {graph_exec_id} failed: {type(exec_error)} {exec_error}"
)
try:
channel.connection.add_callback_threadsafe(
@@ -1388,6 +1409,38 @@ def update_node_execution_status(
return exec_update
async def async_update_graph_execution_state(
db_client: "DatabaseManagerAsyncClient",
graph_exec_id: str,
status: ExecutionStatus | None = None,
stats: GraphExecutionStats | None = None,
) -> GraphExecution | None:
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
graph_update = await db_client.update_graph_execution_stats(
graph_exec_id, status, stats
)
if graph_update:
await send_async_execution_update(graph_update)
else:
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
return graph_update
def update_graph_execution_state(
db_client: "DatabaseManagerClient",
graph_exec_id: str,
status: ExecutionStatus | None = None,
stats: GraphExecutionStats | None = None,
) -> GraphExecution | None:
"""Sets status and fetches+broadcasts the latest state of the graph execution"""
graph_update = db_client.update_graph_execution_stats(graph_exec_id, status, stats)
if graph_update:
send_execution_update(graph_update)
else:
logger.error(f"Failed to update graph execution stats for {graph_exec_id}")
return graph_update
@asynccontextmanager
async def synchronized(key: str, timeout: int = 60):
r = await redis.get_redis_async()

View File

@@ -4,7 +4,7 @@ import threading
import time
from collections import defaultdict
from concurrent.futures import Future
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast
from autogpt_libs.utils.cache import thread_cached
from pydantic import BaseModel, JsonValue
@@ -669,13 +669,13 @@ def create_execution_queue_config() -> RabbitMQConfig:
durable=True,
auto_delete=False,
arguments={
# x-consumer-timeout (0 = disabled)
# x-consumer-timeout (1 week)
# Problem: Default 30-minute consumer timeout kills long-running graph executions
# Original error: "Consumer acknowledgement timed out after 1800000 ms (30 minutes)"
# Solution: Disable consumer timeout entirely - let graphs run indefinitely
# Safety: Heartbeat mechanism now handles dead consumer detection instead
# Use case: Graph executions that take hours to complete (AI model training, etc.)
"x-consumer-timeout": 0,
"x-consumer-timeout": (7 * 24 * 60 * 60 * 1000), # 7 days in milliseconds
},
)
cancel_queue = Queue(
@@ -737,51 +737,28 @@ async def stop_graph_execution(
ExecutionStatus.QUEUED,
ExecutionStatus.INCOMPLETE,
]:
break
# If the graph is still on the queue, we can prevent them from being executed
# by setting the status to TERMINATED.
graph_exec.status = ExecutionStatus.TERMINATED
await asyncio.gather(
# Update graph execution status
db.update_graph_execution_stats(
graph_exec_id=graph_exec.id,
status=ExecutionStatus.TERMINATED,
),
# Publish graph execution event
get_async_execution_event_bus().publish(graph_exec),
)
return
if graph_exec.status == ExecutionStatus.RUNNING:
await asyncio.sleep(0.1)
# Set the termination status if the graph is not stopped after the timeout.
if graph_exec := await db.get_graph_execution_meta(
execution_id=graph_exec_id, user_id=user_id
):
# If the graph is still on the queue, we can prevent them from being executed
# by setting the status to TERMINATED.
node_execs = await db.get_node_executions(
graph_exec_id=graph_exec_id,
statuses=[
ExecutionStatus.QUEUED,
ExecutionStatus.RUNNING,
],
include_exec_data=False,
)
graph_exec.status = ExecutionStatus.TERMINATED
for node_exec in node_execs:
node_exec.status = ExecutionStatus.TERMINATED
await asyncio.gather(
# Update node execution statuses
db.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in node_execs],
ExecutionStatus.TERMINATED,
),
# Publish node execution events
*[
get_async_execution_event_bus().publish(node_exec)
for node_exec in node_execs
],
)
await asyncio.gather(
# Update graph execution status
db.update_graph_execution_stats(
graph_exec_id=graph_exec_id,
status=ExecutionStatus.TERMINATED,
),
# Publish graph execution event
get_async_execution_event_bus().publish(graph_exec),
)
raise TimeoutError(
f"Graph execution #{graph_exec_id} will need to take longer than {wait_timeout} seconds to stop. "
f"You can check the status of the execution in the UI or try again later."
)
async def add_graph_execution(
@@ -888,13 +865,9 @@ class ExecutionOutputEntry(BaseModel):
class NodeExecutionProgress:
def __init__(
self,
on_done_task: Callable[[str, object], None],
):
def __init__(self):
self.output: dict[str, list[ExecutionOutputEntry]] = defaultdict(list)
self.tasks: dict[str, Future] = {}
self.on_done_task = on_done_task
self._lock = threading.Lock()
def add_task(self, node_exec_id: str, task: Future):
@@ -934,7 +907,9 @@ class NodeExecutionProgress:
except TimeoutError:
pass
except Exception as e:
logger.error(f"Task for exec ID {exec_id} failed with error: {str(e)}")
logger.error(
f"Task for exec ID {exec_id} failed with error: {e.__class__.__name__} {str(e)}"
)
pass
return self.is_done(0)
@@ -952,7 +927,7 @@ class NodeExecutionProgress:
cancelled_ids.append(task_id)
return cancelled_ids
def wait_for_cancellation(self, timeout: float = 5.0):
def wait_for_done(self, timeout: float = 5.0):
"""
Wait for all cancelled tasks to complete cancellation.
@@ -962,9 +937,12 @@ class NodeExecutionProgress:
start_time = time.time()
while time.time() - start_time < timeout:
# Check if all tasks are done (either completed or cancelled)
if all(task.done() for task in self.tasks.values()):
return True
while self.pop_output():
pass
if self.is_done():
return
time.sleep(0.1) # Small delay to avoid busy waiting
raise TimeoutError(
@@ -983,11 +961,7 @@ class NodeExecutionProgress:
if self.output[exec_id]:
return False
if task := self.tasks.pop(exec_id):
try:
self.on_done_task(exec_id, task.result())
except Exception as e:
logger.error(f"Task for exec ID {exec_id} failed with error: {str(e)}")
self.tasks.pop(exec_id)
return True
def _next_exec(self) -> str | None:

View File

@@ -16,6 +16,8 @@ from typing import (
from pydantic import BaseModel
from backend.util.logging import TruncatedLogger
class TimingInfo(BaseModel):
cpu_time: float
@@ -37,7 +39,7 @@ def _end_measurement(
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
logger = TruncatedLogger(logging.getLogger(__name__))
def time_measured(func: Callable[P, T]) -> Callable[P, Tuple[TimingInfo, T]]:
@@ -120,7 +122,7 @@ def error_logged(
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
try:
return f(*args, **kwargs)
except Exception as e:
except BaseException as e:
logger.exception(
f"Error when calling function {f.__name__} with arguments {args} {kwargs}: {e}"
)
@@ -177,13 +179,13 @@ def async_error_logged(*, swallow: bool = True) -> (
"""
def decorator(
f: Callable[P, Coroutine[Any, Any, T]]
f: Callable[P, Coroutine[Any, Any, T]],
) -> Callable[P, Coroutine[Any, Any, T | None]]:
@functools.wraps(f)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
try:
return await f(*args, **kwargs)
except Exception as e:
except BaseException as e:
logger.exception(
f"Error when calling async function {f.__name__} with arguments {args} {kwargs}: {e}"
)

View File

@@ -6,13 +6,83 @@ import time
from functools import wraps
from uuid import uuid4
from tenacity import retry, stop_after_attempt, wait_exponential
from tenacity import (
retry,
retry_if_not_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
from backend.util.process import get_service_name
logger = logging.getLogger(__name__)
def _create_retry_callback(context: str = ""):
"""Create a retry callback with optional context."""
def callback(retry_state):
attempt_number = retry_state.attempt_number
exception = retry_state.outcome.exception()
func_name = getattr(retry_state.fn, "__name__", "unknown")
prefix = f"{context}: " if context else ""
if retry_state.outcome.failed and retry_state.next_action is None:
# Final failure
logger.error(
f"{prefix}Giving up after {attempt_number} attempts for '{func_name}': "
f"{type(exception).__name__}: {exception}"
)
else:
# Retry attempt
logger.warning(
f"{prefix}Retry attempt {attempt_number} for '{func_name}': "
f"{type(exception).__name__}: {exception}"
)
return callback
def create_retry_decorator(
max_attempts: int = 5,
exclude_exceptions: tuple[type[BaseException], ...] = (),
max_wait: float = 30.0,
context: str = "",
reraise: bool = True,
):
"""
Create a preconfigured retry decorator with sensible defaults.
Uses exponential backoff with jitter by default.
Args:
max_attempts: Maximum number of attempts (default: 5)
exclude_exceptions: Tuple of exception types to not retry on
max_wait: Maximum wait time in seconds (default: 30)
context: Optional context string for log messages
reraise: Whether to reraise the final exception (default: True)
Returns:
Configured retry decorator
"""
if exclude_exceptions:
return retry(
stop=stop_after_attempt(max_attempts),
wait=wait_exponential_jitter(max=max_wait),
before_sleep=_create_retry_callback(context),
reraise=reraise,
retry=retry_if_not_exception_type(exclude_exceptions),
)
else:
return retry(
stop=stop_after_attempt(max_attempts),
wait=wait_exponential_jitter(max=max_wait),
before_sleep=_create_retry_callback(context),
reraise=reraise,
)
def _log_prefix(resource_name: str, conn_id: str):
"""
Returns a prefix string for logging purposes.
@@ -26,8 +96,6 @@ def conn_retry(
resource_name: str,
action_name: str,
max_retry: int = 5,
multiplier: int = 1,
min_wait: float = 1,
max_wait: float = 30,
):
conn_id = str(uuid4())
@@ -35,13 +103,20 @@ def conn_retry(
def on_retry(retry_state):
prefix = _log_prefix(resource_name, conn_id)
exception = retry_state.outcome.exception()
logger.warning(f"{prefix} {action_name} failed: {exception}. Retrying now...")
if retry_state.outcome.failed and retry_state.next_action is None:
logger.error(f"{prefix} {action_name} failed after retries: {exception}")
else:
logger.warning(
f"{prefix} {action_name} failed: {exception}. Retrying now..."
)
def decorator(func):
is_coroutine = asyncio.iscoroutinefunction(func)
# Use static retry configuration
retry_decorator = retry(
stop=stop_after_attempt(max_retry + 1),
wait=wait_exponential(multiplier=multiplier, min=min_wait, max=max_wait),
stop=stop_after_attempt(max_retry + 1), # +1 for the initial attempt
wait=wait_exponential_jitter(max=max_wait),
before_sleep=on_retry,
reraise=True,
)
@@ -76,11 +151,8 @@ def conn_retry(
return decorator
func_retry = retry(
reraise=False,
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=1, max=30),
)
# Preconfigured retry decorator for general functions
func_retry = create_retry_decorator(max_attempts=5, reraise=False)
def continuous_retry(*, retry_delay: float = 1.0):

View File

@@ -8,7 +8,7 @@ from backend.util.retry import conn_retry
def test_conn_retry_sync_function():
retry_count = 0
@conn_retry("Test", "Test function", max_retry=2, max_wait=0.1, min_wait=0.1)
@conn_retry("Test", "Test function", max_retry=2, max_wait=0.1)
def test_function():
nonlocal retry_count
retry_count -= 1
@@ -30,7 +30,7 @@ def test_conn_retry_sync_function():
async def test_conn_retry_async_function():
retry_count = 0
@conn_retry("Test", "Test function", max_retry=2, max_wait=0.1, min_wait=0.1)
@conn_retry("Test", "Test function", max_retry=2, max_wait=0.1)
async def test_function():
nonlocal retry_count
await asyncio.sleep(1)

View File

@@ -1,4 +1,6 @@
import asyncio
import concurrent
import concurrent.futures
import inspect
import logging
import os
@@ -25,18 +27,12 @@ import uvicorn
from autogpt_libs.logging.utils import generate_uvicorn_config
from fastapi import FastAPI, Request, responses
from pydantic import BaseModel, TypeAdapter, create_model
from tenacity import (
retry,
retry_if_not_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
import backend.util.exceptions as exceptions
from backend.util.json import to_dict
from backend.util.metrics import sentry_init
from backend.util.process import AppProcess, get_service_name
from backend.util.retry import conn_retry
from backend.util.retry import conn_retry, create_retry_decorator
from backend.util.settings import Config
logger = logging.getLogger(__name__)
@@ -290,18 +286,19 @@ def get_service_client(
if not request_retry:
return fn
return retry(
reraise=True,
stop=stop_after_attempt(api_comm_retry),
wait=wait_exponential_jitter(max=5.0),
retry=retry_if_not_exception_type(
(
# Don't retry these specific exceptions that won't be fixed by retrying
ValueError, # Invalid input/parameters
KeyError, # Missing required data
TypeError, # Wrong data types
AttributeError, # Missing attributes
)
# Use preconfigured retry decorator for service communication
return create_retry_decorator(
max_attempts=api_comm_retry,
max_wait=5.0,
context="Service communication",
exclude_exceptions=(
# Don't retry these specific exceptions that won't be fixed by retrying
ValueError, # Invalid input/parameters
KeyError, # Missing required data
TypeError, # Wrong data types
AttributeError, # Missing attributes
asyncio.CancelledError, # Task was cancelled
concurrent.futures.CancelledError, # Future was cancelled
),
)(fn)
@@ -379,7 +376,6 @@ def get_service_client(
self._connection_failure_count = 0
return response.json()
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error in {method_name}: {e.response.text}")
error = RemoteCallError.model_validate(e.response.json())
# DEBUG HELP: if you made a custom exception, make sure you override self.args to be how to make your exception
raise EXCEPTION_MAPPING.get(error.type, Exception)(