mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-08 15:14:00 -05:00
feat(tools/bigquery-analyze-contribution): add allowed dataset support (#1675)
## Description > Should include a concise description of the changes (bug or feature), it's > impact, along with a summary of the solution ## PR Checklist > Thank you for opening a Pull Request! Before submitting your PR, there are a > few things you can do to make sure it goes smoothly: - [ ] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) - [ ] Make sure to add `!` if this involve a breaking change 🛠️ Fixes #<issue_number_goes_here> --------- Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
This commit is contained in:
@@ -46,6 +46,13 @@ The behavior of this tool is influenced by the `writeMode` setting on its `bigqu
|
||||
tools using the same source. This allows the `input_data` parameter to be a query that references temporary resources (e.g.,
|
||||
`TEMP` tables) created within that session.
|
||||
|
||||
The tool's behavior is also influenced by the `allowedDatasets` restriction on the `bigquery` source:
|
||||
|
||||
- **Without `allowedDatasets` restriction:** The tool can use any table or query for the `input_data` parameter.
|
||||
- **With `allowedDatasets` restriction:** The tool verifies that the `input_data` parameter only accesses tables within the allowed datasets.
|
||||
- If `input_data` is a table ID, the tool checks if the table's dataset is in the allowed list.
|
||||
- If `input_data` is a query, the tool performs a dry run to analyze the query and rejects it if it accesses any table outside the allowed list.
|
||||
|
||||
|
||||
## Example
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||
"google.golang.org/api/iterator"
|
||||
)
|
||||
@@ -50,6 +51,8 @@ type compatibleSource interface {
|
||||
BigQueryRestService() *bigqueryrestapi.Service
|
||||
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
||||
UseClientAuthorization() bool
|
||||
IsDatasetAllowed(projectID, datasetID string) bool
|
||||
BigQueryAllowedDatasets() []string
|
||||
BigQuerySession() bigqueryds.BigQuerySessionProvider
|
||||
}
|
||||
|
||||
@@ -86,8 +89,17 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
inputDataParameter := tools.NewStringParameter("input_data",
|
||||
"The data that contain the test and control data to analyze. Can be a fully qualified BigQuery table ID or a SQL query.")
|
||||
allowedDatasets := s.BigQueryAllowedDatasets()
|
||||
inputDataDescription := "The data that contain the test and control data to analyze. Can be a fully qualified BigQuery table ID or a SQL query."
|
||||
if len(allowedDatasets) > 0 {
|
||||
datasetIDs := []string{}
|
||||
for _, ds := range allowedDatasets {
|
||||
datasetIDs = append(datasetIDs, fmt.Sprintf("`%s`", ds))
|
||||
}
|
||||
inputDataDescription += fmt.Sprintf(" The query or table must only access datasets from the following list: %s.", strings.Join(datasetIDs, ", "))
|
||||
}
|
||||
|
||||
inputDataParameter := tools.NewStringParameter("input_data", inputDataDescription)
|
||||
contributionMetricParameter := tools.NewStringParameter("contribution_metric",
|
||||
`The name of the column that contains the metric to analyze.
|
||||
Provides the expression to use to calculate the metric you are analyzing.
|
||||
@@ -123,17 +135,19 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
SessionProvider: s.BigQuerySession(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
AllowedDatasets: allowedDatasets,
|
||||
SessionProvider: s.BigQuerySession(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -148,12 +162,14 @@ type Tool struct {
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
SessionProvider bigqueryds.BigQuerySessionProvider
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
AllowedDatasets []string
|
||||
SessionProvider bigqueryds.BigQuerySessionProvider
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
// Invoke runs the contribution analysis.
|
||||
@@ -164,6 +180,22 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return nil, fmt.Errorf("unable to cast input_data parameter %s", paramsMap["input_data"])
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
restService := t.RestService
|
||||
var err error
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, restService, err = t.ClientCreator(tokenStr, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
modelID := fmt.Sprintf("contribution_analysis_model_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
|
||||
|
||||
var options []string
|
||||
@@ -196,8 +228,54 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
var inputDataSource string
|
||||
trimmedUpperInputData := strings.TrimSpace(strings.ToUpper(inputData))
|
||||
if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") {
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
var connProps []*bigqueryapi.ConnectionProperty
|
||||
session, err := t.SessionProvider(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
if session != nil {
|
||||
connProps = []*bigqueryapi.ConnectionProperty{
|
||||
{Key: "session_id", Value: session.ID},
|
||||
}
|
||||
}
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, inputData, nil, connProps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||
}
|
||||
statementType := dryRunJob.Statistics.Query.StatementType
|
||||
if statementType != "SELECT" {
|
||||
return nil, fmt.Errorf("the 'input_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", statementType)
|
||||
}
|
||||
|
||||
queryStats := dryRunJob.Statistics.Query
|
||||
if queryStats != nil {
|
||||
for _, tableRef := range queryStats.ReferencedTables {
|
||||
if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
|
||||
return nil, fmt.Errorf("query in input_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("could not analyze query in input_data to validate against allowed datasets")
|
||||
}
|
||||
}
|
||||
inputDataSource = fmt.Sprintf("(%s)", inputData)
|
||||
} else {
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
parts := strings.Split(inputData, ".")
|
||||
var projectID, datasetID string
|
||||
switch len(parts) {
|
||||
case 3: // project.dataset.table
|
||||
projectID, datasetID = parts[0], parts[1]
|
||||
case 2: // dataset.table
|
||||
projectID, datasetID = t.Client.Project(), parts[0]
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid table ID format for 'input_data': %q. Expected 'dataset.table' or 'project.dataset.table'", inputData)
|
||||
}
|
||||
if !t.IsDatasetAllowed(projectID, datasetID) {
|
||||
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, inputData)
|
||||
}
|
||||
}
|
||||
inputDataSource = fmt.Sprintf("SELECT * FROM `%s`", inputData)
|
||||
}
|
||||
|
||||
@@ -209,21 +287,6 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
inputDataSource,
|
||||
)
|
||||
|
||||
bqClient := t.Client
|
||||
var err error
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
createModelQuery := bqClient.Query(createModelSQL)
|
||||
|
||||
// Get session from provider if in protected mode.
|
||||
|
||||
@@ -205,7 +205,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, err := initBigQueryConnection(BigqueryProject)
|
||||
@@ -225,6 +225,9 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
|
||||
allowedForecastTableName2 := "allowed_forecast_table_2"
|
||||
disallowedForecastTableName := "disallowed_forecast_table"
|
||||
|
||||
allowedAnalyzeContributionTableName1 := "allowed_analyze_contribution_table_1"
|
||||
allowedAnalyzeContributionTableName2 := "allowed_analyze_contribution_table_2"
|
||||
disallowedAnalyzeContributionTableName := "disallowed_analyze_contribution_table"
|
||||
// Setup allowed table
|
||||
allowedTableNameParam1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedTableName1)
|
||||
createAllowedTableStmt1 := fmt.Sprintf("CREATE TABLE %s (id INT64)", allowedTableNameParam1)
|
||||
@@ -259,6 +262,23 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
|
||||
teardownDisallowedForecast := setupBigQueryTable(t, ctx, client, createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedDatasetName, disallowedForecastTableFullName, disallowedForecastParams)
|
||||
defer teardownDisallowedForecast(t)
|
||||
|
||||
// Setup allowed analyze contribution table
|
||||
allowedAnalyzeContributionTableFullName1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedAnalyzeContributionTableName1)
|
||||
createAnalyzeContributionStmt1, insertAnalyzeContributionStmt1, analyzeContributionParams1 := getBigQueryAnalyzeContributionToolInfo(allowedAnalyzeContributionTableFullName1)
|
||||
teardownAllowedAnalyzeContribution1 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt1, insertAnalyzeContributionStmt1, allowedDatasetName1, allowedAnalyzeContributionTableFullName1, analyzeContributionParams1)
|
||||
defer teardownAllowedAnalyzeContribution1(t)
|
||||
|
||||
allowedAnalyzeContributionTableFullName2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedAnalyzeContributionTableName2)
|
||||
createAnalyzeContributionStmt2, insertAnalyzeContributionStmt2, analyzeContributionParams2 := getBigQueryAnalyzeContributionToolInfo(allowedAnalyzeContributionTableFullName2)
|
||||
teardownAllowedAnalyzeContribution2 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt2, insertAnalyzeContributionStmt2, allowedDatasetName2, allowedAnalyzeContributionTableFullName2, analyzeContributionParams2)
|
||||
defer teardownAllowedAnalyzeContribution2(t)
|
||||
|
||||
// Setup disallowed analyze contribution table
|
||||
disallowedAnalyzeContributionTableFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedAnalyzeContributionTableName)
|
||||
createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedAnalyzeContributionParams := getBigQueryAnalyzeContributionToolInfo(disallowedAnalyzeContributionTableFullName)
|
||||
teardownDisallowedAnalyzeContribution := setupBigQueryTable(t, ctx, client, createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedDatasetName, disallowedAnalyzeContributionTableFullName, disallowedAnalyzeContributionParams)
|
||||
defer teardownDisallowedAnalyzeContribution(t)
|
||||
|
||||
// Configure source with dataset restriction.
|
||||
sourceConfig := getBigQueryVars(t)
|
||||
sourceConfig["allowedDatasets"] = []string{allowedDatasetName1, allowedDatasetName2}
|
||||
@@ -300,6 +320,11 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
|
||||
"source": "my-instance",
|
||||
"description": "Tool to forecast",
|
||||
},
|
||||
"analyze-contribution-restricted": map[string]any{
|
||||
"kind": "bigquery-analyze-contribution",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to analyze contribution",
|
||||
},
|
||||
}
|
||||
|
||||
// Create config file
|
||||
@@ -327,8 +352,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
|
||||
|
||||
// Run tests
|
||||
runListDatasetIdsWithRestriction(t, allowedDatasetName1, allowedDatasetName2)
|
||||
runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1)
|
||||
runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2)
|
||||
runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1, allowedAnalyzeContributionTableName1)
|
||||
runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2, allowedAnalyzeContributionTableName2)
|
||||
runGetDatasetInfoWithRestriction(t, allowedDatasetName1, disallowedDatasetName)
|
||||
runGetDatasetInfoWithRestriction(t, allowedDatasetName2, disallowedDatasetName)
|
||||
runGetTableInfoWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, disallowedTableName)
|
||||
@@ -339,6 +364,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
|
||||
runConversationalAnalyticsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, disallowedTableName)
|
||||
runForecastWithRestriction(t, allowedForecastTableFullName1, disallowedForecastTableFullName)
|
||||
runForecastWithRestriction(t, allowedForecastTableFullName2, disallowedForecastTableFullName)
|
||||
runAnalyzeContributionWithRestriction(t, allowedAnalyzeContributionTableFullName1, disallowedAnalyzeContributionTableFullName)
|
||||
runAnalyzeContributionWithRestriction(t, allowedAnalyzeContributionTableFullName2, disallowedAnalyzeContributionTableFullName)
|
||||
}
|
||||
|
||||
func TestBigQueryWriteModeAllowed(t *testing.T) {
|
||||
@@ -3125,3 +3152,86 @@ func runForecastWithRestriction(t *testing.T, allowedTableFullName, disallowedTa
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runAnalyzeContributionWithRestriction(t *testing.T, allowedTableFullName, disallowedTableFullName string) {
|
||||
allowedTableUnquoted := strings.ReplaceAll(allowedTableFullName, "`", "")
|
||||
disallowedTableUnquoted := strings.ReplaceAll(disallowedTableFullName, "`", "")
|
||||
disallowedDatasetFQN := strings.Join(strings.Split(disallowedTableUnquoted, ".")[0:2], ".")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputData string
|
||||
wantStatusCode int
|
||||
wantInResult string
|
||||
wantInError string
|
||||
}{
|
||||
{
|
||||
name: "invoke with allowed table name",
|
||||
inputData: allowedTableUnquoted,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantInResult: `"relative_difference"`,
|
||||
},
|
||||
{
|
||||
name: "invoke with disallowed table name",
|
||||
inputData: disallowedTableUnquoted,
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantInError: fmt.Sprintf("access to dataset '%s' (from table '%s') is not allowed", disallowedDatasetFQN, disallowedTableUnquoted),
|
||||
},
|
||||
{
|
||||
name: "invoke with query on allowed table",
|
||||
inputData: fmt.Sprintf("SELECT * FROM %s", allowedTableFullName),
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantInResult: `"relative_difference"`,
|
||||
},
|
||||
{
|
||||
name: "invoke with query on disallowed table",
|
||||
inputData: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName),
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantInError: fmt.Sprintf("query in input_data accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
requestBodyMap := map[string]any{
|
||||
"input_data": tc.inputData,
|
||||
"contribution_metric": "SUM(metric)",
|
||||
"is_test_col": "is_test",
|
||||
"dimension_id_cols": []string{"dim1", "dim2"},
|
||||
}
|
||||
bodyBytes, err := json.Marshal(requestBodyMap)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal request body: %v", err)
|
||||
}
|
||||
body := bytes.NewBuffer(bodyBytes)
|
||||
|
||||
resp, bodyBytes := tests.RunRequest(t, http.MethodPost, "http://127.0.0.1:5000/api/tool/analyze-contribution-restricted/invoke", body, nil)
|
||||
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var respBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &respBody); err != nil {
|
||||
t.Fatalf("error parsing response body: %v", err)
|
||||
}
|
||||
|
||||
if tc.wantInResult != "" {
|
||||
got, ok := respBody["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if !strings.Contains(got, tc.wantInResult) {
|
||||
t.Errorf("unexpected result: got %q, want to contain %q", string(bodyBytes), tc.wantInResult)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.wantInError != "" {
|
||||
if !strings.Contains(string(bodyBytes), tc.wantInError) {
|
||||
t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user