mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 16:08:16 -05:00
feat(source/mysql): support queryParams in MySQL source (#1299)
Fixes #1286 ### Motivation * Allow secure connections to PostgreSQL without custom code. ### Changes #### Sources * `mysql`: `Config.QueryParams` added; DSN building rewritten via `url.Values`. --------- Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com> Co-authored-by: Yuan Teoh <yuanteoh@google.com>
This commit is contained in:
@@ -42,6 +42,9 @@ sources:
|
||||
database: my_db
|
||||
user: ${USER_NAME}
|
||||
password: ${PASSWORD}
|
||||
# Optional TLS and other driver parameters. For example, enable preferred TLS:
|
||||
# queryParams:
|
||||
# tls: preferred
|
||||
queryTimeout: 30s # Optional: query timeout duration
|
||||
```
|
||||
|
||||
@@ -61,3 +64,4 @@ instead of hardcoding your secrets into the configuration file.
|
||||
| user | string | true | Name of the MySQL user to connect as (e.g. "my-mysql-user"). |
|
||||
| password | string | true | Password of the MySQL user (e.g. "my-password"). |
|
||||
| queryTimeout | string | false | Maximum time to wait for query execution (e.g. "30s", "2m"). By default, no timeout is applied. |
|
||||
| queryParams | map<string,string> | false | Arbitrary DSN parameters passed to the driver (e.g. `tls: preferred`, `charset: utf8mb4`). Useful for enabling TLS or other connection options. |
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
@@ -46,14 +47,15 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
Password string `yaml:"password" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
QueryTimeout string `yaml:"queryTimeout"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
Password string `yaml:"password" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
QueryTimeout string `yaml:"queryTimeout"`
|
||||
QueryParams map[string]string `yaml:"queryParams"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
@@ -61,7 +63,7 @@ func (r Config) SourceConfigKind() string {
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
pool, err := initMySQLConnectionPool(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Database, r.QueryTimeout)
|
||||
pool, err := initMySQLConnectionPool(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Database, r.QueryTimeout, r.QueryParams)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create pool: %w", err)
|
||||
}
|
||||
@@ -95,21 +97,34 @@ func (s *Source) MySQLPool() *sql.DB {
|
||||
return s.Pool
|
||||
}
|
||||
|
||||
func initMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) {
|
||||
func initMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string, queryParams map[string]string) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
// Configure the driver to connect to the database
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", user, pass, host, port, dbname)
|
||||
// Build query parameters via url.Values for deterministic order and proper escaping.
|
||||
values := url.Values{}
|
||||
|
||||
// Add query timeout to DSN if specified
|
||||
// Derive readTimeout from queryTimeout when provided.
|
||||
if queryTimeout != "" {
|
||||
timeout, err := time.ParseDuration(queryTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid queryTimeout %q: %w", queryTimeout, err)
|
||||
}
|
||||
dsn += "&readTimeout=" + timeout.String()
|
||||
values.Set("readTimeout", timeout.String())
|
||||
}
|
||||
|
||||
// Custom user parameters
|
||||
for k, v := range queryParams {
|
||||
if v == "" {
|
||||
continue // skip empty values
|
||||
}
|
||||
values.Set(k, v)
|
||||
}
|
||||
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", user, pass, host, port, dbname)
|
||||
if enc := values.Encode(); enc != "" {
|
||||
dsn += "&" + enc
|
||||
}
|
||||
|
||||
// Interact with the driver directly as you normally would
|
||||
|
||||
@@ -15,10 +15,15 @@
|
||||
package mysql_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"go.opentelemetry.io/otel/trace/noop"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/mysql"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
@@ -80,9 +85,41 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with query params",
|
||||
in: `
|
||||
sources:
|
||||
my-mysql-instance:
|
||||
kind: mysql
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
queryParams:
|
||||
tls: preferred
|
||||
charset: utf8mb4
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-mysql-instance": mysql.Config{
|
||||
Name: "my-mysql-instance",
|
||||
Kind: mysql.SourceKind,
|
||||
Host: "0.0.0.0",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
User: "my_user",
|
||||
Password: "my_pass",
|
||||
QueryParams: map[string]string{
|
||||
"tls": "preferred",
|
||||
"charset": "utf8mb4",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
@@ -91,8 +128,8 @@ func TestParseFromYamlCloudSQLMySQL(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
if diff := cmp.Diff(tc.want, got.Sources, cmpopts.EquateEmpty()); diff != "" {
|
||||
t.Fatalf("mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -118,7 +155,7 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
password: my_pass
|
||||
foo: bar
|
||||
`,
|
||||
err: "unable to parse source \"my-mysql-instance\" as \"mysql\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | kind: mysql\n 5 | password: my_pass\n 6 | ",
|
||||
err: "unknown field \"foo\"",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
@@ -131,11 +168,27 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
err: "unable to parse source \"my-mysql-instance\" as \"mysql\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
|
||||
err: "Field validation for 'Host' failed",
|
||||
},
|
||||
{
|
||||
desc: "invalid query params type",
|
||||
in: `
|
||||
sources:
|
||||
my-mysql-instance:
|
||||
kind: mysql
|
||||
host: 0.0.0.0
|
||||
port: 3306
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
queryParams: not-a-map
|
||||
`,
|
||||
err: "string was used where mapping is expected",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
@@ -145,9 +198,32 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if errStr != tc.err {
|
||||
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
|
||||
if !strings.Contains(errStr, tc.err) {
|
||||
t.Fatalf("unexpected error: got %q, want substring %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFailInitialization test error during initialization without attempting a DB connection.
|
||||
func TestFailInitialization(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := mysql.Config{
|
||||
Name: "instance",
|
||||
Kind: "mysql",
|
||||
Host: "localhost",
|
||||
Port: "3306",
|
||||
Database: "db",
|
||||
User: "user",
|
||||
Password: "pass",
|
||||
QueryTimeout: "abc", // invalid duration
|
||||
}
|
||||
_, err := cfg.Initialize(context.Background(), noop.NewTracerProvider().Tracer("test"))
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for invalid queryTimeout, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid queryTimeout") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user