feat(serverless-spark): add cancel-batch tool

This commit is contained in:
Dave Borowitz
2025-10-29 11:47:30 -07:00
parent 8ef0566e1e
commit 2881683226
13 changed files with 516 additions and 17 deletions

View File

@@ -751,8 +751,9 @@ steps:
entrypoint: /bin/bash
env:
- "GOPATH=/gopath"
- "SERVERLESS_SPARK_PROJECT=$PROJECT_ID"
- "SERVERLESS_SPARK_LOCATION=$_REGION"
- "SERVERLESS_SPARK_PROJECT=$PROJECT_ID"
- "SERVERLESS_SPARK_SERVICE_ACCOUNT=$SERVICE_ACCOUNT_EMAIL"
secretEnv: ["CLIENT_ID"]
volumes:
- name: "go"

View File

@@ -163,6 +163,7 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistviews"
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql"
_ "github.com/googleapis/genai-toolbox/internal/tools/redis"
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcancelbatch"
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkgetbatch"
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches"
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerexecutesql"

View File

@@ -1474,7 +1474,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{
"serverless_spark_tools": tools.ToolsetConfig{
Name: "serverless_spark_tools",
ToolNames: []string{"list_batches", "get_batch"},
ToolNames: []string{"list_batches", "get_batch", "cancel_batch"},
},
},
},

View File

@@ -19,6 +19,8 @@ Apache Spark.
List and filter Serverless Spark batches.
- [`serverless-spark-get-batch`](../tools/serverless-spark/serverless-spark-get-batch.md)
Get a Serverless Spark batch.
- [`serverless-spark-cancel-batch`](../tools/serverless-spark/serverless-spark-cancel-batch.md)
Cancel a running Serverless Spark batch operation.
## Requirements

View File

@@ -8,3 +8,4 @@ description: >
- [serverless-spark-get-batch](./serverless-spark-get-batch.md)
- [serverless-spark-list-batches](./serverless-spark-list-batches.md)
- [serverless-spark-cancel-batch](./serverless-spark-cancel-batch.md)

View File

@@ -0,0 +1,51 @@
---
title: "serverless-spark-cancel-batch"
type: docs
weight: 2
description: >
A "serverless-spark-cancel-batch" tool cancels a running Spark batch operation.
aliases:
- /resources/tools/serverless-spark-cancel-batch
---
## About
`serverless-spark-cancel-batch` tool cancels a running Spark batch operation in
a Google Cloud Serverless for Apache Spark source. The cancellation request is
asynchronous, so the batch state will not change immediately after the tool
returns; it can take a minute or so for the cancellation to be reflected.
It's compatible with the following sources:
- [serverless-spark](../../sources/serverless-spark.md)
`serverless-spark-cancel-batch` accepts the following parameters:
- **`operation`** (required): The name of the operation to cancel. For example, for `projects/my-project/locations/us-central1/operations/my-operation`, you would pass `my-operation`.
The tool inherits the `project` and `location` from the source configuration.
## Example
```yaml
tools:
cancel_spark_batch:
kind: serverless-spark-cancel-batch
source: my-serverless-spark-source
description: Use this tool to cancel a running serverless spark batch operation.
```
## Response Format
```json
"Cancelled [projects/my-project/regions/us-central1/operations/my-operation]."
```
## Reference
| **field** | **type** | **required** | **description** |
| ------------ | :------: | :----------: | -------------------------------------------------- |
| kind | string | true | Must be "serverless-spark-cancel-batch". |
| source | string | true | Name of the source the tool should use. |
| description | string | true | Description of the tool that is passed to the LLM. |
| authRequired | string[] | false | List of auth services required to invoke this tool |

2
go.mod
View File

