mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-11 08:28:11 -05:00
fix(bigquery)!: Add Bearer parsing to auth token (#1386)
Previously we propagate tokens directly to the BQ API. But MCP inspector adds a "Bearer" prefix to all authorization header. We will need to parse the token accordingly to make it work.
This commit is contained in:
@@ -21,7 +21,6 @@ import (
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/oauth2"
|
||||
@@ -35,7 +34,7 @@ const SourceKind string = "bigquery"
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
type BigqueryClientCreator func(tokenString tools.AccessToken, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
|
||||
type BigqueryClientCreator func(tokenString string, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
@@ -199,7 +198,7 @@ func initBigQueryConnectionWithOAuthToken(
|
||||
location string,
|
||||
name string,
|
||||
userAgent string,
|
||||
tokenString tools.AccessToken,
|
||||
tokenString string,
|
||||
wantRestService bool,
|
||||
) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
@@ -238,13 +237,13 @@ func newBigQueryClientCreator(
|
||||
project string,
|
||||
location string,
|
||||
name string,
|
||||
) (func(tools.AccessToken, bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error), error) {
|
||||
) (func(string, bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error), error) {
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func(tokenString tools.AccessToken, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
|
||||
return func(tokenString string, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
|
||||
return initBigQueryConnectionWithOAuthToken(ctx, tracer, project, location, name, userAgent, tokenString, wantRestService)
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -183,14 +183,18 @@ type Tool struct {
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
var tokenStr string
|
||||
var err error
|
||||
|
||||
// Get credentials for the API call
|
||||
if t.UseClientOAuth {
|
||||
// Use client-side access token
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header")
|
||||
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", tools.ErrUnauthorized)
|
||||
}
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
tokenStr = string(accessToken)
|
||||
} else {
|
||||
// Use ADC
|
||||
if t.TokenSource == nil {
|
||||
|
||||
@@ -150,7 +150,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
var err error
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
bqClient, restService, err = t.ClientCreator(accessToken, true)
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, restService, err = t.ClientCreator(tokenStr, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
|
||||
@@ -199,7 +199,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
bqClient, _, err = t.ClientCreator(accessToken, false)
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
|
||||
@@ -142,7 +142,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
bqClient, _, err = t.ClientCreator(accessToken, false)
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
|
||||
@@ -149,7 +149,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
var err error
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
bqClient, _, err = t.ClientCreator(accessToken, false)
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
|
||||
@@ -133,10 +133,13 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
var err error
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
bqClient, _, err = t.ClientCreator(accessToken, false)
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
|
||||
@@ -139,10 +139,13 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
var err error
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
bqClient, _, err = t.ClientCreator(accessToken, false)
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
|
||||
@@ -222,7 +222,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
bqClient, restService, err = t.ClientCreator(accessToken, true)
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, restService, err = t.ClientCreator(tokenStr, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -66,6 +67,14 @@ type ToolConfig interface {
|
||||
|
||||
type AccessToken string
|
||||
|
||||
func (token AccessToken) ParseBearerToken() (string, error) {
|
||||
headerParts := strings.Split(string(token), " ")
|
||||
if len(headerParts) != 2 || strings.ToLower(headerParts[0]) != "bearer" {
|
||||
return "", fmt.Errorf("authorization header must be in the format 'Bearer <token>': %w", ErrUnauthorized)
|
||||
}
|
||||
return headerParts[1], nil
|
||||
}
|
||||
|
||||
type Tool interface {
|
||||
Invoke(context.Context, ParamValues, AccessToken) (any, error)
|
||||
ParseParams(map[string]any, map[string]map[string]any) (ParamValues, error)
|
||||
|
||||
@@ -542,6 +542,7 @@ func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamW
|
||||
if err != nil {
|
||||
t.Fatalf("error getting access token from ADC: %s", err)
|
||||
}
|
||||
accessToken = "Bearer " + accessToken
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
@@ -824,6 +825,7 @@ func runBigQueryForecastToolInvokeTest(t *testing.T, tableName string) {
|
||||
if err != nil {
|
||||
t.Fatalf("error getting access token from ADC: %s", err)
|
||||
}
|
||||
accessToken = "Bearer " + accessToken
|
||||
|
||||
historyDataTable := strings.ReplaceAll(tableName, "`", "")
|
||||
historyDataQuery := fmt.Sprintf("SELECT ts, data, id FROM %s", tableName)
|
||||
@@ -1040,6 +1042,7 @@ func runBigQueryListDatasetToolInvokeTest(t *testing.T, datasetWant string) {
|
||||
if err != nil {
|
||||
t.Fatalf("error getting access token from ADC: %s", err)
|
||||
}
|
||||
accessToken = "Bearer " + accessToken
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
@@ -1161,6 +1164,7 @@ func runBigQueryGetDatasetInfoToolInvokeTest(t *testing.T, datasetName, datasetI
|
||||
if err != nil {
|
||||
t.Fatalf("error getting access token from ADC: %s", err)
|
||||
}
|
||||
accessToken = "Bearer " + accessToken
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
@@ -1310,6 +1314,7 @@ func runBigQueryListTableIdsToolInvokeTest(t *testing.T, datasetName, tablename_
|
||||
if err != nil {
|
||||
t.Fatalf("error getting access token from ADC: %s", err)
|
||||
}
|
||||
accessToken = "Bearer " + accessToken
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
@@ -1459,6 +1464,7 @@ func runBigQueryGetTableInfoToolInvokeTest(t *testing.T, datasetName, tableName,
|
||||
if err != nil {
|
||||
t.Fatalf("error getting access token from ADC: %s", err)
|
||||
}
|
||||
accessToken = "Bearer " + accessToken
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
@@ -1608,6 +1614,7 @@ func runBigQueryConversationalAnalyticsInvokeTest(t *testing.T, datasetName, tab
|
||||
if err != nil {
|
||||
t.Fatalf("error getting access token from ADC: %s", err)
|
||||
}
|
||||
accessToken = "Bearer " + accessToken
|
||||
|
||||
tableRefsJSON := fmt.Sprintf(`[{"projectId":"%s","datasetId":"%s","tableId":"%s"}]`, BigqueryProject, datasetName, tableName)
|
||||
|
||||
|
||||
@@ -273,6 +273,7 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
|
||||
if err != nil {
|
||||
t.Fatalf("error getting access token from ADC: %s", err)
|
||||
}
|
||||
accessToken = "Bearer " + accessToken
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
@@ -841,6 +842,7 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant, select1Want string, opti
|
||||
if err != nil {
|
||||
t.Fatalf("error getting access token from ADC: %s", err)
|
||||
}
|
||||
accessToken = "Bearer " + accessToken
|
||||
|
||||
idToken, err := GetGoogleIdToken(ClientId)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user