mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-12 00:49:08 -05:00
feat(sources/bigquery): add support for user-credential passthrough (#1067)
Support end-user credential passthrough with the BigQuery source and the `bigquery-sql` tool. Support for other BQ tools will be added in subsequent PRs. Issue: https://github.com/googleapis/genai-toolbox/issues/813
This commit is contained in:
@@ -65,23 +65,40 @@ Connect your IDE to BigQuery using Toolbox.
|
||||
|
||||
BigQuery uses [Identity and Access Management (IAM)][iam-overview] to control
|
||||
user and group access to BigQuery resources like projects, datasets, and tables.
|
||||
Toolbox will use your [Application Default Credentials (ADC)][adc] to authorize
|
||||
and authenticate when interacting with [BigQuery][bigquery-docs].
|
||||
|
||||
In addition to [setting the ADC for your server][set-adc], you need to ensure
|
||||
the IAM identity has been given the correct IAM permissions for the queries
|
||||
you intend to run. Common roles include `roles/bigquery.user` (which includes
|
||||
permissions to run jobs and read data) or `roles/bigquery.dataViewer`. See
|
||||
[Introduction to BigQuery IAM][grant-permissions] for more information on
|
||||
applying IAM permissions and roles to an identity.
|
||||
### Authentication via Application Default Credentials (ADC)
|
||||
|
||||
By **default**, Toolbox will use your [Application Default Credentials (ADC)][adc] to authorize and authenticate when interacting with [BigQuery][bigquery-docs].
|
||||
|
||||
When using this method, you need to ensure the IAM identity associated with your
|
||||
ADC (such as a service account) has the correct permissions for the queries you
|
||||
intend to run. Common roles include `roles/bigquery.user` (which includes
|
||||
permissions to run jobs and read data) or `roles/bigbigquery.dataViewer`.
|
||||
Follow this [guide][set-adc] to set up your ADC.
|
||||
|
||||
### Authentication via User's OAuth Access Token
|
||||
|
||||
If the `useClientOAuth` parameter is set to `true`, Toolbox will instead use the
|
||||
OAuth access token for authentication. This token is parsed from the
|
||||
`Authorization` header passed in with the tool invocation request. This method
|
||||
allows Toolbox to make queries to [BigQuery][bigquery-docs] on behalf of the
|
||||
client or the end-user.
|
||||
|
||||
When using this on-behalf-of authentication, you must ensure that the
|
||||
identity used has been granted the correct IAM permissions. Currently,
|
||||
this option is only supported by the following BigQuery tools:
|
||||
|
||||
- [`bigquery-sql`](../tools/bigquery/bigquery-sql.md)
|
||||
Run SQL queries directly against BigQuery datasets.
|
||||
|
||||
[iam-overview]: https://cloud.google.com/bigquery/docs/access-control
|
||||
[adc]: https://cloud.google.com/docs/authentication#adc
|
||||
[set-adc]: https://cloud.google.com/docs/authentication/provide-credentials-adc
|
||||
[grant-permissions]: https://cloud.google.com/bigquery/docs/access-control
|
||||
|
||||
## Example
|
||||
|
||||
Initialize a BigQuery source that uses ADC:
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-bigquery-source:
|
||||
@@ -89,10 +106,21 @@ sources:
|
||||
project: "my-project-id"
|
||||
```
|
||||
|
||||
Initialize a BigQuery source that uses the client's access token:
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-bigquery-client-auth-source:
|
||||
kind: "bigquery"
|
||||
project: "my-project-id"
|
||||
useClientOAuth: true
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-----------|:--------:|:------------:|-------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "bigquery". |
|
||||
| project | string | true | Id of the GCP project that the cluster was created in (e.g. "my-project-id"). |
|
||||
| location | string | false | Specifies the location (e.g., 'us', 'asia-northeast1') in which to run the query job. This location must match the location of any tables referenced in the query. The default behavior is for it to be executed in the US multi-region |
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|----------------|:--------:|:------------:|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "bigquery". |
|
||||
| project | string | true | Id of the GCP project that the cluster was created in (e.g. "my-project-id"). |
|
||||
| location | string | false | Specifies the location (e.g., 'us', 'asia-northeast1') in which to run the query job. This location must match the location of any tables referenced in the query. The default behavior is for it to be executed in the US multi-region |
|
||||
| useClientOAuth | bool | false | If true, forwards the client's OAuth access token from the "Authorization" header to downstream queries. |
|
||||
|
||||
@@ -167,6 +167,20 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Extract OAuth access token from the "Authorization" header (currently for
|
||||
// BigQuery end-user credentials usage only)
|
||||
accessToken := tools.AccessToken(r.Header.Get("Authorization"))
|
||||
|
||||
// Check if this specific tool requires the standard authorization header
|
||||
if tool.RequiresClientAuthorization() {
|
||||
if accessToken == "" {
|
||||
err = fmt.Errorf("tool requires client authorization but access token is missing from the request header")
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Tool authentication
|
||||
// claimsFromAuth maps the name of the authservice to the claims retrieved from it.
|
||||
claimsFromAuth := make(map[string]map[string]any)
|
||||
@@ -225,10 +239,6 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||
|
||||
// Extract OAuth access token from the "Authorization" header (currently for
|
||||
// BigQuery end-user credentials usage only)
|
||||
accessToken := tools.AccessToken(r.Header.Get("Authorization"))
|
||||
|
||||
res, err := tool.Invoke(ctx, params, accessToken)
|
||||
|
||||
// Determine what error to return to the users.
|
||||
|
||||
@@ -218,6 +218,11 @@ func (c *ToolConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interfac
|
||||
return fmt.Errorf("unable to unmarshal %q: %w", name, err)
|
||||
}
|
||||
|
||||
// `authRequired` and `useClientOAuth` cannot be specified together
|
||||
if v["authRequired"] != nil && v["useClientOAuth"] == true {
|
||||
return fmt.Errorf("`authRequired` and `useClientOAuth` are mutually exclusive. Choose only one authentication method")
|
||||
}
|
||||
|
||||
// Make `authRequired` an empty list instead of nil for Tool manifest
|
||||
if v["authRequired"] == nil {
|
||||
v["authRequired"] = []string{}
|
||||
|
||||
@@ -404,6 +404,8 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Extract OAuth access token from the "Authorization" header (currently for
|
||||
// BigQuery end-user credentials usage only)
|
||||
accessToken := tools.AccessToken(r.Header.Get("Authorization"))
|
||||
|
||||
v, res, err := processMcpMessage(ctx, body, s, protocolVersion, toolsetName, accessToken)
|
||||
@@ -440,10 +442,10 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
if rpcResponse, ok := res.(jsonrpc.JSONRPCError); ok {
|
||||
code := rpcResponse.Error.Code
|
||||
switch {
|
||||
case code == jsonrpc.INTERNAL_ERROR:
|
||||
switch code {
|
||||
case jsonrpc.INTERNAL_ERROR:
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
case code == jsonrpc.INVALID_REQUEST:
|
||||
case jsonrpc.INVALID_REQUEST:
|
||||
errStr := err.Error()
|
||||
if errors.Is(err, tools.ErrUnauthorized) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
@@ -515,6 +517,7 @@ func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVers
|
||||
err = fmt.Errorf("toolset does not exist")
|
||||
return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
res, err := mcp.ProcessMethod(ctx, protocolVersion, baseMessage.Id, baseMessage.Method, toolset, s.ResourceMgr.GetToolsMap(), body, accessToken)
|
||||
return "", res, err
|
||||
}
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
|
||||
package util
|
||||
|
||||
import "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
import (
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
)
|
||||
|
||||
const (
|
||||
// SERVER_NAME is the server name used in Implementation.
|
||||
|
||||
@@ -91,6 +91,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolsMap map[st
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// Check if this specific tool requires the standard authorization header
|
||||
if tool.RequiresClientAuthorization() {
|
||||
if accessToken == "" {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), tools.ErrUnauthorized
|
||||
}
|
||||
}
|
||||
|
||||
// marshal arguments and decode it using decodeJSON instead to prevent loss between floats/int.
|
||||
aMarshal, err := json.Marshal(toolArgument)
|
||||
if err != nil {
|
||||
|
||||
@@ -91,6 +91,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolsMap map[st
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// Check if this specific tool requires the standard authorization header
|
||||
if tool.RequiresClientAuthorization() {
|
||||
if accessToken == "" {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), tools.ErrUnauthorized
|
||||
}
|
||||
}
|
||||
|
||||
// marshal arguments and decode it using decodeJSON instead to prevent loss between floats/int.
|
||||
aMarshal, err := json.Marshal(toolArgument)
|
||||
if err != nil {
|
||||
|
||||
@@ -91,6 +91,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolsMap map[st
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// Check if this specific tool requires the standard authorization header
|
||||
if tool.RequiresClientAuthorization() {
|
||||
if accessToken == "" {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), tools.ErrUnauthorized
|
||||
}
|
||||
}
|
||||
|
||||
// marshal arguments and decode it using decodeJSON instead to prevent loss between floats/int.
|
||||
aMarshal, err := json.Marshal(toolArgument)
|
||||
if err != nil {
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/oauth2"
|
||||
@@ -34,6 +35,8 @@ const SourceKind string = "bigquery"
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
type BigqueryClientCreator func(tokenString tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
@@ -63,7 +66,7 @@ func (r Config) SourceConfigKind() string {
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
// Initializes a BigQuery Google SQL source
|
||||
client, restService, tokenSource, err := initBigQueryConnection(ctx, tracer, r.Name, r.Project, r.Location)
|
||||
client, restService, tokenSource, clientCreator, err := initBigQueryConnection(ctx, tracer, r.Name, r.Project, r.Location)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -75,6 +78,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
RestService: restService,
|
||||
TokenSource: tokenSource,
|
||||
MaxQueryResultRows: 50,
|
||||
ClientCreator: clientCreator,
|
||||
}
|
||||
return s, nil
|
||||
|
||||
@@ -90,6 +94,7 @@ type Source struct {
|
||||
RestService *bigqueryrestapi.Service
|
||||
TokenSource oauth2.TokenSource
|
||||
MaxQueryResultRows int
|
||||
ClientCreator BigqueryClientCreator
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
@@ -113,38 +118,88 @@ func (s *Source) GetMaxQueryResultRows() int {
|
||||
return s.MaxQueryResultRows
|
||||
}
|
||||
|
||||
func (s *Source) BigQueryClientCreator() BigqueryClientCreator {
|
||||
return s.ClientCreator
|
||||
}
|
||||
|
||||
func initBigQueryConnection(
|
||||
ctx context.Context,
|
||||
tracer trace.Tracer,
|
||||
name string,
|
||||
project string,
|
||||
location string,
|
||||
) (*bigqueryapi.Client, *bigqueryrestapi.Service, oauth2.TokenSource, error) {
|
||||
) (*bigqueryapi.Client, *bigqueryrestapi.Service, oauth2.TokenSource, BigqueryClientCreator, error) {
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
cred, err := google.FindDefaultCredentials(ctx, bigqueryapi.Scope)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to find default Google Cloud credentials with scope %q: %w", bigqueryapi.Scope, err)
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed to find default Google Cloud credentials with scope %q: %w", bigqueryapi.Scope, err)
|
||||
}
|
||||
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
// Initialize the high-level BigQuery client
|
||||
client, err := bigqueryapi.NewClient(ctx, project, option.WithUserAgent(userAgent), option.WithCredentials(cred))
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to create BigQuery client for project %q: %w", project, err)
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed to create BigQuery client for project %q: %w", project, err)
|
||||
}
|
||||
client.Location = location
|
||||
|
||||
// Initialize the low-level BigQuery REST service using the same credentials
|
||||
restService, err := bigqueryrestapi.NewService(ctx, option.WithUserAgent(userAgent), option.WithCredentials(cred))
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to create BigQuery v2 service: %w", err)
|
||||
return nil, nil, nil, nil, fmt.Errorf("failed to create BigQuery v2 service: %w", err)
|
||||
}
|
||||
|
||||
return client, restService, cred.TokenSource, nil
|
||||
clientCreator := newBigQueryClientCreator(ctx, project, location, userAgent)
|
||||
return client, restService, cred.TokenSource, clientCreator, nil
|
||||
}
|
||||
|
||||
// initBigQueryConnectionWithOAuthToken initialize a BigQuery client with an
|
||||
// OAuth access token.
|
||||
func initBigQueryConnectionWithOAuthToken(
|
||||
ctx context.Context,
|
||||
project string,
|
||||
location string,
|
||||
userAgent string,
|
||||
tokenString tools.AccessToken,
|
||||
) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
|
||||
// Construct token source
|
||||
token := &oauth2.Token{
|
||||
AccessToken: string(tokenString),
|
||||
}
|
||||
ts := oauth2.StaticTokenSource(token)
|
||||
|
||||
// Initialize the BigQuery client with tokenSource
|
||||
client, err := bigqueryapi.NewClient(ctx, project, option.WithUserAgent(userAgent), option.WithTokenSource(ts))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create BigQuery client for project %q: %w", project, err)
|
||||
}
|
||||
client.Location = location
|
||||
|
||||
// Initialize the low-level BigQuery REST service using the same credentials
|
||||
restService, err := bigqueryrestapi.NewService(ctx, option.WithUserAgent(userAgent), option.WithTokenSource(ts))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create BigQuery v2 service: %w", err)
|
||||
}
|
||||
|
||||
return client, restService, nil
|
||||
}
|
||||
|
||||
// newBigQueryClientCreator sets the project parameters for the init helper
|
||||
// function. The returned function takes in an OAuth access token and uses it to
|
||||
// create a BQ client.
|
||||
func newBigQueryClientCreator(
|
||||
ctx context.Context,
|
||||
project string,
|
||||
location string,
|
||||
userAgent string,
|
||||
) func(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
|
||||
return func(tokenString tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
|
||||
return initBigQueryConnectionWithOAuthToken(ctx, project, location, userAgent, tokenString)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
|
||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||
@@ -48,6 +49,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
BigQueryClient() *bigqueryapi.Client
|
||||
BigQueryRestService() *bigqueryrestapi.Service
|
||||
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
@@ -62,6 +64,7 @@ type Config struct {
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Statement string `yaml:"statement" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
}
|
||||
@@ -101,15 +104,18 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: cfg.UseClientOAuth,
|
||||
Parameters: cfg.Parameters,
|
||||
TemplateParameters: cfg.TemplateParameters,
|
||||
AllParams: allParameters,
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
|
||||
Statement: cfg.Statement,
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -121,14 +127,17 @@ type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
Statement string
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
|
||||
Statement string
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
@@ -208,11 +217,23 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
lowLevelParams = append(lowLevelParams, lowLevelParam)
|
||||
}
|
||||
|
||||
query := t.Client.Query(newStatement)
|
||||
bqClient := t.Client
|
||||
restService := t.RestService
|
||||
var query *bigqueryapi.Query
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
bqClient, restService, err = t.ClientCreator(accessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
query = bqClient.Query(newStatement)
|
||||
query.Parameters = highLevelParams
|
||||
query.Location = t.Client.Location
|
||||
|
||||
dryRunJob, err := dryRunQuery(ctx, t.RestService, t.Client.Project(), t.Client.Location, newStatement, lowLevelParams, query.ConnectionProperties)
|
||||
dryRunJob, err := dryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, newStatement, lowLevelParams, query.ConnectionProperties)
|
||||
if err != nil {
|
||||
// This is a fallback check in case the switch logic was bypassed.
|
||||
return nil, fmt.Errorf("final query validation failed: %w", err)
|
||||
@@ -277,7 +298,7 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return false
|
||||
return t.UseClientOAuth
|
||||
}
|
||||
|
||||
func BQTypeStringFromToolType(toolType string) (string, error) {
|
||||
|
||||
@@ -166,8 +166,8 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
|
||||
// Run tests
|
||||
tests.RunToolGetTest(t)
|
||||
tests.RunToolInvokeTest(t, select1Want, tests.DisableOptionalNullParamTest())
|
||||
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant)
|
||||
tests.RunToolInvokeTest(t, select1Want, tests.DisableOptionalNullParamTest(), tests.EnableClientAuthTest())
|
||||
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, tests.EnableMcpClientAuthTest())
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam,
|
||||
tests.WithCreateColArray(createColArray),
|
||||
tests.WithDdlWant(ddlWant),
|
||||
@@ -468,7 +468,13 @@ func addBigQuerySqlToolConfig(t *testing.T, config map[string]any, toolStatement
|
||||
map[string]any{"name": "bool_array", "type": "array", "description": "an array of boolean values", "items": map[string]any{"name": "item", "type": "boolean", "description": "desc"}},
|
||||
},
|
||||
}
|
||||
|
||||
tools["my-client-auth-tool"] = map[string]any{
|
||||
"kind": "bigquery-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test client authorization.",
|
||||
"useClientOAuth": true,
|
||||
"statement": "SELECT 1",
|
||||
}
|
||||
config["tools"] = tools
|
||||
return config
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ type InvokeTestConfig struct {
|
||||
nullWant string
|
||||
supportOptionalNullParam bool
|
||||
supportArrayParam bool
|
||||
supportClientAuth bool
|
||||
}
|
||||
|
||||
type InvokeTestOption func(*InvokeTestConfig)
|
||||
@@ -68,11 +69,21 @@ func DisableArrayTest() InvokeTestOption {
|
||||
}
|
||||
}
|
||||
|
||||
// EnableClientAuthTest runs the client authorization tests.
|
||||
// Only enable it if your source supports the `useClientOAuth` configuration.
|
||||
// Currently, this should only be used with the BigQuery tests.
|
||||
func EnableClientAuthTest() InvokeTestOption {
|
||||
return func(c *InvokeTestConfig) {
|
||||
c.supportClientAuth = true
|
||||
}
|
||||
}
|
||||
|
||||
/* Configurations for RunMCPToolCallMethod() */
|
||||
|
||||
// MCPTestConfig represents the various configuration options for mcp tool call tests.
|
||||
type MCPTestConfig struct {
|
||||
myToolId3NameAliceWant string
|
||||
supportClientAuth bool
|
||||
}
|
||||
|
||||
type McpTestOption func(*MCPTestConfig)
|
||||
@@ -85,6 +96,15 @@ func WithMcpMyToolId3NameAliceWant(s string) McpTestOption {
|
||||
}
|
||||
}
|
||||
|
||||
// EnableMcpClientAuthTest runs the client authorization tests.
|
||||
// Only enable it if your source supports the `useClientOAuth` configuration.
|
||||
// Currently, this should only be used with the BigQuery tests.
|
||||
func EnableMcpClientAuthTest() McpTestOption {
|
||||
return func(c *MCPTestConfig) {
|
||||
c.supportClientAuth = true
|
||||
}
|
||||
}
|
||||
|
||||
/* Configurations for RunExecuteSqlToolInvokeTest() */
|
||||
|
||||
// ExecuteSqlTestConfig represents the various configuration options for RunExecuteSqlToolInvokeTest()
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
)
|
||||
|
||||
// RunToolGet runs the tool get endpoint
|
||||
@@ -253,6 +254,7 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
|
||||
nullWant: "null",
|
||||
supportOptionalNullParam: true,
|
||||
supportArrayParam: true,
|
||||
supportClientAuth: false,
|
||||
}
|
||||
|
||||
// Apply provided options
|
||||
@@ -266,6 +268,12 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
|
||||
t.Fatalf("error getting Google ID token: %s", err)
|
||||
}
|
||||
|
||||
// Get access token
|
||||
accessToken, err := sources.GetIAMAccessToken(t.Context())
|
||||
if err != nil {
|
||||
t.Fatalf("error getting access token from ADC: %s", err)
|
||||
}
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
@@ -390,6 +398,32 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
wantStatusCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-client-auth-tool with auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-client-auth-tool/invoke",
|
||||
enabled: configs.supportClientAuth,
|
||||
requestHeader: map[string]string{"Authorization": accessToken},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
wantBody: select1Want,
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-client-auth-tool without auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-client-auth-tool/invoke",
|
||||
enabled: configs.supportClientAuth,
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
wantStatusCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
|
||||
name: "Invoke my-client-auth-tool with invalid auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-client-auth-tool/invoke",
|
||||
enabled: configs.supportClientAuth,
|
||||
requestHeader: map[string]string{"Authorization": "Bearer invalid-token"},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
wantStatusCode: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -792,6 +826,7 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant string, options ...McpTes
|
||||
// Default values for MCPTestConfig
|
||||
configs := &MCPTestConfig{
|
||||
myToolId3NameAliceWant: `{"jsonrpc":"2.0","id":"my-tool","result":{"content":[{"type":"text","text":"{\"id\":1,\"name\":\"Alice\"}"},{"type":"text","text":"{\"id\":3,\"name\":\"Sid\"}"}]}}`,
|
||||
supportClientAuth: false,
|
||||
}
|
||||
|
||||
// Apply provided options
|
||||
@@ -801,6 +836,12 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant string, options ...McpTes
|
||||
|
||||
sessionId := RunInitialize(t, "2024-11-05")
|
||||
|
||||
// Get access token
|
||||
accessToken, err := sources.GetIAMAccessToken(t.Context())
|
||||
if err != nil {
|
||||
t.Fatalf("error getting access token from ADC: %s", err)
|
||||
}
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
@@ -909,6 +950,63 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant string, options ...McpTes
|
||||
wantStatusCode: http.StatusUnauthorized,
|
||||
wantBody: "{\"jsonrpc\":\"2.0\",\"id\":\"invoke my-auth-required-tool\",\"error\":{\"code\":-32600,\"message\":\"unauthorized Tool call: `authRequired` is set for the target Tool but isn't supported through MCP Tool call: unauthorized\"}}",
|
||||
},
|
||||
|
||||
{
|
||||
name: "MCP Invoke my-client-auth-tool",
|
||||
enabled: configs.supportClientAuth,
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
requestHeader: map[string]string{"Authorization": accessToken},
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "invoke my-client-auth-tool",
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
"name": "my-client-auth-tool",
|
||||
"arguments": map[string]any{},
|
||||
},
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantBody: "{\"jsonrpc\":\"2.0\",\"id\":\"invoke my-client-auth-tool\",\"result\":{\"content\":[{\"type\":\"text\",\"text\":\"{\\\"f0_\\\":1}\"}]}}",
|
||||
},
|
||||
{
|
||||
name: "MCP Invoke my-client-auth-tool without access token",
|
||||
enabled: configs.supportClientAuth,
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "invoke my-client-auth-tool",
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
"name": "my-client-auth-tool",
|
||||
"arguments": map[string]any{},
|
||||
},
|
||||
},
|
||||
wantStatusCode: http.StatusUnauthorized,
|
||||
wantBody: "{\"jsonrpc\":\"2.0\",\"id\":\"invoke my-client-auth-tool\",\"error\":{\"code\":-32600,\"message\":\"missing access token in the 'Authorization' header\"}",
|
||||
},
|
||||
{
|
||||
name: "MCP Invoke my-client-auth-tool with invalid access token",
|
||||
enabled: configs.supportClientAuth,
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
requestHeader: map[string]string{"Authorization": "Bearer invalid-token"},
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "invoke my-client-auth-tool",
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
"name": "my-client-auth-tool",
|
||||
"arguments": map[string]any{},
|
||||
},
|
||||
},
|
||||
wantStatusCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "MCP Invoke my-fail-tool",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
|
||||
Reference in New Issue
Block a user