diff --git a/.ci/integration.cloudbuild.yaml b/.ci/integration.cloudbuild.yaml index 8868ffe7dd..eb052475ac 100644 --- a/.ci/integration.cloudbuild.yaml +++ b/.ci/integration.cloudbuild.yaml @@ -43,7 +43,7 @@ steps: args: - -c - | - go test -race -v -tags=integration,cloudsql ./tests + go test -race -v -tags=integration,cloudsqlpg ./tests - id: "alloydb-pg" name: golang:1 @@ -139,6 +139,27 @@ steps: - | go test -race -v -tags=integration,cloudsqlmssql ./tests + - id: "cloud-sql-mysql" + name: golang:1 + waitFor: ["install-dependencies"] + entrypoint: /bin/bash + env: + - "GOPATH=/gopath" + - "CLOUD_SQL_MYSQL_PROJECT=$PROJECT_ID" + - "CLOUD_SQL_MYSQL_INSTANCE=$_CLOUD_SQL_MYSQL_INSTANCE" + - "CLOUD_SQL_MYSQL_DATABASE=$_DATABASE_NAME" + - "CLOUD_SQL_MYSQL_REGION=$_REGION" + - "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL" + secretEnv: + ["CLOUD_SQL_MYSQL_USER", "CLOUD_SQL_MYSQL_PASS", "CLIENT_ID"] + volumes: + - name: "go" + path: "/gopath" + args: + - -c + - | + go test -race -v -tags=integration,cloudsqlmysql ./tests + availableSecrets: secretManager: - versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest @@ -163,6 +184,10 @@ availableSecrets: env: CLOUD_SQL_MSSQL_USER - versionName: projects/$PROJECT_ID/secrets/cloud_sql_mssql_pass/versions/latest env: CLOUD_SQL_MSSQL_PASS + - versionName: projects/$PROJECT_ID/secrets/cloud_sql_mysql_user/versions/latest + env: CLOUD_SQL_MYSQL_USER + - versionName: projects/$PROJECT_ID/secrets/cloud_sql_mysql_pass/versions/latest + env: CLOUD_SQL_MYSQL_PASS options: logging: CLOUD_LOGGING_ONLY @@ -183,3 +208,4 @@ substitutions: _SPANNER_INSTANCE: "spanner-testing" _NEO4J_DATABASE: "neo4j" _CLOUD_SQL_MSSQL_INSTANCE: "cloud-sql-mssql-testing" + _CLOUD_SQL_MYSQL_INSTANCE: "cloud-sql-mysql-testing" diff --git a/.golangci.yaml b/.golangci.yaml index d8dd174222..64acdde110 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -33,7 +33,10 @@ issues: run: build-tags: - integration - - cloudsql + - cloudsqlpg - postgres - alloydb - spanner + - cloudsqlmssql + - cloudsqlmysql + - neo4j diff --git a/go.mod b/go.mod index 3ff2ce4f84..4d27ca4376 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,6 @@ require ( github.com/go-chi/chi/v5 v5.1.0 github.com/go-chi/httplog/v2 v2.1.1 github.com/go-chi/render v1.0.3 - github.com/go-sql-driver/mysql v1.8.1 github.com/goccy/go-yaml v1.15.13 github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.6.0 @@ -56,6 +55,7 @@ require ( github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect diff --git a/internal/sources/cloudsqlmysql/cloud_sql_mysql.go b/internal/sources/cloudsqlmysql/cloud_sql_mysql.go index 89846873c1..e74a6d14eb 100644 --- a/internal/sources/cloudsqlmysql/cloud_sql_mysql.go +++ b/internal/sources/cloudsqlmysql/cloud_sql_mysql.go @@ -18,11 +18,11 @@ import ( "context" "database/sql" "fmt" - "net" + "slices" "strings" "cloud.google.com/go/cloudsqlconn" - "github.com/go-sql-driver/mysql" + "cloud.google.com/go/cloudsqlconn/mysql/mysql" "github.com/googleapis/genai-toolbox/internal/sources" "go.opentelemetry.io/otel/trace" ) @@ -104,24 +104,22 @@ func initCloudSQLMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, n if err != nil { return nil, err } - d, err := cloudsqlconn.NewDialer(context.Background()) - if err != nil { - return nil, fmt.Errorf("unable to parse connection uri: %w", err) + + if !slices.Contains(sql.Drivers(), "cloudsql-mysql") { + _, err = mysql.RegisterDriver("cloudsql-mysql", cloudsqlconn.WithDefaultDialOptions(dialOpts...)) + if err != nil { + return nil, fmt.Errorf("unable to register driver: %w", err) + } } // Tell the driver to use the Cloud SQL Go Connector to create connections - i := fmt.Sprintf("%s:%s:%s", project, region, instance) - mysql.RegisterDialContext("cloudsqlconn", func(ctx context.Context, addr string) (net.Conn, error) { - return d.Dial(ctx, i, dialOpts...) - }) - - // Configure the driver to connect to the database - dbURI := fmt.Sprintf("%s:%s@cloudsqlconn(localhost:3306)/%s?parseTime=true", user, pass, dbname) - - // Interact with the driver directly as you normally would - pool, err := sql.Open("mysql", dbURI) + dsn := fmt.Sprintf("%s:%s@cloudsql-mysql(%s:%s:%s)/%s", user, pass, project, region, instance, dbname) + db, err := sql.Open( + "cloudsql-mysql", + dsn, + ) if err != nil { - return nil, fmt.Errorf("sql.Open: %w", err) + return nil, err } - return pool, nil + return db, nil } diff --git a/tests/auth_test.go b/tests/auth_test.go index 9c26a8a7ad..721b6eff87 100644 --- a/tests/auth_test.go +++ b/tests/auth_test.go @@ -65,6 +65,8 @@ func RunGoogleAuthenticatedParameterTest(t *testing.T, sourceConfig map[string]a statement = fmt.Sprintf("SELECT * FROM %s WHERE email = $1;", tableName) case strings.EqualFold(toolKind, "mssql-sql"): statement = fmt.Sprintf("SELECT * FROM %s WHERE email = @email;", tableName) + case strings.EqualFold(toolKind, "mysql-sql"): + statement = fmt.Sprintf("SELECT * FROM %s WHERE email = ?;", tableName) default: t.Fatalf("invalid tool kind: %s", toolKind) } @@ -132,7 +134,7 @@ func RunGoogleAuthenticatedParameterTest(t *testing.T, sourceConfig map[string]a // Tools using database/sql interface only outputs `int64` instead of `int32` var wantString string switch toolKind { - case "mssql-sql": + case "mssql-sql", "mysql-sql": wantString = fmt.Sprintf("Stub tool call for \"my-auth-tool\"! Parameters parsed: [{\"email\" \"%s\"}] \n Output: [%%!s(int64=1) Alice %s]", SERVICE_ACCOUNT_EMAIL, SERVICE_ACCOUNT_EMAIL) default: wantString = fmt.Sprintf("Stub tool call for \"my-auth-tool\"! Parameters parsed: [{\"email\" \"%s\"}] \n Output: [%%!s(int32=1) Alice %s]", SERVICE_ACCOUNT_EMAIL, SERVICE_ACCOUNT_EMAIL) @@ -216,7 +218,7 @@ func RunAuthRequiredToolInvocationTest(t *testing.T, sourceConfig map[string]any // Tools using database/sql interface only outputs `int64` instead of `int32` var wantString string switch toolKind { - case "mssql-sql": + case "mssql-sql", "mysql-sql": wantString = "Stub tool call for \"my-auth-tool\"! Parameters parsed: [] \n Output: [%!s(int64=1)]" default: wantString = "Stub tool call for \"my-auth-tool\"! Parameters parsed: [] \n Output: [%!s(int32=1)]" diff --git a/tests/cloud_sql_mssql_integration_test.go b/tests/cloud_sql_mssql_integration_test.go index cea51e0230..ce0d49af2a 100644 --- a/tests/cloud_sql_mssql_integration_test.go +++ b/tests/cloud_sql_mssql_integration_test.go @@ -78,24 +78,13 @@ func requireCloudSQLMssqlVars(t *testing.T) map[string]any { } } -func getDialOpts(ipType string) ([]cloudsqlconn.DialOption, error) { - switch strings.ToLower(ipType) { - case "private": - return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil - case "public": - return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil - default: - return nil, fmt.Errorf("invalid ipType %s", ipType) - } -} - // Copied over from cloud_sql_mssql.go func initCloudSQLMssqlConnection(project, region, instance, ipAddress, ipType, user, pass, dbname string) (*sql.DB, error) { // Create dsn dsn := fmt.Sprintf("sqlserver://%s:%s@%s?database=%s&cloudsql=%s:%s:%s", user, pass, ipAddress, dbname, project, region, instance) // Get dial options - dialOpts, err := getDialOpts(ipType) + dialOpts, err := GetCloudSQLDialOpts(ipType) if err != nil { return nil, err } @@ -243,7 +232,7 @@ func TestCloudSQLMssql(t *testing.T) { } // Set up tool calling with parameters test table -func setupParamTest(t *testing.T, tableName string) (func(*testing.T), error) { +func setupCloudSQLMssqlParamTest(t *testing.T, tableName string) (func(*testing.T), error) { // Set up Tool invocation with parameters test db, err := initCloudSQLMssqlConnection(CLOUD_SQL_MSSQL_PROJECT, CLOUD_SQL_MSSQL_REGION, CLOUD_SQL_MSSQL_INSTANCE, CLOUD_SQL_MSSQL_IP, "public", CLOUD_SQL_MSSQL_USER, CLOUD_SQL_MSSQL_PASS, CLOUD_SQL_MSSQL_DATABASE) if err != nil { @@ -285,7 +274,7 @@ func setupParamTest(t *testing.T, tableName string) (func(*testing.T), error) { }, nil } -func TestToolInvocationWithParams(t *testing.T) { +func TestToolInvocationCloudSQLMssqlWithParams(t *testing.T) { // create source config sourceConfig := requireCloudSQLMssqlVars(t) @@ -293,7 +282,7 @@ func TestToolInvocationWithParams(t *testing.T) { tableName := "param_test_table_" + strings.Replace(uuid.New().String(), "-", "", -1) // test setup function reterns teardown function - teardownTest, err := setupParamTest(t, tableName) + teardownTest, err := setupCloudSQLMssqlParamTest(t, tableName) if err != nil { t.Fatalf("Unable to set up auth test: %s", err) } diff --git a/tests/cloud_sql_mysql_integration_test.go b/tests/cloud_sql_mysql_integration_test.go new file mode 100644 index 0000000000..716e7bf36f --- /dev/null +++ b/tests/cloud_sql_mysql_integration_test.go @@ -0,0 +1,299 @@ +//go:build integration && cloudsqlmysql + +// +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "reflect" + "regexp" + "slices" + "strings" + "testing" + "time" + + "cloud.google.com/go/cloudsqlconn" + "cloud.google.com/go/cloudsqlconn/mysql/mysql" + "github.com/google/uuid" +) + +var ( + CLOUD_SQL_MYSQL_PROJECT = os.Getenv("CLOUD_SQL_MYSQL_PROJECT") + CLOUD_SQL_MYSQL_REGION = os.Getenv("CLOUD_SQL_MYSQL_REGION") + CLOUD_SQL_MYSQL_INSTANCE = os.Getenv("CLOUD_SQL_MYSQL_INSTANCE") + CLOUD_SQL_MYSQL_DATABASE = os.Getenv("CLOUD_SQL_MYSQL_DATABASE") + CLOUD_SQL_MYSQL_USER = os.Getenv("CLOUD_SQL_MYSQL_USER") + CLOUD_SQL_MYSQL_PASS = os.Getenv("CLOUD_SQL_MYSQL_PASS") +) + +func requireCloudSQLMySQLVars(t *testing.T) map[string]any { + switch "" { + case CLOUD_SQL_MYSQL_PROJECT: + t.Fatal("'CLOUD_SQL_MYSQL_PROJECT' not set") + case CLOUD_SQL_MYSQL_REGION: + t.Fatal("'CLOUD_SQL_MYSQL_REGION' not set") + case CLOUD_SQL_MYSQL_INSTANCE: + t.Fatal("'CLOUD_SQL_MYSQL_INSTANCE' not set") + case CLOUD_SQL_MYSQL_DATABASE: + t.Fatal("'CLOUD_SQL_MYSQL_DATABASE' not set") + case CLOUD_SQL_MYSQL_USER: + t.Fatal("'CLOUD_SQL_MYSQL_USER' not set") + case CLOUD_SQL_MYSQL_PASS: + t.Fatal("'CLOUD_SQL_MYSQL_PASS' not set") + } + + return map[string]any{ + "kind": "cloud-sql-mysql", + "project": CLOUD_SQL_MYSQL_PROJECT, + "instance": CLOUD_SQL_MYSQL_INSTANCE, + "region": CLOUD_SQL_MYSQL_REGION, + "database": CLOUD_SQL_MYSQL_DATABASE, + "user": CLOUD_SQL_MYSQL_USER, + "password": CLOUD_SQL_MYSQL_PASS, + } +} + +// Copied over from cloud_sql_mysql.go +func initCloudSQLMySQLConnectionPool(project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) { + + // Create a new dialer with options + dialOpts, err := GetCloudSQLDialOpts(ipType) + if err != nil { + return nil, err + } + + if !slices.Contains(sql.Drivers(), "cloudsql-mysql") { + _, err = mysql.RegisterDriver("cloudsql-mysql", cloudsqlconn.WithDefaultDialOptions(dialOpts...)) + if err != nil { + return nil, fmt.Errorf("unable to register driver: %w", err) + } + } + + // Tell the driver to use the Cloud SQL Go Connector to create connections + dsn := fmt.Sprintf("%s:%s@cloudsql-mysql(%s:%s:%s)/%s", user, pass, project, region, instance, dbname) + db, err := sql.Open( + "cloudsql-mysql", + dsn, + ) + if err != nil { + return nil, err + } + return db, nil +} + +func TestCloudSQLMySQL(t *testing.T) { + sourceConfig := requireCloudSQLMySQLVars(t) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + var args []string + + // Write config into a file and pass it to command + toolsFile := map[string]any{ + "sources": map[string]any{ + "my-mysql-instance": sourceConfig, + }, + "tools": map[string]any{ + "my-simple-tool": map[string]any{ + "kind": "mysql-sql", + "source": "my-mysql-instance", + "description": "Simple tool to test end to end functionality.", + "statement": "SELECT 1;", + }, + }, + } + cmd, cleanup, err := StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`)) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + // Test tool get endpoint + tcs := []struct { + name string + api string + want map[string]any + }{ + { + name: "get my-simple-tool", + api: "http://127.0.0.1:5000/api/tool/my-simple-tool/", + want: map[string]any{ + "my-simple-tool": map[string]any{ + "description": "Simple tool to test end to end functionality.", + "parameters": []any{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + resp, err := http.Get(tc.api) + if err != nil { + t.Fatalf("error when sending a request: %s", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var body map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&body) + if err != nil { + t.Fatalf("error parsing response body") + } + + got, ok := body["tools"] + if !ok { + t.Fatalf("unable to find tools in response body") + } + if !reflect.DeepEqual(got, tc.want) { + t.Fatalf("got %q, want %q", got, tc.want) + } + }) + } + + // Test tool invoke endpoint + invokeTcs := []struct { + name string + api string + requestBody io.Reader + want string + }{ + { + name: "invoke my-simple-tool", + api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke", + requestBody: bytes.NewBuffer([]byte(`{}`)), + want: "Stub tool call for \"my-simple-tool\"! Parameters parsed: [] \n Output: [%!s(int64=1)]", + }, + } + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + resp, err := http.Post(tc.api, "application/json", tc.requestBody) + if err != nil { + t.Fatalf("error when sending a request: %s", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var body map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&body) + if err != nil { + t.Fatalf("error parsing response body") + } + got, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } + + if got != tc.want { + t.Fatalf("unexpected value: got %q, want %q", got, tc.want) + } + }) + } +} + +// Set up auth test database table +func setupCloudSQLMySQLAuthTest(t *testing.T, ctx context.Context, tableName string) func(*testing.T) { + // set up testt + pool, err := initCloudSQLMySQLConnectionPool(CLOUD_SQL_MYSQL_PROJECT, CLOUD_SQL_MYSQL_REGION, CLOUD_SQL_MYSQL_INSTANCE, "public", CLOUD_SQL_MYSQL_USER, CLOUD_SQL_MYSQL_PASS, CLOUD_SQL_MYSQL_DATABASE) + if err != nil { + t.Fatalf("unable to create Cloud SQL connection pool: %s", err) + } + + err = pool.PingContext(ctx) + if err != nil { + t.Fatalf("unable to connect to test database: %s", err) + } + + _, err = pool.QueryContext(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255), + email VARCHAR(255) + ); + `, tableName)) + if err != nil { + t.Fatalf("unable to create test table: %s", err) + } + + // Insert test data + statement := fmt.Sprintf(` + INSERT INTO %s (name, email) + VALUES (?, ?), (?, ?) + `, tableName) + + params := []any{"Alice", SERVICE_ACCOUNT_EMAIL, "Jane", "janedoe@gmail.com"} + _, err = pool.QueryContext(ctx, statement, params...) + if err != nil { + t.Fatalf("unable to insert test data: %s", err) + } + + return func(t *testing.T) { + // tear down test + _, err := pool.ExecContext(ctx, fmt.Sprintf(`DROP TABLE %s;`, tableName)) + if err != nil { + t.Errorf("Teardown failed: %s", err) + } + } +} + +func TestCloudSQLMySQLGoogleAuthenticatedParameter(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + // create test configs + sourceConfig := requireCloudSQLMySQLVars(t) + + // create table name with UUID + tableName := "auth_table_" + strings.Replace(uuid.New().String(), "-", "", -1) + + // test setup function reterns teardown function + teardownTest := setupCloudSQLMySQLAuthTest(t, ctx, tableName) + defer teardownTest(t) + + // call generic auth test helper + RunGoogleAuthenticatedParameterTest(t, sourceConfig, "mysql-sql", tableName) + +} + +func TestCloudSQLMySQLAuthRequiredToolInvocation(t *testing.T) { + // create test configs + sourceConfig := requireCloudSQLMySQLVars(t) + + // call generic auth test helper + RunAuthRequiredToolInvocationTest(t, sourceConfig, "mysql-sql") + +} diff --git a/tests/cloud_sql_pg_integration_test.go b/tests/cloud_sql_pg_integration_test.go index fc80cf6c67..bd391d8dfa 100644 --- a/tests/cloud_sql_pg_integration_test.go +++ b/tests/cloud_sql_pg_integration_test.go @@ -1,4 +1,4 @@ -//go:build integration && cloudsql +//go:build integration && cloudsqlpg // // Copyright 2024 Google LLC @@ -73,18 +73,6 @@ func requireCloudSQLPgVars(t *testing.T) map[string]any { } } -// Copied over from cloud_sql_pg.go -func getCloudSQLDialOpts(ip_type string) ([]cloudsqlconn.DialOption, error) { - switch strings.ToLower(ip_type) { - case "private": - return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil - case "public": - return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil - default: - return nil, fmt.Errorf("invalid ip_type %s", ip_type) - } -} - // Copied over from cloud_sql_pg.go func initCloudSQLPgConnectionPool(project, region, instance, ip_type, user, pass, dbname string) (*pgxpool.Pool, error) { // Configure the driver to connect to the database @@ -95,7 +83,7 @@ func initCloudSQLPgConnectionPool(project, region, instance, ip_type, user, pass } // Create a new dialer with options - dialOpts, err := getCloudSQLDialOpts(ip_type) + dialOpts, err := GetCloudSQLDialOpts(ip_type) if err != nil { return nil, err } diff --git a/tests/common_test.go b/tests/common_test.go index f768481ced..6434a0772a 100644 --- a/tests/common_test.go +++ b/tests/common_test.go @@ -34,6 +34,7 @@ import ( "testing" "time" + "cloud.google.com/go/cloudsqlconn" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/cmd" @@ -380,3 +381,15 @@ func RunSourceConnectionTest(t *testing.T, sourceConfig map[string]any, toolKind t.Fatalf("toolbox didn't start successfully: %s", err) } } + +// GetCloudSQLDialOpts returns cloud sql connector's dial option for ip type. +func GetCloudSQLDialOpts(ipType string) ([]cloudsqlconn.DialOption, error) { + switch strings.ToLower(ipType) { + case "private": + return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil + case "public": + return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil + default: + return nil, fmt.Errorf("invalid ipType %s", ipType) + } +}