mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(backend): Avoid loading all node executions on large continuos agent (#9667)
A graph that processes tens of thousands of nodes will cripple the system since the API tries to load all of them and dump them into the browser. The scope of this change is to avoid such an issue by only returning the last 1000 node executions. ### Changes 🏗️ * Return only 1000 node executions from `AgentNodeExecutions` reference. * Unify the include clause for fetching `AgentNodeExecutions` in one place and its format. * Fix & optimize `cancel_execution` logic to always set both the graph & node execution status in batch. ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Execute a graph in a loop that executes 10000 nodes, it should only display the last 1000 nodes when refreshed. Cancelling the graph should also not cripple the server.
This commit is contained in:
@@ -10,6 +10,7 @@ from prisma.models import (
|
||||
AgentNodeExecution,
|
||||
AgentNodeExecutionInputOutput,
|
||||
)
|
||||
from prisma.types import AgentNodeExecutionUpdateInput, AgentNodeExecutionWhereInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import BlockData, BlockInput, CompletedBlockOutput
|
||||
@@ -292,13 +293,19 @@ async def update_graph_execution_start_time(graph_exec_id: str) -> ExecutionResu
|
||||
async def update_graph_execution_stats(
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
stats: GraphExecutionStats,
|
||||
stats: GraphExecutionStats | None = None,
|
||||
) -> ExecutionResult:
|
||||
data = stats.model_dump()
|
||||
if isinstance(data["error"], Exception):
|
||||
data = stats.model_dump() if stats else {}
|
||||
if isinstance(data.get("error"), Exception):
|
||||
data["error"] = str(data["error"])
|
||||
res = await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
"OR": [
|
||||
{"executionStatus": ExecutionStatus.RUNNING},
|
||||
{"executionStatus": ExecutionStatus.QUEUED},
|
||||
],
|
||||
},
|
||||
data={
|
||||
"executionStatus": status,
|
||||
"stats": Json(data),
|
||||
@@ -320,6 +327,17 @@ async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionSta
|
||||
)
|
||||
|
||||
|
||||
async def update_execution_status_batch(
|
||||
node_exec_ids: list[str],
|
||||
status: ExecutionStatus,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
await AgentNodeExecution.prisma().update_many(
|
||||
where={"id": {"in": node_exec_ids}},
|
||||
data=_get_update_status_data(status, None, stats),
|
||||
)
|
||||
|
||||
|
||||
async def update_execution_status(
|
||||
node_exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
@@ -329,20 +347,9 @@ async def update_execution_status(
|
||||
if status == ExecutionStatus.QUEUED and execution_data is None:
|
||||
raise ValueError("Execution data must be provided when queuing an execution.")
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
data = {
|
||||
**({"executionStatus": status}),
|
||||
**({"queuedTime": now} if status == ExecutionStatus.QUEUED else {}),
|
||||
**({"startedTime": now} if status == ExecutionStatus.RUNNING else {}),
|
||||
**({"endedTime": now} if status == ExecutionStatus.FAILED else {}),
|
||||
**({"endedTime": now} if status == ExecutionStatus.COMPLETED else {}),
|
||||
**({"executionData": Json(execution_data)} if execution_data else {}),
|
||||
**({"stats": Json(stats)} if stats else {}),
|
||||
}
|
||||
|
||||
res = await AgentNodeExecution.prisma().update(
|
||||
where={"id": node_exec_id},
|
||||
data=data, # type: ignore
|
||||
data=_get_update_status_data(status, execution_data, stats),
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
if not res:
|
||||
@@ -351,6 +358,29 @@ async def update_execution_status(
|
||||
return ExecutionResult.from_db(res)
|
||||
|
||||
|
||||
def _get_update_status_data(
|
||||
status: ExecutionStatus,
|
||||
execution_data: BlockInput | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
) -> AgentNodeExecutionUpdateInput:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
update_data: AgentNodeExecutionUpdateInput = {"executionStatus": status}
|
||||
|
||||
if status == ExecutionStatus.QUEUED:
|
||||
update_data["queuedTime"] = now
|
||||
elif status == ExecutionStatus.RUNNING:
|
||||
update_data["startedTime"] = now
|
||||
elif status in (ExecutionStatus.FAILED, ExecutionStatus.COMPLETED):
|
||||
update_data["endedTime"] = now
|
||||
|
||||
if execution_data:
|
||||
update_data["executionData"] = Json(execution_data)
|
||||
if stats:
|
||||
update_data["stats"] = Json(stats)
|
||||
|
||||
return update_data
|
||||
|
||||
|
||||
async def delete_execution(
|
||||
graph_exec_id: str, user_id: str, soft_delete: bool = True
|
||||
) -> None:
|
||||
@@ -368,41 +398,29 @@ async def delete_execution(
|
||||
)
|
||||
|
||||
|
||||
async def get_execution_results(graph_exec_id: str) -> list[ExecutionResult]:
|
||||
async def get_execution_results(
|
||||
graph_exec_id: str,
|
||||
block_ids: list[str] | None = None,
|
||||
statuses: list[ExecutionStatus] | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[ExecutionResult]:
|
||||
where_clause: AgentNodeExecutionWhereInput = {
|
||||
"agentGraphExecutionId": graph_exec_id,
|
||||
}
|
||||
if block_ids:
|
||||
where_clause["AgentNode"] = {"is": {"agentBlockId": {"in": block_ids}}}
|
||||
if statuses:
|
||||
where_clause["OR"] = [{"executionStatus": status} for status in statuses]
|
||||
|
||||
executions = await AgentNodeExecution.prisma().find_many(
|
||||
where={"agentGraphExecutionId": graph_exec_id},
|
||||
where=where_clause,
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
order=[
|
||||
{"queuedTime": "asc"},
|
||||
{"addedTime": "asc"}, # Fallback: Incomplete execs has no queuedTime.
|
||||
],
|
||||
take=limit,
|
||||
)
|
||||
res = [ExecutionResult.from_db(execution) for execution in executions]
|
||||
return res
|
||||
|
||||
|
||||
async def get_executions_in_timerange(
|
||||
user_id: str, start_time: str, end_time: str
|
||||
) -> list[ExecutionResult]:
|
||||
try:
|
||||
executions = await AgentGraphExecution.prisma().find_many(
|
||||
where={
|
||||
"startedAt": {
|
||||
"gte": datetime.fromisoformat(start_time),
|
||||
"lte": datetime.fromisoformat(end_time),
|
||||
},
|
||||
"userId": user_id,
|
||||
"isDeleted": False,
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
return [ExecutionResult.from_graph(execution) for execution in executions]
|
||||
except Exception as e:
|
||||
raise DatabaseError(
|
||||
f"Failed to get executions in timerange {start_time} to {end_time} for user {user_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
LIST_SPLIT = "_$_"
|
||||
DICT_SPLIT = "_#_"
|
||||
OBJC_SPLIT = "_@_"
|
||||
@@ -545,7 +563,10 @@ async def get_latest_execution(node_id: str, graph_eid: str) -> ExecutionResult
|
||||
"agentGraphExecutionId": graph_eid,
|
||||
"executionStatus": {"not": ExecutionStatus.INCOMPLETE}, # type: ignore
|
||||
},
|
||||
order={"queuedTime": "desc"},
|
||||
order=[
|
||||
{"queuedTime": "desc"},
|
||||
{"addedTime": "desc"},
|
||||
],
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
)
|
||||
if not execution:
|
||||
|
||||
@@ -23,7 +23,7 @@ from backend.util import type as type_utils
|
||||
from .block import Block, BlockInput, BlockSchema, BlockType, get_block, get_blocks
|
||||
from .db import BaseDbModel, transaction
|
||||
from .execution import ExecutionResult, ExecutionStatus
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE
|
||||
from .includes import AGENT_GRAPH_INCLUDE, AGENT_NODE_INCLUDE, GRAPH_EXECUTION_INCLUDE
|
||||
from .integrations import Webhook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -216,10 +216,13 @@ class GraphExecution(GraphExecutionMeta):
|
||||
|
||||
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
|
||||
|
||||
node_executions = [
|
||||
ExecutionResult.from_db(ne, _graph_exec.userId)
|
||||
for ne in _graph_exec.AgentNodeExecutions
|
||||
]
|
||||
node_executions = sorted(
|
||||
[
|
||||
ExecutionResult.from_db(ne, _graph_exec.userId)
|
||||
for ne in _graph_exec.AgentNodeExecutions
|
||||
],
|
||||
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
|
||||
)
|
||||
|
||||
inputs = {
|
||||
**{
|
||||
@@ -658,20 +661,13 @@ async def get_execution_meta(
|
||||
return GraphExecutionMeta.from_db(execution) if execution else None
|
||||
|
||||
|
||||
async def get_execution(user_id: str, execution_id: str) -> GraphExecution | None:
|
||||
async def get_execution(
|
||||
user_id: str,
|
||||
execution_id: str,
|
||||
) -> GraphExecution | None:
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={"id": execution_id, "isDeleted": False, "userId": user_id},
|
||||
include={
|
||||
"AgentNodeExecutions": {
|
||||
"include": {"AgentNode": True, "Input": True, "Output": True},
|
||||
"order_by": [
|
||||
{"queuedTime": "asc"},
|
||||
{ # Fallback: Incomplete execs has no queuedTime.
|
||||
"addedTime": "asc"
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
return GraphExecution.from_db(execution) if execution else None
|
||||
|
||||
|
||||
@@ -18,6 +18,8 @@ EXECUTION_RESULT_INCLUDE: prisma.types.AgentNodeExecutionInclude = {
|
||||
"AgentGraphExecution": True,
|
||||
}
|
||||
|
||||
MAX_NODE_EXECUTIONS_FETCH = 1000
|
||||
|
||||
GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
|
||||
"AgentNodeExecutions": {
|
||||
"include": {
|
||||
@@ -25,10 +27,17 @@ GRAPH_EXECUTION_INCLUDE: prisma.types.AgentGraphExecutionInclude = {
|
||||
"Output": True,
|
||||
"AgentNode": True,
|
||||
"AgentGraphExecution": True,
|
||||
}
|
||||
},
|
||||
"order_by": [
|
||||
{"queuedTime": "desc"},
|
||||
# Fallback: Incomplete execs has no queuedTime.
|
||||
{"addedTime": "desc"},
|
||||
],
|
||||
"take": MAX_NODE_EXECUTIONS_FETCH, # Avoid loading excessive node executions.
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
INTEGRATION_WEBHOOK_INCLUDE: prisma.types.IntegrationWebhookInclude = {
|
||||
"AgentNodes": {"include": AGENT_NODE_INCLUDE} # type: ignore
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ from backend.data.execution import (
|
||||
get_incomplete_executions,
|
||||
get_latest_execution,
|
||||
update_execution_status,
|
||||
update_execution_status_batch,
|
||||
update_graph_execution_start_time,
|
||||
update_graph_execution_stats,
|
||||
update_node_execution_stats,
|
||||
@@ -69,6 +70,7 @@ class DatabaseManager(AppService):
|
||||
get_incomplete_executions = exposed_run_and_wait(get_incomplete_executions)
|
||||
get_latest_execution = exposed_run_and_wait(get_latest_execution)
|
||||
update_execution_status = exposed_run_and_wait(update_execution_status)
|
||||
update_execution_status_batch = exposed_run_and_wait(update_execution_status_batch)
|
||||
update_graph_execution_start_time = exposed_run_and_wait(
|
||||
update_graph_execution_start_time
|
||||
)
|
||||
|
||||
@@ -783,7 +783,10 @@ class Executor:
|
||||
metadata = cls.db_client.get_graph_metadata(
|
||||
graph_exec.graph_id, graph_exec.graph_version
|
||||
)
|
||||
outputs = cls.db_client.get_execution_results(graph_exec.graph_exec_id)
|
||||
outputs = cls.db_client.get_execution_results(
|
||||
graph_exec.graph_exec_id,
|
||||
block_ids=[AgentOutputBlock().id],
|
||||
)
|
||||
|
||||
named_outputs = [
|
||||
{
|
||||
@@ -791,7 +794,6 @@ class Executor:
|
||||
for key, value in output.output_data.items()
|
||||
}
|
||||
for output in outputs
|
||||
if output.block_id == AgentOutputBlock().id
|
||||
]
|
||||
|
||||
event = NotificationEventDTO(
|
||||
@@ -985,29 +987,36 @@ class ExecutionManager(AppService):
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
if graph_exec_id not in self.active_graph_runs:
|
||||
raise Exception(
|
||||
logger.warning(
|
||||
f"Graph execution #{graph_exec_id} not active/running: "
|
||||
"possibly already completed/cancelled."
|
||||
)
|
||||
else:
|
||||
future, cancel_event = self.active_graph_runs[graph_exec_id]
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
future.result()
|
||||
|
||||
future, cancel_event = self.active_graph_runs[graph_exec_id]
|
||||
if cancel_event.is_set():
|
||||
return
|
||||
|
||||
cancel_event.set()
|
||||
future.result()
|
||||
|
||||
# Update the status of the unfinished node executions
|
||||
node_execs = self.db_client.get_execution_results(graph_exec_id)
|
||||
# Update the status of the graph & node executions
|
||||
self.db_client.update_graph_execution_stats(
|
||||
graph_exec_id,
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
node_execs = self.db_client.get_execution_results(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
],
|
||||
)
|
||||
self.db_client.update_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in node_execs],
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
for node_exec in node_execs:
|
||||
if node_exec.status not in (
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.FAILED,
|
||||
):
|
||||
exec_update = self.db_client.update_execution_status(
|
||||
node_exec.node_exec_id, ExecutionStatus.TERMINATED
|
||||
)
|
||||
self.db_client.send_execution_update(exec_update)
|
||||
node_exec.status = ExecutionStatus.TERMINATED
|
||||
self.db_client.send_execution_update(node_exec)
|
||||
|
||||
def _validate_node_input_credentials(self, graph: GraphModel, user_id: str):
|
||||
"""Checks all credentials for all nodes of the graph"""
|
||||
|
||||
@@ -640,12 +640,8 @@ async def get_graph_execution(
|
||||
graph_exec_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> graph_db.GraphExecution:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
result = await graph_db.get_execution(execution_id=graph_exec_id, user_id=user_id)
|
||||
if not result:
|
||||
if not result or result.graph_id != graph_id:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Graph execution #{graph_exec_id} not found."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user