mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-11 16:38:15 -05:00
feat(tool/bigquery-list-dataset-ids)!: add allowed datasets support (#1573)
## Description This introduces a breaking change. The bigquery-list-dataset-ids tool will now enforce the allowed datasets setting from its BigQuery source configuration. Previously, this setting had no effect on the tool. The tool's behavior regarding this parameter is influenced by the `allowedDatasets` restriction on the `bigquery` source: - **Without `allowedDatasets` restriction:** The tool can list datasets from any project specified by the `project` parameter. - **With `allowedDatasets` restriction:** The tool directly returns the pre-configured list of dataset IDs from the source, and the `project` parameter is ignored. --- > Should include a concise description of the changes (bug or feature), it's > impact, along with a summary of the solution ## PR Checklist --- > Thank you for opening a Pull Request! Before submitting your PR, there are a > few things you can do to make sure it goes smoothly: - [ ] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) - [ ] Make sure to add `!` if this involve a breaking change 🛠️ Part of https://github.com/googleapis/genai-toolbox/issues/873 --------- Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
This commit is contained in:
@@ -15,9 +15,17 @@ It's compatible with the following sources:
|
||||
|
||||
- [bigquery](../../sources/bigquery.md)
|
||||
|
||||
`bigquery-list-dataset-ids` optionally accepts a `project` parameter to define
|
||||
the Google Cloud project ID. If the `project` parameter is not provided, the
|
||||
tool defaults to using the project defined in the source configuration.
|
||||
`bigquery-list-dataset-ids` accepts the following parameter:
|
||||
- **`project`** (optional): Defines the Google Cloud project ID. If not provided,
|
||||
the tool defaults to the project from the source configuration.
|
||||
|
||||
The tool's behavior regarding this parameter is influenced by the
|
||||
`allowedDatasets` restriction on the `bigquery` source:
|
||||
- **Without `allowedDatasets` restriction:** The tool can list datasets from any
|
||||
project specified by the `project` parameter.
|
||||
- **With `allowedDatasets` restriction:** The tool directly returns the
|
||||
pre-configured list of dataset IDs from the source, and the `project`
|
||||
parameter is ignored.
|
||||
|
||||
## Example
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ type compatibleSource interface {
|
||||
BigQueryClient() *bigqueryapi.Client
|
||||
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
||||
UseClientAuthorization() bool
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
@@ -83,7 +84,17 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
projectParameter := tools.NewStringParameterWithDefault(projectKey, s.BigQueryProject(), "The Google Cloud project to list dataset ids.")
|
||||
var projectParameter tools.Parameter
|
||||
var projectParameterDescription string
|
||||
|
||||
allowedDatasets := s.BigQueryAllowedDatasets()
|
||||
if len(allowedDatasets) > 0 {
|
||||
projectParameterDescription = "This parameter will be ignored. The list of datasets is restricted to a pre-configured list; No need to provide a project ID."
|
||||
} else {
|
||||
projectParameterDescription = "The Google Cloud project to list dataset ids."
|
||||
}
|
||||
|
||||
projectParameter = tools.NewStringParameterWithDefault(projectKey, s.BigQueryProject(), projectParameterDescription)
|
||||
|
||||
parameters := tools.Parameters{projectParameter}
|
||||
|
||||
@@ -91,15 +102,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
AllowedDatasets: allowedDatasets,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -114,14 +126,18 @@ type Tool struct {
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
Statement string
|
||||
AllowedDatasets []string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
return t.AllowedDatasets, nil
|
||||
}
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
if !ok {
|
||||
|
||||
@@ -264,6 +264,11 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
|
||||
|
||||
// Configure tool
|
||||
toolsConfig := map[string]any{
|
||||
"list-dataset-ids-restricted": map[string]any{
|
||||
"kind": "bigquery-list-dataset-ids",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to list dataset ids",
|
||||
},
|
||||
"list-table-ids-restricted": map[string]any{
|
||||
"kind": "bigquery-list-table-ids",
|
||||
"source": "my-instance",
|
||||
@@ -310,6 +315,7 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
|
||||
}
|
||||
|
||||
// Run tests
|
||||
runListDatasetIdsWithRestriction(t, allowedDatasetName1, allowedDatasetName2)
|
||||
runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1)
|
||||
runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2)
|
||||
runExecuteSqlWithRestriction(t, allowedTableNameParam1, disallowedTableNameParam)
|
||||
@@ -2080,6 +2086,43 @@ func runBigQueryConversationalAnalyticsInvokeTest(t *testing.T, datasetName, tab
|
||||
}
|
||||
}
|
||||
|
||||
func runListDatasetIdsWithRestriction(t *testing.T, allowedDatasetName1, allowedDatasetName2 string) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantStatusCode int
|
||||
wantResult string
|
||||
}{
|
||||
{
|
||||
name: "invoke list-dataset-ids with restriction",
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantResult: fmt.Sprintf(`["%s.%s","%s.%s"]`, BigqueryProject, allowedDatasetName1, BigqueryProject, allowedDatasetName2),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
body := bytes.NewBuffer([]byte(`{}`))
|
||||
resp, bodyBytes := tests.RunRequest(t, http.MethodPost, "http://127.0.0.1:5000/api/tool/list-dataset-ids-restricted/invoke", body, nil)
|
||||
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var respBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &respBody); err != nil {
|
||||
t.Fatalf("error parsing response body: %v", err)
|
||||
}
|
||||
got, ok := respBody["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
if got != tc.wantResult {
|
||||
t.Errorf("unexpected result: got %q, want %q", got, tc.wantResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runListTableIdsWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName string, allowedTableNames ...string) {
|
||||
sort.Strings(allowedTableNames)
|
||||
var quotedNames []string
|
||||
|
||||
Reference in New Issue
Block a user