feat(executor): Improve execution ordering to allow depth-first execution (#10142)

Allowing depth-first execution will unlock faster processing latency and
a better sense of progress.

<img width="950" alt="image"
src="https://github.com/user-attachments/assets/e2a0e11a-8bc5-4a65-a10d-b5b6c6383354"
/>


### Changes 🏗️

* Prioritize adding a new execution over processing execution output
* Make sure to enqueue each node once when processing output instead of
draining a single node and move on.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Run company follower count finder agent.

---------

Co-authored-by: Swifty <craigswift13@gmail.com>
This commit is contained in:
Zamil Majdy
2025-06-10 19:41:31 +07:00
committed by GitHub
parent f9b37d2693
commit 210d457ecd
3 changed files with 61 additions and 43 deletions

View File

@@ -628,15 +628,11 @@ async def update_node_execution_stats(
data = stats.model_dump()
if isinstance(data["error"], Exception):
data["error"] = str(data["error"])
execution_status = ExecutionStatus.FAILED
else:
execution_status = ExecutionStatus.COMPLETED
res = await AgentNodeExecution.prisma().update(
where={"id": node_exec_id},
data={
"stats": Json(data),
"executionStatus": execution_status,
"endedTime": datetime.now(tz=timezone.utc),
},
include=EXECUTION_RESULT_INCLUDE,

View File

@@ -706,6 +706,42 @@ class Executor:
)
running_executions[output.node.id].add_output(output)
def drain_done_task(node_exec_id: str, result: object):
if not isinstance(result, NodeExecutionStats):
log_metadata.error(f"Unexpected result #{node_exec_id}: {type(result)}")
return
nonlocal execution_stats
execution_stats.node_count += 1
execution_stats.nodes_cputime += result.cputime
execution_stats.nodes_walltime += result.walltime
if (err := result.error) and isinstance(err, Exception):
execution_stats.node_error_count += 1
cls.db_client.update_node_execution_status(
node_exec_id=node_exec_id,
status=ExecutionStatus.FAILED,
)
else:
cls.db_client.update_node_execution_status(
node_exec_id=node_exec_id,
status=ExecutionStatus.COMPLETED,
)
if _graph_exec := cls.db_client.update_graph_execution_stats(
graph_exec_id=graph_exec.graph_exec_id,
status=execution_status,
stats=execution_stats,
):
send_execution_update(_graph_exec)
else:
logger.error(
"Callback for "
f"finished node execution #{node_exec_id} "
"could not update execution stats "
f"for graph execution #{graph_exec.graph_exec_id}; "
f"triggered while graph exec status = {execution_status}"
)
def cancel_handler():
nonlocal execution_status
@@ -739,38 +775,12 @@ class Executor:
execution_queue.add(node_exec.to_node_execution_entry())
running_executions: dict[str, NodeExecutionProgress] = defaultdict(
lambda: NodeExecutionProgress(drain_output_queue)
lambda: NodeExecutionProgress(
drain_output_queue=drain_output_queue,
drain_done_task=drain_done_task,
)
)
def make_exec_callback(exec_data: NodeExecutionEntry):
def callback(result: object):
if not isinstance(result, NodeExecutionStats):
return
nonlocal execution_stats
execution_stats.node_count += 1
execution_stats.nodes_cputime += result.cputime
execution_stats.nodes_walltime += result.walltime
if (err := result.error) and isinstance(err, Exception):
execution_stats.node_error_count += 1
if _graph_exec := cls.db_client.update_graph_execution_stats(
graph_exec_id=exec_data.graph_exec_id,
status=execution_status,
stats=execution_stats,
):
send_execution_update(_graph_exec)
else:
logger.error(
"Callback for "
f"finished node execution #{exec_data.node_exec_id} "
"could not update execution stats "
f"for graph execution #{exec_data.graph_exec_id}; "
f"triggered while graph exec status = {execution_status}"
)
return callback
while not execution_queue.empty():
if cancel.is_set():
execution_status = ExecutionStatus.TERMINATED
@@ -829,7 +839,6 @@ class Executor:
cls.executor.apply_async(
cls.on_node_execution,
(output_queue, queued_node_exec, node_creds_map),
callback=make_exec_callback(queued_node_exec),
),
)
@@ -845,9 +854,6 @@ class Executor:
execution_status = ExecutionStatus.TERMINATED
return execution_stats, execution_status, error
if not execution_queue.empty():
break # yield to parent loop to execute new queue items
log_metadata.debug(f"Waiting on execution of node {node_id}")
while output := execution.pop_output():
cls._process_node_output(
@@ -858,11 +864,20 @@ class Executor:
node_creds_map=node_creds_map,
execution_queue=execution_queue,
)
if not execution_queue.empty():
break # Prioritize executing next nodes than enqueuing outputs
if execution.is_done(1):
if execution.is_done():
running_executions.pop(node_id)
else:
time.sleep(0.1)
if not execution_queue.empty():
continue # Make sure each not is checked once
if execution_queue.empty() and running_executions:
log_metadata.debug(
"No more nodes to execute, waiting for outputs..."
)
time.sleep(0.1)
log_metadata.info(f"Finished graph execution {graph_exec.graph_exec_id}")
execution_status = ExecutionStatus.COMPLETED

View File

@@ -818,10 +818,15 @@ class ExecutionOutputEntry(BaseModel):
class NodeExecutionProgress:
def __init__(self, drain_output_queue: Callable[[], None]):
def __init__(
self,
drain_output_queue: Callable[[], None],
drain_done_task: Callable[[str, object], None],
):
self.output: dict[str, list[ExecutionOutputEntry]] = defaultdict(list)
self.tasks: dict[str, AsyncResult] = {}
self.drain_output_queue = drain_output_queue
self.drain_done_task = drain_done_task
def add_task(self, node_exec_id: str, task: AsyncResult):
self.tasks[node_exec_id] = task
@@ -868,7 +873,9 @@ class NodeExecutionProgress:
if self.output[exec_id]:
return False
self.tasks.pop(exec_id)
if task := self.tasks.pop(exec_id):
self.drain_done_task(exec_id, task.get())
return True
def _next_exec(self) -> str | None: