mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 15:38:08 -05:00
fix(bigquery): Move useClientOAuth config from tool to source (#1279)
This commit is contained in:
@@ -53,10 +53,11 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
// BigQuery configs
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Location string `yaml:"location"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Location string `yaml:"location"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
@@ -65,10 +66,23 @@ func (r Config) SourceConfigKind() string {
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
// Initializes a BigQuery Google SQL source
|
||||
client, restService, tokenSource, clientCreator, err := initBigQueryConnection(ctx, tracer, r.Name, r.Project, r.Location)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
var client *bigqueryapi.Client
|
||||
var restService *bigqueryrestapi.Service
|
||||
var tokenSource oauth2.TokenSource
|
||||
var clientCreator BigqueryClientCreator
|
||||
var err error
|
||||
|
||||
if r.UseClientOAuth {
|
||||
clientCreator, err = newBigQueryClientCreator(ctx, tracer, r.Project, r.Location, r.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error constructing client creator: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Initializes a BigQuery Google SQL source
|
||||
client, restService, tokenSource, err = initBigQueryConnection(ctx, tracer, r.Name, r.Project, r.Location)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from ADC: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
@@ -79,6 +93,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
TokenSource: tokenSource,
|
||||
MaxQueryResultRows: 50,
|
||||
ClientCreator: clientCreator,
|
||||
UseClientOAuth: r.UseClientOAuth,
|
||||
}
|
||||
return s, nil
|
||||
|
||||
@@ -95,6 +110,7 @@ type Source struct {
|
||||
TokenSource oauth2.TokenSource
|
||||
MaxQueryResultRows int
|
||||
ClientCreator BigqueryClientCreator
|
||||
UseClientOAuth bool
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
@@ -110,6 +126,10 @@ func (s *Source) BigQueryRestService() *bigqueryrestapi.Service {
|
||||
return s.RestService
|
||||
}
|
||||
|
||||
func (s *Source) UseClientAuthorization() bool {
|
||||
return s.UseClientOAuth
|
||||
}
|
||||
|
||||
func (s *Source) BigQueryTokenSource() oauth2.TokenSource {
|
||||
return s.TokenSource
|
||||
}
|
||||
@@ -128,46 +148,49 @@ func initBigQueryConnection(
|
||||
name string,
|
||||
project string,
|
||||
location string,
|
||||
) (*bigqueryapi.Client, *bigqueryrestapi.Service, oauth2.TokenSource, BigqueryClientCreator, error) {
|
||||
) (*bigqueryapi.Client, *bigqueryrestapi.Service, oauth2.TokenSource, error) {
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
cred, err := google.FindDefaultCredentials(ctx, bigqueryapi.Scope)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed to find default Google Cloud credentials with scope %q: %w", bigqueryapi.Scope, err)
|
||||
return nil, nil, nil, fmt.Errorf("failed to find default Google Cloud credentials with scope %q: %w", bigqueryapi.Scope, err)
|
||||
}
|
||||
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
// Initialize the high-level BigQuery client
|
||||
client, err := bigqueryapi.NewClient(ctx, project, option.WithUserAgent(userAgent), option.WithCredentials(cred))
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed to create BigQuery client for project %q: %w", project, err)
|
||||
return nil, nil, nil, fmt.Errorf("failed to create BigQuery client for project %q: %w", project, err)
|
||||
}
|
||||
client.Location = location
|
||||
|
||||
// Initialize the low-level BigQuery REST service using the same credentials
|
||||
restService, err := bigqueryrestapi.NewService(ctx, option.WithUserAgent(userAgent), option.WithCredentials(cred))
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed to create BigQuery v2 service: %w", err)
|
||||
return nil, nil, nil, fmt.Errorf("failed to create BigQuery v2 service: %w", err)
|
||||
}
|
||||
|
||||
clientCreator := newBigQueryClientCreator(ctx, project, location, userAgent)
|
||||
return client, restService, cred.TokenSource, clientCreator, nil
|
||||
return client, restService, cred.TokenSource, nil
|
||||
}
|
||||
|
||||
// initBigQueryConnectionWithOAuthToken initialize a BigQuery client with an
|
||||
// OAuth access token.
|
||||
func initBigQueryConnectionWithOAuthToken(
|
||||
ctx context.Context,
|
||||
tracer trace.Tracer,
|
||||
project string,
|
||||
location string,
|
||||
name string,
|
||||
userAgent string,
|
||||
tokenString tools.AccessToken,
|
||||
) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
// Construct token source
|
||||
token := &oauth2.Token{
|
||||
AccessToken: string(tokenString),
|
||||
@@ -195,11 +218,17 @@ func initBigQueryConnectionWithOAuthToken(
|
||||
// create a BQ client.
|
||||
func newBigQueryClientCreator(
|
||||
ctx context.Context,
|
||||
tracer trace.Tracer,
|
||||
project string,
|
||||
location string,
|
||||
userAgent string,
|
||||
) func(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
|
||||
return func(tokenString tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
|
||||
return initBigQueryConnectionWithOAuthToken(ctx, project, location, userAgent, tokenString)
|
||||
name string,
|
||||
) (func(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error), error) {
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func(tokenString tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
|
||||
return initBigQueryConnectionWithOAuthToken(ctx, tracer, project, location, name, userAgent, tokenString)
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -41,10 +41,31 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
Name: "my-instance",
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
UseClientOAuth: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "use client auth example",
|
||||
in: `
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
useClientOAuth: true
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
UseClientOAuth: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -50,6 +50,7 @@ type compatibleSource interface {
|
||||
BigQueryClient() *bigqueryapi.Client
|
||||
BigQueryRestService() *bigqueryrestapi.Service
|
||||
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
@@ -105,17 +106,17 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: cfg.UseClientOAuth,
|
||||
Parameters: cfg.Parameters,
|
||||
TemplateParameters: cfg.TemplateParameters,
|
||||
AllParams: allParameters,
|
||||
|
||||
Statement: cfg.Statement,
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Statement: cfg.Statement,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -231,9 +232,9 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
|
||||
query = bqClient.Query(newStatement)
|
||||
query.Parameters = highLevelParams
|
||||
query.Location = t.Client.Location
|
||||
query.Location = bqClient.Location
|
||||
|
||||
dryRunJob, err := dryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, newStatement, lowLevelParams, query.ConnectionProperties)
|
||||
dryRunJob, err := dryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, query.ConnectionProperties)
|
||||
if err != nil {
|
||||
// This is a fallback check in case the switch logic was bypassed.
|
||||
return nil, fmt.Errorf("final query validation failed: %w", err)
|
||||
|
||||
@@ -133,6 +133,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, BigqueryToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = addClientAuthSourceConfig(t, toolsFile)
|
||||
toolsFile = addBigQuerySqlToolConfig(t, toolsFile, dataTypeToolStmt, arrayDataTypeToolStmt)
|
||||
toolsFile = addBigQueryPrebuiltToolsConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := getBigQueryTmplToolStatement()
|
||||
@@ -440,6 +441,20 @@ func addBigQueryPrebuiltToolsConfig(t *testing.T, config map[string]any) map[str
|
||||
return config
|
||||
}
|
||||
|
||||
func addClientAuthSourceConfig(t *testing.T, config map[string]any) map[string]any {
|
||||
sources, ok := config["sources"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("unable to get sources from config")
|
||||
}
|
||||
sources["my-client-auth-source"] = map[string]any{
|
||||
"kind": BigquerySourceKind,
|
||||
"project": BigqueryProject,
|
||||
"useClientOAuth": true,
|
||||
}
|
||||
config["sources"] = sources
|
||||
return config
|
||||
}
|
||||
|
||||
func addBigQuerySqlToolConfig(t *testing.T, config map[string]any, toolStatement, arrayToolStatement string) map[string]any {
|
||||
tools, ok := config["tools"].(map[string]any)
|
||||
if !ok {
|
||||
@@ -470,11 +485,10 @@ func addBigQuerySqlToolConfig(t *testing.T, config map[string]any, toolStatement
|
||||
},
|
||||
}
|
||||
tools["my-client-auth-tool"] = map[string]any{
|
||||
"kind": "bigquery-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test client authorization.",
|
||||
"useClientOAuth": true,
|
||||
"statement": "SELECT 1",
|
||||
"kind": "bigquery-sql",
|
||||
"source": "my-client-auth-source",
|
||||
"description": "Tool to test client authorization.",
|
||||
"statement": "SELECT 1",
|
||||
}
|
||||
config["tools"] = tools
|
||||
return config
|
||||
|
||||
Reference in New Issue
Block a user