feat(backend): Improve cancel execution reliability (#9889)

When an executor dies, an ongoing execution will not be retried and will
just stuck in the running status.
This change avoids such a scenario by allowing an execution of an entry
that is not in QUEUED status with the low-probability risk of double
execution.

### Changes 🏗️

* Allow non-QUEUED status to be re-executed.
* Improve cleanup of node & graph executor.
* Make a cancellation request consumption a separate thread to avoid
being blocked by other messages.
* Remove unused retry loop on the execution manager.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Run agent, kill the server, re-run it, agent restarted.
This commit is contained in:
Zamil Majdy
2025-04-30 00:06:03 +07:00
committed by GitHub
parent d5dc687484
commit 9fa62c03f6
3 changed files with 61 additions and 95 deletions

View File

@@ -492,21 +492,12 @@ async def upsert_execution_output(
async def update_graph_execution_start_time(
graph_exec_id: str,
) -> GraphExecution | None:
count = await AgentGraphExecution.prisma().update_many(
where={
"id": graph_exec_id,
"executionStatus": ExecutionStatus.QUEUED,
},
res = await AgentGraphExecution.prisma().update(
where={"id": graph_exec_id},
data={
"executionStatus": ExecutionStatus.RUNNING,
"startedAt": datetime.now(tz=timezone.utc),
},
)
if count == 0:
return None
res = await AgentGraphExecution.prisma().find_unique(
where={"id": graph_exec_id},
include=GRAPH_EXECUTION_INCLUDE,
)
return GraphExecution.from_db(res) if res else None

View File

@@ -5,7 +5,6 @@ import os
import signal
import sys
import threading
import time
from concurrent.futures import Future, ProcessPoolExecutor
from contextlib import contextmanager
from multiprocessing.pool import AsyncResult, Pool
@@ -23,13 +22,15 @@ from backend.data.notifications import (
NotificationEventDTO,
NotificationType,
)
from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.utils import create_execution_queue_config
from backend.util.exceptions import InsufficientBalanceError
if TYPE_CHECKING:
from backend.executor import DatabaseManager
from backend.notifications.notifications import NotificationManager
from autogpt_libs.utils.cache import clear_thread_cache, thread_cached
from autogpt_libs.utils.cache import thread_cached
from prometheus_client import Gauge, start_http_server
from backend.blocks.agent import AgentExecutorBlock
@@ -446,36 +447,28 @@ class Executor:
# Set up shutdown handlers
cls.shutdown_lock = threading.Lock()
atexit.register(cls.on_node_executor_stop) # handle regular shutdown
signal.signal( # handle termination
signal.SIGTERM, lambda _, __: cls.on_node_executor_sigterm()
)
atexit.register(cls.on_node_executor_stop)
signal.signal(signal.SIGTERM, lambda _, __: cls.on_node_executor_sigterm())
signal.signal(signal.SIGINT, lambda _, __: cls.on_node_executor_sigterm())
@classmethod
def on_node_executor_stop(cls):
def on_node_executor_stop(cls, log=logger.info):
if not cls.shutdown_lock.acquire(blocking=False):
return # already shutting down
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
log(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
cls.creds_manager.release_all_locks()
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
log(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
redis.disconnect()
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB manager...")
log(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB manager...")
close_service_client(cls.db_client)
logger.info(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
log(f"[on_node_executor_stop {cls.pid}] ✅ Finished NodeExec cleanup")
sys.exit(0)
@classmethod
def on_node_executor_sigterm(cls):
llprint(f"[on_node_executor_sigterm {cls.pid}] ⚠️ SIGTERM received")
if not cls.shutdown_lock.acquire(blocking=False):
return # already shutting down
llprint(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
cls.creds_manager.release_all_locks()
llprint(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
redis.disconnect()
llprint(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
sys.exit(0)
llprint(f"[on_node_executor_sigterm {cls.pid}] ⚠️ NodeExec SIGTERM received")
cls.on_node_executor_stop(log=llprint)
@classmethod
@error_logged
@@ -551,21 +544,7 @@ class Executor:
cls.pid = os.getpid()
cls.notification_service = get_notification_service()
cls._init_node_executor_pool()
logger.info(
f"Graph executor {cls.pid} started with {cls.pool_size} node workers"
)
# Set up shutdown handler
atexit.register(cls.on_graph_executor_stop)
@classmethod
def on_graph_executor_stop(cls):
prefix = f"[on_graph_executor_stop {cls.pid}]"
logger.info(f"{prefix} ⏳ Terminating node executor pool...")
cls.executor.terminate()
logger.info(f"{prefix} ⏳ Disconnecting DB manager...")
close_service_client(cls.db_client)
logger.info(f"{prefix} ✅ Finished cleanup")
logger.info(f"GraphExec {cls.pid} started with {cls.pool_size} node workers")
@classmethod
def _init_node_executor_pool(cls):
@@ -592,7 +571,7 @@ class Executor:
)
if exec_meta is None:
logger.warning(
f"Skipped graph execution {graph_exec.graph_exec_id}, the graph execution is not found or not currently in the QUEUED state."
f"Skipped graph execution {graph_exec.graph_exec_id}, the graph execution is not found."
)
return
@@ -909,13 +888,14 @@ class ExecutionManager(AppProcess):
self.pool_size = settings.config.num_graph_workers
self.running = True
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
atexit.register(self._on_cleanup)
signal.signal(signal.SIGTERM, lambda sig, frame: self._on_sigterm())
signal.signal(signal.SIGINT, lambda sig, frame: self._on_sigterm())
def run(self):
pool_size_gauge.set(self.pool_size)
active_runs_gauge.set(0)
utilization_gauge.set(0)
retry_count_max = settings.config.execution_manager_loop_max_retry
retry_count = 0
self.metrics_server = threading.Thread(
target=start_http_server,
@@ -924,27 +904,7 @@ class ExecutionManager(AppProcess):
)
self.metrics_server.start()
logger.info(f"[{self.service_name}] Starting execution manager...")
for retry_count in range(retry_count_max):
try:
self._run()
except Exception as e:
if not self.running:
break
logger.exception(
f"[{self.service_name}] Error in execution manager: {e}"
)
if retry_count >= retry_count_max:
logger.error(
f"[{self.service_name}] Max retries reached ({retry_count_max}), exiting..."
)
break
else:
logger.info(
f"[{self.service_name}] Retrying execution loop in {retry_count} seconds..."
)
time.sleep(retry_count)
self._run()
def _run(self):
logger.info(f"[{self.service_name}] ⏳ Spawn max-{self.pool_size} workers...")
@@ -956,23 +916,33 @@ class ExecutionManager(AppProcess):
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
redis.connect()
# Consume Cancel & Run execution requests.
clear_thread_cache(get_execution_queue)
channel = get_execution_queue().get_channel()
channel.basic_qos(prefetch_count=self.pool_size)
channel.basic_consume(
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
on_message_callback=self._handle_cancel_message,
auto_ack=True,
)
channel.basic_consume(
cancel_client = SyncRabbitMQ(create_execution_queue_config())
cancel_client.connect()
cancel_channel = cancel_client.get_channel()
logger.info(f"[{self.service_name}] ⏳ Starting cancel message consumer...")
threading.Thread(
target=lambda: (
cancel_channel.basic_consume(
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
on_message_callback=self._handle_cancel_message,
auto_ack=True,
),
cancel_channel.start_consuming(),
),
daemon=True,
).start()
run_client = SyncRabbitMQ(create_execution_queue_config())
run_client.connect()
run_channel = run_client.get_channel()
run_channel.basic_qos(prefetch_count=self.pool_size)
run_channel.basic_consume(
queue=GRAPH_EXECUTION_QUEUE_NAME,
on_message_callback=self._handle_run_message,
auto_ack=False,
)
logger.info(f"[{self.service_name}] Ready to consume messages...")
channel.start_consuming()
logger.info(f"[{self.service_name}] ⏳ Starting to consume run messages...")
run_channel.start_consuming()
def _handle_cancel_message(
self,
@@ -1069,19 +1039,29 @@ class ExecutionManager(AppProcess):
def cleanup(self):
super().cleanup()
self._on_cleanup()
logger.info(f"[{self.service_name}] ⏳ Shutting down service loop...")
def _on_sigterm(self):
llprint(f"[{self.service_name}] ⚠️ GraphExec SIGTERM received")
self._on_cleanup(log=llprint)
def _on_cleanup(self, log=logger.info):
prefix = f"[{self.service_name}][on_graph_executor_stop {os.getpid()}]"
log(f"{prefix} ⏳ Shutting down service loop...")
self.running = False
logger.info(f"[{self.service_name}] ⏳ Shutting down RabbitMQ channel...")
log(f"{prefix} ⏳ Shutting down RabbitMQ channel...")
get_execution_queue().get_channel().stop_consuming()
logger.info(f"[{self.service_name}] ⏳ Shutting down graph executor pool...")
self.executor.shutdown(cancel_futures=True)
if hasattr(self, "executor"):
log(f"{prefix} ⏳ Shutting down GraphExec pool...")
self.executor.shutdown(cancel_futures=True, wait=True)
logger.info(f"[{self.service_name}] ⏳ Disconnecting Redis...")
log(f"{prefix} ⏳ Disconnecting Redis...")
redis.disconnect()
log(f"{prefix} ✅ Finished GraphExec cleanup")
@property
def db_client(self) -> "DatabaseManager":
return get_db_client()
@@ -1124,5 +1104,4 @@ def llprint(message: str):
Low-level print/log helper function for use in signal handlers.
Regular log/print statements are not allowed in signal handlers.
"""
if logger.getEffectiveLevel() == logging.DEBUG:
os.write(sys.stdout.fileno(), (message + "\n").encode())
os.write(sys.stdout.fileno(), (message + "\n").encode())

View File

@@ -137,10 +137,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default=8002,
description="The port for execution manager daemon to run on",
)
execution_manager_loop_max_retry: int = Field(
default=5,
description="The maximum number of retries for the execution manager loop",
)
execution_scheduler_port: int = Field(
default=8003,