feat(sources/postgres): add support for queryParams (#1047)

Add support for `queryParams` for users that would like to connect with
additional query parameters.

```
sources:
    my-pg-source:
        kind: postgres
        host: 127.0.0.1
        port: 5432
        database: my_db
        user: ${USER_NAME}
        password: ${PASSWORD}
        queryParams:
            sslmode: verify-full
            sslrootcert: /tmp/ca.crt
```

`queryParams` will be added as raw query of the database connection url.

Fixes #963
This commit is contained in:
Yuan Teoh
2025-08-12 15:34:56 -07:00
committed by GitHub
parent de3429bdf1
commit 7b57251402
3 changed files with 108 additions and 21 deletions

View File

@@ -57,11 +57,12 @@ instead of hardcoding your secrets into the configuration file.
## Reference
| **field** | **type** | **required** | **description** |
|-----------|:--------:|:------------:|------------------------------------------------------------------------|
| kind | string | true | Must be "postgres". |
| host | string | true | IP address to connect to (e.g. "127.0.0.1") |
| port | string | true | Port to connect to (e.g. "5432") |
| database | string | true | Name of the Postgres database to connect to (e.g. "my_db"). |
| user | string | true | Name of the Postgres user to connect as (e.g. "my-pg-user"). |
| password | string | true | Password of the Postgres user (e.g. "my-password"). |
| **field** | **type** | **required** | **description** |
|-------------|:------------------:|:------------:|------------------------------------------------------------------------|
| kind | string | true | Must be "postgres". |
| host | string | true | IP address to connect to (e.g. "127.0.0.1") |
| port | string | true | Port to connect to (e.g. "5432") |
| database | string | true | Name of the Postgres database to connect to (e.g. "my_db"). |
| user | string | true | Name of the Postgres user to connect as (e.g. "my-pg-user"). |
| password | string | true | Password of the Postgres user (e.g. "my-password"). |
| queryParams | map[string]string | false | Raw query to be added to the db connection string. |

View File

@@ -18,6 +18,7 @@ import (
"context"
"fmt"
"net/url"
"strings"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -45,13 +46,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
Password string `yaml:"password" validate:"required"`
Database string `yaml:"database" validate:"required"`
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
Password string `yaml:"password" validate:"required"`
Database string `yaml:"database" validate:"required"`
QueryParams map[string]string `yaml:"queryParams"`
}
func (r Config) SourceConfigKind() string {
@@ -59,7 +61,7 @@ func (r Config) SourceConfigKind() string {
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
pool, err := initPostgresConnectionPool(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Database)
pool, err := initPostgresConnectionPool(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Database, r.QueryParams)
if err != nil {
return nil, fmt.Errorf("unable to create pool: %w", err)
}
@@ -93,17 +95,18 @@ func (s *Source) PostgresPool() *pgxpool.Pool {
return s.Pool
}
func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string) (*pgxpool.Pool, error) {
func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string) (*pgxpool.Pool, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// urlExample := "postgres:dd//username:password@localhost:5432/database_name"
url := &url.URL{
Scheme: "postgres",
User: url.UserPassword(user, pass),
Host: fmt.Sprintf("%s:%s", host, port),
Path: dbname,
Scheme: "postgres",
User: url.UserPassword(user, pass),
Host: fmt.Sprintf("%s:%s", host, port),
Path: dbname,
RawQuery: ConvertParamMapToRawQuery(queryParams),
}
pool, err := pgxpool.New(ctx, url.String())
if err != nil {
@@ -112,3 +115,11 @@ func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name,
return pool, nil
}
func ConvertParamMapToRawQuery(queryParams map[string]string) string {
queryArray := []string{}
for k, v := range queryParams {
queryArray = append(queryArray, fmt.Sprintf("%s=%s", k, v))
}
return strings.Join(queryArray, "&")
}

View File

@@ -15,6 +15,8 @@
package postgres_test
import (
"sort"
"strings"
"testing"
yaml "github.com/goccy/go-yaml"
@@ -54,6 +56,37 @@ func TestParseFromYamlPostgres(t *testing.T) {
},
},
},
{
desc: "example with query params",
in: `
sources:
my-pg-instance:
kind: postgres
host: my-host
port: my-port
database: my_db
user: my_user
password: my_pass
queryParams:
sslmode: verify-full
sslrootcert: /tmp/ca.crt
`,
want: server.SourceConfigs{
"my-pg-instance": postgres.Config{
Name: "my-pg-instance",
Kind: postgres.SourceKind,
Host: "my-host",
Port: "my-port",
Database: "my_db",
User: "my_user",
Password: "my_pass",
QueryParams: map[string]string{
"sslmode": "verify-full",
"sslrootcert": "/tmp/ca.crt",
},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
@@ -125,3 +158,45 @@ func TestFailParseFromYaml(t *testing.T) {
})
}
}
func TestConvertParamMapToRawQuery(t *testing.T) {
tcs := []struct {
desc string
in map[string]string
want string
}{
{
desc: "nil param",
in: nil,
want: "",
},
{
desc: "single query param",
in: map[string]string{
"foo": "bar",
},
want: "foo=bar",
},
{
desc: "more than one query param",
in: map[string]string{
"foo": "bar",
"hello": "world",
},
want: "foo=bar&hello=world",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := postgres.ConvertParamMapToRawQuery(tc.in)
if strings.Contains(got, "&") {
splitGot := strings.Split(got, "&")
sort.Strings(splitGot)
got = strings.Join(splitGot, "&")
}
if got != tc.want {
t.Fatalf("incorrect conversion: got %s want %s", got, tc.want)
}
})
}
}