feat(source/mysql): support queryParams in MySQL source (#1299)

Fixes #1286

### Motivation
* Allow secure connections to PostgreSQL without custom code.

### Changes
#### Sources
* `mysql`: `Config.QueryParams` added; DSN building rewritten via
`url.Values`.

---------

Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
Co-authored-by: Yuan Teoh <yuanteoh@google.com>
This commit is contained in:
Valeriy
2025-09-03 20:48:22 +03:00
committed by GitHub
parent 8755e3db34
commit 3ae2526e0f
3 changed files with 115 additions and 20 deletions

View File

@@ -42,6 +42,9 @@ sources:
database: my_db
user: ${USER_NAME}
password: ${PASSWORD}
# Optional TLS and other driver parameters. For example, enable preferred TLS:
# queryParams:
# tls: preferred
queryTimeout: 30s # Optional: query timeout duration
```
@@ -61,3 +64,4 @@ instead of hardcoding your secrets into the configuration file.
| user | string | true | Name of the MySQL user to connect as (e.g. "my-mysql-user"). |
| password | string | true | Password of the MySQL user (e.g. "my-password"). |
| queryTimeout | string | false | Maximum time to wait for query execution (e.g. "30s", "2m"). By default, no timeout is applied. |
| queryParams | map<string,string> | false | Arbitrary DSN parameters passed to the driver (e.g. `tls: preferred`, `charset: utf8mb4`). Useful for enabling TLS or other connection options. |

View File

@@ -18,6 +18,7 @@ import (
"context"
"database/sql"
"fmt"
"net/url"
"time"
_ "github.com/go-sql-driver/mysql"
@@ -46,14 +47,15 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
Password string `yaml:"password" validate:"required"`
Database string `yaml:"database" validate:"required"`
QueryTimeout string `yaml:"queryTimeout"`
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Host string `yaml:"host" validate:"required"`
Port string `yaml:"port" validate:"required"`
User string `yaml:"user" validate:"required"`
Password string `yaml:"password" validate:"required"`
Database string `yaml:"database" validate:"required"`
QueryTimeout string `yaml:"queryTimeout"`
QueryParams map[string]string `yaml:"queryParams"`
}
func (r Config) SourceConfigKind() string {
@@ -61,7 +63,7 @@ func (r Config) SourceConfigKind() string {
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
pool, err := initMySQLConnectionPool(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Database, r.QueryTimeout)
pool, err := initMySQLConnectionPool(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Database, r.QueryTimeout, r.QueryParams)
if err != nil {
return nil, fmt.Errorf("unable to create pool: %w", err)
}
@@ -95,21 +97,34 @@ func (s *Source) MySQLPool() *sql.DB {
return s.Pool
}
func initMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
func initMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string, queryParams map[string]string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Configure the driver to connect to the database
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", user, pass, host, port, dbname)
// Build query parameters via url.Values for deterministic order and proper escaping.
values := url.Values{}
// Add query timeout to DSN if specified
// Derive readTimeout from queryTimeout when provided.
if queryTimeout != "" {
timeout, err := time.ParseDuration(queryTimeout)
if err != nil {
return nil, fmt.Errorf("invalid queryTimeout %q: %w", queryTimeout, err)
}
dsn += "&readTimeout=" + timeout.String()
values.Set("readTimeout", timeout.String())
}
// Custom user parameters
for k, v := range queryParams {
if v == "" {
continue // skip empty values
}
values.Set(k, v)
}
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", user, pass, host, port, dbname)
if enc := values.Encode(); enc != "" {
dsn += "&" + enc
}
// Interact with the driver directly as you normally would

View File

@@ -15,10 +15,15 @@
package mysql_test
import (
"context"
"strings"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"go.opentelemetry.io/otel/trace/noop"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources/mysql"
"github.com/googleapis/genai-toolbox/internal/testutils"
@@ -80,9 +85,41 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
},
},
},
{
desc: "with query params",
in: `
sources:
my-mysql-instance:
kind: mysql
host: 0.0.0.0
port: my-port
database: my_db
user: my_user
password: my_pass
queryParams:
tls: preferred
charset: utf8mb4
`,
want: server.SourceConfigs{
"my-mysql-instance": mysql.Config{
Name: "my-mysql-instance",
Kind: mysql.SourceKind,
Host: "0.0.0.0",
Port: "my-port",
Database: "my_db",
User: "my_user",
Password: "my_pass",
QueryParams: map[string]string{
"tls": "preferred",
"charset": "utf8mb4",
},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
@@ -91,8 +128,8 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
if diff := cmp.Diff(tc.want, got.Sources, cmpopts.EquateEmpty()); diff != "" {
t.Fatalf("mismatch (-want +got):\n%s", diff)
}
})
}
@@ -118,7 +155,7 @@ func TestFailParseFromYaml(t *testing.T) {
password: my_pass
foo: bar
`,
err: "unable to parse source \"my-mysql-instance\" as \"mysql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | kind: mysql\n 5 | password: my_pass\n 6 | ",
err: "unknown field \"foo\"",
},
{
desc: "missing required field",
@@ -131,11 +168,27 @@ func TestFailParseFromYaml(t *testing.T) {
user: my_user
password: my_pass
`,
err: "unable to parse source \"my-mysql-instance\" as \"mysql\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
err: "Field validation for 'Host' failed",
},
{
desc: "invalid query params type",
in: `
sources:
my-mysql-instance:
kind: mysql
host: 0.0.0.0
port: 3306
database: my_db
user: my_user
password: my_pass
queryParams: not-a-map
`,
err: "string was used where mapping is expected",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
@@ -145,9 +198,32 @@ func TestFailParseFromYaml(t *testing.T) {
t.Fatalf("expect parsing to fail")
}
errStr := err.Error()
if errStr != tc.err {
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
if !strings.Contains(errStr, tc.err) {
t.Fatalf("unexpected error: got %q, want substring %q", errStr, tc.err)
}
})
}
}
// TestFailInitialization test error during initialization without attempting a DB connection.
func TestFailInitialization(t *testing.T) {
t.Parallel()
cfg := mysql.Config{
Name: "instance",
Kind: "mysql",
Host: "localhost",
Port: "3306",
Database: "db",
User: "user",
Password: "pass",
QueryTimeout: "abc", // invalid duration
}
_, err := cfg.Initialize(context.Background(), noop.NewTracerProvider().Tracer("test"))
if err == nil {
t.Fatalf("expected error for invalid queryTimeout, got nil")
}
if !strings.Contains(err.Error(), "invalid queryTimeout") {
t.Fatalf("unexpected error: %v", err)
}
}