From 28816832265250de97d84e6ba38bf6c35e040796 Mon Sep 17 00:00:00 2001 From: Dave Borowitz Date: Wed, 29 Oct 2025 11:47:30 -0700 Subject: [PATCH] feat(serverless-spark): add cancel-batch tool --- .ci/integration.cloudbuild.yaml | 3 +- cmd/root.go | 1 + cmd/root_test.go | 2 +- docs/en/resources/sources/serverless-spark.md | 2 + .../tools/serverless-spark/_index.md | 1 + .../serverless-spark-cancel-batch.md | 51 +++++ go.mod | 2 +- .../tools/serverless-spark.yaml | 4 + .../serverlessspark/serverlessspark.go | 41 +++- .../serverlesssparkcancelbatch.go | 162 +++++++++++++++ .../serverlesssparkcancelbatch_test.go | 72 +++++++ .../serverlesssparklistbatches.go | 2 + .../serverless_spark_integration_test.go | 190 +++++++++++++++++- 13 files changed, 516 insertions(+), 17 deletions(-) create mode 100644 docs/en/resources/tools/serverless-spark/serverless-spark-cancel-batch.md create mode 100644 internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go create mode 100644 internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch_test.go diff --git a/.ci/integration.cloudbuild.yaml b/.ci/integration.cloudbuild.yaml index 17cfa1d3bf..3dc6896ce0 100644 --- a/.ci/integration.cloudbuild.yaml +++ b/.ci/integration.cloudbuild.yaml @@ -751,8 +751,9 @@ steps: entrypoint: /bin/bash env: - "GOPATH=/gopath" - - "SERVERLESS_SPARK_PROJECT=$PROJECT_ID" - "SERVERLESS_SPARK_LOCATION=$_REGION" + - "SERVERLESS_SPARK_PROJECT=$PROJECT_ID" + - "SERVERLESS_SPARK_SERVICE_ACCOUNT=$SERVICE_ACCOUNT_EMAIL" secretEnv: ["CLIENT_ID"] volumes: - name: "go" diff --git a/cmd/root.go b/cmd/root.go index 885ba0b687..9fa25c6775 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -163,6 +163,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistviews" _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql" _ "github.com/googleapis/genai-toolbox/internal/tools/redis" + _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcancelbatch" _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkgetbatch" _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches" _ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerexecutesql" diff --git a/cmd/root_test.go b/cmd/root_test.go index 0449adca00..cb971d1a04 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -1474,7 +1474,7 @@ func TestPrebuiltTools(t *testing.T) { wantToolset: server.ToolsetConfigs{ "serverless_spark_tools": tools.ToolsetConfig{ Name: "serverless_spark_tools", - ToolNames: []string{"list_batches", "get_batch"}, + ToolNames: []string{"list_batches", "get_batch", "cancel_batch"}, }, }, }, diff --git a/docs/en/resources/sources/serverless-spark.md b/docs/en/resources/sources/serverless-spark.md index c6ebbfc5d8..0d137d36b7 100644 --- a/docs/en/resources/sources/serverless-spark.md +++ b/docs/en/resources/sources/serverless-spark.md @@ -19,6 +19,8 @@ Apache Spark. List and filter Serverless Spark batches. - [`serverless-spark-get-batch`](../tools/serverless-spark/serverless-spark-get-batch.md) Get a Serverless Spark batch. +- [`serverless-spark-cancel-batch`](../tools/serverless-spark/serverless-spark-cancel-batch.md) + Cancel a running Serverless Spark batch operation. ## Requirements diff --git a/docs/en/resources/tools/serverless-spark/_index.md b/docs/en/resources/tools/serverless-spark/_index.md index 7e9867aeb6..4974a07b19 100644 --- a/docs/en/resources/tools/serverless-spark/_index.md +++ b/docs/en/resources/tools/serverless-spark/_index.md @@ -8,3 +8,4 @@ description: > - [serverless-spark-get-batch](./serverless-spark-get-batch.md) - [serverless-spark-list-batches](./serverless-spark-list-batches.md) +- [serverless-spark-cancel-batch](./serverless-spark-cancel-batch.md) diff --git a/docs/en/resources/tools/serverless-spark/serverless-spark-cancel-batch.md b/docs/en/resources/tools/serverless-spark/serverless-spark-cancel-batch.md new file mode 100644 index 0000000000..4321d64fee --- /dev/null +++ b/docs/en/resources/tools/serverless-spark/serverless-spark-cancel-batch.md @@ -0,0 +1,51 @@ +--- +title: "serverless-spark-cancel-batch" +type: docs +weight: 2 +description: > + A "serverless-spark-cancel-batch" tool cancels a running Spark batch operation. +aliases: + - /resources/tools/serverless-spark-cancel-batch +--- + +## About + + `serverless-spark-cancel-batch` tool cancels a running Spark batch operation in + a Google Cloud Serverless for Apache Spark source. The cancellation request is + asynchronous, so the batch state will not change immediately after the tool + returns; it can take a minute or so for the cancellation to be reflected. + +It's compatible with the following sources: + +- [serverless-spark](../../sources/serverless-spark.md) + +`serverless-spark-cancel-batch` accepts the following parameters: + +- **`operation`** (required): The name of the operation to cancel. For example, for `projects/my-project/locations/us-central1/operations/my-operation`, you would pass `my-operation`. + +The tool inherits the `project` and `location` from the source configuration. + +## Example + +```yaml +tools: + cancel_spark_batch: + kind: serverless-spark-cancel-batch + source: my-serverless-spark-source + description: Use this tool to cancel a running serverless spark batch operation. +``` + +## Response Format + +```json +"Cancelled [projects/my-project/regions/us-central1/operations/my-operation]." +``` + +## Reference + +| **field** | **type** | **required** | **description** | +| ------------ | :------: | :----------: | -------------------------------------------------- | +| kind | string | true | Must be "serverless-spark-cancel-batch". | +| source | string | true | Name of the source the tool should use. | +| description | string | true | Description of the tool that is passed to the LLM. | +| authRequired | string[] | false | List of auth services required to invoke this tool | diff --git a/go.mod b/go.mod index b194e13f4e..2e91d9d3fe 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( cloud.google.com/go/dataproc/v2 v2.15.0 cloud.google.com/go/firestore v1.20.0 cloud.google.com/go/geminidataanalytics v0.2.1 + cloud.google.com/go/longrunning v0.7.0 cloud.google.com/go/spanner v1.86.1 github.com/ClickHouse/clickhouse-go/v2 v2.40.3 github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.53.0 @@ -80,7 +81,6 @@ require ( cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/compute/metadata v0.9.0 // indirect cloud.google.com/go/iam v1.5.3 // indirect - cloud.google.com/go/longrunning v0.7.0 // indirect cloud.google.com/go/monitoring v1.24.3 // indirect cloud.google.com/go/trace v1.11.7 // indirect filippo.io/edwards25519 v1.1.0 // indirect diff --git a/internal/prebuiltconfigs/tools/serverless-spark.yaml b/internal/prebuiltconfigs/tools/serverless-spark.yaml index 3ef0a2834a..7d78b18a95 100644 --- a/internal/prebuiltconfigs/tools/serverless-spark.yaml +++ b/internal/prebuiltconfigs/tools/serverless-spark.yaml @@ -25,8 +25,12 @@ tools: get_batch: kind: serverless-spark-get-batch source: serverless-spark-source + cancel_batch: + kind: serverless-spark-cancel-batch + source: serverless-spark-source toolsets: serverless_spark_tools: - list_batches - get_batch + - cancel_batch diff --git a/internal/sources/serverlessspark/serverlessspark.go b/internal/sources/serverlessspark/serverlessspark.go index 10cdaf1f78..7a8a769635 100644 --- a/internal/sources/serverlessspark/serverlessspark.go +++ b/internal/sources/serverlessspark/serverlessspark.go @@ -22,6 +22,7 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/util" + "cloud.google.com/go/longrunning/autogen" "go.opentelemetry.io/otel/trace" "google.golang.org/api/option" ) @@ -66,13 +67,18 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So if err != nil { return nil, fmt.Errorf("failed to create dataproc client: %w", err) } + opsClient, err := longrunning.NewOperationsClient(ctx, option.WithEndpoint(endpoint), option.WithUserAgent(ua)) + if err != nil { + return nil, fmt.Errorf("failed to create longrunning client: %w", err) + } s := &Source{ - Name: r.Name, - Kind: SourceKind, - Project: r.Project, - Location: r.Location, - Client: client, + Name: r.Name, + Kind: SourceKind, + Project: r.Project, + Location: r.Location, + Client: client, + OpsClient: opsClient, } return s, nil } @@ -80,11 +86,12 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So var _ sources.Source = &Source{} type Source struct { - Name string `yaml:"name"` - Kind string `yaml:"kind"` - Project string - Location string - Client *dataproc.BatchControllerClient + Name string `yaml:"name"` + Kind string `yaml:"kind"` + Project string + Location string + Client *dataproc.BatchControllerClient + OpsClient *longrunning.OperationsClient } func (s *Source) SourceKind() string { @@ -94,3 +101,17 @@ func (s *Source) SourceKind() string { func (s *Source) GetBatchControllerClient() *dataproc.BatchControllerClient { return s.Client } + +func (s *Source) GetOperationsClient(ctx context.Context) (*longrunning.OperationsClient, error) { + return s.OpsClient, nil +} + +func (s *Source) Close() error { + if err := s.Client.Close(); err != nil { + return err + } + if err := s.OpsClient.Close(); err != nil { + return err + } + return nil +} diff --git a/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go b/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go new file mode 100644 index 0000000000..36a0be4095 --- /dev/null +++ b/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go @@ -0,0 +1,162 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package serverlesssparkcancelbatch + +import ( + "context" + "fmt" + "strings" + + "cloud.google.com/go/longrunning/autogen/longrunningpb" + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark" + "github.com/googleapis/genai-toolbox/internal/tools" +) + +const kind = "serverless-spark-cancel-batch" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +// ToolConfigKind returns the unique name for this tool. +func (cfg Config) ToolConfigKind() string { + return kind +} + +// Initialize creates a new Tool instance. +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("source %q not found", cfg.Source) + } + + ds, ok := rawS.(*serverlessspark.Source) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, serverlessspark.SourceKind) + } + + desc := cfg.Description + if desc == "" { + desc = "Cancels a running Serverless Spark (aka Dataproc Serverless) batch operation. Note that the batch state will not change immediately after the tool returns; it can take a minute or so for the cancellation to be reflected." + } + + allParameters := tools.Parameters{ + tools.NewStringParameter("operation", "The name of the operation to cancel, e.g. for \"projects/my-project/locations/us-central1/operations/my-operation\", pass \"my-operation\""), + } + inputSchema, _ := allParameters.McpManifest() + + mcpManifest := tools.McpManifest{ + Name: cfg.Name, + Description: desc, + InputSchema: inputSchema, + } + + return &Tool{ + Name: cfg.Name, + Kind: kind, + Source: ds, + AuthRequired: cfg.AuthRequired, + manifest: tools.Manifest{Description: desc, Parameters: allParameters.Manifest()}, + mcpManifest: mcpManifest, + Parameters: allParameters, + }, nil +} + +// Tool is the implementation of the tool. +type Tool struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + Description string `yaml:"description"` + AuthRequired []string `yaml:"authRequired"` + + Source *serverlessspark.Source + + manifest tools.Manifest + mcpManifest tools.McpManifest + Parameters tools.Parameters +} + +// Invoke executes the tool's operation. +func (t *Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { + client, err := t.Source.GetOperationsClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get operations client: %w", err) + } + + paramMap := params.AsMap() + operation, ok := paramMap["operation"].(string) + if !ok { + return nil, fmt.Errorf("missing required parameter: operation") + } + + if strings.Contains(operation, "/") { + return nil, fmt.Errorf("operation must be a short operation name without '/': %s", operation) + } + + req := &longrunningpb.CancelOperationRequest{ + Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", t.Source.Project, t.Source.Location, operation), + } + + err = client.CancelOperation(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to cancel operation: %w", err) + } + + return fmt.Sprintf("Cancelled [%s].", operation), nil +} + +func (t *Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) { + return tools.ParseParams(t.Parameters, data, claims) +} + +func (t *Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t *Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t *Tool) Authorized(services []string) bool { + return tools.IsAuthorized(t.AuthRequired, services) +} + +func (t *Tool) RequiresClientAuthorization() bool { + // Client OAuth not supported, rely on ADCs. + return false +} diff --git a/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch_test.go b/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch_test.go new file mode 100644 index 0000000000..5348399a32 --- /dev/null +++ b/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch_test.go @@ -0,0 +1,72 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package serverlesssparkcancelbatch_test + +import ( + "testing" + + "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcancelbatch" +) + +func TestParseFromYaml(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: serverless-spark-cancel-batch + source: my-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": serverlesssparkcancelbatch.Config{ + Name: "example_tool", + Kind: "serverless-spark-cancel-batch", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got, yaml.Strict()) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go b/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go index 7a5cacb4da..a45e504a9b 100644 --- a/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go +++ b/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go @@ -127,6 +127,7 @@ type Batch struct { State string `json:"state"` Creator string `json:"creator"` CreateTime string `json:"createTime"` + Operation string `json:"operation"` } // Invoke executes the tool's operation. @@ -177,6 +178,7 @@ func ToBatches(batchPbs []*dataprocpb.Batch) []Batch { State: batchPb.State.Enum().String(), Creator: batchPb.Creator, CreateTime: batchPb.CreateTime.AsTime().Format(time.RFC3339), + Operation: batchPb.Operation, } batches = append(batches, batch) } diff --git a/tests/serverlessspark/serverless_spark_integration_test.go b/tests/serverlessspark/serverless_spark_integration_test.go index 9455f4e6e9..f2fa106f8a 100644 --- a/tests/serverlessspark/serverless_spark_integration_test.go +++ b/tests/serverlessspark/serverless_spark_integration_test.go @@ -24,6 +24,7 @@ import ( "os" "reflect" "regexp" + "slices" "strings" "testing" "time" @@ -41,16 +42,19 @@ import ( ) var ( - serverlessSparkProject = os.Getenv("SERVERLESS_SPARK_PROJECT") - serverlessSparkLocation = os.Getenv("SERVERLESS_SPARK_LOCATION") + serverlessSparkLocation = os.Getenv("SERVERLESS_SPARK_LOCATION") + serverlessSparkProject = os.Getenv("SERVERLESS_SPARK_PROJECT") + serverlessSparkServiceAccount = os.Getenv("SERVERLESS_SPARK_SERVICE_ACCOUNT") ) func getServerlessSparkVars(t *testing.T) map[string]any { switch "" { - case serverlessSparkProject: - t.Fatal("'SERVERLESS_SPARK_PROJECT' not set") case serverlessSparkLocation: t.Fatal("'SERVERLESS_SPARK_LOCATION' not set") + case serverlessSparkProject: + t.Fatal("'SERVERLESS_SPARK_PROJECT' not set") + case serverlessSparkServiceAccount: + t.Fatal("'SERVERLESS_SPARK_SERVICE_ACCOUNT' not set") } return map[string]any{ @@ -94,6 +98,15 @@ func TestServerlessSparkToolEndpoints(t *testing.T) { "source": "my-spark", "authRequired": []string{"my-google-auth"}, }, + "cancel-batch": map[string]any{ + "kind": "serverless-spark-cancel-batch", + "source": "my-spark", + }, + "cancel-batch-with-auth": map[string]any{ + "kind": "serverless-spark-cancel-batch", + "source": "my-spark", + "authRequired": []string{"my-google-auth"}, + }, }, } @@ -206,9 +219,178 @@ func TestServerlessSparkToolEndpoints(t *testing.T) { runAuthTest(t, "get-batch-with-auth", map[string]any{"name": shortName(fullName)}, http.StatusOK) }) }) + + t.Run("cancel-batch", func(t *testing.T) { + t.Parallel() + t.Run("success", func(t *testing.T) { + t.Parallel() + tcs := []struct { + name string + getBatchName func(t *testing.T) string + }{ + { + name: "running batch", + getBatchName: func(t *testing.T) string { + return createBatch(t, client, ctx) + }, + }, + { + name: "succeeded batch", + getBatchName: func(t *testing.T) string { + return listBatchesRpc(t, client, ctx, "state = SUCCEEDED", 1, true)[0].Name + }, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + runCancelBatchTest(t, client, ctx, tc.getBatchName(t)) + }) + } + }) + t.Run("errors", func(t *testing.T) { + t.Parallel() + // Find a batch that's already completed. + completedBatchOp := listBatchesRpc(t, client, ctx, "state = SUCCEEDED", 1, true)[0].Operation + fullOpName := fmt.Sprintf("projects/%s/locations/%s/operations/%s", serverlessSparkProject, serverlessSparkLocation, shortName(completedBatchOp)) + tcs := []struct { + name string + toolName string + request map[string]any + wantCode int + wantMsg string + }{ + { + name: "missing op parameter", + toolName: "cancel-batch", + request: map[string]any{}, + wantCode: http.StatusBadRequest, + wantMsg: "parameter \\\"operation\\\" is required", + }, + { + name: "nonexistent op", + toolName: "cancel-batch", + request: map[string]any{"operation": "INVALID_OPERATION"}, + wantCode: http.StatusBadRequest, + wantMsg: "Operation not found", + }, + { + name: "full op name", + toolName: "cancel-batch", + request: map[string]any{"operation": fullOpName}, + wantCode: http.StatusBadRequest, + wantMsg: fmt.Sprintf("operation must be a short operation name without '/': %s", fullOpName), + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + testError(t, tc.toolName, tc.request, tc.wantCode, tc.wantMsg) + }) + } + }) + t.Run("auth", func(t *testing.T) { + t.Parallel() + runAuthTest(t, "cancel-batch-with-auth", map[string]any{"operation": "INVALID_OPERATION"}, http.StatusBadRequest) + }) + }) }) } +func waitForBatch(t *testing.T, client *dataproc.BatchControllerClient, parentCtx context.Context, batch string, desiredStates []dataprocpb.Batch_State, timeout time.Duration) { + ctx, cancel := context.WithTimeout(parentCtx, timeout) + defer cancel() + + for { + select { + case <-ctx.Done(): + t.Fatalf("timed out waiting for batch %s to reach one of states %v", batch, desiredStates) + default: + } + + getReq := &dataprocpb.GetBatchRequest{Name: batch} + batch, err := client.GetBatch(ctx, getReq) + if err != nil { + t.Fatalf("failed to get batch %s: %v", batch, err) + } + + if slices.Contains(desiredStates, batch.State) { + return + } + + if batch.State == dataprocpb.Batch_FAILED || batch.State == dataprocpb.Batch_CANCELLED || batch.State == dataprocpb.Batch_SUCCEEDED { + t.Fatalf("batch op %s is in a terminal state %s, but wanted one of %v. State message: %s", batch, batch.State, desiredStates, batch.StateMessage) + } + time.Sleep(2 * time.Second) + } +} + +// createBatch creates a test batch and immediately returns the batch name, without waiting for the +// batch to start or complete. +func createBatch(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context) string { + parent := fmt.Sprintf("projects/%s/locations/%s", serverlessSparkProject, serverlessSparkLocation) + req := &dataprocpb.CreateBatchRequest{ + Parent: parent, + Batch: &dataprocpb.Batch{ + BatchConfig: &dataprocpb.Batch_SparkBatch{ + SparkBatch: &dataprocpb.SparkBatch{ + Driver: &dataprocpb.SparkBatch_MainClass{ + MainClass: "org.apache.spark.examples.SparkPi", + }, + JarFileUris: []string{ + "file:///usr/lib/spark/examples/jars/spark-examples.jar", + }, + Args: []string{"1000"}, + }, + }, + EnvironmentConfig: &dataprocpb.EnvironmentConfig{ + ExecutionConfig: &dataprocpb.ExecutionConfig{ + ServiceAccount: serverlessSparkServiceAccount, + }, + }, + }, + } + + createOp, err := client.CreateBatch(ctx, req) + if err != nil { + t.Fatalf("failed to create batch: %v", err) + } + meta, err := createOp.Metadata() + if err != nil { + t.Fatalf("failed to get batch metadata: %v", err) + } + + // Wait for the batch to become at least PENDING; it typically takes >10s to go from PENDING to + // RUNNING, giving us plenty of time to cancel it before it completes. + waitForBatch(t, client, ctx, meta.Batch, []dataprocpb.Batch_State{dataprocpb.Batch_PENDING, dataprocpb.Batch_RUNNING}, 1*time.Minute) + return meta.Batch +} + +func runCancelBatchTest(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, batchName string) { + // First get the batch details directly from the Go proto API. + batch, err := client.GetBatch(ctx, &dataprocpb.GetBatchRequest{Name: batchName}) + if err != nil { + t.Fatalf("failed to get batch: %s", err) + } + + request := map[string]any{"operation": shortName(batch.Operation)} + resp, err := invokeTool("cancel-batch", request, nil) + if err != nil { + t.Fatalf("invokeTool failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + + if batch.State != dataprocpb.Batch_SUCCEEDED { + waitForBatch(t, client, ctx, batchName, []dataprocpb.Batch_State{dataprocpb.Batch_CANCELLING, dataprocpb.Batch_CANCELLED}, 2*time.Minute) + } +} + // runListBatchesTest invokes the running list-batches tool and ensures it returns the correct // number of results. It can run successfully against any GCP project that contains at least 2 total // Serverless Spark batches.