@@ -13,6 +13,7 @@ require (
cloud.google.com/go/dataproc/v2 v2.15.0
cloud.google.com/go/firestore v1.20.0
cloud.google.com/go/geminidataanalytics v0.2.1
cloud.google.com/go/longrunning v0.7.0
cloud.google.com/go/spanner v1.86.1
github.com/ClickHouse/clickhouse-go/v2 v2.40.3
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.53.0
@@ -80,7 +81,6 @@ require (
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
cloud.google.com/go/compute/metadata v0.9.0 // indirect
cloud.google.com/go/iam v1.5.3 // indirect
cloud.google.com/go/longrunning v0.7.0 // indirect
cloud.google.com/go/monitoring v1.24.3 // indirect
cloud.google.com/go/trace v1.11.7 // indirect
filippo.io/edwards25519 v1.1.0 // indirect

View File

@@ -25,8 +25,12 @@ tools:
get_batch:
kind: serverless-spark-get-batch
source: serverless-spark-source
cancel_batch:
kind: serverless-spark-cancel-batch
source: serverless-spark-source
toolsets:
serverless_spark_tools:
- list_batches
- get_batch
- cancel_batch

View File

@@ -22,6 +22,7 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"cloud.google.com/go/longrunning/autogen"
"go.opentelemetry.io/otel/trace"
"google.golang.org/api/option"
)
@@ -66,13 +67,18 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
if err != nil {
return nil, fmt.Errorf("failed to create dataproc client: %w", err)
}
opsClient, err := longrunning.NewOperationsClient(ctx, option.WithEndpoint(endpoint), option.WithUserAgent(ua))
if err != nil {
return nil, fmt.Errorf("failed to create longrunning client: %w", err)
}
s := &Source{
Name: r.Name,
Kind: SourceKind,
Project: r.Project,
Location: r.Location,
Client: client,
Name: r.Name,
Kind: SourceKind,
Project: r.Project,
Location: r.Location,
Client: client,
OpsClient: opsClient,
}
return s, nil
}
@@ -80,11 +86,12 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
var _ sources.Source = &Source{}
type Source struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Project string
Location string
Client *dataproc.BatchControllerClient
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Project string
Location string
Client *dataproc.BatchControllerClient
OpsClient *longrunning.OperationsClient
}
func (s *Source) SourceKind() string {
@@ -94,3 +101,17 @@ func (s *Source) SourceKind() string {
func (s *Source) GetBatchControllerClient() *dataproc.BatchControllerClient {
return s.Client
}
func (s *Source) GetOperationsClient(ctx context.Context) (*longrunning.OperationsClient, error) {
return s.OpsClient, nil
}
func (s *Source) Close() error {
if err := s.Client.Close(); err != nil {
return err
}
if err := s.OpsClient.Close(); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,162 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package serverlesssparkcancelbatch
import (
"context"
"fmt"
"strings"
"cloud.google.com/go/longrunning/autogen/longrunningpb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
"github.com/googleapis/genai-toolbox/internal/tools"
)
const kind = "serverless-spark-cancel-batch"
func init() {
if !tools.Register(kind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", kind))
}
}
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
}
return actual, nil
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description"`
AuthRequired []string `yaml:"authRequired"`
}
// validate interface
var _ tools.ToolConfig = Config{}
// ToolConfigKind returns the unique name for this tool.
func (cfg Config) ToolConfigKind() string {
return kind
}
// Initialize creates a new Tool instance.
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("source %q not found", cfg.Source)
}
ds, ok := rawS.(*serverlessspark.Source)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, serverlessspark.SourceKind)
}
desc := cfg.Description
if desc == "" {
desc = "Cancels a running Serverless Spark (aka Dataproc Serverless) batch operation. Note that the batch state will not change immediately after the tool returns; it can take a minute or so for the cancellation to be reflected."
}
allParameters := tools.Parameters{
tools.NewStringParameter("operation", "The name of the operation to cancel, e.g. for \"projects/my-project/locations/us-central1/operations/my-operation\", pass \"my-operation\""),
}
inputSchema, _ := allParameters.McpManifest()
mcpManifest := tools.McpManifest{
Name: cfg.Name,
Description: desc,
InputSchema: inputSchema,
}
return &Tool{
Name: cfg.Name,
Kind: kind,
Source: ds,
AuthRequired: cfg.AuthRequired,
manifest: tools.Manifest{Description: desc, Parameters: allParameters.Manifest()},
mcpManifest: mcpManifest,
Parameters: allParameters,
}, nil
}
// Tool is the implementation of the tool.
type Tool struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Description string `yaml:"description"`
AuthRequired []string `yaml:"authRequired"`
Source *serverlessspark.Source
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters tools.Parameters
}
// Invoke executes the tool's operation.
func (t *Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
client, err := t.Source.GetOperationsClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get operations client: %w", err)
}
paramMap := params.AsMap()
operation, ok := paramMap["operation"].(string)
if !ok {
return nil, fmt.Errorf("missing required parameter: operation")
}
if strings.Contains(operation, "/") {
return nil, fmt.Errorf("operation must be a short operation name without '/': %s", operation)
}
req := &longrunningpb.CancelOperationRequest{
Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", t.Source.Project, t.Source.Location, operation),
}
err = client.CancelOperation(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to cancel operation: %w", err)
}
return fmt.Sprintf("Cancelled [%s].", operation), nil
}
func (t *Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
return tools.ParseParams(t.Parameters, data, claims)
}
func (t *Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t *Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t *Tool) Authorized(services []string) bool {
return tools.IsAuthorized(t.AuthRequired, services)
}
func (t *Tool) RequiresClientAuthorization() bool {
// Client OAuth not supported, rely on ADCs.
return false
}

View File

@@ -0,0 +1,72 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package serverlesssparkcancelbatch_test
import (
"testing"
"github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcancelbatch"
)
func TestParseFromYaml(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
desc string
in string
want server.ToolConfigs
}{
{
desc: "basic example",
in: `
tools:
example_tool:
kind: serverless-spark-cancel-batch
source: my-instance
description: some description
`,
want: server.ToolConfigs{
"example_tool": serverlesssparkcancelbatch.Config{
Name: "example_tool",
Kind: "serverless-spark-cancel-batch",
Source: "my-instance",
Description: "some description",
AuthRequired: []string{},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Tools server.ToolConfigs `yaml:"tools"`
}{}
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got, yaml.Strict())
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})
}
}

View File

@@ -127,6 +127,7 @@ type Batch struct {
State string `json:"state"`
Creator string `json:"creator"`
CreateTime string `json:"createTime"`
Operation string `json:"operation"`
}
// Invoke executes the tool's operation.
@@ -177,6 +178,7 @@ func ToBatches(batchPbs []*dataprocpb.Batch) []Batch {
State: batchPb.State.Enum().String(),
Creator: batchPb.Creator,
CreateTime: batchPb.CreateTime.AsTime().Format(time.RFC3339),
Operation: batchPb.Operation,
}
batches = append(batches, batch)
}

View File

