diff --git a/docs/en/resources/tools/bigquery/bigquery-conversational-analytics.md b/docs/en/resources/tools/bigquery/bigquery-conversational-analytics.md index eee68286ca..ca3913eb9a 100644 --- a/docs/en/resources/tools/bigquery/bigquery-conversational-analytics.md +++ b/docs/en/resources/tools/bigquery/bigquery-conversational-analytics.md @@ -26,12 +26,23 @@ for instructions. It's compatible with the following sources: -- [bigquery](../sources/bigquery.md) +- [bigquery](../../sources/bigquery.md) -The tool takes the following input parameters: +`bigquery-conversational-analytics` accepts the following parameters: -* `user_query_with_context`: The user's question, potentially including conversation history and system instructions for context. -* `table_references`: A JSON string of a list of BigQuery tables to use as context. Each object in the list must contain `projectId`, `datasetId`, and `tableId`. Example: `'[{"projectId": "my-gcp-project", "datasetId": "my_dataset", "tableId": "my_table"}]'` +- **`user_query_with_context`:** The user's question, potentially including conversation history and system +instructions for context. +- **`table_references`:** A JSON string of a list of BigQuery tables to use as context. +Each object in the list must contain `projectId`, `datasetId`, and `tableId`. Example: +`'[{"projectId": "my-gcp-project", "datasetId": "my_dataset", "tableId": "my_table"}]'` + +The tool's behavior regarding these parameters is influenced by the `allowedDatasets` +restriction on the `bigquery` source: +- **Without `allowedDatasets` restriction:** The tool can use tables from any +dataset specified in the `table_references` parameter. +- **With `allowedDatasets` restriction:** Before processing the request, the tool +verifies that every table in `table_references` belongs to a dataset in the allowed +list. If any table is from a dataset that is not in the list, the request is denied. ## Example diff --git a/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go b/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go index 0c05a7ee9b..8ab3c8bdb2 100644 --- a/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go +++ b/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go @@ -59,6 +59,8 @@ type compatibleSource interface { BigQueryLocation() string GetMaxQueryResultRows() int UseClientAuthorization() bool + IsDatasetAllowed(projectID, datasetID string) bool + BigQueryAllowedDatasets() []string } type BQTableReference struct { @@ -134,8 +136,17 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) } + allowedDatasets := s.BigQueryAllowedDatasets() + tableRefsDescription := `A JSON string of a list of BigQuery tables to use as context. Each object in the list must contain 'projectId', 'datasetId', and 'tableId'. Example: '[{"projectId": "my-gcp-project", "datasetId": "my_dataset", "tableId": "my_table"}]'.` + if len(allowedDatasets) > 0 { + datasetIDs := []string{} + for _, ds := range allowedDatasets { + datasetIDs = append(datasetIDs, fmt.Sprintf("`%s`", ds)) + } + tableRefsDescription += fmt.Sprintf(" The tables must only be from datasets in the following list: %s.", strings.Join(datasetIDs, ", ")) + } userQueryParameter := tools.NewStringParameter("user_query_with_context", "The user's question, potentially including conversation history and system instructions for context.") - tableRefsParameter := tools.NewStringParameter("table_references", `A JSON string of a list of BigQuery tables to use as context. Each object in the list must contain 'projectId', 'datasetId', and 'tableId'. Example: '[{"projectId": "my-gcp-project", "datasetId": "my_dataset", "tableId": "my_table"}]'`) + tableRefsParameter := tools.NewStringParameter("table_references", tableRefsDescription) parameters := tools.Parameters{userQueryParameter, tableRefsParameter} @@ -170,6 +181,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, MaxQueryResultRows: s.GetMaxQueryResultRows(), + IsDatasetAllowed: s.IsDatasetAllowed, + AllowedDatasets: allowedDatasets, } return t, nil } @@ -191,6 +204,8 @@ type Tool struct { manifest tools.Manifest mcpManifest tools.McpManifest MaxQueryResultRows int + IsDatasetAllowed func(projectID, datasetID string) bool + AllowedDatasets []string } func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { @@ -233,6 +248,14 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken } } + if len(t.AllowedDatasets) > 0 { + for _, tableRef := range tableRefs { + if !t.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) { + return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID) + } + } + } + // Construct URL, headers, and payload projectID := t.Project location := t.Location diff --git a/tests/bigquery/bigquery_integration_test.go b/tests/bigquery/bigquery_integration_test.go index b80aa4827a..c4ec453b59 100644 --- a/tests/bigquery/bigquery_integration_test.go +++ b/tests/bigquery/bigquery_integration_test.go @@ -269,6 +269,11 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) { "source": "my-instance", "description": "Tool to list table within a dataset", }, + "conversational-analytics-restricted": map[string]any{ + "kind": "bigquery-conversational-analytics", + "source": "my-instance", + "description": "Tool to ask BigQuery conversational analytics", + }, } // Create config file @@ -297,6 +302,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) { // Run tests runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1) runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2) + runConversationalAnalyticsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, disallowedTableName) + runConversationalAnalyticsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, disallowedTableName) } // getBigQueryParamToolInfo returns statements and param for my-tool for bigquery kind @@ -2142,6 +2149,83 @@ func runListTableIdsWithRestriction(t *testing.T, allowedDatasetName, disallowed } } +func runConversationalAnalyticsWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName, allowedTableName, disallowedTableName string) { + allowedTableRefsJSON := fmt.Sprintf(`[{"projectId":"%s","datasetId":"%s","tableId":"%s"}]`, BigqueryProject, allowedDatasetName, allowedTableName) + disallowedTableRefsJSON := fmt.Sprintf(`[{"projectId":"%s","datasetId":"%s","tableId":"%s"}]`, BigqueryProject, disallowedDatasetName, disallowedTableName) + + testCases := []struct { + name string + tableRefs string + wantStatusCode int + wantInResult string + wantInError string + }{ + { + name: "invoke with allowed table", + tableRefs: allowedTableRefsJSON, + wantStatusCode: http.StatusOK, + wantInResult: `Answer`, + }, + { + name: "invoke with disallowed table", + tableRefs: disallowedTableRefsJSON, + wantStatusCode: http.StatusBadRequest, + wantInError: fmt.Sprintf("access to dataset '%s.%s' (from table '%s') is not allowed", BigqueryProject, disallowedDatasetName, disallowedTableName), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + requestBodyMap := map[string]any{ + "user_query_with_context": "What is in the table?", + "table_references": tc.tableRefs, + } + bodyBytes, err := json.Marshal(requestBodyMap) + if err != nil { + t.Fatalf("failed to marshal request body: %v", err) + } + body := bytes.NewBuffer(bodyBytes) + + req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/conversational-analytics-restricted/invoke", body) + if err != nil { + t.Fatalf("unable to create request: %s", err) + } + req.Header.Add("Content-type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("unable to send request: %s", err) + } + defer resp.Body.Close() + + if resp.StatusCode != tc.wantStatusCode { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes)) + } + + if tc.wantInResult != "" { + var respBody map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { + t.Fatalf("error parsing response body: %v", err) + } + got, ok := respBody["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } + if !strings.Contains(got, tc.wantInResult) { + t.Errorf("unexpected result: got %q, want to contain %q", got, tc.wantInResult) + } + } + + if tc.wantInError != "" { + bodyBytes, _ := io.ReadAll(resp.Body) + if !strings.Contains(string(bodyBytes), tc.wantInError) { + t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError) + } + } + }) + } +} + func runBigQuerySearchCatalogToolInvokeTest(t *testing.T, datasetName string, tableName string) { // Get ID token idToken, err := tests.GetGoogleIdToken(tests.ClientId)