mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-02-11 15:45:06 -05:00
This PR introduces a significant update to the Toolbox configuration file format, which is one of the primary **breaking changes** required for the implementation of the Advanced Control Plane. # Summary of Changes The configuration schema has been updated to enforce resource isolation and facilitate atomic, incremental updates. * Resource Isolation: Resource definitions are now separated into individual blocks, using a distinct structure for each resource type (Source, Tool, Toolset, etc.). This improves readability, management, and auditing of configuration files. * Field Name Modification: Internal field names have been modified to align with declarative methodologies. Specifically, the configuration now separates kind (general resource type, e.g., Source) from type (specific implementation, e.g., Postgres). # User Impact Existing tools.yaml configuration files are now in an outdated format. Users must eventually update their files to the new YAML format. # Mitigation & Compatibility Backward compatibility is maintained during this transition to ensure no immediate user action is required for existing files. * Immediate Backward Compatibility: The source code includes a pre-processing layer that automatically detects outdated configuration files (v1 format) and converts them to the new v2 format under the hood. * [COMING SOON] Migration Support: The new toolbox migrate subcommand will be introduced to allow users to automatically convert their old configuration files to the latest format. # Example Example for config file v2: ``` kind: sources name: my-pg-instance type: cloud-sql-postgres project: my-project region: my-region instance: my-instance database: my_db user: my_user password: my_pass --- kind: authServices name: my-google-auth type: google clientId: testing-id --- kind: tools name: example_tool type: postgres-sql source: my-pg-instance description: some description statement: SELECT * FROM SQL_STATEMENT; parameters: - name: country type: string description: some description --- kind: tools name: example_tool_2 type: postgres-sql source: my-pg-instance description: returning the number one statement: SELECT 1; --- kind: toolsets name: example_toolset tools: - example_tool ``` --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Averi Kitsch <akitsch@google.com>
1044 lines
32 KiB
Go
1044 lines
32 KiB
Go
// Copyright 2025 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package serverlessspark
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"maps"
|
|
"net/http"
|
|
"os"
|
|
"reflect"
|
|
"regexp"
|
|
"slices"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
|
|
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
|
"github.com/google/go-cmp/cmp"
|
|
"github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
|
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
|
"github.com/googleapis/genai-toolbox/tests"
|
|
"google.golang.org/api/iterator"
|
|
"google.golang.org/api/option"
|
|
"google.golang.org/protobuf/encoding/protojson"
|
|
"google.golang.org/protobuf/testing/protocmp"
|
|
)
|
|
|
|
var (
|
|
serverlessSparkLocation = os.Getenv("SERVERLESS_SPARK_LOCATION")
|
|
serverlessSparkProject = os.Getenv("SERVERLESS_SPARK_PROJECT")
|
|
serverlessSparkServiceAccount = os.Getenv("SERVERLESS_SPARK_SERVICE_ACCOUNT")
|
|
)
|
|
|
|
const (
|
|
batchURLPrefix = "https://console.cloud.google.com/dataproc/batches/"
|
|
logsURLPrefix = "https://console.cloud.google.com/logs/viewer?"
|
|
)
|
|
|
|
func getServerlessSparkVars(t *testing.T) map[string]any {
|
|
switch "" {
|
|
case serverlessSparkLocation:
|
|
t.Fatal("'SERVERLESS_SPARK_LOCATION' not set")
|
|
case serverlessSparkProject:
|
|
t.Fatal("'SERVERLESS_SPARK_PROJECT' not set")
|
|
case serverlessSparkServiceAccount:
|
|
t.Fatal("'SERVERLESS_SPARK_SERVICE_ACCOUNT' not set")
|
|
}
|
|
|
|
return map[string]any{
|
|
"type": "serverless-spark",
|
|
"project": serverlessSparkProject,
|
|
"location": serverlessSparkLocation,
|
|
}
|
|
}
|
|
|
|
func TestServerlessSparkToolEndpoints(t *testing.T) {
|
|
sourceConfig := getServerlessSparkVars(t)
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
|
defer cancel()
|
|
|
|
toolsFile := map[string]any{
|
|
"sources": map[string]any{
|
|
"my-spark": sourceConfig,
|
|
},
|
|
"authServices": map[string]any{
|
|
"my-google-auth": map[string]any{
|
|
"type": "google",
|
|
"clientId": tests.ClientId,
|
|
},
|
|
},
|
|
"tools": map[string]any{
|
|
"list-batches": map[string]any{
|
|
"type": "serverless-spark-list-batches",
|
|
"source": "my-spark",
|
|
},
|
|
"list-batches-with-auth": map[string]any{
|
|
"type": "serverless-spark-list-batches",
|
|
"source": "my-spark",
|
|
"authRequired": []string{"my-google-auth"},
|
|
},
|
|
"get-batch": map[string]any{
|
|
"type": "serverless-spark-get-batch",
|
|
"source": "my-spark",
|
|
},
|
|
"get-batch-with-auth": map[string]any{
|
|
"type": "serverless-spark-get-batch",
|
|
"source": "my-spark",
|
|
"authRequired": []string{"my-google-auth"},
|
|
},
|
|
"cancel-batch": map[string]any{
|
|
"type": "serverless-spark-cancel-batch",
|
|
"source": "my-spark",
|
|
},
|
|
"cancel-batch-with-auth": map[string]any{
|
|
"type": "serverless-spark-cancel-batch",
|
|
"source": "my-spark",
|
|
"authRequired": []string{"my-google-auth"},
|
|
},
|
|
"create-pyspark-batch": map[string]any{
|
|
"type": "serverless-spark-create-pyspark-batch",
|
|
"source": "my-spark",
|
|
"environmentConfig": map[string]any{
|
|
"executionConfig": map[string]any{
|
|
"serviceAccount": serverlessSparkServiceAccount,
|
|
},
|
|
},
|
|
},
|
|
"create-pyspark-batch-2-3": map[string]any{
|
|
"type": "serverless-spark-create-pyspark-batch",
|
|
"source": "my-spark",
|
|
"runtimeConfig": map[string]any{"version": "2.3"},
|
|
"environmentConfig": map[string]any{
|
|
"executionConfig": map[string]any{
|
|
"serviceAccount": serverlessSparkServiceAccount,
|
|
},
|
|
},
|
|
},
|
|
"create-pyspark-batch-with-auth": map[string]any{
|
|
"type": "serverless-spark-create-pyspark-batch",
|
|
"source": "my-spark",
|
|
"authRequired": []string{"my-google-auth"},
|
|
},
|
|
"create-spark-batch": map[string]any{
|
|
"type": "serverless-spark-create-spark-batch",
|
|
"source": "my-spark",
|
|
"environmentConfig": map[string]any{
|
|
"executionConfig": map[string]any{
|
|
"serviceAccount": serverlessSparkServiceAccount,
|
|
},
|
|
},
|
|
},
|
|
"create-spark-batch-2-3": map[string]any{
|
|
"type": "serverless-spark-create-spark-batch",
|
|
"source": "my-spark",
|
|
"runtimeConfig": map[string]any{"version": "2.3"},
|
|
"environmentConfig": map[string]any{
|
|
"executionConfig": map[string]any{
|
|
"serviceAccount": serverlessSparkServiceAccount,
|
|
},
|
|
},
|
|
},
|
|
"create-spark-batch-with-auth": map[string]any{
|
|
"type": "serverless-spark-create-spark-batch",
|
|
"source": "my-spark",
|
|
"authRequired": []string{"my-google-auth"},
|
|
},
|
|
},
|
|
}
|
|
|
|
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile)
|
|
if err != nil {
|
|
t.Fatalf("command initialization returned an error: %s", err)
|
|
}
|
|
defer cleanup()
|
|
|
|
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
|
defer cancel()
|
|
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
|
if err != nil {
|
|
t.Logf("toolbox command logs: \n%s", out)
|
|
t.Fatalf("toolbox didn't start successfully: %s", err)
|
|
}
|
|
|
|
endpoint := fmt.Sprintf("%s-dataproc.googleapis.com:443", serverlessSparkLocation)
|
|
client, err := dataproc.NewBatchControllerClient(ctx, option.WithEndpoint(endpoint))
|
|
if err != nil {
|
|
t.Fatalf("failed to create dataproc client: %v", err)
|
|
}
|
|
defer client.Close()
|
|
|
|
t.Run("list-batches", func(t *testing.T) {
|
|
// list-batches is sensitive to state changes, so this test must run sequentially.
|
|
t.Run("success", func(t *testing.T) {
|
|
runListBatchesTest(t, client, ctx)
|
|
})
|
|
t.Run("errors", func(t *testing.T) {
|
|
t.Parallel()
|
|
tcs := []struct {
|
|
name string
|
|
toolName string
|
|
request map[string]any
|
|
wantCode int
|
|
wantMsg string
|
|
}{
|
|
{
|
|
name: "zero page size",
|
|
toolName: "list-batches",
|
|
request: map[string]any{"pageSize": 0},
|
|
wantCode: http.StatusBadRequest,
|
|
wantMsg: "pageSize must be positive: 0",
|
|
},
|
|
{
|
|
name: "negative page size",
|
|
toolName: "list-batches",
|
|
request: map[string]any{"pageSize": -1},
|
|
wantCode: http.StatusBadRequest,
|
|
wantMsg: "pageSize must be positive: -1",
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
testError(t, tc.toolName, tc.request, tc.wantCode, tc.wantMsg)
|
|
})
|
|
}
|
|
})
|
|
t.Run("auth", func(t *testing.T) {
|
|
t.Parallel()
|
|
runAuthTest(t, "list-batches-with-auth", map[string]any{"pageSize": 1}, http.StatusOK)
|
|
})
|
|
})
|
|
|
|
// The following tool tests are independent and can run in parallel with each other.
|
|
t.Run("parallel-tool-tests", func(t *testing.T) {
|
|
t.Run("get-batch", func(t *testing.T) {
|
|
t.Parallel()
|
|
fullName := listBatchesRpc(t, client, ctx, "", 1, true)[0].Name
|
|
t.Run("success", func(t *testing.T) {
|
|
t.Parallel()
|
|
runGetBatchTest(t, client, ctx, fullName)
|
|
})
|
|
t.Run("errors", func(t *testing.T) {
|
|
t.Parallel()
|
|
missingBatchFullName := fmt.Sprintf("projects/%s/locations/%s/batches/INVALID_BATCH", serverlessSparkProject, serverlessSparkLocation)
|
|
tcs := []struct {
|
|
name string
|
|
toolName string
|
|
request map[string]any
|
|
wantCode int
|
|
wantMsg string
|
|
}{
|
|
{
|
|
name: "missing batch",
|
|
toolName: "get-batch",
|
|
request: map[string]any{"name": "INVALID_BATCH"},
|
|
wantCode: http.StatusBadRequest,
|
|
wantMsg: fmt.Sprintf("Not found: Batch projects/%s/locations/%s/batches/INVALID_BATCH", serverlessSparkProject, serverlessSparkLocation),
|
|
},
|
|
{
|
|
name: "full batch name",
|
|
toolName: "get-batch",
|
|
request: map[string]any{"name": missingBatchFullName},
|
|
wantCode: http.StatusBadRequest,
|
|
wantMsg: fmt.Sprintf("name must be a short batch name without '/': %s", missingBatchFullName),
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
testError(t, tc.toolName, tc.request, tc.wantCode, tc.wantMsg)
|
|
})
|
|
}
|
|
})
|
|
t.Run("auth", func(t *testing.T) {
|
|
t.Parallel()
|
|
runAuthTest(t, "get-batch-with-auth", map[string]any{"name": shortName(fullName)}, http.StatusOK)
|
|
})
|
|
})
|
|
|
|
t.Run("create-pyspark-batch", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("success", func(t *testing.T) {
|
|
t.Parallel()
|
|
piPy := "file:///usr/lib/spark/examples/src/main/python/pi.py"
|
|
tcs := []struct {
|
|
name string
|
|
toolName string
|
|
request map[string]any
|
|
waitForSuccess bool
|
|
validate func(t *testing.T, b *dataprocpb.Batch)
|
|
}{
|
|
{
|
|
name: "no params",
|
|
toolName: "create-pyspark-batch",
|
|
waitForSuccess: true,
|
|
request: map[string]any{"mainFile": piPy},
|
|
},
|
|
// Tests below are just verifying options are set correctly on created batches,
|
|
// they don't need to wait for success.
|
|
{
|
|
name: "with arg",
|
|
toolName: "create-pyspark-batch",
|
|
request: map[string]any{"mainFile": piPy, "args": []string{"100"}},
|
|
validate: func(t *testing.T, b *dataprocpb.Batch) {
|
|
if !cmp.Equal(b.GetPysparkBatch().Args, []string{"100"}) {
|
|
t.Errorf("unexpected args: got %v, want %v", b.GetPysparkBatch().Args, []string{"100"})
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "version",
|
|
toolName: "create-pyspark-batch",
|
|
request: map[string]any{"mainFile": piPy, "version": "2.2"},
|
|
validate: func(t *testing.T, b *dataprocpb.Batch) {
|
|
v := b.GetRuntimeConfig().GetVersion()
|
|
if v != "2.2" {
|
|
t.Errorf("unexpected version: got %v, want 2.2", v)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "version param overrides tool",
|
|
toolName: "create-pyspark-batch-2-3",
|
|
request: map[string]any{"mainFile": piPy, "version": "2.2"},
|
|
validate: func(t *testing.T, b *dataprocpb.Batch) {
|
|
v := b.GetRuntimeConfig().GetVersion()
|
|
if v != "2.2" {
|
|
t.Errorf("unexpected version: got %v, want 2.2", v)
|
|
}
|
|
},
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
runCreateSparkBatchTest(t, client, ctx, tc.toolName, tc.request, tc.waitForSuccess, tc.validate)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("auth", func(t *testing.T) {
|
|
t.Parallel()
|
|
// Batch creation succeeds even with an invalid main file, but will fail quickly once running.
|
|
runAuthTest(t, "create-pyspark-batch-with-auth", map[string]any{"mainFile": "file:///placeholder"}, http.StatusOK)
|
|
})
|
|
|
|
t.Run("errors", func(t *testing.T) {
|
|
t.Parallel()
|
|
tcs := []struct {
|
|
name string
|
|
request map[string]any
|
|
wantMsg string
|
|
}{
|
|
{
|
|
name: "missing main file",
|
|
request: map[string]any{},
|
|
wantMsg: "parameter \\\"mainFile\\\" is required",
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
testError(t, "create-pyspark-batch", tc.request, http.StatusBadRequest, tc.wantMsg)
|
|
})
|
|
}
|
|
})
|
|
})
|
|
|
|
t.Run("create-spark-batch", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("success", func(t *testing.T) {
|
|
t.Parallel()
|
|
tcs := []struct {
|
|
name string
|
|
toolName string
|
|
request map[string]any
|
|
waitForSuccess bool
|
|
validate func(t *testing.T, b *dataprocpb.Batch)
|
|
}{
|
|
{
|
|
name: "main class",
|
|
toolName: "create-spark-batch",
|
|
waitForSuccess: true,
|
|
request: javaReq(map[string]any{}),
|
|
},
|
|
{
|
|
// spark-examples.jar doesn't have a Main-Class, so pick an arbitrary other
|
|
// jar that does. Note there's a chance a subminor release of 2.2 will
|
|
// upgrade Spark and its dependencies, causing a failure. If that happens,
|
|
// find the new ivy jar filename and use that. The alternative would be to
|
|
// pin a subminor version, but that's guaranteed to be GC'ed after 1 year,
|
|
// whereas 2.2 is old enough it's unlikely to see a Spark version bump.
|
|
name: "main jar",
|
|
toolName: "create-spark-batch",
|
|
waitForSuccess: true,
|
|
request: map[string]any{
|
|
"version": "2.2",
|
|
"mainJarFile": "file:///usr/lib/spark/jars/ivy-2.5.2.jar",
|
|
"args": []string{"-version"},
|
|
},
|
|
},
|
|
// Tests below are just verifying options are set correctly on created batches,
|
|
// they don't need to wait for success.
|
|
{
|
|
name: "with arg",
|
|
toolName: "create-spark-batch",
|
|
request: javaReq(map[string]any{"args": []string{"100"}}),
|
|
validate: func(t *testing.T, b *dataprocpb.Batch) {
|
|
if !cmp.Equal(b.GetSparkBatch().Args, []string{"100"}) {
|
|
t.Errorf("unexpected args: got %v, want %v", b.GetSparkBatch().Args, []string{"100"})
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "version",
|
|
toolName: "create-spark-batch",
|
|
request: javaReq(map[string]any{"version": "2.2"}),
|
|
validate: func(t *testing.T, b *dataprocpb.Batch) {
|
|
v := b.GetRuntimeConfig().GetVersion()
|
|
if v != "2.2" {
|
|
t.Errorf("unexpected version: got %v, want 2.2", v)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "version param overrides tool",
|
|
toolName: "create-spark-batch-2-3",
|
|
request: javaReq(map[string]any{"version": "2.2"}),
|
|
validate: func(t *testing.T, b *dataprocpb.Batch) {
|
|
v := b.GetRuntimeConfig().GetVersion()
|
|
if v != "2.2" {
|
|
t.Errorf("unexpected version: got %v, want 2.2", v)
|
|
}
|
|
},
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
runCreateSparkBatchTest(t, client, ctx, tc.toolName, tc.request, tc.waitForSuccess, tc.validate)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("auth", func(t *testing.T) {
|
|
t.Parallel()
|
|
// Batch creation succeeds even with an invalid main file, but will fail quickly once running.
|
|
runAuthTest(t, "create-spark-batch-with-auth", map[string]any{"mainJarFile": "file:///placeholder"}, http.StatusOK)
|
|
})
|
|
|
|
t.Run("errors", func(t *testing.T) {
|
|
t.Parallel()
|
|
tcs := []struct {
|
|
name string
|
|
request map[string]any
|
|
wantMsg string
|
|
}{
|
|
{
|
|
name: "no main jar or main class",
|
|
request: map[string]any{},
|
|
wantMsg: "must provide either mainJarFile or mainClass",
|
|
},
|
|
{
|
|
name: "both main jar and main class",
|
|
request: map[string]any{
|
|
"mainJarFile": "my.jar",
|
|
"mainClass": "com.example.MyClass",
|
|
},
|
|
wantMsg: "cannot provide both mainJarFile and mainClass",
|
|
},
|
|
{
|
|
name: "main class without jar files",
|
|
request: map[string]any{
|
|
"mainClass": "com.example.MyClass",
|
|
},
|
|
wantMsg: "jarFiles is required when mainClass is provided",
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
testError(t, "create-spark-batch", tc.request, http.StatusBadRequest, tc.wantMsg)
|
|
})
|
|
}
|
|
})
|
|
})
|
|
|
|
t.Run("cancel-batch", func(t *testing.T) {
|
|
t.Parallel()
|
|
t.Run("success", func(t *testing.T) {
|
|
t.Parallel()
|
|
tcs := []struct {
|
|
name string
|
|
getBatchName func(t *testing.T) string
|
|
}{
|
|
{
|
|
name: "running batch",
|
|
getBatchName: func(t *testing.T) string {
|
|
return createBatch(t, client, ctx)
|
|
},
|
|
},
|
|
{
|
|
name: "succeeded batch",
|
|
getBatchName: func(t *testing.T) string {
|
|
return listBatchesRpc(t, client, ctx, "state = SUCCEEDED", 1, true)[0].Name
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
runCancelBatchTest(t, client, ctx, tc.getBatchName(t))
|
|
})
|
|
}
|
|
})
|
|
t.Run("errors", func(t *testing.T) {
|
|
t.Parallel()
|
|
// Find a batch that's already completed.
|
|
completedBatchOp := listBatchesRpc(t, client, ctx, "state = SUCCEEDED", 1, true)[0].Operation
|
|
fullOpName := fmt.Sprintf("projects/%s/locations/%s/operations/%s", serverlessSparkProject, serverlessSparkLocation, shortName(completedBatchOp))
|
|
tcs := []struct {
|
|
name string
|
|
toolName string
|
|
request map[string]any
|
|
wantCode int
|
|
wantMsg string
|
|
}{
|
|
{
|
|
name: "missing op parameter",
|
|
toolName: "cancel-batch",
|
|
request: map[string]any{},
|
|
wantCode: http.StatusBadRequest,
|
|
wantMsg: "parameter \\\"operation\\\" is required",
|
|
},
|
|
{
|
|
name: "nonexistent op",
|
|
toolName: "cancel-batch",
|
|
request: map[string]any{"operation": "INVALID_OPERATION"},
|
|
wantCode: http.StatusBadRequest,
|
|
wantMsg: "Operation not found",
|
|
},
|
|
{
|
|
name: "full op name",
|
|
toolName: "cancel-batch",
|
|
request: map[string]any{"operation": fullOpName},
|
|
wantCode: http.StatusBadRequest,
|
|
wantMsg: fmt.Sprintf("operation must be a short operation name without '/': %s", fullOpName),
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
testError(t, tc.toolName, tc.request, tc.wantCode, tc.wantMsg)
|
|
})
|
|
}
|
|
})
|
|
t.Run("auth", func(t *testing.T) {
|
|
t.Parallel()
|
|
runAuthTest(t, "cancel-batch-with-auth", map[string]any{"operation": "INVALID_OPERATION"}, http.StatusBadRequest)
|
|
})
|
|
})
|
|
})
|
|
}
|
|
|
|
func waitForBatch(t *testing.T, client *dataproc.BatchControllerClient, parentCtx context.Context, batch string, desiredStates []dataprocpb.Batch_State, timeout time.Duration) {
|
|
t.Logf("waiting %s for batch %s to reach one of %v", timeout, batch, desiredStates)
|
|
ctx, cancel := context.WithTimeout(parentCtx, timeout)
|
|
defer cancel()
|
|
|
|
start := time.Now()
|
|
lastLog := start
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
t.Fatalf("timed out waiting for batch %s to reach one of %v", batch, desiredStates)
|
|
default:
|
|
}
|
|
|
|
getReq := &dataprocpb.GetBatchRequest{Name: batch}
|
|
batch, err := client.GetBatch(ctx, getReq)
|
|
if err != nil {
|
|
t.Fatalf("failed to get batch %s: %v", batch, err)
|
|
}
|
|
|
|
now := time.Now()
|
|
if now.Sub(lastLog) >= 30*time.Second {
|
|
t.Logf("%s: batch %s is in state %s after %s", t.Name(), batch.Name, batch.State, now.Sub(start))
|
|
lastLog = now
|
|
}
|
|
|
|
if slices.Contains(desiredStates, batch.State) {
|
|
return
|
|
}
|
|
|
|
if batch.State == dataprocpb.Batch_FAILED || batch.State == dataprocpb.Batch_CANCELLED || batch.State == dataprocpb.Batch_SUCCEEDED {
|
|
t.Fatalf("batch op %s is in a terminal state %s, but wanted one of %v. State message: %s", batch.Name, batch.State, desiredStates, batch.StateMessage)
|
|
}
|
|
time.Sleep(2 * time.Second)
|
|
}
|
|
}
|
|
|
|
// createBatch creates a test batch and immediately returns the batch name, without waiting for the
|
|
// batch to start or complete.
|
|
func createBatch(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context) string {
|
|
parent := fmt.Sprintf("projects/%s/locations/%s", serverlessSparkProject, serverlessSparkLocation)
|
|
req := &dataprocpb.CreateBatchRequest{
|
|
Parent: parent,
|
|
Batch: &dataprocpb.Batch{
|
|
BatchConfig: &dataprocpb.Batch_SparkBatch{
|
|
SparkBatch: &dataprocpb.SparkBatch{
|
|
Driver: &dataprocpb.SparkBatch_MainClass{
|
|
MainClass: "org.apache.spark.examples.SparkPi",
|
|
},
|
|
JarFileUris: []string{
|
|
"file:///usr/lib/spark/examples/jars/spark-examples.jar",
|
|
},
|
|
Args: []string{"1000"},
|
|
},
|
|
},
|
|
EnvironmentConfig: &dataprocpb.EnvironmentConfig{
|
|
ExecutionConfig: &dataprocpb.ExecutionConfig{
|
|
ServiceAccount: serverlessSparkServiceAccount,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
createOp, err := client.CreateBatch(ctx, req)
|
|
if err != nil {
|
|
t.Fatalf("failed to create batch: %v", err)
|
|
}
|
|
meta, err := createOp.Metadata()
|
|
if err != nil {
|
|
t.Fatalf("failed to get batch metadata: %v", err)
|
|
}
|
|
|
|
// Wait for the batch to become at least PENDING; it typically takes >10s to go from PENDING to
|
|
// RUNNING, giving the cancel batch tests plenty of time to cancel it before it completes.
|
|
waitForBatch(t, client, ctx, meta.Batch, []dataprocpb.Batch_State{dataprocpb.Batch_PENDING, dataprocpb.Batch_RUNNING}, 1*time.Minute)
|
|
return meta.Batch
|
|
}
|
|
|
|
func runCancelBatchTest(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, batchName string) {
|
|
// First get the batch details directly from the Go proto API.
|
|
batch, err := client.GetBatch(ctx, &dataprocpb.GetBatchRequest{Name: batchName})
|
|
if err != nil {
|
|
t.Fatalf("failed to get batch: %s", err)
|
|
}
|
|
|
|
request := map[string]any{"operation": shortName(batch.Operation)}
|
|
resp, err := invokeTool("cancel-batch", request, nil)
|
|
if err != nil {
|
|
t.Fatalf("invokeTool failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
|
}
|
|
|
|
if batch.State != dataprocpb.Batch_SUCCEEDED {
|
|
waitForBatch(t, client, ctx, batchName, []dataprocpb.Batch_State{dataprocpb.Batch_CANCELLING, dataprocpb.Batch_CANCELLED}, 2*time.Minute)
|
|
}
|
|
}
|
|
|
|
// runListBatchesTest invokes the running list-batches tool and ensures it returns the correct
|
|
// number of results. It can run successfully against any GCP project that contains at least 2 total
|
|
// Serverless Spark batches.
|
|
func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context) {
|
|
batch2 := listBatchesRpc(t, client, ctx, "", 2, true)
|
|
batch20 := listBatchesRpc(t, client, ctx, "", 20, false)
|
|
|
|
tcs := []struct {
|
|
name string
|
|
filter string
|
|
pageSize int
|
|
numPages int
|
|
want []serverlessspark.Batch
|
|
}{
|
|
{name: "one page", pageSize: 2, numPages: 1, want: batch2},
|
|
{name: "two pages", pageSize: 1, numPages: 2, want: batch2},
|
|
{name: "20 batches", pageSize: 20, numPages: 1, want: batch20},
|
|
{name: "omit page size", numPages: 1, want: batch20},
|
|
{
|
|
name: "filtered",
|
|
filter: "state = SUCCEEDED",
|
|
pageSize: 2,
|
|
numPages: 1,
|
|
want: listBatchesRpc(t, client, ctx, "state = SUCCEEDED", 2, true),
|
|
},
|
|
{
|
|
name: "empty",
|
|
filter: "state = SUCCEEDED AND state = FAILED",
|
|
pageSize: 1,
|
|
numPages: 1,
|
|
want: nil,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
var actual []serverlessspark.Batch
|
|
var pageToken string
|
|
for i := 0; i < tc.numPages; i++ {
|
|
request := map[string]any{
|
|
"filter": tc.filter,
|
|
"pageToken": pageToken,
|
|
}
|
|
if tc.pageSize > 0 {
|
|
request["pageSize"] = tc.pageSize
|
|
}
|
|
|
|
resp, err := invokeTool("list-batches", request, nil)
|
|
if err != nil {
|
|
t.Fatalf("invokeTool failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
|
}
|
|
|
|
var body map[string]any
|
|
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
|
t.Fatalf("error parsing response body: %v", err)
|
|
}
|
|
|
|
result, ok := body["result"].(string)
|
|
if !ok {
|
|
t.Fatalf("unable to find result in response body")
|
|
}
|
|
|
|
var listResponse serverlessspark.ListBatchesResponse
|
|
if err := json.Unmarshal([]byte(result), &listResponse); err != nil {
|
|
t.Fatalf("error unmarshalling result: %s", err)
|
|
}
|
|
actual = append(actual, listResponse.Batches...)
|
|
pageToken = listResponse.NextPageToken
|
|
}
|
|
|
|
if !reflect.DeepEqual(actual, tc.want) {
|
|
t.Fatalf("unexpected batches: got %+v, want %+v", actual, tc.want)
|
|
}
|
|
|
|
// want has URLs because it's created from Batch instances by the same utility function
|
|
// used by the tool internals. Double-check that the URLs are reasonable.
|
|
for _, batch := range tc.want {
|
|
if !strings.HasPrefix(batch.ConsoleURL, batchURLPrefix) {
|
|
t.Errorf("unexpected consoleUrl in batch: %#v", batch)
|
|
}
|
|
if !strings.HasPrefix(batch.LogsURL, logsURLPrefix) {
|
|
t.Errorf("unexpected logsUrl in batch: %#v", batch)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func listBatchesRpc(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, filter string, n int, exact bool) []serverlessspark.Batch {
|
|
parent := fmt.Sprintf("projects/%s/locations/%s", serverlessSparkProject, serverlessSparkLocation)
|
|
req := &dataprocpb.ListBatchesRequest{
|
|
Parent: parent,
|
|
PageSize: 2,
|
|
OrderBy: "create_time desc",
|
|
}
|
|
if filter != "" {
|
|
req.Filter = filter
|
|
}
|
|
|
|
it := client.ListBatches(ctx, req)
|
|
pager := iterator.NewPager(it, n, "")
|
|
var batchPbs []*dataprocpb.Batch
|
|
_, err := pager.NextPage(&batchPbs)
|
|
if err != nil {
|
|
t.Fatalf("failed to list batches: %s", err)
|
|
}
|
|
if exact && len(batchPbs) != n {
|
|
t.Fatalf("expected exactly %d batches, got %d", n, len(batchPbs))
|
|
}
|
|
if !exact && (len(batchPbs) == 0 || len(batchPbs) > n) {
|
|
t.Fatalf("expected between 1 and %d batches, got %d", n, len(batchPbs))
|
|
}
|
|
batches, err := serverlessspark.ToBatches(batchPbs)
|
|
if err != nil {
|
|
t.Fatalf("failed to convert batches to JSON: %v", err)
|
|
}
|
|
|
|
return batches
|
|
}
|
|
|
|
func runAuthTest(t *testing.T, toolName string, request map[string]any, wantStatus int) {
|
|
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
|
|
if err != nil {
|
|
t.Fatalf("error getting Google ID token: %s", err)
|
|
}
|
|
tcs := []struct {
|
|
name string
|
|
headers map[string]string
|
|
wantStatus int
|
|
}{
|
|
{
|
|
name: "valid auth token",
|
|
headers: map[string]string{"my-google-auth_token": idToken},
|
|
wantStatus: wantStatus,
|
|
},
|
|
{
|
|
name: "invalid auth token",
|
|
headers: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
|
|
wantStatus: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "no auth token",
|
|
headers: nil,
|
|
wantStatus: http.StatusUnauthorized,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
resp, err := invokeTool(toolName, request, tc.headers)
|
|
if err != nil {
|
|
t.Fatalf("invokeTool failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != tc.wantStatus {
|
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
t.Fatalf("response status code is not %d, got %d: %s", tc.wantStatus, resp.StatusCode, string(bodyBytes))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func runGetBatchTest(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, fullName string) {
|
|
// First get the batch details directly from the Go proto API.
|
|
req := &dataprocpb.GetBatchRequest{
|
|
Name: fullName,
|
|
}
|
|
rawWantBatchPb, err := client.GetBatch(ctx, req)
|
|
if err != nil {
|
|
t.Fatalf("failed to get batch: %s", err)
|
|
}
|
|
|
|
// Trim unknown fields from the proto by marshalling and unmarshalling.
|
|
jsonBytes, err := protojson.Marshal(rawWantBatchPb)
|
|
if err != nil {
|
|
t.Fatalf("failed to marshal batch to JSON: %s", err)
|
|
}
|
|
var wantBatchPb dataprocpb.Batch
|
|
if err := protojson.Unmarshal(jsonBytes, &wantBatchPb); err != nil {
|
|
t.Fatalf("error unmarshalling result: %s", err)
|
|
}
|
|
|
|
tcs := []struct {
|
|
name string
|
|
batchName string
|
|
want *dataprocpb.Batch
|
|
}{
|
|
{
|
|
name: "found batch",
|
|
batchName: shortName(fullName),
|
|
want: &wantBatchPb,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
request := map[string]any{"name": tc.batchName}
|
|
resp, err := invokeTool("get-batch", request, nil)
|
|
if err != nil {
|
|
t.Fatalf("invokeTool failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
|
}
|
|
var body map[string]any
|
|
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
|
t.Fatalf("error parsing response body: %v", err)
|
|
}
|
|
result, ok := body["result"].(string)
|
|
if !ok {
|
|
t.Fatalf("unable to find result in response body")
|
|
}
|
|
var wrappedResult map[string]any
|
|
if err := json.Unmarshal([]byte(result), &wrappedResult); err != nil {
|
|
t.Fatalf("error unmarshalling result: %s", err)
|
|
}
|
|
consoleURL, ok := wrappedResult["consoleUrl"].(string)
|
|
if !ok || !strings.HasPrefix(consoleURL, batchURLPrefix) {
|
|
t.Errorf("unexpected consoleUrl: %v", consoleURL)
|
|
}
|
|
logsURL, ok := wrappedResult["logsUrl"].(string)
|
|
if !ok || !strings.HasPrefix(logsURL, logsURLPrefix) {
|
|
t.Errorf("unexpected logsUrl: %v", logsURL)
|
|
}
|
|
batchJSON, err := json.Marshal(wrappedResult["batch"])
|
|
if err != nil {
|
|
t.Fatalf("failed to marshal batch: %v", err)
|
|
}
|
|
|
|
// Unmarshal JSON to proto for proto-aware deep comparison.
|
|
var batch dataprocpb.Batch
|
|
if err := protojson.Unmarshal(batchJSON, &batch); err != nil {
|
|
t.Fatalf("error unmarshalling batch from wrapped result: %s", err)
|
|
}
|
|
|
|
if !cmp.Equal(&batch, tc.want, protocmp.Transform()) {
|
|
diff := cmp.Diff(&batch, tc.want, protocmp.Transform())
|
|
t.Errorf("GetBatch() returned diff (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func javaReq(req map[string]any) map[string]any {
|
|
merged := map[string]any{
|
|
"mainClass": "org.apache.spark.examples.SparkPi",
|
|
"jarFiles": []string{"file:///usr/lib/spark/examples/jars/spark-examples.jar"},
|
|
}
|
|
maps.Copy(merged, req)
|
|
return merged
|
|
}
|
|
|
|
func runCreateSparkBatchTest(
|
|
t *testing.T,
|
|
client *dataproc.BatchControllerClient,
|
|
ctx context.Context,
|
|
toolName string,
|
|
request map[string]any,
|
|
waitForSuccess bool,
|
|
validate func(t *testing.T, b *dataprocpb.Batch),
|
|
) {
|
|
resp, err := invokeTool(toolName, request, nil)
|
|
if err != nil {
|
|
t.Fatalf("invokeTool failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
|
}
|
|
|
|
var body map[string]any
|
|
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
|
t.Fatalf("error parsing response body: %v", err)
|
|
}
|
|
|
|
result, ok := body["result"].(string)
|
|
if !ok {
|
|
t.Fatalf("unable to find result in response body")
|
|
}
|
|
|
|
var resultMap map[string]any
|
|
if err := json.Unmarshal([]byte(result), &resultMap); err != nil {
|
|
t.Fatalf("failed to unmarshal result: %v", err)
|
|
}
|
|
consoleURL, ok := resultMap["consoleUrl"].(string)
|
|
if !ok || !strings.HasPrefix(consoleURL, batchURLPrefix) {
|
|
t.Errorf("unexpected consoleUrl: %v", consoleURL)
|
|
}
|
|
logsURL, ok := resultMap["logsUrl"].(string)
|
|
if !ok || !strings.HasPrefix(logsURL, logsURLPrefix) {
|
|
t.Errorf("unexpected logsUrl: %v", logsURL)
|
|
}
|
|
metaMap, ok := resultMap["opMetadata"].(map[string]any)
|
|
if !ok {
|
|
t.Fatalf("unexpected opMetadata: %v", metaMap)
|
|
}
|
|
metaJson, err := json.Marshal(metaMap)
|
|
if err != nil {
|
|
t.Fatalf("failed to marshal op metadata to JSON: %s", err)
|
|
}
|
|
var meta dataprocpb.BatchOperationMetadata
|
|
if err := json.Unmarshal([]byte(metaJson), &meta); err != nil {
|
|
t.Fatalf("failed to unmarshal result: %v", err)
|
|
}
|
|
|
|
if validate != nil {
|
|
b, err := client.GetBatch(ctx, &dataprocpb.GetBatchRequest{Name: meta.Batch})
|
|
if err != nil {
|
|
t.Fatalf("failed to get batch %s: %s", meta.Batch, err)
|
|
}
|
|
validate(t, b)
|
|
}
|
|
|
|
if waitForSuccess {
|
|
waitForBatch(t, client, ctx, meta.Batch, []dataprocpb.Batch_State{dataprocpb.Batch_SUCCEEDED}, 5*time.Minute)
|
|
}
|
|
}
|
|
|
|
func testError(t *testing.T, toolName string, request map[string]any, wantCode int, wantMsg string) {
|
|
resp, err := invokeTool(toolName, request, nil)
|
|
if err != nil {
|
|
t.Fatalf("invokeTool failed: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != wantCode {
|
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
t.Fatalf("response status code is not %d, got %d: %s", wantCode, resp.StatusCode, string(bodyBytes))
|
|
}
|
|
|
|
bodyBytes, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("failed to read response body: %v", err)
|
|
}
|
|
|
|
if !bytes.Contains(bodyBytes, []byte(wantMsg)) {
|
|
t.Fatalf("response body does not contain %q: %s", wantMsg, string(bodyBytes))
|
|
}
|
|
}
|
|
|
|
func invokeTool(toolName string, request map[string]any, headers map[string]string) (*http.Response, error) {
|
|
requestBytes, err := json.Marshal(request)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
|
|
url := fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", toolName)
|
|
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(requestBytes))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to create request: %w", err)
|
|
}
|
|
req.Header.Add("Content-type", "application/json")
|
|
for k, v := range headers {
|
|
req.Header.Add(k, v)
|
|
}
|
|
|
|
return http.DefaultClient.Do(req)
|
|
}
|
|
|
|
func shortName(fullName string) string {
|
|
parts := strings.Split(fullName, "/")
|
|
return parts[len(parts)-1]
|
|
}
|