fix(backend): Implement graceful shutdown in AppService to prevent RPC errors (#11240)

We're currently seeing errors in the `DatabaseManager` while it's
shutting down, like:

```
WARNING [DatabaseManager] Termination request: SystemExit; 0 executing cleanup.
INFO [DatabaseManager]  Disconnecting Database...
INFO [PID-1|THREAD-29|DatabaseManager|Prisma-82fb1994-4b87-40c1-8869-fbd97bd33fc8] Releasing connection started...
INFO [PID-1|THREAD-29|DatabaseManager|Prisma-82fb1994-4b87-40c1-8869-fbd97bd33fc8] Releasing connection completed successfully.
INFO [DatabaseManager] Terminated.
ERROR POST /create_or_add_to_user_notification_batch failed: Failed to create or add to notification batch for user {user_id} and type AGENT_RUN: NoneType: None
```

This indicates two issues:
- The service doesn't wait for pending RPC calls to finish before
terminating
- We're using `logger.exception` outside an error handling context,
causing the confusing and not much useful `NoneType: None` to be printed
instead of error info

### Changes 🏗️

- Implement graceful shutdown in `AppService` so in-flight RPC calls can
finish
  - Add tests for graceful shutdown
  - Prevent `AppService` accepting new requests during shutdown
- Rework `AppService` lifecycle management; add support for async
`lifespan`
- Fix `AppService` endpoint error logging
- Improve logging in `AppProcess` and `AppService`

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- Deploy to Dev cluster, then `kubectl rollout restart` the different
services a few times
    - [x] -> `DatabaseManager` doesn't break on re-deployment
    - [x] -> `Scheduler` doesn't break on re-deployment
    - [x] -> `NotificationManager` doesn't break on re-deployment
This commit is contained in:
Reinier van der Leer
2025-10-25 16:47:19 +02:00
committed by GitHub
parent acb946801b
commit e06e7ff33f
11 changed files with 333 additions and 60 deletions

View File

@@ -1,6 +1,7 @@
import logging
import signal
import threading
import warnings
from contextlib import contextmanager
from enum import Enum
@@ -26,6 +27,13 @@ from backend.sdk import (
SchemaField,
)
# Suppress false positive cleanup warning of litellm (a dependency of stagehand)
warnings.filterwarnings(
"ignore",
message="coroutine 'close_litellm_async_clients' was never awaited",
category=RuntimeWarning,
)
# Store the original method
original_register_signal_handlers = stagehand.main.Stagehand._register_signal_handlers

View File

@@ -45,9 +45,6 @@ class MainApp(AppProcess):
app.main(silent=True)
def cleanup(self):
pass
@click.group()
def main():

View File

@@ -1,5 +1,6 @@
import logging
from typing import Callable, Concatenate, ParamSpec, TypeVar, cast
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cast
from backend.data import db
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
@@ -57,6 +58,9 @@ from backend.util.service import (
)
from backend.util.settings import Config
if TYPE_CHECKING:
from fastapi import FastAPI
config = Config()
logger = logging.getLogger(__name__)
P = ParamSpec("P")
@@ -76,15 +80,17 @@ async def _get_credits(user_id: str) -> int:
class DatabaseManager(AppService):
def run_service(self) -> None:
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
self.run_and_wait(db.connect())
super().run_service()
@asynccontextmanager
async def lifespan(self, app: "FastAPI"):
async with super().lifespan(app):
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
await db.connect()
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
self.run_and_wait(db.disconnect())
logger.info(f"[{self.service_name}] ✅ Ready")
yield
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
await db.disconnect()
async def health_check(self) -> str:
if not db.is_connected():

View File

@@ -1714,6 +1714,8 @@ class ExecutionManager(AppProcess):
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
super().cleanup()
# ------- UTILITIES ------- #

View File

@@ -248,7 +248,7 @@ class Scheduler(AppService):
raise UnhealthyServiceError("Scheduler is still initializing")
# Check if we're in the middle of cleanup
if self.cleaned_up:
if self._shutting_down:
return await super().health_check()
# Normal operation - check if scheduler is running
@@ -375,7 +375,6 @@ class Scheduler(AppService):
super().run_service()
def cleanup(self):
super().cleanup()
if self.scheduler:
logger.info("⏳ Shutting down scheduler...")
self.scheduler.shutdown(wait=True)
@@ -390,7 +389,7 @@ class Scheduler(AppService):
logger.info("⏳ Waiting for event loop thread to finish...")
_event_loop_thread.join(timeout=SCHEDULER_OPERATION_TIMEOUT_SECONDS)
logger.info("Scheduler cleanup complete.")
super().cleanup()
@expose
def add_graph_execution_schedule(

View File

@@ -1017,10 +1017,14 @@ class NotificationManager(AppService):
logger.exception(f"Fatal error in consumer for {queue_name}: {e}")
raise
@continuous_retry()
def run_service(self):
self.run_and_wait(self._run_service())
# Queue the main _run_service task
asyncio.run_coroutine_threadsafe(self._run_service(), self.shared_event_loop)
# Start the main event loop
super().run_service()
@continuous_retry()
async def _run_service(self):
logger.info(f"[{self.service_name}] ⏳ Configuring RabbitMQ...")
self.rabbitmq_service = rabbitmq.AsyncRabbitMQ(self.rabbitmq_config)
@@ -1086,10 +1090,11 @@ class NotificationManager(AppService):
def cleanup(self):
"""Cleanup service resources"""
self.running = False
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Disconnecting RabbitMQ...")
logger.info("⏳ Disconnecting RabbitMQ...")
self.run_and_wait(self.rabbitmq_service.disconnect())
super().cleanup()
class NotificationManagerClient(AppServiceClient):
@classmethod

View File

@@ -321,10 +321,6 @@ class AgentServer(backend.util.service.AppProcess):
uvicorn.run(**uvicorn_config)
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Shutting down Agent Server...")
@staticmethod
async def test_execute_graph(
graph_id: str,

View File

@@ -329,7 +329,3 @@ class WebsocketServer(AppProcess):
port=Config().websocket_server_port,
log_config=None,
)
def cleanup(self):
super().cleanup()
logger.info(f"[{self.service_name}] ⏳ Shutting down WebSocket Server...")

View File

@@ -19,7 +19,8 @@ class AppProcess(ABC):
"""
process: Optional[Process] = None
cleaned_up = False
_shutting_down: bool = False
_cleaned_up: bool = False
if "forkserver" in get_all_start_methods():
set_start_method("forkserver", force=True)
@@ -43,7 +44,6 @@ class AppProcess(ABC):
def service_name(self) -> str:
return self.__class__.__name__
@abstractmethod
def cleanup(self):
"""
Implement this method on a subclass to do post-execution cleanup,
@@ -65,7 +65,8 @@ class AppProcess(ABC):
self.run()
except BaseException as e:
logger.warning(
f"[{self.service_name}] Termination request: {type(e).__name__}; {e} executing cleanup."
f"[{self.service_name}] 🛑 Terminating because of {type(e).__name__}: {e}", # noqa
exc_info=e if not isinstance(e, SystemExit) else None,
)
# Send error to Sentry before cleanup
if not isinstance(e, (KeyboardInterrupt, SystemExit)):
@@ -76,8 +77,12 @@ class AppProcess(ABC):
except Exception:
pass # Silently ignore if Sentry isn't available
finally:
self.cleanup()
logger.info(f"[{self.service_name}] Terminated.")
if not self._cleaned_up:
self._cleaned_up = True
logger.info(f"[{self.service_name}] 🧹 Running cleanup")
self.cleanup()
logger.info(f"[{self.service_name}] ✅ Cleanup done")
logger.info(f"[{self.service_name}] 🛑 Terminated")
@staticmethod
def llprint(message: str):
@@ -88,8 +93,8 @@ class AppProcess(ABC):
os.write(sys.stdout.fileno(), (message + "\n").encode())
def _self_terminate(self, signum: int, frame):
if not self.cleaned_up:
self.cleaned_up = True
if not self._shutting_down:
self._shutting_down = True
sys.exit(0)
else:
self.llprint(

View File

@@ -4,9 +4,12 @@ import concurrent.futures
import inspect
import logging
import os
import signal
import sys
import threading
import time
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from functools import update_wrapper
from typing import (
Any,
@@ -111,14 +114,44 @@ class BaseAppService(AppProcess, ABC):
return target_host
def run_service(self) -> None:
while True:
time.sleep(10)
# HACK: run the main event loop outside the main thread to disable Uvicorn's
# internal signal handlers, since there is no config option for this :(
shared_asyncio_thread = threading.Thread(
target=self._run_shared_event_loop,
daemon=True,
name=f"{self.service_name}-shared-event-loop",
)
shared_asyncio_thread.start()
shared_asyncio_thread.join()
def _run_shared_event_loop(self) -> None:
try:
self.shared_event_loop.run_forever()
finally:
logger.info(f"[{self.service_name}] 🛑 Shared event loop stopped")
self.shared_event_loop.close() # ensure held resources are released
def run_and_wait(self, coro: Coroutine[Any, Any, T]) -> T:
return asyncio.run_coroutine_threadsafe(coro, self.shared_event_loop).result()
def run(self):
self.shared_event_loop = asyncio.get_event_loop()
self.shared_event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.shared_event_loop)
def cleanup(self):
"""
**💡 Overriding `AppService.lifespan` may be a more convenient option.**
Implement this method on a subclass to do post-execution cleanup,
e.g. disconnecting from a database or terminating child processes.
**Note:** if you override this method in a subclass, it must call
`super().cleanup()` *at the end*!
"""
# Stop the shared event loop to allow resource clean-up
self.shared_event_loop.call_soon_threadsafe(self.shared_event_loop.stop)
super().cleanup()
class RemoteCallError(BaseModel):
@@ -179,6 +212,7 @@ EXCEPTION_MAPPING = {
class AppService(BaseAppService, ABC):
fastapi_app: FastAPI
http_server: uvicorn.Server | None = None
log_level: str = "info"
def set_log_level(self, log_level: str):
@@ -190,11 +224,10 @@ class AppService(BaseAppService, ABC):
def _handle_internal_http_error(status_code: int = 500, log_error: bool = True):
def handler(request: Request, exc: Exception):
if log_error:
if status_code == 500:
log = logger.exception
else:
log = logger.error
log(f"{request.method} {request.url.path} failed: {exc}")
logger.error(
f"{request.method} {request.url.path} failed: {exc}",
exc_info=exc if status_code == 500 else None,
)
return responses.JSONResponse(
status_code=status_code,
content=RemoteCallError(
@@ -256,13 +289,13 @@ class AppService(BaseAppService, ABC):
return sync_endpoint
@conn_retry("FastAPI server", "Starting FastAPI server")
@conn_retry("FastAPI server", "Running FastAPI server")
def __start_fastapi(self):
logger.info(
f"[{self.service_name}] Starting RPC server at http://{api_host}:{self.get_port()}"
)
server = uvicorn.Server(
self.http_server = uvicorn.Server(
uvicorn.Config(
self.fastapi_app,
host=api_host,
@@ -271,18 +304,76 @@ class AppService(BaseAppService, ABC):
log_level=self.log_level,
)
)
self.shared_event_loop.run_until_complete(server.serve())
self.run_and_wait(self.http_server.serve())
# Perform clean-up when the server exits
if not self._cleaned_up:
self._cleaned_up = True
logger.info(f"[{self.service_name}] 🧹 Running cleanup")
self.cleanup()
logger.info(f"[{self.service_name}] ✅ Cleanup done")
def _self_terminate(self, signum: int, frame):
"""Pass SIGTERM to Uvicorn so it can shut down gracefully"""
signame = signal.Signals(signum).name
if not self._shutting_down:
self._shutting_down = True
if self.http_server:
logger.info(
f"[{self.service_name}] 🛑 Received {signame} ({signum}) - "
"Entering RPC server graceful shutdown"
)
self.http_server.handle_exit(signum, frame) # stop accepting requests
# NOTE: Actually stopping the process is triggered by:
# 1. The call to self.cleanup() at the end of __start_fastapi() 👆🏼
# 2. BaseAppService.cleanup() stopping the shared event loop
else:
logger.warning(
f"[{self.service_name}] {signame} received before HTTP server init."
" Terminating..."
)
sys.exit(0)
else:
# Expedite shutdown on second SIGTERM
logger.info(
f"[{self.service_name}] 🛑🛑 Received {signame} ({signum}), "
"but shutdown is already underway. Terminating..."
)
sys.exit(0)
@asynccontextmanager
async def lifespan(self, app: FastAPI):
"""
The FastAPI/Uvicorn server's lifespan manager, used for setup and shutdown.
You can extend and use this in a subclass like:
```
@asynccontextmanager
async def lifespan(self, app: FastAPI):
async with super().lifespan(app):
await db.connect()
yield
await db.disconnect()
```
"""
# Startup - this runs before Uvicorn starts accepting connections
yield
# Shutdown - this runs when FastAPI/Uvicorn shuts down
logger.info(f"[{self.service_name}] ✅ FastAPI has finished")
async def health_check(self) -> str:
"""
A method to check the health of the process.
"""
"""A method to check the health of the process."""
return "OK"
def run(self):
sentry_init()
super().run()
self.fastapi_app = FastAPI()
self.fastapi_app = FastAPI(lifespan=self.lifespan)
# Add Prometheus instrumentation to all services
try:
@@ -325,7 +416,11 @@ class AppService(BaseAppService, ABC):
)
# Start the FastAPI server in a separate thread.
api_thread = threading.Thread(target=self.__start_fastapi, daemon=True)
api_thread = threading.Thread(
target=self.__start_fastapi,
daemon=True,
name=f"{self.service_name}-http-server",
)
api_thread.start()
# Run the main service loop (blocking).

View File

@@ -1,3 +1,5 @@
import asyncio
import contextlib
import time
from functools import cached_property
from unittest.mock import Mock
@@ -18,20 +20,11 @@ from backend.util.service import (
TEST_SERVICE_PORT = 8765
def wait_for_service_ready(service_client_type, timeout_seconds=30):
"""Helper method to wait for a service to be ready using health check with retry."""
client = get_service_client(service_client_type, request_retry=True)
client.health_check() # This will retry until service is ready
class ServiceTest(AppService):
def __init__(self):
super().__init__()
self.fail_count = 0
def cleanup(self):
pass
@classmethod
def get_port(cls) -> int:
return TEST_SERVICE_PORT
@@ -41,10 +34,17 @@ class ServiceTest(AppService):
result = super().__enter__()
# Wait for the service to be ready
wait_for_service_ready(ServiceTestClient)
self.wait_until_ready()
return result
def wait_until_ready(self, timeout_seconds: int = 5):
"""Helper method to wait for a service to be ready using health check with retry."""
client = get_service_client(
ServiceTestClient, call_timeout=timeout_seconds, request_retry=True
)
client.health_check() # This will retry until service is ready\
@expose
def add(self, a: int, b: int) -> int:
return a + b
@@ -490,3 +490,167 @@ class TestHTTPErrorRetryBehavior:
)
assert exc_info.value.status_code == status_code
class TestGracefulShutdownService(AppService):
"""Test service with slow endpoints for testing graceful shutdown"""
@classmethod
def get_port(cls) -> int:
return 18999 # Use a specific test port
def __init__(self):
super().__init__()
self.request_log = []
self.cleanup_called = False
self.cleanup_completed = False
@expose
async def slow_endpoint(self, duration: int = 5) -> dict:
"""Endpoint that takes time to complete"""
start_time = time.time()
self.request_log.append(f"slow_endpoint started at {start_time}")
await asyncio.sleep(duration)
end_time = time.time()
result = {
"message": "completed",
"duration": end_time - start_time,
"start_time": start_time,
"end_time": end_time,
}
self.request_log.append(f"slow_endpoint completed at {end_time}")
return result
@expose
def fast_endpoint(self) -> dict:
"""Fast endpoint for testing rejection during shutdown"""
timestamp = time.time()
self.request_log.append(f"fast_endpoint called at {timestamp}")
return {"message": "fast", "timestamp": timestamp}
def cleanup(self):
"""Override cleanup to track when it's called"""
self.cleanup_called = True
self.request_log.append(f"cleanup started at {time.time()}")
# Call parent cleanup
super().cleanup()
self.cleanup_completed = True
self.request_log.append(f"cleanup completed at {time.time()}")
@pytest.fixture(scope="function")
async def test_service():
"""Run the test service in a separate process"""
service = TestGracefulShutdownService()
service.start(background=True)
base_url = f"http://localhost:{service.get_port()}"
await wait_until_service_ready(base_url)
yield service, base_url
service.stop()
async def wait_until_service_ready(base_url: str, timeout: float = 10):
start_time = time.time()
while time.time() - start_time <= timeout:
async with httpx.AsyncClient(timeout=5) as client:
with contextlib.suppress(httpx.ConnectError):
response = await client.get(f"{base_url}/health_check", timeout=5)
if response.status_code == 200 and response.json() == "OK":
return
await asyncio.sleep(0.5)
raise RuntimeError(f"Service at {base_url} not available after {timeout} seconds")
async def send_slow_request(base_url: str) -> dict:
"""Send a slow request and return the result"""
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(f"{base_url}/slow_endpoint", json={"duration": 5})
assert response.status_code == 200
return response.json()
@pytest.mark.asyncio
async def test_graceful_shutdown(test_service):
"""Test that AppService handles graceful shutdown correctly"""
service, test_service_url = test_service
# Start a slow request that should complete even after shutdown
slow_task = asyncio.create_task(send_slow_request(test_service_url))
# Give the slow request time to start
await asyncio.sleep(1)
# Send SIGTERM to the service process
shutdown_start_time = time.time()
service.process.terminate() # This sends SIGTERM
# Wait a moment for shutdown to start
await asyncio.sleep(0.5)
# Try to send a new request - should be rejected or connection refused
try:
async with httpx.AsyncClient(timeout=5) as client:
response = await client.post(f"{test_service_url}/fast_endpoint", json={})
# Should get 503 Service Unavailable during shutdown
assert response.status_code == 503
assert "shutting down" in response.json()["detail"].lower()
except httpx.ConnectError:
# Connection refused is also acceptable - server stopped accepting
pass
# The slow request should still complete successfully
slow_result = await slow_task
assert slow_result["message"] == "completed"
assert 4.9 < slow_result["duration"] < 5.5 # Should have taken ~5 seconds
# Wait for the service to fully shut down
service.process.join(timeout=15)
shutdown_end_time = time.time()
# Verify the service actually terminated
assert not service.process.is_alive()
# Verify shutdown took reasonable time (slow request - 1s + cleanup)
shutdown_duration = shutdown_end_time - shutdown_start_time
assert 4 <= shutdown_duration <= 6 # ~5s request - 1s + buffer
print(f"Shutdown took {shutdown_duration:.2f} seconds")
print(f"Slow request completed in: {slow_result['duration']:.2f} seconds")
@pytest.mark.asyncio
async def test_health_check_during_shutdown(test_service):
"""Test that health checks behave correctly during shutdown"""
service, test_service_url = test_service
# Health check should pass initially
async with httpx.AsyncClient(timeout=5) as client:
response = await client.get(f"{test_service_url}/health_check")
assert response.status_code == 200
# Send SIGTERM
service.process.terminate()
# Wait for shutdown to begin
await asyncio.sleep(1)
# Health check should now fail or connection should be refused
try:
async with httpx.AsyncClient(timeout=5) as client:
response = await client.get(f"{test_service_url}/health_check")
# Could either get 503, 500 (unhealthy), or connection error
assert response.status_code in [500, 503]
except (httpx.ConnectError, httpx.ConnectTimeout):
# Connection refused/timeout is also acceptable
pass