mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 07:58:12 -05:00
feat(tools/bigquery-get-table-info)!: add allowed dataset support (#1093)
This introduces a breaking change. The bigquery-get-table-info tool will now enforce the allowed datasets setting from its BigQuery source configuration. Previously, this setting had no effect on the tool. Part of https://github.com/googleapis/genai-toolbox/issues/873 --------- Co-authored-by: Nikunj Badjatya <nikunj.badjatya@harness.io>
This commit is contained in:
@@ -15,10 +15,20 @@ It's compatible with the following sources:
|
||||
|
||||
- [bigquery](../../sources/bigquery.md)
|
||||
|
||||
`bigquery-get-table-info` takes `dataset` and `table` parameters to specify
|
||||
the target table. It also optionally accepts a `project` parameter to define
|
||||
the Google Cloud project ID. If the `project` parameter is not provided, the
|
||||
tool defaults to using the project defined in the source configuration.
|
||||
`bigquery-get-table-info` accepts the following parameters:
|
||||
- **`table`** (required): The name of the table for which to retrieve metadata.
|
||||
- **`dataset`** (required): The dataset containing the specified table.
|
||||
- **`project`** (optional): The Google Cloud project ID. If not provided, the
|
||||
tool defaults to the project from the source configuration.
|
||||
|
||||
The tool's behavior regarding these parameters is influenced by the
|
||||
`allowedDatasets` restriction on the `bigquery` source:
|
||||
- **Without `allowedDatasets` restriction:** The tool can retrieve metadata for
|
||||
any table specified by the `table`, `dataset`, and `project` parameters.
|
||||
- **With `allowedDatasets` restriction:** Before retrieving metadata, the tool
|
||||
verifies that the requested dataset is in the allowed list. If it is not, the
|
||||
request is denied. If only one dataset is specified in the `allowedDatasets`
|
||||
list, it will be used as the default value for the `dataset` parameter.
|
||||
|
||||
## Example
|
||||
|
||||
|
||||
@@ -17,8 +17,11 @@ package bigquerycommon
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||
)
|
||||
|
||||
@@ -69,3 +72,49 @@ func BQTypeStringFromToolType(toolType string) (string, error) {
|
||||
return "", fmt.Errorf("unsupported tool parameter type for BigQuery: %s", toolType)
|
||||
}
|
||||
}
|
||||
|
||||
// InitializeDatasetParameters generates project and dataset tool parameters based on allowedDatasets.
|
||||
func InitializeDatasetParameters(
|
||||
allowedDatasets []string,
|
||||
defaultProjectID string,
|
||||
projectKey, datasetKey string,
|
||||
projectDescription, datasetDescription string,
|
||||
) (projectParam, datasetParam tools.Parameter) {
|
||||
if len(allowedDatasets) > 0 {
|
||||
if len(allowedDatasets) == 1 {
|
||||
parts := strings.Split(allowedDatasets[0], ".")
|
||||
defaultProjectID = parts[0]
|
||||
datasetID := parts[1]
|
||||
projectDescription += fmt.Sprintf(" Must be `%s`.", defaultProjectID)
|
||||
datasetDescription += fmt.Sprintf(" Must be `%s`.", datasetID)
|
||||
datasetParam = tools.NewStringParameterWithDefault(datasetKey, datasetID, datasetDescription)
|
||||
} else {
|
||||
datasetIDsByProject := make(map[string][]string)
|
||||
for _, ds := range allowedDatasets {
|
||||
parts := strings.Split(ds, ".")
|
||||
project := parts[0]
|
||||
dataset := parts[1]
|
||||
datasetIDsByProject[project] = append(datasetIDsByProject[project], fmt.Sprintf("`%s`", dataset))
|
||||
}
|
||||
|
||||
var datasetDescriptions, projectIDList []string
|
||||
for project, datasets := range datasetIDsByProject {
|
||||
sort.Strings(datasets)
|
||||
projectIDList = append(projectIDList, fmt.Sprintf("`%s`", project))
|
||||
datasetList := strings.Join(datasets, ", ")
|
||||
datasetDescriptions = append(datasetDescriptions, fmt.Sprintf("%s from project `%s`", datasetList, project))
|
||||
}
|
||||
sort.Strings(projectIDList)
|
||||
sort.Strings(datasetDescriptions)
|
||||
projectDescription += fmt.Sprintf(" Must be one of the following: %s.", strings.Join(projectIDList, ", "))
|
||||
datasetDescription += fmt.Sprintf(" Must be one of the allowed datasets: %s.", strings.Join(datasetDescriptions, "; "))
|
||||
datasetParam = tools.NewStringParameter(datasetKey, datasetDescription)
|
||||
}
|
||||
} else {
|
||||
datasetParam = tools.NewStringParameter(datasetKey, datasetDescription)
|
||||
}
|
||||
|
||||
projectParam = tools.NewStringParameterWithDefault(projectKey, defaultProjectID, projectDescription)
|
||||
|
||||
return projectParam, datasetParam
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
||||
)
|
||||
|
||||
const kind string = "bigquery-get-table-info"
|
||||
@@ -49,6 +50,8 @@ type compatibleSource interface {
|
||||
BigQueryClient() *bigqueryapi.Client
|
||||
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
||||
UseClientAuthorization() bool
|
||||
IsDatasetAllowed(projectID, datasetID string) bool
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
@@ -84,8 +87,19 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
projectParameter := tools.NewStringParameterWithDefault(projectKey, s.BigQueryProject(), "The Google Cloud project ID containing the dataset and table.")
|
||||
datasetParameter := tools.NewStringParameter(datasetKey, "The table's parent dataset.")
|
||||
defaultProjectID := s.BigQueryProject()
|
||||
projectDescription := "The Google Cloud project ID containing the dataset and table."
|
||||
datasetDescription := "The table's parent dataset."
|
||||
var datasetParameter tools.Parameter
|
||||
var projectParameter tools.Parameter
|
||||
|
||||
projectParameter, datasetParameter = bqutil.InitializeDatasetParameters(
|
||||
s.BigQueryAllowedDatasets(),
|
||||
defaultProjectID,
|
||||
projectKey, datasetKey,
|
||||
projectDescription, datasetDescription,
|
||||
)
|
||||
|
||||
tableParameter := tools.NewStringParameter(tableKey, "The table to get metadata information.")
|
||||
parameters := tools.Parameters{projectParameter, datasetParameter, tableParameter}
|
||||
|
||||
@@ -93,15 +107,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -116,11 +131,12 @@ type Tool struct {
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
Statement string
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
@@ -140,6 +156,10 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", tableKey)
|
||||
}
|
||||
|
||||
if !t.IsDatasetAllowed(projectId, datasetId) {
|
||||
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
|
||||
var err error
|
||||
|
||||
@@ -17,14 +17,13 @@ package bigquerylisttableids
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
||||
"google.golang.org/api/iterator"
|
||||
)
|
||||
|
||||
@@ -92,39 +91,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
projectDescription := "The Google Cloud project ID containing the dataset."
|
||||
datasetDescription := "The dataset to list table ids."
|
||||
var datasetParameter tools.Parameter
|
||||
allowedDatasets := s.BigQueryAllowedDatasets()
|
||||
if len(allowedDatasets) > 0 {
|
||||
if len(allowedDatasets) == 1 {
|
||||
parts := strings.Split(allowedDatasets[0], ".")
|
||||
defaultProjectID = parts[0]
|
||||
datasetID := parts[1]
|
||||
projectDescription += fmt.Sprintf(" Must be `%s`.", defaultProjectID)
|
||||
datasetDescription += fmt.Sprintf(" Must be `%s`.", datasetID)
|
||||
datasetParameter = tools.NewStringParameterWithDefault(datasetKey, datasetID, datasetDescription)
|
||||
} else {
|
||||
datasetIDsByProject := make(map[string][]string)
|
||||
for _, ds := range allowedDatasets {
|
||||
parts := strings.Split(ds, ".")
|
||||
project := parts[0]
|
||||
dataset := parts[1]
|
||||
datasetIDsByProject[project] = append(datasetIDsByProject[project], fmt.Sprintf("`%s`", dataset))
|
||||
}
|
||||
var projectParameter tools.Parameter
|
||||
|
||||
var datasetDescriptions, projectIDList []string
|
||||
for project, datasets := range datasetIDsByProject {
|
||||
sort.Strings(datasets)
|
||||
projectIDList = append(projectIDList, fmt.Sprintf("`%s`", project))
|
||||
datasetList := strings.Join(datasets, ", ")
|
||||
datasetDescriptions = append(datasetDescriptions, fmt.Sprintf("%s from project `%s`", datasetList, project))
|
||||
}
|
||||
projectDescription += fmt.Sprintf(" Must be one of the following: %s.", strings.Join(projectIDList, ", "))
|
||||
datasetDescription += fmt.Sprintf(" Must be one of the allowed datasets: %s.", strings.Join(datasetDescriptions, "; "))
|
||||
datasetParameter = tools.NewStringParameter(datasetKey, datasetDescription)
|
||||
}
|
||||
} else {
|
||||
datasetParameter = tools.NewStringParameter(datasetKey, datasetDescription)
|
||||
}
|
||||
projectParameter := tools.NewStringParameterWithDefault(projectKey, defaultProjectID, projectDescription)
|
||||
projectParameter, datasetParameter = bqutil.InitializeDatasetParameters(
|
||||
s.BigQueryAllowedDatasets(),
|
||||
defaultProjectID,
|
||||
projectKey, datasetKey,
|
||||
projectDescription, datasetDescription,
|
||||
)
|
||||
|
||||
parameters := tools.Parameters{projectParameter, datasetParameter}
|
||||
|
||||
|
||||
@@ -274,6 +274,11 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
|
||||
"source": "my-instance",
|
||||
"description": "Tool to list table within a dataset",
|
||||
},
|
||||
"get-table-info-restricted": map[string]any{
|
||||
"kind": "bigquery-get-table-info",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to get table info",
|
||||
},
|
||||
"execute-sql-restricted": map[string]any{
|
||||
"kind": "bigquery-execute-sql",
|
||||
"source": "my-instance",
|
||||
@@ -318,6 +323,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
|
||||
runListDatasetIdsWithRestriction(t, allowedDatasetName1, allowedDatasetName2)
|
||||
runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1)
|
||||
runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2)
|
||||
runGetTableInfoWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, disallowedTableName)
|
||||
runGetTableInfoWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, disallowedTableName)
|
||||
runExecuteSqlWithRestriction(t, allowedTableNameParam1, disallowedTableNameParam)
|
||||
runExecuteSqlWithRestriction(t, allowedTableNameParam2, disallowedTableNameParam)
|
||||
runConversationalAnalyticsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, disallowedTableName)
|
||||
@@ -2582,6 +2589,58 @@ func runListTableIdsWithRestriction(t *testing.T, allowedDatasetName, disallowed
|
||||
}
|
||||
}
|
||||
|
||||
func runGetTableInfoWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName, allowedTableName, disallowedTableName string) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
dataset string
|
||||
table string
|
||||
wantStatusCode int
|
||||
wantInError string
|
||||
}{
|
||||
{
|
||||
name: "invoke on allowed table",
|
||||
dataset: allowedDatasetName,
|
||||
table: allowedTableName,
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "invoke on disallowed table",
|
||||
dataset: disallowedDatasetName,
|
||||
table: disallowedTableName,
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantInError: fmt.Sprintf("access denied to dataset '%s'", disallowedDatasetName),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"dataset":"%s", "table":"%s"}`, tc.dataset, tc.table)))
|
||||
req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/get-table-info-restricted/invoke", body)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
if tc.wantInError != "" {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
if !strings.Contains(string(bodyBytes), tc.wantInError) {
|
||||
t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runExecuteSqlWithRestriction(t *testing.T, allowedTableFullName, disallowedTableFullName string) {
|
||||
allowedTableParts := strings.Split(strings.Trim(allowedTableFullName, "`"), ".")
|
||||
if len(allowedTableParts) != 3 {
|
||||
|
||||
Reference in New Issue
Block a user