mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 07:58:12 -05:00
feat(source/cloudsql-pg): add configuration for public and private IP (#114)
Allow user to set if their database uses private or public ip. The reason we add this is because the dialer require different initialization with private and public ip. By default, toolbox will use public ip.
This commit is contained in:
@@ -271,6 +271,7 @@ func TestParseToolFile(t *testing.T) {
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
IPType: "public",
|
||||
Database: "my_db",
|
||||
},
|
||||
},
|
||||
|
||||
@@ -121,13 +121,13 @@ func (c *SourceConfigs) UnmarshalYAML(node *yaml.Node) error {
|
||||
}
|
||||
switch k.Kind {
|
||||
case alloydbpgsrc.SourceKind:
|
||||
actual := alloydbpgsrc.Config{Name: name, IP_type: "public"}
|
||||
actual := alloydbpgsrc.Config{Name: name, IPType: "public"}
|
||||
if err := n.Decode(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
case cloudsqlpgsrc.SourceKind:
|
||||
actual := cloudsqlpgsrc.Config{Name: name}
|
||||
actual := cloudsqlpgsrc.Config{Name: name, IPType: "public"}
|
||||
if err := n.Decode(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ type Config struct {
|
||||
Region string `yaml:"region"`
|
||||
Cluster string `yaml:"cluster"`
|
||||
Instance string `yaml:"instance"`
|
||||
IP_type sources.IPType `yaml:"ip_type"`
|
||||
IPType sources.IPType `yaml:"ip_type"`
|
||||
User string `yaml:"user"`
|
||||
Password string `yaml:"password"`
|
||||
Database string `yaml:"database"`
|
||||
@@ -48,7 +48,7 @@ func (r Config) SourceConfigKind() string {
|
||||
}
|
||||
|
||||
func (r Config) Initialize() (sources.Source, error) {
|
||||
pool, err := initAlloyDBPgConnectionPool(r.Project, r.Region, r.Cluster, r.Instance, r.IP_type.String(), r.User, r.Password, r.Database)
|
||||
pool, err := initAlloyDBPgConnectionPool(r.Project, r.Region, r.Cluster, r.Instance, r.IPType.String(), r.User, r.Password, r.Database)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create pool: %w", err)
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
|
||||
Region: "my-region",
|
||||
Cluster: "my-cluster",
|
||||
Instance: "my-instance",
|
||||
IP_type: "public",
|
||||
IPType: "public",
|
||||
Database: "my_db",
|
||||
},
|
||||
},
|
||||
@@ -77,7 +77,7 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
|
||||
Region: "my-region",
|
||||
Cluster: "my-cluster",
|
||||
Instance: "my-instance",
|
||||
IP_type: "private",
|
||||
IPType: "private",
|
||||
Database: "my_db",
|
||||
},
|
||||
},
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"cloud.google.com/go/cloudsqlconn"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -30,14 +31,15 @@ const SourceKind string = "cloud-sql-postgres"
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Project string `yaml:"project"`
|
||||
Region string `yaml:"region"`
|
||||
Instance string `yaml:"instance"`
|
||||
User string `yaml:"user"`
|
||||
Password string `yaml:"password"`
|
||||
Database string `yaml:"database"`
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Project string `yaml:"project"`
|
||||
Region string `yaml:"region"`
|
||||
Instance string `yaml:"instance"`
|
||||
IPType sources.IPType `yaml:"ip_type"`
|
||||
User string `yaml:"user"`
|
||||
Password string `yaml:"password"`
|
||||
Database string `yaml:"database"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
@@ -45,7 +47,7 @@ func (r Config) SourceConfigKind() string {
|
||||
}
|
||||
|
||||
func (r Config) Initialize() (sources.Source, error) {
|
||||
pool, err := initCloudSQLPgConnectionPool(r.Project, r.Region, r.Instance, r.User, r.Password, r.Database)
|
||||
pool, err := initCloudSQLPgConnectionPool(r.Project, r.Region, r.Instance, r.IPType.String(), r.User, r.Password, r.Database)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create pool: %w", err)
|
||||
}
|
||||
@@ -79,7 +81,18 @@ func (s *Source) PostgresPool() *pgxpool.Pool {
|
||||
return s.Pool
|
||||
}
|
||||
|
||||
func initCloudSQLPgConnectionPool(project, region, instance, user, pass, dbname string) (*pgxpool.Pool, error) {
|
||||
func getDialOpts(ip_type string) ([]cloudsqlconn.DialOption, error) {
|
||||
switch strings.ToLower(ip_type) {
|
||||
case "private":
|
||||
return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil
|
||||
case "public":
|
||||
return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid ip_type %s", ip_type)
|
||||
}
|
||||
}
|
||||
|
||||
func initCloudSQLPgConnectionPool(project, region, instance, ip_type, user, pass, dbname string) (*pgxpool.Pool, error) {
|
||||
// Configure the driver to connect to the database
|
||||
dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", user, pass, dbname)
|
||||
config, err := pgxpool.ParseConfig(dsn)
|
||||
@@ -87,8 +100,12 @@ func initCloudSQLPgConnectionPool(project, region, instance, user, pass, dbname
|
||||
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
|
||||
}
|
||||
|
||||
// Create a new dialer with any options
|
||||
d, err := cloudsqlconn.NewDialer(context.Background())
|
||||
// Create a new dialer with options
|
||||
dialOpts, err := getDialOpts(ip_type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d, err := cloudsqlconn.NewDialer(context.Background(), cloudsqlconn.WithDefaultDialOptions(dialOpts...))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
|
||||
}
|
||||
|
||||
@@ -48,6 +48,31 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
IPType: "public",
|
||||
Database: "my_db",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ip_type: private
|
||||
database: my_db
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpg.Config{
|
||||
Name: "my-pg-instance",
|
||||
Kind: cloudsqlpg.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
IPType: "private",
|
||||
Database: "my_db",
|
||||
},
|
||||
},
|
||||
@@ -70,3 +95,36 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func FailParseFromYamlCloudSQLPg(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
}{
|
||||
{
|
||||
desc: "invalid ip_type",
|
||||
in: `
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ip_type: fail
|
||||
database: my_db
|
||||
`,
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user