mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-10 07:38:04 -05:00
feat(server): Clean up resources when spinning down services/processes (#7938)
- Add SIGTERM handler and `cleanup()` hook to `AppProcess` - Implement `cleanup()` on `AppService` to close DB and Redis connections - Implement `cleanup()` on `ExecutionManager` to shut down worker pool - Add `atexit` and SIGTERM handlers to node executor to close DB connection and shut down node workers - Improve logging in `.executor.manager` - Fix shutdown order of `.util.test:SpinTestServer`
This commit is contained in:
committed by
GitHub
parent
b12dba13f4
commit
3bd8040d6a
@@ -1,6 +1,10 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
@@ -114,9 +118,7 @@ def execute_node(
|
||||
if input_data is None:
|
||||
logger.error(
|
||||
"Skip execution, input validation error",
|
||||
extra={
|
||||
"json_fields": {**log_metadata, "error": error},
|
||||
},
|
||||
extra={"json_fields": {**log_metadata, "error": error}},
|
||||
)
|
||||
return
|
||||
|
||||
@@ -254,22 +256,14 @@ def _enqueue_next_nodes(
|
||||
if not next_node_input:
|
||||
logger.warning(
|
||||
f"Skipped queueing {suffix}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_metadata,
|
||||
}
|
||||
},
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
return enqueued_executions
|
||||
|
||||
# Input is complete, enqueue the execution.
|
||||
logger.info(
|
||||
f"Enqueued {suffix}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_metadata,
|
||||
}
|
||||
},
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
enqueued_executions.append(
|
||||
add_enqueued_execution(next_node_exec_id, next_node_id, next_node_input)
|
||||
@@ -402,29 +396,62 @@ class Executor:
|
||||
@classmethod
|
||||
def on_node_executor_start(cls):
|
||||
configure_logging()
|
||||
cls.logger = logging.getLogger("node_executor")
|
||||
|
||||
cls.loop = asyncio.new_event_loop()
|
||||
cls.pid = os.getpid()
|
||||
|
||||
cls.loop.run_until_complete(db.connect())
|
||||
cls.agent_server_client = get_agent_server_client()
|
||||
|
||||
# Set up shutdown handlers
|
||||
cls.shutdown_lock = threading.Lock()
|
||||
atexit.register(cls.on_node_executor_stop) # handle regular shutdown
|
||||
signal.signal( # handle termination
|
||||
signal.SIGTERM, lambda _, __: cls.on_node_executor_sigterm()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def on_node_executor_stop(cls):
|
||||
if not cls.shutdown_lock.acquire(blocking=False):
|
||||
return # already shutting down
|
||||
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ⏳ Disconnecting DB...")
|
||||
cls.loop.run_until_complete(db.disconnect())
|
||||
logger.info(f"[on_node_executor_stop {cls.pid}] ✅ Finished cleanup")
|
||||
|
||||
@classmethod
|
||||
def on_node_executor_sigterm(cls):
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ⚠️ SIGTERM received")
|
||||
if not cls.shutdown_lock.acquire(blocking=False):
|
||||
return # already shutting down, no need to self-terminate
|
||||
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ⏳ Disconnecting DB...")
|
||||
cls.loop.run_until_complete(db.disconnect())
|
||||
llprint(f"[on_node_executor_sigterm {cls.pid}] ✅ Finished cleanup")
|
||||
sys.exit(0)
|
||||
|
||||
@classmethod
|
||||
@error_logged
|
||||
def on_node_execution(cls, q: ExecutionQueue[NodeExecution], data: NodeExecution):
|
||||
def on_node_execution(
|
||||
cls, q: ExecutionQueue[NodeExecution], node_exec: NodeExecution
|
||||
):
|
||||
log_metadata = get_log_metadata(
|
||||
graph_eid=data.graph_exec_id,
|
||||
graph_id=data.graph_id,
|
||||
node_eid=data.node_exec_id,
|
||||
node_id=data.node_id,
|
||||
graph_eid=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_eid=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_name="-",
|
||||
)
|
||||
|
||||
execution_stats = {}
|
||||
timing_info, _ = cls._on_node_execution(q, data, log_metadata, execution_stats)
|
||||
timing_info, _ = cls._on_node_execution(
|
||||
q, node_exec, log_metadata, execution_stats
|
||||
)
|
||||
execution_stats["walltime"] = timing_info.wall_time
|
||||
execution_stats["cputime"] = timing_info.cpu_time
|
||||
|
||||
cls.loop.run_until_complete(
|
||||
update_node_execution_stats(data.node_exec_id, execution_stats)
|
||||
update_node_execution_stats(node_exec.node_exec_id, execution_stats)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -432,32 +459,26 @@ class Executor:
|
||||
def _on_node_execution(
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecution],
|
||||
d: NodeExecution,
|
||||
node_exec: NodeExecution,
|
||||
log_metadata: dict,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
try:
|
||||
cls.logger.info(
|
||||
"Start node execution",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_metadata,
|
||||
}
|
||||
},
|
||||
logger.info(
|
||||
f"Start node execution {node_exec.node_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
for execution in execute_node(cls.loop, cls.agent_server_client, d, stats):
|
||||
for execution in execute_node(
|
||||
cls.loop, cls.agent_server_client, node_exec, stats
|
||||
):
|
||||
q.add(execution)
|
||||
cls.logger.info(
|
||||
"Finished node execution",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_metadata,
|
||||
}
|
||||
},
|
||||
logger.info(
|
||||
f"Finished node execution {node_exec.node_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
except Exception as e:
|
||||
cls.logger.exception(
|
||||
f"Failed node execution: {e}",
|
||||
logger.exception(
|
||||
f"Failed node execution {node_exec.node_exec_id}: {e}",
|
||||
extra={
|
||||
**log_metadata,
|
||||
},
|
||||
@@ -466,12 +487,26 @@ class Executor:
|
||||
@classmethod
|
||||
def on_graph_executor_start(cls):
|
||||
configure_logging()
|
||||
cls.logger = logging.getLogger("graph_executor")
|
||||
cls.loop = asyncio.new_event_loop()
|
||||
cls.loop.run_until_complete(db.connect())
|
||||
|
||||
cls.pool_size = Config().num_node_workers
|
||||
cls.loop = asyncio.new_event_loop()
|
||||
cls.pid = os.getpid()
|
||||
|
||||
cls.loop.run_until_complete(db.connect())
|
||||
cls._init_node_executor_pool()
|
||||
logger.info(f"Graph executor started with max-{cls.pool_size} node workers.")
|
||||
logger.info(
|
||||
f"Graph executor {cls.pid} started with {cls.pool_size} node workers"
|
||||
)
|
||||
|
||||
# Set up shutdown handler
|
||||
atexit.register(cls.on_graph_executor_stop)
|
||||
|
||||
@classmethod
|
||||
def on_graph_executor_stop(cls):
|
||||
logger.info(
|
||||
f"[on_graph_executor_stop {cls.pid}] ⏳ Terminating node executor pool..."
|
||||
)
|
||||
cls.executor.terminate()
|
||||
|
||||
@classmethod
|
||||
def _init_node_executor_pool(cls):
|
||||
@@ -482,19 +517,21 @@ class Executor:
|
||||
|
||||
@classmethod
|
||||
@error_logged
|
||||
def on_graph_execution(cls, data: GraphExecution, cancel: threading.Event):
|
||||
def on_graph_execution(cls, graph_exec: GraphExecution, cancel: threading.Event):
|
||||
log_metadata = get_log_metadata(
|
||||
graph_eid=data.graph_exec_id,
|
||||
graph_id=data.graph_id,
|
||||
graph_eid=graph_exec.graph_exec_id,
|
||||
graph_id=graph_exec.graph_id,
|
||||
node_id="*",
|
||||
node_eid="*",
|
||||
block_name="-",
|
||||
)
|
||||
timing_info, node_count = cls._on_graph_execution(data, cancel, log_metadata)
|
||||
timing_info, node_count = cls._on_graph_execution(
|
||||
graph_exec, cancel, log_metadata
|
||||
)
|
||||
|
||||
cls.loop.run_until_complete(
|
||||
update_graph_execution_stats(
|
||||
data.graph_exec_id,
|
||||
graph_exec.graph_exec_id,
|
||||
{
|
||||
"walltime": timing_info.wall_time,
|
||||
"cputime": timing_info.cpu_time,
|
||||
@@ -506,15 +543,11 @@ class Executor:
|
||||
@classmethod
|
||||
@time_measured
|
||||
def _on_graph_execution(
|
||||
cls, graph_data: GraphExecution, cancel: threading.Event, log_metadata: dict
|
||||
cls, graph_exec: GraphExecution, cancel: threading.Event, log_metadata: dict
|
||||
) -> int:
|
||||
cls.logger.info(
|
||||
"Start graph execution",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_metadata,
|
||||
}
|
||||
},
|
||||
logger.info(
|
||||
f"Start graph execution {graph_exec.graph_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
n_node_executions = 0
|
||||
finished = False
|
||||
@@ -526,7 +559,7 @@ class Executor:
|
||||
return
|
||||
cls.executor.terminate()
|
||||
logger.info(
|
||||
f"Terminated graph execution {graph_data.graph_exec_id}",
|
||||
f"Terminated graph execution {graph_exec.graph_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
cls._init_node_executor_pool()
|
||||
@@ -536,7 +569,7 @@ class Executor:
|
||||
|
||||
try:
|
||||
queue = ExecutionQueue[NodeExecution]()
|
||||
for node_exec in graph_data.start_node_execs:
|
||||
for node_exec in graph_exec.start_node_execs:
|
||||
queue.add(node_exec)
|
||||
|
||||
running_executions: dict[str, AsyncResult] = {}
|
||||
@@ -566,7 +599,11 @@ class Executor:
|
||||
# Re-enqueueing the data back to the queue will disrupt the order.
|
||||
execution.wait()
|
||||
|
||||
logger.debug(f"Dispatching execution of node {exec_data.node_id}")
|
||||
logger.debug(
|
||||
f"Dispatching node execution {exec_data.node_exec_id} "
|
||||
f"for node {exec_data.node_id}",
|
||||
extra={**log_metadata},
|
||||
)
|
||||
running_executions[exec_data.node_id] = cls.executor.apply_async(
|
||||
cls.on_node_execution,
|
||||
(queue, exec_data),
|
||||
@@ -577,41 +614,30 @@ class Executor:
|
||||
while queue.empty() and running_executions:
|
||||
logger.debug(
|
||||
"Queue empty; running nodes: "
|
||||
f"{list(running_executions.keys())}"
|
||||
f"{list(running_executions.keys())}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
for node_id, execution in list(running_executions.items()):
|
||||
if cancel.is_set():
|
||||
return n_node_executions
|
||||
|
||||
if not queue.empty():
|
||||
logger.debug(
|
||||
"Queue no longer empty! Returning to dispatching loop."
|
||||
)
|
||||
break # yield to parent loop to execute new queue items
|
||||
|
||||
logger.debug(f"Waiting on execution of node {node_id}")
|
||||
execution.wait(3)
|
||||
logger.debug(
|
||||
f"State of execution of node {node_id} after waiting: "
|
||||
f"{'DONE' if execution.ready() else 'RUNNING'}"
|
||||
f"Waiting on execution of node {node_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
execution.wait(3)
|
||||
|
||||
cls.logger.info(
|
||||
"Finished graph execution",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_metadata,
|
||||
}
|
||||
},
|
||||
logger.info(
|
||||
f"Finished graph execution {graph_exec.graph_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed graph execution: {e}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_metadata,
|
||||
}
|
||||
},
|
||||
f"Failed graph execution {graph_exec.graph_exec_id}: {e}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
finally:
|
||||
if not cancel.is_set():
|
||||
@@ -628,29 +654,33 @@ class ExecutionManager(AppService):
|
||||
self.queue = ExecutionQueue[GraphExecution]()
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
|
||||
# def __del__(self):
|
||||
# self.sync_manager.shutdown()
|
||||
|
||||
def run_service(self):
|
||||
with ProcessPoolExecutor(
|
||||
self.executor = ProcessPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
initializer=Executor.on_graph_executor_start,
|
||||
) as executor:
|
||||
sync_manager = multiprocessing.Manager()
|
||||
logger.info(
|
||||
f"Execution manager started with max-{self.pool_size} graph workers."
|
||||
)
|
||||
sync_manager = multiprocessing.Manager()
|
||||
logger.info(f"ExecutionManager started with max-{self.pool_size} graph workers")
|
||||
while True:
|
||||
graph_exec_data = self.queue.get()
|
||||
graph_exec_id = graph_exec_data.graph_exec_id
|
||||
logger.debug(
|
||||
f"[ExecutionManager] Dispatching graph execution {graph_exec_id}"
|
||||
)
|
||||
while True:
|
||||
graph_exec_data = self.queue.get()
|
||||
graph_exec_id = graph_exec_data.graph_exec_id
|
||||
cancel_event = sync_manager.Event()
|
||||
future = executor.submit(
|
||||
Executor.on_graph_execution, graph_exec_data, cancel_event
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
future.add_done_callback(
|
||||
lambda _: self.active_graph_runs.pop(graph_exec_id)
|
||||
)
|
||||
cancel_event = sync_manager.Event()
|
||||
future = self.executor.submit(
|
||||
Executor.on_graph_execution, graph_exec_data, cancel_event
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
future.add_done_callback(
|
||||
lambda _: self.active_graph_runs.pop(graph_exec_id)
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
logger.info(f"[{__class__.__name__}] ⏳ Shutting down graph executor pool...")
|
||||
self.executor.shutdown(cancel_futures=True)
|
||||
|
||||
super().cleanup()
|
||||
|
||||
@property
|
||||
def agent_server_client(self) -> "AgentServer":
|
||||
@@ -754,3 +784,12 @@ class ExecutionManager(AppService):
|
||||
)
|
||||
)
|
||||
self.agent_server_client.send_execution_update(exec_update.model_dump())
|
||||
|
||||
|
||||
def llprint(message: str):
|
||||
"""
|
||||
Low-level print/log helper function for use in signal handlers.
|
||||
Regular log/print statements are not allowed in signal handlers.
|
||||
"""
|
||||
if logger.getEffectiveLevel() == logging.DEBUG:
|
||||
os.write(sys.stdout.fileno(), (message + "\n").encode())
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from multiprocessing import Process, set_start_method
|
||||
@@ -7,6 +9,8 @@ from typing import Optional
|
||||
from autogpt_server.util.logging import configure_logging
|
||||
from autogpt_server.util.metrics import sentry_init
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppProcess(ABC):
|
||||
"""
|
||||
@@ -19,6 +23,8 @@ class AppProcess(ABC):
|
||||
configure_logging()
|
||||
sentry_init()
|
||||
|
||||
# Methods that are executed INSIDE the process #
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
"""
|
||||
@@ -26,6 +32,13 @@ class AppProcess(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Implement this method on a subclass to do post-execution cleanup,
|
||||
e.g. disconnecting from a database or terminating child processes.
|
||||
"""
|
||||
pass
|
||||
|
||||
def health_check(self):
|
||||
"""
|
||||
A method to check the health of the process.
|
||||
@@ -33,13 +46,21 @@ class AppProcess(ABC):
|
||||
pass
|
||||
|
||||
def execute_run_command(self, silent):
|
||||
signal.signal(signal.SIGTERM, self._self_terminate)
|
||||
|
||||
try:
|
||||
if silent:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
self.run()
|
||||
except KeyboardInterrupt or SystemExit as e:
|
||||
print(f"Process terminated: {e}")
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info(f"[{self.__class__.__name__}] Terminated; quitting...")
|
||||
|
||||
def _self_terminate(self, signum: int, frame):
|
||||
self.cleanup()
|
||||
sys.exit(0)
|
||||
|
||||
# Methods that are executed OUTSIDE the process #
|
||||
|
||||
def __enter__(self):
|
||||
self.start(background=True)
|
||||
|
||||
@@ -102,6 +102,14 @@ class AppService(AppProcess):
|
||||
# Run the main service (if it's not implemented, just sleep).
|
||||
self.run_service()
|
||||
|
||||
def cleanup(self):
|
||||
if self.use_db:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting DB...")
|
||||
self.run_and_wait(db.disconnect())
|
||||
if self.use_redis:
|
||||
logger.info(f"[{self.__class__.__name__}] ⏳ Disconnecting Redis...")
|
||||
self.run_and_wait(self.event_queue.close())
|
||||
|
||||
@conn_retry
|
||||
def __start_pyro(self):
|
||||
daemon = pyro.Daemon(host=pyro_host)
|
||||
|
||||
@@ -73,10 +73,10 @@ class SpinTestServer:
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await db.disconnect()
|
||||
|
||||
self.name_server.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.agent_server.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.exec_manager.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.scheduler.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.exec_manager.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.agent_server.__exit__(exc_type, exc_val, exc_tb)
|
||||
self.name_server.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def setup_dependency_overrides(self):
|
||||
# Override get_user_id for testing
|
||||
|
||||
Reference in New Issue
Block a user