feat(tool/bigquery-execute-sql)!: add allowed datasets support (#1443)

## Description
This introduces a breaking change. The bigquery-execute-sql tool will
now enforce the allowed datasets setting from its BigQuery source
configuration. Previously, this setting had no effect on the tool.

---
> Should include a concise description of the changes (bug or feature),
it's
> impact, along with a summary of the solution

## PR Checklist

---
> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [ ] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [ ] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [ ] Ensure the tests and linter pass
- [ ] Code coverage does not decrease (if any source code was changed)
- [ ] Appropriate docs were updated (if necessary)
- [ ] Make sure to add `!` if this involve a breaking change

🛠️ Fixes https://github.com/googleapis/genai-toolbox/issues/873

---------

Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
This commit is contained in:
Huan Chen
2025-09-25 12:08:50 -07:00
committed by GitHub
parent a3bd2e9927
commit 9501ebbdbc
5 changed files with 1138 additions and 21 deletions

View File

@@ -15,9 +15,23 @@ It's compatible with the following sources:
- [bigquery](../../sources/bigquery.md)
`bigquery-execute-sql` takes a required `sql` input parameter and runs the SQL
statement against the configured `source`. It also supports an optional `dry_run`
parameter to validate a query without executing it.
`bigquery-execute-sql` accepts the following parameters:
- **`sql`** (required): The GoogleSQL statement to execute.
- **`dry_run`** (optional): If set to `true`, the query is validated but not run,
returning information about the execution instead. Defaults to `false`.
The tool's behavior is influenced by the `allowedDatasets` restriction on the
`bigquery` source:
- **Without `allowedDatasets` restriction:** The tool can execute any valid GoogleSQL
query.
- **With `allowedDatasets` restriction:** Before execution, the tool performs a dry run
to analyze the query.
It will reject the query if it attempts to access any table outside the
allowed `datasets` list. To enforce this restriction, the following operations
are also disallowed:
- **Dataset-level operations** (e.g., `CREATE SCHEMA`, `ALTER SCHEMA`).
- **Unanalyzable operations** where the accessed tables cannot be determined
statically (e.g., `EXECUTE IMMEDIATE`, `CREATE PROCEDURE`, `CALL`).
## Example

View File

@@ -0,0 +1,431 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bigquerycommon
import (
"fmt"
"strings"
"unicode"
)
// parserState defines the state of the SQL parser's state machine.
type parserState int
const (
stateNormal parserState = iota
// String states
stateInSingleQuoteString
stateInDoubleQuoteString
stateInTripleSingleQuoteString
stateInTripleDoubleQuoteString
stateInRawSingleQuoteString
stateInRawDoubleQuoteString
stateInRawTripleSingleQuoteString
stateInRawTripleDoubleQuoteString
// Comment states
stateInSingleLineCommentDash
stateInSingleLineCommentHash
stateInMultiLineComment
)
// SQL statement verbs
const (
verbCreate = "create"
verbAlter = "alter"
verbDrop = "drop"
verbSelect = "select"
verbInsert = "insert"
verbUpdate = "update"
verbDelete = "delete"
verbMerge = "merge"
)
var tableFollowsKeywords = map[string]bool{
"from": true,
"join": true,
"update": true,
"into": true, // INSERT INTO, MERGE INTO
"table": true, // CREATE TABLE, ALTER TABLE
"using": true, // MERGE ... USING
"insert": true, // INSERT my_table
"merge": true, // MERGE my_table
}
var tableContextExitKeywords = map[string]bool{
"where": true,
"group": true, // GROUP BY
"having": true,
"order": true, // ORDER BY
"limit": true,
"window": true,
"on": true, // JOIN ... ON
"set": true, // UPDATE ... SET
"when": true, // MERGE ... WHEN
}
// TableParser is the main entry point for parsing a SQL string to find all referenced table IDs.
// It handles multi-statement SQL, comments, and recursive parsing of EXECUTE IMMEDIATE statements.
func TableParser(sql, defaultProjectID string) ([]string, error) {
tableIDSet := make(map[string]struct{})
visitedSQLs := make(map[string]struct{})
if _, err := parseSQL(sql, defaultProjectID, tableIDSet, visitedSQLs, false); err != nil {
return nil, err
}
tableIDs := make([]string, 0, len(tableIDSet))
for id := range tableIDSet {
tableIDs = append(tableIDs, id)
}
return tableIDs, nil
}
// parseSQL is the core recursive function that processes SQL strings.
// It uses a state machine to find table names and recursively parse EXECUTE IMMEDIATE.
func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visitedSQLs map[string]struct{}, inSubquery bool) (int, error) {
// Prevent infinite recursion.
if _, ok := visitedSQLs[sql]; ok {
return len(sql), nil
}
visitedSQLs[sql] = struct{}{}
state := stateNormal
expectingTable := false
var lastTableKeyword, lastToken, statementVerb string
runes := []rune(sql)
for i := 0; i < len(runes); {
char := runes[i]
remaining := sql[i:]
switch state {
case stateNormal:
if strings.HasPrefix(remaining, "--") {
state = stateInSingleLineCommentDash
i += 2
continue
}
if strings.HasPrefix(remaining, "#") {
state = stateInSingleLineCommentHash
i++
continue
}
if strings.HasPrefix(remaining, "/*") {
state = stateInMultiLineComment
i += 2
continue
}
if char == '(' {
if expectingTable {
// The subquery starts after '('.
consumed, err := parseSQL(remaining[1:], defaultProjectID, tableIDSet, visitedSQLs, true)
if err != nil {
return 0, err
}
// Advance i by the length of the subquery + the opening parenthesis.
// The recursive call returns what it consumed, including the closing parenthesis.
i += consumed + 1
// For most keywords, we expect only one table. `from` can have multiple "tables" (subqueries).
if lastTableKeyword != "from" {
expectingTable = false
}
continue
}
}
if char == ')' {
if inSubquery {
return i + 1, nil
}
}
if char == ';' {
statementVerb = ""
lastToken = ""
i++
continue
}
// Raw strings must be checked before regular strings.
if strings.HasPrefix(remaining, "r'''") || strings.HasPrefix(remaining, "R'''") {
state = stateInRawTripleSingleQuoteString
i += 4
continue
}
if strings.HasPrefix(remaining, `r"""`) || strings.HasPrefix(remaining, `R"""`) {
state = stateInRawTripleDoubleQuoteString
i += 4
continue
}
if strings.HasPrefix(remaining, "r'") || strings.HasPrefix(remaining, "R'") {
state = stateInRawSingleQuoteString
i += 2
continue
}
if strings.HasPrefix(remaining, `r"`) || strings.HasPrefix(remaining, `R"`) {
state = stateInRawDoubleQuoteString
i += 2
continue
}
if strings.HasPrefix(remaining, "'''") {
state = stateInTripleSingleQuoteString
i += 3
continue
}
if strings.HasPrefix(remaining, `"""`) {
state = stateInTripleDoubleQuoteString
i += 3
continue
}
if char == '\'' {
state = stateInSingleQuoteString
i++
continue
}
if char == '"' {
state = stateInDoubleQuoteString
i++
continue
}
if unicode.IsLetter(char) || char == '`' {
parts, consumed, err := parseIdentifierSequence(remaining)
if err != nil {
return 0, err
}
if consumed == 0 {
i++
continue
}
if len(parts) == 1 {
keyword := strings.ToLower(parts[0])
switch keyword {
case "call":
return 0, fmt.Errorf("CALL is not allowed when dataset restrictions are in place, as the called procedure's contents cannot be safely analyzed")
case "immediate":
if lastToken == "execute" {
return 0, fmt.Errorf("EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place, as its contents cannot be safely analyzed")
}
case "procedure", "function":
if lastToken == "create" || lastToken == "create or replace" {
return 0, fmt.Errorf("unanalyzable statements like '%s %s' are not allowed", strings.ToUpper(lastToken), strings.ToUpper(keyword))
}
case verbCreate, verbAlter, verbDrop, verbSelect, verbInsert, verbUpdate, verbDelete, verbMerge:
if statementVerb == "" {
statementVerb = keyword
}
}
if statementVerb == verbCreate || statementVerb == verbAlter || statementVerb == verbDrop {
if keyword == "schema" || keyword == "dataset" {
return 0, fmt.Errorf("dataset-level operations like '%s %s' are not allowed when dataset restrictions are in place", strings.ToUpper(statementVerb), strings.ToUpper(keyword))
}
}
if _, ok := tableFollowsKeywords[keyword]; ok {
expectingTable = true
lastTableKeyword = keyword
} else if _, ok := tableContextExitKeywords[keyword]; ok {
expectingTable = false
lastTableKeyword = ""
}
if lastToken == "create" && keyword == "or" {
lastToken = "create or"
} else if lastToken == "create or" && keyword == "replace" {
lastToken = "create or replace"
} else {
lastToken = keyword
}
} else if len(parts) >= 2 {
// This is a multi-part identifier. If we were expecting a table, this is it.
if expectingTable {
tableID, err := formatTableID(parts, defaultProjectID)
if err != nil {
return 0, err
}
if tableID != "" {
tableIDSet[tableID] = struct{}{}
}
// For most keywords, we expect only one table.
if lastTableKeyword != "from" {
expectingTable = false
}
}
lastToken = ""
}
i += consumed
continue
}
i++
case stateInSingleQuoteString:
if char == '\\' {
i += 2 // Skip backslash and the escaped character.
continue
}
if char == '\'' {
state = stateNormal
}
i++
case stateInDoubleQuoteString:
if char == '\\' {
i += 2 // Skip backslash and the escaped character.
continue
}
if char == '"' {
state = stateNormal
}
i++
case stateInTripleSingleQuoteString:
if strings.HasPrefix(remaining, "'''") {
state = stateNormal
i += 3
} else {
i++
}
case stateInTripleDoubleQuoteString:
if strings.HasPrefix(remaining, `"""`) {
state = stateNormal
i += 3
} else {
i++
}
case stateInSingleLineCommentDash, stateInSingleLineCommentHash:
if char == '\n' {
state = stateNormal
}
i++
case stateInMultiLineComment:
if strings.HasPrefix(remaining, "*/") {
state = stateNormal
i += 2
} else {
i++
}
case stateInRawSingleQuoteString:
if char == '\'' {
state = stateNormal
}
i++
case stateInRawDoubleQuoteString:
if char == '"' {
state = stateNormal
}
i++
case stateInRawTripleSingleQuoteString:
if strings.HasPrefix(remaining, "'''") {
state = stateNormal
i += 3
} else {
i++
}
case stateInRawTripleDoubleQuoteString:
if strings.HasPrefix(remaining, `"""`) {
state = stateNormal
i += 3
} else {
i++
}
}
}
if inSubquery {
return 0, fmt.Errorf("unclosed subquery parenthesis")
}
return len(sql), nil
}
// parseIdentifierSequence parses a sequence of dot-separated identifiers.
// It returns the parts of the identifier, the number of characters consumed, and an error.
func parseIdentifierSequence(s string) ([]string, int, error) {
var parts []string
var totalConsumed int
for {
remaining := s[totalConsumed:]
trimmed := strings.TrimLeftFunc(remaining, unicode.IsSpace)
totalConsumed += len(remaining) - len(trimmed)
current := s[totalConsumed:]
if len(current) == 0 {
break
}
var part string
var consumed int
if current[0] == '`' {
end := strings.Index(current[1:], "`")
if end == -1 {
return nil, 0, fmt.Errorf("unclosed backtick identifier")
}
part = current[1 : end+1]
consumed = end + 2
} else if len(current) > 0 && unicode.IsLetter(rune(current[0])) {
end := strings.IndexFunc(current, func(r rune) bool {
return !unicode.IsLetter(r) && !unicode.IsNumber(r) && r != '_' && r != '-'
})
if end == -1 {
part = current
consumed = len(current)
} else {
part = current[:end]
consumed = end
}
} else {
break
}
if current[0] == '`' && strings.Contains(part, ".") {
// This handles cases like `project.dataset.table` but not `project.dataset`.table.
// If the character after the quoted identifier is not a dot, we treat it as a full name.
if len(current) <= consumed || current[consumed] != '.' {
parts = append(parts, strings.Split(part, ".")...)
totalConsumed += consumed
break
}
}
parts = append(parts, strings.Split(part, ".")...)
totalConsumed += consumed
if len(s) <= totalConsumed || s[totalConsumed] != '.' {
break
}
totalConsumed++
}
return parts, totalConsumed, nil
}
func formatTableID(parts []string, defaultProjectID string) (string, error) {
if len(parts) < 2 || len(parts) > 3 {
// Not a table identifier (could be a CTE, column, etc.).
// Return the consumed length so the main loop can skip this identifier.
return "", nil
}
var tableID string
if len(parts) == 3 { // project.dataset.table
tableID = strings.Join(parts, ".")
} else { // dataset.table
if defaultProjectID == "" {
return "", fmt.Errorf("query contains table '%s' without project ID, and no default project ID is provided", strings.Join(parts, "."))
}
tableID = fmt.Sprintf("%s.%s", defaultProjectID, strings.Join(parts, "."))
}
return tableID, nil
}

