From 57b77bca09ce6ee260bd64af9be5fcef593e9acb Mon Sep 17 00:00:00 2001 From: Mark L <73659136+liuxiaopai-ai@users.noreply.github.com> Date: Thu, 19 Feb 2026 04:35:44 +0800 Subject: [PATCH] feat(sources/postgres): add configurable pgx query execution mode (#2477) Adds optional `queryExecMode` to postgres source config, allowing users to set pgx `DefaultQueryExecMode` for compatibility with external connection poolers (e.g. transaction pooling). Supported values: - cache_statement (default) - cache_describe - describe_exec - exec - simple_protocol Implementation details: - parse DSN with `pgxpool.ParseConfig` - map `queryExecMode` to `pgx.QueryExecMode*` - create pool via `pgxpool.NewWithConfig` - validate config using `oneof` tag - document new field in postgres source docs - add parser/validation tests Tests run: `go test -v ./internal/sources/postgres -count=1 -vet=off` Refs #2385 --------- Co-authored-by: Molt (OpenClaw) Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com> Co-authored-by: Yuan Teoh --- docs/en/resources/sources/postgres.md | 1 + internal/sources/postgres/postgres.go | 52 ++++++++++++---- internal/sources/postgres/postgres_test.go | 71 ++++++++++++++++++++++ 3 files changed, 113 insertions(+), 11 deletions(-) diff --git a/docs/en/resources/sources/postgres.md b/docs/en/resources/sources/postgres.md index ed7c77aeee2..c42d8dca644 100644 --- a/docs/en/resources/sources/postgres.md +++ b/docs/en/resources/sources/postgres.md @@ -133,3 +133,4 @@ instead of hardcoding your secrets into the configuration file. | user | string | true | Name of the Postgres user to connect as (e.g. "my-pg-user"). | | password | string | true | Password of the Postgres user (e.g. "my-password"). | | queryParams | map[string]string | false | Raw query to be added to the db connection string. | +| queryExecMode | string | false | pgx query execution mode. Valid values: `cache_statement` (default), `cache_describe`, `describe_exec`, `exec`, `simple_protocol`. Useful with connection poolers that don't support prepared statement caching. | diff --git a/internal/sources/postgres/postgres.go b/internal/sources/postgres/postgres.go index 65d7a17b75d..b2ed44a95d4 100644 --- a/internal/sources/postgres/postgres.go +++ b/internal/sources/postgres/postgres.go @@ -24,6 +24,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/orderedmap" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "go.opentelemetry.io/otel/trace" ) @@ -48,14 +49,15 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources } type Config struct { - Name string `yaml:"name" validate:"required"` - Type string `yaml:"type" validate:"required"` - Host string `yaml:"host" validate:"required"` - Port string `yaml:"port" validate:"required"` - User string `yaml:"user" validate:"required"` - Password string `yaml:"password" validate:"required"` - Database string `yaml:"database" validate:"required"` - QueryParams map[string]string `yaml:"queryParams"` + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Host string `yaml:"host" validate:"required"` + Port string `yaml:"port" validate:"required"` + User string `yaml:"user" validate:"required"` + Password string `yaml:"password" validate:"required"` + Database string `yaml:"database" validate:"required"` + QueryParams map[string]string `yaml:"queryParams"` + QueryExecMode string `yaml:"queryExecMode" validate:"omitempty,oneof=cache_statement cache_describe describe_exec exec simple_protocol"` } func (r Config) SourceConfigType() string { @@ -63,7 +65,7 @@ func (r Config) SourceConfigType() string { } func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) { - pool, err := initPostgresConnectionPool(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Database, r.QueryParams) + pool, err := initPostgresConnectionPool(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Database, r.QueryParams, r.QueryExecMode) if err != nil { return nil, fmt.Errorf("unable to create pool: %w", err) } @@ -126,7 +128,7 @@ func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (an return out, nil } -func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string) (*pgxpool.Pool, error) { +func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string, queryExecMode string) (*pgxpool.Pool, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name) defer span.End() @@ -150,7 +152,18 @@ func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name, Path: dbname, RawQuery: ConvertParamMapToRawQuery(queryParams), } - pool, err := pgxpool.New(ctx, url.String()) + config, err := pgxpool.ParseConfig(url.String()) + if err != nil { + return nil, fmt.Errorf("unable to parse connection uri: %w", err) + } + + execMode, err := ParseQueryExecMode(queryExecMode) + if err != nil { + return nil, err + } + config.ConnConfig.DefaultQueryExecMode = execMode + + pool, err := pgxpool.NewWithConfig(ctx, config) if err != nil { return nil, fmt.Errorf("unable to create connection pool: %w", err) } @@ -165,3 +178,20 @@ func ConvertParamMapToRawQuery(queryParams map[string]string) string { } return strings.Join(queryArray, "&") } + +func ParseQueryExecMode(queryExecMode string) (pgx.QueryExecMode, error) { + switch queryExecMode { + case "", "cache_statement": + return pgx.QueryExecModeCacheStatement, nil + case "cache_describe": + return pgx.QueryExecModeCacheDescribe, nil + case "describe_exec": + return pgx.QueryExecModeDescribeExec, nil + case "exec": + return pgx.QueryExecModeExec, nil + case "simple_protocol": + return pgx.QueryExecModeSimpleProtocol, nil + default: + return 0, fmt.Errorf("invalid queryExecMode %q: must be one of %q, %q, %q, %q, or %q", queryExecMode, "cache_statement", "cache_describe", "describe_exec", "exec", "simple_protocol") + } +} diff --git a/internal/sources/postgres/postgres_test.go b/internal/sources/postgres/postgres_test.go index 4bcde0420e5..0441f356a3d 100644 --- a/internal/sources/postgres/postgres_test.go +++ b/internal/sources/postgres/postgres_test.go @@ -25,6 +25,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/jackc/pgx/v5" ) func TestParseFromYamlPostgres(t *testing.T) { @@ -88,6 +89,32 @@ func TestParseFromYamlPostgres(t *testing.T) { }, }, }, + { + desc: "example with query exec mode", + in: ` + kind: sources + name: my-pg-instance + type: postgres + host: my-host + port: my-port + database: my_db + user: my_user + password: my_pass + queryExecMode: simple_protocol + `, + want: map[string]sources.SourceConfig{ + "my-pg-instance": postgres.Config{ + Name: "my-pg-instance", + Type: postgres.SourceType, + Host: "my-host", + Port: "my-port", + Database: "my_db", + User: "my_user", + Password: "my_pass", + QueryExecMode: "simple_protocol", + }, + }, + }, } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { @@ -137,6 +164,21 @@ func TestFailParseFromYaml(t *testing.T) { `, err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"postgres\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag", }, + { + desc: "invalid query exec mode", + in: ` + kind: sources + name: my-pg-instance + type: postgres + host: my-host + port: my-port + database: my_db + user: my_user + password: my_pass + queryExecMode: invalid_mode + `, + err: "error unmarshaling sources: unable to parse source \"my-pg-instance\" as \"postgres\": [6:16] Key: 'Config.QueryExecMode' Error:Field validation for 'QueryExecMode' failed on the 'oneof' tag\n 3 | name: my-pg-instance\n 4 | password: my_pass\n 5 | port: my-port\n> 6 | queryExecMode: invalid_mode\n ^\n 7 | type: postgres\n 8 | user: my_user", + }, } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { @@ -193,3 +235,32 @@ func TestConvertParamMapToRawQuery(t *testing.T) { }) } } + +func TestParseQueryExecMode(t *testing.T) { + tcs := []struct { + desc string + in string + want pgx.QueryExecMode + wantErr bool + }{ + {desc: "empty (default)", in: "", want: pgx.QueryExecModeCacheStatement}, + {desc: "cache_statement", in: "cache_statement", want: pgx.QueryExecModeCacheStatement}, + {desc: "cache_describe", in: "cache_describe", want: pgx.QueryExecModeCacheDescribe}, + {desc: "describe_exec", in: "describe_exec", want: pgx.QueryExecModeDescribeExec}, + {desc: "exec", in: "exec", want: pgx.QueryExecModeExec}, + {desc: "simple_protocol", in: "simple_protocol", want: pgx.QueryExecModeSimpleProtocol}, + {desc: "invalid mode", in: "invalid_mode", wantErr: true}, + } + + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got, err := postgres.ParseQueryExecMode(tc.in) + if (err != nil) != tc.wantErr { + t.Fatalf("parseQueryExecMode() error = %v, wantErr %v", err, tc.wantErr) + } + if !tc.wantErr && got != tc.want { + t.Errorf("parseQueryExecMode() = %v, want %v", got, tc.want) + } + }) + } +}