fix(bigquery): Move useClientOAuth config from tool to source (#1279)

This commit is contained in:
Wenxin Du
2025-08-29 13:47:00 -04:00
committed by GitHub
parent 89af3a4ca3
commit 8d20a48f13
4 changed files with 102 additions and 37 deletions

View File

@@ -53,10 +53,11 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
type Config struct {
// BigQuery configs
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Location string `yaml:"location"`
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Location string `yaml:"location"`
UseClientOAuth bool `yaml:"useClientOAuth"`
}
func (r Config) SourceConfigKind() string {
@@ -65,10 +66,23 @@ func (r Config) SourceConfigKind() string {
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
// Initializes a BigQuery Google SQL source
client, restService, tokenSource, clientCreator, err := initBigQueryConnection(ctx, tracer, r.Name, r.Project, r.Location)
if err != nil {
return nil, err
var client *bigqueryapi.Client
var restService *bigqueryrestapi.Service
var tokenSource oauth2.TokenSource
var clientCreator BigqueryClientCreator
var err error
if r.UseClientOAuth {
clientCreator, err = newBigQueryClientCreator(ctx, tracer, r.Project, r.Location, r.Name)
if err != nil {
return nil, fmt.Errorf("error constructing client creator: %w", err)
}
} else {
// Initializes a BigQuery Google SQL source
client, restService, tokenSource, err = initBigQueryConnection(ctx, tracer, r.Name, r.Project, r.Location)
if err != nil {
return nil, fmt.Errorf("error creating client from ADC: %w", err)
}
}
s := &Source{
@@ -79,6 +93,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
TokenSource: tokenSource,
MaxQueryResultRows: 50,
ClientCreator: clientCreator,
UseClientOAuth: r.UseClientOAuth,
}
return s, nil
@@ -95,6 +110,7 @@ type Source struct {
TokenSource oauth2.TokenSource
MaxQueryResultRows int
ClientCreator BigqueryClientCreator
UseClientOAuth bool
}
func (s *Source) SourceKind() string {
@@ -110,6 +126,10 @@ func (s *Source) BigQueryRestService() *bigqueryrestapi.Service {
return s.RestService
}
func (s *Source) UseClientAuthorization() bool {
return s.UseClientOAuth
}
func (s *Source) BigQueryTokenSource() oauth2.TokenSource {
return s.TokenSource
}
@@ -128,46 +148,49 @@ func initBigQueryConnection(
name string,
project string,
location string,
) (*bigqueryapi.Client, *bigqueryrestapi.Service, oauth2.TokenSource, BigqueryClientCreator, error) {
) (*bigqueryapi.Client, *bigqueryrestapi.Service, oauth2.TokenSource, error) {
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
cred, err := google.FindDefaultCredentials(ctx, bigqueryapi.Scope)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("failed to find default Google Cloud credentials with scope %q: %w", bigqueryapi.Scope, err)
return nil, nil, nil, fmt.Errorf("failed to find default Google Cloud credentials with scope %q: %w", bigqueryapi.Scope, err)
}
userAgent, err := util.UserAgentFromContext(ctx)
if err != nil {
return nil, nil, nil, nil, err
return nil, nil, nil, err
}
// Initialize the high-level BigQuery client
client, err := bigqueryapi.NewClient(ctx, project, option.WithUserAgent(userAgent), option.WithCredentials(cred))
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("failed to create BigQuery client for project %q: %w", project, err)
return nil, nil, nil, fmt.Errorf("failed to create BigQuery client for project %q: %w", project, err)
}
client.Location = location
// Initialize the low-level BigQuery REST service using the same credentials
restService, err := bigqueryrestapi.NewService(ctx, option.WithUserAgent(userAgent), option.WithCredentials(cred))
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("failed to create BigQuery v2 service: %w", err)
return nil, nil, nil, fmt.Errorf("failed to create BigQuery v2 service: %w", err)
}
clientCreator := newBigQueryClientCreator(ctx, project, location, userAgent)
return client, restService, cred.TokenSource, clientCreator, nil
return client, restService, cred.TokenSource, nil
}
// initBigQueryConnectionWithOAuthToken initialize a BigQuery client with an
// OAuth access token.
func initBigQueryConnectionWithOAuthToken(
ctx context.Context,
tracer trace.Tracer,
project string,
location string,
name string,
userAgent string,
tokenString tools.AccessToken,
) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Construct token source
token := &oauth2.Token{
AccessToken: string(tokenString),
@@ -195,11 +218,17 @@ func initBigQueryConnectionWithOAuthToken(
// create a BQ client.
func newBigQueryClientCreator(
ctx context.Context,
tracer trace.Tracer,
project string,
location string,
userAgent string,
) func(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
return func(tokenString tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
return initBigQueryConnectionWithOAuthToken(ctx, project, location, userAgent, tokenString)
name string,
) (func(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error), error) {
userAgent, err := util.UserAgentFromContext(ctx)
if err != nil {
return nil, err
}
return func(tokenString tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
return initBigQueryConnectionWithOAuthToken(ctx, tracer, project, location, name, userAgent, tokenString)
}, nil
}

View File

@@ -41,10 +41,31 @@ func TestParseFromYamlBigQuery(t *testing.T) {
`,
want: server.SourceConfigs{
"my-instance": bigquery.Config{
Name: "my-instance",
Kind: bigquery.SourceKind,
Project: "my-project",
Location: "us",
Name: "my-instance",
Kind: bigquery.SourceKind,
Project: "my-project",
Location: "us",
UseClientOAuth: false,
},
},
},
{
desc: "use client auth example",
in: `
sources:
my-instance:
kind: bigquery
project: my-project
location: us
useClientOAuth: true
`,
want: server.SourceConfigs{
"my-instance": bigquery.Config{
Name: "my-instance",
Kind: bigquery.SourceKind,
Project: "my-project",
Location: "us",
UseClientOAuth: true,
},
},
},

View File

@@ -50,6 +50,7 @@ type compatibleSource interface {
BigQueryClient() *bigqueryapi.Client
BigQueryRestService() *bigqueryrestapi.Service
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
}
// validate compatible sources are still compatible
@@ -105,17 +106,17 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
Name: cfg.Name,
Kind: kind,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: cfg.UseClientOAuth,
Parameters: cfg.Parameters,
TemplateParameters: cfg.TemplateParameters,
AllParams: allParameters,
Statement: cfg.Statement,
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
ClientCreator: s.BigQueryClientCreator(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Statement: cfg.Statement,
UseClientOAuth: s.UseClientAuthorization(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
ClientCreator: s.BigQueryClientCreator(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -231,9 +232,9 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
query = bqClient.Query(newStatement)
query.Parameters = highLevelParams
query.Location = t.Client.Location
query.Location = bqClient.Location
dryRunJob, err := dryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, newStatement, lowLevelParams, query.ConnectionProperties)
dryRunJob, err := dryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, query.ConnectionProperties)
if err != nil {
// This is a fallback check in case the switch logic was bypassed.
return nil, fmt.Errorf("final query validation failed: %w", err)

View File

@@ -133,6 +133,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, BigqueryToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
toolsFile = addClientAuthSourceConfig(t, toolsFile)
toolsFile = addBigQuerySqlToolConfig(t, toolsFile, dataTypeToolStmt, arrayDataTypeToolStmt)
toolsFile = addBigQueryPrebuiltToolsConfig(t, toolsFile)
tmplSelectCombined, tmplSelectFilterCombined := getBigQueryTmplToolStatement()
@@ -440,6 +441,20 @@ func addBigQueryPrebuiltToolsConfig(t *testing.T, config map[string]any) map[str
return config
}
func addClientAuthSourceConfig(t *testing.T, config map[string]any) map[string]any {
sources, ok := config["sources"].(map[string]any)
if !ok {
t.Fatalf("unable to get sources from config")
}
sources["my-client-auth-source"] = map[string]any{
"kind": BigquerySourceKind,
"project": BigqueryProject,
"useClientOAuth": true,
}
config["sources"] = sources
return config
}
func addBigQuerySqlToolConfig(t *testing.T, config map[string]any, toolStatement, arrayToolStatement string) map[string]any {
tools, ok := config["tools"].(map[string]any)
if !ok {
@@ -470,11 +485,10 @@ func addBigQuerySqlToolConfig(t *testing.T, config map[string]any, toolStatement
},
}
tools["my-client-auth-tool"] = map[string]any{
"kind": "bigquery-sql",
"source": "my-instance",
"description": "Tool to test client authorization.",
"useClientOAuth": true,
"statement": "SELECT 1",
"kind": "bigquery-sql",
"source": "my-client-auth-source",
"description": "Tool to test client authorization.",
"statement": "SELECT 1",
}
config["tools"] = tools
return config