mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(backend): Agent Executor reliability; make RPC to DB manager durable (#10516)
Some failure on DB RPC can cause agent execution failure. This change makes sure the error chance is minimized. ### Changes 🏗️ * Enable request retry * Increase transaction timeout * Use better typing on the DB query * Gracefully handles insufficient balance ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: <!-- Put your test plan here: --> - [x] Manual tests
This commit is contained in:
@@ -357,15 +357,15 @@ class UserCreditBase(ABC):
|
||||
amount = min(-user_balance, 0)
|
||||
|
||||
# Create the transaction
|
||||
transaction_data = CreditTransactionCreateInput(
|
||||
userId=user_id,
|
||||
amount=amount,
|
||||
runningBalance=user_balance + amount,
|
||||
type=transaction_type,
|
||||
metadata=metadata,
|
||||
isActive=is_active,
|
||||
createdAt=self.time_now(),
|
||||
)
|
||||
transaction_data: CreditTransactionCreateInput = {
|
||||
"userId": user_id,
|
||||
"amount": amount,
|
||||
"runningBalance": user_balance + amount,
|
||||
"type": transaction_type,
|
||||
"metadata": metadata,
|
||||
"isActive": is_active,
|
||||
"createdAt": self.time_now(),
|
||||
}
|
||||
if transaction_key:
|
||||
transaction_data["transactionKey"] = transaction_key
|
||||
tx = await CreditTransaction.prisma().create(data=transaction_data)
|
||||
|
||||
@@ -79,16 +79,39 @@ async def disconnect():
|
||||
raise ConnectionError("Failed to disconnect from Prisma.")
|
||||
|
||||
|
||||
# Transaction timeout constant (in milliseconds)
|
||||
TRANSACTION_TIMEOUT = 15000 # 15 seconds - Increased from 5s to prevent timeout errors
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction():
|
||||
async with prisma.tx() as tx:
|
||||
async def transaction(timeout: int | None = None):
|
||||
"""
|
||||
Create a database transaction with optional timeout.
|
||||
|
||||
Args:
|
||||
timeout: Transaction timeout in milliseconds. If None, uses TRANSACTION_TIMEOUT (15s).
|
||||
"""
|
||||
if timeout is None:
|
||||
timeout = TRANSACTION_TIMEOUT
|
||||
|
||||
async with prisma.tx(timeout=timeout) as tx:
|
||||
yield tx
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def locked_transaction(key: str):
|
||||
async def locked_transaction(key: str, timeout: int | None = None):
|
||||
"""
|
||||
Create a database transaction with advisory lock.
|
||||
|
||||
Args:
|
||||
key: Lock key for advisory lock
|
||||
timeout: Transaction timeout in milliseconds. If None, uses TRANSACTION_TIMEOUT (15s).
|
||||
"""
|
||||
if timeout is None:
|
||||
timeout = TRANSACTION_TIMEOUT
|
||||
|
||||
lock_key = zlib.crc32(key.encode("utf-8"))
|
||||
async with transaction() as tx:
|
||||
async with transaction(timeout=timeout) as tx:
|
||||
await tx.execute_raw("SELECT pg_advisory_xact_lock($1)", lock_key)
|
||||
yield tx
|
||||
|
||||
|
||||
@@ -583,10 +583,10 @@ async def upsert_execution_output(
|
||||
"""
|
||||
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
|
||||
"""
|
||||
data = AgentNodeExecutionInputOutputCreateInput(
|
||||
name=output_name,
|
||||
referencedByOutputExecId=node_exec_id,
|
||||
)
|
||||
data: AgentNodeExecutionInputOutputCreateInput = {
|
||||
"name": output_name,
|
||||
"referencedByOutputExecId": node_exec_id,
|
||||
}
|
||||
if output_data is not None:
|
||||
data["data"] = SafeJson(output_data)
|
||||
await AgentNodeExecutionInputOutput.prisma().create(data=data)
|
||||
|
||||
@@ -737,7 +737,8 @@ class Executor:
|
||||
execution_count=increment_execution_count(graph_exec.user_id),
|
||||
execution_stats=execution_stats,
|
||||
)
|
||||
except InsufficientBalanceError as error:
|
||||
except InsufficientBalanceError as balance_error:
|
||||
error = balance_error # Set error to trigger FAILED status
|
||||
node_exec_id = queued_node_exec.node_exec_id
|
||||
db_client.upsert_execution_output(
|
||||
node_exec_id=node_exec_id,
|
||||
@@ -749,7 +750,6 @@ class Executor:
|
||||
exec_id=node_exec_id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
)
|
||||
execution_status = ExecutionStatus.FAILED
|
||||
|
||||
cls._handle_low_balance_notif(
|
||||
db_client,
|
||||
@@ -758,7 +758,8 @@ class Executor:
|
||||
execution_stats,
|
||||
error,
|
||||
)
|
||||
raise
|
||||
# Gracefully stop the execution loop
|
||||
break
|
||||
|
||||
# Add input overrides -----------------------------
|
||||
node_id = queued_node_exec.node_id
|
||||
@@ -833,8 +834,10 @@ class Executor:
|
||||
time.sleep(0.1)
|
||||
|
||||
# loop done --------------------------------------------------
|
||||
execution_status = ExecutionStatus.COMPLETED
|
||||
return execution_stats, execution_status, error
|
||||
# Determine final execution status based on whether there was an error
|
||||
execution_status = (
|
||||
ExecutionStatus.FAILED if error else ExecutionStatus.COMPLETED
|
||||
)
|
||||
|
||||
except CancelledError as exc:
|
||||
execution_status = ExecutionStatus.TERMINATED
|
||||
@@ -849,62 +852,90 @@ class Executor:
|
||||
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
|
||||
)
|
||||
finally:
|
||||
# Cancel and wait for all node executions to complete
|
||||
for node_id, inflight_exec in running_node_execution.items():
|
||||
if inflight_exec.is_done():
|
||||
continue
|
||||
log_metadata.info(f"Stopping node execution {node_id}")
|
||||
inflight_exec.stop()
|
||||
|
||||
for node_id, inflight_eval in running_node_evaluation.items():
|
||||
if inflight_eval.done():
|
||||
continue
|
||||
log_metadata.info(f"Stopping node evaluation {node_id}")
|
||||
inflight_eval.cancel()
|
||||
|
||||
for node_id, inflight_exec in running_node_execution.items():
|
||||
if inflight_exec.is_done():
|
||||
continue
|
||||
try:
|
||||
inflight_exec.wait_for_cancellation(timeout=60.0)
|
||||
except TimeoutError:
|
||||
log_metadata.exception(
|
||||
f"Node execution #{node_id} did not stop in time, "
|
||||
"it may be stuck or taking too long."
|
||||
)
|
||||
|
||||
for node_id, inflight_eval in running_node_evaluation.items():
|
||||
if inflight_eval.done():
|
||||
continue
|
||||
try:
|
||||
inflight_eval.result(timeout=60.0)
|
||||
except TimeoutError:
|
||||
log_metadata.exception(
|
||||
f"Node evaluation #{node_id} did not stop in time, "
|
||||
"it may be stuck or taking too long."
|
||||
)
|
||||
|
||||
if execution_status in [ExecutionStatus.TERMINATED, ExecutionStatus.FAILED]:
|
||||
inflight_executions = db_client.get_node_executions(
|
||||
graph_exec.graph_exec_id,
|
||||
statuses=[
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
include_exec_data=False,
|
||||
)
|
||||
db_client.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in inflight_executions],
|
||||
status=execution_status,
|
||||
stats={"error": str(error)} if error else None,
|
||||
)
|
||||
for node_exec in inflight_executions:
|
||||
node_exec.status = execution_status
|
||||
send_execution_update(node_exec)
|
||||
|
||||
clean_exec_files(graph_exec.graph_exec_id)
|
||||
# Use helper method with error handling to ensure cleanup never fails
|
||||
cls._cleanup_graph_execution(
|
||||
running_node_execution=running_node_execution,
|
||||
running_node_evaluation=running_node_evaluation,
|
||||
execution_status=execution_status,
|
||||
error=error,
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
log_metadata=log_metadata,
|
||||
db_client=db_client,
|
||||
)
|
||||
return execution_stats, execution_status, error
|
||||
|
||||
@classmethod
|
||||
@error_logged(swallow=True)
|
||||
def _cleanup_graph_execution(
|
||||
cls,
|
||||
running_node_execution: dict[str, "NodeExecutionProgress"],
|
||||
running_node_evaluation: dict[str, Future],
|
||||
execution_status: ExecutionStatus,
|
||||
error: Exception | None,
|
||||
graph_exec_id: str,
|
||||
log_metadata: LogMetadata,
|
||||
db_client: "DatabaseManagerClient",
|
||||
) -> None:
|
||||
"""
|
||||
Clean up running node executions and evaluations when graph execution ends.
|
||||
This method is decorated with @error_logged(swallow=True) to ensure cleanup
|
||||
never fails in the finally block.
|
||||
"""
|
||||
# Cancel and wait for all node executions to complete
|
||||
for node_id, inflight_exec in running_node_execution.items():
|
||||
if inflight_exec.is_done():
|
||||
continue
|
||||
log_metadata.info(f"Stopping node execution {node_id}")
|
||||
inflight_exec.stop()
|
||||
|
||||
for node_id, inflight_eval in running_node_evaluation.items():
|
||||
if inflight_eval.done():
|
||||
continue
|
||||
log_metadata.info(f"Stopping node evaluation {node_id}")
|
||||
inflight_eval.cancel()
|
||||
|
||||
for node_id, inflight_exec in running_node_execution.items():
|
||||
if inflight_exec.is_done():
|
||||
continue
|
||||
try:
|
||||
inflight_exec.wait_for_cancellation(timeout=60.0)
|
||||
except TimeoutError:
|
||||
log_metadata.exception(
|
||||
f"Node execution #{node_id} did not stop in time, "
|
||||
"it may be stuck or taking too long."
|
||||
)
|
||||
|
||||
for node_id, inflight_eval in running_node_evaluation.items():
|
||||
if inflight_eval.done():
|
||||
continue
|
||||
try:
|
||||
inflight_eval.result(timeout=60.0)
|
||||
except TimeoutError:
|
||||
log_metadata.exception(
|
||||
f"Node evaluation #{node_id} did not stop in time, "
|
||||
"it may be stuck or taking too long."
|
||||
)
|
||||
|
||||
if execution_status in [ExecutionStatus.TERMINATED, ExecutionStatus.FAILED]:
|
||||
inflight_executions = db_client.get_node_executions(
|
||||
graph_exec_id,
|
||||
statuses=[
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
include_exec_data=False,
|
||||
)
|
||||
db_client.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in inflight_executions],
|
||||
status=execution_status,
|
||||
stats={"error": str(error)} if error else None,
|
||||
)
|
||||
for node_exec in inflight_executions:
|
||||
node_exec.status = execution_status
|
||||
send_execution_update(node_exec)
|
||||
|
||||
clean_exec_files(graph_exec_id)
|
||||
|
||||
@classmethod
|
||||
async def _process_node_output(
|
||||
cls,
|
||||
|
||||
@@ -27,7 +27,7 @@ from fastapi import FastAPI, Request, responses
|
||||
from pydantic import BaseModel, TypeAdapter, create_model
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
retry_if_not_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
@@ -275,33 +275,25 @@ def get_service_client(
|
||||
service_client_type: Type[ASC],
|
||||
call_timeout: int | None = api_call_timeout,
|
||||
health_check: bool = True,
|
||||
request_retry: bool | int = False,
|
||||
request_retry: bool = False,
|
||||
) -> ASC:
|
||||
|
||||
def _maybe_retry(fn: Callable[..., R]) -> Callable[..., R]:
|
||||
"""Decorate *fn* with tenacity retry when enabled."""
|
||||
nonlocal request_retry
|
||||
|
||||
if isinstance(request_retry, int):
|
||||
retry_attempts = request_retry
|
||||
request_retry = True
|
||||
else:
|
||||
retry_attempts = api_comm_retry
|
||||
|
||||
if not request_retry:
|
||||
return fn
|
||||
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(retry_attempts),
|
||||
stop=stop_after_attempt(api_comm_retry),
|
||||
wait=wait_exponential_jitter(max=4.0),
|
||||
retry=retry_if_exception_type(
|
||||
retry=retry_if_not_exception_type(
|
||||
(
|
||||
httpx.ConnectError,
|
||||
httpx.ReadTimeout,
|
||||
httpx.WriteTimeout,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.RemoteProtocolError,
|
||||
# Don't retry these specific exceptions that won't be fixed by retrying
|
||||
ValueError, # Invalid input/parameters
|
||||
KeyError, # Missing required data
|
||||
TypeError, # Wrong data types
|
||||
AttributeError, # Missing attributes
|
||||
)
|
||||
),
|
||||
)(fn)
|
||||
|
||||
Reference in New Issue
Block a user