ci: Add Spanner integration tests (#441)

Add Spanner integration tests
This commit is contained in:
Wenxin Du
2025-04-16 17:04:50 -04:00
committed by GitHub
parent 4dba0df12d
commit 29560d66a0
5 changed files with 153 additions and 103 deletions

View File

@@ -115,6 +115,8 @@ steps:
- "SPANNER_PROJECT=$PROJECT_ID"
- "SPANNER_DATABASE=$_DATABASE_NAME"
- "SPANNER_INSTANCE=$_SPANNER_INSTANCE"
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
secretEnv: ["CLIENT_ID"]
volumes:
- name: "go"
path: "/gopath"

1
go.mod
View File

@@ -41,6 +41,7 @@ require (
cloud.google.com/go/auth v0.15.0 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
cloud.google.com/go/compute/metadata v0.6.0 // indirect
cloud.google.com/go/iam v1.4.2 // indirect
cloud.google.com/go/longrunning v0.6.6 // indirect
cloud.google.com/go/monitoring v1.24.1 // indirect
cloud.google.com/go/trace v1.11.3 // indirect

View File

@@ -249,3 +249,26 @@ func GetMysqlLAuthToolInfo(tableName string) (string, string, string, []any) {
params := []any{"Alice", SERVICE_ACCOUNT_EMAIL, "Jane", "janedoe@gmail.com"}
return create_statement, insert_statement, tool_statement, params
}
// GetSpannerToolInfo returns statements and param for my-param-tool for spanner-sql kind
func GetSpannerParamToolInfo(tableName string) (string, string, string, map[string]any) {
create_statement := fmt.Sprintf("CREATE TABLE %s (id INT64, name STRING(MAX)) PRIMARY KEY (id)", tableName)
insert_statement := fmt.Sprintf("INSERT INTO %s (id, name) VALUES (1, @name1), (2, @name2), (3, @name3)", tableName)
tool_statement := fmt.Sprintf("SELECT * FROM %s WHERE id = @id OR name = @name", tableName)
params := map[string]any{"name1": "Alice", "name2": "Jane", "name3": "Sid"}
return create_statement, insert_statement, tool_statement, params
}
// GetSpannerAuthToolInfo returns statements and param of my-auth-tool for spanner-sql kind
func GetSpannerAuthToolInfo(tableName string) (string, string, string, map[string]any) {
create_statement := fmt.Sprintf("CREATE TABLE %s (id INT64, name STRING(MAX), email STRING(MAX)) PRIMARY KEY (id)", tableName)
insert_statement := fmt.Sprintf("INSERT INTO %s (id, name, email) VALUES (1, @name1, @email1), (2, @name2, @email2)", tableName)
tool_statement := fmt.Sprintf("SELECT name FROM %s WHERE email = @email", tableName)
params := map[string]any{
"name1": "Alice",
"email1": SERVICE_ACCOUNT_EMAIL,
"name2": "Jane",
"email2": "janedoe@gmail.com",
}
return create_statement, insert_statement, tool_statement, params
}

View File

@@ -17,16 +17,17 @@
package tests
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"fmt"
"os"
"reflect"
"regexp"
"strings"
"testing"
"time"
"cloud.google.com/go/spanner"
database "cloud.google.com/go/spanner/admin/database/apiv1"
"github.com/google/uuid"
)
var (
@@ -55,6 +56,33 @@ func getSpannerVars(t *testing.T) map[string]any {
}
}
func initSpannerClients(ctx context.Context, project, instance, dbname string) (*spanner.Client, *database.DatabaseAdminClient, error) {
// Configure the connection to the database
db := fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, dbname)
// Configure session pool to automatically clean inactive transactions
sessionPoolConfig := spanner.SessionPoolConfig{
TrackSessionHandles: true,
InactiveTransactionRemovalOptions: spanner.InactiveTransactionRemovalOptions{
ActionOnInactiveTransaction: spanner.WarnAndClose,
},
}
// Create Spanner client (for queries)
dataClient, err := spanner.NewClientWithConfig(context.Background(), db, spanner.ClientConfig{SessionPoolConfig: sessionPoolConfig})
if err != nil {
return nil, nil, fmt.Errorf("unable to create new Spanner client: %w", err)
}
// Create Spanner admin client (for creating databases)
adminClient, err := database.NewDatabaseAdminClient(ctx)
if err != nil {
return nil, nil, fmt.Errorf("unable to create new Spanner admin client: %w", err)
}
return dataClient, adminClient, nil
}
func TestSpannerToolEndpoints(t *testing.T) {
sourceConfig := getSpannerVars(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
@@ -62,20 +90,35 @@ func TestSpannerToolEndpoints(t *testing.T) {
var args []string
// Write config into a file and pass it to command
toolsFile := map[string]any{
"sources": map[string]any{
"my-spanner-instance": sourceConfig,
},
"tools": map[string]any{
"my-simple-tool": map[string]any{
"kind": "spanner-sql",
"source": "my-spanner-instance",
"description": "Simple tool to test end to end functionality.",
"statement": "SELECT 1;",
},
},
// Create Spanner client
dataClient, adminClient, err := initSpannerClients(ctx, SPANNER_PROJECT, SPANNER_INSTANCE, SPANNER_DATABASE)
if err != nil {
t.Fatalf("unable to create Spanner client: %s", err)
}
// create table name with UUID
tableNameParam := "param_table_" + strings.Replace(uuid.New().String(), "-", "", -1)
tableNameAuth := "auth_table_" + strings.Replace(uuid.New().String(), "-", "", -1)
// set up data for param tool
create_statement1, insert_statement1, tool_statement1, params1 := GetSpannerParamToolInfo(tableNameParam)
dbString := fmt.Sprintf(
"projects/%s/instances/%s/databases/%s",
SPANNER_PROJECT,
SPANNER_INSTANCE,
SPANNER_DATABASE,
)
teardownTable1 := SetupSpannerTable(t, ctx, adminClient, dataClient, create_statement1, insert_statement1, tableNameParam, dbString, params1)
defer teardownTable1(t)
// set up data for auth tool
create_statement2, insert_statement2, tool_statement2, params2 := GetSpannerAuthToolInfo(tableNameAuth)
teardownTable2 := SetupSpannerTable(t, ctx, adminClient, dataClient, create_statement2, insert_statement2, tableNameAuth, dbString, params2)
defer teardownTable2(t)
// Write config into a file and pass it to command
toolsFile := GetToolsConfig(sourceConfig, SPANNER_TOOL_KIND, tool_statement1, tool_statement2)
cmd, cleanup, err := StartCmd(ctx, toolsFile, args...)
if err != nil {
t.Fatalf("command initialization returned an error: %s", err)
@@ -90,88 +133,11 @@ func TestSpannerToolEndpoints(t *testing.T) {
t.Fatalf("toolbox didn't start successfully: %s", err)
}
// Test tool get endpoint
tcs := []struct {
name string
api string
want map[string]any
}{
{
name: "get my-simple-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/",
want: map[string]any{
"my-simple-tool": map[string]any{
"description": "Simple tool to test end to end functionality.",
"parameters": []any{},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
resp, err := http.Get(tc.api)
if err != nil {
t.Fatalf("error when sending a request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Fatalf("response status code is not 200")
}
RunToolGetTest(t)
var body map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&body)
if err != nil {
t.Fatalf("error parsing response body")
}
select_1_want := "[{\"\":\"1\"}]"
fail_invocation_want := `"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute client: unable to parse row: spanner: code = "InvalidArgument", desc = "Syntax error: Unexpected identifier \"SELEC\" [at 1:1]\nSELEC 1;\n^"`
got, ok := body["tools"]
if !ok {
t.Fatalf("unable to find tools in response body")
}
if !reflect.DeepEqual(got, tc.want) {
t.Fatalf("got %q, want %q", got, tc.want)
}
})
}
// Test tool invoke endpoint
invokeTcs := []struct {
name string
api string
requestBody io.Reader
want string
}{
{
name: "invoke my-simple-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{}`)),
want: "[{\"\":\"1\"}]",
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
resp, err := http.Post(tc.api, "application/json", tc.requestBody)
if err != nil {
t.Fatalf("error when sending a request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Fatalf("response status code is not 200")
}
var body map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&body)
if err != nil {
t.Fatalf("error parsing response body")
}
got, ok := body["result"].(string)
if !ok {
t.Fatalf("unable to find result in response body")
}
if got != tc.want {
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
}
})
}
RunToolInvokeTest(t, select_1_want)
RunMCPToolCallMethod(t, fail_invocation_want)
}

View File

@@ -28,6 +28,9 @@ import (
"strings"
"testing"
"cloud.google.com/go/spanner"
database "cloud.google.com/go/spanner/admin/database/apiv1"
"cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
"github.com/googleapis/genai-toolbox/internal/server/mcp"
"github.com/jackc/pgx/v5/pgxpool"
)
@@ -119,6 +122,54 @@ func SetupMySQLTable(t *testing.T, ctx context.Context, pool *sql.DB, create_sta
}
}
// SetupSpannerTable creates and inserts data into a table of tool
// compatible with spanner-sql tool
func SetupSpannerTable(t *testing.T, ctx context.Context, adminClient *database.DatabaseAdminClient, dataClient *spanner.Client, create_statement, insert_statement, tableName, dbString string, params map[string]any) func(*testing.T) {
// Create table
op, err := adminClient.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{
Database: dbString,
Statements: []string{create_statement},
})
if err != nil {
t.Fatalf("unable to start create table operation %s: %s", tableName, err)
}
err = op.Wait(ctx)
if err != nil {
t.Fatalf("unable to create test table %s: %s", tableName, err)
}
// Insert test data
_, err = dataClient.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
stmt := spanner.Statement{
SQL: insert_statement,
Params: params,
}
_, err := txn.Update(ctx, stmt)
return err
})
if err != nil {
t.Fatalf("unable to insert test data: %s", err)
}
return func(t *testing.T) {
// tear down test
op, err = adminClient.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{
Database: dbString,
Statements: []string{fmt.Sprintf("DROP TABLE %s", tableName)},
})
if err != nil {
t.Errorf("unable to start drop table operation: %s", err)
return
}
err = op.Wait(ctx)
if err != nil {
t.Errorf("Teardown failed: %s", err)
}
}
}
// RunToolGet runs the tool get endpoint
func RunToolGetTest(t *testing.T) {
// Test tool get endpoint
@@ -294,11 +345,13 @@ func RunToolInvokeTest(t *testing.T, select_1_want string) {
if !ok {
t.Fatalf("unable to find result in response body")
}
// Remove `\` and `"` for string comparison
got = strings.ReplaceAll(got, "\\", "")
want := strings.ReplaceAll(tc.want, "\\", "")
got = strings.ReplaceAll(got, "\"", "")
want = strings.ReplaceAll(want, "\"", "")
if got != want {
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
}
@@ -431,9 +484,14 @@ func RunMCPToolCallMethod(t *testing.T, fail_invocation_want string) {
defer resp.Body.Close()
got := string(bytes.TrimSpace(respBody))
if got != tc.want {
fmt.Printf("res is %s\n\n", got)
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
// Remove `\` and `"` for string comparison
got = strings.ReplaceAll(got, "\\", "")
want := strings.ReplaceAll(tc.want, "\\", "")
got = strings.ReplaceAll(got, "\"", "")
want = strings.ReplaceAll(want, "\"", "")
if !strings.Contains(got, want) {
t.Fatalf("Expected substring not found:\ngot: %q\nwant: %q (to be contained within got)", got, want)
}
})
}