mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-04-09 03:02:26 -04:00
## Description
This change addresses the ask where the user may want to use custom
scopes. For instance, the default scope (bigquery) falls short from
running sql that utilizes integration with other google products, such
as Drive, Vertex AI, Cloud Run etc. With this change the user would be
able to configure custom scopes depending on their use case.
The custom scopes can be configured in the tools.yaml file, e.g.:
```yaml
sources:
bigquery-source:
kind: "bigquery"
project: ${BIGQUERY_PROJECT}
location: ${BIGQUERY_LOCATION:}
useClientOAuth: ${BIGQUERY_USE_CLIENT_OAUTH:false}
scopes:
- "https://www.googleapis.com/auth/bigquery"
- "https://www.googleapis.com/auth/drive"
```
and if the [bigquery prebuilt
config](https://github.com/googleapis/genai-toolbox/blob/main/internal/prebuiltconfigs/tools/bigquery.yaml)
is being used, then it can be set in the environment variable as well:
```shell
$ export BIGQUERY_SCOPES="https://www.googleapis.com/auth/bigquery,https://www.googleapis.com/auth/drive"
```
## PR Checklist
> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:
- [ ] Make sure you reviewed
[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [x] Make sure to open an issue as a
[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
before writing your code! That way we can discuss the change, evaluate
designs, and agree on the general idea
- [x] Ensure the tests and linter pass
- [x] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)
- [ ] Make sure to add `!` if this involve a breaking change
🛠️ Fixes #1942
879 lines
27 KiB
Go
879 lines
27 KiB
Go
// Copyright 2025 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package bigquery
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"math/big"
|
|
"net/http"
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
bigqueryapi "cloud.google.com/go/bigquery"
|
|
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
|
"github.com/goccy/go-yaml"
|
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
|
"github.com/googleapis/genai-toolbox/internal/util"
|
|
"github.com/googleapis/genai-toolbox/internal/util/orderedmap"
|
|
"go.opentelemetry.io/otel/trace"
|
|
"golang.org/x/oauth2"
|
|
"golang.org/x/oauth2/google"
|
|
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
|
"google.golang.org/api/googleapi"
|
|
"google.golang.org/api/impersonate"
|
|
"google.golang.org/api/iterator"
|
|
"google.golang.org/api/option"
|
|
)
|
|
|
|
const SourceKind string = "bigquery"
|
|
|
|
// CloudPlatformScope is a broad scope for Google Cloud Platform services.
|
|
const CloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
|
|
|
|
const (
|
|
// No write operations are allowed.
|
|
WriteModeBlocked string = "blocked"
|
|
// Only protected write operations are allowed in a BigQuery session.
|
|
WriteModeProtected string = "protected"
|
|
// All write operations are allowed.
|
|
WriteModeAllowed string = "allowed"
|
|
)
|
|
|
|
// validate interface
|
|
var _ sources.SourceConfig = Config{}
|
|
|
|
type BigqueryClientCreator func(tokenString string, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
|
|
|
|
type BigQuerySessionProvider func(ctx context.Context) (*Session, error)
|
|
|
|
type DataplexClientCreator func(tokenString string) (*dataplexapi.CatalogClient, error)
|
|
|
|
func init() {
|
|
if !sources.Register(SourceKind, newConfig) {
|
|
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
|
}
|
|
}
|
|
|
|
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
|
|
actual := Config{Name: name}
|
|
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
|
return nil, err
|
|
}
|
|
return actual, nil
|
|
}
|
|
|
|
type Config struct {
|
|
// BigQuery configs
|
|
Name string `yaml:"name" validate:"required"`
|
|
Kind string `yaml:"kind" validate:"required"`
|
|
Project string `yaml:"project" validate:"required"`
|
|
Location string `yaml:"location"`
|
|
WriteMode string `yaml:"writeMode"`
|
|
AllowedDatasets StringOrStringSlice `yaml:"allowedDatasets"`
|
|
UseClientOAuth bool `yaml:"useClientOAuth"`
|
|
ImpersonateServiceAccount string `yaml:"impersonateServiceAccount"`
|
|
Scopes StringOrStringSlice `yaml:"scopes"`
|
|
}
|
|
|
|
// StringOrStringSlice is a custom type that can unmarshal both a single string
|
|
// (which it splits by comma) and a sequence of strings into a string slice.
|
|
type StringOrStringSlice []string
|
|
|
|
// UnmarshalYAML implements the yaml.Unmarshaler interface.
|
|
func (s *StringOrStringSlice) UnmarshalYAML(unmarshal func(any) error) error {
|
|
var v any
|
|
if err := unmarshal(&v); err != nil {
|
|
return err
|
|
}
|
|
switch val := v.(type) {
|
|
case string:
|
|
*s = strings.Split(val, ",")
|
|
return nil
|
|
case []any:
|
|
for _, item := range val {
|
|
if str, ok := item.(string); ok {
|
|
*s = append(*s, str)
|
|
} else {
|
|
return fmt.Errorf("element in sequence is not a string: %v", item)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
return fmt.Errorf("cannot unmarshal %T into StringOrStringSlice", v)
|
|
}
|
|
|
|
func (r Config) SourceConfigKind() string {
|
|
// Returns BigQuery source kind
|
|
return SourceKind
|
|
}
|
|
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
|
if r.WriteMode == "" {
|
|
r.WriteMode = WriteModeAllowed
|
|
}
|
|
|
|
if r.WriteMode == WriteModeProtected && r.UseClientOAuth {
|
|
// The protected mode only allows write operations to the session's temporary datasets.
|
|
// when using client OAuth, a new session is created every
|
|
// time a BigQuery tool is invoked. Therefore, no session data can
|
|
// be preserved as needed by the protected mode.
|
|
return nil, fmt.Errorf("writeMode 'protected' cannot be used with useClientOAuth 'true'")
|
|
}
|
|
|
|
if r.UseClientOAuth && r.ImpersonateServiceAccount != "" {
|
|
return nil, fmt.Errorf("useClientOAuth cannot be used with impersonateServiceAccount")
|
|
}
|
|
|
|
var client *bigqueryapi.Client
|
|
var restService *bigqueryrestapi.Service
|
|
var tokenSource oauth2.TokenSource
|
|
var clientCreator BigqueryClientCreator
|
|
var err error
|
|
|
|
s := &Source{
|
|
Config: r,
|
|
Client: client,
|
|
RestService: restService,
|
|
TokenSource: tokenSource,
|
|
MaxQueryResultRows: 50,
|
|
ClientCreator: clientCreator,
|
|
}
|
|
|
|
if r.UseClientOAuth {
|
|
// use client OAuth
|
|
baseClientCreator, err := newBigQueryClientCreator(ctx, tracer, r.Project, r.Location, r.Name)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error constructing client creator: %w", err)
|
|
}
|
|
setupClientCaching(s, baseClientCreator)
|
|
|
|
} else {
|
|
// Initializes a BigQuery Google SQL source
|
|
client, restService, tokenSource, err = initBigQueryConnection(ctx, tracer, r.Name, r.Project, r.Location, r.ImpersonateServiceAccount, r.Scopes)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error creating client from ADC: %w", err)
|
|
}
|
|
s.Client = client
|
|
s.RestService = restService
|
|
s.TokenSource = tokenSource
|
|
}
|
|
|
|
allowedDatasets := make(map[string]struct{})
|
|
// Get full id of allowed datasets and verify they exist.
|
|
if len(r.AllowedDatasets) > 0 {
|
|
for _, allowed := range r.AllowedDatasets {
|
|
var projectID, datasetID, allowedFullID string
|
|
if strings.Contains(allowed, ".") {
|
|
parts := strings.Split(allowed, ".")
|
|
if len(parts) != 2 {
|
|
return nil, fmt.Errorf("invalid allowedDataset format: %q, expected 'project.dataset' or 'dataset'", allowed)
|
|
}
|
|
projectID = parts[0]
|
|
datasetID = parts[1]
|
|
allowedFullID = allowed
|
|
} else {
|
|
projectID = r.Project
|
|
datasetID = allowed
|
|
allowedFullID = fmt.Sprintf("%s.%s", projectID, datasetID)
|
|
}
|
|
|
|
if s.Client != nil {
|
|
dataset := s.Client.DatasetInProject(projectID, datasetID)
|
|
_, err := dataset.Metadata(ctx)
|
|
if err != nil {
|
|
if gerr, ok := err.(*googleapi.Error); ok && gerr.Code == http.StatusNotFound {
|
|
return nil, fmt.Errorf("allowedDataset '%s' not found in project '%s'", datasetID, projectID)
|
|
}
|
|
return nil, fmt.Errorf("failed to verify allowedDataset '%s' in project '%s': %w", datasetID, projectID, err)
|
|
}
|
|
}
|
|
allowedDatasets[allowedFullID] = struct{}{}
|
|
}
|
|
}
|
|
|
|
s.AllowedDatasets = allowedDatasets
|
|
s.SessionProvider = s.newBigQuerySessionProvider()
|
|
|
|
if r.WriteMode != WriteModeAllowed && r.WriteMode != WriteModeBlocked && r.WriteMode != WriteModeProtected {
|
|
return nil, fmt.Errorf("invalid writeMode %q: must be one of %q, %q, or %q", r.WriteMode, WriteModeAllowed, WriteModeProtected, WriteModeBlocked)
|
|
}
|
|
s.makeDataplexCatalogClient = s.lazyInitDataplexClient(ctx, tracer)
|
|
return s, nil
|
|
}
|
|
|
|
// setupClientCaching initializes caches and wraps the base client creator with caching logic.
|
|
func setupClientCaching(s *Source, baseCreator BigqueryClientCreator) {
|
|
// Define eviction handlers
|
|
onBqEvict := func(key string, value interface{}) {
|
|
if client, ok := value.(*bigqueryapi.Client); ok && client != nil {
|
|
client.Close()
|
|
}
|
|
}
|
|
onDataplexEvict := func(key string, value interface{}) {
|
|
if client, ok := value.(*dataplexapi.CatalogClient); ok && client != nil {
|
|
client.Close()
|
|
}
|
|
}
|
|
|
|
// Initialize caches
|
|
s.bqClientCache = NewCache(onBqEvict)
|
|
s.bqRestCache = NewCache(nil)
|
|
s.dataplexCache = NewCache(onDataplexEvict)
|
|
|
|
// Create the caching wrapper for the client creator
|
|
s.ClientCreator = func(tokenString string, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
|
|
// Check cache
|
|
bqClientVal, bqFound := s.bqClientCache.Get(tokenString)
|
|
|
|
if wantRestService {
|
|
restServiceVal, restFound := s.bqRestCache.Get(tokenString)
|
|
if bqFound && restFound {
|
|
// Cache hit for both
|
|
return bqClientVal.(*bigqueryapi.Client), restServiceVal.(*bigqueryrestapi.Service), nil
|
|
}
|
|
} else {
|
|
if bqFound {
|
|
return bqClientVal.(*bigqueryapi.Client), nil, nil
|
|
}
|
|
}
|
|
|
|
// Cache miss - call the client creator
|
|
client, restService, err := baseCreator(tokenString, wantRestService)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
// Set in cache
|
|
s.bqClientCache.Set(tokenString, client)
|
|
if wantRestService && restService != nil {
|
|
s.bqRestCache.Set(tokenString, restService)
|
|
}
|
|
|
|
return client, restService, nil
|
|
}
|
|
}
|
|
|
|
var _ sources.Source = &Source{}
|
|
|
|
type Source struct {
|
|
Config
|
|
Client *bigqueryapi.Client
|
|
RestService *bigqueryrestapi.Service
|
|
TokenSource oauth2.TokenSource
|
|
MaxQueryResultRows int
|
|
ClientCreator BigqueryClientCreator
|
|
AllowedDatasets map[string]struct{}
|
|
sessionMutex sync.Mutex
|
|
makeDataplexCatalogClient func() (*dataplexapi.CatalogClient, DataplexClientCreator, error)
|
|
SessionProvider BigQuerySessionProvider
|
|
Session *Session
|
|
|
|
// Caches for OAuth clients
|
|
bqClientCache *Cache
|
|
bqRestCache *Cache
|
|
dataplexCache *Cache
|
|
}
|
|
|
|
type Session struct {
|
|
ID string
|
|
ProjectID string
|
|
DatasetID string
|
|
CreationTime time.Time
|
|
LastUsed time.Time
|
|
}
|
|
|
|
func (s *Source) SourceKind() string {
|
|
// Returns BigQuery Google SQL source kind
|
|
return SourceKind
|
|
}
|
|
|
|
func (s *Source) ToConfig() sources.SourceConfig {
|
|
return s.Config
|
|
}
|
|
|
|
func (s *Source) BigQueryClient() *bigqueryapi.Client {
|
|
return s.Client
|
|
}
|
|
|
|
func (s *Source) BigQueryRestService() *bigqueryrestapi.Service {
|
|
return s.RestService
|
|
}
|
|
|
|
func (s *Source) BigQueryWriteMode() string {
|
|
return s.WriteMode
|
|
}
|
|
|
|
func (s *Source) BigQuerySession() BigQuerySessionProvider {
|
|
return s.SessionProvider
|
|
}
|
|
|
|
func (s *Source) newBigQuerySessionProvider() BigQuerySessionProvider {
|
|
return func(ctx context.Context) (*Session, error) {
|
|
if s.WriteMode != WriteModeProtected {
|
|
return nil, nil
|
|
}
|
|
|
|
s.sessionMutex.Lock()
|
|
defer s.sessionMutex.Unlock()
|
|
|
|
logger, err := util.LoggerFromContext(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get logger from context: %w", err)
|
|
}
|
|
|
|
if s.Session != nil {
|
|
// Absolute 7-day lifetime check.
|
|
const sessionMaxLifetime = 7 * 24 * time.Hour
|
|
// This assumes a single task will not exceed 30 minutes, preventing it from failing mid-execution.
|
|
const refreshThreshold = 30 * time.Minute
|
|
if time.Since(s.Session.CreationTime) > (sessionMaxLifetime - refreshThreshold) {
|
|
logger.DebugContext(ctx, "Session is approaching its 7-day maximum lifetime. Creating a new one.")
|
|
} else {
|
|
job := &bigqueryrestapi.Job{
|
|
Configuration: &bigqueryrestapi.JobConfiguration{
|
|
DryRun: true,
|
|
Query: &bigqueryrestapi.JobConfigurationQuery{
|
|
Query: "SELECT 1",
|
|
UseLegacySql: new(bool),
|
|
ConnectionProperties: []*bigqueryrestapi.ConnectionProperty{{Key: "session_id", Value: s.Session.ID}},
|
|
},
|
|
},
|
|
}
|
|
_, err := s.RestService.Jobs.Insert(s.Project, job).Do()
|
|
if err == nil {
|
|
s.Session.LastUsed = time.Now()
|
|
return s.Session, nil
|
|
}
|
|
logger.DebugContext(ctx, "Session validation failed (likely expired), creating a new one.", "error", err)
|
|
}
|
|
}
|
|
|
|
// Create a new session if one doesn't exist, it has passed its 7-day lifetime,
|
|
// or it failed the validation dry run.
|
|
|
|
creationTime := time.Now()
|
|
job := &bigqueryrestapi.Job{
|
|
JobReference: &bigqueryrestapi.JobReference{
|
|
ProjectId: s.Project,
|
|
Location: s.Location,
|
|
},
|
|
Configuration: &bigqueryrestapi.JobConfiguration{
|
|
DryRun: true,
|
|
Query: &bigqueryrestapi.JobConfigurationQuery{
|
|
Query: "SELECT 1",
|
|
CreateSession: true,
|
|
},
|
|
},
|
|
}
|
|
|
|
createdJob, err := s.RestService.Jobs.Insert(s.Project, job).Do()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create new session: %w", err)
|
|
}
|
|
|
|
var sessionID, sessionDatasetID, projectID string
|
|
if createdJob.Status != nil && createdJob.Statistics.SessionInfo != nil {
|
|
sessionID = createdJob.Statistics.SessionInfo.SessionId
|
|
} else {
|
|
return nil, fmt.Errorf("failed to get session ID from new session job")
|
|
}
|
|
|
|
if createdJob.Configuration != nil && createdJob.Configuration.Query != nil && createdJob.Configuration.Query.DestinationTable != nil {
|
|
sessionDatasetID = createdJob.Configuration.Query.DestinationTable.DatasetId
|
|
projectID = createdJob.Configuration.Query.DestinationTable.ProjectId
|
|
} else {
|
|
return nil, fmt.Errorf("failed to get session dataset ID from new session job")
|
|
}
|
|
|
|
s.Session = &Session{
|
|
ID: sessionID,
|
|
ProjectID: projectID,
|
|
DatasetID: sessionDatasetID,
|
|
CreationTime: creationTime,
|
|
LastUsed: creationTime,
|
|
}
|
|
return s.Session, nil
|
|
}
|
|
}
|
|
|
|
func (s *Source) UseClientAuthorization() bool {
|
|
return s.UseClientOAuth
|
|
}
|
|
|
|
func (s *Source) BigQueryProject() string {
|
|
return s.Project
|
|
}
|
|
|
|
func (s *Source) BigQueryLocation() string {
|
|
return s.Location
|
|
}
|
|
|
|
func (s *Source) BigQueryTokenSource() oauth2.TokenSource {
|
|
return s.TokenSource
|
|
}
|
|
|
|
func (s *Source) BigQueryTokenSourceWithScope(ctx context.Context, scopes []string) (oauth2.TokenSource, error) {
|
|
if len(scopes) == 0 {
|
|
scopes = s.Scopes
|
|
if len(scopes) == 0 {
|
|
scopes = []string{CloudPlatformScope}
|
|
}
|
|
}
|
|
|
|
if s.ImpersonateServiceAccount != "" {
|
|
// Create impersonated credentials token source with the requested scopes
|
|
ts, err := impersonate.CredentialsTokenSource(ctx, impersonate.CredentialsConfig{
|
|
TargetPrincipal: s.ImpersonateServiceAccount,
|
|
Scopes: scopes,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create impersonated credentials for %q with scopes %v: %w", s.ImpersonateServiceAccount, scopes, err)
|
|
}
|
|
return ts, nil
|
|
}
|
|
return google.DefaultTokenSource(ctx, scopes...)
|
|
}
|
|
|
|
func (s *Source) GetMaxQueryResultRows() int {
|
|
return s.MaxQueryResultRows
|
|
}
|
|
|
|
func (s *Source) BigQueryClientCreator() BigqueryClientCreator {
|
|
return s.ClientCreator
|
|
}
|
|
|
|
func (s *Source) BigQueryAllowedDatasets() []string {
|
|
if len(s.AllowedDatasets) == 0 {
|
|
return nil
|
|
}
|
|
datasets := make([]string, 0, len(s.AllowedDatasets))
|
|
for d := range s.AllowedDatasets {
|
|
datasets = append(datasets, d)
|
|
}
|
|
return datasets
|
|
}
|
|
|
|
// IsDatasetAllowed checks if a given dataset is accessible based on the source's configuration.
|
|
func (s *Source) IsDatasetAllowed(projectID, datasetID string) bool {
|
|
// If the normalized map is empty, it means no restrictions were configured.
|
|
if len(s.AllowedDatasets) == 0 {
|
|
return true
|
|
}
|
|
|
|
targetDataset := fmt.Sprintf("%s.%s", projectID, datasetID)
|
|
_, ok := s.AllowedDatasets[targetDataset]
|
|
return ok
|
|
}
|
|
|
|
func (s *Source) MakeDataplexCatalogClient() func() (*dataplexapi.CatalogClient, DataplexClientCreator, error) {
|
|
return s.makeDataplexCatalogClient
|
|
}
|
|
|
|
func (s *Source) lazyInitDataplexClient(ctx context.Context, tracer trace.Tracer) func() (*dataplexapi.CatalogClient, DataplexClientCreator, error) {
|
|
var once sync.Once
|
|
var client *dataplexapi.CatalogClient
|
|
var clientCreator DataplexClientCreator
|
|
var err error
|
|
|
|
return func() (*dataplexapi.CatalogClient, DataplexClientCreator, error) {
|
|
once.Do(func() {
|
|
c, cc, e := initDataplexConnection(ctx, tracer, s.Name, s.Project, s.UseClientOAuth, s.ImpersonateServiceAccount, s.Scopes)
|
|
if e != nil {
|
|
err = fmt.Errorf("failed to initialize dataplex client: %w", e)
|
|
return
|
|
}
|
|
client = c
|
|
|
|
// If using OAuth, wrap the provided client creator (cc) with caching logic
|
|
if s.UseClientOAuth && cc != nil {
|
|
clientCreator = func(tokenString string) (*dataplexapi.CatalogClient, error) {
|
|
// Check cache
|
|
if val, found := s.dataplexCache.Get(tokenString); found {
|
|
return val.(*dataplexapi.CatalogClient), nil
|
|
}
|
|
|
|
// Cache miss - call client creator
|
|
dpClient, err := cc(tokenString)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Set in cache
|
|
s.dataplexCache.Set(tokenString, dpClient)
|
|
return dpClient, nil
|
|
}
|
|
} else {
|
|
// Not using OAuth or no creator was returned
|
|
clientCreator = cc
|
|
}
|
|
})
|
|
return client, clientCreator, err
|
|
}
|
|
}
|
|
|
|
func (s *Source) RetrieveClientAndService(accessToken tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
|
|
bqClient := s.BigQueryClient()
|
|
restService := s.BigQueryRestService()
|
|
|
|
// Initialize new client if using user OAuth token
|
|
if s.UseClientAuthorization() {
|
|
tokenStr, err := accessToken.ParseBearerToken()
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("error parsing access token: %w", err)
|
|
}
|
|
bqClient, restService, err = s.BigQueryClientCreator()(tokenStr, true)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
|
}
|
|
}
|
|
return bqClient, restService, nil
|
|
}
|
|
|
|
func (s *Source) RunSQL(ctx context.Context, bqClient *bigqueryapi.Client, statement, statementType string, params []bigqueryapi.QueryParameter, connProps []*bigqueryapi.ConnectionProperty) (any, error) {
|
|
query := bqClient.Query(statement)
|
|
query.Location = bqClient.Location
|
|
if params != nil {
|
|
query.Parameters = params
|
|
}
|
|
if connProps != nil {
|
|
query.ConnectionProperties = connProps
|
|
}
|
|
|
|
// This block handles SELECT statements, which return a row set.
|
|
// We iterate through the results, convert each row into a map of
|
|
// column names to values, and return the collection of rows.
|
|
job, err := query.Run(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to execute query: %w", err)
|
|
}
|
|
it, err := job.Read(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to read query results: %w", err)
|
|
}
|
|
|
|
var out []any
|
|
for {
|
|
var val []bigqueryapi.Value
|
|
err = it.Next(&val)
|
|
if err == iterator.Done {
|
|
break
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to iterate through query results: %w", err)
|
|
}
|
|
schema := it.Schema
|
|
row := orderedmap.Row{}
|
|
for i, field := range schema {
|
|
row.Add(field.Name, NormalizeValue(val[i]))
|
|
}
|
|
out = append(out, row)
|
|
}
|
|
// If the query returned any rows, return them directly.
|
|
if len(out) > 0 {
|
|
return out, nil
|
|
}
|
|
|
|
// This handles the standard case for a SELECT query that successfully
|
|
// executes but returns zero rows.
|
|
if statementType == "SELECT" {
|
|
return "The query returned 0 rows.", nil
|
|
}
|
|
// This is the fallback for a successful query that doesn't return content.
|
|
// In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc.
|
|
// However, it is also possible that this was a query that was expected to return rows
|
|
// but returned none, a case that we cannot distinguish here.
|
|
return "Query executed successfully and returned no content.", nil
|
|
}
|
|
|
|
// NormalizeValue converts BigQuery specific types to standard JSON-compatible types.
|
|
// Specifically, it handles *big.Rat (used for NUMERIC/BIGNUMERIC) by converting
|
|
// them to decimal strings with up to 38 digits of precision, trimming trailing zeros.
|
|
// It recursively handles slices (arrays) and maps (structs) using reflection.
|
|
func NormalizeValue(v any) any {
|
|
if v == nil {
|
|
return nil
|
|
}
|
|
|
|
// Handle *big.Rat specifically.
|
|
if rat, ok := v.(*big.Rat); ok {
|
|
// Convert big.Rat to a decimal string.
|
|
// Use a precision of 38 digits (enough for BIGNUMERIC and NUMERIC)
|
|
// and trim trailing zeros to match BigQuery's behavior.
|
|
s := rat.FloatString(38)
|
|
if strings.Contains(s, ".") {
|
|
s = strings.TrimRight(s, "0")
|
|
s = strings.TrimRight(s, ".")
|
|
}
|
|
return s
|
|
}
|
|
|
|
// Use reflection for slices and maps to handle various underlying types.
|
|
rv := reflect.ValueOf(v)
|
|
switch rv.Kind() {
|
|
case reflect.Slice, reflect.Array:
|
|
// Preserve []byte as is, so json.Marshal encodes it as Base64 string (BigQuery BYTES behavior).
|
|
if rv.Type().Elem().Kind() == reflect.Uint8 {
|
|
return v
|
|
}
|
|
newSlice := make([]any, rv.Len())
|
|
for i := 0; i < rv.Len(); i++ {
|
|
newSlice[i] = NormalizeValue(rv.Index(i).Interface())
|
|
}
|
|
return newSlice
|
|
case reflect.Map:
|
|
// Ensure keys are strings to produce a JSON-compatible map.
|
|
if rv.Type().Key().Kind() != reflect.String {
|
|
return v
|
|
}
|
|
newMap := make(map[string]any, rv.Len())
|
|
iter := rv.MapRange()
|
|
for iter.Next() {
|
|
newMap[iter.Key().String()] = NormalizeValue(iter.Value().Interface())
|
|
}
|
|
return newMap
|
|
}
|
|
|
|
return v
|
|
}
|
|
|
|
func initBigQueryConnection(
|
|
ctx context.Context,
|
|
tracer trace.Tracer,
|
|
name string,
|
|
project string,
|
|
location string,
|
|
impersonateServiceAccount string,
|
|
scopes []string,
|
|
) (*bigqueryapi.Client, *bigqueryrestapi.Service, oauth2.TokenSource, error) {
|
|
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
|
defer span.End()
|
|
|
|
userAgent, err := util.UserAgentFromContext(ctx)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
var tokenSource oauth2.TokenSource
|
|
var opts []option.ClientOption
|
|
|
|
var credScopes []string
|
|
if len(scopes) > 0 {
|
|
credScopes = scopes
|
|
} else if impersonateServiceAccount != "" {
|
|
credScopes = []string{CloudPlatformScope}
|
|
} else {
|
|
credScopes = []string{bigqueryapi.Scope}
|
|
}
|
|
|
|
if impersonateServiceAccount != "" {
|
|
// Create impersonated credentials token source
|
|
// This broader scope is needed for tools like conversational analytics
|
|
cloudPlatformTokenSource, err := impersonate.CredentialsTokenSource(ctx, impersonate.CredentialsConfig{
|
|
TargetPrincipal: impersonateServiceAccount,
|
|
Scopes: credScopes,
|
|
})
|
|
if err != nil {
|
|
return nil, nil, nil, fmt.Errorf("failed to create impersonated credentials for %q: %w", impersonateServiceAccount, err)
|
|
}
|
|
tokenSource = cloudPlatformTokenSource
|
|
opts = []option.ClientOption{
|
|
option.WithUserAgent(userAgent),
|
|
option.WithTokenSource(cloudPlatformTokenSource),
|
|
}
|
|
} else {
|
|
// Use default credentials
|
|
cred, err := google.FindDefaultCredentials(ctx, credScopes...)
|
|
if err != nil {
|
|
return nil, nil, nil, fmt.Errorf("failed to find default Google Cloud credentials with scopes %v: %w", credScopes, err)
|
|
}
|
|
tokenSource = cred.TokenSource
|
|
opts = []option.ClientOption{
|
|
option.WithUserAgent(userAgent),
|
|
option.WithCredentials(cred),
|
|
}
|
|
}
|
|
|
|
// Initialize the high-level BigQuery client
|
|
client, err := bigqueryapi.NewClient(ctx, project, opts...)
|
|
if err != nil {
|
|
return nil, nil, nil, fmt.Errorf("failed to create BigQuery client for project %q: %w", project, err)
|
|
}
|
|
client.Location = location
|
|
|
|
// Initialize the low-level BigQuery REST service using the same credentials
|
|
restService, err := bigqueryrestapi.NewService(ctx, opts...)
|
|
if err != nil {
|
|
return nil, nil, nil, fmt.Errorf("failed to create BigQuery v2 service: %w", err)
|
|
}
|
|
|
|
return client, restService, tokenSource, nil
|
|
}
|
|
|
|
// initBigQueryConnectionWithOAuthToken initialize a BigQuery client with an
|
|
// OAuth access token.
|
|
func initBigQueryConnectionWithOAuthToken(
|
|
ctx context.Context,
|
|
tracer trace.Tracer,
|
|
project string,
|
|
location string,
|
|
name string,
|
|
userAgent string,
|
|
tokenString string,
|
|
wantRestService bool,
|
|
) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
|
|
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
|
defer span.End()
|
|
// Construct token source
|
|
token := &oauth2.Token{
|
|
AccessToken: string(tokenString),
|
|
}
|
|
ts := oauth2.StaticTokenSource(token)
|
|
|
|
// Initialize the BigQuery client with tokenSource
|
|
client, err := bigqueryapi.NewClient(ctx, project, option.WithUserAgent(userAgent), option.WithTokenSource(ts))
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to create BigQuery client for project %q: %w", project, err)
|
|
}
|
|
client.Location = location
|
|
|
|
if wantRestService {
|
|
// Initialize the low-level BigQuery REST service using the same credentials
|
|
restService, err := bigqueryrestapi.NewService(ctx, option.WithUserAgent(userAgent), option.WithTokenSource(ts))
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to create BigQuery v2 service: %w", err)
|
|
}
|
|
return client, restService, nil
|
|
}
|
|
|
|
return client, nil, nil
|
|
}
|
|
|
|
// newBigQueryClientCreator sets the project parameters for the init helper
|
|
// function. The returned function takes in an OAuth access token and uses it to
|
|
// create a BQ client.
|
|
func newBigQueryClientCreator(
|
|
ctx context.Context,
|
|
tracer trace.Tracer,
|
|
project string,
|
|
location string,
|
|
name string,
|
|
) (func(string, bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error), error) {
|
|
userAgent, err := util.UserAgentFromContext(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return func(tokenString string, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
|
|
return initBigQueryConnectionWithOAuthToken(ctx, tracer, project, location, name, userAgent, tokenString, wantRestService)
|
|
}, nil
|
|
}
|
|
|
|
func initDataplexConnection(
|
|
ctx context.Context,
|
|
tracer trace.Tracer,
|
|
name string,
|
|
project string,
|
|
useClientOAuth bool,
|
|
impersonateServiceAccount string,
|
|
scopes []string,
|
|
) (*dataplexapi.CatalogClient, DataplexClientCreator, error) {
|
|
var client *dataplexapi.CatalogClient
|
|
var clientCreator DataplexClientCreator
|
|
var err error
|
|
|
|
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
|
defer span.End()
|
|
|
|
userAgent, err := util.UserAgentFromContext(ctx)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
if useClientOAuth {
|
|
clientCreator = newDataplexClientCreator(ctx, project, userAgent)
|
|
} else {
|
|
var opts []option.ClientOption
|
|
|
|
credScopes := scopes
|
|
if len(credScopes) == 0 {
|
|
credScopes = []string{CloudPlatformScope}
|
|
}
|
|
|
|
if impersonateServiceAccount != "" {
|
|
// Create impersonated credentials token source
|
|
ts, err := impersonate.CredentialsTokenSource(ctx, impersonate.CredentialsConfig{
|
|
TargetPrincipal: impersonateServiceAccount,
|
|
Scopes: credScopes,
|
|
})
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to create impersonated credentials for %q: %w", impersonateServiceAccount, err)
|
|
}
|
|
opts = []option.ClientOption{
|
|
option.WithUserAgent(userAgent),
|
|
option.WithTokenSource(ts),
|
|
}
|
|
} else {
|
|
// Use default credentials
|
|
cred, err := google.FindDefaultCredentials(ctx, credScopes...)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to find default Google Cloud credentials: %w", err)
|
|
}
|
|
opts = []option.ClientOption{
|
|
option.WithUserAgent(userAgent),
|
|
option.WithCredentials(cred),
|
|
}
|
|
}
|
|
|
|
client, err = dataplexapi.NewCatalogClient(ctx, opts...)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to create Dataplex client for project %q: %w", project, err)
|
|
}
|
|
}
|
|
|
|
return client, clientCreator, nil
|
|
}
|
|
|
|
func initDataplexConnectionWithOAuthToken(
|
|
ctx context.Context,
|
|
project string,
|
|
userAgent string,
|
|
tokenString string,
|
|
) (*dataplexapi.CatalogClient, error) {
|
|
// Construct token source
|
|
token := &oauth2.Token{
|
|
AccessToken: string(tokenString),
|
|
}
|
|
ts := oauth2.StaticTokenSource(token)
|
|
|
|
client, err := dataplexapi.NewCatalogClient(ctx, option.WithUserAgent(userAgent), option.WithTokenSource(ts))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create Dataplex client for project %q: %w", project, err)
|
|
}
|
|
return client, nil
|
|
}
|
|
|
|
func newDataplexClientCreator(
|
|
ctx context.Context,
|
|
project string,
|
|
userAgent string,
|
|
) func(string) (*dataplexapi.CatalogClient, error) {
|
|
return func(tokenString string) (*dataplexapi.CatalogClient, error) {
|
|
return initDataplexConnectionWithOAuthToken(ctx, project, userAgent, tokenString)
|
|
}
|
|
}
|