ALL-4627 Database Fixes (#12156)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Tim O'Farrell
2025-12-26 09:19:51 -07:00
committed by GitHub
parent c80f70392f
commit cb8c1fa263
5 changed files with 58 additions and 43 deletions

View File

@@ -19,17 +19,23 @@ GCP_REGION = os.environ.get('GCP_REGION')
POOL_SIZE = int(os.environ.get('DB_POOL_SIZE', '25'))
MAX_OVERFLOW = int(os.environ.get('DB_MAX_OVERFLOW', '10'))
POOL_RECYCLE = int(os.environ.get('DB_POOL_RECYCLE', '1800'))
# Initialize Cloud SQL Connector once at module level for GCP environments.
_connector = None
def _get_db_engine():
if GCP_DB_INSTANCE: # GCP environments
def get_db_connection():
global _connector
from google.cloud.sql.connector import Connector
connector = Connector()
if not _connector:
_connector = Connector()
instance_string = f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}'
return connector.connect(
return _connector.connect(
instance_string, 'pg8000', user=DB_USER, password=DB_PASS, db=DB_NAME
)
@@ -38,6 +44,7 @@ def _get_db_engine():
creator=get_db_connection,
pool_size=POOL_SIZE,
max_overflow=MAX_OVERFLOW,
pool_recycle=POOL_RECYCLE,
pool_pre_ping=True,
)
else:
@@ -48,6 +55,7 @@ def _get_db_engine():
host_string,
pool_size=POOL_SIZE,
max_overflow=MAX_OVERFLOW,
pool_recycle=POOL_RECYCLE,
pool_pre_ping=True,
)

View File

