mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-07 22:54:06 -05:00
feat(source/cloudsqlmysql): add support for IAM authentication in Cloud SQL MySQL source (#2050)
## Description This PR adds the support for IAM authentication in the Cloud SQL MySQL source ## PR Checklist > Thank you for opening a Pull Request! Before submitting your PR, there are a > few things you can do to make sure it goes smoothly: - [x] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) - [x] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [x] Ensure the tests and linter pass - [x] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) - [x] Make sure to add `!` if this involve a breaking change 🛠️ Fixes #<issue_number_goes_here>
This commit is contained in:
@@ -88,13 +88,40 @@ mTLS.
|
||||
[public-ip]: https://cloud.google.com/sql/docs/mysql/configure-ip
|
||||
[conn-overview]: https://cloud.google.com/sql/docs/mysql/connect-overview
|
||||
|
||||
### Database User
|
||||
### Authentication
|
||||
|
||||
Currently, this source only uses standard authentication. You will need to [create
|
||||
a MySQL user][cloud-sql-users] to login to the database with.
|
||||
This source supports both password-based authentication and IAM
|
||||
authentication (using your [Application Default Credentials][adc]).
|
||||
|
||||
#### Standard Authentication
|
||||
|
||||
To connect using user/password, [create
|
||||
a MySQL user][cloud-sql-users] and input your credentials in the `user` and
|
||||
`password` fields.
|
||||
|
||||
```yaml
|
||||
user: ${USER_NAME}
|
||||
password: ${PASSWORD}
|
||||
```
|
||||
|
||||
[cloud-sql-users]: https://cloud.google.com/sql/docs/mysql/create-manage-users
|
||||
|
||||
#### IAM Authentication
|
||||
|
||||
To connect using IAM authentication:
|
||||
|
||||
1. Prepare your database instance and user following this [guide][iam-guide].
|
||||
2. You could choose one of the two ways to log in:
|
||||
- Specify your IAM email as the `user`.
|
||||
- Leave your `user` field blank. Toolbox will fetch the [ADC][adc]
|
||||
automatically and log in using the email associated with it.
|
||||
|
||||
3. Leave the `password` field blank.
|
||||
|
||||
[iam-guide]: https://cloud.google.com/sql/docs/mysql/iam-logins
|
||||
[cloudsql-users]: https://cloud.google.com/sql/docs/mysql/create-manage-users
|
||||
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
@@ -124,6 +151,6 @@ instead of hardcoding your secrets into the configuration file.
|
||||
| region | string | true | Name of the GCP region that the cluster was created in (e.g. "us-central1"). |
|
||||
| instance | string | true | Name of the Cloud SQL instance within the cluster (e.g. "my-instance"). |
|
||||
| database | string | true | Name of the MySQL database to connect to (e.g. "my_db"). |
|
||||
| user | string | true | Name of the MySQL user to connect as (e.g. "my-pg-user"). |
|
||||
| password | string | true | Password of the MySQL user (e.g. "my-password"). |
|
||||
| user | string | false | Name of the MySQL user to connect as (e.g "my-mysql-user"). Defaults to IAM auth using [ADC][adc] email if unspecified. |
|
||||
| password | string | false | Password of the MySQL user (e.g. "my-password"). Defaults to attempting IAM authentication if unspecified. |
|
||||
| ipType | string | false | IP Type of the Cloud SQL instance, must be either `public`, `private`, or `psc`. Default: `public`. |
|
||||
|
||||
@@ -141,7 +141,7 @@ func getConnectionConfig(ctx context.Context, user, pass, dbname string) (string
|
||||
// If password is provided without an username, raise an error
|
||||
return "", useIAM, fmt.Errorf("password is provided without a username. Please provide both a username and password, or leave both fields empty")
|
||||
}
|
||||
email, err := sources.GetIAMPrincipalEmailFromADC(ctx)
|
||||
email, err := sources.GetIAMPrincipalEmailFromADC(ctx, "postgres")
|
||||
if err != nil {
|
||||
return "", useIAM, fmt.Errorf("error getting email from ADC: %v", err)
|
||||
}
|
||||
|
||||
@@ -54,8 +54,8 @@ type Config struct {
|
||||
Region string `yaml:"region" validate:"required"`
|
||||
Instance string `yaml:"instance" validate:"required"`
|
||||
IPType sources.IPType `yaml:"ipType"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
Password string `yaml:"password" validate:"required"`
|
||||
User string `yaml:"user"`
|
||||
Password string `yaml:"password"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
}
|
||||
|
||||
@@ -100,31 +100,89 @@ func (s *Source) MySQLPool() *sql.DB {
|
||||
return s.Pool
|
||||
}
|
||||
|
||||
func getConnectionConfig(ctx context.Context, user, pass string) (string, string, bool, error) {
|
||||
useIAM := true
|
||||
|
||||
// If username and password both provided, use password authentication
|
||||
if user != "" && pass != "" {
|
||||
useIAM = false
|
||||
return user, pass, useIAM, nil
|
||||
}
|
||||
|
||||
// If username is empty, fetch email from ADC
|
||||
// otherwise, use username as IAM email
|
||||
if user == "" {
|
||||
if pass != "" {
|
||||
return "", "", useIAM, fmt.Errorf("password is provided without a username. Please provide both a username and password, or leave both fields empty")
|
||||
}
|
||||
email, err := sources.GetIAMPrincipalEmailFromADC(ctx, "mysql")
|
||||
if err != nil {
|
||||
return "", "", useIAM, fmt.Errorf("error getting email from ADC: %v", err)
|
||||
}
|
||||
user = email
|
||||
}
|
||||
|
||||
// Pass the user, empty password and useIAM set to true
|
||||
return user, pass, useIAM, nil
|
||||
}
|
||||
|
||||
func initCloudSQLMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
// Configure the driver to connect to the database
|
||||
user, pass, useIAM, err := getConnectionConfig(ctx, user, pass)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get Cloud SQL connection config: %w", err)
|
||||
}
|
||||
|
||||
// Create a new dialer with options
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
opts, err := sources.GetCloudSQLOpts(ipType, userAgent, false)
|
||||
opts, err := sources.GetCloudSQLOpts(ipType, userAgent, useIAM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !slices.Contains(sql.Drivers(), "cloudsql-mysql") {
|
||||
_, err = mysql.RegisterDriver("cloudsql-mysql", opts...)
|
||||
if err != nil {
|
||||
// Use a unique driver name based on the source name.
|
||||
driverName := fmt.Sprintf("cloudsql-mysql-%s", name)
|
||||
|
||||
if !slices.Contains(sql.Drivers(), driverName) {
|
||||
if _, err := mysql.RegisterDriver(driverName, opts...); err != nil {
|
||||
return nil, fmt.Errorf("unable to register driver: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var dsn string
|
||||
// Tell the driver to use the Cloud SQL Go Connector to create connections
|
||||
dsn := fmt.Sprintf("%s:%s@cloudsql-mysql(%s:%s:%s)/%s?connectionAttributes=program_name:%s", user, pass, project, region, instance, dbname, url.QueryEscape(userAgent))
|
||||
if useIAM {
|
||||
dsn = fmt.Sprintf("%s@%s(%s:%s:%s)/%s?connectionAttributes=program_name:%s",
|
||||
user,
|
||||
driverName,
|
||||
project,
|
||||
region,
|
||||
instance,
|
||||
dbname,
|
||||
url.QueryEscape(userAgent),
|
||||
)
|
||||
} else {
|
||||
dsn = fmt.Sprintf("%s:%s@%s(%s:%s:%s)/%s?connectionAttributes=program_name:%s",
|
||||
user,
|
||||
pass,
|
||||
driverName,
|
||||
project,
|
||||
region,
|
||||
instance,
|
||||
dbname,
|
||||
url.QueryEscape(userAgent),
|
||||
)
|
||||
}
|
||||
|
||||
db, err := sql.Open(
|
||||
"cloudsql-mysql",
|
||||
driverName,
|
||||
dsn,
|
||||
)
|
||||
if err != nil {
|
||||
|
||||
@@ -120,7 +120,7 @@ func getConnectionConfig(ctx context.Context, user, pass, dbname string) (string
|
||||
// If password is provided without an username, raise an error
|
||||
return "", useIAM, fmt.Errorf("password is provided without a username. Please provide both a username and password, or leave both fields empty")
|
||||
}
|
||||
email, err := sources.GetIAMPrincipalEmailFromADC(ctx)
|
||||
email, err := sources.GetIAMPrincipalEmailFromADC(ctx, "postgres")
|
||||
if err != nil {
|
||||
return "", useIAM, fmt.Errorf("error getting email from ADC: %v", err)
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ func GetCloudSQLOpts(ipType, userAgent string, useIAM bool) ([]cloudsqlconn.Opti
|
||||
}
|
||||
|
||||
// GetIAMPrincipalEmailFromADC finds the email associated with ADC
|
||||
func GetIAMPrincipalEmailFromADC(ctx context.Context) (string, error) {
|
||||
func GetIAMPrincipalEmailFromADC(ctx context.Context, dbType string) (string, error) {
|
||||
// Finds ADC and returns an HTTP client associated with it
|
||||
client, err := google.DefaultClient(ctx,
|
||||
"https://www.googleapis.com/auth/userinfo.email")
|
||||
@@ -83,9 +83,31 @@ func GetIAMPrincipalEmailFromADC(ctx context.Context) (string, error) {
|
||||
if !ok {
|
||||
return "", fmt.Errorf("email not found in response: %v", err)
|
||||
}
|
||||
// service account email used for IAM should trim the suffix
|
||||
email := strings.TrimSuffix(emailValue.(string), ".gserviceaccount.com")
|
||||
return email, nil
|
||||
|
||||
fullEmail, ok := emailValue.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("email field is not a string")
|
||||
}
|
||||
|
||||
var username string
|
||||
// Format the username based on Database Type
|
||||
switch strings.ToLower(dbType) {
|
||||
case "mysql":
|
||||
username, _, _ = strings.Cut(fullEmail, "@")
|
||||
|
||||
case "postgres":
|
||||
// service account email used for IAM should trim the suffix
|
||||
username = strings.TrimSuffix(fullEmail, ".gserviceaccount.com")
|
||||
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported dbType: %s. Use 'mysql' or 'postgres'", dbType)
|
||||
}
|
||||
|
||||
if username == "" {
|
||||
return "", fmt.Errorf("username from ADC cannot be an empty string")
|
||||
}
|
||||
|
||||
return username, nil
|
||||
}
|
||||
|
||||
func GetIAMAccessToken(ctx context.Context) (string, error) {
|
||||
|
||||
@@ -192,3 +192,104 @@ func TestCloudSQLMySQLIpConnection(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudSQLMySQLIAMConnection(t *testing.T) {
|
||||
getCloudSQLMySQLVars(t)
|
||||
// service account email used for IAM should trim the suffix
|
||||
serviceAccountEmail, _, _ := strings.Cut(tests.ServiceAccountEmail, "@")
|
||||
|
||||
noPassSourceConfig := map[string]any{
|
||||
"kind": CloudSQLMySQLSourceKind,
|
||||
"project": CloudSQLMySQLProject,
|
||||
"instance": CloudSQLMySQLInstance,
|
||||
"region": CloudSQLMySQLRegion,
|
||||
"database": CloudSQLMySQLDatabase,
|
||||
"user": serviceAccountEmail,
|
||||
}
|
||||
noUserSourceConfig := map[string]any{
|
||||
"kind": CloudSQLMySQLSourceKind,
|
||||
"project": CloudSQLMySQLProject,
|
||||
"instance": CloudSQLMySQLInstance,
|
||||
"region": CloudSQLMySQLRegion,
|
||||
"database": CloudSQLMySQLDatabase,
|
||||
"password": "random",
|
||||
}
|
||||
noUserNoPassSourceConfig := map[string]any{
|
||||
"kind": CloudSQLMySQLSourceKind,
|
||||
"project": CloudSQLMySQLProject,
|
||||
"instance": CloudSQLMySQLInstance,
|
||||
"region": CloudSQLMySQLRegion,
|
||||
"database": CloudSQLMySQLDatabase,
|
||||
}
|
||||
tcs := []struct {
|
||||
name string
|
||||
sourceConfig map[string]any
|
||||
isErr bool
|
||||
}{
|
||||
{
|
||||
name: "no user no pass",
|
||||
sourceConfig: noUserNoPassSourceConfig,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "no password",
|
||||
sourceConfig: noPassSourceConfig,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "no user",
|
||||
sourceConfig: noUserSourceConfig,
|
||||
isErr: true,
|
||||
},
|
||||
}
|
||||
for i, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Generate a UNIQUE source name for this test case.
|
||||
// It ensures the app registers a unique driver name
|
||||
// like "cloudsql-mysql-iam-test-0", preventing conflicts.
|
||||
uniqueSourceName := fmt.Sprintf("iam-test-%d", i)
|
||||
|
||||
// Construct the tools config manually (Copied from RunSourceConnectionTest)
|
||||
toolsFile := map[string]any{
|
||||
"sources": map[string]any{
|
||||
uniqueSourceName: tc.sourceConfig,
|
||||
},
|
||||
"tools": map[string]any{
|
||||
"my-simple-tool": map[string]any{
|
||||
"kind": CloudSQLMySQLToolKind,
|
||||
"source": uniqueSourceName,
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"statement": "SELECT 1;",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Start the Toolbox Command
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
// Wait for the server to be ready
|
||||
waitCtx, waitCancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer waitCancel()
|
||||
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
if err != nil {
|
||||
if tc.isErr {
|
||||
return
|
||||
}
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("Connection test failure: toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
if tc.isErr {
|
||||
t.Fatalf("Expected error but test passed.")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user