mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(backend): Make Graph & Node Execution Stats Update Durable (#10529)
Graph and Node execution can fail due to so many reasons, sometimes this messes up the stats tracking, giving an inaccurate result. The scope of this PR is to minimize such issues. ### Changes 🏗️ * Catch BaseException on time_measured decorator to catch asyncio.CancelledError * Make sure update node & graph stats are executed on cancellation & exception. * Protect graph execution stats update under the thread lock to avoid race condition. ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: <!-- Put your test plan here: --> - [x] Existing automated tests. --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -412,20 +412,24 @@ class Executor:
|
||||
)
|
||||
db_client = get_db_async_client()
|
||||
node = await db_client.get_node(node_exec.node_id)
|
||||
execution_stats = NodeExecutionStats()
|
||||
|
||||
timing_info, execution_stats = await cls._on_node_execution(
|
||||
timing_info, status = await cls._on_node_execution(
|
||||
node=node,
|
||||
node_exec=node_exec,
|
||||
node_exec_progress=node_exec_progress,
|
||||
stats=execution_stats,
|
||||
db_client=db_client,
|
||||
log_metadata=log_metadata,
|
||||
nodes_input_masks=nodes_input_masks,
|
||||
)
|
||||
if isinstance(status, BaseException):
|
||||
raise status
|
||||
|
||||
execution_stats.walltime = timing_info.wall_time
|
||||
execution_stats.cputime = timing_info.cpu_time
|
||||
|
||||
graph_stats, graph_stats_lock = graph_stats_pair
|
||||
|
||||
with graph_stats_lock:
|
||||
graph_stats.node_count += 1 + execution_stats.extra_steps
|
||||
graph_stats.nodes_cputime += execution_stats.cputime
|
||||
@@ -439,24 +443,18 @@ class Executor:
|
||||
if node_error and not isinstance(node_error, str):
|
||||
node_stats["error"] = str(node_error) or node_stats.__class__.__name__
|
||||
|
||||
if isinstance(node_error, Exception):
|
||||
status = ExecutionStatus.FAILED
|
||||
elif isinstance(node_error, BaseException):
|
||||
status = ExecutionStatus.TERMINATED
|
||||
else:
|
||||
status = ExecutionStatus.COMPLETED
|
||||
|
||||
await async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
exec_id=node_exec.node_exec_id,
|
||||
status=status,
|
||||
stats=node_stats,
|
||||
)
|
||||
|
||||
await async_update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
stats=graph_stats,
|
||||
await asyncio.gather(
|
||||
async_update_node_execution_status(
|
||||
db_client=db_client,
|
||||
exec_id=node_exec.node_exec_id,
|
||||
status=status,
|
||||
stats=node_stats,
|
||||
),
|
||||
async_update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
stats=graph_stats,
|
||||
),
|
||||
)
|
||||
|
||||
return execution_stats
|
||||
@@ -468,12 +466,11 @@ class Executor:
|
||||
node: Node,
|
||||
node_exec: NodeExecutionEntry,
|
||||
node_exec_progress: NodeExecutionProgress,
|
||||
stats: NodeExecutionStats,
|
||||
db_client: "DatabaseManagerAsyncClient",
|
||||
log_metadata: LogMetadata,
|
||||
nodes_input_masks: Optional[dict[str, dict[str, JsonValue]]] = None,
|
||||
) -> NodeExecutionStats:
|
||||
stats = NodeExecutionStats()
|
||||
|
||||
) -> ExecutionStatus:
|
||||
try:
|
||||
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
|
||||
await async_update_node_execution_status(
|
||||
@@ -497,6 +494,8 @@ class Executor:
|
||||
)
|
||||
)
|
||||
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
|
||||
status = ExecutionStatus.COMPLETED
|
||||
|
||||
except BaseException as e:
|
||||
stats.error = e
|
||||
|
||||
@@ -505,18 +504,21 @@ class Executor:
|
||||
log_metadata.info(
|
||||
f"Expected failure on node execution {node_exec.node_exec_id}: {e}"
|
||||
)
|
||||
status = ExecutionStatus.FAILED
|
||||
elif isinstance(e, Exception):
|
||||
# If the exception is not a ValueError, it is unexpected.
|
||||
log_metadata.exception(
|
||||
f"Unexpected failure on node execution {node_exec.node_exec_id}: {type(e).__name__} - {e}"
|
||||
)
|
||||
status = ExecutionStatus.FAILED
|
||||
else:
|
||||
# CancelledError or SystemExit
|
||||
log_metadata.warning(
|
||||
f"Interuption error on node execution {node_exec.node_exec_id}: {type(e).__name__}"
|
||||
)
|
||||
status = ExecutionStatus.TERMINATED
|
||||
|
||||
return stats
|
||||
return status
|
||||
|
||||
@classmethod
|
||||
@func_retry
|
||||
@@ -591,29 +593,27 @@ class Executor:
|
||||
)
|
||||
return
|
||||
|
||||
timing_info, (exec_stats, status, error) = cls._on_graph_execution(
|
||||
if exec_meta.stats is None:
|
||||
exec_stats = GraphExecutionStats()
|
||||
else:
|
||||
exec_stats = exec_meta.stats.to_db()
|
||||
|
||||
timing_info, status = cls._on_graph_execution(
|
||||
graph_exec=graph_exec,
|
||||
cancel=cancel,
|
||||
log_metadata=log_metadata,
|
||||
execution_stats=(
|
||||
exec_meta.stats.to_db() if exec_meta.stats else GraphExecutionStats()
|
||||
),
|
||||
execution_stats=exec_stats,
|
||||
)
|
||||
exec_stats.walltime += timing_info.wall_time
|
||||
exec_stats.cputime += timing_info.cpu_time
|
||||
exec_stats.error = str(error) if error else exec_stats.error
|
||||
|
||||
if status not in {
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.TERMINATED,
|
||||
ExecutionStatus.FAILED,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Graph Execution #{graph_exec.graph_exec_id} ended with unexpected status {status}"
|
||||
)
|
||||
|
||||
# Generate AI activity status before updating stats
|
||||
try:
|
||||
# Failure handling
|
||||
if isinstance(status, BaseException):
|
||||
raise status
|
||||
exec_meta.status = status
|
||||
|
||||
# Activity status handling
|
||||
activity_status = asyncio.run_coroutine_threadsafe(
|
||||
generate_activity_status_for_execution(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
@@ -634,30 +634,29 @@ class Executor:
|
||||
"Activity status generation disabled, not setting field"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log_metadata.error(f"Failed to generate activity status: {str(e)}")
|
||||
# Communication handling
|
||||
cls._handle_agent_run_notif(db_client, graph_exec, exec_stats)
|
||||
|
||||
update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
status=status,
|
||||
stats=exec_stats,
|
||||
)
|
||||
|
||||
cls._handle_agent_run_notif(db_client, graph_exec, exec_stats)
|
||||
finally:
|
||||
update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
status=exec_meta.status,
|
||||
stats=exec_stats,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _charge_usage(
|
||||
cls,
|
||||
node_exec: NodeExecutionEntry,
|
||||
execution_count: int,
|
||||
execution_stats: GraphExecutionStats,
|
||||
):
|
||||
) -> int:
|
||||
total_cost = 0
|
||||
db_client = get_db_client()
|
||||
block = get_block(node_exec.block_id)
|
||||
if not block:
|
||||
logger.error(f"Block {node_exec.block_id} not found.")
|
||||
return
|
||||
return total_cost
|
||||
|
||||
cost, matching_filter = block_usage_cost(
|
||||
block=block, input_data=node_exec.inputs
|
||||
@@ -677,7 +676,7 @@ class Executor:
|
||||
reason=f"Ran block {node_exec.block_id} {block.name}",
|
||||
),
|
||||
)
|
||||
execution_stats.cost += cost
|
||||
total_cost += cost
|
||||
|
||||
cost, usage_count = execution_usage_cost(execution_count)
|
||||
if cost > 0:
|
||||
@@ -694,7 +693,9 @@ class Executor:
|
||||
reason=f"Execution Cost for {usage_count} blocks of ex_id:{node_exec.graph_exec_id} g_id:{node_exec.graph_id}",
|
||||
),
|
||||
)
|
||||
execution_stats.cost += cost
|
||||
total_cost += cost
|
||||
|
||||
return total_cost
|
||||
|
||||
@classmethod
|
||||
@time_measured
|
||||
@@ -704,7 +705,7 @@ class Executor:
|
||||
cancel: threading.Event,
|
||||
log_metadata: LogMetadata,
|
||||
execution_stats: GraphExecutionStats,
|
||||
) -> tuple[GraphExecutionStats, ExecutionStatus, Exception | None]:
|
||||
) -> ExecutionStatus:
|
||||
"""
|
||||
Returns:
|
||||
dict: The execution statistics of the graph execution.
|
||||
@@ -761,11 +762,12 @@ class Executor:
|
||||
|
||||
# Charge usage (may raise) ------------------------------
|
||||
try:
|
||||
cls._charge_usage(
|
||||
cost = cls._charge_usage(
|
||||
node_exec=queued_node_exec,
|
||||
execution_count=increment_execution_count(graph_exec.user_id),
|
||||
execution_stats=execution_stats,
|
||||
)
|
||||
with execution_stats_lock:
|
||||
execution_stats.cost += cost
|
||||
except InsufficientBalanceError as balance_error:
|
||||
error = balance_error # Set error to trigger FAILED status
|
||||
node_exec_id = queued_node_exec.node_exec_id
|
||||
@@ -877,10 +879,12 @@ class Executor:
|
||||
else:
|
||||
execution_status = ExecutionStatus.COMPLETED
|
||||
|
||||
return execution_stats, execution_status, error
|
||||
if error:
|
||||
execution_stats.error = str(error) or type(error).__name__
|
||||
|
||||
return execution_status
|
||||
|
||||
except BaseException as exc:
|
||||
execution_status = ExecutionStatus.FAILED
|
||||
error = (
|
||||
exc
|
||||
if isinstance(exc, Exception)
|
||||
@@ -889,7 +893,8 @@ class Executor:
|
||||
|
||||
known_errors = (InsufficientBalanceError,)
|
||||
if isinstance(error, known_errors):
|
||||
return execution_stats, execution_status, error
|
||||
execution_stats.error = str(error)
|
||||
return ExecutionStatus.FAILED
|
||||
|
||||
log_metadata.exception(
|
||||
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
|
||||
|
||||
@@ -841,7 +841,7 @@ async def add_graph_execution(
|
||||
await bus.publish(graph_exec)
|
||||
|
||||
return graph_exec
|
||||
except Exception as e:
|
||||
except BaseException as e:
|
||||
logger.error(f"Unable to publish graph #{graph_id} exec #{graph_exec.id}: {e}")
|
||||
await edb.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in graph_exec.node_executions],
|
||||
|
||||
@@ -42,16 +42,22 @@ T = TypeVar("T")
|
||||
logger = TruncatedLogger(logging.getLogger(__name__))
|
||||
|
||||
|
||||
def time_measured(func: Callable[P, T]) -> Callable[P, Tuple[TimingInfo, T]]:
|
||||
def time_measured(
|
||||
func: Callable[P, T],
|
||||
) -> Callable[P, Tuple[TimingInfo, T | BaseException]]:
|
||||
"""
|
||||
Decorator to measure the time taken by a synchronous function to execute.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Tuple[TimingInfo, T]:
|
||||
def wrapper(
|
||||
*args: P.args, **kwargs: P.kwargs
|
||||
) -> Tuple[TimingInfo, T | BaseException]:
|
||||
start_wall_time, start_cpu_time = _start_measurement()
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
except BaseException as e:
|
||||
result = e
|
||||
finally:
|
||||
wall_duration, cpu_duration = _end_measurement(
|
||||
start_wall_time, start_cpu_time
|
||||
@@ -64,16 +70,20 @@ def time_measured(func: Callable[P, T]) -> Callable[P, Tuple[TimingInfo, T]]:
|
||||
|
||||
def async_time_measured(
|
||||
func: Callable[P, Awaitable[T]],
|
||||
) -> Callable[P, Awaitable[Tuple[TimingInfo, T]]]:
|
||||
) -> Callable[P, Awaitable[Tuple[TimingInfo, T | BaseException]]]:
|
||||
"""
|
||||
Decorator to measure the time taken by an async function to execute.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> Tuple[TimingInfo, T]:
|
||||
async def async_wrapper(
|
||||
*args: P.args, **kwargs: P.kwargs
|
||||
) -> Tuple[TimingInfo, T | BaseException]:
|
||||
start_wall_time, start_cpu_time = _start_measurement()
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
except BaseException as e:
|
||||
result = e
|
||||
finally:
|
||||
wall_duration, cpu_duration = _end_measurement(
|
||||
start_wall_time, start_cpu_time
|
||||
|
||||
Reference in New Issue
Block a user