diff --git a/docs/en/resources/tools/bigquery/bigquery-analyze-contribution.md b/docs/en/resources/tools/bigquery/bigquery-analyze-contribution.md index 64060eaab5..97af3c8e16 100644 --- a/docs/en/resources/tools/bigquery/bigquery-analyze-contribution.md +++ b/docs/en/resources/tools/bigquery/bigquery-analyze-contribution.md @@ -46,6 +46,13 @@ The behavior of this tool is influenced by the `writeMode` setting on its `bigqu tools using the same source. This allows the `input_data` parameter to be a query that references temporary resources (e.g., `TEMP` tables) created within that session. +The tool's behavior is also influenced by the `allowedDatasets` restriction on the `bigquery` source: + +- **Without `allowedDatasets` restriction:** The tool can use any table or query for the `input_data` parameter. +- **With `allowedDatasets` restriction:** The tool verifies that the `input_data` parameter only accesses tables within the allowed datasets. + - If `input_data` is a table ID, the tool checks if the table's dataset is in the allowed list. + - If `input_data` is a query, the tool performs a dry run to analyze the query and rejects it if it accesses any table outside the allowed list. + ## Example diff --git a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go index 971cb918f2..a5115e7a2a 100644 --- a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go +++ b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go @@ -25,6 +25,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" "github.com/googleapis/genai-toolbox/internal/tools" + bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" bigqueryrestapi "google.golang.org/api/bigquery/v2" "google.golang.org/api/iterator" ) @@ -50,6 +51,8 @@ type compatibleSource interface { BigQueryRestService() *bigqueryrestapi.Service BigQueryClientCreator() bigqueryds.BigqueryClientCreator UseClientAuthorization() bool + IsDatasetAllowed(projectID, datasetID string) bool + BigQueryAllowedDatasets() []string BigQuerySession() bigqueryds.BigQuerySessionProvider } @@ -86,8 +89,17 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) } - inputDataParameter := tools.NewStringParameter("input_data", - "The data that contain the test and control data to analyze. Can be a fully qualified BigQuery table ID or a SQL query.") + allowedDatasets := s.BigQueryAllowedDatasets() + inputDataDescription := "The data that contain the test and control data to analyze. Can be a fully qualified BigQuery table ID or a SQL query." + if len(allowedDatasets) > 0 { + datasetIDs := []string{} + for _, ds := range allowedDatasets { + datasetIDs = append(datasetIDs, fmt.Sprintf("`%s`", ds)) + } + inputDataDescription += fmt.Sprintf(" The query or table must only access datasets from the following list: %s.", strings.Join(datasetIDs, ", ")) + } + + inputDataParameter := tools.NewStringParameter("input_data", inputDataDescription) contributionMetricParameter := tools.NewStringParameter("contribution_metric", `The name of the column that contains the metric to analyze. Provides the expression to use to calculate the metric you are analyzing. @@ -123,17 +135,19 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Name: cfg.Name, - Kind: kind, - Parameters: parameters, - AuthRequired: cfg.AuthRequired, - UseClientOAuth: s.UseClientAuthorization(), - ClientCreator: s.BigQueryClientCreator(), - Client: s.BigQueryClient(), - RestService: s.BigQueryRestService(), - SessionProvider: s.BigQuerySession(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Name: cfg.Name, + Kind: kind, + Parameters: parameters, + AuthRequired: cfg.AuthRequired, + UseClientOAuth: s.UseClientAuthorization(), + ClientCreator: s.BigQueryClientCreator(), + Client: s.BigQueryClient(), + RestService: s.BigQueryRestService(), + IsDatasetAllowed: s.IsDatasetAllowed, + AllowedDatasets: allowedDatasets, + SessionProvider: s.BigQuerySession(), + manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -148,12 +162,14 @@ type Tool struct { UseClientOAuth bool `yaml:"useClientOAuth"` Parameters tools.Parameters `yaml:"parameters"` - Client *bigqueryapi.Client - RestService *bigqueryrestapi.Service - ClientCreator bigqueryds.BigqueryClientCreator - SessionProvider bigqueryds.BigQuerySessionProvider - manifest tools.Manifest - mcpManifest tools.McpManifest + Client *bigqueryapi.Client + RestService *bigqueryrestapi.Service + ClientCreator bigqueryds.BigqueryClientCreator + IsDatasetAllowed func(projectID, datasetID string) bool + AllowedDatasets []string + SessionProvider bigqueryds.BigQuerySessionProvider + manifest tools.Manifest + mcpManifest tools.McpManifest } // Invoke runs the contribution analysis. @@ -164,6 +180,22 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken return nil, fmt.Errorf("unable to cast input_data parameter %s", paramsMap["input_data"]) } + bqClient := t.Client + restService := t.RestService + var err error + + // Initialize new client if using user OAuth token + if t.UseClientOAuth { + tokenStr, err := accessToken.ParseBearerToken() + if err != nil { + return nil, fmt.Errorf("error parsing access token: %w", err) + } + bqClient, restService, err = t.ClientCreator(tokenStr, true) + if err != nil { + return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) + } + } + modelID := fmt.Sprintf("contribution_analysis_model_%s", strings.ReplaceAll(uuid.New().String(), "-", "")) var options []string @@ -196,8 +228,54 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken var inputDataSource string trimmedUpperInputData := strings.TrimSpace(strings.ToUpper(inputData)) if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") { + if len(t.AllowedDatasets) > 0 { + var connProps []*bigqueryapi.ConnectionProperty + session, err := t.SessionProvider(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get BigQuery session: %w", err) + } + if session != nil { + connProps = []*bigqueryapi.ConnectionProperty{ + {Key: "session_id", Value: session.ID}, + } + } + dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, inputData, nil, connProps) + if err != nil { + return nil, fmt.Errorf("query validation failed: %w", err) + } + statementType := dryRunJob.Statistics.Query.StatementType + if statementType != "SELECT" { + return nil, fmt.Errorf("the 'input_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", statementType) + } + + queryStats := dryRunJob.Statistics.Query + if queryStats != nil { + for _, tableRef := range queryStats.ReferencedTables { + if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) { + return nil, fmt.Errorf("query in input_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId) + } + } + } else { + return nil, fmt.Errorf("could not analyze query in input_data to validate against allowed datasets") + } + } inputDataSource = fmt.Sprintf("(%s)", inputData) } else { + if len(t.AllowedDatasets) > 0 { + parts := strings.Split(inputData, ".") + var projectID, datasetID string + switch len(parts) { + case 3: // project.dataset.table + projectID, datasetID = parts[0], parts[1] + case 2: // dataset.table + projectID, datasetID = t.Client.Project(), parts[0] + default: + return nil, fmt.Errorf("invalid table ID format for 'input_data': %q. Expected 'dataset.table' or 'project.dataset.table'", inputData) + } + if !t.IsDatasetAllowed(projectID, datasetID) { + return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, inputData) + } + } inputDataSource = fmt.Sprintf("SELECT * FROM `%s`", inputData) } @@ -209,21 +287,6 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken inputDataSource, ) - bqClient := t.Client - var err error - - // Initialize new client if using user OAuth token - if t.UseClientOAuth { - tokenStr, err := accessToken.ParseBearerToken() - if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) - } - bqClient, _, err = t.ClientCreator(tokenStr, false) - if err != nil { - return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) - } - } - createModelQuery := bqClient.Query(createModelSQL) // Get session from provider if in protected mode. diff --git a/tests/bigquery/bigquery_integration_test.go b/tests/bigquery/bigquery_integration_test.go index 40a0fee825..081ddeb0dc 100644 --- a/tests/bigquery/bigquery_integration_test.go +++ b/tests/bigquery/bigquery_integration_test.go @@ -205,7 +205,7 @@ func TestBigQueryToolEndpoints(t *testing.T) { } func TestBigQueryToolWithDatasetRestriction(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Minute) defer cancel() client, err := initBigQueryConnection(BigqueryProject) @@ -225,6 +225,9 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) { allowedForecastTableName2 := "allowed_forecast_table_2" disallowedForecastTableName := "disallowed_forecast_table" + allowedAnalyzeContributionTableName1 := "allowed_analyze_contribution_table_1" + allowedAnalyzeContributionTableName2 := "allowed_analyze_contribution_table_2" + disallowedAnalyzeContributionTableName := "disallowed_analyze_contribution_table" // Setup allowed table allowedTableNameParam1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedTableName1) createAllowedTableStmt1 := fmt.Sprintf("CREATE TABLE %s (id INT64)", allowedTableNameParam1) @@ -259,6 +262,23 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) { teardownDisallowedForecast := setupBigQueryTable(t, ctx, client, createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedDatasetName, disallowedForecastTableFullName, disallowedForecastParams) defer teardownDisallowedForecast(t) + // Setup allowed analyze contribution table + allowedAnalyzeContributionTableFullName1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedAnalyzeContributionTableName1) + createAnalyzeContributionStmt1, insertAnalyzeContributionStmt1, analyzeContributionParams1 := getBigQueryAnalyzeContributionToolInfo(allowedAnalyzeContributionTableFullName1) + teardownAllowedAnalyzeContribution1 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt1, insertAnalyzeContributionStmt1, allowedDatasetName1, allowedAnalyzeContributionTableFullName1, analyzeContributionParams1) + defer teardownAllowedAnalyzeContribution1(t) + + allowedAnalyzeContributionTableFullName2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedAnalyzeContributionTableName2) + createAnalyzeContributionStmt2, insertAnalyzeContributionStmt2, analyzeContributionParams2 := getBigQueryAnalyzeContributionToolInfo(allowedAnalyzeContributionTableFullName2) + teardownAllowedAnalyzeContribution2 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt2, insertAnalyzeContributionStmt2, allowedDatasetName2, allowedAnalyzeContributionTableFullName2, analyzeContributionParams2) + defer teardownAllowedAnalyzeContribution2(t) + + // Setup disallowed analyze contribution table + disallowedAnalyzeContributionTableFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedAnalyzeContributionTableName) + createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedAnalyzeContributionParams := getBigQueryAnalyzeContributionToolInfo(disallowedAnalyzeContributionTableFullName) + teardownDisallowedAnalyzeContribution := setupBigQueryTable(t, ctx, client, createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedDatasetName, disallowedAnalyzeContributionTableFullName, disallowedAnalyzeContributionParams) + defer teardownDisallowedAnalyzeContribution(t) + // Configure source with dataset restriction. sourceConfig := getBigQueryVars(t) sourceConfig["allowedDatasets"] = []string{allowedDatasetName1, allowedDatasetName2} @@ -300,6 +320,11 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) { "source": "my-instance", "description": "Tool to forecast", }, + "analyze-contribution-restricted": map[string]any{ + "kind": "bigquery-analyze-contribution", + "source": "my-instance", + "description": "Tool to analyze contribution", + }, } // Create config file @@ -327,8 +352,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) { // Run tests runListDatasetIdsWithRestriction(t, allowedDatasetName1, allowedDatasetName2) - runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1) - runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2) + runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1, allowedAnalyzeContributionTableName1) + runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2, allowedAnalyzeContributionTableName2) runGetDatasetInfoWithRestriction(t, allowedDatasetName1, disallowedDatasetName) runGetDatasetInfoWithRestriction(t, allowedDatasetName2, disallowedDatasetName) runGetTableInfoWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, disallowedTableName) @@ -339,6 +364,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) { runConversationalAnalyticsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, disallowedTableName) runForecastWithRestriction(t, allowedForecastTableFullName1, disallowedForecastTableFullName) runForecastWithRestriction(t, allowedForecastTableFullName2, disallowedForecastTableFullName) + runAnalyzeContributionWithRestriction(t, allowedAnalyzeContributionTableFullName1, disallowedAnalyzeContributionTableFullName) + runAnalyzeContributionWithRestriction(t, allowedAnalyzeContributionTableFullName2, disallowedAnalyzeContributionTableFullName) } func TestBigQueryWriteModeAllowed(t *testing.T) { @@ -3125,3 +3152,86 @@ func runForecastWithRestriction(t *testing.T, allowedTableFullName, disallowedTa }) } } + +func runAnalyzeContributionWithRestriction(t *testing.T, allowedTableFullName, disallowedTableFullName string) { + allowedTableUnquoted := strings.ReplaceAll(allowedTableFullName, "`", "") + disallowedTableUnquoted := strings.ReplaceAll(disallowedTableFullName, "`", "") + disallowedDatasetFQN := strings.Join(strings.Split(disallowedTableUnquoted, ".")[0:2], ".") + + testCases := []struct { + name string + inputData string + wantStatusCode int + wantInResult string + wantInError string + }{ + { + name: "invoke with allowed table name", + inputData: allowedTableUnquoted, + wantStatusCode: http.StatusOK, + wantInResult: `"relative_difference"`, + }, + { + name: "invoke with disallowed table name", + inputData: disallowedTableUnquoted, + wantStatusCode: http.StatusBadRequest, + wantInError: fmt.Sprintf("access to dataset '%s' (from table '%s') is not allowed", disallowedDatasetFQN, disallowedTableUnquoted), + }, + { + name: "invoke with query on allowed table", + inputData: fmt.Sprintf("SELECT * FROM %s", allowedTableFullName), + wantStatusCode: http.StatusOK, + wantInResult: `"relative_difference"`, + }, + { + name: "invoke with query on disallowed table", + inputData: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName), + wantStatusCode: http.StatusBadRequest, + wantInError: fmt.Sprintf("query in input_data accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + requestBodyMap := map[string]any{ + "input_data": tc.inputData, + "contribution_metric": "SUM(metric)", + "is_test_col": "is_test", + "dimension_id_cols": []string{"dim1", "dim2"}, + } + bodyBytes, err := json.Marshal(requestBodyMap) + if err != nil { + t.Fatalf("failed to marshal request body: %v", err) + } + body := bytes.NewBuffer(bodyBytes) + + resp, bodyBytes := tests.RunRequest(t, http.MethodPost, "http://127.0.0.1:5000/api/tool/analyze-contribution-restricted/invoke", body, nil) + + if resp.StatusCode != tc.wantStatusCode { + t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes)) + } + + var respBody map[string]interface{} + if err := json.Unmarshal(bodyBytes, &respBody); err != nil { + t.Fatalf("error parsing response body: %v", err) + } + + if tc.wantInResult != "" { + got, ok := respBody["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } + + if !strings.Contains(got, tc.wantInResult) { + t.Errorf("unexpected result: got %q, want to contain %q", string(bodyBytes), tc.wantInResult) + } + } + + if tc.wantInError != "" { + if !strings.Contains(string(bodyBytes), tc.wantInError) { + t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError) + } + } + }) + } +}