mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(backend): Improve cancel execution reliability (#9889)
When an executor dies, an ongoing execution will not be retried and will just stuck in the running status. This change avoids such a scenario by allowing an execution of an entry that is not in QUEUED status with the low-probability risk of double execution. ### Changes 🏗️ * Allow non-QUEUED status to be re-executed. * Improve cleanup of node & graph executor. * Make a cancellation request consumption a separate thread to avoid being blocked by other messages. * Remove unused retry loop on the execution manager. ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: <!-- Put your test plan here: --> - [x] Run agent, kill the server, re-run it, agent restarted.
This commit is contained in:
@@ -492,21 +492,12 @@ async def upsert_execution_output(
|
||||
async def update_graph_execution_start_time(
|
||||
graph_exec_id: str,
|
||||
) -> GraphExecution | None:
|
||||
count = await AgentGraphExecution.prisma().update_many(
|
||||
where={
|
||||
"id": graph_exec_id,
|
||||
"executionStatus": ExecutionStatus.QUEUED,
|
||||
},
|
||||
res = await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
data={
|
||||
"executionStatus": ExecutionStatus.RUNNING,
|
||||
"startedAt": datetime.now(tz=timezone.utc),
|
||||
},
|
||||
)
|
||||
if count == 0:
|
||||
return None
|
||||
|
||||
res = await AgentGraphExecution.prisma().find_unique(
|
||||
where={"id": graph_exec_id},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
return GraphExecution.from_db(res) if res else None
|
||||
|
||||
@@ -5,7 +5,6 @@ import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing.pool import AsyncResult, Pool
|
||||
@@ -23,13 +22,15 @@ from backend.data.notifications import (
|
||||
NotificationEventDTO,
|
||||
NotificationType,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.utils import create_execution_queue_config
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManager
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
|
||||
from autogpt_libs.utils.cache import clear_thread_cache, thread_cached
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
@@ -446,36 +447,28 @@ class Executor:
|
||||
|
||||
# Set up shutdown handlers
|
||||
cls.shutdown_lock = threading.Lock()
|
||||
atexit.register(cls.on_node_executor_stop) # handle regular shutdown
|
||||
signal.signal( # handle termination
|
||||
signal.SIGTERM, lambda _, __: cls.on_node_executor_sigterm()
|
||||
)
|
||||
atexit.register(cls.on_node_executor_stop)
|
||||
signal.signal(signal.SIGTERM, lambda _, __: cls.on_node_executor_sigterm())
|
||||
signal.signal(signal.SIGINT, lambda _, __: cls.on_node_executor_sigterm())
|
||||
|
||||
@classmethod
|
||||
def on_node_executor_stop(cls):
|
||||
def on_node_executor_stop(cls, log=logger.info):
|
||||
if not cls.shutdown_lock.acquire(blocking=False):
|
||||
return # already shutting down
|
||||
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
|
||||
log(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
|
||||
cls.creds_manager.release_all_locks()
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
|
||||
log(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB manager...")
|
||||
log(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB manager...")
|
||||
close_service_client(cls.db_client)
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
|
||||
log(f"[on_node_executor_stop {cls.pid}] ✅ Finished NodeExec cleanup")
|
||||
sys.exit(0)
|
||||
|
||||
@classmethod
|
||||
def on_node_executor_sigterm(cls):
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ⚠️ SIGTERM received")
|
||||
if not cls.shutdown_lock.acquire(blocking=False):
|
||||
return # already shutting down
|
||||
|
||||
llprint(f"[on_node_executor_stop {cls.pid}] ⏳ Releasing locks...")
|
||||
cls.creds_manager.release_all_locks()
|
||||
llprint(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
llprint(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
|
||||
sys.exit(0)
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ⚠️ NodeExec SIGTERM received")
|
||||
cls.on_node_executor_stop(log=llprint)
|
||||
|
||||
@classmethod
|
||||
@error_logged
|
||||
@@ -551,21 +544,7 @@ class Executor:
|
||||
cls.pid = os.getpid()
|
||||
cls.notification_service = get_notification_service()
|
||||
cls._init_node_executor_pool()
|
||||
logger.info(
|
||||
f"Graph executor {cls.pid} started with {cls.pool_size} node workers"
|
||||
)
|
||||
|
||||
# Set up shutdown handler
|
||||
atexit.register(cls.on_graph_executor_stop)
|
||||
|
||||
@classmethod
|
||||
def on_graph_executor_stop(cls):
|
||||
prefix = f"[on_graph_executor_stop {cls.pid}]"
|
||||
logger.info(f"{prefix} ⏳ Terminating node executor pool...")
|
||||
cls.executor.terminate()
|
||||
logger.info(f"{prefix} ⏳ Disconnecting DB manager...")
|
||||
close_service_client(cls.db_client)
|
||||
logger.info(f"{prefix} ✅ Finished cleanup")
|
||||
logger.info(f"GraphExec {cls.pid} started with {cls.pool_size} node workers")
|
||||
|
||||
@classmethod
|
||||
def _init_node_executor_pool(cls):
|
||||
@@ -592,7 +571,7 @@ class Executor:
|
||||
)
|
||||
if exec_meta is None:
|
||||
logger.warning(
|
||||
f"Skipped graph execution {graph_exec.graph_exec_id}, the graph execution is not found or not currently in the QUEUED state."
|
||||
f"Skipped graph execution {graph_exec.graph_exec_id}, the graph execution is not found."
|
||||
)
|
||||
return
|
||||
|
||||
@@ -909,13 +888,14 @@ class ExecutionManager(AppProcess):
|
||||
self.pool_size = settings.config.num_graph_workers
|
||||
self.running = True
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
atexit.register(self._on_cleanup)
|
||||
signal.signal(signal.SIGTERM, lambda sig, frame: self._on_sigterm())
|
||||
signal.signal(signal.SIGINT, lambda sig, frame: self._on_sigterm())
|
||||
|
||||
def run(self):
|
||||
pool_size_gauge.set(self.pool_size)
|
||||
active_runs_gauge.set(0)
|
||||
utilization_gauge.set(0)
|
||||
retry_count_max = settings.config.execution_manager_loop_max_retry
|
||||
retry_count = 0
|
||||
|
||||
self.metrics_server = threading.Thread(
|
||||
target=start_http_server,
|
||||
@@ -924,27 +904,7 @@ class ExecutionManager(AppProcess):
|
||||
)
|
||||
self.metrics_server.start()
|
||||
logger.info(f"[{self.service_name}] Starting execution manager...")
|
||||
|
||||
for retry_count in range(retry_count_max):
|
||||
try:
|
||||
self._run()
|
||||
except Exception as e:
|
||||
if not self.running:
|
||||
break
|
||||
logger.exception(
|
||||
f"[{self.service_name}] Error in execution manager: {e}"
|
||||
)
|
||||
|
||||
if retry_count >= retry_count_max:
|
||||
logger.error(
|
||||
f"[{self.service_name}] Max retries reached ({retry_count_max}), exiting..."
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.info(
|
||||
f"[{self.service_name}] Retrying execution loop in {retry_count} seconds..."
|
||||
)
|
||||
time.sleep(retry_count)
|
||||
self._run()
|
||||
|
||||
def _run(self):
|
||||
logger.info(f"[{self.service_name}] ⏳ Spawn max-{self.pool_size} workers...")
|
||||
@@ -956,23 +916,33 @@ class ExecutionManager(AppProcess):
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Redis...")
|
||||
redis.connect()
|
||||
|
||||
# Consume Cancel & Run execution requests.
|
||||
clear_thread_cache(get_execution_queue)
|
||||
channel = get_execution_queue().get_channel()
|
||||
channel.basic_qos(prefetch_count=self.pool_size)
|
||||
channel.basic_consume(
|
||||
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
on_message_callback=self._handle_cancel_message,
|
||||
auto_ack=True,
|
||||
)
|
||||
channel.basic_consume(
|
||||
cancel_client = SyncRabbitMQ(create_execution_queue_config())
|
||||
cancel_client.connect()
|
||||
cancel_channel = cancel_client.get_channel()
|
||||
logger.info(f"[{self.service_name}] ⏳ Starting cancel message consumer...")
|
||||
threading.Thread(
|
||||
target=lambda: (
|
||||
cancel_channel.basic_consume(
|
||||
queue=GRAPH_EXECUTION_CANCEL_QUEUE_NAME,
|
||||
on_message_callback=self._handle_cancel_message,
|
||||
auto_ack=True,
|
||||
),
|
||||
cancel_channel.start_consuming(),
|
||||
),
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
run_client = SyncRabbitMQ(create_execution_queue_config())
|
||||
run_client.connect()
|
||||
run_channel = run_client.get_channel()
|
||||
run_channel.basic_qos(prefetch_count=self.pool_size)
|
||||
run_channel.basic_consume(
|
||||
queue=GRAPH_EXECUTION_QUEUE_NAME,
|
||||
on_message_callback=self._handle_run_message,
|
||||
auto_ack=False,
|
||||
)
|
||||
|
||||
logger.info(f"[{self.service_name}] Ready to consume messages...")
|
||||
channel.start_consuming()
|
||||
logger.info(f"[{self.service_name}] ⏳ Starting to consume run messages...")
|
||||
run_channel.start_consuming()
|
||||
|
||||
def _handle_cancel_message(
|
||||
self,
|
||||
@@ -1069,19 +1039,29 @@ class ExecutionManager(AppProcess):
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
self._on_cleanup()
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down service loop...")
|
||||
def _on_sigterm(self):
|
||||
llprint(f"[{self.service_name}] ⚠️ GraphExec SIGTERM received")
|
||||
self._on_cleanup(log=llprint)
|
||||
|
||||
def _on_cleanup(self, log=logger.info):
|
||||
prefix = f"[{self.service_name}][on_graph_executor_stop {os.getpid()}]"
|
||||
log(f"{prefix} ⏳ Shutting down service loop...")
|
||||
self.running = False
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down RabbitMQ channel...")
|
||||
log(f"{prefix} ⏳ Shutting down RabbitMQ channel...")
|
||||
get_execution_queue().get_channel().stop_consuming()
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down graph executor pool...")
|
||||
self.executor.shutdown(cancel_futures=True)
|
||||
if hasattr(self, "executor"):
|
||||
log(f"{prefix} ⏳ Shutting down GraphExec pool...")
|
||||
self.executor.shutdown(cancel_futures=True, wait=True)
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Redis...")
|
||||
log(f"{prefix} ⏳ Disconnecting Redis...")
|
||||
redis.disconnect()
|
||||
|
||||
log(f"{prefix} ✅ Finished GraphExec cleanup")
|
||||
|
||||
@property
|
||||
def db_client(self) -> "DatabaseManager":
|
||||
return get_db_client()
|
||||
@@ -1124,5 +1104,4 @@ def llprint(message: str):
|
||||
Low-level print/log helper function for use in signal handlers.
|
||||
Regular log/print statements are not allowed in signal handlers.
|
||||
"""
|
||||
if logger.getEffectiveLevel() == logging.DEBUG:
|
||||
os.write(sys.stdout.fileno(), (message + "\n").encode())
|
||||
os.write(sys.stdout.fileno(), (message + "\n").encode())
|
||||
|
||||
@@ -137,10 +137,6 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default=8002,
|
||||
description="The port for execution manager daemon to run on",
|
||||
)
|
||||
execution_manager_loop_max_retry: int = Field(
|
||||
default=5,
|
||||
description="The maximum number of retries for the execution manager loop",
|
||||
)
|
||||
|
||||
execution_scheduler_port: int = Field(
|
||||
default=8003,
|
||||
|
||||
Reference in New Issue
Block a user