diff --git a/cmd/root.go b/cmd/root.go index 41761fa550..693405bc52 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -198,6 +198,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql" _ "github.com/googleapis/genai-toolbox/internal/tools/redis" _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcancelbatch" + _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch" _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkgetbatch" _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches" _ "github.com/googleapis/genai-toolbox/internal/tools/singlestore/singlestoreexecutesql" diff --git a/cmd/root_test.go b/cmd/root_test.go index bf697415e9..c8ae8042f4 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -1558,7 +1558,7 @@ func TestPrebuiltTools(t *testing.T) { wantToolset: server.ToolsetConfigs{ "serverless_spark_tools": tools.ToolsetConfig{ Name: "serverless_spark_tools", - ToolNames: []string{"list_batches", "get_batch", "cancel_batch"}, + ToolNames: []string{"list_batches", "get_batch", "cancel_batch", "create_pyspark_batch"}, }, }, }, diff --git a/docs/en/resources/sources/serverless-spark.md b/docs/en/resources/sources/serverless-spark.md index 0d137d36b7..d032e63460 100644 --- a/docs/en/resources/sources/serverless-spark.md +++ b/docs/en/resources/sources/serverless-spark.md @@ -21,6 +21,8 @@ Apache Spark. Get a Serverless Spark batch. - [`serverless-spark-cancel-batch`](../tools/serverless-spark/serverless-spark-cancel-batch.md) Cancel a running Serverless Spark batch operation. +- [`serverless-spark-create-pyspark-batch`](../tools/serverless-spark/serverless-spark-create-pyspark-batch.md) + Create a Serverless Spark PySpark batch operation. ## Requirements diff --git a/docs/en/resources/tools/serverless-spark/_index.md b/docs/en/resources/tools/serverless-spark/_index.md index 4974a07b19..e5ff3c18a4 100644 --- a/docs/en/resources/tools/serverless-spark/_index.md +++ b/docs/en/resources/tools/serverless-spark/_index.md @@ -9,3 +9,4 @@ description: > - [serverless-spark-get-batch](./serverless-spark-get-batch.md) - [serverless-spark-list-batches](./serverless-spark-list-batches.md) - [serverless-spark-cancel-batch](./serverless-spark-cancel-batch.md) +- [serverless-spark-create-pyspark-batch](./serverless-spark-create-pyspark-batch.md) diff --git a/docs/en/resources/tools/serverless-spark/serverless-spark-create-pyspark-batch.md b/docs/en/resources/tools/serverless-spark/serverless-spark-create-pyspark-batch.md new file mode 100644 index 0000000000..cc58e38412 --- /dev/null +++ b/docs/en/resources/tools/serverless-spark/serverless-spark-create-pyspark-batch.md @@ -0,0 +1,90 @@ +--- +title: "serverless-spark-create-pyspark-batch" +type: docs +weight: 2 +description: > + A "serverless-spark-create-pyspark-batch" tool submits a Spark batch to run asynchronously. +aliases: + - /resources/tools/serverless-spark-create-pyspark-batch +--- + +## About + +A `serverless-spark-create-pyspark-batch` tool submits a Spark batch to a Google +Cloud Serverless for Apache Spark source. The workload executes asynchronously +and takes around a minute to begin executing; status can be polled using the +[get batch](serverless-spark-get-batch.md) tool. + +It's compatible with the following sources: + +- [serverless-spark](../../sources/serverless-spark.md) + +`serverless-spark-create-pyspark-batch` accepts the following parameters: + +- **`mainFile`**: The path to the main Python file, as a gs://... URI. +- **`args`** Optional. A list of arguments passed to the main file. +- **`version`** Optional. The Serverless [runtime + version](https://docs.cloud.google.com/dataproc-serverless/docs/concepts/versions/dataproc-serverless-versions) + to execute with. + +## Custom Configuration + +This tool supports custom +[`runtimeConfig`](https://docs.cloud.google.com/dataproc-serverless/docs/reference/rest/v1/RuntimeConfig) +and +[`environmentConfig`](https://docs.cloud.google.com/dataproc-serverless/docs/reference/rest/v1/EnvironmentConfig) +settings, which can be specified in a `tools.yaml` file. These configurations +are parsed as YAML and passed to the Dataproc API. + +**Note:** If your project requires custom runtime or environment configuration, +you must write a custom `tools.yaml`, you cannot use the `serverless-spark` +prebuilt config. + +### Example `tools.yaml` + +```yaml +tools: + - name: "serverless-spark-create-pyspark-batch" + kind: "serverless-spark-create-pyspark-batch" + source: "my-serverless-spark-source" + runtimeConfig: + properties: + spark.driver.memory: "1024m" + environmentConfig: + executionConfig: + networkUri: "my-network" +``` + +## Response Format + +The response is an [operation](https://docs.cloud.google.com/dataproc-serverless/docs/reference/rest/v1/projects.locations.operations#resource:-operation) metadata JSON +object corresponding to [batch operation metadata](https://pkg.go.dev/cloud.google.com/go/dataproc/v2/apiv1/dataprocpb#BatchOperationMetadata) +Example: + +```json +{ + "batch": "projects/myproject/locations/us-central1/batches/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + "batchUuid": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + "createTime": "2025-11-19T16:36:47.607119Z", + "description": "Batch", + "labels": { + "goog-dataproc-batch-uuid": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + "goog-dataproc-location": "us-central1" + }, + "operationType": "BATCH", + "warnings": [ + "No runtime version specified. Using the default runtime version." + ] +} +``` + +## Reference + +| **field** | **type** | **required** | **description** | +| ----------------- | :------: | :----------: | -------------------------------------------------------------------------------------------------------------------------------------------------------- | +| kind | string | true | Must be "serverless-spark-create-pyspark-batch". | +| source | string | true | Name of the source the tool should use. | +| description | string | false | Description of the tool that is passed to the LLM. | +| runtimeConfig | map | false | [Runtime config](https://docs.cloud.google.com/dataproc-serverless/docs/reference/rest/v1/RuntimeConfig) for all batches created with this tool. | +| environmentConfig | map | false | [Environment config](https://docs.cloud.google.com/dataproc-serverless/docs/reference/rest/v1/EnvironmentConfig) for all batches created with this tool. | +| authRequired | string[] | false | List of auth services required to invoke this tool. | diff --git a/internal/prebuiltconfigs/tools/serverless-spark.yaml b/internal/prebuiltconfigs/tools/serverless-spark.yaml index 7d78b18a95..287925eb04 100644 --- a/internal/prebuiltconfigs/tools/serverless-spark.yaml +++ b/internal/prebuiltconfigs/tools/serverless-spark.yaml @@ -28,9 +28,13 @@ tools: cancel_batch: kind: serverless-spark-cancel-batch source: serverless-spark-source + create_pyspark_batch: + kind: serverless-spark-create-pyspark-batch + source: serverless-spark-source toolsets: serverless_spark_tools: - list_batches - get_batch - cancel_batch + - create_pyspark_batch diff --git a/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch/serverlesssparkcreatepysparkbatch.go b/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch/serverlesssparkcreatepysparkbatch.go new file mode 100644 index 0000000000..454f9b6015 --- /dev/null +++ b/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch/serverlesssparkcreatepysparkbatch.go @@ -0,0 +1,252 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package serverlesssparkcreatepysparkbatch + +import ( + "context" + "encoding/json" + "fmt" + + dataproc "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" +) + +const kind = "serverless-spark-create-pyspark-batch" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + // Use a temporary struct to decode the YAML, so that we can handle the proto + // conversion for RuntimeConfig and EnvironmentConfig. + var ymlCfg struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + Source string `yaml:"source"` + Description string `yaml:"description"` + RuntimeConfig any `yaml:"runtimeConfig"` + EnvironmentConfig any `yaml:"environmentConfig"` + AuthRequired []string `yaml:"authRequired"` + } + + if err := decoder.DecodeContext(ctx, &ymlCfg); err != nil { + return nil, err + } + + cfg := Config{ + Name: name, + Kind: ymlCfg.Kind, + Source: ymlCfg.Source, + Description: ymlCfg.Description, + AuthRequired: ymlCfg.AuthRequired, + } + + if ymlCfg.RuntimeConfig != nil { + rc := &dataproc.RuntimeConfig{} + jsonData, err := json.Marshal(ymlCfg.RuntimeConfig) + if err != nil { + return nil, fmt.Errorf("failed to marshal runtimeConfig: %w", err) + } + if err := protojson.Unmarshal(jsonData, rc); err != nil { + return nil, fmt.Errorf("failed to unmarshal runtimeConfig: %w", err) + } + cfg.RuntimeConfig = rc + } + + if ymlCfg.EnvironmentConfig != nil { + ec := &dataproc.EnvironmentConfig{} + jsonData, err := json.Marshal(ymlCfg.EnvironmentConfig) + if err != nil { + return nil, fmt.Errorf("failed to marshal environmentConfig: %w", err) + } + if err := protojson.Unmarshal(jsonData, ec); err != nil { + return nil, fmt.Errorf("failed to unmarshal environmentConfig: %w", err) + } + cfg.EnvironmentConfig = ec + } + + return cfg, nil +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description"` + RuntimeConfig *dataproc.RuntimeConfig `yaml:"runtimeConfig"` + EnvironmentConfig *dataproc.EnvironmentConfig `yaml:"environmentConfig"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +// ToolConfigKind returns the unique name for this tool. +func (cfg Config) ToolConfigKind() string { + return kind +} + +// Initialize creates a new Tool instance. +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("source %q not found", cfg.Source) + } + + ds, ok := rawS.(*serverlessspark.Source) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, serverlessspark.SourceKind) + } + + desc := cfg.Description + if desc == "" { + desc = "Creates a Serverless Spark (aka Dataproc Serverless) PySpark batch operation." + } + + allParameters := parameters.Parameters{ + parameters.NewStringParameterWithRequired("mainFile", "The path to the main Python file, as a gs://... URI.", true), + parameters.NewArrayParameterWithRequired("args", "Optional. A list of arguments passed to the main file.", false, parameters.NewStringParameter("arg", "An argument.")), + parameters.NewStringParameterWithRequired("version", "Optional. The Serverless runtime version to execute with.", false), + } + inputSchema, _ := allParameters.McpManifest() + + mcpManifest := tools.McpManifest{ + Name: cfg.Name, + Description: desc, + InputSchema: inputSchema, + } + + return &Tool{ + Config: cfg, + Source: ds, + manifest: tools.Manifest{Description: desc, Parameters: allParameters.Manifest()}, + mcpManifest: mcpManifest, + Parameters: allParameters, + }, nil +} + +// Tool is the implementation of the tool. +type Tool struct { + Config + + Source *serverlessspark.Source + + manifest tools.Manifest + mcpManifest tools.McpManifest + Parameters parameters.Parameters +} + +// Invoke executes the tool's operation. +func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + client := t.Source.GetBatchControllerClient() + + paramMap := params.AsMap() + + mainFile := paramMap["mainFile"].(string) + + batch := &dataproc.Batch{ + BatchConfig: &dataproc.Batch_PysparkBatch{ + PysparkBatch: &dataproc.PySparkBatch{ + MainPythonFileUri: mainFile, + }, + }, + } + + if args, ok := paramMap["args"].([]any); ok { + for _, arg := range args { + batch.GetPysparkBatch().Args = append(batch.GetPysparkBatch().Args, fmt.Sprintf("%v", arg)) + } + } + + if t.Config.RuntimeConfig != nil { + batch.RuntimeConfig = proto.Clone(t.Config.RuntimeConfig).(*dataproc.RuntimeConfig) + } + + if t.Config.EnvironmentConfig != nil { + batch.EnvironmentConfig = proto.Clone(t.Config.EnvironmentConfig).(*dataproc.EnvironmentConfig) + } + + if version, ok := paramMap["version"].(string); ok && version != "" { + if batch.RuntimeConfig == nil { + batch.RuntimeConfig = &dataproc.RuntimeConfig{} + } + batch.RuntimeConfig.Version = version + } + + req := &dataproc.CreateBatchRequest{ + Parent: fmt.Sprintf("projects/%s/locations/%s", t.Source.Project, t.Source.Location), + Batch: batch, + } + + op, err := client.CreateBatch(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to create batch: %w", err) + } + + meta, err := op.Metadata() + if err != nil { + return nil, fmt.Errorf("failed to get create batch op metadata: %w", err) + } + + jsonBytes, err := protojson.Marshal(meta) + if err != nil { + return nil, fmt.Errorf("failed to marshal create batch op metadata to JSON: %w", err) + } + + var result map[string]any + if err := json.Unmarshal(jsonBytes, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal create batch op metadata JSON: %w", err) + } + + return result, nil +} + +func (t *Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.Parameters, data, claims) +} + +func (t *Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t *Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t *Tool) Authorized(services []string) bool { + return tools.IsAuthorized(t.AuthRequired, services) +} + +func (t *Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool { + // Client OAuth not supported, rely on ADCs. + return false +} + +func (t *Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t *Tool) GetAuthTokenHeaderName() string { + return "Authorization" +} diff --git a/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch/serverlesssparkcreatepysparkbatch_test.go b/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch/serverlesssparkcreatepysparkbatch_test.go new file mode 100644 index 0000000000..28384c2df4 --- /dev/null +++ b/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch/serverlesssparkcreatepysparkbatch_test.go @@ -0,0 +1,144 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package serverlesssparkcreatepysparkbatch_test + +import ( + "strings" + "testing" + + dataproc "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" + "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch" + "google.golang.org/protobuf/testing/protocmp" +) + +func TestParseFromYaml(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + wantErr string + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: serverless-spark-create-pyspark-batch + source: my-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": serverlesssparkcreatepysparkbatch.Config{ + Name: "example_tool", + Kind: "serverless-spark-create-pyspark-batch", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + { + desc: "detailed config", + in: ` + tools: + example_tool: + kind: serverless-spark-create-pyspark-batch + source: my-instance + description: some description + runtimeConfig: + properties: + "spark.driver.memory": "1g" + environmentConfig: + executionConfig: + networkUri: "my-network" + `, + want: server.ToolConfigs{ + "example_tool": serverlesssparkcreatepysparkbatch.Config{ + Name: "example_tool", + Kind: "serverless-spark-create-pyspark-batch", + Source: "my-instance", + Description: "some description", + RuntimeConfig: &dataproc.RuntimeConfig{ + Properties: map[string]string{"spark.driver.memory": "1g"}, + }, + EnvironmentConfig: &dataproc.EnvironmentConfig{ + ExecutionConfig: &dataproc.ExecutionConfig{ + Network: &dataproc.ExecutionConfig_NetworkUri{NetworkUri: "my-network"}, + }, + }, + AuthRequired: []string{}, + }, + }, + }, + { + desc: "invalid runtime config", + in: ` + tools: + example_tool: + kind: serverless-spark-create-pyspark-batch + source: my-instance + description: some description + runtimeConfig: + invalidField: true + `, + wantErr: "unmarshal runtimeConfig", + }, + { + desc: "invalid environment config", + in: ` + tools: + example_tool: + kind: serverless-spark-create-pyspark-batch + source: my-instance + description: some description + environmentConfig: + invalidField: true + `, + wantErr: "unmarshal environmentConfig", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got, yaml.Strict()) + if tc.wantErr != "" { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("expected error to contain %q, got %q", tc.wantErr, err) + } + return + } + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + + if diff := cmp.Diff(tc.want, got.Tools, protocmp.Transform()); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/tests/serverlessspark/serverless_spark_integration_test.go b/tests/serverlessspark/serverless_spark_integration_test.go index f2fa106f8a..cfd5079b2a 100644 --- a/tests/serverlessspark/serverless_spark_integration_test.go +++ b/tests/serverlessspark/serverless_spark_integration_test.go @@ -66,7 +66,7 @@ func getServerlessSparkVars(t *testing.T) map[string]any { func TestServerlessSparkToolEndpoints(t *testing.T) { sourceConfig := getServerlessSparkVars(t) - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) defer cancel() toolsFile := map[string]any{ @@ -107,6 +107,30 @@ func TestServerlessSparkToolEndpoints(t *testing.T) { "source": "my-spark", "authRequired": []string{"my-google-auth"}, }, + "create-pyspark-batch": map[string]any{ + "kind": "serverless-spark-create-pyspark-batch", + "source": "my-spark", + "environmentConfig": map[string]any{ + "executionConfig": map[string]any{ + "serviceAccount": serverlessSparkServiceAccount, + }, + }, + }, + "create-pyspark-batch-2-3": map[string]any{ + "kind": "serverless-spark-create-pyspark-batch", + "source": "my-spark", + "runtimeConfig": map[string]any{"version": "2.3"}, + "environmentConfig": map[string]any{ + "executionConfig": map[string]any{ + "serviceAccount": serverlessSparkServiceAccount, + }, + }, + }, + "create-pyspark-batch-with-auth": map[string]any{ + "kind": "serverless-spark-create-pyspark-batch", + "source": "my-spark", + "authRequired": []string{"my-google-auth"}, + }, }, } @@ -220,6 +244,96 @@ func TestServerlessSparkToolEndpoints(t *testing.T) { }) }) + t.Run("create-pyspark-batch", func(t *testing.T) { + t.Parallel() + + t.Run("success", func(t *testing.T) { + t.Parallel() + piPy := "file:///usr/lib/spark/examples/src/main/python/pi.py" + tcs := []struct { + name string + toolName string + request map[string]any + waitForSuccess bool + validate func(t *testing.T, b *dataprocpb.Batch) + }{ + { + name: "no params", + toolName: "create-pyspark-batch", + waitForSuccess: true, + request: map[string]any{"mainFile": piPy}, + }, + // Tests below are just verifying options are set correctly on created batches, + // they don't need to wait for success. + { + name: "with arg", + toolName: "create-pyspark-batch", + request: map[string]any{"mainFile": piPy, "args": []string{"100"}}, + validate: func(t *testing.T, b *dataprocpb.Batch) { + if !cmp.Equal(b.GetPysparkBatch().Args, []string{"100"}) { + t.Errorf("unexpected args: got %v, want %v", b.GetPysparkBatch().Args, []string{"100"}) + } + }, + }, + { + name: "version", + toolName: "create-pyspark-batch", + request: map[string]any{"mainFile": piPy, "version": "2.2"}, + validate: func(t *testing.T, b *dataprocpb.Batch) { + v := b.GetRuntimeConfig().GetVersion() + if v != "2.2" { + t.Errorf("unexpected version: got %v, want 2.2", v) + } + }, + }, + { + name: "version param overrides tool", + toolName: "create-pyspark-batch-2-3", + request: map[string]any{"mainFile": piPy, "version": "2.2"}, + validate: func(t *testing.T, b *dataprocpb.Batch) { + v := b.GetRuntimeConfig().GetVersion() + if v != "2.2" { + t.Errorf("unexpected version: got %v, want 2.2", v) + } + }, + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + runCreatePysparkBatchTest(t, client, ctx, tc.toolName, tc.request, tc.waitForSuccess, tc.validate) + }) + } + }) + + t.Run("auth", func(t *testing.T) { + t.Parallel() + // Batch creation succeeds even with an invalid main file, but will fail quickly once running. + runAuthTest(t, "create-pyspark-batch-with-auth", map[string]any{"mainFile": "file:///placeholder"}, http.StatusOK) + }) + + t.Run("errors", func(t *testing.T) { + t.Parallel() + tcs := []struct { + name string + request map[string]any + wantMsg string + }{ + { + name: "missing main file", + request: map[string]any{}, + wantMsg: "parameter \\\"mainFile\\\" is required", + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + testError(t, "create-pyspark-batch", tc.request, http.StatusBadRequest, tc.wantMsg) + }) + } + }) + }) + t.Run("cancel-batch", func(t *testing.T) { t.Parallel() t.Run("success", func(t *testing.T) { @@ -299,13 +413,16 @@ func TestServerlessSparkToolEndpoints(t *testing.T) { } func waitForBatch(t *testing.T, client *dataproc.BatchControllerClient, parentCtx context.Context, batch string, desiredStates []dataprocpb.Batch_State, timeout time.Duration) { + t.Logf("waiting %s for batch %s to reach one of %v", timeout, batch, desiredStates) ctx, cancel := context.WithTimeout(parentCtx, timeout) defer cancel() + start := time.Now() + lastLog := start for { select { case <-ctx.Done(): - t.Fatalf("timed out waiting for batch %s to reach one of states %v", batch, desiredStates) + t.Fatalf("timed out waiting for batch %s to reach one of %v", batch, desiredStates) default: } @@ -315,12 +432,18 @@ func waitForBatch(t *testing.T, client *dataproc.BatchControllerClient, parentCt t.Fatalf("failed to get batch %s: %v", batch, err) } + now := time.Now() + if now.Sub(lastLog) >= 30*time.Second { + t.Logf("%s: batch %s is in state %s after %s", t.Name(), batch.Name, batch.State, now.Sub(start)) + lastLog = now + } + if slices.Contains(desiredStates, batch.State) { return } if batch.State == dataprocpb.Batch_FAILED || batch.State == dataprocpb.Batch_CANCELLED || batch.State == dataprocpb.Batch_SUCCEEDED { - t.Fatalf("batch op %s is in a terminal state %s, but wanted one of %v. State message: %s", batch, batch.State, desiredStates, batch.StateMessage) + t.Fatalf("batch op %s is in a terminal state %s, but wanted one of %v. State message: %s", batch.Name, batch.State, desiredStates, batch.StateMessage) } time.Sleep(2 * time.Second) } @@ -362,7 +485,7 @@ func createBatch(t *testing.T, client *dataproc.BatchControllerClient, ctx conte } // Wait for the batch to become at least PENDING; it typically takes >10s to go from PENDING to - // RUNNING, giving us plenty of time to cancel it before it completes. + // RUNNING, giving the cancel batch tests plenty of time to cancel it before it completes. waitForBatch(t, client, ctx, meta.Batch, []dataprocpb.Batch_State{dataprocpb.Batch_PENDING, dataprocpb.Batch_RUNNING}, 1*time.Minute) return meta.Batch } @@ -615,6 +738,54 @@ func runGetBatchTest(t *testing.T, client *dataproc.BatchControllerClient, ctx c } } +func runCreatePysparkBatchTest( + t *testing.T, + client *dataproc.BatchControllerClient, + ctx context.Context, + toolName string, + request map[string]any, + waitForSuccess bool, + validate func(t *testing.T, b *dataprocpb.Batch), +) { + resp, err := invokeTool(toolName, request, nil) + if err != nil { + t.Fatalf("invokeTool failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var body map[string]any + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + t.Fatalf("error parsing response body: %v", err) + } + + result, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } + + var meta dataprocpb.BatchOperationMetadata + if err := json.Unmarshal([]byte(result), &meta); err != nil { + t.Fatalf("failed to unmarshal result: %v", err) + } + + if validate != nil { + b, err := client.GetBatch(ctx, &dataprocpb.GetBatchRequest{Name: meta.Batch}) + if err != nil { + t.Fatalf("failed to get batch: %s", err) + } + validate(t, b) + } + + if waitForSuccess { + waitForBatch(t, client, ctx, meta.Batch, []dataprocpb.Batch_State{dataprocpb.Batch_SUCCEEDED}, 5*time.Minute) + } +} + func testError(t *testing.T, toolName string, request map[string]any, wantCode int, wantMsg string) { resp, err := invokeTool(toolName, request, nil) if err != nil {