feat(tools/bigquery-get-table-info)!: add allowed dataset support (#1093)

This introduces a breaking change. The bigquery-get-table-info tool will
now enforce the allowed datasets setting from its BigQuery source
configuration. Previously, this setting had no effect on the tool.

Part of https://github.com/googleapis/genai-toolbox/issues/873

---------

Co-authored-by: Nikunj Badjatya <nikunj.badjatya@harness.io>
This commit is contained in:
Huan Chen
2025-10-08 16:41:40 -07:00
committed by GitHub
parent 86eecc356d
commit acb205ca47
5 changed files with 166 additions and 54 deletions

View File

@@ -15,10 +15,20 @@ It's compatible with the following sources:
- [bigquery](../../sources/bigquery.md)
`bigquery-get-table-info` takes `dataset` and `table` parameters to specify
the target table. It also optionally accepts a `project` parameter to define
the Google Cloud project ID. If the `project` parameter is not provided, the
tool defaults to using the project defined in the source configuration.
`bigquery-get-table-info` accepts the following parameters:
- **`table`** (required): The name of the table for which to retrieve metadata.
- **`dataset`** (required): The dataset containing the specified table.
- **`project`** (optional): The Google Cloud project ID. If not provided, the
tool defaults to the project from the source configuration.
The tool's behavior regarding these parameters is influenced by the
`allowedDatasets` restriction on the `bigquery` source:
- **Without `allowedDatasets` restriction:** The tool can retrieve metadata for
any table specified by the `table`, `dataset`, and `project` parameters.
- **With `allowedDatasets` restriction:** Before retrieving metadata, the tool
verifies that the requested dataset is in the allowed list. If it is not, the
request is denied. If only one dataset is specified in the `allowedDatasets`
list, it will be used as the default value for the `dataset` parameter.
## Example

View File

@@ -17,8 +17,11 @@ package bigquerycommon
import (
"context"
"fmt"
"sort"
"strings"
bigqueryapi "cloud.google.com/go/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
)
@@ -69,3 +72,49 @@ func BQTypeStringFromToolType(toolType string) (string, error) {
return "", fmt.Errorf("unsupported tool parameter type for BigQuery: %s", toolType)
}
}
// InitializeDatasetParameters generates project and dataset tool parameters based on allowedDatasets.
func InitializeDatasetParameters(
allowedDatasets []string,
defaultProjectID string,
projectKey, datasetKey string,
projectDescription, datasetDescription string,
) (projectParam, datasetParam tools.Parameter) {
if len(allowedDatasets) > 0 {
if len(allowedDatasets) == 1 {
parts := strings.Split(allowedDatasets[0], ".")
defaultProjectID = parts[0]
datasetID := parts[1]
projectDescription += fmt.Sprintf(" Must be `%s`.", defaultProjectID)
datasetDescription += fmt.Sprintf(" Must be `%s`.", datasetID)
datasetParam = tools.NewStringParameterWithDefault(datasetKey, datasetID, datasetDescription)
} else {
datasetIDsByProject := make(map[string][]string)
for _, ds := range allowedDatasets {
parts := strings.Split(ds, ".")
project := parts[0]
dataset := parts[1]
datasetIDsByProject[project] = append(datasetIDsByProject[project], fmt.Sprintf("`%s`", dataset))
}
var datasetDescriptions, projectIDList []string
for project, datasets := range datasetIDsByProject {
sort.Strings(datasets)
projectIDList = append(projectIDList, fmt.Sprintf("`%s`", project))
datasetList := strings.Join(datasets, ", ")
datasetDescriptions = append(datasetDescriptions, fmt.Sprintf("%s from project `%s`", datasetList, project))
}
sort.Strings(projectIDList)
sort.Strings(datasetDescriptions)
projectDescription += fmt.Sprintf(" Must be one of the following: %s.", strings.Join(projectIDList, ", "))
datasetDescription += fmt.Sprintf(" Must be one of the allowed datasets: %s.", strings.Join(datasetDescriptions, "; "))
datasetParam = tools.NewStringParameter(datasetKey, datasetDescription)
}
} else {
datasetParam = tools.NewStringParameter(datasetKey, datasetDescription)
}
projectParam = tools.NewStringParameterWithDefault(projectKey, defaultProjectID, projectDescription)
return projectParam, datasetParam
}

View File

@@ -23,6 +23,7 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
)
const kind string = "bigquery-get-table-info"
@@ -49,6 +50,8 @@ type compatibleSource interface {
BigQueryClient() *bigqueryapi.Client
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
IsDatasetAllowed(projectID, datasetID string) bool
BigQueryAllowedDatasets() []string
}
// validate compatible sources are still compatible
@@ -84,8 +87,19 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
projectParameter := tools.NewStringParameterWithDefault(projectKey, s.BigQueryProject(), "The Google Cloud project ID containing the dataset and table.")
datasetParameter := tools.NewStringParameter(datasetKey, "The table's parent dataset.")
defaultProjectID := s.BigQueryProject()
projectDescription := "The Google Cloud project ID containing the dataset and table."
datasetDescription := "The table's parent dataset."
var datasetParameter tools.Parameter
var projectParameter tools.Parameter
projectParameter, datasetParameter = bqutil.InitializeDatasetParameters(
s.BigQueryAllowedDatasets(),
defaultProjectID,
projectKey, datasetKey,
projectDescription, datasetDescription,
)
tableParameter := tools.NewStringParameter(tableKey, "The table to get metadata information.")
parameters := tools.Parameters{projectParameter, datasetParameter, tableParameter}
@@ -93,15 +107,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
IsDatasetAllowed: s.IsDatasetAllowed,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -116,11 +131,12 @@ type Tool struct {
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
Statement string
IsDatasetAllowed func(projectID, datasetID string) bool
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
@@ -140,6 +156,10 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", tableKey)
}
if !t.IsDatasetAllowed(projectId, datasetId) {
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
}
bqClient := t.Client
var err error

