feat(serverless-spark): Add get_batch tool

This commit is contained in:
Dave Borowitz
2025-09-16 17:17:29 -07:00
parent 5d1be9caf9
commit 7ad10720b4
10 changed files with 509 additions and 72 deletions

View File

@@ -155,6 +155,7 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttables"
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql"
_ "github.com/googleapis/genai-toolbox/internal/tools/redis"
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkgetbatch"
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches"
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerexecutesql"
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerlisttables"

View File

@@ -1467,7 +1467,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{
"serverless_spark_tools": tools.ToolsetConfig{
Name: "serverless_spark_tools",
ToolNames: []string{"list_batches"},
ToolNames: []string{"list_batches", "get_batch"},
},
},
},

View File

@@ -17,6 +17,8 @@ Apache Spark.
- [`serverless-spark-list-batches`](../tools/serverless-spark/serverless-spark-list-batches.md)
List and filter Serverless Spark batches.
- [`serverless-spark-get-batch`](../tools/serverless-spark/serverless-spark-get-batch.md)
Get a Serverless Spark batch.
## Requirements

View File

@@ -4,4 +4,7 @@ type: docs
weight: 1
description: >
Tools that work with Google Cloud Serverless for Apache Spark Sources.
---
---
- [serverless-spark-get-batch](./serverless-spark-get-batch.md)
- [serverless-spark-list-batches](./serverless-spark-list-batches.md)

View File

