refactor(sources/bigquery): move source implementation in Invoke() function to Source (#2242)

Move source-related queries from `Invoke()` function into Source.

This is an effort to generalizing tools to work with any Source that
implements a specific interface. This will provide a better segregation
of the roles for Tools vs Source.

Tool's role will be limited to the following:
* Resolve any pre-implementation steps or parameters (e.g. template
parameters)
* Retrieving Source
* Calling the source's implementation
This commit is contained in:
Yuan Teoh
2025-12-30 21:43:09 -08:00
committed by GitHub
parent f9df2635c6
commit 0f27f956c7
13 changed files with 288 additions and 481 deletions

View File

@@ -17,7 +17,9 @@ package bigquery
import (
"context"
"fmt"
"math/big"
"net/http"
"reflect"
"strings"
"sync"
"time"
@@ -26,13 +28,16 @@ import (
dataplexapi "cloud.google.com/go/dataplex/apiv1"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/orderedmap"
"go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/googleapi"
"google.golang.org/api/impersonate"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
)
@@ -483,6 +488,131 @@ func (s *Source) lazyInitDataplexClient(ctx context.Context, tracer trace.Tracer
}
}
func (s *Source) RetrieveClientAndService(accessToken tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
bqClient := s.BigQueryClient()
restService := s.BigQueryRestService()
// Initialize new client if using user OAuth token
if s.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = s.BigQueryClientCreator()(tokenStr, true)
if err != nil {
return nil, nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}
return bqClient, restService, nil
}
func (s *Source) RunSQL(ctx context.Context, bqClient *bigqueryapi.Client, statement, statementType string, params []bigqueryapi.QueryParameter, connProps []*bigqueryapi.ConnectionProperty) (any, error) {
query := bqClient.Query(statement)
query.Location = bqClient.Location
if params != nil {
query.Parameters = params
}
if connProps != nil {
query.ConnectionProperties = connProps
}
// This block handles SELECT statements, which return a row set.
// We iterate through the results, convert each row into a map of
// column names to values, and return the collection of rows.
job, err := query.Run(ctx)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
it, err := job.Read(ctx)
if err != nil {
return nil, fmt.Errorf("unable to read query results: %w", err)
}
var out []any
for {
var val []bigqueryapi.Value
err = it.Next(&val)
if err == iterator.Done {
break
}
if err != nil {
return nil, fmt.Errorf("unable to iterate through query results: %w", err)
}
schema := it.Schema
row := orderedmap.Row{}
for i, field := range schema {
row.Add(field.Name, NormalizeValue(val[i]))
}
out = append(out, row)
}
// If the query returned any rows, return them directly.
if len(out) > 0 {
return out, nil
}
// This handles the standard case for a SELECT query that successfully
// executes but returns zero rows.
if statementType == "SELECT" {
return "The query returned 0 rows.", nil
}
// This is the fallback for a successful query that doesn't return content.
// In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc.
// However, it is also possible that this was a query that was expected to return rows
// but returned none, a case that we cannot distinguish here.
return "Query executed successfully and returned no content.", nil
}
// NormalizeValue converts BigQuery specific types to standard JSON-compatible types.
// Specifically, it handles *big.Rat (used for NUMERIC/BIGNUMERIC) by converting
// them to decimal strings with up to 38 digits of precision, trimming trailing zeros.
// It recursively handles slices (arrays) and maps (structs) using reflection.
func NormalizeValue(v any) any {
if v == nil {
return nil
}
// Handle *big.Rat specifically.
if rat, ok := v.(*big.Rat); ok {
// Convert big.Rat to a decimal string.
// Use a precision of 38 digits (enough for BIGNUMERIC and NUMERIC)
// and trim trailing zeros to match BigQuery's behavior.
s := rat.FloatString(38)
if strings.Contains(s, ".") {
s = strings.TrimRight(s, "0")
s = strings.TrimRight(s, ".")
}
return s
}
// Use reflection for slices and maps to handle various underlying types.
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Slice, reflect.Array:
// Preserve []byte as is, so json.Marshal encodes it as Base64 string (BigQuery BYTES behavior).
if rv.Type().Elem().Kind() == reflect.Uint8 {
return v
}
newSlice := make([]any, rv.Len())
for i := 0; i < rv.Len(); i++ {
newSlice[i] = NormalizeValue(rv.Index(i).Interface())
}
return newSlice
case reflect.Map:
// Ensure keys are strings to produce a JSON-compatible map.
if rv.Type().Key().Kind() != reflect.String {
return v
}
newMap := make(map[string]any, rv.Len())
iter := rv.MapRange()
for iter.Next() {
newMap[iter.Key().String()] = NormalizeValue(iter.Value().Interface())
}
return newMap
}
return v
}
func initBigQueryConnection(
ctx context.Context,
tracer trace.Tracer,

View File

@@ -15,6 +15,8 @@
package bigquery_test
import (
"math/big"
"reflect"
"testing"
yaml "github.com/goccy/go-yaml"
@@ -195,3 +197,105 @@ func TestFailParseFromYaml(t *testing.T) {
})
}
}
func TestNormalizeValue(t *testing.T) {
tests := []struct {
name string
input any
expected any
}{
{
name: "big.Rat 1/3 (NUMERIC scale 9)",
input: new(big.Rat).SetFrac64(1, 3), // 0.33333333333...
expected: "0.33333333333333333333333333333333333333", // FloatString(38)
},
{
name: "big.Rat 19/2 (9.5)",
input: new(big.Rat).SetFrac64(19, 2),
expected: "9.5",
},
{
name: "big.Rat 12341/10 (1234.1)",
input: new(big.Rat).SetFrac64(12341, 10),
expected: "1234.1",
},
{
name: "big.Rat 10/1 (10)",
input: new(big.Rat).SetFrac64(10, 1),
expected: "10",
},
{
name: "string",
input: "hello",
expected: "hello",
},
{
name: "int",
input: 123,
expected: 123,
},
{
name: "nested slice of big.Rat",
input: []any{
new(big.Rat).SetFrac64(19, 2),
new(big.Rat).SetFrac64(1, 4),
},
expected: []any{"9.5", "0.25"},
},
{
name: "nested map of big.Rat",
input: map[string]any{
"val1": new(big.Rat).SetFrac64(19, 2),
"val2": new(big.Rat).SetFrac64(1, 2),
},
expected: map[string]any{
"val1": "9.5",
"val2": "0.5",
},
},
{
name: "complex nested structure",
input: map[string]any{
"list": []any{
map[string]any{
"rat": new(big.Rat).SetFrac64(3, 2),
},
},
},
expected: map[string]any{
"list": []any{
map[string]any{
"rat": "1.5",
},
},
},
},
{
name: "slice of *big.Rat",
input: []*big.Rat{
new(big.Rat).SetFrac64(19, 2),
new(big.Rat).SetFrac64(1, 4),
},
expected: []any{"9.5", "0.25"},
},
{
name: "slice of strings",
input: []string{"a", "b"},
expected: []any{"a", "b"},
},
{
name: "byte slice (BYTES)",
input: []byte("hello"),
expected: []byte("hello"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := bigquery.NormalizeValue(tt.input)
if !reflect.DeepEqual(got, tt.expected) {
t.Errorf("NormalizeValue() = %v, want %v", got, tt.expected)
}
})
}
}

