From cb8c1fa263852b40ff16eb5466b79cee28f6a8b0 Mon Sep 17 00:00:00 2001 From: Tim O'Farrell Date: Fri, 26 Dec 2025 09:19:51 -0700 Subject: [PATCH] ALL-4627 Database Fixes (#12156) Co-authored-by: openhands --- enterprise/storage/database.py | 12 +++- .../app_conversation_router.py | 15 +++-- .../event_callback/webhook_router.py | 4 -- .../services/db_session_injector.py | 61 +++++++++++-------- .../app_server/test_db_session_injector.py | 9 +-- 5 files changed, 58 insertions(+), 43 deletions(-) diff --git a/enterprise/storage/database.py b/enterprise/storage/database.py index f0d8e9d62c..ec06550e03 100644 --- a/enterprise/storage/database.py +++ b/enterprise/storage/database.py @@ -19,17 +19,23 @@ GCP_REGION = os.environ.get('GCP_REGION') POOL_SIZE = int(os.environ.get('DB_POOL_SIZE', '25')) MAX_OVERFLOW = int(os.environ.get('DB_MAX_OVERFLOW', '10')) +POOL_RECYCLE = int(os.environ.get('DB_POOL_RECYCLE', '1800')) + +# Initialize Cloud SQL Connector once at module level for GCP environments. +_connector = None def _get_db_engine(): if GCP_DB_INSTANCE: # GCP environments def get_db_connection(): + global _connector from google.cloud.sql.connector import Connector - connector = Connector() + if not _connector: + _connector = Connector() instance_string = f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}' - return connector.connect( + return _connector.connect( instance_string, 'pg8000', user=DB_USER, password=DB_PASS, db=DB_NAME ) @@ -38,6 +44,7 @@ def _get_db_engine(): creator=get_db_connection, pool_size=POOL_SIZE, max_overflow=MAX_OVERFLOW, + pool_recycle=POOL_RECYCLE, pool_pre_ping=True, ) else: @@ -48,6 +55,7 @@ def _get_db_engine(): host_string, pool_size=POOL_SIZE, max_overflow=MAX_OVERFLOW, + pool_recycle=POOL_RECYCLE, pool_pre_ping=True, ) diff --git a/openhands/app_server/app_conversation/app_conversation_router.py b/openhands/app_server/app_conversation/app_conversation_router.py index 532602dbca..f68b80ba4e 100644 --- a/openhands/app_server/app_conversation/app_conversation_router.py +++ b/openhands/app_server/app_conversation/app_conversation_router.py @@ -210,11 +210,16 @@ async def start_app_conversation( set_db_session_keep_open(request.state, True) set_httpx_client_keep_open(request.state, True) - """Start an app conversation start task and return it.""" - async_iter = app_conversation_service.start_app_conversation(start_request) - result = await anext(async_iter) - asyncio.create_task(_consume_remaining(async_iter, db_session, httpx_client)) - return result + try: + """Start an app conversation start task and return it.""" + async_iter = app_conversation_service.start_app_conversation(start_request) + result = await anext(async_iter) + asyncio.create_task(_consume_remaining(async_iter, db_session, httpx_client)) + return result + except Exception: + await db_session.close() + await httpx_client.aclose() + raise @router.post('/stream-start') diff --git a/openhands/app_server/event_callback/webhook_router.py b/openhands/app_server/event_callback/webhook_router.py index 37ae9d89b2..62dd7bec16 100644 --- a/openhands/app_server/event_callback/webhook_router.py +++ b/openhands/app_server/event_callback/webhook_router.py @@ -21,12 +21,10 @@ from openhands.app_server.app_conversation.app_conversation_models import ( ) from openhands.app_server.config import ( depends_app_conversation_info_service, - depends_db_session, depends_event_service, depends_jwt_service, depends_sandbox_service, get_event_callback_service, - get_global_config, ) from openhands.app_server.errors import AuthError from openhands.app_server.event.event_service import EventService @@ -54,8 +52,6 @@ sandbox_service_dependency = depends_sandbox_service() event_service_dependency = depends_event_service() app_conversation_info_service_dependency = depends_app_conversation_info_service() jwt_dependency = depends_jwt_service() -config = get_global_config() -db_session_dependency = depends_db_session() _logger = logging.getLogger(__name__) diff --git a/openhands/app_server/services/db_session_injector.py b/openhands/app_server/services/db_session_injector.py index c59243af91..737e1ff879 100644 --- a/openhands/app_server/services/db_session_injector.py +++ b/openhands/app_server/services/db_session_injector.py @@ -4,8 +4,9 @@ import asyncio import logging import os from pathlib import Path -from typing import AsyncGenerator +from typing import Any, AsyncGenerator +import asyncpg from fastapi import Request from pydantic import BaseModel, PrivateAttr, SecretStr, model_validator from sqlalchemy import Engine, create_engine @@ -33,6 +34,7 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]): echo: bool = False pool_size: int = 25 max_overflow: int = 10 + pool_recycle: int = 1800 gcp_db_instance: str | None = None gcp_project: str | None = None gcp_region: str | None = None @@ -42,6 +44,7 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]): _async_engine: AsyncEngine | None = PrivateAttr(default=None) _session_maker: sessionmaker | None = PrivateAttr(default=None) _async_session_maker: async_sessionmaker | None = PrivateAttr(default=None) + _gcp_connector: Any = PrivateAttr(default=None) @model_validator(mode='after') def fill_empty_fields(self): @@ -65,14 +68,18 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]): return self def _create_gcp_db_connection(self): - # Lazy import because lib does not import if user does not have posgres installed - from google.cloud.sql.connector import Connector + gcp_connector = self._gcp_connector + if gcp_connector is None: + # Lazy import because lib does not import if user does not have posgres installed + from google.cloud.sql.connector import Connector + + gcp_connector = Connector() + self._gcp_connector = gcp_connector - connector = Connector() instance_string = f'{self.gcp_project}:{self.gcp_region}:{self.gcp_db_instance}' password = self.password assert password is not None - return connector.connect( + return gcp_connector.connect( instance_string, 'pg8000', user=self.user, @@ -81,21 +88,25 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]): ) async def _create_async_gcp_db_connection(self): - # Lazy import because lib does not import if user does not have posgres installed - from google.cloud.sql.connector import Connector + gcp_connector = self._gcp_connector + if gcp_connector is None: + # Lazy import because lib does not import if user does not have posgres installed + from google.cloud.sql.connector import Connector - loop = asyncio.get_running_loop() - async with Connector(loop=loop) as connector: - password = self.password - assert password is not None - conn = await connector.connect_async( - f'{self.gcp_project}:{self.gcp_region}:{self.gcp_db_instance}', - 'asyncpg', - user=self.user, - password=password.get_secret_value(), - db=self.name, - ) - return conn + loop = asyncio.get_running_loop() + gcp_connector = Connector(loop=loop) + self._gcp_connector = gcp_connector + + password = self.password + assert password is not None + conn = await gcp_connector.connect_async( + f'{self.gcp_project}:{self.gcp_region}:{self.gcp_db_instance}', + 'asyncpg', + user=self.user, + password=password.get_secret_value(), + db=self.name, + ) + return conn def _create_gcp_engine(self): engine = create_engine( @@ -112,10 +123,8 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]): AsyncAdapt_asyncpg_connection, ) - engine = self._create_gcp_engine() - return AsyncAdapt_asyncpg_connection( - engine.dialect.dbapi, + asyncpg, await self._create_async_gcp_db_connection(), prepared_statement_cache_size=100, ) @@ -125,12 +134,9 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]): AsyncAdapt_asyncpg_connection, ) - base_engine = self._create_gcp_engine() - dbapi = base_engine.dialect.dbapi - def adapted_creator(): return AsyncAdapt_asyncpg_connection( - dbapi, + asyncpg, await_only(self._create_async_gcp_db_connection()), prepared_statement_cache_size=100, ) @@ -141,6 +147,7 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]): pool_size=self.pool_size, max_overflow=self.max_overflow, pool_pre_ping=True, + pool_recycle=self.pool_recycle, ) async def get_async_db_engine(self) -> AsyncEngine: @@ -174,6 +181,7 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]): url, pool_size=self.pool_size, max_overflow=self.max_overflow, + pool_recycle=self.pool_recycle, pool_pre_ping=True, ) else: @@ -214,6 +222,7 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]): url, pool_size=self.pool_size, max_overflow=self.max_overflow, + pool_recycle=self.pool_recycle, pool_pre_ping=True, ) self._engine = engine diff --git a/tests/unit/app_server/test_db_session_injector.py b/tests/unit/app_server/test_db_session_injector.py index fd0908817c..8183cc26b9 100644 --- a/tests/unit/app_server/test_db_session_injector.py +++ b/tests/unit/app_server/test_db_session_injector.py @@ -456,13 +456,10 @@ class TestDbSessionInjectorGCPIntegration: # Mock the google.cloud.sql.connector module with patch.dict('sys.modules', {'google.cloud.sql.connector': MagicMock()}): mock_connector_module = sys.modules['google.cloud.sql.connector'] - mock_connector = AsyncMock() - mock_connector_module.Connector.return_value.__aenter__.return_value = ( - mock_connector - ) - mock_connector_module.Connector.return_value.__aexit__.return_value = None + mock_connector = MagicMock() + mock_connector_module.Connector.return_value = mock_connector mock_connection = AsyncMock() - mock_connector.connect_async.return_value = mock_connection + mock_connector.connect_async = AsyncMock(return_value=mock_connection) connection = await gcp_db_session_injector._create_async_gcp_db_connection()