View File

@@ -0,0 +1,496 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bigquerycommon_test
import (
"sort"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
)
func TestTableParser(t *testing.T) {
testCases := []struct {
name string
sql string
defaultProjectID string
want []string
wantErr bool
wantErrMsg string
}{
{
name: "single fully qualified table",
sql: "SELECT * FROM `my-project.my_dataset.my_table`",
defaultProjectID: "default-proj",
want: []string{"my-project.my_dataset.my_table"},
wantErr: false,
},
{
name: "multiple statements with same table",
sql: "select * from proj1.data1.tbl1 limit 1; select A.b from proj1.data1.tbl1 as A limit 1;",
defaultProjectID: "default-proj",
want: []string{"proj1.data1.tbl1"},
wantErr: false,
},
{
name: "multiple fully qualified tables",
sql: "SELECT * FROM `proj1.data1`.`tbl1` JOIN proj2.`data2.tbl2` ON id",
defaultProjectID: "default-proj",
want: []string{"proj1.data1.tbl1", "proj2.data2.tbl2"},
wantErr: false,
},
{
name: "duplicate tables",
sql: "SELECT * FROM `proj1.data1.tbl1` JOIN proj1.data1.tbl1 ON id",
defaultProjectID: "default-proj",
want: []string{"proj1.data1.tbl1"},
wantErr: false,
},
{
name: "partial table with default project",
sql: "SELECT * FROM `my_dataset`.my_table",
defaultProjectID: "default-proj",
want: []string{"default-proj.my_dataset.my_table"},
wantErr: false,
},
{
name: "partial table without default project",
sql: "SELECT * FROM `my_dataset.my_table`",
defaultProjectID: "",
want: nil,
wantErr: true,
},
{
name: "mixed fully qualified and partial tables",
sql: "SELECT t1.*, t2.* FROM `proj1.data1.tbl1` AS t1 JOIN `data2.tbl2` AS t2 ON t1.id = t2.id",
defaultProjectID: "default-proj",
want: []string{"proj1.data1.tbl1", "default-proj.data2.tbl2"},
wantErr: false,
},
{
name: "no tables",
sql: "SELECT 1+1",
defaultProjectID: "default-proj",
want: []string{},
wantErr: false,
},
{
name: "ignore single part identifiers (like CTEs)",
sql: "WITH my_cte AS (SELECT 1) SELECT * FROM `my_cte`",
defaultProjectID: "default-proj",
want: []string{},
wantErr: false,
},
{
name: "complex CTE",
sql: "WITH cte1 AS (SELECT * FROM `real.table.one`), cte2 AS (SELECT * FROM cte1) SELECT * FROM cte2 JOIN `real.table.two` ON true",
defaultProjectID: "default-proj",
want: []string{"real.table.one", "real.table.two"},
wantErr: false,
},
{
name: "nested subquery should be parsed",
sql: "SELECT * FROM (SELECT a FROM (SELECT A.b FROM `real.table.nested` AS A))",
defaultProjectID: "default-proj",
want: []string{"real.table.nested"},
wantErr: false,
},
{
name: "from clause with unnest",
sql: "SELECT event.name FROM `my-project.my_dataset.my_table` AS A, UNNEST(A.events) AS event",
defaultProjectID: "default-proj",
want: []string{"my-project.my_dataset.my_table"},
wantErr: false,
},
{
name: "ignore more than 3 parts",
sql: "SELECT * FROM `proj.data.tbl.col`",
defaultProjectID: "default-proj",
want: []string{},
wantErr: false,
},
{
name: "complex query",
sql: "SELECT name FROM (SELECT name FROM `proj1.data1.tbl1`) UNION ALL SELECT name FROM `data2.tbl2`",
defaultProjectID: "default-proj",
want: []string{"proj1.data1.tbl1", "default-proj.data2.tbl2"},
wantErr: false,
},
{
name: "empty sql",
sql: "",
defaultProjectID: "default-proj",
want: []string{},
wantErr: false,
},
{
name: "with comments",
sql: "SELECT * FROM `proj1.data1.tbl1`; -- comment `fake.table.one` \n SELECT * FROM `proj2.data2.tbl2`; # comment `fake.table.two`",
defaultProjectID: "default-proj",
want: []string{"proj1.data1.tbl1", "proj2.data2.tbl2"},
wantErr: false,
},
{
name: "multi-statement with semicolon",
sql: "SELECT * FROM `proj1.data1.tbl1`; SELECT * FROM `proj2.data2.tbl2`",
defaultProjectID: "default-proj",
want: []string{"proj1.data1.tbl1", "proj2.data2.tbl2"},
wantErr: false,
},
{
name: "simple execute immediate",
sql: "EXECUTE IMMEDIATE 'SELECT * FROM `exec.proj.tbl`'",
defaultProjectID: "default-proj",
want: nil,
wantErr: true,
wantErrMsg: "EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place",
},
{
name: "execute immediate with multiple spaces",
sql: "EXECUTE IMMEDIATE 'SELECT 1'",
defaultProjectID: "default-proj",
want: nil,
wantErr: true,
wantErrMsg: "EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place",
},
{
name: "execute immediate with newline",
sql: "EXECUTE\nIMMEDIATE 'SELECT 1'",
defaultProjectID: "default-proj",
want: nil,
wantErr: true,
wantErrMsg: "EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place",
},
{
name: "execute immediate with comment",
sql: "EXECUTE -- some comment\n IMMEDIATE 'SELECT * FROM `exec.proj.tbl`'",
defaultProjectID: "default-proj",
want: nil,
wantErr: true,
wantErrMsg: "EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place",
},
{
name: "nested execute immediate",
sql: "EXECUTE IMMEDIATE \"EXECUTE IMMEDIATE '''SELECT * FROM `nested.exec.tbl`'''\"",
defaultProjectID: "default-proj",
want: nil,
wantErr: true,
wantErrMsg: "EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place",
},
{
name: "begin execute immediate",
sql: "BEGIN EXECUTE IMMEDIATE 'SELECT * FROM `exec.proj.tbl`'; END;",
defaultProjectID: "default-proj",
want: nil,
wantErr: true,
wantErrMsg: "EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place",
},
{
name: "table inside string literal should be ignored",
sql: "SELECT * FROM `real.table.one` WHERE name = 'select * from `fake.table.two`'",
defaultProjectID: "default-proj",
want: []string{"real.table.one"},
wantErr: false,
},
{
name: "string with escaped single quote",
sql: "SELECT 'this is a string with an escaped quote \\' and a fake table `fake.table.one`' FROM `real.table.two`",
defaultProjectID: "default-proj",
want: []string{"real.table.two"},
wantErr: false,
},
{
name: "string with escaped double quote",
sql: `SELECT "this is a string with an escaped quote \" and a fake table ` + "`fake.table.one`" + `" FROM ` + "`real.table.two`",
defaultProjectID: "default-proj",
want: []string{"real.table.two"},
wantErr: false,
},
{
name: "multi-line comment",
sql: "/* `fake.table.1` */ SELECT * FROM `real.table.2`",
defaultProjectID: "default-proj",
want: []string{"real.table.2"},
wantErr: false,
},
{
name: "raw string with backslash should be ignored",
sql: "SELECT * FROM `real.table.one` WHERE name = r'a raw string with a \\ and a fake table `fake.table.two`'",
defaultProjectID: "default-proj",
want: []string{"real.table.one"},
wantErr: false,
},
{
name: "capital R raw string with quotes inside should be ignored",
sql: `SELECT * FROM ` + "`real.table.one`" + ` WHERE name = R"""a raw string with a ' and a " and a \ and a fake table ` + "`fake.table.two`" + `"""`,
defaultProjectID: "default-proj",
want: []string{"real.table.one"},
wantErr: false,
},
{
name: "triple quoted raw string should be ignored",
sql: "SELECT * FROM `real.table.one` WHERE name = r'''a raw string with a ' and a \" and a \\ and a fake table `fake.table.two`'''",
defaultProjectID: "default-proj",
want: []string{"real.table.one"},
wantErr: false,
},
{
name: "triple quoted capital R raw string should be ignored",
sql: `SELECT * FROM ` + "`real.table.one`" + ` WHERE name = R"""a raw string with a ' and a " and a \ and a fake table ` + "`fake.table.two`" + `"""`,
defaultProjectID: "default-proj",
want: []string{"real.table.one"},
wantErr: false,
},
{
name: "unquoted fully qualified table",
sql: "SELECT * FROM my-project.my_dataset.my_table",
defaultProjectID: "default-proj",
want: []string{"my-project.my_dataset.my_table"},
wantErr: false,
},
{
name: "unquoted partial table with default project",
sql: "SELECT * FROM my_dataset.my_table",
defaultProjectID: "default-proj",
want: []string{"default-proj.my_dataset.my_table"},
wantErr: false,
},
{
name: "unquoted partial table without default project",
sql: "SELECT * FROM my_dataset.my_table",
defaultProjectID: "",
want: nil,
wantErr: true,
},
{
name: "mixed quoting style 1",
sql: "SELECT * FROM `my-project`.my_dataset.my_table",
defaultProjectID: "default-proj",
want: []string{"my-project.my_dataset.my_table"},
wantErr: false,
},
{
name: "mixed quoting style 2",
sql: "SELECT * FROM `my-project`.`my_dataset`.my_table",
defaultProjectID: "default-proj",
want: []string{"my-project.my_dataset.my_table"},
wantErr: false,
},
{
name: "mixed quoting style 3",
sql: "SELECT * FROM `my-project`.`my_dataset`.`my_table`",
defaultProjectID: "default-proj",
want: []string{"my-project.my_dataset.my_table"},
wantErr: false,
},
{
name: "mixed quoted and unquoted tables",
sql: "SELECT * FROM `proj1.data1.tbl1` JOIN proj2.data2.tbl2 ON id",
defaultProjectID: "default-proj",
want: []string{"proj1.data1.tbl1", "proj2.data2.tbl2"},
wantErr: false,
},
{
name: "create table statement",
sql: "CREATE TABLE `my-project.my_dataset.my_table` (x INT64)",
defaultProjectID: "default-proj",
want: []string{"my-project.my_dataset.my_table"},
wantErr: false,
},
{
name: "insert into statement",
sql: "INSERT INTO `my-project.my_dataset.my_table` (x) VALUES (1)",
defaultProjectID: "default-proj",
want: []string{"my-project.my_dataset.my_table"},
wantErr: false,
},
{
name: "update statement",
sql: "UPDATE `my-project.my_dataset.my_table` SET x = 2 WHERE true",
defaultProjectID: "default-proj",
want: []string{"my-project.my_dataset.my_table"},
wantErr: false,
},
{
name: "delete from statement",
sql: "DELETE FROM `my-project.my_dataset.my_table` WHERE true",
defaultProjectID: "default-proj",
want: []string{"my-project.my_dataset.my_table"},
wantErr: false,
},
{
name: "merge into statement",
sql: "MERGE `proj.data.target` T USING `proj.data.source` S ON T.id = S.id WHEN NOT MATCHED THEN INSERT ROW",
defaultProjectID: "default-proj",
want: []string{"proj.data.source", "proj.data.target"},
wantErr: false,
},
{
name: "create schema statement",
sql: "CREATE SCHEMA `my-project.my_dataset`",
defaultProjectID: "default-proj",
want: nil,
wantErr: true,
wantErrMsg: "dataset-level operations like 'CREATE SCHEMA' are not allowed",
},
{
name: "create dataset statement",
sql: "CREATE DATASET `my-project.my_dataset`",
defaultProjectID: "default-proj",
want: nil,
wantErr: true,
wantErrMsg: "dataset-level operations like 'CREATE DATASET' are not allowed",
},
{
name: "drop schema statement",
sql: "DROP SCHEMA `my-project.my_dataset`",
defaultProjectID: "default-proj",
want: nil,
wantErr: true,
wantErrMsg: "dataset-level operations like 'DROP SCHEMA' are not allowed",
},
{
name: "drop dataset statement",
sql: "DROP DATASET `my-project.my_dataset`",
defaultProjectID: "default-proj",
want: nil,
wantErr: true,
wantErrMsg: "dataset-level operations like 'DROP DATASET' are not allowed",
},
{
name: "alter schema statement",
sql: "ALTER SCHEMA my_dataset SET OPTIONS(description='new description')",
defaultProjectID: "default-proj",
want: nil,
wantErr: true,
wantErrMsg: "dataset-level operations like 'ALTER SCHEMA' are not allowed",
},
{
name: "alter dataset statement",
sql: "ALTER DATASET my_dataset SET OPTIONS(description='new description')",
defaultProjectID: "default-proj",
want: nil,
wantErr: true,
wantErrMsg: "dataset-level operations like 'ALTER DATASET' are not allowed",
},
{
name: "begin...end block",
sql: "BEGIN CREATE TABLE `proj.data.tbl1` (x INT64); INSERT `proj.data.tbl2` (y) VALUES (1); END;",
defaultProjectID: "default-proj",
want: []string{"proj.data.tbl1", "proj.data.tbl2"},
wantErr: false,
},
{
name: "complex begin...end block with comments and different quoting",
sql: `
BEGIN
-- Create a new table
CREATE TABLE proj.data.tbl1 (x INT64);
/* Insert some data from another table */
INSERT INTO ` + "`proj.data.tbl2`" + ` (y) SELECT y FROM proj.data.source;
END;`,
defaultProjectID: "default-proj",
want: []string{"proj.data.source", "proj.data.tbl1", "proj.data.tbl2"},
wantErr: false,
},
{
name: "call fully qualified procedure",
sql: "CALL my-project.my_dataset.my_procedure()",
defaultProjectID: "default-proj",
want: nil,
wantErr: true,
wantErrMsg: "CALL is not allowed when dataset restrictions are in place",
},
{
name: "call partially qualified procedure",
sql: "CALL my_dataset.my_procedure()",
defaultProjectID: "default-proj",
want: nil,
wantErr: true,
wantErrMsg: "CALL is not allowed when dataset restrictions are in place",
},
{
name: "call procedure in begin...end block",
sql: "BEGIN CALL proj.data.proc1(); SELECT * FROM proj.data.tbl1; END;",
defaultProjectID: "default-proj",
want: nil,
wantErr: true,
wantErrMsg: "CALL is not allowed when dataset restrictions are in place",
},
{
name: "call procedure with newline",
sql: "CALL\nmy_dataset.my_procedure()",
defaultProjectID: "default-proj",
want: nil,
wantErr: true,
wantErrMsg: "CALL is not allowed when dataset restrictions are in place",
},
{
name: "call procedure without default project should fail",
sql: "CALL my_dataset.my_procedure()",
defaultProjectID: "",
want: nil,
wantErr: true,
wantErrMsg: "CALL is not allowed when dataset restrictions are in place",
},
{
name: "create procedure statement",
sql: "CREATE PROCEDURE my_dataset.my_procedure() BEGIN SELECT 1; END;",
defaultProjectID: "default-proj",
want: nil,
wantErr: true,
wantErrMsg: "unanalyzable statements like 'CREATE PROCEDURE' are not allowed",
},
{
name: "create or replace procedure statement",
sql: "CREATE\n OR \nREPLACE \nPROCEDURE my_dataset.my_procedure() BEGIN SELECT 1; END;",
defaultProjectID: "default-proj",
want: nil,
wantErr: true,
wantErrMsg: "unanalyzable statements like 'CREATE OR REPLACE PROCEDURE' are not allowed",
},
{
name: "create function statement",
sql: "CREATE FUNCTION my_dataset.my_function() RETURNS INT64 AS (1);",
defaultProjectID: "default-proj",
want: nil,
wantErr: true,
wantErrMsg: "unanalyzable statements like 'CREATE FUNCTION' are not allowed",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := bigquerycommon.TableParser(tc.sql, tc.defaultProjectID)
if (err != nil) != tc.wantErr {
t.Errorf("TableParser() error = %v, wantErr %v", err, tc.wantErr)
return
}
if tc.wantErr && tc.wantErrMsg != "" {
if err == nil || !strings.Contains(err.Error(), tc.wantErrMsg) {
t.Errorf("TableParser() error = %v, want err containing %q", err, tc.wantErrMsg)
}
}
// Sort slices to ensure comparison is order-independent.
sort.Strings(got)
sort.Strings(tc.want)
if diff := cmp.Diff(tc.want, got); diff != "" {
t.Errorf("TableParser() mismatch (-want +got):\n%s", diff)
}
})
}
}