@@ -0,0 +1,84 @@
---
title: "serverless-spark-get-batch"
type: docs
weight: 1
description: >
A "serverless-spark-get-batch" tool gets a single Spark batch from the source.
aliases:
- /resources/tools/serverless-spark-get-batch
---
# serverless-spark-get-batch
The `serverless-spark-get-batch` tool allows you to retrieve a specific
Serverless Spark batch job. It's compatible with the following sources:
- [serverless-spark](../../sources/serverless-spark.md)
`serverless-spark-list-batches` accepts the following parameters:
- **`name`**: The short name of the batch, e.g. for
`projects/my-project/locations/us-central1/my-batch`, pass `my-batch`.
The tool gets the `project` and `location` from the source configuration.
## Example
```yaml
tools:
get_my_batch:
kind: serverless-spark-get-batch
source: my-serverless-spark-source
description: Use this tool to get a serverless spark batch.
```
## Response Format
The response is a full Batch JSON object as defined in the [API
spec](https://cloud.google.com/dataproc-serverless/docs/reference/rest/v1/projects.locations.batches#Batch).
Example with a reduced set of fields:
```json
{
"createTime": "2025-10-10T15:15:21.303146Z",
"creator": "alice@example.com",
"labels": {
"goog-dataproc-batch-uuid": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
"goog-dataproc-location": "us-central1"
},
"name": "projects/google.com:hadoop-cloud-dev/locations/us-central1/batches/alice-20251010-abcd",
"operation": "projects/google.com:hadoop-cloud-dev/regions/us-central1/operations/11111111-2222-3333-4444-555555555555",
"runtimeConfig": {
"properties": {
"spark:spark.driver.cores": "4",
"spark:spark.driver.memory": "12200m"
}
},
"sparkBatch": {
"jarFileUris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"],
"mainClass": "org.apache.spark.examples.SparkPi"
},
"state": "SUCCEEDED",
"stateHistory": [
{
"state": "PENDING",
"stateStartTime": "2025-10-10T15:15:21.303146Z"
},
{
"state": "RUNNING",
"stateStartTime": "2025-10-10T15:16:41.291747Z"
}
],
"stateTime": "2025-10-10T15:17:21.265493Z",
"uuid": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
}
```
## Reference
| **field** | **type** | **required** | **description** |
| ------------ | :------: | :----------: | -------------------------------------------------- |
| kind | string | true | Must be "serverless-spark-get-batch". |
| source | string | true | Name of the source the tool should use. |
| description | string | true | Description of the tool that is passed to the LLM. |
| authRequired | string[] | false | List of auth services required to invoke this tool |

2
go.mod
View File

@@ -56,6 +56,7 @@ require (
golang.org/x/oauth2 v0.32.0
google.golang.org/api v0.251.0
google.golang.org/genproto v0.0.0-20251007200510-49b9836ed3ff
google.golang.org/protobuf v1.36.10
modernc.org/sqlite v1.39.1
)
@@ -180,7 +181,6 @@ require (
google.golang.org/genproto/googleapis/api v0.0.0-20251002232023-7c0ddcbb5797 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251002232023-7c0ddcbb5797 // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
modernc.org/libc v1.66.10 // indirect

View File

@@ -22,7 +22,11 @@ tools:
list_batches:
kind: serverless-spark-list-batches
source: serverless-spark-source
get_batch:
kind: serverless-spark-get-batch
source: serverless-spark-source
toolsets:
serverless_spark_tools:
- list_batches
- get_batch

View File

@@ -0,0 +1,171 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package serverlesssparkgetbatch
import (
"context"
"encoding/json"
"fmt"
"strings"
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
"github.com/googleapis/genai-toolbox/internal/tools"
"google.golang.org/protobuf/encoding/protojson"
)
const kind = "serverless-spark-get-batch"
func init() {
if !tools.Register(kind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", kind))
}
}
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
}
return actual, nil
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description"`
AuthRequired []string `yaml:"authRequired"`
}
// validate interface
var _ tools.ToolConfig = Config{}
// ToolConfigKind returns the unique name for this tool.
func (cfg Config) ToolConfigKind() string {
return kind
}
// Initialize creates a new Tool instance.
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
ds, ok := rawS.(*serverlessspark.Source)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, serverlessspark.SourceKind)
}
desc := cfg.Description
if desc == "" {
desc = "Gets a Serverless Spark (aka Dataproc Serverless) batch"
}
allParameters := tools.Parameters{
tools.NewStringParameter("name", "The short name of the batch, e.g. for \"projects/my-project/locations/us-central1/batches/my-batch\", pass \"my-batch\" (the project and location are inherited from the source)"),
}
inputSchema, _ := allParameters.McpManifest()
mcpManifest := tools.McpManifest{
Name: cfg.Name,
Description: desc,
InputSchema: inputSchema,
}
return Tool{
Name: cfg.Name,
Kind: kind,
Source: ds,
AuthRequired: cfg.AuthRequired,
manifest: tools.Manifest{Description: desc, Parameters: allParameters.Manifest()},
mcpManifest: mcpManifest,
Parameters: allParameters,
}, nil
}
// Tool is the implementation of the tool.
type Tool struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Description string `yaml:"description"`
AuthRequired []string `yaml:"authRequired"`
Source *serverlessspark.Source
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters tools.Parameters
}
// Invoke executes the tool's operation.
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
client := t.Source.GetBatchControllerClient()
paramMap := params.AsMap()
name, ok := paramMap["name"].(string)
if !ok {
return nil, fmt.Errorf("missing required parameter: name")
}
if strings.Contains(name, "/") {
return nil, fmt.Errorf("name must be a short batch name without '/': %s", name)
}
req := &dataprocpb.GetBatchRequest{
Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", t.Source.Project, t.Source.Location, name),
}
batchPb, err := client.GetBatch(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to get batch: %w", err)
}
jsonBytes, err := protojson.Marshal(batchPb)
if err != nil {
return nil, fmt.Errorf("failed to marshal batch to JSON: %w", err)
}
var result map[string]any
if err := json.Unmarshal(jsonBytes, &result); err != nil {
return nil, fmt.Errorf("failed to unmarshal batch JSON: %w", err)
}
return result, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
return tools.ParseParams(t.Parameters, data, claims)
}
func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) Authorized(services []string) bool {
return tools.IsAuthorized(t.AuthRequired, services)
}
func (t Tool) RequiresClientAuthorization() bool {
// Client OAuth not supported, rely on ADCs.
return false
}

View File

@@ -0,0 +1,72 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package serverlesssparkgetbatch_test
import (
"testing"
"github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkgetbatch"
)
func TestParseFromYaml(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
desc string
in string
want server.ToolConfigs
}{
{
desc: "basic example",
in: `
tools:
example_tool:
kind: serverless-spark-get-batch
source: my-instance
description: some description
`,
want: server.ToolConfigs{
"example_tool": serverlesssparkgetbatch.Config{
Name: "example_tool",
Kind: "serverless-spark-get-batch",
Source: "my-instance",
Description: "some description",
AuthRequired: []string{},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Tools server.ToolConfigs `yaml:"tools"`
}{}
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got, yaml.Strict())
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})
}
}

View File

