mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 07:58:12 -05:00
ci(cloudsqlmysql): add integration tests (#243)
Add integration test for CloudSQL for MySQL. Added other integration tests' tag into `.golangci.yaml`, and fixing lint errors. Moved getCloudSQLDialOpts to `common_test.go` since it is used across all three cloud sql integration tests.
This commit is contained in:
@@ -43,7 +43,7 @@ steps:
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
go test -race -v -tags=integration,cloudsql ./tests
|
||||
go test -race -v -tags=integration,cloudsqlpg ./tests
|
||||
|
||||
- id: "alloydb-pg"
|
||||
name: golang:1
|
||||
@@ -139,6 +139,27 @@ steps:
|
||||
- |
|
||||
go test -race -v -tags=integration,cloudsqlmssql ./tests
|
||||
|
||||
- id: "cloud-sql-mysql"
|
||||
name: golang:1
|
||||
waitFor: ["install-dependencies"]
|
||||
entrypoint: /bin/bash
|
||||
env:
|
||||
- "GOPATH=/gopath"
|
||||
- "CLOUD_SQL_MYSQL_PROJECT=$PROJECT_ID"
|
||||
- "CLOUD_SQL_MYSQL_INSTANCE=$_CLOUD_SQL_MYSQL_INSTANCE"
|
||||
- "CLOUD_SQL_MYSQL_DATABASE=$_DATABASE_NAME"
|
||||
- "CLOUD_SQL_MYSQL_REGION=$_REGION"
|
||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||
secretEnv:
|
||||
["CLOUD_SQL_MYSQL_USER", "CLOUD_SQL_MYSQL_PASS", "CLIENT_ID"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
go test -race -v -tags=integration,cloudsqlmysql ./tests
|
||||
|
||||
availableSecrets:
|
||||
secretManager:
|
||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
|
||||
@@ -163,6 +184,10 @@ availableSecrets:
|
||||
env: CLOUD_SQL_MSSQL_USER
|
||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_mssql_pass/versions/latest
|
||||
env: CLOUD_SQL_MSSQL_PASS
|
||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_mysql_user/versions/latest
|
||||
env: CLOUD_SQL_MYSQL_USER
|
||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_mysql_pass/versions/latest
|
||||
env: CLOUD_SQL_MYSQL_PASS
|
||||
|
||||
options:
|
||||
logging: CLOUD_LOGGING_ONLY
|
||||
@@ -183,3 +208,4 @@ substitutions:
|
||||
_SPANNER_INSTANCE: "spanner-testing"
|
||||
_NEO4J_DATABASE: "neo4j"
|
||||
_CLOUD_SQL_MSSQL_INSTANCE: "cloud-sql-mssql-testing"
|
||||
_CLOUD_SQL_MYSQL_INSTANCE: "cloud-sql-mysql-testing"
|
||||
|
||||
@@ -33,7 +33,10 @@ issues:
|
||||
run:
|
||||
build-tags:
|
||||
- integration
|
||||
- cloudsql
|
||||
- cloudsqlpg
|
||||
- postgres
|
||||
- alloydb
|
||||
- spanner
|
||||
- cloudsqlmssql
|
||||
- cloudsqlmysql
|
||||
- neo4j
|
||||
|
||||
2
go.mod
2
go.mod
@@ -13,7 +13,6 @@ require (
|
||||
github.com/go-chi/chi/v5 v5.1.0
|
||||
github.com/go-chi/httplog/v2 v2.1.1
|
||||
github.com/go-chi/render v1.0.3
|
||||
github.com/go-sql-driver/mysql v1.8.1
|
||||
github.com/goccy/go-yaml v1.15.13
|
||||
github.com/google/go-cmp v0.6.0
|
||||
github.com/google/uuid v1.6.0
|
||||
@@ -56,6 +55,7 @@ require (
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-logr/logr v1.4.2 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
|
||||
|
||||
@@ -18,11 +18,11 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"cloud.google.com/go/cloudsqlconn"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"cloud.google.com/go/cloudsqlconn/mysql/mysql"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
@@ -104,24 +104,22 @@ func initCloudSQLMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, n
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d, err := cloudsqlconn.NewDialer(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
|
||||
|
||||
if !slices.Contains(sql.Drivers(), "cloudsql-mysql") {
|
||||
_, err = mysql.RegisterDriver("cloudsql-mysql", cloudsqlconn.WithDefaultDialOptions(dialOpts...))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to register driver: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Tell the driver to use the Cloud SQL Go Connector to create connections
|
||||
i := fmt.Sprintf("%s:%s:%s", project, region, instance)
|
||||
mysql.RegisterDialContext("cloudsqlconn", func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return d.Dial(ctx, i, dialOpts...)
|
||||
})
|
||||
|
||||
// Configure the driver to connect to the database
|
||||
dbURI := fmt.Sprintf("%s:%s@cloudsqlconn(localhost:3306)/%s?parseTime=true", user, pass, dbname)
|
||||
|
||||
// Interact with the driver directly as you normally would
|
||||
pool, err := sql.Open("mysql", dbURI)
|
||||
dsn := fmt.Sprintf("%s:%s@cloudsql-mysql(%s:%s:%s)/%s", user, pass, project, region, instance, dbname)
|
||||
db, err := sql.Open(
|
||||
"cloudsql-mysql",
|
||||
dsn,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sql.Open: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
return pool, nil
|
||||
return db, nil
|
||||
}
|
||||
|
||||
@@ -65,6 +65,8 @@ func RunGoogleAuthenticatedParameterTest(t *testing.T, sourceConfig map[string]a
|
||||
statement = fmt.Sprintf("SELECT * FROM %s WHERE email = $1;", tableName)
|
||||
case strings.EqualFold(toolKind, "mssql-sql"):
|
||||
statement = fmt.Sprintf("SELECT * FROM %s WHERE email = @email;", tableName)
|
||||
case strings.EqualFold(toolKind, "mysql-sql"):
|
||||
statement = fmt.Sprintf("SELECT * FROM %s WHERE email = ?;", tableName)
|
||||
default:
|
||||
t.Fatalf("invalid tool kind: %s", toolKind)
|
||||
}
|
||||
@@ -132,7 +134,7 @@ func RunGoogleAuthenticatedParameterTest(t *testing.T, sourceConfig map[string]a
|
||||
// Tools using database/sql interface only outputs `int64` instead of `int32`
|
||||
var wantString string
|
||||
switch toolKind {
|
||||
case "mssql-sql":
|
||||
case "mssql-sql", "mysql-sql":
|
||||
wantString = fmt.Sprintf("Stub tool call for \"my-auth-tool\"! Parameters parsed: [{\"email\" \"%s\"}] \n Output: [%%!s(int64=1) Alice %s]", SERVICE_ACCOUNT_EMAIL, SERVICE_ACCOUNT_EMAIL)
|
||||
default:
|
||||
wantString = fmt.Sprintf("Stub tool call for \"my-auth-tool\"! Parameters parsed: [{\"email\" \"%s\"}] \n Output: [%%!s(int32=1) Alice %s]", SERVICE_ACCOUNT_EMAIL, SERVICE_ACCOUNT_EMAIL)
|
||||
@@ -216,7 +218,7 @@ func RunAuthRequiredToolInvocationTest(t *testing.T, sourceConfig map[string]any
|
||||
// Tools using database/sql interface only outputs `int64` instead of `int32`
|
||||
var wantString string
|
||||
switch toolKind {
|
||||
case "mssql-sql":
|
||||
case "mssql-sql", "mysql-sql":
|
||||
wantString = "Stub tool call for \"my-auth-tool\"! Parameters parsed: [] \n Output: [%!s(int64=1)]"
|
||||
default:
|
||||
wantString = "Stub tool call for \"my-auth-tool\"! Parameters parsed: [] \n Output: [%!s(int32=1)]"
|
||||
|
||||
@@ -78,24 +78,13 @@ func requireCloudSQLMssqlVars(t *testing.T) map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
func getDialOpts(ipType string) ([]cloudsqlconn.DialOption, error) {
|
||||
switch strings.ToLower(ipType) {
|
||||
case "private":
|
||||
return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil
|
||||
case "public":
|
||||
return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid ipType %s", ipType)
|
||||
}
|
||||
}
|
||||
|
||||
// Copied over from cloud_sql_mssql.go
|
||||
func initCloudSQLMssqlConnection(project, region, instance, ipAddress, ipType, user, pass, dbname string) (*sql.DB, error) {
|
||||
// Create dsn
|
||||
dsn := fmt.Sprintf("sqlserver://%s:%s@%s?database=%s&cloudsql=%s:%s:%s", user, pass, ipAddress, dbname, project, region, instance)
|
||||
|
||||
// Get dial options
|
||||
dialOpts, err := getDialOpts(ipType)
|
||||
dialOpts, err := GetCloudSQLDialOpts(ipType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -243,7 +232,7 @@ func TestCloudSQLMssql(t *testing.T) {
|
||||
}
|
||||
|
||||
// Set up tool calling with parameters test table
|
||||
func setupParamTest(t *testing.T, tableName string) (func(*testing.T), error) {
|
||||
func setupCloudSQLMssqlParamTest(t *testing.T, tableName string) (func(*testing.T), error) {
|
||||
// Set up Tool invocation with parameters test
|
||||
db, err := initCloudSQLMssqlConnection(CLOUD_SQL_MSSQL_PROJECT, CLOUD_SQL_MSSQL_REGION, CLOUD_SQL_MSSQL_INSTANCE, CLOUD_SQL_MSSQL_IP, "public", CLOUD_SQL_MSSQL_USER, CLOUD_SQL_MSSQL_PASS, CLOUD_SQL_MSSQL_DATABASE)
|
||||
if err != nil {
|
||||
@@ -285,7 +274,7 @@ func setupParamTest(t *testing.T, tableName string) (func(*testing.T), error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestToolInvocationWithParams(t *testing.T) {
|
||||
func TestToolInvocationCloudSQLMssqlWithParams(t *testing.T) {
|
||||
// create source config
|
||||
sourceConfig := requireCloudSQLMssqlVars(t)
|
||||
|
||||
@@ -293,7 +282,7 @@ func TestToolInvocationWithParams(t *testing.T) {
|
||||
tableName := "param_test_table_" + strings.Replace(uuid.New().String(), "-", "", -1)
|
||||
|
||||
// test setup function reterns teardown function
|
||||
teardownTest, err := setupParamTest(t, tableName)
|
||||
teardownTest, err := setupCloudSQLMssqlParamTest(t, tableName)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to set up auth test: %s", err)
|
||||
}
|
||||
|
||||
299
tests/cloud_sql_mysql_integration_test.go
Normal file
299
tests/cloud_sql_mysql_integration_test.go
Normal file
@@ -0,0 +1,299 @@
|
||||
//go:build integration && cloudsqlmysql
|
||||
|
||||
//
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cloud.google.com/go/cloudsqlconn"
|
||||
"cloud.google.com/go/cloudsqlconn/mysql/mysql"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var (
|
||||
CLOUD_SQL_MYSQL_PROJECT = os.Getenv("CLOUD_SQL_MYSQL_PROJECT")
|
||||
CLOUD_SQL_MYSQL_REGION = os.Getenv("CLOUD_SQL_MYSQL_REGION")
|
||||
CLOUD_SQL_MYSQL_INSTANCE = os.Getenv("CLOUD_SQL_MYSQL_INSTANCE")
|
||||
CLOUD_SQL_MYSQL_DATABASE = os.Getenv("CLOUD_SQL_MYSQL_DATABASE")
|
||||
CLOUD_SQL_MYSQL_USER = os.Getenv("CLOUD_SQL_MYSQL_USER")
|
||||
CLOUD_SQL_MYSQL_PASS = os.Getenv("CLOUD_SQL_MYSQL_PASS")
|
||||
)
|
||||
|
||||
func requireCloudSQLMySQLVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case CLOUD_SQL_MYSQL_PROJECT:
|
||||
t.Fatal("'CLOUD_SQL_MYSQL_PROJECT' not set")
|
||||
case CLOUD_SQL_MYSQL_REGION:
|
||||
t.Fatal("'CLOUD_SQL_MYSQL_REGION' not set")
|
||||
case CLOUD_SQL_MYSQL_INSTANCE:
|
||||
t.Fatal("'CLOUD_SQL_MYSQL_INSTANCE' not set")
|
||||
case CLOUD_SQL_MYSQL_DATABASE:
|
||||
t.Fatal("'CLOUD_SQL_MYSQL_DATABASE' not set")
|
||||
case CLOUD_SQL_MYSQL_USER:
|
||||
t.Fatal("'CLOUD_SQL_MYSQL_USER' not set")
|
||||
case CLOUD_SQL_MYSQL_PASS:
|
||||
t.Fatal("'CLOUD_SQL_MYSQL_PASS' not set")
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"kind": "cloud-sql-mysql",
|
||||
"project": CLOUD_SQL_MYSQL_PROJECT,
|
||||
"instance": CLOUD_SQL_MYSQL_INSTANCE,
|
||||
"region": CLOUD_SQL_MYSQL_REGION,
|
||||
"database": CLOUD_SQL_MYSQL_DATABASE,
|
||||
"user": CLOUD_SQL_MYSQL_USER,
|
||||
"password": CLOUD_SQL_MYSQL_PASS,
|
||||
}
|
||||
}
|
||||
|
||||
// Copied over from cloud_sql_mysql.go
|
||||
func initCloudSQLMySQLConnectionPool(project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) {
|
||||
|
||||
// Create a new dialer with options
|
||||
dialOpts, err := GetCloudSQLDialOpts(ipType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !slices.Contains(sql.Drivers(), "cloudsql-mysql") {
|
||||
_, err = mysql.RegisterDriver("cloudsql-mysql", cloudsqlconn.WithDefaultDialOptions(dialOpts...))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to register driver: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Tell the driver to use the Cloud SQL Go Connector to create connections
|
||||
dsn := fmt.Sprintf("%s:%s@cloudsql-mysql(%s:%s:%s)/%s", user, pass, project, region, instance, dbname)
|
||||
db, err := sql.Open(
|
||||
"cloudsql-mysql",
|
||||
dsn,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func TestCloudSQLMySQL(t *testing.T) {
|
||||
sourceConfig := requireCloudSQLMySQLVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := map[string]any{
|
||||
"sources": map[string]any{
|
||||
"my-mysql-instance": sourceConfig,
|
||||
},
|
||||
"tools": map[string]any{
|
||||
"my-simple-tool": map[string]any{
|
||||
"kind": "mysql-sql",
|
||||
"source": "my-mysql-instance",
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"statement": "SELECT 1;",
|
||||
},
|
||||
},
|
||||
}
|
||||
cmd, cleanup, err := StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
// Test tool get endpoint
|
||||
tcs := []struct {
|
||||
name string
|
||||
api string
|
||||
want map[string]any
|
||||
}{
|
||||
{
|
||||
name: "get my-simple-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/",
|
||||
want: map[string]any{
|
||||
"my-simple-tool": map[string]any{
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"parameters": []any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, err := http.Get(tc.api)
|
||||
if err != nil {
|
||||
t.Fatalf("error when sending a request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
|
||||
got, ok := body["tools"]
|
||||
if !ok {
|
||||
t.Fatalf("unable to find tools in response body")
|
||||
}
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Fatalf("got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "invoke my-simple-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
want: "Stub tool call for \"my-simple-tool\"! Parameters parsed: [] \n Output: [%!s(int64=1)]",
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, err := http.Post(tc.api, "application/json", tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("error when sending a request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Set up auth test database table
|
||||
func setupCloudSQLMySQLAuthTest(t *testing.T, ctx context.Context, tableName string) func(*testing.T) {
|
||||
// set up testt
|
||||
pool, err := initCloudSQLMySQLConnectionPool(CLOUD_SQL_MYSQL_PROJECT, CLOUD_SQL_MYSQL_REGION, CLOUD_SQL_MYSQL_INSTANCE, "public", CLOUD_SQL_MYSQL_USER, CLOUD_SQL_MYSQL_PASS, CLOUD_SQL_MYSQL_DATABASE)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create Cloud SQL connection pool: %s", err)
|
||||
}
|
||||
|
||||
err = pool.PingContext(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to connect to test database: %s", err)
|
||||
}
|
||||
|
||||
_, err = pool.QueryContext(ctx, fmt.Sprintf(`
|
||||
CREATE TABLE %s (
|
||||
id INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
|
||||
name VARCHAR(255),
|
||||
email VARCHAR(255)
|
||||
);
|
||||
`, tableName))
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create test table: %s", err)
|
||||
}
|
||||
|
||||
// Insert test data
|
||||
statement := fmt.Sprintf(`
|
||||
INSERT INTO %s (name, email)
|
||||
VALUES (?, ?), (?, ?)
|
||||
`, tableName)
|
||||
|
||||
params := []any{"Alice", SERVICE_ACCOUNT_EMAIL, "Jane", "janedoe@gmail.com"}
|
||||
_, err = pool.QueryContext(ctx, statement, params...)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to insert test data: %s", err)
|
||||
}
|
||||
|
||||
return func(t *testing.T) {
|
||||
// tear down test
|
||||
_, err := pool.ExecContext(ctx, fmt.Sprintf(`DROP TABLE %s;`, tableName))
|
||||
if err != nil {
|
||||
t.Errorf("Teardown failed: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudSQLMySQLGoogleAuthenticatedParameter(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// create test configs
|
||||
sourceConfig := requireCloudSQLMySQLVars(t)
|
||||
|
||||
// create table name with UUID
|
||||
tableName := "auth_table_" + strings.Replace(uuid.New().String(), "-", "", -1)
|
||||
|
||||
// test setup function reterns teardown function
|
||||
teardownTest := setupCloudSQLMySQLAuthTest(t, ctx, tableName)
|
||||
defer teardownTest(t)
|
||||
|
||||
// call generic auth test helper
|
||||
RunGoogleAuthenticatedParameterTest(t, sourceConfig, "mysql-sql", tableName)
|
||||
|
||||
}
|
||||
|
||||
func TestCloudSQLMySQLAuthRequiredToolInvocation(t *testing.T) {
|
||||
// create test configs
|
||||
sourceConfig := requireCloudSQLMySQLVars(t)
|
||||
|
||||
// call generic auth test helper
|
||||
RunAuthRequiredToolInvocationTest(t, sourceConfig, "mysql-sql")
|
||||
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build integration && cloudsql
|
||||
//go:build integration && cloudsqlpg
|
||||
|
||||
//
|
||||
// Copyright 2024 Google LLC
|
||||
@@ -73,18 +73,6 @@ func requireCloudSQLPgVars(t *testing.T) map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
// Copied over from cloud_sql_pg.go
|
||||
func getCloudSQLDialOpts(ip_type string) ([]cloudsqlconn.DialOption, error) {
|
||||
switch strings.ToLower(ip_type) {
|
||||
case "private":
|
||||
return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil
|
||||
case "public":
|
||||
return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid ip_type %s", ip_type)
|
||||
}
|
||||
}
|
||||
|
||||
// Copied over from cloud_sql_pg.go
|
||||
func initCloudSQLPgConnectionPool(project, region, instance, ip_type, user, pass, dbname string) (*pgxpool.Pool, error) {
|
||||
// Configure the driver to connect to the database
|
||||
@@ -95,7 +83,7 @@ func initCloudSQLPgConnectionPool(project, region, instance, ip_type, user, pass
|
||||
}
|
||||
|
||||
// Create a new dialer with options
|
||||
dialOpts, err := getCloudSQLDialOpts(ip_type)
|
||||
dialOpts, err := GetCloudSQLDialOpts(ip_type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cloud.google.com/go/cloudsqlconn"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/cmd"
|
||||
@@ -380,3 +381,15 @@ func RunSourceConnectionTest(t *testing.T, sourceConfig map[string]any, toolKind
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetCloudSQLDialOpts returns cloud sql connector's dial option for ip type.
|
||||
func GetCloudSQLDialOpts(ipType string) ([]cloudsqlconn.DialOption, error) {
|
||||
switch strings.ToLower(ipType) {
|
||||
case "private":
|
||||
return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil
|
||||
case "public":
|
||||
return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid ipType %s", ipType)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user