feat(tools/bigquery-analyze-contribution): add allowed dataset support (#1675)

## Description

> Should include a concise description of the changes (bug or feature),
it's
> impact, along with a summary of the solution

## PR Checklist

> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [ ] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [ ] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
  before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [ ] Ensure the tests and linter pass
- [ ] Code coverage does not decrease (if any source code was changed)
- [ ] Appropriate docs were updated (if necessary)
- [ ] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #<issue_number_goes_here>

---------

Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
This commit is contained in:
Huan Chen
2025-10-30 14:44:30 -07:00
committed by GitHub
parent a2006ad577
commit ef28e39e90
3 changed files with 217 additions and 37 deletions

View File

@@ -46,6 +46,13 @@ The behavior of this tool is influenced by the `writeMode` setting on its `bigqu
tools using the same source. This allows the `input_data` parameter to be a query that references temporary resources (e.g.,
`TEMP` tables) created within that session.
The tool's behavior is also influenced by the `allowedDatasets` restriction on the `bigquery` source:
- **Without `allowedDatasets` restriction:** The tool can use any table or query for the `input_data` parameter.
- **With `allowedDatasets` restriction:** The tool verifies that the `input_data` parameter only accesses tables within the allowed datasets.
- If `input_data` is a table ID, the tool checks if the table's dataset is in the allowed list.
- If `input_data` is a query, the tool performs a dry run to analyze the query and rejects it if it accesses any table outside the allowed list.
## Example

View File

@@ -25,6 +25,7 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator"
)
@@ -50,6 +51,8 @@ type compatibleSource interface {
BigQueryRestService() *bigqueryrestapi.Service
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
IsDatasetAllowed(projectID, datasetID string) bool
BigQueryAllowedDatasets() []string
BigQuerySession() bigqueryds.BigQuerySessionProvider
}
@@ -86,8 +89,17 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
inputDataParameter := tools.NewStringParameter("input_data",
"The data that contain the test and control data to analyze. Can be a fully qualified BigQuery table ID or a SQL query.")
allowedDatasets := s.BigQueryAllowedDatasets()
inputDataDescription := "The data that contain the test and control data to analyze. Can be a fully qualified BigQuery table ID or a SQL query."
if len(allowedDatasets) > 0 {
datasetIDs := []string{}
for _, ds := range allowedDatasets {
datasetIDs = append(datasetIDs, fmt.Sprintf("`%s`", ds))
}
inputDataDescription += fmt.Sprintf(" The query or table must only access datasets from the following list: %s.", strings.Join(datasetIDs, ", "))
}
inputDataParameter := tools.NewStringParameter("input_data", inputDataDescription)
contributionMetricParameter := tools.NewStringParameter("contribution_metric",
`The name of the column that contains the metric to analyze.
Provides the expression to use to calculate the metric you are analyzing.
@@ -123,17 +135,19 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
SessionProvider: s.BigQuerySession(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
IsDatasetAllowed: s.IsDatasetAllowed,
AllowedDatasets: allowedDatasets,
SessionProvider: s.BigQuerySession(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -148,12 +162,14 @@ type Tool struct {
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
SessionProvider bigqueryds.BigQuerySessionProvider
manifest tools.Manifest
mcpManifest tools.McpManifest
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
SessionProvider bigqueryds.BigQuerySessionProvider
manifest tools.Manifest
mcpManifest tools.McpManifest
}
// Invoke runs the contribution analysis.
@@ -164,6 +180,22 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("unable to cast input_data parameter %s", paramsMap["input_data"])
}
bqClient := t.Client
restService := t.RestService
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = t.ClientCreator(tokenStr, true)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}
modelID := fmt.Sprintf("contribution_analysis_model_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
var options []string
@@ -196,8 +228,54 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
var inputDataSource string
trimmedUpperInputData := strings.TrimSpace(strings.ToUpper(inputData))
if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") {
if len(t.AllowedDatasets) > 0 {
var connProps []*bigqueryapi.ConnectionProperty
session, err := t.SessionProvider(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
if session != nil {
connProps = []*bigqueryapi.ConnectionProperty{
{Key: "session_id", Value: session.ID},
}
}
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, inputData, nil, connProps)
if err != nil {
return nil, fmt.Errorf("query validation failed: %w", err)
}
statementType := dryRunJob.Statistics.Query.StatementType
if statementType != "SELECT" {
return nil, fmt.Errorf("the 'input_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", statementType)
}
queryStats := dryRunJob.Statistics.Query
if queryStats != nil {
for _, tableRef := range queryStats.ReferencedTables {
if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
return nil, fmt.Errorf("query in input_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
}
}
} else {
return nil, fmt.Errorf("could not analyze query in input_data to validate against allowed datasets")
}
}
inputDataSource = fmt.Sprintf("(%s)", inputData)
} else {
if len(t.AllowedDatasets) > 0 {
parts := strings.Split(inputData, ".")
var projectID, datasetID string
switch len(parts) {
case 3: // project.dataset.table
projectID, datasetID = parts[0], parts[1]
case 2: // dataset.table
projectID, datasetID = t.Client.Project(), parts[0]
default:
return nil, fmt.Errorf("invalid table ID format for 'input_data': %q. Expected 'dataset.table' or 'project.dataset.table'", inputData)
}
if !t.IsDatasetAllowed(projectID, datasetID) {
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, inputData)
}
}
inputDataSource = fmt.Sprintf("SELECT * FROM `%s`", inputData)
}
@@ -209,21 +287,6 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
inputDataSource,
)
bqClient := t.Client
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}
createModelQuery := bqClient.Query(createModelSQL)
// Get session from provider if in protected mode.

