feat(tools/bigquery-get-dataset-info)!: add allowed dataset support (#1654)

## Description

---
This introduces a breaking change. The bigquery-get-dataset-info tool
will now enforce the allowed datasets setting from its BigQuery source
configuration. Previously, this setting had no effect on the tool.

## PR Checklist

---
> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [ ] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [ ] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [ ] Ensure the tests and linter pass
- [ ] Code coverage does not decrease (if any source code was changed)
- [ ] Appropriate docs were updated (if necessary)
- [ ] Make sure to add `!` if this involve a breaking change

🛠️ Part of https://github.com/googleapis/genai-toolbox/issues/873
This commit is contained in:
Huan Chen
2025-10-30 13:45:02 -07:00
committed by GitHub
parent fae133930f
commit a2006ad577
3 changed files with 107 additions and 23 deletions

View File

@@ -15,10 +15,19 @@ It's compatible with the following sources:
- [bigquery](../../sources/bigquery.md)
`bigquery-get-dataset-info` takes a `dataset` parameter to specify the dataset
on the given source. It also optionally accepts a `project` parameter to
define the Google Cloud project ID. If the `project` parameter is not provided,
the tool defaults to using the project defined in the source configuration.
`bigquery-get-dataset-info` accepts the following parameters:
- **`dataset`** (required): Specifies the dataset for which to retrieve metadata.
- **`project`** (optional): Defines the Google Cloud project ID. If not provided,
the tool defaults to the project from the source configuration.
The tool's behavior regarding these parameters is influenced by the
`allowedDatasets` restriction on the `bigquery` source:
- **Without `allowedDatasets` restriction:** The tool can retrieve metadata for
any dataset specified by the `dataset` and `project` parameters.
- **With `allowedDatasets` restriction:** Before retrieving metadata, the tool
verifies that the requested dataset is in the allowed list. If it is not, the
request is denied. If only one dataset is specified in the `allowedDatasets`
list, it will be used as the default value for the `dataset` parameter.
## Example

View File

@@ -23,6 +23,7 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
)
const kind string = "bigquery-get-dataset-info"
@@ -48,6 +49,8 @@ type compatibleSource interface {
BigQueryClient() *bigqueryapi.Client
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
IsDatasetAllowed(projectID, datasetID string) bool
BigQueryAllowedDatasets() []string
}
// validate compatible sources are still compatible
@@ -83,23 +86,33 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
projectParameter := tools.NewStringParameterWithDefault(projectKey, s.BigQueryProject(), "The Google Cloud project ID containing the dataset.")
datasetParameter := tools.NewStringParameter(datasetKey, "The dataset to get metadata information.")
defaultProjectID := s.BigQueryProject()
projectDescription := "The Google Cloud project ID containing the dataset."
datasetDescription := "The dataset to get metadata information. Can be in `project.dataset` format."
var datasetParameter tools.Parameter
var projectParameter tools.Parameter
projectParameter, datasetParameter = bqutil.InitializeDatasetParameters(
s.BigQueryAllowedDatasets(),
defaultProjectID,
projectKey, datasetKey,
projectDescription, datasetDescription)
parameters := tools.Parameters{projectParameter, datasetParameter}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters)
// finish tool setup
t := Tool{
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
IsDatasetAllowed: s.IsDatasetAllowed,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -114,11 +127,12 @@ type Tool struct {
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
Statement string
IsDatasetAllowed func(projectID, datasetID string) bool
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
@@ -147,11 +161,16 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}
if !t.IsDatasetAllowed(projectId, datasetId) {
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
}
dsHandle := bqClient.DatasetInProject(projectId, datasetId)
metadata, err := dsHandle.Metadata(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get metadata for dataset %s (in project %s): %w", datasetId, bqClient.Project(), err)
return nil, fmt.Errorf("failed to get metadata for dataset %s (in project %s): %w", datasetId, projectId, err)
}
return metadata, nil

View File

@@ -275,6 +275,11 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
"source": "my-instance",
"description": "Tool to list table within a dataset",
},
"get-dataset-info-restricted": map[string]any{
"kind": "bigquery-get-dataset-info",
"source": "my-instance",
"description": "Tool to get dataset info",
},
"get-table-info-restricted": map[string]any{
"kind": "bigquery-get-table-info",
"source": "my-instance",
@@ -324,6 +329,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
runListDatasetIdsWithRestriction(t, allowedDatasetName1, allowedDatasetName2)
runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1)
runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2)
runGetDatasetInfoWithRestriction(t, allowedDatasetName1, disallowedDatasetName)
runGetDatasetInfoWithRestriction(t, allowedDatasetName2, disallowedDatasetName)
runGetTableInfoWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, disallowedTableName)
runGetTableInfoWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, disallowedTableName)
runExecuteSqlWithRestriction(t, allowedTableNameParam1, disallowedTableNameParam)
@@ -2474,7 +2481,7 @@ func runListDatasetIdsWithRestriction(t *testing.T, allowedDatasetName1, allowed
testCases := []struct {
name string
wantStatusCode int
wantElements []string
wantElements []string
}{
{
name: "invoke list-dataset-ids with restriction",
@@ -2499,7 +2506,7 @@ func runListDatasetIdsWithRestriction(t *testing.T, allowedDatasetName1, allowed
if err := json.Unmarshal(bodyBytes, &respBody); err != nil {
t.Fatalf("error parsing response body: %v", err)
}
gotJSON, ok := respBody["result"].(string)
if !ok {
t.Fatalf("unable to find 'result' as a string in response body: %s", string(bodyBytes))
@@ -2603,6 +2610,55 @@ func runListTableIdsWithRestriction(t *testing.T, allowedDatasetName, disallowed
}
}
func runGetDatasetInfoWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName string) {
testCases := []struct {
name string
dataset string
wantStatusCode int
wantInError string
}{
{
name: "invoke on allowed dataset",
dataset: allowedDatasetName,
wantStatusCode: http.StatusOK,
},
{
name: "invoke on disallowed dataset",
dataset: disallowedDatasetName,
wantStatusCode: http.StatusBadRequest,
wantInError: fmt.Sprintf("access denied to dataset '%s'", disallowedDatasetName),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"dataset":"%s"}`, tc.dataset)))
req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/get-dataset-info-restricted/invoke", body)
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != tc.wantStatusCode {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
}
if tc.wantInError != "" {
bodyBytes, _ := io.ReadAll(resp.Body)
if !strings.Contains(string(bodyBytes), tc.wantInError) {
t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError)
}
}
})
}
}
func runGetTableInfoWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName, allowedTableName, disallowedTableName string) {
testCases := []struct {
name string