feat(backend): Avoid loading all node executions on large continuos agent (#9667)

A graph that processes tens of thousands of nodes will cripple the
system since the API tries to load all of them and dump them into the
browser. The scope of this change is to avoid such an issue by only
returning the last 1000 node executions.

### Changes 🏗️

* Return only 1000 node executions from `AgentNodeExecutions` reference.
* Unify the include clause for fetching `AgentNodeExecutions` in one
place and its format.
* Fix & optimize `cancel_execution` logic to always set both the graph &
node execution status in batch.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Execute a graph in a loop that executes 10000 nodes, it should
only display the last 1000 nodes when refreshed. Cancelling the graph
should also not cripple the server.
This commit is contained in:
Zamil Majdy
2025-03-21 19:48:35 +07:00
committed by GitHub
parent 5411e18bd0
commit f01b31873f
6 changed files with 121 additions and 88 deletions

View File

@@ -10,6 +10,7 @@ from prisma.models import (
AgentNodeExecution,
AgentNodeExecutionInputOutput,
)
from prisma.types import AgentNodeExecutionUpdateInput, AgentNodeExecutionWhereInput
from pydantic import BaseModel
from backend.data.block import BlockData, BlockInput, CompletedBlockOutput
@@ -292,13 +293,19 @@ async def update_graph_execution_start_time(graph_exec_id: str) -> ExecutionResu
async def update_graph_execution_stats(
graph_exec_id: str,
status: ExecutionStatus,
stats: GraphExecutionStats,
stats: GraphExecutionStats | None = None,
) -> ExecutionResult:
data = stats.model_dump()
if isinstance(data["error"], Exception):
data = stats.model_dump() if stats else {}
if isinstance(data.get("error"), Exception):
data["error"] = str(data["error"])
res = await AgentGraphExecution.prisma().update(
where={"id": graph_exec_id},
where={
"id": graph_exec_id,
"OR": [
{"executionStatus": ExecutionStatus.RUNNING},
{"executionStatus": ExecutionStatus.QUEUED},
],
},
data={
"executionStatus": status,
"stats": Json(data),
@@ -320,6 +327,17 @@ async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionSta
)
async def update_execution_status_batch(
node_exec_ids: list[str],
status: ExecutionStatus,
stats: dict[str, Any] | None = None,
):
await AgentNodeExecution.prisma().update_many(
where={"id": {"in": node_exec_ids}},
data=_get_update_status_data(status, None, stats),
)
async def update_execution_status(
node_exec_id: str,
status: ExecutionStatus,
@@ -329,20 +347,9 @@ async def update_execution_status(
if status == ExecutionStatus.QUEUED and execution_data is None:
raise ValueError("Execution data must be provided when queuing an execution.")
now = datetime.now(tz=timezone.utc)
data = {
**({"executionStatus": status}),
**({"queuedTime": now} if status == ExecutionStatus.QUEUED else {}),
**({"startedTime": now} if status == ExecutionStatus.RUNNING else {}),
**({"endedTime": now} if status == ExecutionStatus.FAILED else {}),
**({"endedTime": now} if status == ExecutionStatus.COMPLETED else {}),
**({"executionData": Json(execution_data)} if execution_data else {}),
**({"stats": Json(stats)} if stats else {}),
}
res = await AgentNodeExecution.prisma().update(
where={"id": node_exec_id},
data=data, # type: ignore
data=_get_update_status_data(status, execution_data, stats),
include=EXECUTION_RESULT_INCLUDE,
)
if not res:
@@ -351,6 +358,29 @@ async def update_execution_status(
return ExecutionResult.from_db(res)
def _get_update_status_data(
status: ExecutionStatus,
execution_data: BlockInput | None = None,
stats: dict[str, Any] | None = None,
) -> AgentNodeExecutionUpdateInput:
now = datetime.now(tz=timezone.utc)
update_data: AgentNodeExecutionUpdateInput = {"executionStatus": status}
if status == ExecutionStatus.QUEUED:
update_data["queuedTime"] = now
elif status == ExecutionStatus.RUNNING:
update_data["startedTime"] = now
elif status in (ExecutionStatus.FAILED, ExecutionStatus.COMPLETED):
update_data["endedTime"] = now
if execution_data:
update_data["executionData"] = Json(execution_data)
if stats:
update_data["stats"] = Json(stats)
return update_data
async def delete_execution(
graph_exec_id: str, user_id: str, soft_delete: bool = True
) -> None:
@@ -368,41 +398,29 @@ async def delete_execution(
)
async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
async def get_execution_results(
graph_exec_id: str,
block_ids: list[str] | None = None,
statuses: list[ExecutionStatus] | None = None,
limit: int | None = None,
) -> list[ExecutionResult]:
where_clause: AgentNodeExecutionWhereInput = {
"agentGraphExecutionId": graph_exec_id,
}
if block_ids:
where_clause["AgentNode"] = {"is": {"agentBlockId": {"in": block_ids}}}
if statuses:
where_clause["OR"] = [{"executionStatus": status} for status in statuses]
executions = await AgentNodeExecution.prisma().find_many(
where={"agentGraphExecutionId": graph_exec_id},
where=where_clause,
include=EXECUTION_RESULT_INCLUDE,
order=[
{"queuedTime": "asc"},
{"addedTime": "asc"}, # Fallback: Incomplete execs has no queuedTime.
],
take=limit,
)
res = [ExecutionResult.from_db(execution) for execution in executions]
return res
async def get_executions_in_timerange(
user_id: str, start_time: str, end_time: str
) -> list[ExecutionResult]:
try:
executions = await AgentGraphExecution.prisma().find_many(
where={
"startedAt": {
"gte": datetime.fromisoformat(start_time),
"lte": datetime.fromisoformat(end_time),
},
"userId": user_id,
"isDeleted": False,
},
include=GRAPH_EXECUTION_INCLUDE,
)
return [ExecutionResult.from_graph(execution) for execution in executions]
except Exception as e:
raise DatabaseError(
f"Failed to get executions in timerange {start_time} to {end_time} for user {user_id}: {e}"
) from e
LIST_SPLIT = "_$_"
DICT_SPLIT = "_#_"
OBJC_SPLIT = "_@_"
@@ -545,7 +563,10 @@ async def get_latest_execution(node_id: str, graph_eid: str) -> ExecutionResult
"agentGraphExecutionId": graph_eid,
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
},
order={"queuedTime": "desc"},
order=[
{"queuedTime": "desc"},
{"addedTime": "desc"},
],
include=EXECUTION_RESULT_INCLUDE,
)
if not execution:

View File

@@ -23,7 +23,7 @@ from backend.util import type as type_utils
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
from .db import BaseDbModel, transaction
from .execution import ExecutionResult, ExecutionStatus
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE, GRAPH_EXECUTION_INCLUDE
from .integrations import Webhook
logger = logging.getLogger(__name__)
@@ -216,10 +216,13 @@ class GraphExecution(GraphExecutionMeta):
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
node_executions = [
ExecutionResult.from_db(ne, _graph_exec.userId)
for ne in _graph_exec.AgentNodeExecutions
]
node_executions = sorted(
[
ExecutionResult.from_db(ne, _graph_exec.userId)
for ne in _graph_exec.AgentNodeExecutions
],
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
)
inputs = {
**{
@@ -658,20 +661,13 @@ async def get_execution_meta(
return GraphExecutionMeta.from_db(execution) if execution else None
async def get_execution(user_id: str, execution_id: str) -> GraphExecution | None:
async def get_execution(
user_id: str,
execution_id: str,
) -> GraphExecution | None:
execution = await AgentGraphExecution.prisma().find_first(
where={"id": execution_id, "isDeleted": False, "userId": user_id},
include={
"AgentNodeExecutions": {
"include": {"AgentNode": True, "Input": True, "Output": True},
"order_by": [
{"queuedTime": "asc"},
{ # Fallback: Incomplete execs has no queuedTime.
"addedTime": "asc"
},
],
},
},
include=GRAPH_EXECUTION_INCLUDE,
)
return GraphExecution.from_db(execution) if execution else None

View File

@@ -18,6 +18,8 @@ EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
"AgentGraphExecution": True,
}
MAX_NODE_EXECUTIONS_FETCH = 1000
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
"AgentNodeExecutions": {
"include": {
@@ -25,10 +27,17 @@ GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
"Output": True,
"AgentNode": True,
"AgentGraphExecution": True,
}
},
"order_by": [
{"queuedTime": "desc"},
# Fallback: Incomplete execs has no queuedTime.
{"addedTime": "desc"},
],
"take": MAX_NODE_EXECUTIONS_FETCH, # Avoid loading excessive node executions.
}
}
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
}

View File

@@ -8,6 +8,7 @@ from backend.data.execution import (
get_incomplete_executions,
get_latest_execution,
update_execution_status,
update_execution_status_batch,
update_graph_execution_start_time,
update_graph_execution_stats,
update_node_execution_stats,
@@ -69,6 +70,7 @@ class DatabaseManager(AppService):
get_incomplete_executions = exposed_run_and_wait(get_incomplete_executions)
get_latest_execution = exposed_run_and_wait(get_latest_execution)
update_execution_status = exposed_run_and_wait(update_execution_status)
update_execution_status_batch = exposed_run_and_wait(update_execution_status_batch)
update_graph_execution_start_time = exposed_run_and_wait(
update_graph_execution_start_time
)

View File

@@ -783,7 +783,10 @@ class Executor:
metadata = cls.db_client.get_graph_metadata(
graph_exec.graph_id, graph_exec.graph_version
)
outputs = cls.db_client.get_execution_results(graph_exec.graph_exec_id)
outputs = cls.db_client.get_execution_results(
graph_exec.graph_exec_id,
block_ids=[AgentOutputBlock().id],
)
named_outputs = [
{
@@ -791,7 +794,6 @@ class Executor:
for key, value in output.output_data.items()
}
for output in outputs
if output.block_id == AgentOutputBlock().id
]
event = NotificationEventDTO(
@@ -985,29 +987,36 @@ class ExecutionManager(AppService):
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
"""
if graph_exec_id not in self.active_graph_runs:
raise Exception(
logger.warning(
f"Graph execution #{graph_exec_id} not active/running: "
"possibly already completed/cancelled."
)
else:
future, cancel_event = self.active_graph_runs[graph_exec_id]
if not cancel_event.is_set():
cancel_event.set()
future.result()
future, cancel_event = self.active_graph_runs[graph_exec_id]
if cancel_event.is_set():
return
cancel_event.set()
future.result()
# Update the status of the unfinished node executions
node_execs = self.db_client.get_execution_results(graph_exec_id)
# Update the status of the graph & node executions
self.db_client.update_graph_execution_stats(
graph_exec_id,
ExecutionStatus.TERMINATED,
)
node_execs = self.db_client.get_execution_results(
graph_exec_id=graph_exec_id,
statuses=[
ExecutionStatus.QUEUED,
ExecutionStatus.RUNNING,
ExecutionStatus.INCOMPLETE,
],
)
self.db_client.update_execution_status_batch(
[node_exec.node_exec_id for node_exec in node_execs],
ExecutionStatus.TERMINATED,
)
for node_exec in node_execs:
if node_exec.status not in (
ExecutionStatus.COMPLETED,
ExecutionStatus.FAILED,
):
exec_update = self.db_client.update_execution_status(
node_exec.node_exec_id, ExecutionStatus.TERMINATED
)
self.db_client.send_execution_update(exec_update)
node_exec.status = ExecutionStatus.TERMINATED
self.db_client.send_execution_update(node_exec)
def _validate_node_input_credentials(self, graph: GraphModel, user_id: str):
"""Checks all credentials for all nodes of the graph"""

View File

@@ -640,12 +640,8 @@ async def get_graph_execution(
graph_exec_id: str,
user_id: Annotated[str, Depends(get_user_id)],
) -> graph_db.GraphExecution:
graph = await graph_db.get_graph(graph_id, user_id=user_id)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
result = await graph_db.get_execution(execution_id=graph_exec_id, user_id=user_id)
if not result:
if not result or result.graph_id != graph_id:
raise HTTPException(
status_code=404, detail=f"Graph execution #{graph_exec_id} not found."
)