View File

@@ -28,7 +28,6 @@ import (
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator"
)
const kind string = "bigquery-analyze-contribution"
@@ -49,12 +48,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
BigQueryClient() *bigqueryapi.Client
BigQueryRestService() *bigqueryrestapi.Service
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
IsDatasetAllowed(projectID, datasetID string) bool
BigQueryAllowedDatasets() []string
BigQuerySession() bigqueryds.BigQuerySessionProvider
RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
RunSQL(context.Context, *bigqueryapi.Client, string, string, []bigqueryapi.QueryParameter, []*bigqueryapi.ConnectionProperty) (any, error)
}
type Config struct {
@@ -166,19 +165,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("unable to cast input_data parameter %s", paramsMap["input_data"])
}
bqClient := source.BigQueryClient()
restService := source.BigQueryRestService()
// Initialize new client if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
bqClient, restService, err := source.RetrieveClientAndService(accessToken)
if err != nil {
return nil, err
}
modelID := fmt.Sprintf("contribution_analysis_model_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
@@ -314,43 +303,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
getInsightsSQL := fmt.Sprintf("SELECT * FROM ML.GET_INSIGHTS(MODEL %s)", modelID)
getInsightsQuery := bqClient.Query(getInsightsSQL)
getInsightsQuery.ConnectionProperties = []*bigqueryapi.ConnectionProperty{{Key: "session_id", Value: sessionID}}
job, err := getInsightsQuery.Run(ctx)
if err != nil {
return nil, fmt.Errorf("failed to execute get insights query: %w", err)
}
it, err := job.Read(ctx)
if err != nil {
return nil, fmt.Errorf("unable to read query results: %w", err)
}
var out []any
for {
var row map[string]bigqueryapi.Value
err := it.Next(&row)
if err == iterator.Done {
break
}
if err != nil {
return nil, fmt.Errorf("failed to iterate through query results: %w", err)
}
vMap := make(map[string]any)
for key, value := range row {
vMap[key] = value
}
out = append(out, vMap)
}
if len(out) > 0 {
return out, nil
}
// This handles the standard case for a SELECT query that successfully
// executes but returns zero rows.
return "The query returned 0 rows.", nil
connProps := []*bigqueryapi.ConnectionProperty{{Key: "session_id", Value: sessionID}}
return source.RunSQL(ctx, bqClient, getInsightsSQL, "SELECT", nil, connProps)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -1,123 +0,0 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bigquerycommon
import (
"math/big"
"reflect"
"testing"
)
func TestNormalizeValue(t *testing.T) {
tests := []struct {
name string
input any
expected any
}{
{
name: "big.Rat 1/3 (NUMERIC scale 9)",
input: new(big.Rat).SetFrac64(1, 3), // 0.33333333333...
expected: "0.33333333333333333333333333333333333333", // FloatString(38)
},
{
name: "big.Rat 19/2 (9.5)",
input: new(big.Rat).SetFrac64(19, 2),
expected: "9.5",
},
{
name: "big.Rat 12341/10 (1234.1)",
input: new(big.Rat).SetFrac64(12341, 10),
expected: "1234.1",
},
{
name: "big.Rat 10/1 (10)",
input: new(big.Rat).SetFrac64(10, 1),
expected: "10",
},
{
name: "string",
input: "hello",
expected: "hello",
},
{
name: "int",
input: 123,
expected: 123,
},
{
name: "nested slice of big.Rat",
input: []any{
new(big.Rat).SetFrac64(19, 2),
new(big.Rat).SetFrac64(1, 4),
},
expected: []any{"9.5", "0.25"},
},
{
name: "nested map of big.Rat",
input: map[string]any{
"val1": new(big.Rat).SetFrac64(19, 2),
"val2": new(big.Rat).SetFrac64(1, 2),
},
expected: map[string]any{
"val1": "9.5",
"val2": "0.5",
},
},
{
name: "complex nested structure",
input: map[string]any{
"list": []any{
map[string]any{
"rat": new(big.Rat).SetFrac64(3, 2),
},
},
},
expected: map[string]any{
"list": []any{
map[string]any{
"rat": "1.5",
},
},
},
},
{
name: "slice of *big.Rat",
input: []*big.Rat{
new(big.Rat).SetFrac64(19, 2),
new(big.Rat).SetFrac64(1, 4),
},
expected: []any{"9.5", "0.25"},
},
{
name: "slice of strings",
input: []string{"a", "b"},
expected: []any{"a", "b"},
},
{
name: "byte slice (BYTES)",
input: []byte("hello"),
expected: []byte("hello"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NormalizeValue(tt.input)
if !reflect.DeepEqual(got, tt.expected) {
t.Errorf("NormalizeValue() = %v, want %v", got, tt.expected)
}
})
}
}

View File

@@ -17,8 +17,6 @@ package bigquerycommon
import (
"context"
"fmt"
"math/big"
"reflect"
"sort"
"strings"
@@ -120,54 +118,3 @@ func InitializeDatasetParameters(
return projectParam, datasetParam
}
// NormalizeValue converts BigQuery specific types to standard JSON-compatible types.
// Specifically, it handles *big.Rat (used for NUMERIC/BIGNUMERIC) by converting
// them to decimal strings with up to 38 digits of precision, trimming trailing zeros.
// It recursively handles slices (arrays) and maps (structs) using reflection.
func NormalizeValue(v any) any {
if v == nil {
return nil
}
// Handle *big.Rat specifically.
if rat, ok := v.(*big.Rat); ok {
// Convert big.Rat to a decimal string.
// Use a precision of 38 digits (enough for BIGNUMERIC and NUMERIC)
// and trim trailing zeros to match BigQuery's behavior.
s := rat.FloatString(38)
if strings.Contains(s, ".") {
s = strings.TrimRight(s, "0")
s = strings.TrimRight(s, ".")
}
return s
}
// Use reflection for slices and maps to handle various underlying types.
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Slice, reflect.Array:
// Preserve []byte as is, so json.Marshal encodes it as Base64 string (BigQuery BYTES behavior).
if rv.Type().Elem().Kind() == reflect.Uint8 {
return v
}
newSlice := make([]any, rv.Len())
for i := 0; i < rv.Len(); i++ {
newSlice[i] = NormalizeValue(rv.Index(i).Interface())
}
return newSlice
case reflect.Map:
// Ensure keys are strings to produce a JSON-compatible map.
if rv.Type().Key().Kind() != reflect.String {
return v
}
newMap := make(map[string]any, rv.Len())
iter := rv.MapRange()
for iter.Next() {
newMap[iter.Key().String()] = NormalizeValue(iter.Value().Interface())
}
return newMap
}
return v
}

View File

@@ -27,10 +27,8 @@ import (
"github.com/googleapis/genai-toolbox/internal/tools"
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/orderedmap"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator"
)
const kind string = "bigquery-execute-sql"
@@ -53,11 +51,11 @@ type compatibleSource interface {
BigQueryClient() *bigqueryapi.Client
BigQuerySession() bigqueryds.BigQuerySessionProvider
BigQueryWriteMode() string
BigQueryRestService() *bigqueryrestapi.Service
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
IsDatasetAllowed(projectID, datasetID string) bool
BigQueryAllowedDatasets() []string
RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
RunSQL(context.Context, *bigqueryapi.Client, string, string, []bigqueryapi.QueryParameter, []*bigqueryapi.ConnectionProperty) (any, error)
}
type Config struct {
@@ -169,19 +167,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"])
}
bqClient := source.BigQueryClient()
restService := source.BigQueryRestService()
// Initialize new client if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
bqClient, restService, err := source.RetrieveClientAndService(accessToken)
if err != nil {
return nil, err
}
var connProps []*bigqueryapi.ConnectionProperty
@@ -283,61 +271,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return "Dry run was requested, but no job information was returned.", nil
}
query := bqClient.Query(sql)
query.Location = bqClient.Location
query.ConnectionProperties = connProps
// Log the query executed for debugging.
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("error getting logger: %s", err)
}
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql))
// This block handles SELECT statements, which return a row set.
// We iterate through the results, convert each row into a map of
// column names to values, and return the collection of rows.
var out []any
job, err := query.Run(ctx)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
it, err := job.Read(ctx)
if err != nil {
return nil, fmt.Errorf("unable to read query results: %w", err)
}
for {
var val []bigqueryapi.Value
err = it.Next(&val)
if err == iterator.Done {
break
}
if err != nil {
return nil, fmt.Errorf("unable to iterate through query results: %w", err)
}
schema := it.Schema
row := orderedmap.Row{}
for i, field := range schema {
row.Add(field.Name, bqutil.NormalizeValue(val[i]))
}
out = append(out, row)
}
// If the query returned any rows, return them directly.
if len(out) > 0 {
return out, nil
}
// This handles the standard case for a SELECT query that successfully
// executes but returns zero rows.
if statementType == "SELECT" {
return "The query returned 0 rows.", nil
}
// This is the fallback for a successful query that doesn't return content.
// In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc.
// However, it is also possible that this was a query that was expected to return rows
// but returned none, a case that we cannot distinguish here.
return "Query executed successfully and returned no content.", nil
return source.RunSQL(ctx, bqClient, sql, statementType, nil, connProps)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -28,7 +28,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator"
)
const kind string = "bigquery-forecast"
@@ -49,12 +48,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
BigQueryClient() *bigqueryapi.Client
BigQueryRestService() *bigqueryrestapi.Service
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
IsDatasetAllowed(projectID, datasetID string) bool
BigQueryAllowedDatasets() []string
BigQuerySession() bigqueryds.BigQuerySessionProvider
RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
RunSQL(context.Context, *bigqueryapi.Client, string, string, []bigqueryapi.QueryParameter, []*bigqueryapi.ConnectionProperty) (any, error)
}
type Config struct {
@@ -173,19 +172,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
}
bqClient := source.BigQueryClient()
restService := source.BigQueryRestService()
// Initialize new client if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
bqClient, restService, err := source.RetrieveClientAndService(accessToken)
if err != nil {
return nil, err
}
var historyDataSource string
@@ -251,7 +240,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
idColsFormatted := fmt.Sprintf("['%s']", strings.Join(idCols, "', '"))
idColsArg = fmt.Sprintf(", id_cols => %s", idColsFormatted)
}
sql := fmt.Sprintf(`SELECT *
FROM AI.FORECAST(
%s,
@@ -260,16 +248,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
horizon => %d%s)`,
historyDataSource, dataCol, timestampCol, horizon, idColsArg)
// JobStatistics.QueryStatistics.StatementType
query := bqClient.Query(sql)
query.Location = bqClient.Location
session, err := source.BigQuerySession()(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
var connProps []*bigqueryapi.ConnectionProperty
if session != nil {
// Add session ID to the connection properties for subsequent calls.
query.ConnectionProperties = []*bigqueryapi.ConnectionProperty{
connProps = []*bigqueryapi.ConnectionProperty{
{Key: "session_id", Value: session.ID},
}
}
@@ -281,40 +267,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql))
// This block handles SELECT statements, which return a row set.
// We iterate through the results, convert each row into a map of
// column names to values, and return the collection of rows.
var out []any
job, err := query.Run(ctx)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
it, err := job.Read(ctx)
if err != nil {
return nil, fmt.Errorf("unable to read query results: %w", err)
}
for {
var row map[string]bigqueryapi.Value
err = it.Next(&row)
if err == iterator.Done {
break
}
if err != nil {
return nil, fmt.Errorf("unable to iterate through query results: %w", err)
}
vMap := make(map[string]any)
for key, value := range row {
vMap[key] = value
}
out = append(out, vMap)
}
// If the query returned any rows, return them directly.
if len(out) > 0 {
return out, nil
}
// This handles the standard case for a SELECT query that successfully
return "The query returned 0 rows.", nil
return source.RunSQL(ctx, bqClient, sql, "SELECT", nil, connProps)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -21,10 +21,10 @@ import (
bigqueryapi "cloud.google.com/go/bigquery"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
)
const kind string = "bigquery-get-dataset-info"
@@ -47,11 +47,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
BigQueryProject() string
BigQueryClient() *bigqueryapi.Client
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
IsDatasetAllowed(projectID, datasetID string) bool
BigQueryAllowedDatasets() []string
RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
}
type Config struct {
@@ -138,18 +137,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
}
bqClient := source.BigQueryClient()
// Initialize new client if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
bqClient, _, err := source.RetrieveClientAndService(accessToken)
if err != nil {
return nil, err
}
if !source.IsDatasetAllowed(projectId, datasetId) {

View File

@@ -21,10 +21,10 @@ import (
bigqueryapi "cloud.google.com/go/bigquery"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
)
const kind string = "bigquery-get-table-info"
@@ -48,11 +48,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
BigQueryProject() string
BigQueryClient() *bigqueryapi.Client
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
IsDatasetAllowed(projectID, datasetID string) bool
BigQueryAllowedDatasets() []string
RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
}
type Config struct {
@@ -151,18 +150,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
}
bqClient := source.BigQueryClient()
// Initialize new client if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
bqClient, _, err := source.RetrieveClientAndService(accessToken)
if err != nil {
return nil, err
}
dsHandle := bqClient.DatasetInProject(projectId, datasetId)

