feat(source/cloudsql-pg): add configuration for public and private IP (#114)

Allow user to set if their database uses private or public ip. The
reason we add this is because the dialer require different
initialization with private and public ip.

By default, toolbox will use public ip.
This commit is contained in:
Yuan
2024-12-05 16:08:15 -08:00
committed by GitHub
parent e815dc49f4
commit 6479c1dbe2
6 changed files with 94 additions and 18 deletions

View File

@@ -271,6 +271,7 @@ func TestParseToolFile(t *testing.T) {
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
IPType: "public",
Database: "my_db",
},
},

View File

@@ -121,13 +121,13 @@ func (c *SourceConfigs) UnmarshalYAML(node *yaml.Node) error {
}
switch k.Kind {
case alloydbpgsrc.SourceKind:
actual := alloydbpgsrc.Config{Name: name, IP_type: "public"}
actual := alloydbpgsrc.Config{Name: name, IPType: "public"}
if err := n.Decode(&actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
case cloudsqlpgsrc.SourceKind:
actual := cloudsqlpgsrc.Config{Name: name}
actual := cloudsqlpgsrc.Config{Name: name, IPType: "public"}
if err := n.Decode(&actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}

View File

@@ -37,7 +37,7 @@ type Config struct {
Region string `yaml:"region"`
Cluster string `yaml:"cluster"`
Instance string `yaml:"instance"`
IP_type sources.IPType `yaml:"ip_type"`
IPType sources.IPType `yaml:"ip_type"`
User string `yaml:"user"`
Password string `yaml:"password"`
Database string `yaml:"database"`
@@ -48,7 +48,7 @@ func (r Config) SourceConfigKind() string {
}
func (r Config) Initialize() (sources.Source, error) {
pool, err := initAlloyDBPgConnectionPool(r.Project, r.Region, r.Cluster, r.Instance, r.IP_type.String(), r.User, r.Password, r.Database)
pool, err := initAlloyDBPgConnectionPool(r.Project, r.Region, r.Cluster, r.Instance, r.IPType.String(), r.User, r.Password, r.Database)
if err != nil {
return nil, fmt.Errorf("unable to create pool: %w", err)
}

View File

@@ -51,7 +51,7 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
Region: "my-region",
Cluster: "my-cluster",
Instance: "my-instance",
IP_type: "public",
IPType: "public",
Database: "my_db",
},
},
@@ -77,7 +77,7 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
Region: "my-region",
Cluster: "my-cluster",
Instance: "my-instance",
IP_type: "private",
IPType: "private",
Database: "my_db",
},
},

View File

@@ -18,6 +18,7 @@ import (
"context"
"fmt"
"net"
"strings"
"cloud.google.com/go/cloudsqlconn"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -30,14 +31,15 @@ const SourceKind string = "cloud-sql-postgres"
var _ sources.SourceConfig = Config{}
type Config struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Project string `yaml:"project"`
Region string `yaml:"region"`
Instance string `yaml:"instance"`
User string `yaml:"user"`
Password string `yaml:"password"`
Database string `yaml:"database"`
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Project string `yaml:"project"`
Region string `yaml:"region"`
Instance string `yaml:"instance"`
IPType sources.IPType `yaml:"ip_type"`
User string `yaml:"user"`
Password string `yaml:"password"`
Database string `yaml:"database"`
}
func (r Config) SourceConfigKind() string {
@@ -45,7 +47,7 @@ func (r Config) SourceConfigKind() string {
}
func (r Config) Initialize() (sources.Source, error) {
pool, err := initCloudSQLPgConnectionPool(r.Project, r.Region, r.Instance, r.User, r.Password, r.Database)
pool, err := initCloudSQLPgConnectionPool(r.Project, r.Region, r.Instance, r.IPType.String(), r.User, r.Password, r.Database)
if err != nil {
return nil, fmt.Errorf("unable to create pool: %w", err)
}
@@ -79,7 +81,18 @@ func (s *Source) PostgresPool() *pgxpool.Pool {
return s.Pool
}
func initCloudSQLPgConnectionPool(project, region, instance, user, pass, dbname string) (*pgxpool.Pool, error) {
func getDialOpts(ip_type string) ([]cloudsqlconn.DialOption, error) {
switch strings.ToLower(ip_type) {
case "private":
return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil
case "public":
return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil
default:
return nil, fmt.Errorf("invalid ip_type %s", ip_type)
}
}
func initCloudSQLPgConnectionPool(project, region, instance, ip_type, user, pass, dbname string) (*pgxpool.Pool, error) {
// Configure the driver to connect to the database
dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", user, pass, dbname)
config, err := pgxpool.ParseConfig(dsn)
@@ -87,8 +100,12 @@ func initCloudSQLPgConnectionPool(project, region, instance, user, pass, dbname
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
}
// Create a new dialer with any options
d, err := cloudsqlconn.NewDialer(context.Background())
// Create a new dialer with options
dialOpts, err := getDialOpts(ip_type)
if err != nil {
return nil, err
}
d, err := cloudsqlconn.NewDialer(context.Background(), cloudsqlconn.WithDefaultDialOptions(dialOpts...))
if err != nil {
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
}

View File

@@ -48,6 +48,31 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
IPType: "public",
Database: "my_db",
},
},
},
{
desc: "basic example",
in: `
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
ip_type: private
database: my_db
`,
want: server.SourceConfigs{
"my-pg-instance": cloudsqlpg.Config{
Name: "my-pg-instance",
Kind: cloudsqlpg.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
IPType: "private",
Database: "my_db",
},
},
@@ -70,3 +95,36 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
}
}
func FailParseFromYamlCloudSQLPg(t *testing.T) {
tcs := []struct {
desc string
in string
}{
{
desc: "invalid ip_type",
in: `
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
ip_type: fail
database: my_db
`,
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail: %s", err)
}
})
}
}