From 0f27f956c7d85e6f8a3f4c4589cf8020bc1d3d3b Mon Sep 17 00:00:00 2001 From: Yuan Teoh <45984206+Yuan325@users.noreply.github.com> Date: Tue, 30 Dec 2025 21:43:09 -0800 Subject: [PATCH] refactor(sources/bigquery): move source implementation in Invoke() function to Source (#2242) Move source-related queries from `Invoke()` function into Source. This is an effort to generalizing tools to work with any Source that implements a specific interface. This will provide a better segregation of the roles for Tools vs Source. Tool's role will be limited to the following: * Resolve any pre-implementation steps or parameters (e.g. template parameters) * Retrieving Source * Calling the source's implementation --- internal/sources/bigquery/bigquery.go | 130 ++++++++++++++++++ internal/sources/bigquery/bigquery_test.go | 104 ++++++++++++++ .../bigqueryanalyzecontribution.go | 60 +------- .../bigquerycommon/conversion_test.go | 123 ----------------- .../tools/bigquery/bigquerycommon/util.go | 53 ------- .../bigqueryexecutesql/bigqueryexecutesql.go | 72 +--------- .../bigqueryforecast/bigqueryforecast.go | 63 ++------- .../bigquerygetdatasetinfo.go | 20 +-- .../bigquerygettableinfo.go | 20 +-- .../bigquerylistdatasetids.go | 19 +-- .../bigquerylisttableids.go | 19 +-- .../tools/bigquery/bigquerysql/bigquerysql.go | 82 ++--------- tests/bigquery/bigquery_integration_test.go | 4 +- 13 files changed, 288 insertions(+), 481 deletions(-) delete mode 100644 internal/tools/bigquery/bigquerycommon/conversion_test.go diff --git a/internal/sources/bigquery/bigquery.go b/internal/sources/bigquery/bigquery.go index 3b2d823dc1..a0e170e144 100644 --- a/internal/sources/bigquery/bigquery.go +++ b/internal/sources/bigquery/bigquery.go @@ -17,7 +17,9 @@ package bigquery import ( "context" "fmt" + "math/big" "net/http" + "reflect" "strings" "sync" "time" @@ -26,13 +28,16 @@ import ( dataplexapi "cloud.google.com/go/dataplex/apiv1" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "go.opentelemetry.io/otel/trace" "golang.org/x/oauth2" "golang.org/x/oauth2/google" bigqueryrestapi "google.golang.org/api/bigquery/v2" "google.golang.org/api/googleapi" "google.golang.org/api/impersonate" + "google.golang.org/api/iterator" "google.golang.org/api/option" ) @@ -483,6 +488,131 @@ func (s *Source) lazyInitDataplexClient(ctx context.Context, tracer trace.Tracer } } +func (s *Source) RetrieveClientAndService(accessToken tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) { + bqClient := s.BigQueryClient() + restService := s.BigQueryRestService() + + // Initialize new client if using user OAuth token + if s.UseClientAuthorization() { + tokenStr, err := accessToken.ParseBearerToken() + if err != nil { + return nil, nil, fmt.Errorf("error parsing access token: %w", err) + } + bqClient, restService, err = s.BigQueryClientCreator()(tokenStr, true) + if err != nil { + return nil, nil, fmt.Errorf("error creating client from OAuth access token: %w", err) + } + } + return bqClient, restService, nil +} + +func (s *Source) RunSQL(ctx context.Context, bqClient *bigqueryapi.Client, statement, statementType string, params []bigqueryapi.QueryParameter, connProps []*bigqueryapi.ConnectionProperty) (any, error) { + query := bqClient.Query(statement) + query.Location = bqClient.Location + if params != nil { + query.Parameters = params + } + if connProps != nil { + query.ConnectionProperties = connProps + } + + // This block handles SELECT statements, which return a row set. + // We iterate through the results, convert each row into a map of + // column names to values, and return the collection of rows. + job, err := query.Run(ctx) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + it, err := job.Read(ctx) + if err != nil { + return nil, fmt.Errorf("unable to read query results: %w", err) + } + + var out []any + for { + var val []bigqueryapi.Value + err = it.Next(&val) + if err == iterator.Done { + break + } + if err != nil { + return nil, fmt.Errorf("unable to iterate through query results: %w", err) + } + schema := it.Schema + row := orderedmap.Row{} + for i, field := range schema { + row.Add(field.Name, NormalizeValue(val[i])) + } + out = append(out, row) + } + // If the query returned any rows, return them directly. + if len(out) > 0 { + return out, nil + } + + // This handles the standard case for a SELECT query that successfully + // executes but returns zero rows. + if statementType == "SELECT" { + return "The query returned 0 rows.", nil + } + // This is the fallback for a successful query that doesn't return content. + // In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc. + // However, it is also possible that this was a query that was expected to return rows + // but returned none, a case that we cannot distinguish here. + return "Query executed successfully and returned no content.", nil +} + +// NormalizeValue converts BigQuery specific types to standard JSON-compatible types. +// Specifically, it handles *big.Rat (used for NUMERIC/BIGNUMERIC) by converting +// them to decimal strings with up to 38 digits of precision, trimming trailing zeros. +// It recursively handles slices (arrays) and maps (structs) using reflection. +func NormalizeValue(v any) any { + if v == nil { + return nil + } + + // Handle *big.Rat specifically. + if rat, ok := v.(*big.Rat); ok { + // Convert big.Rat to a decimal string. + // Use a precision of 38 digits (enough for BIGNUMERIC and NUMERIC) + // and trim trailing zeros to match BigQuery's behavior. + s := rat.FloatString(38) + if strings.Contains(s, ".") { + s = strings.TrimRight(s, "0") + s = strings.TrimRight(s, ".") + } + return s + } + + // Use reflection for slices and maps to handle various underlying types. + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Slice, reflect.Array: + // Preserve []byte as is, so json.Marshal encodes it as Base64 string (BigQuery BYTES behavior). + if rv.Type().Elem().Kind() == reflect.Uint8 { + return v + } + newSlice := make([]any, rv.Len()) + for i := 0; i < rv.Len(); i++ { + newSlice[i] = NormalizeValue(rv.Index(i).Interface()) + } + return newSlice + case reflect.Map: + // Ensure keys are strings to produce a JSON-compatible map. + if rv.Type().Key().Kind() != reflect.String { + return v + } + newMap := make(map[string]any, rv.Len()) + iter := rv.MapRange() + for iter.Next() { + newMap[iter.Key().String()] = NormalizeValue(iter.Value().Interface()) + } + return newMap + } + + return v +} + func initBigQueryConnection( ctx context.Context, tracer trace.Tracer, diff --git a/internal/sources/bigquery/bigquery_test.go b/internal/sources/bigquery/bigquery_test.go index 58970a2ddb..bca167ff57 100644 --- a/internal/sources/bigquery/bigquery_test.go +++ b/internal/sources/bigquery/bigquery_test.go @@ -15,6 +15,8 @@ package bigquery_test import ( + "math/big" + "reflect" "testing" yaml "github.com/goccy/go-yaml" @@ -195,3 +197,105 @@ func TestFailParseFromYaml(t *testing.T) { }) } } + +func TestNormalizeValue(t *testing.T) { + tests := []struct { + name string + input any + expected any + }{ + { + name: "big.Rat 1/3 (NUMERIC scale 9)", + input: new(big.Rat).SetFrac64(1, 3), // 0.33333333333... + expected: "0.33333333333333333333333333333333333333", // FloatString(38) + }, + { + name: "big.Rat 19/2 (9.5)", + input: new(big.Rat).SetFrac64(19, 2), + expected: "9.5", + }, + { + name: "big.Rat 12341/10 (1234.1)", + input: new(big.Rat).SetFrac64(12341, 10), + expected: "1234.1", + }, + { + name: "big.Rat 10/1 (10)", + input: new(big.Rat).SetFrac64(10, 1), + expected: "10", + }, + { + name: "string", + input: "hello", + expected: "hello", + }, + { + name: "int", + input: 123, + expected: 123, + }, + { + name: "nested slice of big.Rat", + input: []any{ + new(big.Rat).SetFrac64(19, 2), + new(big.Rat).SetFrac64(1, 4), + }, + expected: []any{"9.5", "0.25"}, + }, + { + name: "nested map of big.Rat", + input: map[string]any{ + "val1": new(big.Rat).SetFrac64(19, 2), + "val2": new(big.Rat).SetFrac64(1, 2), + }, + expected: map[string]any{ + "val1": "9.5", + "val2": "0.5", + }, + }, + { + name: "complex nested structure", + input: map[string]any{ + "list": []any{ + map[string]any{ + "rat": new(big.Rat).SetFrac64(3, 2), + }, + }, + }, + expected: map[string]any{ + "list": []any{ + map[string]any{ + "rat": "1.5", + }, + }, + }, + }, + { + name: "slice of *big.Rat", + input: []*big.Rat{ + new(big.Rat).SetFrac64(19, 2), + new(big.Rat).SetFrac64(1, 4), + }, + expected: []any{"9.5", "0.25"}, + }, + { + name: "slice of strings", + input: []string{"a", "b"}, + expected: []any{"a", "b"}, + }, + { + name: "byte slice (BYTES)", + input: []byte("hello"), + expected: []byte("hello"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := bigquery.NormalizeValue(tt.input) + if !reflect.DeepEqual(got, tt.expected) { + t.Errorf("NormalizeValue() = %v, want %v", got, tt.expected) + } + }) + } +} diff --git a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go index 61b90a1d11..f3312acfb3 100644 --- a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go +++ b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go @@ -28,7 +28,6 @@ import ( bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" bigqueryrestapi "google.golang.org/api/bigquery/v2" - "google.golang.org/api/iterator" ) const kind string = "bigquery-analyze-contribution" @@ -49,12 +48,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { BigQueryClient() *bigqueryapi.Client - BigQueryRestService() *bigqueryrestapi.Service - BigQueryClientCreator() bigqueryds.BigqueryClientCreator UseClientAuthorization() bool IsDatasetAllowed(projectID, datasetID string) bool BigQueryAllowedDatasets() []string BigQuerySession() bigqueryds.BigQuerySessionProvider + RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) + RunSQL(context.Context, *bigqueryapi.Client, string, string, []bigqueryapi.QueryParameter, []*bigqueryapi.ConnectionProperty) (any, error) } type Config struct { @@ -166,19 +165,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to cast input_data parameter %s", paramsMap["input_data"]) } - bqClient := source.BigQueryClient() - restService := source.BigQueryRestService() - - // Initialize new client if using user OAuth token - if source.UseClientAuthorization() { - tokenStr, err := accessToken.ParseBearerToken() - if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) - } - bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true) - if err != nil { - return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) - } + bqClient, restService, err := source.RetrieveClientAndService(accessToken) + if err != nil { + return nil, err } modelID := fmt.Sprintf("contribution_analysis_model_%s", strings.ReplaceAll(uuid.New().String(), "-", "")) @@ -314,43 +303,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } getInsightsSQL := fmt.Sprintf("SELECT * FROM ML.GET_INSIGHTS(MODEL %s)", modelID) - - getInsightsQuery := bqClient.Query(getInsightsSQL) - getInsightsQuery.ConnectionProperties = []*bigqueryapi.ConnectionProperty{{Key: "session_id", Value: sessionID}} - - job, err := getInsightsQuery.Run(ctx) - if err != nil { - return nil, fmt.Errorf("failed to execute get insights query: %w", err) - } - it, err := job.Read(ctx) - if err != nil { - return nil, fmt.Errorf("unable to read query results: %w", err) - } - - var out []any - for { - var row map[string]bigqueryapi.Value - err := it.Next(&row) - if err == iterator.Done { - break - } - if err != nil { - return nil, fmt.Errorf("failed to iterate through query results: %w", err) - } - vMap := make(map[string]any) - for key, value := range row { - vMap[key] = value - } - out = append(out, vMap) - } - - if len(out) > 0 { - return out, nil - } - - // This handles the standard case for a SELECT query that successfully - // executes but returns zero rows. - return "The query returned 0 rows.", nil + connProps := []*bigqueryapi.ConnectionProperty{{Key: "session_id", Value: sessionID}} + return source.RunSQL(ctx, bqClient, getInsightsSQL, "SELECT", nil, connProps) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/bigquery/bigquerycommon/conversion_test.go b/internal/tools/bigquery/bigquerycommon/conversion_test.go deleted file mode 100644 index c735d0ebe1..0000000000 --- a/internal/tools/bigquery/bigquerycommon/conversion_test.go +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package bigquerycommon - -import ( - "math/big" - "reflect" - "testing" -) - -func TestNormalizeValue(t *testing.T) { - tests := []struct { - name string - input any - expected any - }{ - { - name: "big.Rat 1/3 (NUMERIC scale 9)", - input: new(big.Rat).SetFrac64(1, 3), // 0.33333333333... - expected: "0.33333333333333333333333333333333333333", // FloatString(38) - }, - { - name: "big.Rat 19/2 (9.5)", - input: new(big.Rat).SetFrac64(19, 2), - expected: "9.5", - }, - { - name: "big.Rat 12341/10 (1234.1)", - input: new(big.Rat).SetFrac64(12341, 10), - expected: "1234.1", - }, - { - name: "big.Rat 10/1 (10)", - input: new(big.Rat).SetFrac64(10, 1), - expected: "10", - }, - { - name: "string", - input: "hello", - expected: "hello", - }, - { - name: "int", - input: 123, - expected: 123, - }, - { - name: "nested slice of big.Rat", - input: []any{ - new(big.Rat).SetFrac64(19, 2), - new(big.Rat).SetFrac64(1, 4), - }, - expected: []any{"9.5", "0.25"}, - }, - { - name: "nested map of big.Rat", - input: map[string]any{ - "val1": new(big.Rat).SetFrac64(19, 2), - "val2": new(big.Rat).SetFrac64(1, 2), - }, - expected: map[string]any{ - "val1": "9.5", - "val2": "0.5", - }, - }, - { - name: "complex nested structure", - input: map[string]any{ - "list": []any{ - map[string]any{ - "rat": new(big.Rat).SetFrac64(3, 2), - }, - }, - }, - expected: map[string]any{ - "list": []any{ - map[string]any{ - "rat": "1.5", - }, - }, - }, - }, - { - name: "slice of *big.Rat", - input: []*big.Rat{ - new(big.Rat).SetFrac64(19, 2), - new(big.Rat).SetFrac64(1, 4), - }, - expected: []any{"9.5", "0.25"}, - }, - { - name: "slice of strings", - input: []string{"a", "b"}, - expected: []any{"a", "b"}, - }, - { - name: "byte slice (BYTES)", - input: []byte("hello"), - expected: []byte("hello"), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := NormalizeValue(tt.input) - if !reflect.DeepEqual(got, tt.expected) { - t.Errorf("NormalizeValue() = %v, want %v", got, tt.expected) - } - }) - } -} diff --git a/internal/tools/bigquery/bigquerycommon/util.go b/internal/tools/bigquery/bigquerycommon/util.go index d9b6fd0283..5486ac36ed 100644 --- a/internal/tools/bigquery/bigquerycommon/util.go +++ b/internal/tools/bigquery/bigquerycommon/util.go @@ -17,8 +17,6 @@ package bigquerycommon import ( "context" "fmt" - "math/big" - "reflect" "sort" "strings" @@ -120,54 +118,3 @@ func InitializeDatasetParameters( return projectParam, datasetParam } - -// NormalizeValue converts BigQuery specific types to standard JSON-compatible types. -// Specifically, it handles *big.Rat (used for NUMERIC/BIGNUMERIC) by converting -// them to decimal strings with up to 38 digits of precision, trimming trailing zeros. -// It recursively handles slices (arrays) and maps (structs) using reflection. -func NormalizeValue(v any) any { - if v == nil { - return nil - } - - // Handle *big.Rat specifically. - if rat, ok := v.(*big.Rat); ok { - // Convert big.Rat to a decimal string. - // Use a precision of 38 digits (enough for BIGNUMERIC and NUMERIC) - // and trim trailing zeros to match BigQuery's behavior. - s := rat.FloatString(38) - if strings.Contains(s, ".") { - s = strings.TrimRight(s, "0") - s = strings.TrimRight(s, ".") - } - return s - } - - // Use reflection for slices and maps to handle various underlying types. - rv := reflect.ValueOf(v) - switch rv.Kind() { - case reflect.Slice, reflect.Array: - // Preserve []byte as is, so json.Marshal encodes it as Base64 string (BigQuery BYTES behavior). - if rv.Type().Elem().Kind() == reflect.Uint8 { - return v - } - newSlice := make([]any, rv.Len()) - for i := 0; i < rv.Len(); i++ { - newSlice[i] = NormalizeValue(rv.Index(i).Interface()) - } - return newSlice - case reflect.Map: - // Ensure keys are strings to produce a JSON-compatible map. - if rv.Type().Key().Kind() != reflect.String { - return v - } - newMap := make(map[string]any, rv.Len()) - iter := rv.MapRange() - for iter.Next() { - newMap[iter.Key().String()] = NormalizeValue(iter.Value().Interface()) - } - return newMap - } - - return v -} diff --git a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go index a70d4d342d..3e248e1971 100644 --- a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go +++ b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go @@ -27,10 +27,8 @@ import ( "github.com/googleapis/genai-toolbox/internal/tools" bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" "github.com/googleapis/genai-toolbox/internal/util" - "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "github.com/googleapis/genai-toolbox/internal/util/parameters" bigqueryrestapi "google.golang.org/api/bigquery/v2" - "google.golang.org/api/iterator" ) const kind string = "bigquery-execute-sql" @@ -53,11 +51,11 @@ type compatibleSource interface { BigQueryClient() *bigqueryapi.Client BigQuerySession() bigqueryds.BigQuerySessionProvider BigQueryWriteMode() string - BigQueryRestService() *bigqueryrestapi.Service - BigQueryClientCreator() bigqueryds.BigqueryClientCreator UseClientAuthorization() bool IsDatasetAllowed(projectID, datasetID string) bool BigQueryAllowedDatasets() []string + RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) + RunSQL(context.Context, *bigqueryapi.Client, string, string, []bigqueryapi.QueryParameter, []*bigqueryapi.ConnectionProperty) (any, error) } type Config struct { @@ -169,19 +167,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"]) } - bqClient := source.BigQueryClient() - restService := source.BigQueryRestService() - - // Initialize new client if using user OAuth token - if source.UseClientAuthorization() { - tokenStr, err := accessToken.ParseBearerToken() - if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) - } - bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true) - if err != nil { - return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) - } + bqClient, restService, err := source.RetrieveClientAndService(accessToken) + if err != nil { + return nil, err } var connProps []*bigqueryapi.ConnectionProperty @@ -283,61 +271,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return "Dry run was requested, but no job information was returned.", nil } - query := bqClient.Query(sql) - query.Location = bqClient.Location - - query.ConnectionProperties = connProps - // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("error getting logger: %s", err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - - // This block handles SELECT statements, which return a row set. - // We iterate through the results, convert each row into a map of - // column names to values, and return the collection of rows. - var out []any - job, err := query.Run(ctx) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - it, err := job.Read(ctx) - if err != nil { - return nil, fmt.Errorf("unable to read query results: %w", err) - } - for { - var val []bigqueryapi.Value - err = it.Next(&val) - if err == iterator.Done { - break - } - if err != nil { - return nil, fmt.Errorf("unable to iterate through query results: %w", err) - } - schema := it.Schema - row := orderedmap.Row{} - for i, field := range schema { - row.Add(field.Name, bqutil.NormalizeValue(val[i])) - } - out = append(out, row) - } - // If the query returned any rows, return them directly. - if len(out) > 0 { - return out, nil - } - - // This handles the standard case for a SELECT query that successfully - // executes but returns zero rows. - if statementType == "SELECT" { - return "The query returned 0 rows.", nil - } - // This is the fallback for a successful query that doesn't return content. - // In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc. - // However, it is also possible that this was a query that was expected to return rows - // but returned none, a case that we cannot distinguish here. - return "Query executed successfully and returned no content.", nil + return source.RunSQL(ctx, bqClient, sql, statementType, nil, connProps) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go index 034bce3501..b316aead3f 100644 --- a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go +++ b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go @@ -28,7 +28,6 @@ import ( "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" bigqueryrestapi "google.golang.org/api/bigquery/v2" - "google.golang.org/api/iterator" ) const kind string = "bigquery-forecast" @@ -49,12 +48,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { BigQueryClient() *bigqueryapi.Client - BigQueryRestService() *bigqueryrestapi.Service - BigQueryClientCreator() bigqueryds.BigqueryClientCreator UseClientAuthorization() bool IsDatasetAllowed(projectID, datasetID string) bool BigQueryAllowedDatasets() []string BigQuerySession() bigqueryds.BigQuerySessionProvider + RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) + RunSQL(context.Context, *bigqueryapi.Client, string, string, []bigqueryapi.QueryParameter, []*bigqueryapi.ConnectionProperty) (any, error) } type Config struct { @@ -173,19 +172,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - bqClient := source.BigQueryClient() - restService := source.BigQueryRestService() - - // Initialize new client if using user OAuth token - if source.UseClientAuthorization() { - tokenStr, err := accessToken.ParseBearerToken() - if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) - } - bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, false) - if err != nil { - return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) - } + bqClient, restService, err := source.RetrieveClientAndService(accessToken) + if err != nil { + return nil, err } var historyDataSource string @@ -251,7 +240,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para idColsFormatted := fmt.Sprintf("['%s']", strings.Join(idCols, "', '")) idColsArg = fmt.Sprintf(", id_cols => %s", idColsFormatted) } - sql := fmt.Sprintf(`SELECT * FROM AI.FORECAST( %s, @@ -260,16 +248,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para horizon => %d%s)`, historyDataSource, dataCol, timestampCol, horizon, idColsArg) - // JobStatistics.QueryStatistics.StatementType - query := bqClient.Query(sql) - query.Location = bqClient.Location session, err := source.BigQuerySession()(ctx) if err != nil { return nil, fmt.Errorf("failed to get BigQuery session: %w", err) } + var connProps []*bigqueryapi.ConnectionProperty if session != nil { // Add session ID to the connection properties for subsequent calls. - query.ConnectionProperties = []*bigqueryapi.ConnectionProperty{ + connProps = []*bigqueryapi.ConnectionProperty{ {Key: "session_id", Value: session.ID}, } } @@ -281,40 +267,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - // This block handles SELECT statements, which return a row set. - // We iterate through the results, convert each row into a map of - // column names to values, and return the collection of rows. - var out []any - job, err := query.Run(ctx) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - it, err := job.Read(ctx) - if err != nil { - return nil, fmt.Errorf("unable to read query results: %w", err) - } - for { - var row map[string]bigqueryapi.Value - err = it.Next(&row) - if err == iterator.Done { - break - } - if err != nil { - return nil, fmt.Errorf("unable to iterate through query results: %w", err) - } - vMap := make(map[string]any) - for key, value := range row { - vMap[key] = value - } - out = append(out, vMap) - } - // If the query returned any rows, return them directly. - if len(out) > 0 { - return out, nil - } - - // This handles the standard case for a SELECT query that successfully - return "The query returned 0 rows.", nil + return source.RunSQL(ctx, bqClient, sql, "SELECT", nil, connProps) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go b/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go index b083c49e2c..545850066e 100644 --- a/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go +++ b/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go @@ -21,10 +21,10 @@ import ( bigqueryapi "cloud.google.com/go/bigquery" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" "github.com/googleapis/genai-toolbox/internal/tools" bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" + bigqueryrestapi "google.golang.org/api/bigquery/v2" ) const kind string = "bigquery-get-dataset-info" @@ -47,11 +47,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { BigQueryProject() string - BigQueryClient() *bigqueryapi.Client - BigQueryClientCreator() bigqueryds.BigqueryClientCreator UseClientAuthorization() bool IsDatasetAllowed(projectID, datasetID string) bool BigQueryAllowedDatasets() []string + RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) } type Config struct { @@ -138,18 +137,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey) } - bqClient := source.BigQueryClient() - - // Initialize new client if using user OAuth token - if source.UseClientAuthorization() { - tokenStr, err := accessToken.ParseBearerToken() - if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) - } - bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false) - if err != nil { - return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) - } + bqClient, _, err := source.RetrieveClientAndService(accessToken) + if err != nil { + return nil, err } if !source.IsDatasetAllowed(projectId, datasetId) { diff --git a/internal/tools/bigquery/bigquerygettableinfo/bigquerygettableinfo.go b/internal/tools/bigquery/bigquerygettableinfo/bigquerygettableinfo.go index b896244ed0..4cfc91e55b 100644 --- a/internal/tools/bigquery/bigquerygettableinfo/bigquerygettableinfo.go +++ b/internal/tools/bigquery/bigquerygettableinfo/bigquerygettableinfo.go @@ -21,10 +21,10 @@ import ( bigqueryapi "cloud.google.com/go/bigquery" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" "github.com/googleapis/genai-toolbox/internal/tools" bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" + bigqueryrestapi "google.golang.org/api/bigquery/v2" ) const kind string = "bigquery-get-table-info" @@ -48,11 +48,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { BigQueryProject() string - BigQueryClient() *bigqueryapi.Client - BigQueryClientCreator() bigqueryds.BigqueryClientCreator UseClientAuthorization() bool IsDatasetAllowed(projectID, datasetID string) bool BigQueryAllowedDatasets() []string + RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) } type Config struct { @@ -151,18 +150,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId) } - bqClient := source.BigQueryClient() - - // Initialize new client if using user OAuth token - if source.UseClientAuthorization() { - tokenStr, err := accessToken.ParseBearerToken() - if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) - } - bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false) - if err != nil { - return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) - } + bqClient, _, err := source.RetrieveClientAndService(accessToken) + if err != nil { + return nil, err } dsHandle := bqClient.DatasetInProject(projectId, datasetId) diff --git a/internal/tools/bigquery/bigquerylistdatasetids/bigquerylistdatasetids.go b/internal/tools/bigquery/bigquerylistdatasetids/bigquerylistdatasetids.go index dafe9b2246..93663f4f45 100644 --- a/internal/tools/bigquery/bigquerylistdatasetids/bigquerylistdatasetids.go +++ b/internal/tools/bigquery/bigquerylistdatasetids/bigquerylistdatasetids.go @@ -21,9 +21,9 @@ import ( bigqueryapi "cloud.google.com/go/bigquery" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" + bigqueryrestapi "google.golang.org/api/bigquery/v2" "google.golang.org/api/iterator" ) @@ -46,10 +46,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { BigQueryProject() string - BigQueryClient() *bigqueryapi.Client - BigQueryClientCreator() bigqueryds.BigqueryClientCreator UseClientAuthorization() bool BigQueryAllowedDatasets() []string + RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) } type Config struct { @@ -135,17 +134,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey) } - bqClient := source.BigQueryClient() - // Initialize new client if using user OAuth token - if source.UseClientAuthorization() { - tokenStr, err := accessToken.ParseBearerToken() - if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) - } - bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false) - if err != nil { - return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) - } + bqClient, _, err := source.RetrieveClientAndService(accessToken) + if err != nil { + return nil, err } datasetIterator := bqClient.Datasets(ctx) datasetIterator.ProjectID = projectId diff --git a/internal/tools/bigquery/bigquerylisttableids/bigquerylisttableids.go b/internal/tools/bigquery/bigquerylisttableids/bigquerylisttableids.go index 11987c6dac..e3f609f522 100644 --- a/internal/tools/bigquery/bigquerylisttableids/bigquerylisttableids.go +++ b/internal/tools/bigquery/bigquerylisttableids/bigquerylisttableids.go @@ -21,10 +21,10 @@ import ( bigqueryapi "cloud.google.com/go/bigquery" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" "github.com/googleapis/genai-toolbox/internal/tools" bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" + bigqueryrestapi "google.golang.org/api/bigquery/v2" "google.golang.org/api/iterator" ) @@ -47,12 +47,11 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T } type compatibleSource interface { - BigQueryClient() *bigqueryapi.Client - BigQueryClientCreator() bigqueryds.BigqueryClientCreator BigQueryProject() string UseClientAuthorization() bool IsDatasetAllowed(projectID, datasetID string) bool BigQueryAllowedDatasets() []string + RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) } type Config struct { @@ -145,17 +144,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId) } - bqClient := source.BigQueryClient() - // Initialize new client if using user OAuth token - if source.UseClientAuthorization() { - tokenStr, err := accessToken.ParseBearerToken() - if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) - } - bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false) - if err != nil { - return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) - } + bqClient, _, err := source.RetrieveClientAndService(accessToken) + if err != nil { + return nil, err } dsHandle := bqClient.DatasetInProject(projectId, datasetId) diff --git a/internal/tools/bigquery/bigquerysql/bigquerysql.go b/internal/tools/bigquery/bigquerysql/bigquerysql.go index fa02f658eb..ff433a9ed5 100644 --- a/internal/tools/bigquery/bigquerysql/bigquerysql.go +++ b/internal/tools/bigquery/bigquerysql/bigquerysql.go @@ -23,13 +23,11 @@ import ( bigqueryapi "cloud.google.com/go/bigquery" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/util/parameters" - bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" "github.com/googleapis/genai-toolbox/internal/tools" bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" + "github.com/googleapis/genai-toolbox/internal/util/parameters" bigqueryrestapi "google.golang.org/api/bigquery/v2" - "google.golang.org/api/iterator" ) const kind string = "bigquery-sql" @@ -49,12 +47,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T } type compatibleSource interface { - BigQueryClient() *bigqueryapi.Client BigQuerySession() bigqueryds.BigQuerySessionProvider - BigQueryWriteMode() string - BigQueryRestService() *bigqueryrestapi.Service - BigQueryClientCreator() bigqueryds.BigqueryClientCreator UseClientAuthorization() bool + RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) + RunSQL(context.Context, *bigqueryapi.Client, string, string, []bigqueryapi.QueryParameter, []*bigqueryapi.ConnectionProperty) (any, error) } type Config struct { @@ -189,25 +185,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para lowLevelParams = append(lowLevelParams, lowLevelParam) } - bqClient := source.BigQueryClient() - restService := source.BigQueryRestService() - - // Initialize new client if using user OAuth token - if source.UseClientAuthorization() { - tokenStr, err := accessToken.ParseBearerToken() - if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) - } - bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true) - if err != nil { - return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) - } - } - - query := bqClient.Query(newStatement) - query.Parameters = highLevelParams - query.Location = bqClient.Location - connProps := []*bigqueryapi.ConnectionProperty{} if source.BigQuerySession() != nil { session, err := source.BigQuerySession()(ctx) @@ -219,57 +196,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para connProps = append(connProps, &bigqueryapi.ConnectionProperty{Key: "session_id", Value: session.ID}) } } - query.ConnectionProperties = connProps - dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), query.Location, newStatement, lowLevelParams, connProps) + + bqClient, restService, err := source.RetrieveClientAndService(accessToken) + if err != nil { + return nil, err + } + + dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, connProps) if err != nil { return nil, fmt.Errorf("query validation failed: %w", err) } statementType := dryRunJob.Statistics.Query.StatementType - // This block handles SELECT statements, which return a row set. - // We iterate through the results, convert each row into a map of - // column names to values, and return the collection of rows. - job, err := query.Run(ctx) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - it, err := job.Read(ctx) - if err != nil { - return nil, fmt.Errorf("unable to read query results: %w", err) - } - - var out []any - for { - var row map[string]bigqueryapi.Value - err = it.Next(&row) - if err == iterator.Done { - break - } - if err != nil { - return nil, fmt.Errorf("unable to iterate through query results: %w", err) - } - vMap := make(map[string]any) - for key, value := range row { - vMap[key] = bqutil.NormalizeValue(value) - } - out = append(out, vMap) - } - // If the query returned any rows, return them directly. - if len(out) > 0 { - return out, nil - } - - // This handles the standard case for a SELECT query that successfully - // executes but returns zero rows. - if statementType == "SELECT" { - return "The query returned 0 rows.", nil - } - // This is the fallback for a successful query that doesn't return content. - // In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc. - // However, it is also possible that this was a query that was expected to return rows - // but returned none, a case that we cannot distinguish here. - return "Query executed successfully and returned no content.", nil + return source.RunSQL(ctx, bqClient, newStatement, statementType, highLevelParams, connProps) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/tests/bigquery/bigquery_integration_test.go b/tests/bigquery/bigquery_integration_test.go index f79b8b7d31..de5126cd24 100644 --- a/tests/bigquery/bigquery_integration_test.go +++ b/tests/bigquery/bigquery_integration_test.go @@ -1701,7 +1701,7 @@ func runBigQueryDataTypeTests(t *testing.T) { api: "http://127.0.0.1:5000/api/tool/my-scalar-datatype-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{"int_val": 123, "string_val": "hello", "float_val": 3.14, "bool_val": true}`)), - want: `[{"bool_val":true,"float_val":3.14,"id":1,"int_val":123,"string_val":"hello"}]`, + want: `[{"id":1,"int_val":123,"string_val":"hello","float_val":3.14,"bool_val":true}]`, isErr: false, }, { @@ -1716,7 +1716,7 @@ func runBigQueryDataTypeTests(t *testing.T) { api: "http://127.0.0.1:5000/api/tool/my-array-datatype-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{"int_array": [123, 789], "string_array": ["hello", "test"], "float_array": [3.14, 100.1], "bool_array": [true]}`)), - want: `[{"bool_val":true,"float_val":3.14,"id":1,"int_val":123,"string_val":"hello"},{"bool_val":true,"float_val":100.1,"id":3,"int_val":789,"string_val":"test"}]`, + want: `[{"id":1,"int_val":123,"string_val":"hello","float_val":3.14,"bool_val":true},{"id":3,"int_val":789,"string_val":"test","float_val":100.1,"bool_val":true}]`, isErr: false, }, }