View File

@@ -21,9 +21,9 @@ import (
bigqueryapi "cloud.google.com/go/bigquery"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator"
)
@@ -46,10 +46,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
BigQueryProject() string
BigQueryClient() *bigqueryapi.Client
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
BigQueryAllowedDatasets() []string
RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
}
type Config struct {
@@ -135,17 +134,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
}
bqClient := source.BigQueryClient()
// Initialize new client if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
bqClient, _, err := source.RetrieveClientAndService(accessToken)
if err != nil {
return nil, err
}
datasetIterator := bqClient.Datasets(ctx)
datasetIterator.ProjectID = projectId

View File

@@ -21,10 +21,10 @@ import (
bigqueryapi "cloud.google.com/go/bigquery"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator"
)
@@ -47,12 +47,11 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
BigQueryClient() *bigqueryapi.Client
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
BigQueryProject() string
UseClientAuthorization() bool
IsDatasetAllowed(projectID, datasetID string) bool
BigQueryAllowedDatasets() []string
RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
}
type Config struct {
@@ -145,17 +144,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
}
bqClient := source.BigQueryClient()
// Initialize new client if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
bqClient, _, err := source.RetrieveClientAndService(accessToken)
if err != nil {
return nil, err
}
dsHandle := bqClient.DatasetInProject(projectId, datasetId)

