feat(server): Clean up resources when spinning down services/processes (#7938)

- Add SIGTERM handler and `cleanup()` hook to `AppProcess`
- Implement `cleanup()` on `AppService` to close DB and Redis connections
- Implement `cleanup()` on `ExecutionManager` to shut down worker pool
- Add `atexit` and SIGTERM handlers to node executor to close DB connection and shut down node workers
- Improve logging in `.executor.manager`
- Fix shutdown order of `.util.test:SpinTestServer`
This commit is contained in:
Reinier van der Leer
2024-09-06 16:50:59 +02:00
committed by GitHub
parent b12dba13f4
commit 3bd8040d6a
4 changed files with 172 additions and 104 deletions

View File

@@ -1,6 +1,10 @@
import asyncio
import atexit
import logging
import multiprocessing
import os
import signal
import sys
import threading
from concurrent.futures import Future, ProcessPoolExecutor
from contextlib import contextmanager
@@ -114,9 +118,7 @@ def execute_node(
if input_data is None:
logger.error(
"Skip execution, input validation error",
extra={
"json_fields": {**log_metadata, "error": error},
},
extra={"json_fields": {**log_metadata, "error": error}},
)
return
@@ -254,22 +256,14 @@ def _enqueue_next_nodes(
if not next_node_input:
logger.warning(
f"Skipped queueing {suffix}",
extra={
"json_fields": {
**log_metadata,
}
},
extra={"json_fields": {**log_metadata}},
)
return enqueued_executions
# Input is complete, enqueue the execution.
logger.info(
f"Enqueued {suffix}",
extra={
"json_fields": {
**log_metadata,
}
},
extra={"json_fields": {**log_metadata}},
)
enqueued_executions.append(
add_enqueued_execution(next_node_exec_id, next_node_id, next_node_input)
@@ -402,29 +396,62 @@ class Executor:
@classmethod
def on_node_executor_start(cls):
configure_logging()
cls.logger = logging.getLogger("node_executor")
cls.loop = asyncio.new_event_loop()
cls.pid = os.getpid()
cls.loop.run_until_complete(db.connect())
cls.agent_server_client = get_agent_server_client()
# Set up shutdown handlers
cls.shutdown_lock = threading.Lock()
atexit.register(cls.on_node_executor_stop) # handle regular shutdown
signal.signal( # handle termination
signal.SIGTERM, lambda _, __: cls.on_node_executor_sigterm()
)
@classmethod
def on_node_executor_stop(cls):
if not cls.shutdown_lock.acquire(blocking=False):
return # already shutting down
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB...")
cls.loop.run_until_complete(db.disconnect())
logger.info(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
@classmethod
def on_node_executor_sigterm(cls):
llprint(f"[on_node_executor_sigterm {cls.pid}] ⚠️ SIGTERM received")
if not cls.shutdown_lock.acquire(blocking=False):
return # already shutting down, no need to self-terminate
llprint(f"[on_node_executor_sigterm {cls.pid}] ⏳ Disconnecting DB...")
cls.loop.run_until_complete(db.disconnect())
llprint(f"[on_node_executor_sigterm {cls.pid}] ✅ Finished cleanup")
sys.exit(0)
@classmethod
@error_logged
def on_node_execution(cls, q: ExecutionQueue[NodeExecution], data: NodeExecution):
def on_node_execution(
cls, q: ExecutionQueue[NodeExecution], node_exec: NodeExecution
):
log_metadata = get_log_metadata(
graph_eid=data.graph_exec_id,
graph_id=data.graph_id,
node_eid=data.node_exec_id,
node_id=data.node_id,
graph_eid=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_eid=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_name="-",
)
execution_stats = {}
timing_info, _ = cls._on_node_execution(q, data, log_metadata, execution_stats)
timing_info, _ = cls._on_node_execution(
q, node_exec, log_metadata, execution_stats
)
execution_stats["walltime"] = timing_info.wall_time
execution_stats["cputime"] = timing_info.cpu_time
cls.loop.run_until_complete(
update_node_execution_stats(data.node_exec_id, execution_stats)
update_node_execution_stats(node_exec.node_exec_id, execution_stats)
)
@classmethod
@@ -432,32 +459,26 @@ class Executor:
def _on_node_execution(
cls,
q: ExecutionQueue[NodeExecution],
d: NodeExecution,
node_exec: NodeExecution,
log_metadata: dict,
stats: dict[str, Any] | None = None,
):
try:
cls.logger.info(
"Start node execution",
extra={
"json_fields": {
**log_metadata,
}
},
logger.info(
f"Start node execution {node_exec.node_exec_id}",
extra={"json_fields": {**log_metadata}},
)
for execution in execute_node(cls.loop, cls.agent_server_client, d, stats):
for execution in execute_node(
cls.loop, cls.agent_server_client, node_exec, stats
):
q.add(execution)
cls.logger.info(
"Finished node execution",
extra={
"json_fields": {
**log_metadata,
}
},
logger.info(
f"Finished node execution {node_exec.node_exec_id}",
extra={"json_fields": {**log_metadata}},
)
except Exception as e:
cls.logger.exception(
f"Failed node execution: {e}",
logger.exception(
f"Failed node execution {node_exec.node_exec_id}: {e}",
extra={
**log_metadata,
},
@@ -466,12 +487,26 @@ class Executor:
@classmethod
def on_graph_executor_start(cls):
configure_logging()
cls.logger = logging.getLogger("graph_executor")
cls.loop = asyncio.new_event_loop()
cls.loop.run_until_complete(db.connect())
cls.pool_size = Config().num_node_workers
cls.loop = asyncio.new_event_loop()
cls.pid = os.getpid()
cls.loop.run_until_complete(db.connect())
cls._init_node_executor_pool()
logger.info(f"Graph executor started with max-{cls.pool_size} node workers.")
logger.info(
f"Graph executor {cls.pid} started with {cls.pool_size} node workers"
)
# Set up shutdown handler
atexit.register(cls.on_graph_executor_stop)
@classmethod
def on_graph_executor_stop(cls):
logger.info(
f"[on_graph_executor_stop {cls.pid}] ⏳ Terminating node executor pool..."
)
cls.executor.terminate()
@classmethod
def _init_node_executor_pool(cls):
@@ -482,19 +517,21 @@ class Executor:
@classmethod
@error_logged
def on_graph_execution(cls, data: GraphExecution, cancel: threading.Event):
def on_graph_execution(cls, graph_exec: GraphExecution, cancel: threading.Event):
log_metadata = get_log_metadata(
graph_eid=data.graph_exec_id,
graph_id=data.graph_id,
graph_eid=graph_exec.graph_exec_id,
graph_id=graph_exec.graph_id,
node_id="*",
node_eid="*",
block_name="-",
)
timing_info, node_count = cls._on_graph_execution(data, cancel, log_metadata)
timing_info, node_count = cls._on_graph_execution(
graph_exec, cancel, log_metadata
)
cls.loop.run_until_complete(
update_graph_execution_stats(
data.graph_exec_id,
graph_exec.graph_exec_id,
{
"walltime": timing_info.wall_time,
"cputime": timing_info.cpu_time,
@@ -506,15 +543,11 @@ class Executor:
@classmethod
@time_measured
def _on_graph_execution(
cls, graph_data: GraphExecution, cancel: threading.Event, log_metadata: dict
cls, graph_exec: GraphExecution, cancel: threading.Event, log_metadata: dict
) -> int:
cls.logger.info(
"Start graph execution",
extra={
"json_fields": {
**log_metadata,
}
},
logger.info(
f"Start graph execution {graph_exec.graph_exec_id}",
extra={"json_fields": {**log_metadata}},
)
n_node_executions = 0
finished = False
@@ -526,7 +559,7 @@ class Executor:
return
cls.executor.terminate()
logger.info(
f"Terminated graph execution {graph_data.graph_exec_id}",
f"Terminated graph execution {graph_exec.graph_exec_id}",
extra={"json_fields": {**log_metadata}},
)
cls._init_node_executor_pool()
@@ -536,7 +569,7 @@ class Executor:
try:
queue = ExecutionQueue[NodeExecution]()
for node_exec in graph_data.start_node_execs:
for node_exec in graph_exec.start_node_execs:
queue.add(node_exec)
running_executions: dict[str, AsyncResult] = {}
@@ -566,7 +599,11 @@ class Executor:
# Re-enqueueing the data back to the queue will disrupt the order.
execution.wait()
logger.debug(f"Dispatching execution of node {exec_data.node_id}")
logger.debug(
f"Dispatching node execution {exec_data.node_exec_id} "
f"for node {exec_data.node_id}",
extra={**log_metadata},
)
running_executions[exec_data.node_id] = cls.executor.apply_async(
cls.on_node_execution,
(queue, exec_data),
@@ -577,41 +614,30 @@ class Executor:
while queue.empty() and running_executions:
logger.debug(
"Queue empty; running nodes: "
f"{list(running_executions.keys())}"
f"{list(running_executions.keys())}",
extra={"json_fields": {**log_metadata}},
)
for node_id, execution in list(running_executions.items()):
if cancel.is_set():
return n_node_executions
if not queue.empty():
logger.debug(
"Queue no longer empty! Returning to dispatching loop."
)
break # yield to parent loop to execute new queue items
logger.debug(f"Waiting on execution of node {node_id}")
execution.wait(3)
logger.debug(
f"State of execution of node {node_id} after waiting: "
f"{'DONE' if execution.ready() else 'RUNNING'}"
f"Waiting on execution of node {node_id}",
extra={"json_fields": {**log_metadata}},
)
execution.wait(3)
cls.logger.info(
"Finished graph execution",
extra={
"json_fields": {
**log_metadata,
}
},
logger.info(
f"Finished graph execution {graph_exec.graph_exec_id}",
extra={"json_fields": {**log_metadata}},
)
except Exception as e:
logger.exception(
f"Failed graph execution: {e}",
extra={
"json_fields": {
**log_metadata,
}
},
f"Failed graph execution {graph_exec.graph_exec_id}: {e}",
extra={"json_fields": {**log_metadata}},
)
finally:
if not cancel.is_set():
@@ -628,29 +654,33 @@ class ExecutionManager(AppService):
self.queue = ExecutionQueue[GraphExecution]()
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
# def __del__(self):
# self.sync_manager.shutdown()
def run_service(self):
with ProcessPoolExecutor(
self.executor = ProcessPoolExecutor(
max_workers=self.pool_size,
initializer=Executor.on_graph_executor_start,
) as executor:
sync_manager = multiprocessing.Manager()
logger.info(
f"Execution manager started with max-{self.pool_size} graph workers."
)
sync_manager = multiprocessing.Manager()
logger.info(f"ExecutionManager started with max-{self.pool_size} graph workers")
while True:
graph_exec_data = self.queue.get()
graph_exec_id = graph_exec_data.graph_exec_id
logger.debug(
f"[ExecutionManager] Dispatching graph execution {graph_exec_id}"
)
while True:
graph_exec_data = self.queue.get()
graph_exec_id = graph_exec_data.graph_exec_id
cancel_event = sync_manager.Event()
future = executor.submit(
Executor.on_graph_execution, graph_exec_data, cancel_event
)
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
future.add_done_callback(
lambda _: self.active_graph_runs.pop(graph_exec_id)
)
cancel_event = sync_manager.Event()
future = self.executor.submit(
Executor.on_graph_execution, graph_exec_data, cancel_event
)
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
future.add_done_callback(
lambda _: self.active_graph_runs.pop(graph_exec_id)
)
def cleanup(self):
logger.info(f"[{__class__.__name__}] ⏳ Shutting down graph executor pool...")
self.executor.shutdown(cancel_futures=True)
super().cleanup()
@property
def agent_server_client(self) -> "AgentServer":
@@ -754,3 +784,12 @@ class ExecutionManager(AppService):
)
)
self.agent_server_client.send_execution_update(exec_update.model_dump())
def llprint(message: str):
"""
Low-level print/log helper function for use in signal handlers.
Regular log/print statements are not allowed in signal handlers.
"""
if logger.getEffectiveLevel() == logging.DEBUG:
os.write(sys.stdout.fileno(), (message + "\n").encode())

View File

@@ -1,4 +1,6 @@
import logging
import os
import signal
import sys
from abc import ABC, abstractmethod
from multiprocessing import Process, set_start_method
@@ -7,6 +9,8 @@ from typing import Optional
from autogpt_server.util.logging import configure_logging
from autogpt_server.util.metrics import sentry_init
logger = logging.getLogger(__name__)
class AppProcess(ABC):
"""
@@ -19,6 +23,8 @@ class AppProcess(ABC):
configure_logging()
sentry_init()
# Methods that are executed INSIDE the process #
@abstractmethod
def run(self):
"""
@@ -26,6 +32,13 @@ class AppProcess(ABC):
"""
pass
def cleanup(self):
"""
Implement this method on a subclass to do post-execution cleanup,
e.g. disconnecting from a database or terminating child processes.
"""
pass
def health_check(self):
"""
A method to check the health of the process.
@@ -33,13 +46,21 @@ class AppProcess(ABC):
pass
def execute_run_command(self, silent):
signal.signal(signal.SIGTERM, self._self_terminate)
try:
if silent:
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
self.run()
except KeyboardInterrupt or SystemExit as e:
print(f"Process terminated: {e}")
except (KeyboardInterrupt, SystemExit):
logger.info(f"[{self.__class__.__name__}] Terminated; quitting...")
def _self_terminate(self, signum: int, frame):
self.cleanup()
sys.exit(0)
# Methods that are executed OUTSIDE the process #
def __enter__(self):
self.start(background=True)

View File

@@ -102,6 +102,14 @@ class AppService(AppProcess):
# Run the main service (if it's not implemented, just sleep).
self.run_service()
def cleanup(self):
if self.use_db:
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting DB...")
self.run_and_wait(db.disconnect())
if self.use_redis:
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting Redis...")
self.run_and_wait(self.event_queue.close())
@conn_retry
def __start_pyro(self):
daemon = pyro.Daemon(host=pyro_host)

View File

@@ -73,10 +73,10 @@ class SpinTestServer:
async def __aexit__(self, exc_type, exc_val, exc_tb):
await db.disconnect()
self.name_server.__exit__(exc_type, exc_val, exc_tb)
self.agent_server.__exit__(exc_type, exc_val, exc_tb)
self.exec_manager.__exit__(exc_type, exc_val, exc_tb)
self.scheduler.__exit__(exc_type, exc_val, exc_tb)
self.exec_manager.__exit__(exc_type, exc_val, exc_tb)
self.agent_server.__exit__(exc_type, exc_val, exc_tb)
self.name_server.__exit__(exc_type, exc_val, exc_tb)
def setup_dependency_overrides(self):
# Override get_user_id for testing