mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
9 Commits
fix/git-di
...
rds-iam-au
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4e3b5796a5 | ||
|
|
09abb3fc8f | ||
|
|
7ae445618f | ||
|
|
6a3a3c8e2c | ||
|
|
df55ce4001 | ||
|
|
b61200592f | ||
|
|
d2faf9e180 | ||
|
|
eafdadc2be | ||
|
|
5d505fb292 |
@@ -3,16 +3,18 @@ from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from google.cloud.sql.connector import Connector
|
||||
from sqlalchemy import create_engine
|
||||
from storage.base import Base
|
||||
from sqlalchemy import create_engine, event
|
||||
|
||||
target_metadata = Base.metadata
|
||||
from storage.base import Base
|
||||
|
||||
DB_USER = os.getenv('DB_USER', 'postgres')
|
||||
DB_PASS = os.getenv('DB_PASS', 'postgres')
|
||||
DB_HOST = os.getenv('DB_HOST', 'localhost')
|
||||
DB_PORT = os.getenv('DB_PORT', '5432')
|
||||
DB_NAME = os.getenv('DB_NAME', 'openhands')
|
||||
DB_SCHEMA = os.getenv('DB_SCHEMA')
|
||||
DB_AUTH_TYPE = os.getenv('DB_AUTH_TYPE', 'password') # 'password' or 'rds-iam'
|
||||
AWS_REGION = os.getenv('AWS_REGION', 'us-east-1') # AWS region for RDS IAM auth
|
||||
|
||||
GCP_DB_INSTANCE = os.getenv('GCP_DB_INSTANCE')
|
||||
GCP_PROJECT = os.getenv('GCP_PROJECT')
|
||||
@@ -21,6 +23,24 @@ GCP_REGION = os.getenv('GCP_REGION')
|
||||
POOL_SIZE = int(os.getenv('DB_POOL_SIZE', '25'))
|
||||
MAX_OVERFLOW = int(os.getenv('DB_MAX_OVERFLOW', '10'))
|
||||
|
||||
target_metadata = Base.metadata
|
||||
# Set schema for target metadata if DB_SCHEMA is provided
|
||||
if DB_SCHEMA:
|
||||
target_metadata.schema = DB_SCHEMA
|
||||
|
||||
# RDS IAM authentication setup
|
||||
if DB_AUTH_TYPE == 'rds-iam':
|
||||
import boto3
|
||||
|
||||
# boto3 client (reused for token generation)
|
||||
rds = boto3.client('rds', region_name=AWS_REGION)
|
||||
|
||||
def get_auth_token():
|
||||
"""Generate a fresh IAM DB auth token."""
|
||||
return rds.generate_db_auth_token(
|
||||
DBHostname=DB_HOST, Port=DB_PORT, DBUsername=DB_USER
|
||||
)
|
||||
|
||||
|
||||
def get_engine(database_name=DB_NAME):
|
||||
"""Create SQLAlchemy engine with optional database name."""
|
||||
@@ -29,29 +49,87 @@ def get_engine(database_name=DB_NAME):
|
||||
def get_db_connection():
|
||||
connector = Connector()
|
||||
instance_string = f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}'
|
||||
return connector.connect(
|
||||
instance_string,
|
||||
'pg8000',
|
||||
user=DB_USER,
|
||||
password=DB_PASS.strip(),
|
||||
db=database_name,
|
||||
)
|
||||
connect_kwargs = {
|
||||
'user': DB_USER,
|
||||
'password': DB_PASS.strip(),
|
||||
'db': database_name,
|
||||
}
|
||||
# Note: pg8000 doesn't accept 'options' parameter, so we'll handle schema via SQL
|
||||
# Schema will be set after connection via event listener
|
||||
return connector.connect(instance_string, 'pg8000', **connect_kwargs)
|
||||
|
||||
return create_engine(
|
||||
engine = create_engine(
|
||||
'postgresql+pg8000://',
|
||||
creator=get_db_connection,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
# Set schema via SQL after connection if specified
|
||||
if DB_SCHEMA:
|
||||
@event.listens_for(engine, 'connect')
|
||||
def set_search_path(dbapi_connection, connection_record):
|
||||
with dbapi_connection.cursor() as cursor:
|
||||
cursor.execute(f"SET search_path TO {DB_SCHEMA}")
|
||||
dbapi_connection.commit()
|
||||
|
||||
return engine
|
||||
else:
|
||||
url = f'postgresql://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{database_name}'
|
||||
return create_engine(
|
||||
url,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
if DB_AUTH_TYPE == 'rds-iam':
|
||||
# Build a SQLAlchemy connection URL with a dummy password — token will be injected dynamically
|
||||
# Note: SSL is enabled by default for pg8000 when connecting to RDS
|
||||
# For pg8000, we cannot use URL parameters like options, so schema must be handled differently
|
||||
base_url = (
|
||||
f'postgresql+pg8000://{DB_USER}:dummy-password'
|
||||
f'@{DB_HOST}:{DB_PORT}/{database_name}'
|
||||
)
|
||||
engine = create_engine(
|
||||
base_url,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
# Hook: before a connection is made, inject a fresh token
|
||||
@event.listens_for(engine, 'do_connect')
|
||||
def provide_token(dialect, conn_rec, cargs, cparams):
|
||||
token = get_auth_token()
|
||||
# Replace password in connect arguments
|
||||
cparams['password'] = token
|
||||
return dialect.connect(*cargs, **cparams)
|
||||
|
||||
# Hook: after connection is established, set the schema if specified
|
||||
if DB_SCHEMA:
|
||||
@event.listens_for(engine, 'connect')
|
||||
def set_search_path(dbapi_connection, connection_record):
|
||||
with dbapi_connection.cursor() as cursor:
|
||||
cursor.execute(f"SET search_path TO {DB_SCHEMA}")
|
||||
dbapi_connection.commit()
|
||||
|
||||
return engine
|
||||
else:
|
||||
# Regular password authentication
|
||||
# Use postgresql:// (default driver) but handle schema via SQL to be safe
|
||||
url = (
|
||||
f'postgresql://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{database_name}'
|
||||
)
|
||||
engine = create_engine(
|
||||
url,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
# Set schema via SQL after connection if specified
|
||||
if DB_SCHEMA:
|
||||
@event.listens_for(engine, 'connect')
|
||||
def set_search_path(dbapi_connection, connection_record):
|
||||
with dbapi_connection.cursor() as cursor:
|
||||
cursor.execute(f"SET search_path TO {DB_SCHEMA}")
|
||||
dbapi_connection.commit()
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
engine = get_engine()
|
||||
@@ -83,6 +161,7 @@ def run_migrations_offline() -> None:
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={'paramstyle': 'named'},
|
||||
version_table_schema=target_metadata.schema,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from google.cloud.sql.connector import Connector
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
@@ -13,6 +14,9 @@ DB_PORT = os.environ.get('DB_PORT', '5432') # for non-GCP environments
|
||||
DB_USER = os.environ.get('DB_USER', 'postgres')
|
||||
DB_PASS = os.environ.get('DB_PASS', 'postgres').strip()
|
||||
DB_NAME = os.environ.get('DB_NAME', 'openhands')
|
||||
DB_SCHEMA = os.environ.get('DB_SCHEMA') # PostgreSQL schema name
|
||||
DB_AUTH_TYPE = os.environ.get('DB_AUTH_TYPE', 'password') # 'password' or 'rds-iam'
|
||||
AWS_REGION = os.environ.get('AWS_REGION', 'us-east-1') # AWS region for RDS IAM auth
|
||||
|
||||
GCP_DB_INSTANCE = os.environ.get('GCP_DB_INSTANCE') # for GCP environments
|
||||
GCP_PROJECT = os.environ.get('GCP_PROJECT')
|
||||
@@ -21,6 +25,19 @@ GCP_REGION = os.environ.get('GCP_REGION')
|
||||
POOL_SIZE = int(os.environ.get('DB_POOL_SIZE', '25'))
|
||||
MAX_OVERFLOW = int(os.environ.get('DB_MAX_OVERFLOW', '10'))
|
||||
|
||||
# RDS IAM authentication setup
|
||||
if DB_AUTH_TYPE == 'rds-iam':
|
||||
import boto3
|
||||
|
||||
# boto3 client (reused for token generation)
|
||||
rds = boto3.client('rds', region_name=AWS_REGION)
|
||||
|
||||
def get_auth_token():
|
||||
"""Generate a fresh IAM DB auth token."""
|
||||
return rds.generate_db_auth_token(
|
||||
DBHostname=DB_HOST, Port=DB_PORT, DBUsername=DB_USER
|
||||
)
|
||||
|
||||
|
||||
def _get_db_engine():
|
||||
if GCP_DB_INSTANCE: # GCP environments
|
||||
@@ -28,38 +45,104 @@ def _get_db_engine():
|
||||
def get_db_connection():
|
||||
connector = Connector()
|
||||
instance_string = f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}'
|
||||
return connector.connect(
|
||||
instance_string, 'pg8000', user=DB_USER, password=DB_PASS, db=DB_NAME
|
||||
)
|
||||
connect_kwargs = {
|
||||
'user': DB_USER,
|
||||
'password': DB_PASS,
|
||||
'db': DB_NAME,
|
||||
}
|
||||
# Note: pg8000 doesn't accept 'options' parameter, so we'll handle schema via SQL
|
||||
# Schema will be set after connection via event listener
|
||||
return connector.connect(instance_string, 'pg8000', **connect_kwargs)
|
||||
|
||||
return create_engine(
|
||||
engine = create_engine(
|
||||
'postgresql+pg8000://',
|
||||
creator=get_db_connection,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
# Set schema via SQL after connection if specified
|
||||
if DB_SCHEMA:
|
||||
@event.listens_for(engine, 'connect')
|
||||
def set_search_path(dbapi_connection, connection_record):
|
||||
with dbapi_connection.cursor() as cursor:
|
||||
cursor.execute(f"SET search_path TO {DB_SCHEMA}")
|
||||
dbapi_connection.commit()
|
||||
|
||||
return engine
|
||||
else:
|
||||
host_string = (
|
||||
f'postgresql+pg8000://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}'
|
||||
)
|
||||
return create_engine(
|
||||
host_string,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
if DB_AUTH_TYPE == 'rds-iam':
|
||||
# Build a SQLAlchemy connection URL with a dummy password — token will be injected dynamically
|
||||
# Note: SSL is enabled by default for pg8000 when connecting to RDS
|
||||
# For pg8000, we cannot use URL parameters like options, so schema must be handled differently
|
||||
base_url = (
|
||||
f'postgresql+pg8000://{DB_USER}:dummy-password'
|
||||
f'@{DB_HOST}:{DB_PORT}/{DB_NAME}'
|
||||
)
|
||||
engine = create_engine(
|
||||
base_url,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
# Hook: before a connection is made, inject a fresh token
|
||||
@event.listens_for(engine, 'do_connect')
|
||||
def provide_token(dialect, conn_rec, cargs, cparams):
|
||||
token = get_auth_token()
|
||||
# Replace password in connect arguments
|
||||
cparams['password'] = token
|
||||
return dialect.connect(*cargs, **cparams)
|
||||
|
||||
# Hook: after connection is established, set the schema if specified
|
||||
if DB_SCHEMA:
|
||||
@event.listens_for(engine, 'connect')
|
||||
def set_search_path(dbapi_connection, connection_record):
|
||||
with dbapi_connection.cursor() as cursor:
|
||||
cursor.execute(f"SET search_path TO {DB_SCHEMA}")
|
||||
dbapi_connection.commit()
|
||||
|
||||
return engine
|
||||
else:
|
||||
# Regular password authentication with pg8000
|
||||
# pg8000 doesn't accept options as URL parameter, so handle schema via SQL
|
||||
host_string = (
|
||||
f'postgresql+pg8000://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}'
|
||||
)
|
||||
engine = create_engine(
|
||||
host_string,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
# Set schema via SQL after connection if specified
|
||||
if DB_SCHEMA:
|
||||
@event.listens_for(engine, 'connect')
|
||||
def set_search_path(dbapi_connection, connection_record):
|
||||
with dbapi_connection.cursor() as cursor:
|
||||
cursor.execute(f"SET search_path TO {DB_SCHEMA}")
|
||||
dbapi_connection.commit()
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
async def async_creator():
|
||||
loop = asyncio.get_running_loop()
|
||||
async with Connector(loop=loop) as connector:
|
||||
connect_kwargs: dict[str, Any] = {
|
||||
'user': DB_USER,
|
||||
'password': DB_PASS,
|
||||
'db': DB_NAME,
|
||||
}
|
||||
# Add schema support for async GCP connections
|
||||
if DB_SCHEMA:
|
||||
connect_kwargs['server_settings'] = {'search_path': DB_SCHEMA}
|
||||
conn = await connector.connect_async(
|
||||
f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}', # Cloud SQL instance connection name"
|
||||
'asyncpg',
|
||||
user=DB_USER,
|
||||
password=DB_PASS,
|
||||
db=DB_NAME,
|
||||
**connect_kwargs,
|
||||
)
|
||||
return conn
|
||||
|
||||
@@ -87,14 +170,49 @@ def _get_async_db_engine():
|
||||
poolclass=NullPool,
|
||||
)
|
||||
else:
|
||||
host_string = (
|
||||
f'postgresql+asyncpg://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}'
|
||||
)
|
||||
return create_async_engine(
|
||||
host_string,
|
||||
# Use NullPool to disable connection pooling and avoid event loop issues
|
||||
poolclass=NullPool,
|
||||
)
|
||||
if DB_AUTH_TYPE == 'rds-iam':
|
||||
# Build a SQLAlchemy connection URL with a dummy password — token will be injected dynamically
|
||||
# Note: SSL is enabled by default for asyncpg when connecting to RDS
|
||||
# For asyncpg, we cannot use URL parameters like options, so schema must be handled differently
|
||||
base_url = (
|
||||
f'postgresql+asyncpg://{DB_USER}:dummy-password'
|
||||
f'@{DB_HOST}:{DB_PORT}/{DB_NAME}'
|
||||
)
|
||||
engine = create_async_engine(
|
||||
base_url, echo=True, pool_pre_ping=True, poolclass=NullPool
|
||||
)
|
||||
|
||||
# Hook: before a connection is made, inject a fresh token and set schema
|
||||
@event.listens_for(engine.sync_engine, 'do_connect')
|
||||
def provide_token(dialect, conn_rec, cargs, cparams):
|
||||
token = get_auth_token()
|
||||
# Replace password in connect arguments
|
||||
cparams['password'] = token
|
||||
# Set schema via server_settings for asyncpg
|
||||
if DB_SCHEMA:
|
||||
cparams['server_settings'] = {'search_path': DB_SCHEMA}
|
||||
return dialect.connect(*cargs, **cparams)
|
||||
|
||||
return engine
|
||||
else:
|
||||
# Regular password authentication with asyncpg
|
||||
# asyncpg doesn't accept options as URL parameter, so handle schema via server_settings
|
||||
host_string = f'postgresql+asyncpg://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}'
|
||||
engine = create_async_engine(
|
||||
host_string,
|
||||
# Use NullPool to disable connection pooling and avoid event loop issues
|
||||
poolclass=NullPool,
|
||||
)
|
||||
|
||||
# Set schema via server_settings for asyncpg
|
||||
if DB_SCHEMA:
|
||||
@event.listens_for(engine.sync_engine, 'do_connect')
|
||||
def set_schema(dialect, conn_rec, cargs, cparams):
|
||||
# Set schema via server_settings for asyncpg
|
||||
cparams['server_settings'] = {'search_path': DB_SCHEMA}
|
||||
return dialect.connect(*cargs, **cparams)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
engine = _get_db_engine()
|
||||
|
||||
Reference in New Issue
Block a user