mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-08 22:38:05 -05:00
ALL-4627 Database Fixes (#12156)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -210,11 +210,16 @@ async def start_app_conversation(
|
||||
set_db_session_keep_open(request.state, True)
|
||||
set_httpx_client_keep_open(request.state, True)
|
||||
|
||||
"""Start an app conversation start task and return it."""
|
||||
async_iter = app_conversation_service.start_app_conversation(start_request)
|
||||
result = await anext(async_iter)
|
||||
asyncio.create_task(_consume_remaining(async_iter, db_session, httpx_client))
|
||||
return result
|
||||
try:
|
||||
"""Start an app conversation start task and return it."""
|
||||
async_iter = app_conversation_service.start_app_conversation(start_request)
|
||||
result = await anext(async_iter)
|
||||
asyncio.create_task(_consume_remaining(async_iter, db_session, httpx_client))
|
||||
return result
|
||||
except Exception:
|
||||
await db_session.close()
|
||||
await httpx_client.aclose()
|
||||
raise
|
||||
|
||||
|
||||
@router.post('/stream-start')
|
||||
|
||||
@@ -21,12 +21,10 @@ from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
)
|
||||
from openhands.app_server.config import (
|
||||
depends_app_conversation_info_service,
|
||||
depends_db_session,
|
||||
depends_event_service,
|
||||
depends_jwt_service,
|
||||
depends_sandbox_service,
|
||||
get_event_callback_service,
|
||||
get_global_config,
|
||||
)
|
||||
from openhands.app_server.errors import AuthError
|
||||
from openhands.app_server.event.event_service import EventService
|
||||
@@ -54,8 +52,6 @@ sandbox_service_dependency = depends_sandbox_service()
|
||||
event_service_dependency = depends_event_service()
|
||||
app_conversation_info_service_dependency = depends_app_conversation_info_service()
|
||||
jwt_dependency = depends_jwt_service()
|
||||
config = get_global_config()
|
||||
db_session_dependency = depends_db_session()
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -4,8 +4,9 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import asyncpg
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel, PrivateAttr, SecretStr, model_validator
|
||||
from sqlalchemy import Engine, create_engine
|
||||
@@ -33,6 +34,7 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
|
||||
echo: bool = False
|
||||
pool_size: int = 25
|
||||
max_overflow: int = 10
|
||||
pool_recycle: int = 1800
|
||||
gcp_db_instance: str | None = None
|
||||
gcp_project: str | None = None
|
||||
gcp_region: str | None = None
|
||||
@@ -42,6 +44,7 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
|
||||
_async_engine: AsyncEngine | None = PrivateAttr(default=None)
|
||||
_session_maker: sessionmaker | None = PrivateAttr(default=None)
|
||||
_async_session_maker: async_sessionmaker | None = PrivateAttr(default=None)
|
||||
_gcp_connector: Any = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode='after')
|
||||
def fill_empty_fields(self):
|
||||
@@ -65,14 +68,18 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
|
||||
return self
|
||||
|
||||
def _create_gcp_db_connection(self):
|
||||
# Lazy import because lib does not import if user does not have posgres installed
|
||||
from google.cloud.sql.connector import Connector
|
||||
gcp_connector = self._gcp_connector
|
||||
if gcp_connector is None:
|
||||
# Lazy import because lib does not import if user does not have posgres installed
|
||||
from google.cloud.sql.connector import Connector
|
||||
|
||||
gcp_connector = Connector()
|
||||
self._gcp_connector = gcp_connector
|
||||
|
||||
connector = Connector()
|
||||
instance_string = f'{self.gcp_project}:{self.gcp_region}:{self.gcp_db_instance}'
|
||||
password = self.password
|
||||
assert password is not None
|
||||
return connector.connect(
|
||||
return gcp_connector.connect(
|
||||
instance_string,
|
||||
'pg8000',
|
||||
user=self.user,
|
||||
@@ -81,21 +88,25 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
|
||||
)
|
||||
|
||||
async def _create_async_gcp_db_connection(self):
|
||||
# Lazy import because lib does not import if user does not have posgres installed
|
||||
from google.cloud.sql.connector import Connector
|
||||
gcp_connector = self._gcp_connector
|
||||
if gcp_connector is None:
|
||||
# Lazy import because lib does not import if user does not have posgres installed
|
||||
from google.cloud.sql.connector import Connector
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
async with Connector(loop=loop) as connector:
|
||||
password = self.password
|
||||
assert password is not None
|
||||
conn = await connector.connect_async(
|
||||
f'{self.gcp_project}:{self.gcp_region}:{self.gcp_db_instance}',
|
||||
'asyncpg',
|
||||
user=self.user,
|
||||
password=password.get_secret_value(),
|
||||
db=self.name,
|
||||
)
|
||||
return conn
|
||||
loop = asyncio.get_running_loop()
|
||||
gcp_connector = Connector(loop=loop)
|
||||
self._gcp_connector = gcp_connector
|
||||
|
||||
password = self.password
|
||||
assert password is not None
|
||||
conn = await gcp_connector.connect_async(
|
||||
f'{self.gcp_project}:{self.gcp_region}:{self.gcp_db_instance}',
|
||||
'asyncpg',
|
||||
user=self.user,
|
||||
password=password.get_secret_value(),
|
||||
db=self.name,
|
||||
)
|
||||
return conn
|
||||
|
||||
def _create_gcp_engine(self):
|
||||
engine = create_engine(
|
||||
@@ -112,10 +123,8 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
|
||||
AsyncAdapt_asyncpg_connection,
|
||||
)
|
||||
|
||||
engine = self._create_gcp_engine()
|
||||
|
||||
return AsyncAdapt_asyncpg_connection(
|
||||
engine.dialect.dbapi,
|
||||
asyncpg,
|
||||
await self._create_async_gcp_db_connection(),
|
||||
prepared_statement_cache_size=100,
|
||||
)
|
||||
@@ -125,12 +134,9 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
|
||||
AsyncAdapt_asyncpg_connection,
|
||||
)
|
||||
|
||||
base_engine = self._create_gcp_engine()
|
||||
dbapi = base_engine.dialect.dbapi
|
||||
|
||||
def adapted_creator():
|
||||
return AsyncAdapt_asyncpg_connection(
|
||||
dbapi,
|
||||
asyncpg,
|
||||
await_only(self._create_async_gcp_db_connection()),
|
||||
prepared_statement_cache_size=100,
|
||||
)
|
||||
@@ -141,6 +147,7 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
|
||||
pool_size=self.pool_size,
|
||||
max_overflow=self.max_overflow,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=self.pool_recycle,
|
||||
)
|
||||
|
||||
async def get_async_db_engine(self) -> AsyncEngine:
|
||||
@@ -174,6 +181,7 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
|
||||
url,
|
||||
pool_size=self.pool_size,
|
||||
max_overflow=self.max_overflow,
|
||||
pool_recycle=self.pool_recycle,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
else:
|
||||
@@ -214,6 +222,7 @@ class DbSessionInjector(BaseModel, Injector[async_sessionmaker]):
|
||||
url,
|
||||
pool_size=self.pool_size,
|
||||
max_overflow=self.max_overflow,
|
||||
pool_recycle=self.pool_recycle,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
self._engine = engine
|
||||
|
||||
Reference in New Issue
Block a user