fix(bigquery)!: Add Bearer parsing to auth token (#1386)

Previously we propagate tokens directly to the BQ API. But MCP inspector
adds a "Bearer" prefix to all authorization header. We will need to
parse the token accordingly to make it work.
This commit is contained in:
Wenxin Du
2025-09-09 15:47:52 -04:00
committed by GitHub
parent cce602f280
commit b5f9780a59
12 changed files with 63 additions and 16 deletions

View File

@@ -21,7 +21,6 @@ import (
bigqueryapi "cloud.google.com/go/bigquery"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2"
@@ -35,7 +34,7 @@ const SourceKind string = "bigquery"
// validate interface
var _ sources.SourceConfig = Config{}
type BigqueryClientCreator func(tokenString tools.AccessToken, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
type BigqueryClientCreator func(tokenString string, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
func init() {
if !sources.Register(SourceKind, newConfig) {
@@ -199,7 +198,7 @@ func initBigQueryConnectionWithOAuthToken(
location string,
name string,
userAgent string,
tokenString tools.AccessToken,
tokenString string,
wantRestService bool,
) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
@@ -238,13 +237,13 @@ func newBigQueryClientCreator(
project string,
location string,
name string,
) (func(tools.AccessToken, bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error), error) {
) (func(string, bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error), error) {
userAgent, err := util.UserAgentFromContext(ctx)
if err != nil {
return nil, err
}
return func(tokenString tools.AccessToken, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
return func(tokenString string, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
return initBigQueryConnectionWithOAuthToken(ctx, tracer, project, location, name, userAgent, tokenString, wantRestService)
}, nil
}

View File

@@ -183,14 +183,18 @@ type Tool struct {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
var tokenStr string
var err error
// Get credentials for the API call
if t.UseClientOAuth {
// Use client-side access token
if accessToken == "" {
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header")
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", tools.ErrUnauthorized)
}
tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
tokenStr = string(accessToken)
} else {
// Use ADC
if t.TokenSource == nil {

View File

@@ -150,7 +150,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
bqClient, restService, err = t.ClientCreator(accessToken, true)
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = t.ClientCreator(tokenStr, true)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}

View File

@@ -199,7 +199,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
bqClient, _, err = t.ClientCreator(accessToken, false)
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}

View File

@@ -142,7 +142,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
bqClient, _, err = t.ClientCreator(accessToken, false)
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}

View File

@@ -149,7 +149,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
bqClient, _, err = t.ClientCreator(accessToken, false)
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}

View File

@@ -133,10 +133,13 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
}
bqClient := t.Client
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
bqClient, _, err = t.ClientCreator(accessToken, false)
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}

View File

@@ -139,10 +139,13 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
}
bqClient := t.Client
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
bqClient, _, err = t.ClientCreator(accessToken, false)
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}

View File

@@ -222,7 +222,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
bqClient, restService, err = t.ClientCreator(accessToken, true)
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = t.ClientCreator(tokenStr, true)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}

View File

@@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"slices"
"strings"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -66,6 +67,14 @@ type ToolConfig interface {
type AccessToken string
func (token AccessToken) ParseBearerToken() (string, error) {
headerParts := strings.Split(string(token), " ")
if len(headerParts) != 2 || strings.ToLower(headerParts[0]) != "bearer" {
return "", fmt.Errorf("authorization header must be in the format 'Bearer <token>': %w", ErrUnauthorized)
}
return headerParts[1], nil
}
type Tool interface {
Invoke(context.Context, ParamValues, AccessToken) (any, error)
ParseParams(map[string]any, map[string]map[string]any) (ParamValues, error)

View File

@@ -542,6 +542,7 @@ func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamW
if err != nil {
t.Fatalf("error getting access token from ADC: %s", err)
}
accessToken = "Bearer " + accessToken
// Test tool invoke endpoint
invokeTcs := []struct {
@@ -824,6 +825,7 @@ func runBigQueryForecastToolInvokeTest(t *testing.T, tableName string) {
if err != nil {
t.Fatalf("error getting access token from ADC: %s", err)
}
accessToken = "Bearer " + accessToken
historyDataTable := strings.ReplaceAll(tableName, "`", "")
historyDataQuery := fmt.Sprintf("SELECT ts, data, id FROM %s", tableName)
@@ -1040,6 +1042,7 @@ func runBigQueryListDatasetToolInvokeTest(t *testing.T, datasetWant string) {
if err != nil {
t.Fatalf("error getting access token from ADC: %s", err)
}
accessToken = "Bearer " + accessToken
// Test tool invoke endpoint
invokeTcs := []struct {
@@ -1161,6 +1164,7 @@ func runBigQueryGetDatasetInfoToolInvokeTest(t *testing.T, datasetName, datasetI
if err != nil {
t.Fatalf("error getting access token from ADC: %s", err)
}
accessToken = "Bearer " + accessToken
// Test tool invoke endpoint
invokeTcs := []struct {
@@ -1310,6 +1314,7 @@ func runBigQueryListTableIdsToolInvokeTest(t *testing.T, datasetName, tablename_
if err != nil {
t.Fatalf("error getting access token from ADC: %s", err)
}
accessToken = "Bearer " + accessToken
// Test tool invoke endpoint
invokeTcs := []struct {
@@ -1459,6 +1464,7 @@ func runBigQueryGetTableInfoToolInvokeTest(t *testing.T, datasetName, tableName,
if err != nil {
t.Fatalf("error getting access token from ADC: %s", err)
}
accessToken = "Bearer " + accessToken
// Test tool invoke endpoint
invokeTcs := []struct {
@@ -1608,6 +1614,7 @@ func runBigQueryConversationalAnalyticsInvokeTest(t *testing.T, datasetName, tab
if err != nil {
t.Fatalf("error getting access token from ADC: %s", err)
}
accessToken = "Bearer " + accessToken
tableRefsJSON := fmt.Sprintf(`[{"projectId":"%s","datasetId":"%s","tableId":"%s"}]`, BigqueryProject, datasetName, tableName)

View File

@@ -273,6 +273,7 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
if err != nil {
t.Fatalf("error getting access token from ADC: %s", err)
}
accessToken = "Bearer " + accessToken
// Test tool invoke endpoint
invokeTcs := []struct {
@@ -841,6 +842,7 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant, select1Want string, opti
if err != nil {
t.Fatalf("error getting access token from ADC: %s", err)
}
accessToken = "Bearer " + accessToken
idToken, err := GetGoogleIdToken(ClientId)
if err != nil {