diff --git a/docs/en/resources/sources/cloud-sql-mysql.md b/docs/en/resources/sources/cloud-sql-mysql.md index 93b2f71b41..188bcbce26 100644 --- a/docs/en/resources/sources/cloud-sql-mysql.md +++ b/docs/en/resources/sources/cloud-sql-mysql.md @@ -88,13 +88,40 @@ mTLS. [public-ip]: https://cloud.google.com/sql/docs/mysql/configure-ip [conn-overview]: https://cloud.google.com/sql/docs/mysql/connect-overview -### Database User +### Authentication -Currently, this source only uses standard authentication. You will need to [create -a MySQL user][cloud-sql-users] to login to the database with. +This source supports both password-based authentication and IAM +authentication (using your [Application Default Credentials][adc]). + +#### Standard Authentication + +To connect using user/password, [create +a MySQL user][cloud-sql-users] and input your credentials in the `user` and +`password` fields. + +```yaml +user: ${USER_NAME} +password: ${PASSWORD} +``` [cloud-sql-users]: https://cloud.google.com/sql/docs/mysql/create-manage-users +#### IAM Authentication + +To connect using IAM authentication: + +1. Prepare your database instance and user following this [guide][iam-guide]. +2. You could choose one of the two ways to log in: + - Specify your IAM email as the `user`. + - Leave your `user` field blank. Toolbox will fetch the [ADC][adc] + automatically and log in using the email associated with it. + +3. Leave the `password` field blank. + +[iam-guide]: https://cloud.google.com/sql/docs/mysql/iam-logins +[cloudsql-users]: https://cloud.google.com/sql/docs/mysql/create-manage-users + + ## Example ```yaml @@ -124,6 +151,6 @@ instead of hardcoding your secrets into the configuration file. | region | string | true | Name of the GCP region that the cluster was created in (e.g. "us-central1"). | | instance | string | true | Name of the Cloud SQL instance within the cluster (e.g. "my-instance"). | | database | string | true | Name of the MySQL database to connect to (e.g. "my_db"). | -| user | string | true | Name of the MySQL user to connect as (e.g. "my-pg-user"). | -| password | string | true | Password of the MySQL user (e.g. "my-password"). | +| user | string | false | Name of the MySQL user to connect as (e.g "my-mysql-user"). Defaults to IAM auth using [ADC][adc] email if unspecified. | +| password | string | false | Password of the MySQL user (e.g. "my-password"). Defaults to attempting IAM authentication if unspecified. | | ipType | string | false | IP Type of the Cloud SQL instance, must be either `public`, `private`, or `psc`. Default: `public`. | diff --git a/internal/sources/alloydbpg/alloydb_pg.go b/internal/sources/alloydbpg/alloydb_pg.go index edf4310720..a5a7cb03aa 100644 --- a/internal/sources/alloydbpg/alloydb_pg.go +++ b/internal/sources/alloydbpg/alloydb_pg.go @@ -141,7 +141,7 @@ func getConnectionConfig(ctx context.Context, user, pass, dbname string) (string // If password is provided without an username, raise an error return "", useIAM, fmt.Errorf("password is provided without a username. Please provide both a username and password, or leave both fields empty") } - email, err := sources.GetIAMPrincipalEmailFromADC(ctx) + email, err := sources.GetIAMPrincipalEmailFromADC(ctx, "postgres") if err != nil { return "", useIAM, fmt.Errorf("error getting email from ADC: %v", err) } diff --git a/internal/sources/cloudsqlmysql/cloud_sql_mysql.go b/internal/sources/cloudsqlmysql/cloud_sql_mysql.go index 4bdee7f3a0..797985454b 100644 --- a/internal/sources/cloudsqlmysql/cloud_sql_mysql.go +++ b/internal/sources/cloudsqlmysql/cloud_sql_mysql.go @@ -54,8 +54,8 @@ type Config struct { Region string `yaml:"region" validate:"required"` Instance string `yaml:"instance" validate:"required"` IPType sources.IPType `yaml:"ipType"` - User string `yaml:"user" validate:"required"` - Password string `yaml:"password" validate:"required"` + User string `yaml:"user"` + Password string `yaml:"password"` Database string `yaml:"database" validate:"required"` } @@ -100,31 +100,89 @@ func (s *Source) MySQLPool() *sql.DB { return s.Pool } +func getConnectionConfig(ctx context.Context, user, pass string) (string, string, bool, error) { + useIAM := true + + // If username and password both provided, use password authentication + if user != "" && pass != "" { + useIAM = false + return user, pass, useIAM, nil + } + + // If username is empty, fetch email from ADC + // otherwise, use username as IAM email + if user == "" { + if pass != "" { + return "", "", useIAM, fmt.Errorf("password is provided without a username. Please provide both a username and password, or leave both fields empty") + } + email, err := sources.GetIAMPrincipalEmailFromADC(ctx, "mysql") + if err != nil { + return "", "", useIAM, fmt.Errorf("error getting email from ADC: %v", err) + } + user = email + } + + // Pass the user, empty password and useIAM set to true + return user, pass, useIAM, nil +} + func initCloudSQLMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) defer span.End() + // Configure the driver to connect to the database + user, pass, useIAM, err := getConnectionConfig(ctx, user, pass) + if err != nil { + return nil, fmt.Errorf("unable to get Cloud SQL connection config: %w", err) + } + // Create a new dialer with options userAgent, err := util.UserAgentFromContext(ctx) if err != nil { return nil, err } - opts, err := sources.GetCloudSQLOpts(ipType, userAgent, false) + opts, err := sources.GetCloudSQLOpts(ipType, userAgent, useIAM) if err != nil { return nil, err } - if !slices.Contains(sql.Drivers(), "cloudsql-mysql") { - _, err = mysql.RegisterDriver("cloudsql-mysql", opts...) - if err != nil { + // Use a unique driver name based on the source name. + driverName := fmt.Sprintf("cloudsql-mysql-%s", name) + + if !slices.Contains(sql.Drivers(), driverName) { + if _, err := mysql.RegisterDriver(driverName, opts...); err != nil { return nil, fmt.Errorf("unable to register driver: %w", err) } } + + var dsn string // Tell the driver to use the Cloud SQL Go Connector to create connections - dsn := fmt.Sprintf("%s:%s@cloudsql-mysql(%s:%s:%s)/%s?connectionAttributes=program_name:%s", user, pass, project, region, instance, dbname, url.QueryEscape(userAgent)) + if useIAM { + dsn = fmt.Sprintf("%s@%s(%s:%s:%s)/%s?connectionAttributes=program_name:%s", + user, + driverName, + project, + region, + instance, + dbname, + url.QueryEscape(userAgent), + ) + } else { + dsn = fmt.Sprintf("%s:%s@%s(%s:%s:%s)/%s?connectionAttributes=program_name:%s", + user, + pass, + driverName, + project, + region, + instance, + dbname, + url.QueryEscape(userAgent), + ) + } + db, err := sql.Open( - "cloudsql-mysql", + driverName, dsn, ) if err != nil { diff --git a/internal/sources/cloudsqlpg/cloud_sql_pg.go b/internal/sources/cloudsqlpg/cloud_sql_pg.go index f13c67d5d0..3de83993bb 100644 --- a/internal/sources/cloudsqlpg/cloud_sql_pg.go +++ b/internal/sources/cloudsqlpg/cloud_sql_pg.go @@ -120,7 +120,7 @@ func getConnectionConfig(ctx context.Context, user, pass, dbname string) (string // If password is provided without an username, raise an error return "", useIAM, fmt.Errorf("password is provided without a username. Please provide both a username and password, or leave both fields empty") } - email, err := sources.GetIAMPrincipalEmailFromADC(ctx) + email, err := sources.GetIAMPrincipalEmailFromADC(ctx, "postgres") if err != nil { return "", useIAM, fmt.Errorf("error getting email from ADC: %v", err) } diff --git a/internal/sources/util.go b/internal/sources/util.go index 0a78c1b2a6..d2b2210ddd 100644 --- a/internal/sources/util.go +++ b/internal/sources/util.go @@ -48,7 +48,7 @@ func GetCloudSQLOpts(ipType, userAgent string, useIAM bool) ([]cloudsqlconn.Opti } // GetIAMPrincipalEmailFromADC finds the email associated with ADC -func GetIAMPrincipalEmailFromADC(ctx context.Context) (string, error) { +func GetIAMPrincipalEmailFromADC(ctx context.Context, dbType string) (string, error) { // Finds ADC and returns an HTTP client associated with it client, err := google.DefaultClient(ctx, "https://www.googleapis.com/auth/userinfo.email") @@ -83,9 +83,31 @@ func GetIAMPrincipalEmailFromADC(ctx context.Context) (string, error) { if !ok { return "", fmt.Errorf("email not found in response: %v", err) } - // service account email used for IAM should trim the suffix - email := strings.TrimSuffix(emailValue.(string), ".gserviceaccount.com") - return email, nil + + fullEmail, ok := emailValue.(string) + if !ok { + return "", fmt.Errorf("email field is not a string") + } + + var username string + // Format the username based on Database Type + switch strings.ToLower(dbType) { + case "mysql": + username, _, _ = strings.Cut(fullEmail, "@") + + case "postgres": + // service account email used for IAM should trim the suffix + username = strings.TrimSuffix(fullEmail, ".gserviceaccount.com") + + default: + return "", fmt.Errorf("unsupported dbType: %s. Use 'mysql' or 'postgres'", dbType) + } + + if username == "" { + return "", fmt.Errorf("username from ADC cannot be an empty string") + } + + return username, nil } func GetIAMAccessToken(ctx context.Context) (string, error) { diff --git a/tests/cloudsqlmysql/cloud_sql_mysql_integration_test.go b/tests/cloudsqlmysql/cloud_sql_mysql_integration_test.go index 559ae5fb06..192c779ea9 100644 --- a/tests/cloudsqlmysql/cloud_sql_mysql_integration_test.go +++ b/tests/cloudsqlmysql/cloud_sql_mysql_integration_test.go @@ -192,3 +192,104 @@ func TestCloudSQLMySQLIpConnection(t *testing.T) { }) } } + +func TestCloudSQLMySQLIAMConnection(t *testing.T) { + getCloudSQLMySQLVars(t) + // service account email used for IAM should trim the suffix + serviceAccountEmail, _, _ := strings.Cut(tests.ServiceAccountEmail, "@") + + noPassSourceConfig := map[string]any{ + "kind": CloudSQLMySQLSourceKind, + "project": CloudSQLMySQLProject, + "instance": CloudSQLMySQLInstance, + "region": CloudSQLMySQLRegion, + "database": CloudSQLMySQLDatabase, + "user": serviceAccountEmail, + } + noUserSourceConfig := map[string]any{ + "kind": CloudSQLMySQLSourceKind, + "project": CloudSQLMySQLProject, + "instance": CloudSQLMySQLInstance, + "region": CloudSQLMySQLRegion, + "database": CloudSQLMySQLDatabase, + "password": "random", + } + noUserNoPassSourceConfig := map[string]any{ + "kind": CloudSQLMySQLSourceKind, + "project": CloudSQLMySQLProject, + "instance": CloudSQLMySQLInstance, + "region": CloudSQLMySQLRegion, + "database": CloudSQLMySQLDatabase, + } + tcs := []struct { + name string + sourceConfig map[string]any + isErr bool + }{ + { + name: "no user no pass", + sourceConfig: noUserNoPassSourceConfig, + isErr: false, + }, + { + name: "no password", + sourceConfig: noPassSourceConfig, + isErr: false, + }, + { + name: "no user", + sourceConfig: noUserSourceConfig, + isErr: true, + }, + } + for i, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + // Generate a UNIQUE source name for this test case. + // It ensures the app registers a unique driver name + // like "cloudsql-mysql-iam-test-0", preventing conflicts. + uniqueSourceName := fmt.Sprintf("iam-test-%d", i) + + // Construct the tools config manually (Copied from RunSourceConnectionTest) + toolsFile := map[string]any{ + "sources": map[string]any{ + uniqueSourceName: tc.sourceConfig, + }, + "tools": map[string]any{ + "my-simple-tool": map[string]any{ + "kind": CloudSQLMySQLToolKind, + "source": uniqueSourceName, + "description": "Simple tool to test end to end functionality.", + "statement": "SELECT 1;", + }, + }, + } + + // Start the Toolbox Command + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + // Wait for the server to be ready + waitCtx, waitCancel := context.WithTimeout(ctx, 10*time.Second) + defer waitCancel() + + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + if tc.isErr { + return + } + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("Connection test failure: toolbox didn't start successfully: %s", err) + } + + if tc.isErr { + t.Fatalf("Expected error but test passed.") + } + }) + } +}