mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 07:58:12 -05:00
feat(source/bigquery): Add support for datasets selection (#1313)
## Description --- - bigquery Source: The source configuration now supports a new allowedDatasets field, which defines the list of datasets the tools are allowed to access. - bigquery-list-table-ids: Now verifies that the requested dataset is in the allowed datasets list before listing its tables. An error is returned if access is not permitted. ## PR Checklist --- > Thank you for opening a Pull Request! Before submitting your PR, there are a > few things you can do to make sure it goes smoothly: - [ ] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/langchain-google-alloydb-pg-python/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) - [ ] Make sure to add `!` if this involve a breaking change 🛠️ Fixes #<issue_number_goes_here> --------- Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
This commit is contained in:
@@ -111,6 +111,9 @@ sources:
|
||||
kind: "bigquery"
|
||||
project: "my-project-id"
|
||||
# location: "US" # Optional: Specifies the location for query jobs.
|
||||
# allowedDatasets: # Optional: Restricts tool access to a specific list of datasets.
|
||||
# - "my_dataset_1"
|
||||
# - "other_project.my_dataset_2"
|
||||
```
|
||||
|
||||
Initialize a BigQuery source that uses the client's access token:
|
||||
@@ -122,6 +125,9 @@ sources:
|
||||
project: "my-project-id"
|
||||
useClientOAuth: true
|
||||
# location: "US" # Optional: Specifies the location for query jobs.
|
||||
# allowedDatasets: # Optional: Restricts tool access to a specific list of datasets.
|
||||
# - "my_dataset_1"
|
||||
# - "other_project.my_dataset_2"
|
||||
```
|
||||
|
||||
## Reference
|
||||
@@ -131,4 +137,5 @@ sources:
|
||||
| kind | string | true | Must be "bigquery". |
|
||||
| project | string | true | Id of the Google Cloud project to use for billing and as the default project for BigQuery resources. |
|
||||
| location | string | false | Specifies the location (e.g., 'us', 'asia-northeast1') in which to run the query job. This location must match the location of any tables referenced in the query. Defaults to the table's location or 'US' if the location cannot be determined. [Learn More](https://cloud.google.com/bigquery/docs/locations) |
|
||||
| allowedDatasets | []string | false | An optional list of dataset IDs that tools using this source are allowed to access. If provided, any tool operation attempting to access a dataset not in this list will be rejected. To enforce this, two types of operations are also disallowed: 1) Dataset-level operations (e.g., `CREATE SCHEMA`), and 2) operations where table access cannot be statically analyzed (e.g., `EXECUTE IMMEDIATE`, `CREATE PROCEDURE`). If a single dataset is provided, it will be treated as the default for prebuilt tools. |
|
||||
| useClientOAuth | bool | false | If true, forwards the client's OAuth access token from the "Authorization" header to downstream queries. |
|
||||
|
||||
@@ -15,10 +15,19 @@ It's compatible with the following sources:
|
||||
|
||||
- [bigquery](../../sources/bigquery.md)
|
||||
|
||||
`bigquery-list-table-ids` takes a required `dataset` parameter to specify the dataset
|
||||
from which to list table IDs. It also optionally accepts a `project` parameter to
|
||||
define the Google Cloud project ID. If the `project` parameter is not provided, the
|
||||
tool defaults to using the project defined in the source configuration.
|
||||
`bigquery-list-table-ids` accepts the following parameters:
|
||||
- **`dataset`** (required): Specifies the dataset from which to list table IDs.
|
||||
- **`project`** (optional): Defines the Google Cloud project ID. If not provided,
|
||||
the tool defaults to the project from the source configuration.
|
||||
|
||||
The tool's behavior regarding these parameters is influenced by the
|
||||
`allowedDatasets` restriction on the `bigquery` source:
|
||||
- **Without `allowedDatasets` restriction:** The tool can list tables from any
|
||||
dataset specified by the `dataset` and `project` parameters.
|
||||
- **With `allowedDatasets` restriction:** Before listing tables, the tool verifies
|
||||
that the requested dataset is in the allowed list. If it is not, the request is
|
||||
denied. If only one dataset is specified in the `allowedDatasets` list, it
|
||||
will be used as the default value for the `dataset` parameter.
|
||||
|
||||
## Example
|
||||
|
||||
|
||||
@@ -17,6 +17,8 @@ package bigquery
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
"github.com/goccy/go-yaml"
|
||||
@@ -26,6 +28,7 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||
"google.golang.org/api/googleapi"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
@@ -52,11 +55,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
// BigQuery configs
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Location string `yaml:"location"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Location string `yaml:"location"`
|
||||
AllowedDatasets []string `yaml:"allowedDatasets"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
@@ -84,6 +88,37 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
}
|
||||
}
|
||||
|
||||
allowedDatasets := make(map[string]struct{})
|
||||
// Get full id of allowed datasets and verify they exist.
|
||||
if len(r.AllowedDatasets) > 0 {
|
||||
for _, allowed := range r.AllowedDatasets {
|
||||
var projectID, datasetID, allowedFullID string
|
||||
if strings.Contains(allowed, ".") {
|
||||
parts := strings.Split(allowed, ".")
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid allowedDataset format: %q, expected 'project.dataset' or 'dataset'", allowed)
|
||||
}
|
||||
projectID = parts[0]
|
||||
datasetID = parts[1]
|
||||
allowedFullID = allowed
|
||||
} else {
|
||||
projectID = client.Project()
|
||||
datasetID = allowed
|
||||
allowedFullID = fmt.Sprintf("%s.%s", projectID, datasetID)
|
||||
}
|
||||
|
||||
dataset := client.DatasetInProject(projectID, datasetID)
|
||||
_, err := dataset.Metadata(ctx)
|
||||
if err != nil {
|
||||
if gerr, ok := err.(*googleapi.Error); ok && gerr.Code == http.StatusNotFound {
|
||||
return nil, fmt.Errorf("allowedDataset '%s' not found in project '%s'", datasetID, projectID)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to verify allowedDataset '%s' in project '%s': %w", datasetID, projectID, err)
|
||||
}
|
||||
allowedDatasets[allowedFullID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
@@ -94,6 +129,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
TokenSource: tokenSource,
|
||||
MaxQueryResultRows: 50,
|
||||
ClientCreator: clientCreator,
|
||||
AllowedDatasets: allowedDatasets,
|
||||
UseClientOAuth: r.UseClientOAuth,
|
||||
}
|
||||
return s, nil
|
||||
@@ -113,6 +149,7 @@ type Source struct {
|
||||
TokenSource oauth2.TokenSource
|
||||
MaxQueryResultRows int
|
||||
ClientCreator BigqueryClientCreator
|
||||
AllowedDatasets map[string]struct{}
|
||||
UseClientOAuth bool
|
||||
}
|
||||
|
||||
@@ -153,6 +190,29 @@ func (s *Source) BigQueryClientCreator() BigqueryClientCreator {
|
||||
return s.ClientCreator
|
||||
}
|
||||
|
||||
func (s *Source) BigQueryAllowedDatasets() []string {
|
||||
if len(s.AllowedDatasets) == 0 {
|
||||
return nil
|
||||
}
|
||||
datasets := make([]string, 0, len(s.AllowedDatasets))
|
||||
for d := range s.AllowedDatasets {
|
||||
datasets = append(datasets, d)
|
||||
}
|
||||
return datasets
|
||||
}
|
||||
|
||||
// IsDatasetAllowed checks if a given dataset is accessible based on the source's configuration.
|
||||
func (s *Source) IsDatasetAllowed(projectID, datasetID string) bool {
|
||||
// If the normalized map is empty, it means no restrictions were configured.
|
||||
if len(s.AllowedDatasets) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
targetDataset := fmt.Sprintf("%s.%s", projectID, datasetID)
|
||||
_, ok := s.AllowedDatasets[targetDataset]
|
||||
return ok
|
||||
}
|
||||
|
||||
func initBigQueryConnection(
|
||||
ctx context.Context,
|
||||
tracer trace.Tracer,
|
||||
|
||||
@@ -69,6 +69,27 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with allowed datasets example",
|
||||
in: `
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
allowedDatasets:
|
||||
- my_dataset
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
AllowedDatasets: []string{"my_dataset"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
|
||||
@@ -17,6 +17,8 @@ package bigquerylisttableids
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
@@ -49,6 +51,8 @@ type compatibleSource interface {
|
||||
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
||||
BigQueryProject() string
|
||||
UseClientAuthorization() bool
|
||||
IsDatasetAllowed(projectID, datasetID string) bool
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
@@ -84,8 +88,44 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
projectParameter := tools.NewStringParameterWithDefault(projectKey, s.BigQueryProject(), "The Google Cloud project ID containing the dataset.")
|
||||
datasetParameter := tools.NewStringParameter(datasetKey, "The dataset to list table ids.")
|
||||
defaultProjectID := s.BigQueryProject()
|
||||
projectDescription := "The Google Cloud project ID containing the dataset."
|
||||
datasetDescription := "The dataset to list table ids."
|
||||
var datasetParameter tools.Parameter
|
||||
allowedDatasets := s.BigQueryAllowedDatasets()
|
||||
if len(allowedDatasets) > 0 {
|
||||
if len(allowedDatasets) == 1 {
|
||||
parts := strings.Split(allowedDatasets[0], ".")
|
||||
defaultProjectID = parts[0]
|
||||
datasetID := parts[1]
|
||||
projectDescription += fmt.Sprintf(" Must be `%s`.", defaultProjectID)
|
||||
datasetDescription += fmt.Sprintf(" Must be `%s`.", datasetID)
|
||||
datasetParameter = tools.NewStringParameterWithDefault(datasetKey, datasetID, datasetDescription)
|
||||
} else {
|
||||
datasetIDsByProject := make(map[string][]string)
|
||||
for _, ds := range allowedDatasets {
|
||||
parts := strings.Split(ds, ".")
|
||||
project := parts[0]
|
||||
dataset := parts[1]
|
||||
datasetIDsByProject[project] = append(datasetIDsByProject[project], fmt.Sprintf("`%s`", dataset))
|
||||
}
|
||||
|
||||
var datasetDescriptions, projectIDList []string
|
||||
for project, datasets := range datasetIDsByProject {
|
||||
sort.Strings(datasets)
|
||||
projectIDList = append(projectIDList, fmt.Sprintf("`%s`", project))
|
||||
datasetList := strings.Join(datasets, ", ")
|
||||
datasetDescriptions = append(datasetDescriptions, fmt.Sprintf("%s from project `%s`", datasetList, project))
|
||||
}
|
||||
projectDescription += fmt.Sprintf(" Must be one of the following: %s.", strings.Join(projectIDList, ", "))
|
||||
datasetDescription += fmt.Sprintf(" Must be one of the allowed datasets: %s.", strings.Join(datasetDescriptions, "; "))
|
||||
datasetParameter = tools.NewStringParameter(datasetKey, datasetDescription)
|
||||
}
|
||||
} else {
|
||||
datasetParameter = tools.NewStringParameter(datasetKey, datasetDescription)
|
||||
}
|
||||
projectParameter := tools.NewStringParameterWithDefault(projectKey, defaultProjectID, projectDescription)
|
||||
|
||||
parameters := tools.Parameters{projectParameter, datasetParameter}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
@@ -96,15 +136,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -119,11 +160,12 @@ type Tool struct {
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
@@ -138,6 +180,10 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
|
||||
}
|
||||
|
||||
if !t.IsDatasetAllowed(projectId, datasetId) {
|
||||
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
@@ -161,7 +207,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to iterate through tables in dataset %s.%s: %w", bqClient.Project(), datasetId, err)
|
||||
return nil, fmt.Errorf("failed to iterate through tables in dataset %s.%s: %w", projectId, datasetId, err)
|
||||
}
|
||||
|
||||
// Remove leading and trailing quotes
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -189,6 +190,102 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
runBigQueryConversationalAnalyticsInvokeTest(t, datasetName, tableName, dataInsightsWant)
|
||||
}
|
||||
|
||||
func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, err := initBigQueryConnection(BigqueryProject)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create BigQuery client: %s", err)
|
||||
}
|
||||
|
||||
// Create two datasets, one allowed, one not.
|
||||
baseName := strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
allowedDatasetName1 := fmt.Sprintf("allowed_dataset_1_%s", baseName)
|
||||
allowedDatasetName2 := fmt.Sprintf("allowed_dataset_2_%s", baseName)
|
||||
disallowedDatasetName := fmt.Sprintf("disallowed_dataset_%s", baseName)
|
||||
allowedTableName1 := "allowed_table_1"
|
||||
allowedTableName2 := "allowed_table_2"
|
||||
disallowedTableName := "disallowed_table"
|
||||
allowedForecastTableName1 := "allowed_forecast_table_1"
|
||||
allowedForecastTableName2 := "allowed_forecast_table_2"
|
||||
disallowedForecastTableName := "disallowed_forecast_table"
|
||||
|
||||
// Setup allowed table
|
||||
allowedTableNameParam1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedTableName1)
|
||||
createAllowedTableStmt1 := fmt.Sprintf("CREATE TABLE %s (id INT64)", allowedTableNameParam1)
|
||||
teardownAllowed1 := setupBigQueryTable(t, ctx, client, createAllowedTableStmt1, "", allowedDatasetName1, allowedTableNameParam1, nil)
|
||||
defer teardownAllowed1(t)
|
||||
|
||||
allowedTableNameParam2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedTableName2)
|
||||
createAllowedTableStmt2 := fmt.Sprintf("CREATE TABLE %s (id INT64)", allowedTableNameParam2)
|
||||
teardownAllowed2 := setupBigQueryTable(t, ctx, client, createAllowedTableStmt2, "", allowedDatasetName2, allowedTableNameParam2, nil)
|
||||
defer teardownAllowed2(t)
|
||||
|
||||
// Setup allowed forecast table
|
||||
allowedForecastTableFullName1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedForecastTableName1)
|
||||
createForecastStmt1, insertForecastStmt1, forecastParams1 := getBigQueryForecastToolInfo(allowedForecastTableFullName1)
|
||||
teardownAllowedForecast1 := setupBigQueryTable(t, ctx, client, createForecastStmt1, insertForecastStmt1, allowedDatasetName1, allowedForecastTableFullName1, forecastParams1)
|
||||
defer teardownAllowedForecast1(t)
|
||||
|
||||
allowedForecastTableFullName2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedForecastTableName2)
|
||||
createForecastStmt2, insertForecastStmt2, forecastParams2 := getBigQueryForecastToolInfo(allowedForecastTableFullName2)
|
||||
teardownAllowedForecast2 := setupBigQueryTable(t, ctx, client, createForecastStmt2, insertForecastStmt2, allowedDatasetName2, allowedForecastTableFullName2, forecastParams2)
|
||||
defer teardownAllowedForecast2(t)
|
||||
|
||||
// Setup disallowed table
|
||||
disallowedTableNameParam := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedTableName)
|
||||
createDisallowedTableStmt := fmt.Sprintf("CREATE TABLE %s (id INT64)", disallowedTableNameParam)
|
||||
teardownDisallowed := setupBigQueryTable(t, ctx, client, createDisallowedTableStmt, "", disallowedDatasetName, disallowedTableNameParam, nil)
|
||||
defer teardownDisallowed(t)
|
||||
|
||||
// Setup disallowed forecast table
|
||||
disallowedForecastTableFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedForecastTableName)
|
||||
createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedForecastParams := getBigQueryForecastToolInfo(disallowedForecastTableFullName)
|
||||
teardownDisallowedForecast := setupBigQueryTable(t, ctx, client, createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedDatasetName, disallowedForecastTableFullName, disallowedForecastParams)
|
||||
defer teardownDisallowedForecast(t)
|
||||
|
||||
// Configure source with dataset restriction.
|
||||
sourceConfig := getBigQueryVars(t)
|
||||
sourceConfig["allowedDatasets"] = []string{allowedDatasetName1, allowedDatasetName2}
|
||||
|
||||
// Configure tool
|
||||
toolsConfig := map[string]any{
|
||||
"list-table-ids-restricted": map[string]any{
|
||||
"kind": "bigquery-list-table-ids",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to list table within a dataset",
|
||||
},
|
||||
}
|
||||
|
||||
// Create config file
|
||||
config := map[string]any{
|
||||
"sources": map[string]any{
|
||||
"my-instance": sourceConfig,
|
||||
},
|
||||
"tools": toolsConfig,
|
||||
}
|
||||
|
||||
// Start server
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, config)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
// Run tests
|
||||
runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1)
|
||||
runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2)
|
||||
}
|
||||
|
||||
// getBigQueryParamToolInfo returns statements and param for my-tool for bigquery kind
|
||||
func getBigQueryParamToolInfo(tableName string) (string, string, string, string, string, string, []bigqueryapi.QueryParameter) {
|
||||
createStatement := fmt.Sprintf(`
|
||||
@@ -295,19 +392,21 @@ func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.C
|
||||
t.Fatalf("Create table job for %s failed: %v", tableName, err)
|
||||
}
|
||||
|
||||
// Insert test data
|
||||
insertQuery := client.Query(insertStatement)
|
||||
insertQuery.Parameters = params
|
||||
insertJob, err := insertQuery.Run(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start insert job for %s: %v", tableName, err)
|
||||
}
|
||||
insertStatus, err := insertJob.Wait(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to wait for insert job for %s: %v", tableName, err)
|
||||
}
|
||||
if err := insertStatus.Err(); err != nil {
|
||||
t.Fatalf("Insert job for %s failed: %v", tableName, err)
|
||||
if len(params) > 0 {
|
||||
// Insert test data
|
||||
insertQuery := client.Query(insertStatement)
|
||||
insertQuery.Parameters = params
|
||||
insertJob, err := insertQuery.Run(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start insert job for %s: %v", tableName, err)
|
||||
}
|
||||
insertStatus, err := insertJob.Wait(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to wait for insert job for %s: %v", tableName, err)
|
||||
}
|
||||
if err := insertStatus.Err(); err != nil {
|
||||
t.Fatalf("Insert job for %s failed: %v", tableName, err)
|
||||
}
|
||||
}
|
||||
|
||||
return func(t *testing.T) {
|
||||
@@ -1731,3 +1830,86 @@ func runBigQueryConversationalAnalyticsInvokeTest(t *testing.T, datasetName, tab
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runListTableIdsWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName string, allowedTableNames ...string) {
|
||||
sort.Strings(allowedTableNames)
|
||||
var quotedNames []string
|
||||
for _, name := range allowedTableNames {
|
||||
quotedNames = append(quotedNames, fmt.Sprintf(`"%s"`, name))
|
||||
}
|
||||
wantResult := fmt.Sprintf(`[%s]`, strings.Join(quotedNames, ","))
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
dataset string
|
||||
wantStatusCode int
|
||||
wantInResult string
|
||||
wantInError string
|
||||
}{
|
||||
{
|
||||
name: "invoke on allowed dataset",
|
||||
dataset: allowedDatasetName,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantInResult: wantResult,
|
||||
},
|
||||
{
|
||||
name: "invoke on disallowed dataset",
|
||||
dataset: disallowedDatasetName,
|
||||
wantStatusCode: http.StatusBadRequest, // Or the specific error code returned
|
||||
wantInError: fmt.Sprintf("access denied to dataset '%s'", disallowedDatasetName),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"dataset":"%s"}`, tc.dataset)))
|
||||
req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/list-table-ids-restricted/invoke", body)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
if tc.wantInResult != "" {
|
||||
var respBody map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil {
|
||||
t.Fatalf("error parsing response body: %v", err)
|
||||
}
|
||||
got, ok := respBody["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
var gotSlice []string
|
||||
if err := json.Unmarshal([]byte(got), &gotSlice); err != nil {
|
||||
t.Fatalf("error unmarshalling result: %v", err)
|
||||
}
|
||||
sort.Strings(gotSlice)
|
||||
sortedGotBytes, err := json.Marshal(gotSlice)
|
||||
if err != nil {
|
||||
t.Fatalf("error marshalling sorted result: %v", err)
|
||||
}
|
||||
|
||||
if string(sortedGotBytes) != tc.wantInResult {
|
||||
t.Errorf("unexpected result: got %q, want %q", string(sortedGotBytes), tc.wantInResult)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.wantInError != "" {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
if !strings.Contains(string(bodyBytes), tc.wantInError) {
|
||||
t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user