mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 15:38:08 -05:00
feat(serverless-spark): Add get_batch tool
This commit is contained in:
@@ -155,6 +155,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttables"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/redis"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkgetbatch"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerlisttables"
|
||||
|
||||
@@ -1467,7 +1467,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"serverless_spark_tools": tools.ToolsetConfig{
|
||||
Name: "serverless_spark_tools",
|
||||
ToolNames: []string{"list_batches"},
|
||||
ToolNames: []string{"list_batches", "get_batch"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -17,6 +17,8 @@ Apache Spark.
|
||||
|
||||
- [`serverless-spark-list-batches`](../tools/serverless-spark/serverless-spark-list-batches.md)
|
||||
List and filter Serverless Spark batches.
|
||||
- [`serverless-spark-get-batch`](../tools/serverless-spark/serverless-spark-get-batch.md)
|
||||
Get a Serverless Spark batch.
|
||||
|
||||
## Requirements
|
||||
|
||||
|
||||
@@ -4,4 +4,7 @@ type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Tools that work with Google Cloud Serverless for Apache Spark Sources.
|
||||
---
|
||||
---
|
||||
|
||||
- [serverless-spark-get-batch](./serverless-spark-get-batch.md)
|
||||
- [serverless-spark-list-batches](./serverless-spark-list-batches.md)
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
---
|
||||
title: "serverless-spark-get-batch"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "serverless-spark-get-batch" tool gets a single Spark batch from the source.
|
||||
aliases:
|
||||
- /resources/tools/serverless-spark-get-batch
|
||||
---
|
||||
|
||||
# serverless-spark-get-batch
|
||||
|
||||
The `serverless-spark-get-batch` tool allows you to retrieve a specific
|
||||
Serverless Spark batch job. It's compatible with the following sources:
|
||||
|
||||
- [serverless-spark](../../sources/serverless-spark.md)
|
||||
|
||||
`serverless-spark-list-batches` accepts the following parameters:
|
||||
|
||||
- **`name`**: The short name of the batch, e.g. for
|
||||
`projects/my-project/locations/us-central1/my-batch`, pass `my-batch`.
|
||||
|
||||
The tool gets the `project` and `location` from the source configuration.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
get_my_batch:
|
||||
kind: serverless-spark-get-batch
|
||||
source: my-serverless-spark-source
|
||||
description: Use this tool to get a serverless spark batch.
|
||||
```
|
||||
|
||||
## Response Format
|
||||
|
||||
The response is a full Batch JSON object as defined in the [API
|
||||
spec](https://cloud.google.com/dataproc-serverless/docs/reference/rest/v1/projects.locations.batches#Batch).
|
||||
Example with a reduced set of fields:
|
||||
|
||||
```json
|
||||
{
|
||||
"createTime": "2025-10-10T15:15:21.303146Z",
|
||||
"creator": "alice@example.com",
|
||||
"labels": {
|
||||
"goog-dataproc-batch-uuid": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
|
||||
"goog-dataproc-location": "us-central1"
|
||||
},
|
||||
"name": "projects/google.com:hadoop-cloud-dev/locations/us-central1/batches/alice-20251010-abcd",
|
||||
"operation": "projects/google.com:hadoop-cloud-dev/regions/us-central1/operations/11111111-2222-3333-4444-555555555555",
|
||||
"runtimeConfig": {
|
||||
"properties": {
|
||||
"spark:spark.driver.cores": "4",
|
||||
"spark:spark.driver.memory": "12200m"
|
||||
}
|
||||
},
|
||||
"sparkBatch": {
|
||||
"jarFileUris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"],
|
||||
"mainClass": "org.apache.spark.examples.SparkPi"
|
||||
},
|
||||
"state": "SUCCEEDED",
|
||||
"stateHistory": [
|
||||
{
|
||||
"state": "PENDING",
|
||||
"stateStartTime": "2025-10-10T15:15:21.303146Z"
|
||||
},
|
||||
{
|
||||
"state": "RUNNING",
|
||||
"stateStartTime": "2025-10-10T15:16:41.291747Z"
|
||||
}
|
||||
],
|
||||
"stateTime": "2025-10-10T15:17:21.265493Z",
|
||||
"uuid": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
}
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
| ------------ | :------: | :----------: | -------------------------------------------------- |
|
||||
| kind | string | true | Must be "serverless-spark-get-batch". |
|
||||
| source | string | true | Name of the source the tool should use. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM. |
|
||||
| authRequired | string[] | false | List of auth services required to invoke this tool |
|
||||
2
go.mod
2
go.mod
@@ -56,6 +56,7 @@ require (
|
||||
golang.org/x/oauth2 v0.32.0
|
||||
google.golang.org/api v0.251.0
|
||||
google.golang.org/genproto v0.0.0-20251007200510-49b9836ed3ff
|
||||
google.golang.org/protobuf v1.36.10
|
||||
modernc.org/sqlite v1.39.1
|
||||
)
|
||||
|
||||
@@ -180,7 +181,6 @@ require (
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251002232023-7c0ddcbb5797 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251002232023-7c0ddcbb5797 // indirect
|
||||
google.golang.org/grpc v1.75.1 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
gopkg.in/inf.v0 v0.9.1 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
modernc.org/libc v1.66.10 // indirect
|
||||
|
||||
@@ -22,7 +22,11 @@ tools:
|
||||
list_batches:
|
||||
kind: serverless-spark-list-batches
|
||||
source: serverless-spark-source
|
||||
get_batch:
|
||||
kind: serverless-spark-get-batch
|
||||
source: serverless-spark-source
|
||||
|
||||
toolsets:
|
||||
serverless_spark_tools:
|
||||
- list_batches
|
||||
- get_batch
|
||||
|
||||
@@ -0,0 +1,171 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package serverlesssparkgetbatch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
const kind = "serverless-spark-get-batch"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
// ToolConfigKind returns the unique name for this tool.
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
// Initialize creates a new Tool instance.
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
ds, ok := rawS.(*serverlessspark.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, serverlessspark.SourceKind)
|
||||
}
|
||||
|
||||
desc := cfg.Description
|
||||
if desc == "" {
|
||||
desc = "Gets a Serverless Spark (aka Dataproc Serverless) batch"
|
||||
}
|
||||
|
||||
allParameters := tools.Parameters{
|
||||
tools.NewStringParameter("name", "The short name of the batch, e.g. for \"projects/my-project/locations/us-central1/batches/my-batch\", pass \"my-batch\" (the project and location are inherited from the source)"),
|
||||
}
|
||||
inputSchema, _ := allParameters.McpManifest()
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: desc,
|
||||
InputSchema: inputSchema,
|
||||
}
|
||||
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Source: ds,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
manifest: tools.Manifest{Description: desc, Parameters: allParameters.Manifest()},
|
||||
mcpManifest: mcpManifest,
|
||||
Parameters: allParameters,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Tool is the implementation of the tool.
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
|
||||
Source *serverlessspark.Source
|
||||
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters tools.Parameters
|
||||
}
|
||||
|
||||
// Invoke executes the tool's operation.
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
client := t.Source.GetBatchControllerClient()
|
||||
|
||||
paramMap := params.AsMap()
|
||||
name, ok := paramMap["name"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing required parameter: name")
|
||||
}
|
||||
|
||||
if strings.Contains(name, "/") {
|
||||
return nil, fmt.Errorf("name must be a short batch name without '/': %s", name)
|
||||
}
|
||||
|
||||
req := &dataprocpb.GetBatchRequest{
|
||||
Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", t.Source.Project, t.Source.Location, name),
|
||||
}
|
||||
|
||||
batchPb, err := client.GetBatch(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get batch: %w", err)
|
||||
}
|
||||
|
||||
jsonBytes, err := protojson.Marshal(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal batch to JSON: %w", err)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(jsonBytes, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal batch JSON: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(services []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, services)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
// Client OAuth not supported, rely on ADCs.
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package serverlesssparkgetbatch_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkgetbatch"
|
||||
)
|
||||
|
||||
func TestParseFromYaml(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: serverless-spark-get-batch
|
||||
source: my-instance
|
||||
description: some description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": serverlesssparkgetbatch.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "serverless-spark-get-batch",
|
||||
Source: "my-instance",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got, yaml.Strict())
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -24,16 +24,20 @@ import (
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
|
||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
"google.golang.org/api/iterator"
|
||||
"google.golang.org/api/option"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/testing/protocmp"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -81,6 +85,15 @@ func TestServerlessSparkToolEndpoints(t *testing.T) {
|
||||
"source": "my-spark",
|
||||
"authRequired": []string{"my-google-auth"},
|
||||
},
|
||||
"get-batch": map[string]any{
|
||||
"kind": "serverless-spark-get-batch",
|
||||
"source": "my-spark",
|
||||
},
|
||||
"get-batch-with-auth": map[string]any{
|
||||
"kind": "serverless-spark-get-batch",
|
||||
"source": "my-spark",
|
||||
"authRequired": []string{"my-google-auth"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -106,13 +119,20 @@ func TestServerlessSparkToolEndpoints(t *testing.T) {
|
||||
defer client.Close()
|
||||
|
||||
runListBatchesTest(t, client, ctx)
|
||||
runListBatchesErrorTest(t)
|
||||
runListBatchesAuthTest(t)
|
||||
|
||||
fullName := listBatchesRpc(t, client, ctx, "", 1, true)[0].Name
|
||||
runGetBatchTest(t, client, ctx, fullName)
|
||||
|
||||
runErrorTest(t)
|
||||
|
||||
// Get the most recent batch, which is all we need for this test.
|
||||
runAuthTest(t, "list-batches-with-auth", map[string]any{"pageSize": 1})
|
||||
runAuthTest(t, "get-batch-with-auth", map[string]any{"name": shortName(fullName)})
|
||||
}
|
||||
|
||||
// runListBatchesTest invokes the running list-batches tool and ensures it returns the correct
|
||||
// number of results. It can run successfully against any GCP project that has at least 2 succeeded
|
||||
// or failed Serverless Spark batches, of any age.
|
||||
// number of results. It can run successfully against any GCP project that contains at least 2 total
|
||||
// Serverless Spark batches.
|
||||
func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context) {
|
||||
batch2 := listBatchesRpc(t, client, ctx, "", 2, true)
|
||||
batch20 := listBatchesRpc(t, client, ctx, "", 20, false)
|
||||
@@ -145,7 +165,7 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Run("list-batches "+tc.name, func(t *testing.T) {
|
||||
var actual []serverlesssparklistbatches.Batch
|
||||
var pageToken string
|
||||
for i := 0; i < tc.numPages; i++ {
|
||||
@@ -157,9 +177,9 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct
|
||||
request["pageSize"] = tc.pageSize
|
||||
}
|
||||
|
||||
resp, err := invokeListBatches("list-batches", request, nil)
|
||||
resp, err := invokeTool("list-batches", request, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("invokeListBatches failed: %v", err)
|
||||
t.Fatalf("invokeTool failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
@@ -221,35 +241,160 @@ func listBatchesRpc(t *testing.T, client *dataproc.BatchControllerClient, ctx co
|
||||
return serverlesssparklistbatches.ToBatches(batchPbs)
|
||||
}
|
||||
|
||||
func runListBatchesErrorTest(t *testing.T) {
|
||||
func runAuthTest(t *testing.T, toolName string, request map[string]any) {
|
||||
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting Google ID token: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "valid auth token",
|
||||
headers: map[string]string{"my-google-auth_token": idToken},
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "invalid auth token",
|
||||
headers: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "no auth token",
|
||||
headers: nil,
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(toolName+" "+tc.name, func(t *testing.T) {
|
||||
resp, err := invokeTool(toolName, request, tc.headers)
|
||||
if err != nil {
|
||||
t.Fatalf("invokeTool failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tc.wantStatus {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not %d, got %d: %s", tc.wantStatus, resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runGetBatchTest(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, fullName string) {
|
||||
// First get the batch details directly from the Go proto API.
|
||||
req := &dataprocpb.GetBatchRequest{
|
||||
Name: fullName,
|
||||
}
|
||||
rawWantBatchPb, err := client.GetBatch(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get batch: %s", err)
|
||||
}
|
||||
|
||||
// Trim unknown fields from the proto by marshalling and unmarshalling.
|
||||
jsonBytes, err := protojson.Marshal(rawWantBatchPb)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal batch to JSON: %s", err)
|
||||
}
|
||||
var wantBatchPb dataprocpb.Batch
|
||||
if err := protojson.Unmarshal(jsonBytes, &wantBatchPb); err != nil {
|
||||
t.Fatalf("error unmarshalling result: %s", err)
|
||||
}
|
||||
|
||||
tcs := []struct {
|
||||
name string
|
||||
batchName string
|
||||
want *dataprocpb.Batch
|
||||
}{
|
||||
{
|
||||
name: "found batch",
|
||||
batchName: shortName(fullName),
|
||||
want: &wantBatchPb,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run("get-batch "+tc.name, func(t *testing.T) {
|
||||
request := map[string]any{"name": tc.batchName}
|
||||
resp, err := invokeTool("get-batch", request, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("invokeTool failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
var body map[string]any
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
t.Fatalf("error parsing response body: %v", err)
|
||||
}
|
||||
result, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
// Unmarshal JSON to proto for proto-aware deep comparison.
|
||||
var batch dataprocpb.Batch
|
||||
if err := protojson.Unmarshal([]byte(result), &batch); err != nil {
|
||||
t.Fatalf("error unmarshalling result: %s", err)
|
||||
}
|
||||
|
||||
if !cmp.Equal(&batch, tc.want, protocmp.Transform()) {
|
||||
diff := cmp.Diff(&batch, tc.want, protocmp.Transform())
|
||||
t.Errorf("GetBatch() returned diff (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runErrorTest(t *testing.T) {
|
||||
missingBatchFullName := fmt.Sprintf("projects/%s/locations/%s/batches/INVALID_BATCH", serverlessSparkProject, serverlessSparkLocation)
|
||||
tcs := []struct {
|
||||
name string
|
||||
pageSize int
|
||||
toolName string
|
||||
request map[string]any
|
||||
wantCode int
|
||||
wantMsg string
|
||||
}{
|
||||
{
|
||||
name: "zero page size",
|
||||
pageSize: 0,
|
||||
name: "list-batches zero page size",
|
||||
toolName: "list-batches",
|
||||
request: map[string]any{"pageSize": 0},
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantMsg: "pageSize must be positive: 0",
|
||||
},
|
||||
{
|
||||
name: "negative page size",
|
||||
pageSize: -1,
|
||||
name: "list-batches negative page size",
|
||||
toolName: "list-batches",
|
||||
request: map[string]any{"pageSize": -1},
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantMsg: "pageSize must be positive: -1",
|
||||
},
|
||||
{
|
||||
name: "get-batch missing batch",
|
||||
toolName: "get-batch",
|
||||
request: map[string]any{"name": "INVALID_BATCH"},
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantMsg: fmt.Sprintf("Not found: Batch projects/%s/locations/%s/batches/INVALID_BATCH", serverlessSparkProject, serverlessSparkLocation),
|
||||
},
|
||||
{
|
||||
name: "get-batch full batch name",
|
||||
toolName: "get-batch",
|
||||
request: map[string]any{"name": missingBatchFullName},
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantMsg: fmt.Sprintf("name must be a short batch name without '/': %s", missingBatchFullName),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
request := map[string]any{
|
||||
"pageSize": tc.pageSize,
|
||||
}
|
||||
resp, err := invokeListBatches("list-batches", request, nil)
|
||||
resp, err := invokeTool(tc.toolName, tc.request, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("invokeListBatches failed: %v", err)
|
||||
t.Fatalf("invokeTool failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
@@ -270,57 +415,7 @@ func runListBatchesErrorTest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func runListBatchesAuthTest(t *testing.T) {
|
||||
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting Google ID token: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
name string
|
||||
toolName string
|
||||
headers map[string]string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "valid auth token",
|
||||
toolName: "list-batches-with-auth",
|
||||
headers: map[string]string{"my-google-auth_token": idToken},
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "invalid auth token",
|
||||
toolName: "list-batches-with-auth",
|
||||
headers: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "no auth token",
|
||||
toolName: "list-batches-with-auth",
|
||||
headers: nil,
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
request := map[string]any{
|
||||
"pageSize": 1,
|
||||
}
|
||||
resp, err := invokeListBatches(tc.toolName, request, tc.headers)
|
||||
if err != nil {
|
||||
t.Fatalf("invokeListBatches failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tc.wantStatus {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not %d, got %d: %s", tc.wantStatus, resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func invokeListBatches(toolName string, request map[string]any, headers map[string]string) (*http.Response, error) {
|
||||
func invokeTool(toolName string, request map[string]any, headers map[string]string) (*http.Response, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
@@ -338,3 +433,8 @@ func invokeListBatches(toolName string, request map[string]any, headers map[stri
|
||||
|
||||
return http.DefaultClient.Do(req)
|
||||
}
|
||||
|
||||
func shortName(fullName string) string {
|
||||
parts := strings.Split(fullName, "/")
|
||||
return parts[len(parts)-1]
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user