@@ -24,16 +24,20 @@ import (
"os"
"reflect"
"regexp"
"strings"
"testing"
"time"
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches"
"github.com/googleapis/genai-toolbox/tests"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/testing/protocmp"
)
var (
@@ -81,6 +85,15 @@ func TestServerlessSparkToolEndpoints(t *testing.T) {
"source": "my-spark",
"authRequired": []string{"my-google-auth"},
},
"get-batch": map[string]any{
"kind": "serverless-spark-get-batch",
"source": "my-spark",
},
"get-batch-with-auth": map[string]any{
"kind": "serverless-spark-get-batch",
"source": "my-spark",
"authRequired": []string{"my-google-auth"},
},
},
}
@@ -106,13 +119,20 @@ func TestServerlessSparkToolEndpoints(t *testing.T) {
defer client.Close()
runListBatchesTest(t, client, ctx)
runListBatchesErrorTest(t)
runListBatchesAuthTest(t)
fullName := listBatchesRpc(t, client, ctx, "", 1, true)[0].Name
runGetBatchTest(t, client, ctx, fullName)
runErrorTest(t)
// Get the most recent batch, which is all we need for this test.
runAuthTest(t, "list-batches-with-auth", map[string]any{"pageSize": 1})
runAuthTest(t, "get-batch-with-auth", map[string]any{"name": shortName(fullName)})
}
// runListBatchesTest invokes the running list-batches tool and ensures it returns the correct
// number of results. It can run successfully against any GCP project that has at least 2 succeeded
// or failed Serverless Spark batches, of any age.
// number of results. It can run successfully against any GCP project that contains at least 2 total
// Serverless Spark batches.
func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context) {
batch2 := listBatchesRpc(t, client, ctx, "", 2, true)
batch20 := listBatchesRpc(t, client, ctx, "", 20, false)
@@ -145,7 +165,7 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
t.Run("list-batches "+tc.name, func(t *testing.T) {
var actual []serverlesssparklistbatches.Batch
var pageToken string
for i := 0; i < tc.numPages; i++ {
@@ -157,9 +177,9 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct
request["pageSize"] = tc.pageSize
}
resp, err := invokeListBatches("list-batches", request, nil)
resp, err := invokeTool("list-batches", request, nil)
if err != nil {
t.Fatalf("invokeListBatches failed: %v", err)
t.Fatalf("invokeTool failed: %v", err)
}
defer resp.Body.Close()
@@ -221,35 +241,160 @@ func listBatchesRpc(t *testing.T, client *dataproc.BatchControllerClient, ctx co
return serverlesssparklistbatches.ToBatches(batchPbs)
}
func runListBatchesErrorTest(t *testing.T) {
func runAuthTest(t *testing.T, toolName string, request map[string]any) {
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
if err != nil {
t.Fatalf("error getting Google ID token: %s", err)
}
tcs := []struct {
name string
headers map[string]string
wantStatus int
}{
{
name: "valid auth token",
headers: map[string]string{"my-google-auth_token": idToken},
wantStatus: http.StatusOK,
},
{
name: "invalid auth token",
headers: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
wantStatus: http.StatusUnauthorized,
},
{
name: "no auth token",
headers: nil,
wantStatus: http.StatusUnauthorized,
},
}
for _, tc := range tcs {
t.Run(toolName+" "+tc.name, func(t *testing.T) {
resp, err := invokeTool(toolName, request, tc.headers)
if err != nil {
t.Fatalf("invokeTool failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != tc.wantStatus {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not %d, got %d: %s", tc.wantStatus, resp.StatusCode, string(bodyBytes))
}
})
}
}
func runGetBatchTest(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, fullName string) {
// First get the batch details directly from the Go proto API.
req := &dataprocpb.GetBatchRequest{
Name: fullName,
}
rawWantBatchPb, err := client.GetBatch(ctx, req)
if err != nil {
t.Fatalf("failed to get batch: %s", err)
}
// Trim unknown fields from the proto by marshalling and unmarshalling.
jsonBytes, err := protojson.Marshal(rawWantBatchPb)
if err != nil {
t.Fatalf("failed to marshal batch to JSON: %s", err)
}
var wantBatchPb dataprocpb.Batch
if err := protojson.Unmarshal(jsonBytes, &wantBatchPb); err != nil {
t.Fatalf("error unmarshalling result: %s", err)
}
tcs := []struct {
name string
batchName string
want *dataprocpb.Batch
}{
{
name: "found batch",
batchName: shortName(fullName),
want: &wantBatchPb,
},
}
for _, tc := range tcs {
t.Run("get-batch "+tc.name, func(t *testing.T) {
request := map[string]any{"name": tc.batchName}
resp, err := invokeTool("get-batch", request, nil)
if err != nil {
t.Fatalf("invokeTool failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
}
var body map[string]any
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
t.Fatalf("error parsing response body: %v", err)
}
result, ok := body["result"].(string)
if !ok {
t.Fatalf("unable to find result in response body")
}
// Unmarshal JSON to proto for proto-aware deep comparison.
var batch dataprocpb.Batch
if err := protojson.Unmarshal([]byte(result), &batch); err != nil {
t.Fatalf("error unmarshalling result: %s", err)
}
if !cmp.Equal(&batch, tc.want, protocmp.Transform()) {
diff := cmp.Diff(&batch, tc.want, protocmp.Transform())
t.Errorf("GetBatch() returned diff (-got +want):\n%s", diff)
}
})
}
}
func runErrorTest(t *testing.T) {
missingBatchFullName := fmt.Sprintf("projects/%s/locations/%s/batches/INVALID_BATCH", serverlessSparkProject, serverlessSparkLocation)
tcs := []struct {
name string
pageSize int
toolName string
request map[string]any
wantCode int
wantMsg string
}{
{
name: "zero page size",
pageSize: 0,
name: "list-batches zero page size",
toolName: "list-batches",
request: map[string]any{"pageSize": 0},
wantCode: http.StatusBadRequest,
wantMsg: "pageSize must be positive: 0",
},
{
name: "negative page size",
pageSize: -1,
name: "list-batches negative page size",
toolName: "list-batches",
request: map[string]any{"pageSize": -1},
wantCode: http.StatusBadRequest,
wantMsg: "pageSize must be positive: -1",
},
{
name: "get-batch missing batch",
toolName: "get-batch",
request: map[string]any{"name": "INVALID_BATCH"},
wantCode: http.StatusBadRequest,
wantMsg: fmt.Sprintf("Not found: Batch projects/%s/locations/%s/batches/INVALID_BATCH", serverlessSparkProject, serverlessSparkLocation),
},
{
name: "get-batch full batch name",
toolName: "get-batch",
request: map[string]any{"name": missingBatchFullName},
wantCode: http.StatusBadRequest,
wantMsg: fmt.Sprintf("name must be a short batch name without '/': %s", missingBatchFullName),
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
request := map[string]any{
"pageSize": tc.pageSize,
}
resp, err := invokeListBatches("list-batches", request, nil)
resp, err := invokeTool(tc.toolName, tc.request, nil)
if err != nil {
t.Fatalf("invokeListBatches failed: %v", err)
t.Fatalf("invokeTool failed: %v", err)
}
defer resp.Body.Close()
@@ -270,57 +415,7 @@ func runListBatchesErrorTest(t *testing.T) {
}
}
func runListBatchesAuthTest(t *testing.T) {
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
if err != nil {
t.Fatalf("error getting Google ID token: %s", err)
}
tcs := []struct {
name string
toolName string
headers map[string]string
wantStatus int
}{
{
name: "valid auth token",
toolName: "list-batches-with-auth",
headers: map[string]string{"my-google-auth_token": idToken},
wantStatus: http.StatusOK,
},
{
name: "invalid auth token",
toolName: "list-batches-with-auth",
headers: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
wantStatus: http.StatusUnauthorized,
},
{
name: "no auth token",
toolName: "list-batches-with-auth",
headers: nil,
wantStatus: http.StatusUnauthorized,
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
request := map[string]any{
"pageSize": 1,
}
resp, err := invokeListBatches(tc.toolName, request, tc.headers)
if err != nil {
t.Fatalf("invokeListBatches failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != tc.wantStatus {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not %d, got %d: %s", tc.wantStatus, resp.StatusCode, string(bodyBytes))
}
})
}
}
func invokeListBatches(toolName string, request map[string]any, headers map[string]string) (*http.Response, error) {
func invokeTool(toolName string, request map[string]any, headers map[string]string) (*http.Response, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
@@ -338,3 +433,8 @@ func invokeListBatches(toolName string, request map[string]any, headers map[stri
return http.DefaultClient.Do(req)
}
func shortName(fullName string) string {
parts := strings.Split(fullName, "/")
return parts[len(parts)-1]
}