mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 15:17:59 -05:00
fix(backend): Implement graceful shutdown in AppService to prevent RPC errors (#11240)
We're currently seeing errors in the `DatabaseManager` while it's shutting down, like: ``` WARNING [DatabaseManager] Termination request: SystemExit; 0 executing cleanup. INFO [DatabaseManager] ⏳ Disconnecting Database... INFO [PID-1|THREAD-29|DatabaseManager|Prisma-82fb1994-4b87-40c1-8869-fbd97bd33fc8] Releasing connection started... INFO [PID-1|THREAD-29|DatabaseManager|Prisma-82fb1994-4b87-40c1-8869-fbd97bd33fc8] Releasing connection completed successfully. INFO [DatabaseManager] Terminated. ERROR POST /create_or_add_to_user_notification_batch failed: Failed to create or add to notification batch for user {user_id} and type AGENT_RUN: NoneType: None ``` This indicates two issues: - The service doesn't wait for pending RPC calls to finish before terminating - We're using `logger.exception` outside an error handling context, causing the confusing and not much useful `NoneType: None` to be printed instead of error info ### Changes 🏗️ - Implement graceful shutdown in `AppService` so in-flight RPC calls can finish - Add tests for graceful shutdown - Prevent `AppService` accepting new requests during shutdown - Rework `AppService` lifecycle management; add support for async `lifespan` - Fix `AppService` endpoint error logging - Improve logging in `AppProcess` and `AppService` ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - Deploy to Dev cluster, then `kubectl rollout restart` the different services a few times - [x] -> `DatabaseManager` doesn't break on re-deployment - [x] -> `Scheduler` doesn't break on re-deployment - [x] -> `NotificationManager` doesn't break on re-deployment
This commit is contained in:
committed by
GitHub
parent
acb946801b
commit
e06e7ff33f
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import signal
|
||||
import threading
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
|
||||
@@ -26,6 +27,13 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
# Suppress false positive cleanup warning of litellm (a dependency of stagehand)
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="coroutine 'close_litellm_async_clients' was never awaited",
|
||||
category=RuntimeWarning,
|
||||
)
|
||||
|
||||
# Store the original method
|
||||
original_register_signal_handlers = stagehand.main.Stagehand._register_signal_handlers
|
||||
|
||||
|
||||
@@ -45,9 +45,6 @@ class MainApp(AppProcess):
|
||||
|
||||
app.main(silent=True)
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
|
||||
|
||||
@click.group()
|
||||
def main():
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from typing import Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cast
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
@@ -57,6 +58,9 @@ from backend.util.service import (
|
||||
)
|
||||
from backend.util.settings import Config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import FastAPI
|
||||
|
||||
config = Config()
|
||||
logger = logging.getLogger(__name__)
|
||||
P = ParamSpec("P")
|
||||
@@ -76,15 +80,17 @@ async def _get_credits(user_id: str) -> int:
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
def run_service(self) -> None:
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
|
||||
self.run_and_wait(db.connect())
|
||||
super().run_service()
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, app: "FastAPI"):
|
||||
async with super().lifespan(app):
|
||||
logger.info(f"[{self.service_name}] ⏳ Connecting to Database...")
|
||||
await db.connect()
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
|
||||
self.run_and_wait(db.disconnect())
|
||||
logger.info(f"[{self.service_name}] ✅ Ready")
|
||||
yield
|
||||
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting Database...")
|
||||
await db.disconnect()
|
||||
|
||||
async def health_check(self) -> str:
|
||||
if not db.is_connected():
|
||||
|
||||
@@ -1714,6 +1714,8 @@ class ExecutionManager(AppProcess):
|
||||
|
||||
logger.info(f"{prefix} ✅ Finished GraphExec cleanup")
|
||||
|
||||
super().cleanup()
|
||||
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
|
||||
|
||||
@@ -248,7 +248,7 @@ class Scheduler(AppService):
|
||||
raise UnhealthyServiceError("Scheduler is still initializing")
|
||||
|
||||
# Check if we're in the middle of cleanup
|
||||
if self.cleaned_up:
|
||||
if self._shutting_down:
|
||||
return await super().health_check()
|
||||
|
||||
# Normal operation - check if scheduler is running
|
||||
@@ -375,7 +375,6 @@ class Scheduler(AppService):
|
||||
super().run_service()
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
if self.scheduler:
|
||||
logger.info("⏳ Shutting down scheduler...")
|
||||
self.scheduler.shutdown(wait=True)
|
||||
@@ -390,7 +389,7 @@ class Scheduler(AppService):
|
||||
logger.info("⏳ Waiting for event loop thread to finish...")
|
||||
_event_loop_thread.join(timeout=SCHEDULER_OPERATION_TIMEOUT_SECONDS)
|
||||
|
||||
logger.info("Scheduler cleanup complete.")
|
||||
super().cleanup()
|
||||
|
||||
@expose
|
||||
def add_graph_execution_schedule(
|
||||
|
||||
@@ -1017,10 +1017,14 @@ class NotificationManager(AppService):
|
||||
logger.exception(f"Fatal error in consumer for {queue_name}: {e}")
|
||||
raise
|
||||
|
||||
@continuous_retry()
|
||||
def run_service(self):
|
||||
self.run_and_wait(self._run_service())
|
||||
# Queue the main _run_service task
|
||||
asyncio.run_coroutine_threadsafe(self._run_service(), self.shared_event_loop)
|
||||
|
||||
# Start the main event loop
|
||||
super().run_service()
|
||||
|
||||
@continuous_retry()
|
||||
async def _run_service(self):
|
||||
logger.info(f"[{self.service_name}] ⏳ Configuring RabbitMQ...")
|
||||
self.rabbitmq_service = rabbitmq.AsyncRabbitMQ(self.rabbitmq_config)
|
||||
@@ -1086,10 +1090,11 @@ class NotificationManager(AppService):
|
||||
def cleanup(self):
|
||||
"""Cleanup service resources"""
|
||||
self.running = False
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Disconnecting RabbitMQ...")
|
||||
logger.info("⏳ Disconnecting RabbitMQ...")
|
||||
self.run_and_wait(self.rabbitmq_service.disconnect())
|
||||
|
||||
super().cleanup()
|
||||
|
||||
|
||||
class NotificationManagerClient(AppServiceClient):
|
||||
@classmethod
|
||||
|
||||
@@ -321,10 +321,6 @@ class AgentServer(backend.util.service.AppProcess):
|
||||
|
||||
uvicorn.run(**uvicorn_config)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down Agent Server...")
|
||||
|
||||
@staticmethod
|
||||
async def test_execute_graph(
|
||||
graph_id: str,
|
||||
|
||||
@@ -329,7 +329,3 @@ class WebsocketServer(AppProcess):
|
||||
port=Config().websocket_server_port,
|
||||
log_config=None,
|
||||
)
|
||||
|
||||
def cleanup(self):
|
||||
super().cleanup()
|
||||
logger.info(f"[{self.service_name}] ⏳ Shutting down WebSocket Server...")
|
||||
|
||||
@@ -19,7 +19,8 @@ class AppProcess(ABC):
|
||||
"""
|
||||
|
||||
process: Optional[Process] = None
|
||||
cleaned_up = False
|
||||
_shutting_down: bool = False
|
||||
_cleaned_up: bool = False
|
||||
|
||||
if "forkserver" in get_all_start_methods():
|
||||
set_start_method("forkserver", force=True)
|
||||
@@ -43,7 +44,6 @@ class AppProcess(ABC):
|
||||
def service_name(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
@abstractmethod
|
||||
def cleanup(self):
|
||||
"""
|
||||
Implement this method on a subclass to do post-execution cleanup,
|
||||
@@ -65,7 +65,8 @@ class AppProcess(ABC):
|
||||
self.run()
|
||||
except BaseException as e:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] Termination request: {type(e).__name__}; {e} executing cleanup."
|
||||
f"[{self.service_name}] 🛑 Terminating because of {type(e).__name__}: {e}", # noqa
|
||||
exc_info=e if not isinstance(e, SystemExit) else None,
|
||||
)
|
||||
# Send error to Sentry before cleanup
|
||||
if not isinstance(e, (KeyboardInterrupt, SystemExit)):
|
||||
@@ -76,8 +77,12 @@ class AppProcess(ABC):
|
||||
except Exception:
|
||||
pass # Silently ignore if Sentry isn't available
|
||||
finally:
|
||||
self.cleanup()
|
||||
logger.info(f"[{self.service_name}] Terminated.")
|
||||
if not self._cleaned_up:
|
||||
self._cleaned_up = True
|
||||
logger.info(f"[{self.service_name}] 🧹 Running cleanup")
|
||||
self.cleanup()
|
||||
logger.info(f"[{self.service_name}] ✅ Cleanup done")
|
||||
logger.info(f"[{self.service_name}] 🛑 Terminated")
|
||||
|
||||
@staticmethod
|
||||
def llprint(message: str):
|
||||
@@ -88,8 +93,8 @@ class AppProcess(ABC):
|
||||
os.write(sys.stdout.fileno(), (message + "\n").encode())
|
||||
|
||||
def _self_terminate(self, signum: int, frame):
|
||||
if not self.cleaned_up:
|
||||
self.cleaned_up = True
|
||||
if not self._shutting_down:
|
||||
self._shutting_down = True
|
||||
sys.exit(0)
|
||||
else:
|
||||
self.llprint(
|
||||
|
||||
@@ -4,9 +4,12 @@ import concurrent.futures
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import update_wrapper
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -111,14 +114,44 @@ class BaseAppService(AppProcess, ABC):
|
||||
return target_host
|
||||
|
||||
def run_service(self) -> None:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
# HACK: run the main event loop outside the main thread to disable Uvicorn's
|
||||
# internal signal handlers, since there is no config option for this :(
|
||||
shared_asyncio_thread = threading.Thread(
|
||||
target=self._run_shared_event_loop,
|
||||
daemon=True,
|
||||
name=f"{self.service_name}-shared-event-loop",
|
||||
)
|
||||
shared_asyncio_thread.start()
|
||||
shared_asyncio_thread.join()
|
||||
|
||||
def _run_shared_event_loop(self) -> None:
|
||||
try:
|
||||
self.shared_event_loop.run_forever()
|
||||
finally:
|
||||
logger.info(f"[{self.service_name}] 🛑 Shared event loop stopped")
|
||||
self.shared_event_loop.close() # ensure held resources are released
|
||||
|
||||
def run_and_wait(self, coro: Coroutine[Any, Any, T]) -> T:
|
||||
return asyncio.run_coroutine_threadsafe(coro, self.shared_event_loop).result()
|
||||
|
||||
def run(self):
|
||||
self.shared_event_loop = asyncio.get_event_loop()
|
||||
self.shared_event_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.shared_event_loop)
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
**💡 Overriding `AppService.lifespan` may be a more convenient option.**
|
||||
|
||||
Implement this method on a subclass to do post-execution cleanup,
|
||||
e.g. disconnecting from a database or terminating child processes.
|
||||
|
||||
**Note:** if you override this method in a subclass, it must call
|
||||
`super().cleanup()` *at the end*!
|
||||
"""
|
||||
# Stop the shared event loop to allow resource clean-up
|
||||
self.shared_event_loop.call_soon_threadsafe(self.shared_event_loop.stop)
|
||||
|
||||
super().cleanup()
|
||||
|
||||
|
||||
class RemoteCallError(BaseModel):
|
||||
@@ -179,6 +212,7 @@ EXCEPTION_MAPPING = {
|
||||
|
||||
class AppService(BaseAppService, ABC):
|
||||
fastapi_app: FastAPI
|
||||
http_server: uvicorn.Server | None = None
|
||||
log_level: str = "info"
|
||||
|
||||
def set_log_level(self, log_level: str):
|
||||
@@ -190,11 +224,10 @@ class AppService(BaseAppService, ABC):
|
||||
def _handle_internal_http_error(status_code: int = 500, log_error: bool = True):
|
||||
def handler(request: Request, exc: Exception):
|
||||
if log_error:
|
||||
if status_code == 500:
|
||||
log = logger.exception
|
||||
else:
|
||||
log = logger.error
|
||||
log(f"{request.method} {request.url.path} failed: {exc}")
|
||||
logger.error(
|
||||
f"{request.method} {request.url.path} failed: {exc}",
|
||||
exc_info=exc if status_code == 500 else None,
|
||||
)
|
||||
return responses.JSONResponse(
|
||||
status_code=status_code,
|
||||
content=RemoteCallError(
|
||||
@@ -256,13 +289,13 @@ class AppService(BaseAppService, ABC):
|
||||
|
||||
return sync_endpoint
|
||||
|
||||
@conn_retry("FastAPI server", "Starting FastAPI server")
|
||||
@conn_retry("FastAPI server", "Running FastAPI server")
|
||||
def __start_fastapi(self):
|
||||
logger.info(
|
||||
f"[{self.service_name}] Starting RPC server at http://{api_host}:{self.get_port()}"
|
||||
)
|
||||
|
||||
server = uvicorn.Server(
|
||||
self.http_server = uvicorn.Server(
|
||||
uvicorn.Config(
|
||||
self.fastapi_app,
|
||||
host=api_host,
|
||||
@@ -271,18 +304,76 @@ class AppService(BaseAppService, ABC):
|
||||
log_level=self.log_level,
|
||||
)
|
||||
)
|
||||
self.shared_event_loop.run_until_complete(server.serve())
|
||||
self.run_and_wait(self.http_server.serve())
|
||||
|
||||
# Perform clean-up when the server exits
|
||||
if not self._cleaned_up:
|
||||
self._cleaned_up = True
|
||||
logger.info(f"[{self.service_name}] 🧹 Running cleanup")
|
||||
self.cleanup()
|
||||
logger.info(f"[{self.service_name}] ✅ Cleanup done")
|
||||
|
||||
def _self_terminate(self, signum: int, frame):
|
||||
"""Pass SIGTERM to Uvicorn so it can shut down gracefully"""
|
||||
signame = signal.Signals(signum).name
|
||||
if not self._shutting_down:
|
||||
self._shutting_down = True
|
||||
if self.http_server:
|
||||
logger.info(
|
||||
f"[{self.service_name}] 🛑 Received {signame} ({signum}) - "
|
||||
"Entering RPC server graceful shutdown"
|
||||
)
|
||||
self.http_server.handle_exit(signum, frame) # stop accepting requests
|
||||
|
||||
# NOTE: Actually stopping the process is triggered by:
|
||||
# 1. The call to self.cleanup() at the end of __start_fastapi() 👆🏼
|
||||
# 2. BaseAppService.cleanup() stopping the shared event loop
|
||||
else:
|
||||
logger.warning(
|
||||
f"[{self.service_name}] {signame} received before HTTP server init."
|
||||
" Terminating..."
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
else:
|
||||
# Expedite shutdown on second SIGTERM
|
||||
logger.info(
|
||||
f"[{self.service_name}] 🛑🛑 Received {signame} ({signum}), "
|
||||
"but shutdown is already underway. Terminating..."
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, app: FastAPI):
|
||||
"""
|
||||
The FastAPI/Uvicorn server's lifespan manager, used for setup and shutdown.
|
||||
|
||||
You can extend and use this in a subclass like:
|
||||
```
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, app: FastAPI):
|
||||
async with super().lifespan(app):
|
||||
await db.connect()
|
||||
yield
|
||||
await db.disconnect()
|
||||
```
|
||||
"""
|
||||
# Startup - this runs before Uvicorn starts accepting connections
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown - this runs when FastAPI/Uvicorn shuts down
|
||||
logger.info(f"[{self.service_name}] ✅ FastAPI has finished")
|
||||
|
||||
async def health_check(self) -> str:
|
||||
"""
|
||||
A method to check the health of the process.
|
||||
"""
|
||||
"""A method to check the health of the process."""
|
||||
return "OK"
|
||||
|
||||
def run(self):
|
||||
sentry_init()
|
||||
super().run()
|
||||
self.fastapi_app = FastAPI()
|
||||
|
||||
self.fastapi_app = FastAPI(lifespan=self.lifespan)
|
||||
|
||||
# Add Prometheus instrumentation to all services
|
||||
try:
|
||||
@@ -325,7 +416,11 @@ class AppService(BaseAppService, ABC):
|
||||
)
|
||||
|
||||
# Start the FastAPI server in a separate thread.
|
||||
api_thread = threading.Thread(target=self.__start_fastapi, daemon=True)
|
||||
api_thread = threading.Thread(
|
||||
target=self.__start_fastapi,
|
||||
daemon=True,
|
||||
name=f"{self.service_name}-http-server",
|
||||
)
|
||||
api_thread.start()
|
||||
|
||||
# Run the main service loop (blocking).
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import time
|
||||
from functools import cached_property
|
||||
from unittest.mock import Mock
|
||||
@@ -18,20 +20,11 @@ from backend.util.service import (
|
||||
TEST_SERVICE_PORT = 8765
|
||||
|
||||
|
||||
def wait_for_service_ready(service_client_type, timeout_seconds=30):
|
||||
"""Helper method to wait for a service to be ready using health check with retry."""
|
||||
client = get_service_client(service_client_type, request_retry=True)
|
||||
client.health_check() # This will retry until service is ready
|
||||
|
||||
|
||||
class ServiceTest(AppService):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fail_count = 0
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return TEST_SERVICE_PORT
|
||||
@@ -41,10 +34,17 @@ class ServiceTest(AppService):
|
||||
result = super().__enter__()
|
||||
|
||||
# Wait for the service to be ready
|
||||
wait_for_service_ready(ServiceTestClient)
|
||||
self.wait_until_ready()
|
||||
|
||||
return result
|
||||
|
||||
def wait_until_ready(self, timeout_seconds: int = 5):
|
||||
"""Helper method to wait for a service to be ready using health check with retry."""
|
||||
client = get_service_client(
|
||||
ServiceTestClient, call_timeout=timeout_seconds, request_retry=True
|
||||
)
|
||||
client.health_check() # This will retry until service is ready\
|
||||
|
||||
@expose
|
||||
def add(self, a: int, b: int) -> int:
|
||||
return a + b
|
||||
@@ -490,3 +490,167 @@ class TestHTTPErrorRetryBehavior:
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status_code
|
||||
|
||||
|
||||
class TestGracefulShutdownService(AppService):
|
||||
"""Test service with slow endpoints for testing graceful shutdown"""
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
return 18999 # Use a specific test port
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.request_log = []
|
||||
self.cleanup_called = False
|
||||
self.cleanup_completed = False
|
||||
|
||||
@expose
|
||||
async def slow_endpoint(self, duration: int = 5) -> dict:
|
||||
"""Endpoint that takes time to complete"""
|
||||
start_time = time.time()
|
||||
self.request_log.append(f"slow_endpoint started at {start_time}")
|
||||
|
||||
await asyncio.sleep(duration)
|
||||
|
||||
end_time = time.time()
|
||||
result = {
|
||||
"message": "completed",
|
||||
"duration": end_time - start_time,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
}
|
||||
self.request_log.append(f"slow_endpoint completed at {end_time}")
|
||||
return result
|
||||
|
||||
@expose
|
||||
def fast_endpoint(self) -> dict:
|
||||
"""Fast endpoint for testing rejection during shutdown"""
|
||||
timestamp = time.time()
|
||||
self.request_log.append(f"fast_endpoint called at {timestamp}")
|
||||
return {"message": "fast", "timestamp": timestamp}
|
||||
|
||||
def cleanup(self):
|
||||
"""Override cleanup to track when it's called"""
|
||||
self.cleanup_called = True
|
||||
self.request_log.append(f"cleanup started at {time.time()}")
|
||||
|
||||
# Call parent cleanup
|
||||
super().cleanup()
|
||||
|
||||
self.cleanup_completed = True
|
||||
self.request_log.append(f"cleanup completed at {time.time()}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def test_service():
|
||||
"""Run the test service in a separate process"""
|
||||
|
||||
service = TestGracefulShutdownService()
|
||||
service.start(background=True)
|
||||
|
||||
base_url = f"http://localhost:{service.get_port()}"
|
||||
|
||||
await wait_until_service_ready(base_url)
|
||||
yield service, base_url
|
||||
|
||||
service.stop()
|
||||
|
||||
|
||||
async def wait_until_service_ready(base_url: str, timeout: float = 10):
|
||||
start_time = time.time()
|
||||
while time.time() - start_time <= timeout:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
with contextlib.suppress(httpx.ConnectError):
|
||||
response = await client.get(f"{base_url}/health_check", timeout=5)
|
||||
|
||||
if response.status_code == 200 and response.json() == "OK":
|
||||
return
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
raise RuntimeError(f"Service at {base_url} not available after {timeout} seconds")
|
||||
|
||||
|
||||
async def send_slow_request(base_url: str) -> dict:
|
||||
"""Send a slow request and return the result"""
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.post(f"{base_url}/slow_endpoint", json={"duration": 5})
|
||||
assert response.status_code == 200
|
||||
return response.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graceful_shutdown(test_service):
|
||||
"""Test that AppService handles graceful shutdown correctly"""
|
||||
service, test_service_url = test_service
|
||||
|
||||
# Start a slow request that should complete even after shutdown
|
||||
slow_task = asyncio.create_task(send_slow_request(test_service_url))
|
||||
|
||||
# Give the slow request time to start
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Send SIGTERM to the service process
|
||||
shutdown_start_time = time.time()
|
||||
service.process.terminate() # This sends SIGTERM
|
||||
|
||||
# Wait a moment for shutdown to start
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Try to send a new request - should be rejected or connection refused
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
response = await client.post(f"{test_service_url}/fast_endpoint", json={})
|
||||
# Should get 503 Service Unavailable during shutdown
|
||||
assert response.status_code == 503
|
||||
assert "shutting down" in response.json()["detail"].lower()
|
||||
except httpx.ConnectError:
|
||||
# Connection refused is also acceptable - server stopped accepting
|
||||
pass
|
||||
|
||||
# The slow request should still complete successfully
|
||||
slow_result = await slow_task
|
||||
assert slow_result["message"] == "completed"
|
||||
assert 4.9 < slow_result["duration"] < 5.5 # Should have taken ~5 seconds
|
||||
|
||||
# Wait for the service to fully shut down
|
||||
service.process.join(timeout=15)
|
||||
shutdown_end_time = time.time()
|
||||
|
||||
# Verify the service actually terminated
|
||||
assert not service.process.is_alive()
|
||||
|
||||
# Verify shutdown took reasonable time (slow request - 1s + cleanup)
|
||||
shutdown_duration = shutdown_end_time - shutdown_start_time
|
||||
assert 4 <= shutdown_duration <= 6 # ~5s request - 1s + buffer
|
||||
|
||||
print(f"Shutdown took {shutdown_duration:.2f} seconds")
|
||||
print(f"Slow request completed in: {slow_result['duration']:.2f} seconds")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_during_shutdown(test_service):
|
||||
"""Test that health checks behave correctly during shutdown"""
|
||||
service, test_service_url = test_service
|
||||
|
||||
# Health check should pass initially
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
response = await client.get(f"{test_service_url}/health_check")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Send SIGTERM
|
||||
service.process.terminate()
|
||||
|
||||
# Wait for shutdown to begin
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Health check should now fail or connection should be refused
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
response = await client.get(f"{test_service_url}/health_check")
|
||||
# Could either get 503, 500 (unhealthy), or connection error
|
||||
assert response.status_code in [500, 503]
|
||||
except (httpx.ConnectError, httpx.ConnectTimeout):
|
||||
# Connection refused/timeout is also acceptable
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user