feat(backend): Agent Executor reliability; make RPC to DB manager durable (#10516)

Some failure on DB RPC can cause agent execution failure. This change
makes sure the error chance is minimized.

### Changes 🏗️

* Enable request retry
* Increase transaction timeout
* Use better typing on the DB query 
* Gracefully handles insufficient balance

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Manual tests
This commit is contained in:
Zamil Majdy
2025-08-01 01:04:56 +08:00
committed by GitHub
parent 216762575c
commit 686d811062
5 changed files with 139 additions and 93 deletions

View File

@@ -357,15 +357,15 @@ class UserCreditBase(ABC):
amount = min(-user_balance, 0)
# Create the transaction
transaction_data = CreditTransactionCreateInput(
userId=user_id,
amount=amount,
runningBalance=user_balance + amount,
type=transaction_type,
metadata=metadata,
isActive=is_active,
createdAt=self.time_now(),
)
transaction_data: CreditTransactionCreateInput = {
"userId": user_id,
"amount": amount,
"runningBalance": user_balance + amount,
"type": transaction_type,
"metadata": metadata,
"isActive": is_active,
"createdAt": self.time_now(),
}
if transaction_key:
transaction_data["transactionKey"] = transaction_key
tx = await CreditTransaction.prisma().create(data=transaction_data)

View File

@@ -79,16 +79,39 @@ async def disconnect():
raise ConnectionError("Failed to disconnect from Prisma.")
# Transaction timeout constant (in milliseconds)
TRANSACTION_TIMEOUT = 15000 # 15 seconds - Increased from 5s to prevent timeout errors
@asynccontextmanager
async def transaction():
async with prisma.tx() as tx:
async def transaction(timeout: int | None = None):
"""
Create a database transaction with optional timeout.
Args:
timeout: Transaction timeout in milliseconds. If None, uses TRANSACTION_TIMEOUT (15s).
"""
if timeout is None:
timeout = TRANSACTION_TIMEOUT
async with prisma.tx(timeout=timeout) as tx:
yield tx
@asynccontextmanager
async def locked_transaction(key: str):
async def locked_transaction(key: str, timeout: int | None = None):
"""
Create a database transaction with advisory lock.
Args:
key: Lock key for advisory lock
timeout: Transaction timeout in milliseconds. If None, uses TRANSACTION_TIMEOUT (15s).
"""
if timeout is None:
timeout = TRANSACTION_TIMEOUT
lock_key = zlib.crc32(key.encode("utf-8"))
async with transaction() as tx:
async with transaction(timeout=timeout) as tx:
await tx.execute_raw("SELECT pg_advisory_xact_lock($1)", lock_key)
yield tx

View File

@@ -583,10 +583,10 @@ async def upsert_execution_output(
"""
Insert AgentNodeExecutionInputOutput record for as one of AgentNodeExecution.Output.
"""
data = AgentNodeExecutionInputOutputCreateInput(
name=output_name,
referencedByOutputExecId=node_exec_id,
)
data: AgentNodeExecutionInputOutputCreateInput = {
"name": output_name,
"referencedByOutputExecId": node_exec_id,
}
if output_data is not None:
data["data"] = SafeJson(output_data)
await AgentNodeExecutionInputOutput.prisma().create(data=data)

View File

@@ -737,7 +737,8 @@ class Executor:
execution_count=increment_execution_count(graph_exec.user_id),
execution_stats=execution_stats,
)
except InsufficientBalanceError as error:
except InsufficientBalanceError as balance_error:
error = balance_error # Set error to trigger FAILED status
node_exec_id = queued_node_exec.node_exec_id
db_client.upsert_execution_output(
node_exec_id=node_exec_id,
@@ -749,7 +750,6 @@ class Executor:
exec_id=node_exec_id,
status=ExecutionStatus.FAILED,
)
execution_status = ExecutionStatus.FAILED
cls._handle_low_balance_notif(
db_client,
@@ -758,7 +758,8 @@ class Executor:
execution_stats,
error,
)
raise
# Gracefully stop the execution loop
break
# Add input overrides -----------------------------
node_id = queued_node_exec.node_id
@@ -833,8 +834,10 @@ class Executor:
time.sleep(0.1)
# loop done --------------------------------------------------
execution_status = ExecutionStatus.COMPLETED
return execution_stats, execution_status, error
# Determine final execution status based on whether there was an error
execution_status = (
ExecutionStatus.FAILED if error else ExecutionStatus.COMPLETED
)
except CancelledError as exc:
execution_status = ExecutionStatus.TERMINATED
@@ -849,62 +852,90 @@ class Executor:
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
)
finally:
# Cancel and wait for all node executions to complete
for node_id, inflight_exec in running_node_execution.items():
if inflight_exec.is_done():
continue
log_metadata.info(f"Stopping node execution {node_id}")
inflight_exec.stop()
for node_id, inflight_eval in running_node_evaluation.items():
if inflight_eval.done():
continue
log_metadata.info(f"Stopping node evaluation {node_id}")
inflight_eval.cancel()
for node_id, inflight_exec in running_node_execution.items():
if inflight_exec.is_done():
continue
try:
inflight_exec.wait_for_cancellation(timeout=60.0)
except TimeoutError:
log_metadata.exception(
f"Node execution #{node_id} did not stop in time, "
"it may be stuck or taking too long."
)
for node_id, inflight_eval in running_node_evaluation.items():
if inflight_eval.done():
continue
try:
inflight_eval.result(timeout=60.0)
except TimeoutError:
log_metadata.exception(
f"Node evaluation #{node_id} did not stop in time, "
"it may be stuck or taking too long."
)
if execution_status in [ExecutionStatus.TERMINATED, ExecutionStatus.FAILED]:
inflight_executions = db_client.get_node_executions(
graph_exec.graph_exec_id,
statuses=[
ExecutionStatus.QUEUED,
ExecutionStatus.RUNNING,
],
include_exec_data=False,
)
db_client.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in inflight_executions],
status=execution_status,
stats={"error": str(error)} if error else None,
)
for node_exec in inflight_executions:
node_exec.status = execution_status
send_execution_update(node_exec)
clean_exec_files(graph_exec.graph_exec_id)
# Use helper method with error handling to ensure cleanup never fails
cls._cleanup_graph_execution(
running_node_execution=running_node_execution,
running_node_evaluation=running_node_evaluation,
execution_status=execution_status,
error=error,
graph_exec_id=graph_exec.graph_exec_id,
log_metadata=log_metadata,
db_client=db_client,
)
return execution_stats, execution_status, error
@classmethod
@error_logged(swallow=True)
def _cleanup_graph_execution(
cls,
running_node_execution: dict[str, "NodeExecutionProgress"],
running_node_evaluation: dict[str, Future],
execution_status: ExecutionStatus,
error: Exception | None,
graph_exec_id: str,
log_metadata: LogMetadata,
db_client: "DatabaseManagerClient",
) -> None:
"""
Clean up running node executions and evaluations when graph execution ends.
This method is decorated with @error_logged(swallow=True) to ensure cleanup
never fails in the finally block.
"""
# Cancel and wait for all node executions to complete
for node_id, inflight_exec in running_node_execution.items():
if inflight_exec.is_done():
continue
log_metadata.info(f"Stopping node execution {node_id}")
inflight_exec.stop()
for node_id, inflight_eval in running_node_evaluation.items():
if inflight_eval.done():
continue
log_metadata.info(f"Stopping node evaluation {node_id}")
inflight_eval.cancel()
for node_id, inflight_exec in running_node_execution.items():
if inflight_exec.is_done():
continue
try:
inflight_exec.wait_for_cancellation(timeout=60.0)
except TimeoutError:
log_metadata.exception(
f"Node execution #{node_id} did not stop in time, "
"it may be stuck or taking too long."
)
for node_id, inflight_eval in running_node_evaluation.items():
if inflight_eval.done():
continue
try:
inflight_eval.result(timeout=60.0)
except TimeoutError:
log_metadata.exception(
f"Node evaluation #{node_id} did not stop in time, "
"it may be stuck or taking too long."
)
if execution_status in [ExecutionStatus.TERMINATED, ExecutionStatus.FAILED]:
inflight_executions = db_client.get_node_executions(
graph_exec_id,
statuses=[
ExecutionStatus.QUEUED,
ExecutionStatus.RUNNING,
],
include_exec_data=False,
)
db_client.update_node_execution_status_batch(
[node_exec.node_exec_id for node_exec in inflight_executions],
status=execution_status,
stats={"error": str(error)} if error else None,
)
for node_exec in inflight_executions:
node_exec.status = execution_status
send_execution_update(node_exec)
clean_exec_files(graph_exec_id)
@classmethod
async def _process_node_output(
cls,

View File

@@ -27,7 +27,7 @@ from fastapi import FastAPI, Request, responses
from pydantic import BaseModel, TypeAdapter, create_model
from tenacity import (
retry,
retry_if_exception_type,
retry_if_not_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
@@ -275,33 +275,25 @@ def get_service_client(
service_client_type: Type[ASC],
call_timeout: int | None = api_call_timeout,
health_check: bool = True,
request_retry: bool | int = False,
request_retry: bool = False,
) -> ASC:
def _maybe_retry(fn: Callable[..., R]) -> Callable[..., R]:
"""Decorate *fn* with tenacity retry when enabled."""
nonlocal request_retry
if isinstance(request_retry, int):
retry_attempts = request_retry
request_retry = True
else:
retry_attempts = api_comm_retry
if not request_retry:
return fn
return retry(
reraise=True,
stop=stop_after_attempt(retry_attempts),
stop=stop_after_attempt(api_comm_retry),
wait=wait_exponential_jitter(max=4.0),
retry=retry_if_exception_type(
retry=retry_if_not_exception_type(
(
httpx.ConnectError,
httpx.ReadTimeout,
httpx.WriteTimeout,
httpx.ConnectTimeout,
httpx.RemoteProtocolError,
# Don't retry these specific exceptions that won't be fixed by retrying
ValueError, # Invalid input/parameters
KeyError, # Missing required data
TypeError, # Wrong data types
AttributeError, # Missing attributes
)
),
)(fn)