mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-02-11 07:35:05 -05:00
This PR introduces a significant update to the Toolbox configuration file format, which is one of the primary **breaking changes** required for the implementation of the Advanced Control Plane. # Summary of Changes The configuration schema has been updated to enforce resource isolation and facilitate atomic, incremental updates. * Resource Isolation: Resource definitions are now separated into individual blocks, using a distinct structure for each resource type (Source, Tool, Toolset, etc.). This improves readability, management, and auditing of configuration files. * Field Name Modification: Internal field names have been modified to align with declarative methodologies. Specifically, the configuration now separates kind (general resource type, e.g., Source) from type (specific implementation, e.g., Postgres). # User Impact Existing tools.yaml configuration files are now in an outdated format. Users must eventually update their files to the new YAML format. # Mitigation & Compatibility Backward compatibility is maintained during this transition to ensure no immediate user action is required for existing files. * Immediate Backward Compatibility: The source code includes a pre-processing layer that automatically detects outdated configuration files (v1 format) and converts them to the new v2 format under the hood. * [COMING SOON] Migration Support: The new toolbox migrate subcommand will be introduced to allow users to automatically convert their old configuration files to the latest format. # Example Example for config file v2: ``` kind: sources name: my-pg-instance type: cloud-sql-postgres project: my-project region: my-region instance: my-instance database: my_db user: my_user password: my_pass --- kind: authServices name: my-google-auth type: google clientId: testing-id --- kind: tools name: example_tool type: postgres-sql source: my-pg-instance description: some description statement: SELECT * FROM SQL_STATEMENT; parameters: - name: country type: string description: some description --- kind: tools name: example_tool_2 type: postgres-sql source: my-pg-instance description: returning the number one statement: SELECT 1; --- kind: toolsets name: example_toolset tools: - example_tool ``` --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Averi Kitsch <akitsch@google.com>
295 lines
8.3 KiB
Go
295 lines
8.3 KiB
Go
// Copyright 2025 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package serverlessspark
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"time"
|
|
|
|
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
|
|
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
|
longrunning "cloud.google.com/go/longrunning/autogen"
|
|
"cloud.google.com/go/longrunning/autogen/longrunningpb"
|
|
"github.com/goccy/go-yaml"
|
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
|
"github.com/googleapis/genai-toolbox/internal/util"
|
|
"go.opentelemetry.io/otel/trace"
|
|
"google.golang.org/api/iterator"
|
|
"google.golang.org/api/option"
|
|
"google.golang.org/protobuf/encoding/protojson"
|
|
)
|
|
|
|
const SourceType string = "serverless-spark"
|
|
|
|
// validate interface
|
|
var _ sources.SourceConfig = Config{}
|
|
|
|
func init() {
|
|
if !sources.Register(SourceType, newConfig) {
|
|
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
|
}
|
|
}
|
|
|
|
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
|
|
actual := Config{Name: name}
|
|
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
|
return nil, err
|
|
}
|
|
return actual, nil
|
|
}
|
|
|
|
type Config struct {
|
|
Name string `yaml:"name" validate:"required"`
|
|
Type string `yaml:"type" validate:"required"`
|
|
Project string `yaml:"project" validate:"required"`
|
|
Location string `yaml:"location" validate:"required"`
|
|
}
|
|
|
|
func (r Config) SourceConfigType() string {
|
|
return SourceType
|
|
}
|
|
|
|
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
|
ua, err := util.UserAgentFromContext(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error in User Agent retrieval: %s", err)
|
|
}
|
|
endpoint := fmt.Sprintf("%s-dataproc.googleapis.com:443", r.Location)
|
|
client, err := dataproc.NewBatchControllerClient(ctx, option.WithEndpoint(endpoint), option.WithUserAgent(ua))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create dataproc client: %w", err)
|
|
}
|
|
opsClient, err := longrunning.NewOperationsClient(ctx, option.WithEndpoint(endpoint), option.WithUserAgent(ua))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create longrunning client: %w", err)
|
|
}
|
|
|
|
s := &Source{
|
|
Config: r,
|
|
Client: client,
|
|
OpsClient: opsClient,
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
var _ sources.Source = &Source{}
|
|
|
|
type Source struct {
|
|
Config
|
|
Client *dataproc.BatchControllerClient
|
|
OpsClient *longrunning.OperationsClient
|
|
}
|
|
|
|
func (s *Source) SourceType() string {
|
|
return SourceType
|
|
}
|
|
|
|
func (s *Source) ToConfig() sources.SourceConfig {
|
|
return s.Config
|
|
}
|
|
|
|
func (s *Source) GetProject() string {
|
|
return s.Project
|
|
}
|
|
|
|
func (s *Source) GetLocation() string {
|
|
return s.Location
|
|
}
|
|
|
|
func (s *Source) GetBatchControllerClient() *dataproc.BatchControllerClient {
|
|
return s.Client
|
|
}
|
|
|
|
func (s *Source) GetOperationsClient(ctx context.Context) (*longrunning.OperationsClient, error) {
|
|
return s.OpsClient, nil
|
|
}
|
|
|
|
func (s *Source) Close() error {
|
|
if err := s.Client.Close(); err != nil {
|
|
return err
|
|
}
|
|
if err := s.OpsClient.Close(); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Source) CancelOperation(ctx context.Context, operation string) (any, error) {
|
|
req := &longrunningpb.CancelOperationRequest{
|
|
Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", s.GetProject(), s.GetLocation(), operation),
|
|
}
|
|
client, err := s.GetOperationsClient(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get operations client: %w", err)
|
|
}
|
|
err = client.CancelOperation(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to cancel operation: %w", err)
|
|
}
|
|
return fmt.Sprintf("Cancelled [%s].", operation), nil
|
|
}
|
|
|
|
func (s *Source) CreateBatch(ctx context.Context, batch *dataprocpb.Batch) (map[string]any, error) {
|
|
req := &dataprocpb.CreateBatchRequest{
|
|
Parent: fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation()),
|
|
Batch: batch,
|
|
}
|
|
|
|
client := s.GetBatchControllerClient()
|
|
op, err := client.CreateBatch(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create batch: %w", err)
|
|
}
|
|
meta, err := op.Metadata()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get create batch op metadata: %w", err)
|
|
}
|
|
|
|
projectID, location, batchID, err := ExtractBatchDetails(meta.GetBatch())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error extracting batch details from name %q: %v", meta.GetBatch(), err)
|
|
}
|
|
consoleUrl := BatchConsoleURL(projectID, location, batchID)
|
|
logsUrl := BatchLogsURL(projectID, location, batchID, meta.GetCreateTime().AsTime(), time.Time{})
|
|
|
|
wrappedResult := map[string]any{
|
|
"opMetadata": meta,
|
|
"consoleUrl": consoleUrl,
|
|
"logsUrl": logsUrl,
|
|
}
|
|
return wrappedResult, nil
|
|
}
|
|
|
|
// ListBatchesResponse is the response from the list batches API.
|
|
type ListBatchesResponse struct {
|
|
Batches []Batch `json:"batches"`
|
|
NextPageToken string `json:"nextPageToken"`
|
|
}
|
|
|
|
// Batch represents a single batch job.
|
|
type Batch struct {
|
|
Name string `json:"name"`
|
|
UUID string `json:"uuid"`
|
|
State string `json:"state"`
|
|
Creator string `json:"creator"`
|
|
CreateTime string `json:"createTime"`
|
|
Operation string `json:"operation"`
|
|
ConsoleURL string `json:"consoleUrl"`
|
|
LogsURL string `json:"logsUrl"`
|
|
}
|
|
|
|
func (s *Source) ListBatches(ctx context.Context, ps *int, pt, filter string) (any, error) {
|
|
client := s.GetBatchControllerClient()
|
|
parent := fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation())
|
|
req := &dataprocpb.ListBatchesRequest{
|
|
Parent: parent,
|
|
OrderBy: "create_time desc",
|
|
}
|
|
|
|
if ps != nil {
|
|
req.PageSize = int32(*ps)
|
|
}
|
|
if pt != "" {
|
|
req.PageToken = pt
|
|
}
|
|
if filter != "" {
|
|
req.Filter = filter
|
|
}
|
|
|
|
it := client.ListBatches(ctx, req)
|
|
pager := iterator.NewPager(it, int(req.PageSize), req.PageToken)
|
|
|
|
var batchPbs []*dataprocpb.Batch
|
|
nextPageToken, err := pager.NextPage(&batchPbs)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to list batches: %w", err)
|
|
}
|
|
|
|
batches, err := ToBatches(batchPbs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return ListBatchesResponse{Batches: batches, NextPageToken: nextPageToken}, nil
|
|
}
|
|
|
|
// ToBatches converts a slice of protobuf Batch messages to a slice of Batch structs.
|
|
func ToBatches(batchPbs []*dataprocpb.Batch) ([]Batch, error) {
|
|
batches := make([]Batch, 0, len(batchPbs))
|
|
for _, batchPb := range batchPbs {
|
|
consoleUrl, err := BatchConsoleURLFromProto(batchPb)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error generating console url: %v", err)
|
|
}
|
|
logsUrl, err := BatchLogsURLFromProto(batchPb)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error generating logs url: %v", err)
|
|
}
|
|
batch := Batch{
|
|
Name: batchPb.Name,
|
|
UUID: batchPb.Uuid,
|
|
State: batchPb.State.Enum().String(),
|
|
Creator: batchPb.Creator,
|
|
CreateTime: batchPb.CreateTime.AsTime().Format(time.RFC3339),
|
|
Operation: batchPb.Operation,
|
|
ConsoleURL: consoleUrl,
|
|
LogsURL: logsUrl,
|
|
}
|
|
batches = append(batches, batch)
|
|
}
|
|
return batches, nil
|
|
}
|
|
|
|
func (s *Source) GetBatch(ctx context.Context, name string) (map[string]any, error) {
|
|
client := s.GetBatchControllerClient()
|
|
req := &dataprocpb.GetBatchRequest{
|
|
Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", s.GetProject(), s.GetLocation(), name),
|
|
}
|
|
|
|
batchPb, err := client.GetBatch(ctx, req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get batch: %w", err)
|
|
}
|
|
|
|
jsonBytes, err := protojson.Marshal(batchPb)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal batch to JSON: %w", err)
|
|
}
|
|
|
|
var result map[string]any
|
|
if err := json.Unmarshal(jsonBytes, &result); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal batch JSON: %w", err)
|
|
}
|
|
|
|
consoleUrl, err := BatchConsoleURLFromProto(batchPb)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error generating console url: %v", err)
|
|
}
|
|
logsUrl, err := BatchLogsURLFromProto(batchPb)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error generating logs url: %v", err)
|
|
}
|
|
|
|
wrappedResult := map[string]any{
|
|
"consoleUrl": consoleUrl,
|
|
"logsUrl": logsUrl,
|
|
"batch": result,
|
|
}
|
|
|
|
return wrappedResult, nil
|
|
}
|