mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(backend): Avoid swallowing exception on graph execution failure (#10260)
Graph execution that fails due to interruption or unknown error should be enqueued back to the queue. However, swallowing the error ends up not marking the execution as a failure. ### Changes 🏗️ * Avoid keep updating the graph execution status on each node execution step. * Added a guard rail to avoid completing graph execution on non-completed execution status. * Avoid acknowledging messages from the queue if the graph execution is not yet completed. ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: <!-- Put your test plan here: --> - [x] Run graph execution, kill the process, re-run the process --------- Co-authored-by: Swifty <craigswift13@gmail.com>
This commit is contained in:
@@ -588,12 +588,10 @@ async def update_graph_execution_start_time(
|
||||
|
||||
async def update_graph_execution_stats(
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
status: ExecutionStatus | None = None,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> GraphExecution | None:
|
||||
update_data: AgentGraphExecutionUpdateManyMutationInput = {
|
||||
"executionStatus": status
|
||||
}
|
||||
update_data: AgentGraphExecutionUpdateManyMutationInput = {}
|
||||
|
||||
if stats:
|
||||
stats_dict = stats.model_dump()
|
||||
@@ -601,6 +599,9 @@ async def update_graph_execution_stats(
|
||||
stats_dict["error"] = str(stats_dict["error"])
|
||||
update_data["stats"] = Json(stats_dict)
|
||||
|
||||
if status:
|
||||
update_data["executionStatus"] = status
|
||||
|
||||
updated_count = await AgentGraphExecution.prisma().update_many(
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
|
||||
@@ -421,7 +421,7 @@ class Executor:
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@async_error_logged
|
||||
@async_error_logged(swallow=True)
|
||||
async def on_node_execution(
|
||||
cls,
|
||||
node_exec: NodeExecutionEntry,
|
||||
@@ -529,7 +529,7 @@ class Executor:
|
||||
logger.info(f"[GraphExecutor] {cls.pid} started")
|
||||
|
||||
@classmethod
|
||||
@error_logged
|
||||
@error_logged(swallow=False)
|
||||
def on_graph_execution(
|
||||
cls, graph_exec: GraphExecutionEntry, cancel: threading.Event
|
||||
):
|
||||
@@ -581,6 +581,15 @@ class Executor:
|
||||
exec_stats.cputime += timing_info.cpu_time
|
||||
exec_stats.error = str(error) if error else exec_stats.error
|
||||
|
||||
if status not in {
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.TERMINATED,
|
||||
ExecutionStatus.FAILED,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Graph Execution #{graph_exec.graph_exec_id} ended with unexpected status {status}"
|
||||
)
|
||||
|
||||
if graph_exec_result := db_client.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
status=status,
|
||||
@@ -684,7 +693,6 @@ class Executor:
|
||||
|
||||
if _graph_exec := db_client.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
status=execution_status,
|
||||
stats=execution_stats,
|
||||
):
|
||||
send_execution_update(_graph_exec)
|
||||
|
||||
@@ -2,7 +2,17 @@ import functools
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Awaitable, Callable, Coroutine, ParamSpec, Tuple, TypeVar
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Literal,
|
||||
ParamSpec,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
overload,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -72,37 +82,115 @@ def async_time_measured(
|
||||
return async_wrapper
|
||||
|
||||
|
||||
def error_logged(func: Callable[P, T]) -> Callable[P, T | None]:
|
||||
@overload
|
||||
def error_logged(
|
||||
*, swallow: Literal[True]
|
||||
) -> Callable[[Callable[P, T]], Callable[P, T | None]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def error_logged(
|
||||
*, swallow: Literal[False]
|
||||
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def error_logged() -> Callable[[Callable[P, T]], Callable[P, T | None]]: ...
|
||||
|
||||
|
||||
def error_logged(
|
||||
*, swallow: bool = True
|
||||
) -> (
|
||||
Callable[[Callable[P, T]], Callable[P, T | None]]
|
||||
| Callable[[Callable[P, T]], Callable[P, T]]
|
||||
):
|
||||
"""
|
||||
Decorator to suppress and log any exceptions raised by a function.
|
||||
Decorator to log any exceptions raised by a function, with optional suppression.
|
||||
|
||||
Args:
|
||||
swallow: Whether to suppress the exception (True) or re-raise it (False). Default is True.
|
||||
|
||||
Usage:
|
||||
@error_logged() # Default behavior (swallow errors)
|
||||
@error_logged(swallow=False) # Log and re-raise errors
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error when calling function {func.__name__} with arguments {args} {kwargs}: {e}"
|
||||
)
|
||||
def decorator(f: Callable[P, T]) -> Callable[P, T | None]:
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
|
||||
try:
|
||||
return f(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error when calling function {f.__name__} with arguments {args} {kwargs}: {e}"
|
||||
)
|
||||
if not swallow:
|
||||
raise
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@overload
|
||||
def async_error_logged(
|
||||
func: Callable[P, Coroutine[Any, Any, T]],
|
||||
) -> Callable[P, Coroutine[Any, Any, T | None]]:
|
||||
*, swallow: Literal[True]
|
||||
) -> Callable[
|
||||
[Callable[P, Coroutine[Any, Any, T]]], Callable[P, Coroutine[Any, Any, T | None]]
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def async_error_logged(
|
||||
*, swallow: Literal[False]
|
||||
) -> Callable[
|
||||
[Callable[P, Coroutine[Any, Any, T]]], Callable[P, Coroutine[Any, Any, T]]
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def async_error_logged() -> Callable[
|
||||
[Callable[P, Coroutine[Any, Any, T]]],
|
||||
Callable[P, Coroutine[Any, Any, T | None]],
|
||||
]: ...
|
||||
|
||||
|
||||
def async_error_logged(*, swallow: bool = True) -> (
|
||||
Callable[
|
||||
[Callable[P, Coroutine[Any, Any, T]]],
|
||||
Callable[P, Coroutine[Any, Any, T | None]],
|
||||
]
|
||||
| Callable[
|
||||
[Callable[P, Coroutine[Any, Any, T]]], Callable[P, Coroutine[Any, Any, T]]
|
||||
]
|
||||
):
|
||||
"""
|
||||
Decorator to suppress and log any exceptions raised by an async function.
|
||||
Decorator to log any exceptions raised by an async function, with optional suppression.
|
||||
|
||||
Args:
|
||||
swallow: Whether to suppress the exception (True) or re-raise it (False). Default is True.
|
||||
|
||||
Usage:
|
||||
@async_error_logged() # Default behavior (swallow errors)
|
||||
@async_error_logged(swallow=False) # Log and re-raise errors
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error when calling async function {func.__name__} with arguments {args} {kwargs}: {e}"
|
||||
)
|
||||
def decorator(
|
||||
f: Callable[P, Coroutine[Any, Any, T]]
|
||||
) -> Callable[P, Coroutine[Any, Any, T | None]]:
|
||||
@functools.wraps(f)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | None:
|
||||
try:
|
||||
return await f(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error when calling async function {f.__name__} with arguments {args} {kwargs}: {e}"
|
||||
)
|
||||
if not swallow:
|
||||
raise
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import time
|
||||
|
||||
from backend.util.decorator import error_logged, time_measured
|
||||
import pytest
|
||||
|
||||
from backend.util.decorator import async_error_logged, error_logged, time_measured
|
||||
|
||||
|
||||
@time_measured
|
||||
@@ -9,18 +11,64 @@ def example_function(a: int, b: int, c: int) -> int:
|
||||
return a + b + c
|
||||
|
||||
|
||||
@error_logged
|
||||
def example_function_with_error(a: int, b: int, c: int) -> int:
|
||||
raise ValueError("This is a test error")
|
||||
@error_logged(swallow=True)
|
||||
def example_function_with_error_swallowed(a: int, b: int, c: int) -> int:
|
||||
raise ValueError("This error should be swallowed")
|
||||
|
||||
|
||||
@error_logged(swallow=False)
|
||||
def example_function_with_error_not_swallowed(a: int, b: int, c: int) -> int:
|
||||
raise ValueError("This error should NOT be swallowed")
|
||||
|
||||
|
||||
@async_error_logged(swallow=True)
|
||||
async def async_function_with_error_swallowed() -> int:
|
||||
raise ValueError("This async error should be swallowed")
|
||||
|
||||
|
||||
@async_error_logged(swallow=False)
|
||||
async def async_function_with_error_not_swallowed() -> int:
|
||||
raise ValueError("This async error should NOT be swallowed")
|
||||
|
||||
|
||||
def test_timer_decorator():
|
||||
"""Test that the time_measured decorator correctly measures execution time."""
|
||||
info, res = example_function(1, 2, 3)
|
||||
assert info.cpu_time >= 0
|
||||
assert info.wall_time >= 0.4
|
||||
assert res == 6
|
||||
|
||||
|
||||
def test_error_decorator():
|
||||
res = example_function_with_error(1, 2, 3)
|
||||
def test_error_decorator_swallow_true():
|
||||
"""Test that error_logged(swallow=True) logs and swallows errors."""
|
||||
res = example_function_with_error_swallowed(1, 2, 3)
|
||||
assert res is None
|
||||
|
||||
|
||||
def test_error_decorator_swallow_false():
|
||||
"""Test that error_logged(swallow=False) logs errors but re-raises them."""
|
||||
with pytest.raises(ValueError, match="This error should NOT be swallowed"):
|
||||
example_function_with_error_not_swallowed(1, 2, 3)
|
||||
|
||||
|
||||
def test_async_error_decorator_swallow_true():
|
||||
"""Test that async_error_logged(swallow=True) logs and swallows errors."""
|
||||
import asyncio
|
||||
|
||||
async def run_test():
|
||||
res = await async_function_with_error_swallowed()
|
||||
return res
|
||||
|
||||
res = asyncio.run(run_test())
|
||||
assert res is None
|
||||
|
||||
|
||||
def test_async_error_decorator_swallow_false():
|
||||
"""Test that async_error_logged(swallow=False) logs errors but re-raises them."""
|
||||
import asyncio
|
||||
|
||||
async def run_test():
|
||||
await async_function_with_error_not_swallowed()
|
||||
|
||||
with pytest.raises(ValueError, match="This async error should NOT be swallowed"):
|
||||
asyncio.run(run_test())
|
||||
|
||||
@@ -51,7 +51,7 @@ class ServiceTestClient(AppServiceClient):
|
||||
subtract_async = endpoint_to_async(ServiceTest.subtract)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_creation(server):
|
||||
with ServiceTest():
|
||||
client = get_service_client(ServiceTestClient)
|
||||
|
||||
Reference in New Issue
Block a user