feat(tools/bigquery-conversational-analytics)!: Add allowed datasets support (#1411)

## Description

---
Add support to allowed datasets for conversational-analytics tool in
bigquery.

## PR Checklist

---
> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [ ] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [ ] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [ ] Ensure the tests and linter pass
- [ ] Code coverage does not decrease (if any source code was changed)
- [ ] Appropriate docs were updated (if necessary)
- [ ] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #<issue_number_goes_here>
This commit is contained in:
Huan Chen
2025-09-17 16:25:36 -07:00
committed by GitHub
parent 80b7488ad2
commit 345bd6af52
3 changed files with 123 additions and 5 deletions

View File

@@ -26,12 +26,23 @@ for instructions.
It's compatible with the following sources:
- [bigquery](../sources/bigquery.md)
- [bigquery](../../sources/bigquery.md)
The tool takes the following input parameters:
`bigquery-conversational-analytics` accepts the following parameters:
* `user_query_with_context`: The user's question, potentially including conversation history and system instructions for context.
* `table_references`: A JSON string of a list of BigQuery tables to use as context. Each object in the list must contain `projectId`, `datasetId`, and `tableId`. Example: `'[{"projectId": "my-gcp-project", "datasetId": "my_dataset", "tableId": "my_table"}]'`
- **`user_query_with_context`:** The user's question, potentially including conversation history and system
instructions for context.
- **`table_references`:** A JSON string of a list of BigQuery tables to use as context.
Each object in the list must contain `projectId`, `datasetId`, and `tableId`. Example:
`'[{"projectId": "my-gcp-project", "datasetId": "my_dataset", "tableId": "my_table"}]'`
The tool's behavior regarding these parameters is influenced by the `allowedDatasets`
restriction on the `bigquery` source:
- **Without `allowedDatasets` restriction:** The tool can use tables from any
dataset specified in the `table_references` parameter.
- **With `allowedDatasets` restriction:** Before processing the request, the tool
verifies that every table in `table_references` belongs to a dataset in the allowed
list. If any table is from a dataset that is not in the list, the request is denied.
## Example

View File

@@ -59,6 +59,8 @@ type compatibleSource interface {
BigQueryLocation() string
GetMaxQueryResultRows() int
UseClientAuthorization() bool
IsDatasetAllowed(projectID, datasetID string) bool
BigQueryAllowedDatasets() []string
}
type BQTableReference struct {
@@ -134,8 +136,17 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
allowedDatasets := s.BigQueryAllowedDatasets()
tableRefsDescription := `A JSON string of a list of BigQuery tables to use as context. Each object in the list must contain 'projectId', 'datasetId', and 'tableId'. Example: '[{"projectId": "my-gcp-project", "datasetId": "my_dataset", "tableId": "my_table"}]'.`
if len(allowedDatasets) > 0 {
datasetIDs := []string{}
for _, ds := range allowedDatasets {
datasetIDs = append(datasetIDs, fmt.Sprintf("`%s`", ds))
}
tableRefsDescription += fmt.Sprintf(" The tables must only be from datasets in the following list: %s.", strings.Join(datasetIDs, ", "))
}
userQueryParameter := tools.NewStringParameter("user_query_with_context", "The user's question, potentially including conversation history and system instructions for context.")
tableRefsParameter := tools.NewStringParameter("table_references", `A JSON string of a list of BigQuery tables to use as context. Each object in the list must contain 'projectId', 'datasetId', and 'tableId'. Example: '[{"projectId": "my-gcp-project", "datasetId": "my_dataset", "tableId": "my_table"}]'`)
tableRefsParameter := tools.NewStringParameter("table_references", tableRefsDescription)
parameters := tools.Parameters{userQueryParameter, tableRefsParameter}
@@ -170,6 +181,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
MaxQueryResultRows: s.GetMaxQueryResultRows(),
IsDatasetAllowed: s.IsDatasetAllowed,
AllowedDatasets: allowedDatasets,
}
return t, nil
}
@@ -191,6 +204,8 @@ type Tool struct {
manifest tools.Manifest
mcpManifest tools.McpManifest
MaxQueryResultRows int
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
@@ -233,6 +248,14 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
}
}
if len(t.AllowedDatasets) > 0 {
for _, tableRef := range tableRefs {
if !t.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID)
}
}
}
// Construct URL, headers, and payload
projectID := t.Project
location := t.Location

View File

@@ -269,6 +269,11 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
"source": "my-instance",
"description": "Tool to list table within a dataset",
},
"conversational-analytics-restricted": map[string]any{
"kind": "bigquery-conversational-analytics",
"source": "my-instance",
"description": "Tool to ask BigQuery conversational analytics",
},
}
// Create config file
@@ -297,6 +302,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
// Run tests
runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1)
runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2)
runConversationalAnalyticsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, disallowedTableName)
runConversationalAnalyticsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, disallowedTableName)
}
// getBigQueryParamToolInfo returns statements and param for my-tool for bigquery kind
@@ -2142,6 +2149,83 @@ func runListTableIdsWithRestriction(t *testing.T, allowedDatasetName, disallowed
}
}
func runConversationalAnalyticsWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName, allowedTableName, disallowedTableName string) {
allowedTableRefsJSON := fmt.Sprintf(`[{"projectId":"%s","datasetId":"%s","tableId":"%s"}]`, BigqueryProject, allowedDatasetName, allowedTableName)
disallowedTableRefsJSON := fmt.Sprintf(`[{"projectId":"%s","datasetId":"%s","tableId":"%s"}]`, BigqueryProject, disallowedDatasetName, disallowedTableName)
testCases := []struct {
name string
tableRefs string
wantStatusCode int
wantInResult string
wantInError string
}{
{
name: "invoke with allowed table",
tableRefs: allowedTableRefsJSON,
wantStatusCode: http.StatusOK,
wantInResult: `Answer`,
},
{
name: "invoke with disallowed table",
tableRefs: disallowedTableRefsJSON,
wantStatusCode: http.StatusBadRequest,
wantInError: fmt.Sprintf("access to dataset '%s.%s' (from table '%s') is not allowed", BigqueryProject, disallowedDatasetName, disallowedTableName),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
requestBodyMap := map[string]any{
"user_query_with_context": "What is in the table?",
"table_references": tc.tableRefs,
}
bodyBytes, err := json.Marshal(requestBodyMap)
if err != nil {
t.Fatalf("failed to marshal request body: %v", err)
}
body := bytes.NewBuffer(bodyBytes)
req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/conversational-analytics-restricted/invoke", body)
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != tc.wantStatusCode {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
}
if tc.wantInResult != "" {
var respBody map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil {
t.Fatalf("error parsing response body: %v", err)
}
got, ok := respBody["result"].(string)
if !ok {
t.Fatalf("unable to find result in response body")
}
if !strings.Contains(got, tc.wantInResult) {
t.Errorf("unexpected result: got %q, want to contain %q", got, tc.wantInResult)
}
}
if tc.wantInError != "" {
bodyBytes, _ := io.ReadAll(resp.Body)
if !strings.Contains(string(bodyBytes), tc.wantInError) {
t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError)
}
}
})
}
}
func runBigQuerySearchCatalogToolInvokeTest(t *testing.T, datasetName string, tableName string) {
// Get ID token
idToken, err := tests.GetGoogleIdToken(tests.ClientId)