@@ -210,11 +210,16 @@ async def start_app_conversation(
set_db_session_keep_open(request.state, True)
set_httpx_client_keep_open(request.state, True)
"""Start an app conversation start task and return it."""
async_iter = app_conversation_service.start_app_conversation(start_request)
result = await anext(async_iter)
asyncio.create_task(_consume_remaining(async_iter, db_session, httpx_client))
return result
try:
"""Start an app conversation start task and return it."""
async_iter = app_conversation_service.start_app_conversation(start_request)
result = await anext(async_iter)
asyncio.create_task(_consume_remaining(async_iter, db_session, httpx_client))
return result
except Exception:
await db_session.close()
await httpx_client.aclose()
raise
@router.post('/stream-start')

View File

@@ -21,12 +21,10 @@ from openhands.app_server.app_conversation.app_conversation_models import (
)
from openhands.app_server.config import (
depends_app_conversation_info_service,
depends_db_session,
depends_event_service,
depends_jwt_service,
depends_sandbox_service,
get_event_callback_service,
get_global_config,
)
from openhands.app_server.errors import AuthError
from openhands.app_server.event.event_service import EventService
@@ -54,8 +52,6 @@ sandbox_service_dependency = depends_sandbox_service()
event_service_dependency = depends_event_service()
app_conversation_info_service_dependency = depends_app_conversation_info_service()
jwt_dependency = depends_jwt_service()
config = get_global_config()
db_session_dependency = depends_db_session()
_logger = logging.getLogger(__name__)

View File

@@ -4,8 +4,9 @@ import asyncio
import logging
import os
from pathlib import Path
from typing import AsyncGenerator
from typing import Any, AsyncGenerator
import asyncpg
from fastapi import Request
from pydantic import BaseModel, PrivateAttr, SecretStr, model_validator
from sqlalchemy import Engine, create_engine
@@ -33,6 +34,7 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
echo: bool = False
pool_size: int = 25
max_overflow: int = 10
pool_recycle: int = 1800
gcp_db_instance: str | None = None
gcp_project: str | None = None
gcp_region: str | None = None
@@ -42,6 +44,7 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
_async_engine: AsyncEngine | None = PrivateAttr(default=None)
_session_maker: sessionmaker | None = PrivateAttr(default=None)
_async_session_maker: async_sessionmaker | None = PrivateAttr(default=None)
_gcp_connector: Any = PrivateAttr(default=None)
@model_validator(mode='after')
def fill_empty_fields(self):
@@ -65,14 +68,18 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
return self
def _create_gcp_db_connection(self):
# Lazy import because lib does not import if user does not have posgres installed
from google.cloud.sql.connector import Connector
gcp_connector = self._gcp_connector
if gcp_connector is None:
# Lazy import because lib does not import if user does not have posgres installed
from google.cloud.sql.connector import Connector
gcp_connector = Connector()
self._gcp_connector = gcp_connector
connector = Connector()
instance_string = f'{self.gcp_project}:{self.gcp_region}:{self.gcp_db_instance}'
password = self.password
assert password is not None
return connector.connect(
return gcp_connector.connect(
instance_string,
'pg8000',
user=self.user,
@@ -81,21 +88,25 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
)
async def _create_async_gcp_db_connection(self):
# Lazy import because lib does not import if user does not have posgres installed
from google.cloud.sql.connector import Connector
gcp_connector = self._gcp_connector
if gcp_connector is None:
# Lazy import because lib does not import if user does not have posgres installed
from google.cloud.sql.connector import Connector
loop = asyncio.get_running_loop()
async with Connector(loop=loop) as connector:
password = self.password
assert password is not None
conn = await connector.connect_async(
f'{self.gcp_project}:{self.gcp_region}:{self.gcp_db_instance}',
'asyncpg',
user=self.user,
password=password.get_secret_value(),
db=self.name,
)
return conn
loop = asyncio.get_running_loop()
gcp_connector = Connector(loop=loop)
self._gcp_connector = gcp_connector
password = self.password
assert password is not None
conn = await gcp_connector.connect_async(
f'{self.gcp_project}:{self.gcp_region}:{self.gcp_db_instance}',
'asyncpg',
user=self.user,
password=password.get_secret_value(),
db=self.name,
)
return conn
def _create_gcp_engine(self):
engine = create_engine(
@@ -112,10 +123,8 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
AsyncAdapt_asyncpg_connection,
)
engine = self._create_gcp_engine()
return AsyncAdapt_asyncpg_connection(
engine.dialect.dbapi,
asyncpg,
await self._create_async_gcp_db_connection(),
prepared_statement_cache_size=100,
)
@@ -125,12 +134,9 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
AsyncAdapt_asyncpg_connection,
)
base_engine = self._create_gcp_engine()
dbapi = base_engine.dialect.dbapi
def adapted_creator():
return AsyncAdapt_asyncpg_connection(
dbapi,
asyncpg,
await_only(self._create_async_gcp_db_connection()),
prepared_statement_cache_size=100,
)
@@ -141,6 +147,7 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
pool_size=self.pool_size,
max_overflow=self.max_overflow,
pool_pre_ping=True,
pool_recycle=self.pool_recycle,
)
async def get_async_db_engine(self) -> AsyncEngine:
@@ -174,6 +181,7 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
url,
pool_size=self.pool_size,
max_overflow=self.max_overflow,
pool_recycle=self.pool_recycle,
pool_pre_ping=True,
)
else:
@@ -214,6 +222,7 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
url,
pool_size=self.pool_size,
max_overflow=self.max_overflow,
pool_recycle=self.pool_recycle,
pool_pre_ping=True,
)
self._engine = engine

View File

@@ -456,13 +456,10 @@ class TestDbSessionInjectorGCPIntegration:
# Mock the google.cloud.sql.connector module
with patch.dict('sys.modules', {'google.cloud.sql.connector': MagicMock()}):
mock_connector_module = sys.modules['google.cloud.sql.connector']
mock_connector = AsyncMock()
mock_connector_module.Connector.return_value.__aenter__.return_value = (
mock_connector
)
mock_connector_module.Connector.return_value.__aexit__.return_value = None
mock_connector = MagicMock()
mock_connector_module.Connector.return_value = mock_connector
mock_connection = AsyncMock()
mock_connector.connect_async.return_value = mock_connection
mock_connector.connect_async = AsyncMock(return_value=mock_connection)
connection = await gcp_db_session_injector._create_async_gcp_db_connection()