mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 15:38:08 -05:00
feat(sources/postgres): add support for queryParams (#1047)
Add support for `queryParams` for users that would like to connect with
additional query parameters.
```
sources:
my-pg-source:
kind: postgres
host: 127.0.0.1
port: 5432
database: my_db
user: ${USER_NAME}
password: ${PASSWORD}
queryParams:
sslmode: verify-full
sslrootcert: /tmp/ca.crt
```
`queryParams` will be added as raw query of the database connection url.
Fixes #963
This commit is contained in:
@@ -57,11 +57,12 @@ instead of hardcoding your secrets into the configuration file.
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-----------|:--------:|:------------:|------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "postgres". |
|
||||
| host | string | true | IP address to connect to (e.g. "127.0.0.1") |
|
||||
| port | string | true | Port to connect to (e.g. "5432") |
|
||||
| database | string | true | Name of the Postgres database to connect to (e.g. "my_db"). |
|
||||
| user | string | true | Name of the Postgres user to connect as (e.g. "my-pg-user"). |
|
||||
| password | string | true | Password of the Postgres user (e.g. "my-password"). |
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|:------------------:|:------------:|------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "postgres". |
|
||||
| host | string | true | IP address to connect to (e.g. "127.0.0.1") |
|
||||
| port | string | true | Port to connect to (e.g. "5432") |
|
||||
| database | string | true | Name of the Postgres database to connect to (e.g. "my_db"). |
|
||||
| user | string | true | Name of the Postgres user to connect as (e.g. "my-pg-user"). |
|
||||
| password | string | true | Password of the Postgres user (e.g. "my-password"). |
|
||||
| queryParams | map[string]string | false | Raw query to be added to the db connection string. |
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -45,13 +46,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
Password string `yaml:"password" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
Password string `yaml:"password" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
QueryParams map[string]string `yaml:"queryParams"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
@@ -59,7 +61,7 @@ func (r Config) SourceConfigKind() string {
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
pool, err := initPostgresConnectionPool(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Database)
|
||||
pool, err := initPostgresConnectionPool(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Database, r.QueryParams)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create pool: %w", err)
|
||||
}
|
||||
@@ -93,17 +95,18 @@ func (s *Source) PostgresPool() *pgxpool.Pool {
|
||||
return s.Pool
|
||||
}
|
||||
|
||||
func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string) (*pgxpool.Pool, error) {
|
||||
func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string) (*pgxpool.Pool, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
// urlExample := "postgres:dd//username:password@localhost:5432/database_name"
|
||||
url := &url.URL{
|
||||
Scheme: "postgres",
|
||||
User: url.UserPassword(user, pass),
|
||||
Host: fmt.Sprintf("%s:%s", host, port),
|
||||
Path: dbname,
|
||||
Scheme: "postgres",
|
||||
User: url.UserPassword(user, pass),
|
||||
Host: fmt.Sprintf("%s:%s", host, port),
|
||||
Path: dbname,
|
||||
RawQuery: ConvertParamMapToRawQuery(queryParams),
|
||||
}
|
||||
pool, err := pgxpool.New(ctx, url.String())
|
||||
if err != nil {
|
||||
@@ -112,3 +115,11 @@ func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name,
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
func ConvertParamMapToRawQuery(queryParams map[string]string) string {
|
||||
queryArray := []string{}
|
||||
for k, v := range queryParams {
|
||||
queryArray = append(queryArray, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
return strings.Join(queryArray, "&")
|
||||
}
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
package postgres_test
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
@@ -54,6 +56,37 @@ func TestParseFromYamlPostgres(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "example with query params",
|
||||
in: `
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: postgres
|
||||
host: my-host
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
queryParams:
|
||||
sslmode: verify-full
|
||||
sslrootcert: /tmp/ca.crt
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-pg-instance": postgres.Config{
|
||||
Name: "my-pg-instance",
|
||||
Kind: postgres.SourceKind,
|
||||
Host: "my-host",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
User: "my_user",
|
||||
Password: "my_pass",
|
||||
QueryParams: map[string]string{
|
||||
"sslmode": "verify-full",
|
||||
"sslrootcert": "/tmp/ca.crt",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
@@ -125,3 +158,45 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertParamMapToRawQuery(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in map[string]string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
desc: "nil param",
|
||||
in: nil,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
desc: "single query param",
|
||||
in: map[string]string{
|
||||
"foo": "bar",
|
||||
},
|
||||
want: "foo=bar",
|
||||
},
|
||||
{
|
||||
desc: "more than one query param",
|
||||
in: map[string]string{
|
||||
"foo": "bar",
|
||||
"hello": "world",
|
||||
},
|
||||
want: "foo=bar&hello=world",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := postgres.ConvertParamMapToRawQuery(tc.in)
|
||||
if strings.Contains(got, "&") {
|
||||
splitGot := strings.Split(got, "&")
|
||||
sort.Strings(splitGot)
|
||||
got = strings.Join(splitGot, "&")
|
||||
}
|
||||
if got != tc.want {
|
||||
t.Fatalf("incorrect conversion: got %s want %s", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user