mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 07:28:05 -05:00
feat(tools/bigquery-get-dataset-info)!: add allowed dataset support (#1654)
## Description --- This introduces a breaking change. The bigquery-get-dataset-info tool will now enforce the allowed datasets setting from its BigQuery source configuration. Previously, this setting had no effect on the tool. ## PR Checklist --- > Thank you for opening a Pull Request! Before submitting your PR, there are a > few things you can do to make sure it goes smoothly: - [ ] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) - [ ] Make sure to add `!` if this involve a breaking change 🛠️ Part of https://github.com/googleapis/genai-toolbox/issues/873
This commit is contained in:
@@ -15,10 +15,19 @@ It's compatible with the following sources:
|
||||
|
||||
- [bigquery](../../sources/bigquery.md)
|
||||
|
||||
`bigquery-get-dataset-info` takes a `dataset` parameter to specify the dataset
|
||||
on the given source. It also optionally accepts a `project` parameter to
|
||||
define the Google Cloud project ID. If the `project` parameter is not provided,
|
||||
the tool defaults to using the project defined in the source configuration.
|
||||
`bigquery-get-dataset-info` accepts the following parameters:
|
||||
- **`dataset`** (required): Specifies the dataset for which to retrieve metadata.
|
||||
- **`project`** (optional): Defines the Google Cloud project ID. If not provided,
|
||||
the tool defaults to the project from the source configuration.
|
||||
|
||||
The tool's behavior regarding these parameters is influenced by the
|
||||
`allowedDatasets` restriction on the `bigquery` source:
|
||||
- **Without `allowedDatasets` restriction:** The tool can retrieve metadata for
|
||||
any dataset specified by the `dataset` and `project` parameters.
|
||||
- **With `allowedDatasets` restriction:** Before retrieving metadata, the tool
|
||||
verifies that the requested dataset is in the allowed list. If it is not, the
|
||||
request is denied. If only one dataset is specified in the `allowedDatasets`
|
||||
list, it will be used as the default value for the `dataset` parameter.
|
||||
|
||||
## Example
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
||||
)
|
||||
|
||||
const kind string = "bigquery-get-dataset-info"
|
||||
@@ -48,6 +49,8 @@ type compatibleSource interface {
|
||||
BigQueryClient() *bigqueryapi.Client
|
||||
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
||||
UseClientAuthorization() bool
|
||||
IsDatasetAllowed(projectID, datasetID string) bool
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
@@ -83,23 +86,33 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
projectParameter := tools.NewStringParameterWithDefault(projectKey, s.BigQueryProject(), "The Google Cloud project ID containing the dataset.")
|
||||
datasetParameter := tools.NewStringParameter(datasetKey, "The dataset to get metadata information.")
|
||||
defaultProjectID := s.BigQueryProject()
|
||||
projectDescription := "The Google Cloud project ID containing the dataset."
|
||||
datasetDescription := "The dataset to get metadata information. Can be in `project.dataset` format."
|
||||
var datasetParameter tools.Parameter
|
||||
var projectParameter tools.Parameter
|
||||
|
||||
projectParameter, datasetParameter = bqutil.InitializeDatasetParameters(
|
||||
s.BigQueryAllowedDatasets(),
|
||||
defaultProjectID,
|
||||
projectKey, datasetKey,
|
||||
projectDescription, datasetDescription)
|
||||
parameters := tools.Parameters{projectParameter, datasetParameter}
|
||||
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -114,11 +127,12 @@ type Tool struct {
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
Statement string
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
@@ -147,11 +161,16 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !t.IsDatasetAllowed(projectId, datasetId) {
|
||||
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
|
||||
}
|
||||
|
||||
dsHandle := bqClient.DatasetInProject(projectId, datasetId)
|
||||
|
||||
metadata, err := dsHandle.Metadata(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metadata for dataset %s (in project %s): %w", datasetId, bqClient.Project(), err)
|
||||
return nil, fmt.Errorf("failed to get metadata for dataset %s (in project %s): %w", datasetId, projectId, err)
|
||||
}
|
||||
|
||||
return metadata, nil
|
||||
|
||||
@@ -275,6 +275,11 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
|
||||
"source": "my-instance",
|
||||
"description": "Tool to list table within a dataset",
|
||||
},
|
||||
"get-dataset-info-restricted": map[string]any{
|
||||
"kind": "bigquery-get-dataset-info",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to get dataset info",
|
||||
},
|
||||
"get-table-info-restricted": map[string]any{
|
||||
"kind": "bigquery-get-table-info",
|
||||
"source": "my-instance",
|
||||
@@ -324,6 +329,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
|
||||
runListDatasetIdsWithRestriction(t, allowedDatasetName1, allowedDatasetName2)
|
||||
runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1)
|
||||
runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2)
|
||||
runGetDatasetInfoWithRestriction(t, allowedDatasetName1, disallowedDatasetName)
|
||||
runGetDatasetInfoWithRestriction(t, allowedDatasetName2, disallowedDatasetName)
|
||||
runGetTableInfoWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, disallowedTableName)
|
||||
runGetTableInfoWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, disallowedTableName)
|
||||
runExecuteSqlWithRestriction(t, allowedTableNameParam1, disallowedTableNameParam)
|
||||
@@ -2474,7 +2481,7 @@ func runListDatasetIdsWithRestriction(t *testing.T, allowedDatasetName1, allowed
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantStatusCode int
|
||||
wantElements []string
|
||||
wantElements []string
|
||||
}{
|
||||
{
|
||||
name: "invoke list-dataset-ids with restriction",
|
||||
@@ -2499,7 +2506,7 @@ func runListDatasetIdsWithRestriction(t *testing.T, allowedDatasetName1, allowed
|
||||
if err := json.Unmarshal(bodyBytes, &respBody); err != nil {
|
||||
t.Fatalf("error parsing response body: %v", err)
|
||||
}
|
||||
|
||||
|
||||
gotJSON, ok := respBody["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find 'result' as a string in response body: %s", string(bodyBytes))
|
||||
@@ -2603,6 +2610,55 @@ func runListTableIdsWithRestriction(t *testing.T, allowedDatasetName, disallowed
|
||||
}
|
||||
}
|
||||
|
||||
func runGetDatasetInfoWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName string) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
dataset string
|
||||
wantStatusCode int
|
||||
wantInError string
|
||||
}{
|
||||
{
|
||||
name: "invoke on allowed dataset",
|
||||
dataset: allowedDatasetName,
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "invoke on disallowed dataset",
|
||||
dataset: disallowedDatasetName,
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantInError: fmt.Sprintf("access denied to dataset '%s'", disallowedDatasetName),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"dataset":"%s"}`, tc.dataset)))
|
||||
req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/get-dataset-info-restricted/invoke", body)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
if tc.wantInError != "" {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
if !strings.Contains(string(bodyBytes), tc.wantInError) {
|
||||
t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runGetTableInfoWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName, allowedTableName, disallowedTableName string) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
||||
Reference in New Issue
Block a user