mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(executor): Improve execution ordering to allow depth-first execution (#10142)
Allowing depth-first execution will unlock faster processing latency and a better sense of progress. <img width="950" alt="image" src="https://github.com/user-attachments/assets/e2a0e11a-8bc5-4a65-a10d-b5b6c6383354" /> ### Changes 🏗️ * Prioritize adding a new execution over processing execution output * Make sure to enqueue each node once when processing output instead of draining a single node and move on. ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: <!-- Put your test plan here: --> - [x] Run company follower count finder agent. --------- Co-authored-by: Swifty <craigswift13@gmail.com>
This commit is contained in:
@@ -628,15 +628,11 @@ async def update_node_execution_stats(
|
||||
data = stats.model_dump()
|
||||
if isinstance(data["error"], Exception):
|
||||
data["error"] = str(data["error"])
|
||||
execution_status = ExecutionStatus.FAILED
|
||||
else:
|
||||
execution_status = ExecutionStatus.COMPLETED
|
||||
|
||||
res = await AgentNodeExecution.prisma().update(
|
||||
where={"id": node_exec_id},
|
||||
data={
|
||||
"stats": Json(data),
|
||||
"executionStatus": execution_status,
|
||||
"endedTime": datetime.now(tz=timezone.utc),
|
||||
},
|
||||
include=EXECUTION_RESULT_INCLUDE,
|
||||
|
||||
@@ -706,6 +706,42 @@ class Executor:
|
||||
)
|
||||
running_executions[output.node.id].add_output(output)
|
||||
|
||||
def drain_done_task(node_exec_id: str, result: object):
|
||||
if not isinstance(result, NodeExecutionStats):
|
||||
log_metadata.error(f"Unexpected result #{node_exec_id}: {type(result)}")
|
||||
return
|
||||
|
||||
nonlocal execution_stats
|
||||
execution_stats.node_count += 1
|
||||
execution_stats.nodes_cputime += result.cputime
|
||||
execution_stats.nodes_walltime += result.walltime
|
||||
if (err := result.error) and isinstance(err, Exception):
|
||||
execution_stats.node_error_count += 1
|
||||
cls.db_client.update_node_execution_status(
|
||||
node_exec_id=node_exec_id,
|
||||
status=ExecutionStatus.FAILED,
|
||||
)
|
||||
else:
|
||||
cls.db_client.update_node_execution_status(
|
||||
node_exec_id=node_exec_id,
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
)
|
||||
|
||||
if _graph_exec := cls.db_client.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
status=execution_status,
|
||||
stats=execution_stats,
|
||||
):
|
||||
send_execution_update(_graph_exec)
|
||||
else:
|
||||
logger.error(
|
||||
"Callback for "
|
||||
f"finished node execution #{node_exec_id} "
|
||||
"could not update execution stats "
|
||||
f"for graph execution #{graph_exec.graph_exec_id}; "
|
||||
f"triggered while graph exec status = {execution_status}"
|
||||
)
|
||||
|
||||
def cancel_handler():
|
||||
nonlocal execution_status
|
||||
|
||||
@@ -739,38 +775,12 @@ class Executor:
|
||||
execution_queue.add(node_exec.to_node_execution_entry())
|
||||
|
||||
running_executions: dict[str, NodeExecutionProgress] = defaultdict(
|
||||
lambda: NodeExecutionProgress(drain_output_queue)
|
||||
lambda: NodeExecutionProgress(
|
||||
drain_output_queue=drain_output_queue,
|
||||
drain_done_task=drain_done_task,
|
||||
)
|
||||
)
|
||||
|
||||
def make_exec_callback(exec_data: NodeExecutionEntry):
|
||||
def callback(result: object):
|
||||
if not isinstance(result, NodeExecutionStats):
|
||||
return
|
||||
|
||||
nonlocal execution_stats
|
||||
execution_stats.node_count += 1
|
||||
execution_stats.nodes_cputime += result.cputime
|
||||
execution_stats.nodes_walltime += result.walltime
|
||||
if (err := result.error) and isinstance(err, Exception):
|
||||
execution_stats.node_error_count += 1
|
||||
|
||||
if _graph_exec := cls.db_client.update_graph_execution_stats(
|
||||
graph_exec_id=exec_data.graph_exec_id,
|
||||
status=execution_status,
|
||||
stats=execution_stats,
|
||||
):
|
||||
send_execution_update(_graph_exec)
|
||||
else:
|
||||
logger.error(
|
||||
"Callback for "
|
||||
f"finished node execution #{exec_data.node_exec_id} "
|
||||
"could not update execution stats "
|
||||
f"for graph execution #{exec_data.graph_exec_id}; "
|
||||
f"triggered while graph exec status = {execution_status}"
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
while not execution_queue.empty():
|
||||
if cancel.is_set():
|
||||
execution_status = ExecutionStatus.TERMINATED
|
||||
@@ -829,7 +839,6 @@ class Executor:
|
||||
cls.executor.apply_async(
|
||||
cls.on_node_execution,
|
||||
(output_queue, queued_node_exec, node_creds_map),
|
||||
callback=make_exec_callback(queued_node_exec),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -845,9 +854,6 @@ class Executor:
|
||||
execution_status = ExecutionStatus.TERMINATED
|
||||
return execution_stats, execution_status, error
|
||||
|
||||
if not execution_queue.empty():
|
||||
break # yield to parent loop to execute new queue items
|
||||
|
||||
log_metadata.debug(f"Waiting on execution of node {node_id}")
|
||||
while output := execution.pop_output():
|
||||
cls._process_node_output(
|
||||
@@ -858,11 +864,20 @@ class Executor:
|
||||
node_creds_map=node_creds_map,
|
||||
execution_queue=execution_queue,
|
||||
)
|
||||
if not execution_queue.empty():
|
||||
break # Prioritize executing next nodes than enqueuing outputs
|
||||
|
||||
if execution.is_done(1):
|
||||
if execution.is_done():
|
||||
running_executions.pop(node_id)
|
||||
else:
|
||||
time.sleep(0.1)
|
||||
|
||||
if not execution_queue.empty():
|
||||
continue # Make sure each not is checked once
|
||||
|
||||
if execution_queue.empty() and running_executions:
|
||||
log_metadata.debug(
|
||||
"No more nodes to execute, waiting for outputs..."
|
||||
)
|
||||
time.sleep(0.1)
|
||||
|
||||
log_metadata.info(f"Finished graph execution {graph_exec.graph_exec_id}")
|
||||
execution_status = ExecutionStatus.COMPLETED
|
||||
|
||||
@@ -818,10 +818,15 @@ class ExecutionOutputEntry(BaseModel):
|
||||
|
||||
|
||||
class NodeExecutionProgress:
|
||||
def __init__(self, drain_output_queue: Callable[[], None]):
|
||||
def __init__(
|
||||
self,
|
||||
drain_output_queue: Callable[[], None],
|
||||
drain_done_task: Callable[[str, object], None],
|
||||
):
|
||||
self.output: dict[str, list[ExecutionOutputEntry]] = defaultdict(list)
|
||||
self.tasks: dict[str, AsyncResult] = {}
|
||||
self.drain_output_queue = drain_output_queue
|
||||
self.drain_done_task = drain_done_task
|
||||
|
||||
def add_task(self, node_exec_id: str, task: AsyncResult):
|
||||
self.tasks[node_exec_id] = task
|
||||
@@ -868,7 +873,9 @@ class NodeExecutionProgress:
|
||||
if self.output[exec_id]:
|
||||
return False
|
||||
|
||||
self.tasks.pop(exec_id)
|
||||
if task := self.tasks.pop(exec_id):
|
||||
self.drain_done_task(exec_id, task.get())
|
||||
|
||||
return True
|
||||
|
||||
def _next_exec(self) -> str | None:
|
||||
|
||||
Reference in New Issue
Block a user