mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-12 00:49:08 -05:00
feat: Add IAM authentication to AlloyDB Source (#399)
Add IAM support for AlloyDB source connection: https://pkg.go.dev/cloud.google.com/go/alloydbconn#section-readme
This commit is contained in:
@@ -38,6 +38,11 @@ permissions):
|
||||
- `roles/alloydb.client`
|
||||
- `roles/serviceusage.serviceUsageConsumer`
|
||||
|
||||
To connect to your AlloyDB Source using IAM authentication:
|
||||
|
||||
1. Specify your IAM email as the `user` or leave it blank for Toolbox to fetch from ADC.
|
||||
2. Leave the `password` field blank.
|
||||
|
||||
[alloydb-go-conn]: https://github.com/GoogleCloudPlatform/alloydb-go-connector
|
||||
[adc]: https://cloud.google.com/docs/authentication#adc
|
||||
[set-adc]: https://cloud.google.com/docs/authentication/provide-credentials-adc
|
||||
@@ -83,14 +88,14 @@ sources:
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-----------|:--------:|:------------:|-------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "alloydb-postgres". |
|
||||
| project | string | true | Id of the GCP project that the cluster was created in (e.g. "my-project-id"). |
|
||||
| region | string | true | Name of the GCP region that the cluster was created in (e.g. "us-central1"). |
|
||||
| cluster | string | true | Name of the AlloyDB cluster (e.g. "my-cluster"). |
|
||||
| instance | string | true | Name of the AlloyDB instance within the cluster (e.g. "my-instance"). |
|
||||
| database | string | true | Name of the Postgres database to connect to (e.g. "my_db"). |
|
||||
| user | string | true | Name of the Postgres user to connect as (e.g. "my-pg-user"). |
|
||||
| password | string | true | Password of the Postgres user (e.g. "my-password"). |
|
||||
| ipType | string | false | IP Type of the AlloyDB instance; must be one of `public` or `private`. Default: `public`. |
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-----------|:--------:|:------------:|--------------------------------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "alloydb-postgres". |
|
||||
| project | string | true | Id of the GCP project that the cluster was created in (e.g. "my-project-id"). |
|
||||
| region | string | true | Name of the GCP region that the cluster was created in (e.g. "us-central1"). |
|
||||
| cluster | string | true | Name of the AlloyDB cluster (e.g. "my-cluster"). |
|
||||
| instance | string | true | Name of the AlloyDB instance within the cluster (e.g. "my-instance"). |
|
||||
| database | string | true | Name of the Postgres database to connect to (e.g. "my_db"). |
|
||||
| user | string | false | Name of the Postgres user to connect as (e.g. "my-pg-user"). Defaults to IAM auth using [ADC][adc] email if unspecified. |
|
||||
| password | string | false | Password of the Postgres user (e.g. "my-password"). Defaults to attempting IAM authentication if unspecified. |
|
||||
| ipType | string | false | IP Type of the AlloyDB instance; must be one of `public` or `private`. Default: `public`. |
|
||||
|
||||
2
go.mod
2
go.mod
@@ -30,6 +30,7 @@ require (
|
||||
go.opentelemetry.io/otel/sdk v1.35.0
|
||||
go.opentelemetry.io/otel/sdk/metric v1.35.0
|
||||
go.opentelemetry.io/otel/trace v1.35.0
|
||||
golang.org/x/oauth2 v0.28.0
|
||||
google.golang.org/api v0.228.0
|
||||
)
|
||||
|
||||
@@ -89,7 +90,6 @@ require (
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/crypto v0.36.0 // indirect
|
||||
golang.org/x/net v0.37.0 // indirect
|
||||
golang.org/x/oauth2 v0.28.0 // indirect
|
||||
golang.org/x/sync v0.12.0 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
golang.org/x/text v0.23.0 // indirect
|
||||
|
||||
@@ -40,8 +40,8 @@ type Config struct {
|
||||
Cluster string `yaml:"cluster" validate:"required"`
|
||||
Instance string `yaml:"instance" validate:"required"`
|
||||
IPType sources.IPType `yaml:"ipType" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
Password string `yaml:"password" validate:"required"`
|
||||
User string `yaml:"user"`
|
||||
Password string `yaml:"password"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ func (s *Source) PostgresPool() *pgxpool.Pool {
|
||||
return s.Pool
|
||||
}
|
||||
|
||||
func getOpts(ipType, userAgent string) ([]alloydbconn.Option, error) {
|
||||
func getOpts(ipType, userAgent string, useIAM bool) ([]alloydbconn.Option, error) {
|
||||
opts := []alloydbconn.Option{alloydbconn.WithUserAgent(userAgent)}
|
||||
switch strings.ToLower(ipType) {
|
||||
case "private":
|
||||
@@ -94,27 +94,62 @@ func getOpts(ipType, userAgent string) ([]alloydbconn.Option, error) {
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid ipType %s", ipType)
|
||||
}
|
||||
|
||||
if useIAM {
|
||||
opts = append(opts, alloydbconn.WithIAMAuthN())
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func getConnectionConfig(ctx context.Context, user, pass, dbname string) (string, bool, error) {
|
||||
useIAM := true
|
||||
|
||||
// If username and password both provided, use password authentication
|
||||
if user != "" && pass != "" {
|
||||
dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", user, pass, dbname)
|
||||
useIAM = false
|
||||
return dsn, useIAM, nil
|
||||
}
|
||||
|
||||
// If username is empty, fetch email from ADC
|
||||
// otherwise, use username as IAM email
|
||||
if user == "" {
|
||||
if pass != "" {
|
||||
// If password is provided without an username, raise an error
|
||||
return "", useIAM, fmt.Errorf("password is provided without a username. Please provide both a username and password, or leave both fields empty")
|
||||
}
|
||||
email, err := sources.GetIAMPrincipalEmailFromADC(ctx)
|
||||
if err != nil {
|
||||
return "", useIAM, fmt.Errorf("error getting email from ADC: %v", err)
|
||||
}
|
||||
user = email
|
||||
}
|
||||
|
||||
// Construct IAM connection string with username
|
||||
dsn := fmt.Sprintf("user=%s dbname=%s sslmode=disable", user, dbname)
|
||||
return dsn, useIAM, nil
|
||||
}
|
||||
|
||||
func initAlloyDBPgConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, cluster, instance, ipType, user, pass, dbname string) (*pgxpool.Pool, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
// Configure the driver to connect to the database
|
||||
dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", user, pass, dbname)
|
||||
dsn, useIAM, err := getConnectionConfig(ctx, user, pass, dbname)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get AlloyDB connection config: %w", err)
|
||||
}
|
||||
|
||||
config, err := pgxpool.ParseConfig(dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
|
||||
}
|
||||
|
||||
// Create a new dialer with options
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
opts, err := getOpts(ipType, userAgent)
|
||||
opts, err := getOpts(ipType, userAgent, useIAM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -15,10 +15,15 @@
|
||||
package sources
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"cloud.google.com/go/cloudsqlconn"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
// GetCloudSQLDialOpts retrieve dial options with the right ip type and user agent for cloud sql
|
||||
@@ -35,3 +40,44 @@ func GetCloudSQLOpts(ipType, userAgent string) ([]cloudsqlconn.Option, error) {
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
// GetIAMPrincipalEmailFromADC finds the email associated with ADC
|
||||
func GetIAMPrincipalEmailFromADC(ctx context.Context) (string, error) {
|
||||
// Finds ADC and returns an HTTP client associated with it
|
||||
client, err := google.DefaultClient(ctx,
|
||||
"https://www.googleapis.com/auth/userinfo.email")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to call userinfo endpoint: %w", err)
|
||||
}
|
||||
|
||||
// Retrieve the email associated with the token
|
||||
resp, err := client.Get("https://oauth2.googleapis.com/tokeninfo")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to call tokeninfo endpoint: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error reading response body %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("tokeninfo endpoint returned non-OK status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Unmarshal response body and get `email`
|
||||
var responseJSON map[string]any
|
||||
err = json.Unmarshal(bodyBytes, &responseJSON)
|
||||
if err != nil {
|
||||
|
||||
return "", fmt.Errorf("error parsing JSON: %v", err)
|
||||
}
|
||||
|
||||
emailValue, ok := responseJSON["email"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("email not found in response: %v", err)
|
||||
}
|
||||
// service account email used for IAM should trim the suffix
|
||||
email := strings.TrimSuffix(emailValue.(string), ".gserviceaccount.com")
|
||||
return email, nil
|
||||
}
|
||||
|
||||
@@ -188,7 +188,77 @@ func TestAlloyDBIpConnection(t *testing.T) {
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
sourceConfig["ipType"] = tc.ipType
|
||||
RunSourceConnectionTest(t, sourceConfig, ALLOYDB_POSTGRES_TOOL_KIND)
|
||||
err := RunSourceConnectionTest(t, sourceConfig, ALLOYDB_POSTGRES_TOOL_KIND)
|
||||
if err != nil {
|
||||
t.Fatalf("Connection test failure: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test IAM connection
|
||||
func TestAlloyDBIAMConnection(t *testing.T) {
|
||||
getAlloyDBPgVars(t)
|
||||
// service account email used for IAM should trim the suffix
|
||||
serviceAccountEmail := strings.TrimSuffix(SERVICE_ACCOUNT_EMAIL, ".gserviceaccount.com")
|
||||
|
||||
noPassSourceConfig := map[string]any{
|
||||
"kind": ALLOYDB_POSTGRES_SOURCE_KIND,
|
||||
"project": ALLOYDB_POSTGRES_PROJECT,
|
||||
"cluster": ALLOYDB_POSTGRES_CLUSTER,
|
||||
"instance": ALLOYDB_POSTGRES_INSTANCE,
|
||||
"region": ALLOYDB_POSTGRES_REGION,
|
||||
"database": ALLOYDB_POSTGRES_DATABASE,
|
||||
"user": serviceAccountEmail,
|
||||
}
|
||||
|
||||
noUserSourceConfig := map[string]any{
|
||||
"kind": ALLOYDB_POSTGRES_SOURCE_KIND,
|
||||
"project": ALLOYDB_POSTGRES_PROJECT,
|
||||
"cluster": ALLOYDB_POSTGRES_CLUSTER,
|
||||
"instance": ALLOYDB_POSTGRES_INSTANCE,
|
||||
"region": ALLOYDB_POSTGRES_REGION,
|
||||
"database": ALLOYDB_POSTGRES_DATABASE,
|
||||
"password": "random",
|
||||
}
|
||||
|
||||
noUserNoPassSourceConfig := map[string]any{
|
||||
"kind": ALLOYDB_POSTGRES_SOURCE_KIND,
|
||||
"project": ALLOYDB_POSTGRES_PROJECT,
|
||||
"cluster": ALLOYDB_POSTGRES_CLUSTER,
|
||||
"instance": ALLOYDB_POSTGRES_INSTANCE,
|
||||
"region": ALLOYDB_POSTGRES_REGION,
|
||||
"database": ALLOYDB_POSTGRES_DATABASE,
|
||||
}
|
||||
tcs := []struct {
|
||||
name string
|
||||
sourceConfig map[string]any
|
||||
isErr bool
|
||||
}{
|
||||
{
|
||||
name: "no user no pass",
|
||||
sourceConfig: noUserNoPassSourceConfig,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "no password",
|
||||
sourceConfig: noPassSourceConfig,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "no user",
|
||||
sourceConfig: noUserSourceConfig,
|
||||
isErr: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := RunSourceConnectionTest(t, tc.sourceConfig, ALLOYDB_POSTGRES_TOOL_KIND)
|
||||
if err != nil {
|
||||
if !tc.isErr {
|
||||
t.Fatalf("Connection test failure: %s", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,7 +175,10 @@ func TestCloudSQLMssqlIpConnection(t *testing.T) {
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
sourceConfig["ipType"] = tc.ipType
|
||||
RunSourceConnectionTest(t, sourceConfig, CLOUD_SQL_MSSQL_TOOL_KIND)
|
||||
err := RunSourceConnectionTest(t, sourceConfig, CLOUD_SQL_MSSQL_TOOL_KIND)
|
||||
if err != nil {
|
||||
t.Fatalf("Connection test failure: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,7 +170,10 @@ func TestCloudSQLMysqlIpConnection(t *testing.T) {
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
sourceConfig["ipType"] = tc.ipType
|
||||
RunSourceConnectionTest(t, sourceConfig, CLOUD_SQL_MYSQL_TOOL_KIND)
|
||||
err := RunSourceConnectionTest(t, sourceConfig, CLOUD_SQL_MYSQL_TOOL_KIND)
|
||||
if err != nil {
|
||||
t.Fatalf("Connection test failure: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,7 +174,10 @@ func TestCloudSQLPgIpConnection(t *testing.T) {
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
sourceConfig["ipType"] = tc.ipType
|
||||
RunSourceConnectionTest(t, sourceConfig, CLOUD_SQL_POSTGRES_TOOL_KIND)
|
||||
err := RunSourceConnectionTest(t, sourceConfig, CLOUD_SQL_POSTGRES_TOOL_KIND)
|
||||
if err != nil {
|
||||
t.Fatalf("Connection test failure: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ import (
|
||||
)
|
||||
|
||||
// RunSourceConnection test for source connection
|
||||
func RunSourceConnectionTest(t *testing.T, sourceConfig map[string]any, toolKind string) {
|
||||
func RunSourceConnectionTest(t *testing.T, sourceConfig map[string]any, toolKind string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
@@ -53,7 +53,7 @@ func RunSourceConnectionTest(t *testing.T, sourceConfig map[string]any, toolKind
|
||||
}
|
||||
cmd, cleanup, err := StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
return fmt.Errorf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
@@ -62,8 +62,9 @@ func RunSourceConnectionTest(t *testing.T, sourceConfig map[string]any, toolKind
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
return fmt.Errorf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCloudSQLDialOpts returns cloud sql connector's dial option for ip type.
|
||||
|
||||
Reference in New Issue
Block a user