View File

@@ -18,6 +18,7 @@ import (
"context"
"encoding/json"
"fmt"
"strings"
bigqueryapi "cloud.google.com/go/bigquery"
yaml "github.com/goccy/go-yaml"
@@ -51,6 +52,8 @@ type compatibleSource interface {
BigQueryRestService() *bigqueryrestapi.Service
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
IsDatasetAllowed(projectID, datasetID string) bool
BigQueryAllowedDatasets() []string
}
// validate compatible sources are still compatible
@@ -86,7 +89,28 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
sqlParameter := tools.NewStringParameter("sql", "The sql to execute.")
sqlDescription := "The sql to execute."
allowedDatasets := s.BigQueryAllowedDatasets()
if len(allowedDatasets) > 0 {
datasetIDs := []string{}
for _, ds := range allowedDatasets {
datasetIDs = append(datasetIDs, fmt.Sprintf("`%s`", ds))
}
if len(datasetIDs) == 1 {
parts := strings.Split(allowedDatasets[0], ".")
if len(parts) < 2 {
return nil, fmt.Errorf("expected split to have 2 parts: %s", allowedDatasets[0])
}
datasetID := parts[1]
sqlDescription += fmt.Sprintf(" The query must only access the %s dataset. "+
"To query a table within this dataset (e.g., `my_table`), "+
"qualify it with the dataset id (e.g., `%s.my_table`).", datasetIDs[0], datasetID)
} else {
sqlDescription += fmt.Sprintf(" The query must only access datasets from the following list: %s.", strings.Join(datasetIDs, ", "))
}
}
sqlParameter := tools.NewStringParameter("sql", sqlDescription)
dryRunParameter := tools.NewBooleanParameterWithDefault(
"dry_run",
false,
@@ -103,16 +127,18 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
IsDatasetAllowed: s.IsDatasetAllowed,
AllowedDatasets: allowedDatasets,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -127,11 +153,13 @@ type Tool struct {
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
@@ -165,6 +193,61 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
if err != nil {
return nil, fmt.Errorf("query validation failed during dry run: %w", err)
}
statementType := dryRunJob.Statistics.Query.StatementType
if len(t.AllowedDatasets) > 0 {
switch statementType {
case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA":
return nil, fmt.Errorf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType)
case "CREATE_FUNCTION", "CREATE_TABLE_FUNCTION", "CREATE_PROCEDURE":
return nil, fmt.Errorf("creating stored routines ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType)
case "CALL":
return nil, fmt.Errorf("calling stored procedures ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType)
}
// Use a map to avoid duplicate table names.
tableIDSet := make(map[string]struct{})
// Get all tables from the dry run result. This is the most reliable method.
queryStats := dryRunJob.Statistics.Query
if queryStats != nil {
for _, tableRef := range queryStats.ReferencedTables {
tableIDSet[fmt.Sprintf("%s.%s.%s", tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId)] = struct{}{}
}
if tableRef := queryStats.DdlTargetTable; tableRef != nil {
tableIDSet[fmt.Sprintf("%s.%s.%s", tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId)] = struct{}{}
}
if tableRef := queryStats.DdlDestinationTable; tableRef != nil {
tableIDSet[fmt.Sprintf("%s.%s.%s", tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId)] = struct{}{}
}
}
var tableNames []string
if len(tableIDSet) > 0 {
for tableID := range tableIDSet {
tableNames = append(tableNames, tableID)
}
} else if statementType != "SELECT" {
// If dry run yields no tables, fall back to the parser for non-SELECT statements
// to catch unsafe operations like EXECUTE IMMEDIATE.
parsedTables, parseErr := bqutil.TableParser(sql, t.Client.Project())
if parseErr != nil {
// If parsing fails (e.g., EXECUTE IMMEDIATE), we cannot guarantee safety, so we must fail.
return nil, fmt.Errorf("could not parse tables from query to validate against allowed datasets: %w", parseErr)
}
tableNames = parsedTables
}
for _, tableID := range tableNames {
parts := strings.Split(tableID, ".")
if len(parts) == 3 {
projectID, datasetID := parts[0], parts[1]
if !t.IsDatasetAllowed(projectID, datasetID) {
return nil, fmt.Errorf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID)
}
}
}
}
if dryRun {
if dryRunJob != nil {
@@ -178,8 +261,6 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return "Dry run was requested, but no job information was returned.", nil
}
statementType := dryRunJob.Statistics.Query.StatementType
// JobStatistics.QueryStatistics.StatementType
query := bqClient.Query(sql)
query.Location = bqClient.Location

View File

@@ -269,6 +269,11 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
"source": "my-instance",
"description": "Tool to list table within a dataset",
},
"execute-sql-restricted": map[string]any{
"kind": "bigquery-execute-sql",
"source": "my-instance",
"description": "Tool to execute SQL",
},
"conversational-analytics-restricted": map[string]any{
"kind": "bigquery-conversational-analytics",
"source": "my-instance",
@@ -307,6 +312,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
// Run tests
runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1)
runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2)
runExecuteSqlWithRestriction(t, allowedTableNameParam1, disallowedTableNameParam)
runExecuteSqlWithRestriction(t, allowedTableNameParam2, disallowedTableNameParam)
runConversationalAnalyticsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, disallowedTableName)
runConversationalAnalyticsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, disallowedTableName)
runForecastWithRestriction(t, allowedForecastTableFullName1, disallowedForecastTableFullName)
@@ -2156,6 +2163,94 @@ func runListTableIdsWithRestriction(t *testing.T, allowedDatasetName, disallowed
}
}
func runExecuteSqlWithRestriction(t *testing.T, allowedTableFullName, disallowedTableFullName string) {
allowedTableParts := strings.Split(strings.Trim(allowedTableFullName, "`"), ".")
if len(allowedTableParts) != 3 {
t.Fatalf("invalid allowed table name format: %s", allowedTableFullName)
}
allowedDatasetID := allowedTableParts[1]
testCases := []struct {
name string
sql string
wantStatusCode int
wantInError string
}{
{
name: "invoke on allowed table",
sql: fmt.Sprintf("SELECT * FROM %s", allowedTableFullName),
wantStatusCode: http.StatusOK,
},
{
name: "invoke on disallowed table",
sql: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName),
wantStatusCode: http.StatusBadRequest,
wantInError: fmt.Sprintf("query accesses dataset '%s', which is not in the allowed list",
strings.Join(
strings.Split(strings.Trim(disallowedTableFullName, "`"), ".")[0:2],
".")),
},
{
name: "disallowed create schema",
sql: "CREATE SCHEMA another_dataset",
wantStatusCode: http.StatusBadRequest,
wantInError: "dataset-level operations like 'CREATE_SCHEMA' are not allowed",
},
{
name: "disallowed alter schema",
sql: fmt.Sprintf("ALTER SCHEMA %s SET OPTIONS(description='new one')", allowedDatasetID),
wantStatusCode: http.StatusBadRequest,
wantInError: "dataset-level operations like 'ALTER_SCHEMA' are not allowed",
},
{
name: "disallowed create function",
sql: fmt.Sprintf("CREATE FUNCTION %s.my_func() RETURNS INT64 AS (1)", allowedDatasetID),
wantStatusCode: http.StatusBadRequest,
wantInError: "creating stored routines ('CREATE_FUNCTION') is not allowed",
},
{
name: "disallowed create procedure",
sql: fmt.Sprintf("CREATE PROCEDURE %s.my_proc() BEGIN SELECT 1; END", allowedDatasetID),
wantStatusCode: http.StatusBadRequest,
wantInError: "unanalyzable statements like 'CREATE PROCEDURE' are not allowed",
},
{
name: "disallowed execute immediate",
sql: "EXECUTE IMMEDIATE 'SELECT 1'",
wantStatusCode: http.StatusBadRequest,
wantInError: "EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql":"%s"}`, tc.sql)))
req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/execute-sql-restricted/invoke", body)
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != tc.wantStatusCode {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
}
if tc.wantInError != "" {
bodyBytes, _ := io.ReadAll(resp.Body)
if !strings.Contains(string(bodyBytes), tc.wantInError) {
t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError)
}
}
})
}
}
func runConversationalAnalyticsWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName, allowedTableName, disallowedTableName string) {
allowedTableRefsJSON := fmt.Sprintf(`[{"projectId":"%s","datasetId":"%s","tableId":"%s"}]`, BigqueryProject, allowedDatasetName, allowedTableName)
disallowedTableRefsJSON := fmt.Sprintf(`[{"projectId":"%s","datasetId":"%s","tableId":"%s"}]`, BigqueryProject, disallowedDatasetName, disallowedTableName)