feat(tool/bigquery-forecast)!: add allowed datasets support to forecast (#1412)

## Description
---

This introduces a breaking change. The bigquery-forecast tool will now
enforce the allowed datasets setting from its BigQuery source
configuration. Previously, this setting had no effect on the tool.

## PR Checklist

---
> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [ ] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [ ] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [ ] Ensure the tests and linter pass
- [ ] Code coverage does not decrease (if any source code was changed)
- [ ] Appropriate docs were updated (if necessary)
- [ ] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #<issue_number_goes_here>

---------

Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
This commit is contained in:
Huan Chen
2025-09-24 16:59:00 -07:00
committed by GitHub
parent 20cb43a249
commit 88bac7e36f
6 changed files with 239 additions and 84 deletions

View File

@@ -33,6 +33,13 @@ query based on the provided parameters:
- **horizon** (integer, optional): The number of future time steps you want to
predict. It defaults to 10 if not specified.
The tool's behavior regarding these parameters is influenced by the `allowedDatasets` restriction on the `bigquery` source:
- **Without `allowedDatasets` restriction:** The tool can use any table or query for the `history_data` parameter.
- **With `allowedDatasets` restriction:** The tool verifies that the `history_data` parameter only accesses tables
within the allowed datasets. If `history_data` is a table ID, the tool checks if the table's dataset is in the
allowed list. If `history_data` is a query, the tool performs a dry run to analyze the query and rejects it
if it accesses any table outside the allowed list.
## Example
```yaml

View File

@@ -0,0 +1,55 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bigquerycommon
import (
"context"
"fmt"
bigqueryapi "cloud.google.com/go/bigquery"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
)
// DryRunQuery performs a dry run of the SQL query to validate it and get metadata.
func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, projectID string, location string, sql string, params []*bigqueryrestapi.QueryParameter, connProps []*bigqueryapi.ConnectionProperty) (*bigqueryrestapi.Job, error) {
useLegacySql := false
restConnProps := make([]*bigqueryrestapi.ConnectionProperty, len(connProps))
for i, prop := range connProps {
restConnProps[i] = &bigqueryrestapi.ConnectionProperty{Key: prop.Key, Value: prop.Value}
}
jobToInsert := &bigqueryrestapi.Job{
JobReference: &bigqueryrestapi.JobReference{
ProjectId: projectID,
Location: location,
},
Configuration: &bigqueryrestapi.JobConfiguration{
DryRun: true,
Query: &bigqueryrestapi.JobConfigurationQuery{
Query: sql,
UseLegacySql: &useLegacySql,
ConnectionProperties: restConnProps,
QueryParameters: params,
},
},
}
insertResponse, err := restService.Jobs.Insert(projectID, jobToInsert).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to insert dry run job: %w", err)
}
return insertResponse, nil
}

View File

@@ -24,6 +24,7 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
"github.com/googleapis/genai-toolbox/internal/util"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator"
@@ -160,7 +161,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
}
}
dryRunJob, err := dryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql)
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, nil)
if err != nil {
return nil, fmt.Errorf("query validation failed during dry run: %w", err)
}
@@ -248,27 +249,3 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
func (t Tool) RequiresClientAuthorization() bool {
return t.UseClientOAuth
}
// dryRunQuery performs a dry run of the SQL query to validate it and get metadata.
func dryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, projectID string, location string, sql string) (*bigqueryrestapi.Job, error) {
useLegacySql := false
jobToInsert := &bigqueryrestapi.Job{
JobReference: &bigqueryrestapi.JobReference{
ProjectId: projectID,
Location: location,
},
Configuration: &bigqueryrestapi.JobConfiguration{
DryRun: true,
Query: &bigqueryrestapi.JobConfigurationQuery{
Query: sql,
UseLegacySql: &useLegacySql,
},
},
}
insertResponse, err := restService.Jobs.Insert(projectID, jobToInsert).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to insert dry run job: %w", err)
}
return insertResponse, nil
}

View File

@@ -24,6 +24,7 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
"github.com/googleapis/genai-toolbox/internal/util"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator"
@@ -50,6 +51,8 @@ type compatibleSource interface {
BigQueryRestService() *bigqueryrestapi.Service
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
IsDatasetAllowed(projectID, datasetID string) bool
BigQueryAllowedDatasets() []string
}
// validate compatible sources are still compatible
@@ -85,8 +88,17 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
historyDataParameter := tools.NewStringParameter("history_data",
"The table id or the query of the history time series data.")
allowedDatasets := s.BigQueryAllowedDatasets()
historyDataDescription := "The table id or the query of the history time series data."
if len(allowedDatasets) > 0 {
datasetIDs := []string{}
for _, ds := range allowedDatasets {
datasetIDs = append(datasetIDs, fmt.Sprintf("`%s`", ds))
}
historyDataDescription += fmt.Sprintf(" The query or table must only access datasets from the following list: %s.", strings.Join(datasetIDs, ", "))
}
historyDataParameter := tools.NewStringParameter("history_data", historyDataDescription)
timestampColumnNameParameter := tools.NewStringParameter("timestamp_col",
"The name of the time series timestamp column.")
dataColumnNameParameter := tools.NewStringParameter("data_col",
@@ -106,16 +118,18 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
IsDatasetAllowed: s.IsDatasetAllowed,
AllowedDatasets: allowedDatasets,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -130,11 +144,13 @@ type Tool struct {
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
@@ -175,8 +191,48 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
var historyDataSource string
trimmedUpperHistoryData := strings.TrimSpace(strings.ToUpper(historyData))
if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") {
if len(t.AllowedDatasets) > 0 {
dryRunJob, err := bqutil.DryRunQuery(ctx, t.RestService, t.Client.Project(), t.Client.Location, historyData, nil, nil)
if err != nil {
return nil, fmt.Errorf("query validation failed during dry run: %w", err)
}
statementType := dryRunJob.Statistics.Query.StatementType
if statementType != "SELECT" {
return nil, fmt.Errorf("the 'history_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", statementType)
}
queryStats := dryRunJob.Statistics.Query
if queryStats != nil {
for _, tableRef := range queryStats.ReferencedTables {
if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
return nil, fmt.Errorf("query in history_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
}
}
} else {
return nil, fmt.Errorf("could not analyze query in history_data to validate against allowed datasets")
}
}
historyDataSource = fmt.Sprintf("(%s)", historyData)
} else {
if len(t.AllowedDatasets) > 0 {
parts := strings.Split(historyData, ".")
var projectID, datasetID string
switch len(parts) {
case 3: // project.dataset.table
projectID = parts[0]
datasetID = parts[1]
case 2: // dataset.table
projectID = t.Client.Project()
datasetID = parts[0]
default:
return nil, fmt.Errorf("invalid table ID format for 'history_data': %q. Expected 'dataset.table' or 'project.dataset.table'", historyData)
}
if !t.IsDatasetAllowed(projectID, datasetID) {
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, historyData)
}
}
historyDataSource = fmt.Sprintf("TABLE `%s`", historyData)
}

View File

@@ -26,6 +26,7 @@ import (
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator"
)
@@ -236,7 +237,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
query.Parameters = highLevelParams
query.Location = bqClient.Location
dryRunJob, err := dryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, query.ConnectionProperties)
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, query.ConnectionProperties)
if err != nil {
// This is a fallback check in case the switch logic was bypassed.
return nil, fmt.Errorf("final query validation failed: %w", err)
@@ -319,42 +320,3 @@ func BQTypeStringFromToolType(toolType string) (string, error) {
return "", fmt.Errorf("unsupported tool parameter type for BigQuery: %s", toolType)
}
}
func dryRunQuery(
ctx context.Context,
restService *bigqueryrestapi.Service,
projectID string,
location string,
sql string,
params []*bigqueryrestapi.QueryParameter,
connProps []*bigqueryapi.ConnectionProperty,
) (*bigqueryrestapi.Job, error) {
useLegacySql := false
restConnProps := make([]*bigqueryrestapi.ConnectionProperty, len(connProps))
for i, prop := range connProps {
restConnProps[i] = &bigqueryrestapi.ConnectionProperty{Key: prop.Key, Value: prop.Value}
}
jobToInsert := &bigqueryrestapi.Job{
JobReference: &bigqueryrestapi.JobReference{
ProjectId: projectID,
Location: location,
},
Configuration: &bigqueryrestapi.JobConfiguration{
DryRun: true,
Query: &bigqueryrestapi.JobConfigurationQuery{
Query: sql,
UseLegacySql: &useLegacySql,
ConnectionProperties: restConnProps,
QueryParameters: params,
},
},
}
insertResponse, err := restService.Jobs.Insert(projectID, jobToInsert).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to insert dry run job: %w", err)
}
return insertResponse, nil
}

View File

@@ -74,7 +74,7 @@ func initBigQueryConnection(project string) (*bigqueryapi.Client, error) {
func TestBigQueryToolEndpoints(t *testing.T) {
sourceConfig := getBigQueryVars(t)
ctx, cancel := context.WithTimeout(context.Background(), 6*time.Minute)
ctx, cancel := context.WithTimeout(context.Background(), 7*time.Minute)
defer cancel()
var args []string
@@ -204,7 +204,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
}
func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, err := initBigQueryConnection(BigqueryProject)
@@ -274,6 +274,11 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
"source": "my-instance",
"description": "Tool to ask BigQuery conversational analytics",
},
"forecast-restricted": map[string]any{
"kind": "bigquery-forecast",
"source": "my-instance",
"description": "Tool to forecast",
},
}
// Create config file
@@ -304,6 +309,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2)
runConversationalAnalyticsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, disallowedTableName)
runConversationalAnalyticsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, disallowedTableName)
runForecastWithRestriction(t, allowedForecastTableFullName1, disallowedForecastTableFullName)
runForecastWithRestriction(t, allowedForecastTableFullName2, disallowedForecastTableFullName)
}
// getBigQueryParamToolInfo returns statements and param for my-tool for bigquery kind
@@ -2384,3 +2391,94 @@ func runBigQuerySearchCatalogToolInvokeTest(t *testing.T, datasetName string, ta
})
}
}
func runForecastWithRestriction(t *testing.T, allowedTableFullName, disallowedTableFullName string) {
allowedTableUnquoted := strings.ReplaceAll(allowedTableFullName, "`", "")
disallowedTableUnquoted := strings.ReplaceAll(disallowedTableFullName, "`", "")
disallowedDatasetFQN := strings.Join(strings.Split(disallowedTableUnquoted, ".")[0:2], ".")
testCases := []struct {
name string
historyData string
wantStatusCode int
wantInResult string
wantInError string
}{
{
name: "invoke with allowed table name",
historyData: allowedTableUnquoted,
wantStatusCode: http.StatusOK,
wantInResult: `"forecast_timestamp"`,
},
{
name: "invoke with disallowed table name",
historyData: disallowedTableUnquoted,
wantStatusCode: http.StatusBadRequest,
wantInError: fmt.Sprintf("access to dataset '%s' (from table '%s') is not allowed", disallowedDatasetFQN, disallowedTableUnquoted),
},
{
name: "invoke with query on allowed table",
historyData: fmt.Sprintf("SELECT * FROM %s", allowedTableFullName),
wantStatusCode: http.StatusOK,
wantInResult: `"forecast_timestamp"`,
},
{
name: "invoke with query on disallowed table",
historyData: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName),
wantStatusCode: http.StatusBadRequest,
wantInError: fmt.Sprintf("query in history_data accesses dataset '%s', which is not in the allowed list", disallowedDatasetFQN),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
requestBodyMap := map[string]any{
"history_data": tc.historyData,
"timestamp_col": "ts",
"data_col": "data",
}
bodyBytes, err := json.Marshal(requestBodyMap)
if err != nil {
t.Fatalf("failed to marshal request body: %v", err)
}
body := bytes.NewBuffer(bodyBytes)
req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/forecast-restricted/invoke", body)
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != tc.wantStatusCode {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
}
if tc.wantInResult != "" {
var respBody map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil {
t.Fatalf("error parsing response body: %v", err)
}
got, ok := respBody["result"].(string)
if !ok {
t.Fatalf("unable to find result in response body")
}
if !strings.Contains(got, tc.wantInResult) {
t.Errorf("unexpected result: got %q, want to contain %q", got, tc.wantInResult)
}
}
if tc.wantInError != "" {
bodyBytes, _ := io.ReadAll(resp.Body)
if !strings.Contains(string(bodyBytes), tc.wantInError) {
t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError)
}
}
})
}
}