View File

@@ -23,13 +23,11 @@ import (
bigqueryapi "cloud.google.com/go/bigquery"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator"
)
const kind string = "bigquery-sql"
@@ -49,12 +47,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
BigQueryClient() *bigqueryapi.Client
BigQuerySession() bigqueryds.BigQuerySessionProvider
BigQueryWriteMode() string
BigQueryRestService() *bigqueryrestapi.Service
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
RunSQL(context.Context, *bigqueryapi.Client, string, string, []bigqueryapi.QueryParameter, []*bigqueryapi.ConnectionProperty) (any, error)
}
type Config struct {
@@ -189,25 +185,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
lowLevelParams = append(lowLevelParams, lowLevelParam)
}
bqClient := source.BigQueryClient()
restService := source.BigQueryRestService()
// Initialize new client if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}
query := bqClient.Query(newStatement)
query.Parameters = highLevelParams
query.Location = bqClient.Location
connProps := []*bigqueryapi.ConnectionProperty{}
if source.BigQuerySession() != nil {
session, err := source.BigQuerySession()(ctx)
@@ -219,57 +196,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
connProps = append(connProps, &bigqueryapi.ConnectionProperty{Key: "session_id", Value: session.ID})
}
}
query.ConnectionProperties = connProps
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), query.Location, newStatement, lowLevelParams, connProps)
bqClient, restService, err := source.RetrieveClientAndService(accessToken)
if err != nil {
return nil, err
}
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, connProps)
if err != nil {
return nil, fmt.Errorf("query validation failed: %w", err)
}
statementType := dryRunJob.Statistics.Query.StatementType
// This block handles SELECT statements, which return a row set.
// We iterate through the results, convert each row into a map of
// column names to values, and return the collection of rows.
job, err := query.Run(ctx)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
it, err := job.Read(ctx)
if err != nil {
return nil, fmt.Errorf("unable to read query results: %w", err)
}
var out []any
for {
var row map[string]bigqueryapi.Value
err = it.Next(&row)
if err == iterator.Done {
break
}
if err != nil {
return nil, fmt.Errorf("unable to iterate through query results: %w", err)
}
vMap := make(map[string]any)
for key, value := range row {
vMap[key] = bqutil.NormalizeValue(value)
}
out = append(out, vMap)
}
// If the query returned any rows, return them directly.
if len(out) > 0 {
return out, nil
}
// This handles the standard case for a SELECT query that successfully
// executes but returns zero rows.
if statementType == "SELECT" {
return "The query returned 0 rows.", nil
}
// This is the fallback for a successful query that doesn't return content.
// In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc.
// However, it is also possible that this was a query that was expected to return rows
// but returned none, a case that we cannot distinguish here.
return "Query executed successfully and returned no content.", nil
return source.RunSQL(ctx, bqClient, newStatement, statementType, highLevelParams, connProps)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -1701,7 +1701,7 @@ func runBigQueryDataTypeTests(t *testing.T) {
api: "http://127.0.0.1:5000/api/tool/my-scalar-datatype-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{"int_val": 123, "string_val": "hello", "float_val": 3.14, "bool_val": true}`)),
want: `[{"bool_val":true,"float_val":3.14,"id":1,"int_val":123,"string_val":"hello"}]`,
want: `[{"id":1,"int_val":123,"string_val":"hello","float_val":3.14,"bool_val":true}]`,
isErr: false,
},
{
@@ -1716,7 +1716,7 @@ func runBigQueryDataTypeTests(t *testing.T) {
api: "http://127.0.0.1:5000/api/tool/my-array-datatype-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{"int_array": [123, 789], "string_array": ["hello", "test"], "float_array": [3.14, 100.1], "bool_array": [true]}`)),
want: `[{"bool_val":true,"float_val":3.14,"id":1,"int_val":123,"string_val":"hello"},{"bool_val":true,"float_val":100.1,"id":3,"int_val":789,"string_val":"test"}]`,
want: `[{"id":1,"int_val":123,"string_val":"hello","float_val":3.14,"bool_val":true},{"id":3,"int_val":789,"string_val":"test","float_val":100.1,"bool_val":true}]`,
isErr: false,
},
}