@@ -24,6 +24,7 @@ import (
"os"
"reflect"
"regexp"
"slices"
"strings"
"testing"
"time"
@@ -41,16 +42,19 @@ import (
)
var (
serverlessSparkProject = os.Getenv("SERVERLESS_SPARK_PROJECT")
serverlessSparkLocation = os.Getenv("SERVERLESS_SPARK_LOCATION")
serverlessSparkLocation = os.Getenv("SERVERLESS_SPARK_LOCATION")
serverlessSparkProject = os.Getenv("SERVERLESS_SPARK_PROJECT")
serverlessSparkServiceAccount = os.Getenv("SERVERLESS_SPARK_SERVICE_ACCOUNT")
)
func getServerlessSparkVars(t *testing.T) map[string]any {
switch "" {
case serverlessSparkProject:
t.Fatal("'SERVERLESS_SPARK_PROJECT' not set")
case serverlessSparkLocation:
t.Fatal("'SERVERLESS_SPARK_LOCATION' not set")
case serverlessSparkProject:
t.Fatal("'SERVERLESS_SPARK_PROJECT' not set")
case serverlessSparkServiceAccount:
t.Fatal("'SERVERLESS_SPARK_SERVICE_ACCOUNT' not set")
}
return map[string]any{
@@ -94,6 +98,15 @@ func TestServerlessSparkToolEndpoints(t *testing.T) {
"source": "my-spark",
"authRequired": []string{"my-google-auth"},
},
"cancel-batch": map[string]any{
"kind": "serverless-spark-cancel-batch",
"source": "my-spark",
},
"cancel-batch-with-auth": map[string]any{
"kind": "serverless-spark-cancel-batch",
"source": "my-spark",
"authRequired": []string{"my-google-auth"},
},
},
}
@@ -206,9 +219,178 @@ func TestServerlessSparkToolEndpoints(t *testing.T) {
runAuthTest(t, "get-batch-with-auth", map[string]any{"name": shortName(fullName)}, http.StatusOK)
})
})
t.Run("cancel-batch", func(t *testing.T) {
t.Parallel()
t.Run("success", func(t *testing.T) {
t.Parallel()
tcs := []struct {
name string
getBatchName func(t *testing.T) string
}{
{
name: "running batch",
getBatchName: func(t *testing.T) string {
return createBatch(t, client, ctx)
},
},
{
name: "succeeded batch",
getBatchName: func(t *testing.T) string {
return listBatchesRpc(t, client, ctx, "state = SUCCEEDED", 1, true)[0].Name
},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
runCancelBatchTest(t, client, ctx, tc.getBatchName(t))
})
}
})
t.Run("errors", func(t *testing.T) {
t.Parallel()
// Find a batch that's already completed.
completedBatchOp := listBatchesRpc(t, client, ctx, "state = SUCCEEDED", 1, true)[0].Operation
fullOpName := fmt.Sprintf("projects/%s/locations/%s/operations/%s", serverlessSparkProject, serverlessSparkLocation, shortName(completedBatchOp))
tcs := []struct {
name string
toolName string
request map[string]any
wantCode int
wantMsg string
}{
{
name: "missing op parameter",
toolName: "cancel-batch",
request: map[string]any{},
wantCode: http.StatusBadRequest,
wantMsg: "parameter \\\"operation\\\" is required",
},
{
name: "nonexistent op",
toolName: "cancel-batch",
request: map[string]any{"operation": "INVALID_OPERATION"},
wantCode: http.StatusBadRequest,
wantMsg: "Operation not found",
},
{
name: "full op name",
toolName: "cancel-batch",
request: map[string]any{"operation": fullOpName},
wantCode: http.StatusBadRequest,
wantMsg: fmt.Sprintf("operation must be a short operation name without '/': %s", fullOpName),
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
testError(t, tc.toolName, tc.request, tc.wantCode, tc.wantMsg)
})
}
})
t.Run("auth", func(t *testing.T) {
t.Parallel()
runAuthTest(t, "cancel-batch-with-auth", map[string]any{"operation": "INVALID_OPERATION"}, http.StatusBadRequest)
})
})
})
}
func waitForBatch(t *testing.T, client *dataproc.BatchControllerClient, parentCtx context.Context, batch string, desiredStates []dataprocpb.Batch_State, timeout time.Duration) {
ctx, cancel := context.WithTimeout(parentCtx, timeout)
defer cancel()
for {
select {
case <-ctx.Done():
t.Fatalf("timed out waiting for batch %s to reach one of states %v", batch, desiredStates)
default:
}
getReq := &dataprocpb.GetBatchRequest{Name: batch}
batch, err := client.GetBatch(ctx, getReq)
if err != nil {
t.Fatalf("failed to get batch %s: %v", batch, err)
}
if slices.Contains(desiredStates, batch.State) {
return
}
if batch.State == dataprocpb.Batch_FAILED || batch.State == dataprocpb.Batch_CANCELLED || batch.State == dataprocpb.Batch_SUCCEEDED {
t.Fatalf("batch op %s is in a terminal state %s, but wanted one of %v. State message: %s", batch, batch.State, desiredStates, batch.StateMessage)
}
time.Sleep(2 * time.Second)
}
}
// createBatch creates a test batch and immediately returns the batch name, without waiting for the
// batch to start or complete.
func createBatch(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context) string {
parent := fmt.Sprintf("projects/%s/locations/%s", serverlessSparkProject, serverlessSparkLocation)
req := &dataprocpb.CreateBatchRequest{
Parent: parent,
Batch: &dataprocpb.Batch{
BatchConfig: &dataprocpb.Batch_SparkBatch{
SparkBatch: &dataprocpb.SparkBatch{
Driver: &dataprocpb.SparkBatch_MainClass{
MainClass: "org.apache.spark.examples.SparkPi",
},
JarFileUris: []string{
"file:///usr/lib/spark/examples/jars/spark-examples.jar",
},
Args: []string{"1000"},
},
},
EnvironmentConfig: &dataprocpb.EnvironmentConfig{
ExecutionConfig: &dataprocpb.ExecutionConfig{
ServiceAccount: serverlessSparkServiceAccount,
},
},
},
}
createOp, err := client.CreateBatch(ctx, req)
if err != nil {
t.Fatalf("failed to create batch: %v", err)
}
meta, err := createOp.Metadata()
if err != nil {
t.Fatalf("failed to get batch metadata: %v", err)
}
// Wait for the batch to become at least PENDING; it typically takes >10s to go from PENDING to
// RUNNING, giving us plenty of time to cancel it before it completes.
waitForBatch(t, client, ctx, meta.Batch, []dataprocpb.Batch_State{dataprocpb.Batch_PENDING, dataprocpb.Batch_RUNNING}, 1*time.Minute)
return meta.Batch
}
func runCancelBatchTest(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, batchName string) {
// First get the batch details directly from the Go proto API.
batch, err := client.GetBatch(ctx, &dataprocpb.GetBatchRequest{Name: batchName})
if err != nil {
t.Fatalf("failed to get batch: %s", err)
}
request := map[string]any{"operation": shortName(batch.Operation)}
resp, err := invokeTool("cancel-batch", request, nil)
if err != nil {
t.Fatalf("invokeTool failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
}
if batch.State != dataprocpb.Batch_SUCCEEDED {
waitForBatch(t, client, ctx, batchName, []dataprocpb.Batch_State{dataprocpb.Batch_CANCELLING, dataprocpb.Batch_CANCELLED}, 2*time.Minute)
}
}
// runListBatchesTest invokes the running list-batches tool and ensures it returns the correct
// number of results. It can run successfully against any GCP project that contains at least 2 total
// Serverless Spark batches.