mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 07:58:12 -05:00
feat(tools/bigquery-conversational-analytics)!: Add allowed datasets support (#1411)
## Description --- Add support to allowed datasets for conversational-analytics tool in bigquery. ## PR Checklist --- > Thank you for opening a Pull Request! Before submitting your PR, there are a > few things you can do to make sure it goes smoothly: - [ ] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) - [ ] Make sure to add `!` if this involve a breaking change 🛠️ Fixes #<issue_number_goes_here>
This commit is contained in:
@@ -26,12 +26,23 @@ for instructions.
|
||||
|
||||
It's compatible with the following sources:
|
||||
|
||||
- [bigquery](../sources/bigquery.md)
|
||||
- [bigquery](../../sources/bigquery.md)
|
||||
|
||||
The tool takes the following input parameters:
|
||||
`bigquery-conversational-analytics` accepts the following parameters:
|
||||
|
||||
* `user_query_with_context`: The user's question, potentially including conversation history and system instructions for context.
|
||||
* `table_references`: A JSON string of a list of BigQuery tables to use as context. Each object in the list must contain `projectId`, `datasetId`, and `tableId`. Example: `'[{"projectId": "my-gcp-project", "datasetId": "my_dataset", "tableId": "my_table"}]'`
|
||||
- **`user_query_with_context`:** The user's question, potentially including conversation history and system
|
||||
instructions for context.
|
||||
- **`table_references`:** A JSON string of a list of BigQuery tables to use as context.
|
||||
Each object in the list must contain `projectId`, `datasetId`, and `tableId`. Example:
|
||||
`'[{"projectId": "my-gcp-project", "datasetId": "my_dataset", "tableId": "my_table"}]'`
|
||||
|
||||
The tool's behavior regarding these parameters is influenced by the `allowedDatasets`
|
||||
restriction on the `bigquery` source:
|
||||
- **Without `allowedDatasets` restriction:** The tool can use tables from any
|
||||
dataset specified in the `table_references` parameter.
|
||||
- **With `allowedDatasets` restriction:** Before processing the request, the tool
|
||||
verifies that every table in `table_references` belongs to a dataset in the allowed
|
||||
list. If any table is from a dataset that is not in the list, the request is denied.
|
||||
|
||||
## Example
|
||||
|
||||
|
||||
@@ -59,6 +59,8 @@ type compatibleSource interface {
|
||||
BigQueryLocation() string
|
||||
GetMaxQueryResultRows() int
|
||||
UseClientAuthorization() bool
|
||||
IsDatasetAllowed(projectID, datasetID string) bool
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
type BQTableReference struct {
|
||||
@@ -134,8 +136,17 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allowedDatasets := s.BigQueryAllowedDatasets()
|
||||
tableRefsDescription := `A JSON string of a list of BigQuery tables to use as context. Each object in the list must contain 'projectId', 'datasetId', and 'tableId'. Example: '[{"projectId": "my-gcp-project", "datasetId": "my_dataset", "tableId": "my_table"}]'.`
|
||||
if len(allowedDatasets) > 0 {
|
||||
datasetIDs := []string{}
|
||||
for _, ds := range allowedDatasets {
|
||||
datasetIDs = append(datasetIDs, fmt.Sprintf("`%s`", ds))
|
||||
}
|
||||
tableRefsDescription += fmt.Sprintf(" The tables must only be from datasets in the following list: %s.", strings.Join(datasetIDs, ", "))
|
||||
}
|
||||
userQueryParameter := tools.NewStringParameter("user_query_with_context", "The user's question, potentially including conversation history and system instructions for context.")
|
||||
tableRefsParameter := tools.NewStringParameter("table_references", `A JSON string of a list of BigQuery tables to use as context. Each object in the list must contain 'projectId', 'datasetId', and 'tableId'. Example: '[{"projectId": "my-gcp-project", "datasetId": "my_dataset", "tableId": "my_table"}]'`)
|
||||
tableRefsParameter := tools.NewStringParameter("table_references", tableRefsDescription)
|
||||
|
||||
parameters := tools.Parameters{userQueryParameter, tableRefsParameter}
|
||||
|
||||
@@ -170,6 +181,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
MaxQueryResultRows: s.GetMaxQueryResultRows(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
AllowedDatasets: allowedDatasets,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -191,6 +204,8 @@ type Tool struct {
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
MaxQueryResultRows int
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
AllowedDatasets []string
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
@@ -233,6 +248,14 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
}
|
||||
}
|
||||
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
for _, tableRef := range tableRefs {
|
||||
if !t.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
|
||||
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Construct URL, headers, and payload
|
||||
projectID := t.Project
|
||||
location := t.Location
|
||||
|
||||
@@ -269,6 +269,11 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
|
||||
"source": "my-instance",
|
||||
"description": "Tool to list table within a dataset",
|
||||
},
|
||||
"conversational-analytics-restricted": map[string]any{
|
||||
"kind": "bigquery-conversational-analytics",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to ask BigQuery conversational analytics",
|
||||
},
|
||||
}
|
||||
|
||||
// Create config file
|
||||
@@ -297,6 +302,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
|
||||
// Run tests
|
||||
runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1)
|
||||
runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2)
|
||||
runConversationalAnalyticsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, disallowedTableName)
|
||||
runConversationalAnalyticsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, disallowedTableName)
|
||||
}
|
||||
|
||||
// getBigQueryParamToolInfo returns statements and param for my-tool for bigquery kind
|
||||
@@ -2142,6 +2149,83 @@ func runListTableIdsWithRestriction(t *testing.T, allowedDatasetName, disallowed
|
||||
}
|
||||
}
|
||||
|
||||
func runConversationalAnalyticsWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName, allowedTableName, disallowedTableName string) {
|
||||
allowedTableRefsJSON := fmt.Sprintf(`[{"projectId":"%s","datasetId":"%s","tableId":"%s"}]`, BigqueryProject, allowedDatasetName, allowedTableName)
|
||||
disallowedTableRefsJSON := fmt.Sprintf(`[{"projectId":"%s","datasetId":"%s","tableId":"%s"}]`, BigqueryProject, disallowedDatasetName, disallowedTableName)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
tableRefs string
|
||||
wantStatusCode int
|
||||
wantInResult string
|
||||
wantInError string
|
||||
}{
|
||||
{
|
||||
name: "invoke with allowed table",
|
||||
tableRefs: allowedTableRefsJSON,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantInResult: `Answer`,
|
||||
},
|
||||
{
|
||||
name: "invoke with disallowed table",
|
||||
tableRefs: disallowedTableRefsJSON,
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantInError: fmt.Sprintf("access to dataset '%s.%s' (from table '%s') is not allowed", BigqueryProject, disallowedDatasetName, disallowedTableName),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
requestBodyMap := map[string]any{
|
||||
"user_query_with_context": "What is in the table?",
|
||||
"table_references": tc.tableRefs,
|
||||
}
|
||||
bodyBytes, err := json.Marshal(requestBodyMap)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal request body: %v", err)
|
||||
}
|
||||
body := bytes.NewBuffer(bodyBytes)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/conversational-analytics-restricted/invoke", body)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
if tc.wantInResult != "" {
|
||||
var respBody map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil {
|
||||
t.Fatalf("error parsing response body: %v", err)
|
||||
}
|
||||
got, ok := respBody["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
if !strings.Contains(got, tc.wantInResult) {
|
||||
t.Errorf("unexpected result: got %q, want to contain %q", got, tc.wantInResult)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.wantInError != "" {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
if !strings.Contains(string(bodyBytes), tc.wantInError) {
|
||||
t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runBigQuerySearchCatalogToolInvokeTest(t *testing.T, datasetName string, tableName string) {
|
||||
// Get ID token
|
||||
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
|
||||
|
||||
Reference in New Issue
Block a user