View File

@@ -17,14 +17,13 @@ package bigquerylisttableids
import (
"context"
"fmt"
"sort"
"strings"
bigqueryapi "cloud.google.com/go/bigquery"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
"google.golang.org/api/iterator"
)
@@ -92,39 +91,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
projectDescription := "The Google Cloud project ID containing the dataset."
datasetDescription := "The dataset to list table ids."
var datasetParameter tools.Parameter
allowedDatasets := s.BigQueryAllowedDatasets()
if len(allowedDatasets) > 0 {
if len(allowedDatasets) == 1 {
parts := strings.Split(allowedDatasets[0], ".")
defaultProjectID = parts[0]
datasetID := parts[1]
projectDescription += fmt.Sprintf(" Must be `%s`.", defaultProjectID)
datasetDescription += fmt.Sprintf(" Must be `%s`.", datasetID)
datasetParameter = tools.NewStringParameterWithDefault(datasetKey, datasetID, datasetDescription)
} else {
datasetIDsByProject := make(map[string][]string)
for _, ds := range allowedDatasets {
parts := strings.Split(ds, ".")
project := parts[0]
dataset := parts[1]
datasetIDsByProject[project] = append(datasetIDsByProject[project], fmt.Sprintf("`%s`", dataset))
}
var projectParameter tools.Parameter
var datasetDescriptions, projectIDList []string
for project, datasets := range datasetIDsByProject {
sort.Strings(datasets)
projectIDList = append(projectIDList, fmt.Sprintf("`%s`", project))
datasetList := strings.Join(datasets, ", ")
datasetDescriptions = append(datasetDescriptions, fmt.Sprintf("%s from project `%s`", datasetList, project))
}
projectDescription += fmt.Sprintf(" Must be one of the following: %s.", strings.Join(projectIDList, ", "))
datasetDescription += fmt.Sprintf(" Must be one of the allowed datasets: %s.", strings.Join(datasetDescriptions, "; "))
datasetParameter = tools.NewStringParameter(datasetKey, datasetDescription)
}
} else {
datasetParameter = tools.NewStringParameter(datasetKey, datasetDescription)
}
projectParameter := tools.NewStringParameterWithDefault(projectKey, defaultProjectID, projectDescription)
projectParameter, datasetParameter = bqutil.InitializeDatasetParameters(
s.BigQueryAllowedDatasets(),
defaultProjectID,
projectKey, datasetKey,
projectDescription, datasetDescription,
)
parameters := tools.Parameters{projectParameter, datasetParameter}

View File

@@ -274,6 +274,11 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
"source": "my-instance",
"description": "Tool to list table within a dataset",
},
"get-table-info-restricted": map[string]any{
"kind": "bigquery-get-table-info",
"source": "my-instance",
"description": "Tool to get table info",
},
"execute-sql-restricted": map[string]any{
"kind": "bigquery-execute-sql",
"source": "my-instance",
@@ -318,6 +323,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
runListDatasetIdsWithRestriction(t, allowedDatasetName1, allowedDatasetName2)
runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1)
runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2)
runGetTableInfoWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, disallowedTableName)
runGetTableInfoWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, disallowedTableName)
runExecuteSqlWithRestriction(t, allowedTableNameParam1, disallowedTableNameParam)
runExecuteSqlWithRestriction(t, allowedTableNameParam2, disallowedTableNameParam)
runConversationalAnalyticsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, disallowedTableName)
@@ -2582,6 +2589,58 @@ func runListTableIdsWithRestriction(t *testing.T, allowedDatasetName, disallowed
}
}
func runGetTableInfoWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName, allowedTableName, disallowedTableName string) {
testCases := []struct {
name string
dataset string
table string
wantStatusCode int
wantInError string
}{
{
name: "invoke on allowed table",
dataset: allowedDatasetName,
table: allowedTableName,
wantStatusCode: http.StatusOK,
},
{
name: "invoke on disallowed table",
dataset: disallowedDatasetName,
table: disallowedTableName,
wantStatusCode: http.StatusBadRequest,
wantInError: fmt.Sprintf("access denied to dataset '%s'", disallowedDatasetName),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"dataset":"%s", "table":"%s"}`, tc.dataset, tc.table)))
req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/get-table-info-restricted/invoke", body)
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != tc.wantStatusCode {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
}
if tc.wantInError != "" {
bodyBytes, _ := io.ReadAll(resp.Body)
if !strings.Contains(string(bodyBytes), tc.wantInError) {
t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError)
}
}
})
}
}
func runExecuteSqlWithRestriction(t *testing.T, allowedTableFullName, disallowedTableFullName string) {
allowedTableParts := strings.Split(strings.Trim(allowedTableFullName, "`"), ".")
if len(allowedTableParts) != 3 {