View File

@@ -205,7 +205,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
}
func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Minute)
defer cancel()
client, err := initBigQueryConnection(BigqueryProject)
@@ -225,6 +225,9 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
allowedForecastTableName2 := "allowed_forecast_table_2"
disallowedForecastTableName := "disallowed_forecast_table"
allowedAnalyzeContributionTableName1 := "allowed_analyze_contribution_table_1"
allowedAnalyzeContributionTableName2 := "allowed_analyze_contribution_table_2"
disallowedAnalyzeContributionTableName := "disallowed_analyze_contribution_table"
// Setup allowed table
allowedTableNameParam1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedTableName1)
createAllowedTableStmt1 := fmt.Sprintf("CREATE TABLE %s (id INT64)", allowedTableNameParam1)
@@ -259,6 +262,23 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
teardownDisallowedForecast := setupBigQueryTable(t, ctx, client, createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedDatasetName, disallowedForecastTableFullName, disallowedForecastParams)
defer teardownDisallowedForecast(t)
// Setup allowed analyze contribution table
allowedAnalyzeContributionTableFullName1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedAnalyzeContributionTableName1)
createAnalyzeContributionStmt1, insertAnalyzeContributionStmt1, analyzeContributionParams1 := getBigQueryAnalyzeContributionToolInfo(allowedAnalyzeContributionTableFullName1)
teardownAllowedAnalyzeContribution1 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt1, insertAnalyzeContributionStmt1, allowedDatasetName1, allowedAnalyzeContributionTableFullName1, analyzeContributionParams1)
defer teardownAllowedAnalyzeContribution1(t)
allowedAnalyzeContributionTableFullName2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedAnalyzeContributionTableName2)
createAnalyzeContributionStmt2, insertAnalyzeContributionStmt2, analyzeContributionParams2 := getBigQueryAnalyzeContributionToolInfo(allowedAnalyzeContributionTableFullName2)
teardownAllowedAnalyzeContribution2 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt2, insertAnalyzeContributionStmt2, allowedDatasetName2, allowedAnalyzeContributionTableFullName2, analyzeContributionParams2)
defer teardownAllowedAnalyzeContribution2(t)
// Setup disallowed analyze contribution table
disallowedAnalyzeContributionTableFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedAnalyzeContributionTableName)
createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedAnalyzeContributionParams := getBigQueryAnalyzeContributionToolInfo(disallowedAnalyzeContributionTableFullName)
teardownDisallowedAnalyzeContribution := setupBigQueryTable(t, ctx, client, createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedDatasetName, disallowedAnalyzeContributionTableFullName, disallowedAnalyzeContributionParams)
defer teardownDisallowedAnalyzeContribution(t)
// Configure source with dataset restriction.
sourceConfig := getBigQueryVars(t)
sourceConfig["allowedDatasets"] = []string{allowedDatasetName1, allowedDatasetName2}
@@ -300,6 +320,11 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
"source": "my-instance",
"description": "Tool to forecast",
},
"analyze-contribution-restricted": map[string]any{
"kind": "bigquery-analyze-contribution",
"source": "my-instance",
"description": "Tool to analyze contribution",
},
}
// Create config file
@@ -327,8 +352,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
// Run tests
runListDatasetIdsWithRestriction(t, allowedDatasetName1, allowedDatasetName2)
runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1)
runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2)
runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1, allowedAnalyzeContributionTableName1)
runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2, allowedAnalyzeContributionTableName2)
runGetDatasetInfoWithRestriction(t, allowedDatasetName1, disallowedDatasetName)
runGetDatasetInfoWithRestriction(t, allowedDatasetName2, disallowedDatasetName)
runGetTableInfoWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, disallowedTableName)
@@ -339,6 +364,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
runConversationalAnalyticsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, disallowedTableName)
runForecastWithRestriction(t, allowedForecastTableFullName1, disallowedForecastTableFullName)
runForecastWithRestriction(t, allowedForecastTableFullName2, disallowedForecastTableFullName)
runAnalyzeContributionWithRestriction(t, allowedAnalyzeContributionTableFullName1, disallowedAnalyzeContributionTableFullName)
runAnalyzeContributionWithRestriction(t, allowedAnalyzeContributionTableFullName2, disallowedAnalyzeContributionTableFullName)
}
func TestBigQueryWriteModeAllowed(t *testing.T) {
@@ -3125,3 +3152,86 @@ func runForecastWithRestriction(t *testing.T, allowedTableFullName, disallowedTa
})
}
}
func runAnalyzeContributionWithRestriction(t *testing.T, allowedTableFullName, disallowedTableFullName string) {
allowedTableUnquoted := strings.ReplaceAll(allowedTableFullName, "`", "")
disallowedTableUnquoted := strings.ReplaceAll(disallowedTableFullName, "`", "")
disallowedDatasetFQN := strings.Join(strings.Split(disallowedTableUnquoted, ".")[0:2], ".")
testCases := []struct {
name string
inputData string
wantStatusCode int
wantInResult string
wantInError string
}{
{
name: "invoke with allowed table name",
inputData: allowedTableUnquoted,
wantStatusCode: http.StatusOK,
wantInResult: `"relative_difference"`,
},
{
name: "invoke with disallowed table name",
inputData: disallowedTableUnquoted,
wantStatusCode: http.StatusBadRequest,
wantInError: fmt.Sprintf("access to dataset '%s' (from table '%s') is not allowed", disallowedDatasetFQN, disallowedTableUnquoted),
},
{
name: "invoke with query on allowed table",
inputData: fmt.Sprintf("SELECT * FROM %s", allowedTableFullName),
wantStatusCode: http.StatusOK,
wantInResult: `"relative_difference"`,
},
{
name: "invoke with query on disallowed table",
inputData: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName),
wantStatusCode: http.StatusBadRequest,
wantInError: fmt.Sprintf("query in input_data accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
requestBodyMap := map[string]any{
"input_data": tc.inputData,
"contribution_metric": "SUM(metric)",
"is_test_col": "is_test",
"dimension_id_cols": []string{"dim1", "dim2"},
}
bodyBytes, err := json.Marshal(requestBodyMap)
if err != nil {
t.Fatalf("failed to marshal request body: %v", err)
}
body := bytes.NewBuffer(bodyBytes)
resp, bodyBytes := tests.RunRequest(t, http.MethodPost, "http://127.0.0.1:5000/api/tool/analyze-contribution-restricted/invoke", body, nil)
if resp.StatusCode != tc.wantStatusCode {
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
}
var respBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &respBody); err != nil {
t.Fatalf("error parsing response body: %v", err)
}
if tc.wantInResult != "" {
got, ok := respBody["result"].(string)
if !ok {
t.Fatalf("unable to find result in response body")
}
if !strings.Contains(got, tc.wantInResult) {
t.Errorf("unexpected result: got %q, want to contain %q", string(bodyBytes), tc.wantInResult)
}
}
if tc.wantInError != "" {
if !strings.Contains(string(bodyBytes), tc.wantInError) {
t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError)
}
}
})
}
}