mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 07:28:05 -05:00
feat(serverless-spark): add cancel-batch tool
This commit is contained in:
@@ -751,8 +751,9 @@ steps:
|
||||
entrypoint: /bin/bash
|
||||
env:
|
||||
- "GOPATH=/gopath"
|
||||
- "SERVERLESS_SPARK_PROJECT=$PROJECT_ID"
|
||||
- "SERVERLESS_SPARK_LOCATION=$_REGION"
|
||||
- "SERVERLESS_SPARK_PROJECT=$PROJECT_ID"
|
||||
- "SERVERLESS_SPARK_SERVICE_ACCOUNT=$SERVICE_ACCOUNT_EMAIL"
|
||||
secretEnv: ["CLIENT_ID"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
|
||||
@@ -163,6 +163,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistviews"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/redis"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcancelbatch"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkgetbatch"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerexecutesql"
|
||||
|
||||
@@ -1474,7 +1474,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"serverless_spark_tools": tools.ToolsetConfig{
|
||||
Name: "serverless_spark_tools",
|
||||
ToolNames: []string{"list_batches", "get_batch"},
|
||||
ToolNames: []string{"list_batches", "get_batch", "cancel_batch"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -19,6 +19,8 @@ Apache Spark.
|
||||
List and filter Serverless Spark batches.
|
||||
- [`serverless-spark-get-batch`](../tools/serverless-spark/serverless-spark-get-batch.md)
|
||||
Get a Serverless Spark batch.
|
||||
- [`serverless-spark-cancel-batch`](../tools/serverless-spark/serverless-spark-cancel-batch.md)
|
||||
Cancel a running Serverless Spark batch operation.
|
||||
|
||||
## Requirements
|
||||
|
||||
|
||||
@@ -8,3 +8,4 @@ description: >
|
||||
|
||||
- [serverless-spark-get-batch](./serverless-spark-get-batch.md)
|
||||
- [serverless-spark-list-batches](./serverless-spark-list-batches.md)
|
||||
- [serverless-spark-cancel-batch](./serverless-spark-cancel-batch.md)
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
---
|
||||
title: "serverless-spark-cancel-batch"
|
||||
type: docs
|
||||
weight: 2
|
||||
description: >
|
||||
A "serverless-spark-cancel-batch" tool cancels a running Spark batch operation.
|
||||
aliases:
|
||||
- /resources/tools/serverless-spark-cancel-batch
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
`serverless-spark-cancel-batch` tool cancels a running Spark batch operation in
|
||||
a Google Cloud Serverless for Apache Spark source. The cancellation request is
|
||||
asynchronous, so the batch state will not change immediately after the tool
|
||||
returns; it can take a minute or so for the cancellation to be reflected.
|
||||
|
||||
It's compatible with the following sources:
|
||||
|
||||
- [serverless-spark](../../sources/serverless-spark.md)
|
||||
|
||||
`serverless-spark-cancel-batch` accepts the following parameters:
|
||||
|
||||
- **`operation`** (required): The name of the operation to cancel. For example, for `projects/my-project/locations/us-central1/operations/my-operation`, you would pass `my-operation`.
|
||||
|
||||
The tool inherits the `project` and `location` from the source configuration.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
cancel_spark_batch:
|
||||
kind: serverless-spark-cancel-batch
|
||||
source: my-serverless-spark-source
|
||||
description: Use this tool to cancel a running serverless spark batch operation.
|
||||
```
|
||||
|
||||
## Response Format
|
||||
|
||||
```json
|
||||
"Cancelled [projects/my-project/regions/us-central1/operations/my-operation]."
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
| ------------ | :------: | :----------: | -------------------------------------------------- |
|
||||
| kind | string | true | Must be "serverless-spark-cancel-batch". |
|
||||
| source | string | true | Name of the source the tool should use. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM. |
|
||||
| authRequired | string[] | false | List of auth services required to invoke this tool |
|
||||
2
go.mod
2
go.mod
@@ -13,6 +13,7 @@ require (
|
||||
cloud.google.com/go/dataproc/v2 v2.15.0
|
||||
cloud.google.com/go/firestore v1.20.0
|
||||
cloud.google.com/go/geminidataanalytics v0.2.1
|
||||
cloud.google.com/go/longrunning v0.7.0
|
||||
cloud.google.com/go/spanner v1.86.1
|
||||
github.com/ClickHouse/clickhouse-go/v2 v2.40.3
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.53.0
|
||||
@@ -80,7 +81,6 @@ require (
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.9.0 // indirect
|
||||
cloud.google.com/go/iam v1.5.3 // indirect
|
||||
cloud.google.com/go/longrunning v0.7.0 // indirect
|
||||
cloud.google.com/go/monitoring v1.24.3 // indirect
|
||||
cloud.google.com/go/trace v1.11.7 // indirect
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
|
||||
@@ -25,8 +25,12 @@ tools:
|
||||
get_batch:
|
||||
kind: serverless-spark-get-batch
|
||||
source: serverless-spark-source
|
||||
cancel_batch:
|
||||
kind: serverless-spark-cancel-batch
|
||||
source: serverless-spark-source
|
||||
|
||||
toolsets:
|
||||
serverless_spark_tools:
|
||||
- list_batches
|
||||
- get_batch
|
||||
- cancel_batch
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"cloud.google.com/go/longrunning/autogen"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
@@ -66,13 +67,18 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create dataproc client: %w", err)
|
||||
}
|
||||
opsClient, err := longrunning.NewOperationsClient(ctx, option.WithEndpoint(endpoint), option.WithUserAgent(ua))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create longrunning client: %w", err)
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
Project: r.Project,
|
||||
Location: r.Location,
|
||||
Client: client,
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
Project: r.Project,
|
||||
Location: r.Location,
|
||||
Client: client,
|
||||
OpsClient: opsClient,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
@@ -80,11 +86,12 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Project string
|
||||
Location string
|
||||
Client *dataproc.BatchControllerClient
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Project string
|
||||
Location string
|
||||
Client *dataproc.BatchControllerClient
|
||||
OpsClient *longrunning.OperationsClient
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
@@ -94,3 +101,17 @@ func (s *Source) SourceKind() string {
|
||||
func (s *Source) GetBatchControllerClient() *dataproc.BatchControllerClient {
|
||||
return s.Client
|
||||
}
|
||||
|
||||
func (s *Source) GetOperationsClient(ctx context.Context) (*longrunning.OperationsClient, error) {
|
||||
return s.OpsClient, nil
|
||||
}
|
||||
|
||||
func (s *Source) Close() error {
|
||||
if err := s.Client.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.OpsClient.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package serverlesssparkcancelbatch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cloud.google.com/go/longrunning/autogen/longrunningpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const kind = "serverless-spark-cancel-batch"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
// ToolConfigKind returns the unique name for this tool.
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
// Initialize creates a new Tool instance.
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
ds, ok := rawS.(*serverlessspark.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, serverlessspark.SourceKind)
|
||||
}
|
||||
|
||||
desc := cfg.Description
|
||||
if desc == "" {
|
||||
desc = "Cancels a running Serverless Spark (aka Dataproc Serverless) batch operation. Note that the batch state will not change immediately after the tool returns; it can take a minute or so for the cancellation to be reflected."
|
||||
}
|
||||
|
||||
allParameters := tools.Parameters{
|
||||
tools.NewStringParameter("operation", "The name of the operation to cancel, e.g. for \"projects/my-project/locations/us-central1/operations/my-operation\", pass \"my-operation\""),
|
||||
}
|
||||
inputSchema, _ := allParameters.McpManifest()
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: desc,
|
||||
InputSchema: inputSchema,
|
||||
}
|
||||
|
||||
return &Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Source: ds,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
manifest: tools.Manifest{Description: desc, Parameters: allParameters.Manifest()},
|
||||
mcpManifest: mcpManifest,
|
||||
Parameters: allParameters,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Tool is the implementation of the tool.
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
|
||||
Source *serverlessspark.Source
|
||||
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Parameters tools.Parameters
|
||||
}
|
||||
|
||||
// Invoke executes the tool's operation.
|
||||
func (t *Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
client, err := t.Source.GetOperationsClient(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get operations client: %w", err)
|
||||
}
|
||||
|
||||
paramMap := params.AsMap()
|
||||
operation, ok := paramMap["operation"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing required parameter: operation")
|
||||
}
|
||||
|
||||
if strings.Contains(operation, "/") {
|
||||
return nil, fmt.Errorf("operation must be a short operation name without '/': %s", operation)
|
||||
}
|
||||
|
||||
req := &longrunningpb.CancelOperationRequest{
|
||||
Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", t.Source.Project, t.Source.Location, operation),
|
||||
}
|
||||
|
||||
err = client.CancelOperation(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to cancel operation: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Cancelled [%s].", operation), nil
|
||||
}
|
||||
|
||||
func (t *Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claims)
|
||||
}
|
||||
|
||||
func (t *Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t *Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t *Tool) Authorized(services []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, services)
|
||||
}
|
||||
|
||||
func (t *Tool) RequiresClientAuthorization() bool {
|
||||
// Client OAuth not supported, rely on ADCs.
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package serverlesssparkcancelbatch_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcancelbatch"
|
||||
)
|
||||
|
||||
func TestParseFromYaml(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: serverless-spark-cancel-batch
|
||||
source: my-instance
|
||||
description: some description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": serverlesssparkcancelbatch.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "serverless-spark-cancel-batch",
|
||||
Source: "my-instance",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got, yaml.Strict())
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -127,6 +127,7 @@ type Batch struct {
|
||||
State string `json:"state"`
|
||||
Creator string `json:"creator"`
|
||||
CreateTime string `json:"createTime"`
|
||||
Operation string `json:"operation"`
|
||||
}
|
||||
|
||||
// Invoke executes the tool's operation.
|
||||
@@ -177,6 +178,7 @@ func ToBatches(batchPbs []*dataprocpb.Batch) []Batch {
|
||||
State: batchPb.State.Enum().String(),
|
||||
Creator: batchPb.Creator,
|
||||
CreateTime: batchPb.CreateTime.AsTime().Format(time.RFC3339),
|
||||
Operation: batchPb.Operation,
|
||||
}
|
||||
batches = append(batches, batch)
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -41,16 +42,19 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
serverlessSparkProject = os.Getenv("SERVERLESS_SPARK_PROJECT")
|
||||
serverlessSparkLocation = os.Getenv("SERVERLESS_SPARK_LOCATION")
|
||||
serverlessSparkLocation = os.Getenv("SERVERLESS_SPARK_LOCATION")
|
||||
serverlessSparkProject = os.Getenv("SERVERLESS_SPARK_PROJECT")
|
||||
serverlessSparkServiceAccount = os.Getenv("SERVERLESS_SPARK_SERVICE_ACCOUNT")
|
||||
)
|
||||
|
||||
func getServerlessSparkVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case serverlessSparkProject:
|
||||
t.Fatal("'SERVERLESS_SPARK_PROJECT' not set")
|
||||
case serverlessSparkLocation:
|
||||
t.Fatal("'SERVERLESS_SPARK_LOCATION' not set")
|
||||
case serverlessSparkProject:
|
||||
t.Fatal("'SERVERLESS_SPARK_PROJECT' not set")
|
||||
case serverlessSparkServiceAccount:
|
||||
t.Fatal("'SERVERLESS_SPARK_SERVICE_ACCOUNT' not set")
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
@@ -94,6 +98,15 @@ func TestServerlessSparkToolEndpoints(t *testing.T) {
|
||||
"source": "my-spark",
|
||||
"authRequired": []string{"my-google-auth"},
|
||||
},
|
||||
"cancel-batch": map[string]any{
|
||||
"kind": "serverless-spark-cancel-batch",
|
||||
"source": "my-spark",
|
||||
},
|
||||
"cancel-batch-with-auth": map[string]any{
|
||||
"kind": "serverless-spark-cancel-batch",
|
||||
"source": "my-spark",
|
||||
"authRequired": []string{"my-google-auth"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -206,9 +219,178 @@ func TestServerlessSparkToolEndpoints(t *testing.T) {
|
||||
runAuthTest(t, "get-batch-with-auth", map[string]any{"name": shortName(fullName)}, http.StatusOK)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("cancel-batch", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tcs := []struct {
|
||||
name string
|
||||
getBatchName func(t *testing.T) string
|
||||
}{
|
||||
{
|
||||
name: "running batch",
|
||||
getBatchName: func(t *testing.T) string {
|
||||
return createBatch(t, client, ctx)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "succeeded batch",
|
||||
getBatchName: func(t *testing.T) string {
|
||||
return listBatchesRpc(t, client, ctx, "state = SUCCEEDED", 1, true)[0].Name
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
runCancelBatchTest(t, client, ctx, tc.getBatchName(t))
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("errors", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Find a batch that's already completed.
|
||||
completedBatchOp := listBatchesRpc(t, client, ctx, "state = SUCCEEDED", 1, true)[0].Operation
|
||||
fullOpName := fmt.Sprintf("projects/%s/locations/%s/operations/%s", serverlessSparkProject, serverlessSparkLocation, shortName(completedBatchOp))
|
||||
tcs := []struct {
|
||||
name string
|
||||
toolName string
|
||||
request map[string]any
|
||||
wantCode int
|
||||
wantMsg string
|
||||
}{
|
||||
{
|
||||
name: "missing op parameter",
|
||||
toolName: "cancel-batch",
|
||||
request: map[string]any{},
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantMsg: "parameter \\\"operation\\\" is required",
|
||||
},
|
||||
{
|
||||
name: "nonexistent op",
|
||||
toolName: "cancel-batch",
|
||||
request: map[string]any{"operation": "INVALID_OPERATION"},
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantMsg: "Operation not found",
|
||||
},
|
||||
{
|
||||
name: "full op name",
|
||||
toolName: "cancel-batch",
|
||||
request: map[string]any{"operation": fullOpName},
|
||||
wantCode: http.StatusBadRequest,
|
||||
wantMsg: fmt.Sprintf("operation must be a short operation name without '/': %s", fullOpName),
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testError(t, tc.toolName, tc.request, tc.wantCode, tc.wantMsg)
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("auth", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
runAuthTest(t, "cancel-batch-with-auth", map[string]any{"operation": "INVALID_OPERATION"}, http.StatusBadRequest)
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func waitForBatch(t *testing.T, client *dataproc.BatchControllerClient, parentCtx context.Context, batch string, desiredStates []dataprocpb.Batch_State, timeout time.Duration) {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, timeout)
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("timed out waiting for batch %s to reach one of states %v", batch, desiredStates)
|
||||
default:
|
||||
}
|
||||
|
||||
getReq := &dataprocpb.GetBatchRequest{Name: batch}
|
||||
batch, err := client.GetBatch(ctx, getReq)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get batch %s: %v", batch, err)
|
||||
}
|
||||
|
||||
if slices.Contains(desiredStates, batch.State) {
|
||||
return
|
||||
}
|
||||
|
||||
if batch.State == dataprocpb.Batch_FAILED || batch.State == dataprocpb.Batch_CANCELLED || batch.State == dataprocpb.Batch_SUCCEEDED {
|
||||
t.Fatalf("batch op %s is in a terminal state %s, but wanted one of %v. State message: %s", batch, batch.State, desiredStates, batch.StateMessage)
|
||||
}
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
// createBatch creates a test batch and immediately returns the batch name, without waiting for the
|
||||
// batch to start or complete.
|
||||
func createBatch(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context) string {
|
||||
parent := fmt.Sprintf("projects/%s/locations/%s", serverlessSparkProject, serverlessSparkLocation)
|
||||
req := &dataprocpb.CreateBatchRequest{
|
||||
Parent: parent,
|
||||
Batch: &dataprocpb.Batch{
|
||||
BatchConfig: &dataprocpb.Batch_SparkBatch{
|
||||
SparkBatch: &dataprocpb.SparkBatch{
|
||||
Driver: &dataprocpb.SparkBatch_MainClass{
|
||||
MainClass: "org.apache.spark.examples.SparkPi",
|
||||
},
|
||||
JarFileUris: []string{
|
||||
"file:///usr/lib/spark/examples/jars/spark-examples.jar",
|
||||
},
|
||||
Args: []string{"1000"},
|
||||
},
|
||||
},
|
||||
EnvironmentConfig: &dataprocpb.EnvironmentConfig{
|
||||
ExecutionConfig: &dataprocpb.ExecutionConfig{
|
||||
ServiceAccount: serverlessSparkServiceAccount,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
createOp, err := client.CreateBatch(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create batch: %v", err)
|
||||
}
|
||||
meta, err := createOp.Metadata()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get batch metadata: %v", err)
|
||||
}
|
||||
|
||||
// Wait for the batch to become at least PENDING; it typically takes >10s to go from PENDING to
|
||||
// RUNNING, giving us plenty of time to cancel it before it completes.
|
||||
waitForBatch(t, client, ctx, meta.Batch, []dataprocpb.Batch_State{dataprocpb.Batch_PENDING, dataprocpb.Batch_RUNNING}, 1*time.Minute)
|
||||
return meta.Batch
|
||||
}
|
||||
|
||||
func runCancelBatchTest(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, batchName string) {
|
||||
// First get the batch details directly from the Go proto API.
|
||||
batch, err := client.GetBatch(ctx, &dataprocpb.GetBatchRequest{Name: batchName})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get batch: %s", err)
|
||||
}
|
||||
|
||||
request := map[string]any{"operation": shortName(batch.Operation)}
|
||||
resp, err := invokeTool("cancel-batch", request, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("invokeTool failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
if batch.State != dataprocpb.Batch_SUCCEEDED {
|
||||
waitForBatch(t, client, ctx, batchName, []dataprocpb.Batch_State{dataprocpb.Batch_CANCELLING, dataprocpb.Batch_CANCELLED}, 2*time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
// runListBatchesTest invokes the running list-batches tool and ensures it returns the correct
|
||||
// number of results. It can run successfully against any GCP project that contains at least 2 total
|
||||
// Serverless Spark batches.
|
||||
|
||||
Reference in New Issue
Block a user