mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 07:28:05 -05:00
feat(tool/bigquery-execute-sql)!: add allowed datasets support (#1443)
## Description This introduces a breaking change. The bigquery-execute-sql tool will now enforce the allowed datasets setting from its BigQuery source configuration. Previously, this setting had no effect on the tool. --- > Should include a concise description of the changes (bug or feature), it's > impact, along with a summary of the solution ## PR Checklist --- > Thank you for opening a Pull Request! Before submitting your PR, there are a > few things you can do to make sure it goes smoothly: - [ ] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) - [ ] Make sure to add `!` if this involve a breaking change 🛠️ Fixes https://github.com/googleapis/genai-toolbox/issues/873 --------- Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
This commit is contained in:
@@ -15,9 +15,23 @@ It's compatible with the following sources:
|
||||
|
||||
- [bigquery](../../sources/bigquery.md)
|
||||
|
||||
`bigquery-execute-sql` takes a required `sql` input parameter and runs the SQL
|
||||
statement against the configured `source`. It also supports an optional `dry_run`
|
||||
parameter to validate a query without executing it.
|
||||
`bigquery-execute-sql` accepts the following parameters:
|
||||
- **`sql`** (required): The GoogleSQL statement to execute.
|
||||
- **`dry_run`** (optional): If set to `true`, the query is validated but not run,
|
||||
returning information about the execution instead. Defaults to `false`.
|
||||
|
||||
The tool's behavior is influenced by the `allowedDatasets` restriction on the
|
||||
`bigquery` source:
|
||||
- **Without `allowedDatasets` restriction:** The tool can execute any valid GoogleSQL
|
||||
query.
|
||||
- **With `allowedDatasets` restriction:** Before execution, the tool performs a dry run
|
||||
to analyze the query.
|
||||
It will reject the query if it attempts to access any table outside the
|
||||
allowed `datasets` list. To enforce this restriction, the following operations
|
||||
are also disallowed:
|
||||
- **Dataset-level operations** (e.g., `CREATE SCHEMA`, `ALTER SCHEMA`).
|
||||
- **Unanalyzable operations** where the accessed tables cannot be determined
|
||||
statically (e.g., `EXECUTE IMMEDIATE`, `CREATE PROCEDURE`, `CALL`).
|
||||
|
||||
## Example
|
||||
|
||||
|
||||
431
internal/tools/bigquery/bigquerycommon/table_name_parser.go
Normal file
431
internal/tools/bigquery/bigquerycommon/table_name_parser.go
Normal file
@@ -0,0 +1,431 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package bigquerycommon
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// parserState defines the state of the SQL parser's state machine.
|
||||
type parserState int
|
||||
|
||||
const (
|
||||
stateNormal parserState = iota
|
||||
// String states
|
||||
stateInSingleQuoteString
|
||||
stateInDoubleQuoteString
|
||||
stateInTripleSingleQuoteString
|
||||
stateInTripleDoubleQuoteString
|
||||
stateInRawSingleQuoteString
|
||||
stateInRawDoubleQuoteString
|
||||
stateInRawTripleSingleQuoteString
|
||||
stateInRawTripleDoubleQuoteString
|
||||
// Comment states
|
||||
stateInSingleLineCommentDash
|
||||
stateInSingleLineCommentHash
|
||||
stateInMultiLineComment
|
||||
)
|
||||
|
||||
// SQL statement verbs
|
||||
const (
|
||||
verbCreate = "create"
|
||||
verbAlter = "alter"
|
||||
verbDrop = "drop"
|
||||
verbSelect = "select"
|
||||
verbInsert = "insert"
|
||||
verbUpdate = "update"
|
||||
verbDelete = "delete"
|
||||
verbMerge = "merge"
|
||||
)
|
||||
|
||||
var tableFollowsKeywords = map[string]bool{
|
||||
"from": true,
|
||||
"join": true,
|
||||
"update": true,
|
||||
"into": true, // INSERT INTO, MERGE INTO
|
||||
"table": true, // CREATE TABLE, ALTER TABLE
|
||||
"using": true, // MERGE ... USING
|
||||
"insert": true, // INSERT my_table
|
||||
"merge": true, // MERGE my_table
|
||||
}
|
||||
|
||||
var tableContextExitKeywords = map[string]bool{
|
||||
"where": true,
|
||||
"group": true, // GROUP BY
|
||||
"having": true,
|
||||
"order": true, // ORDER BY
|
||||
"limit": true,
|
||||
"window": true,
|
||||
"on": true, // JOIN ... ON
|
||||
"set": true, // UPDATE ... SET
|
||||
"when": true, // MERGE ... WHEN
|
||||
}
|
||||
|
||||
// TableParser is the main entry point for parsing a SQL string to find all referenced table IDs.
|
||||
// It handles multi-statement SQL, comments, and recursive parsing of EXECUTE IMMEDIATE statements.
|
||||
func TableParser(sql, defaultProjectID string) ([]string, error) {
|
||||
tableIDSet := make(map[string]struct{})
|
||||
visitedSQLs := make(map[string]struct{})
|
||||
if _, err := parseSQL(sql, defaultProjectID, tableIDSet, visitedSQLs, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tableIDs := make([]string, 0, len(tableIDSet))
|
||||
for id := range tableIDSet {
|
||||
tableIDs = append(tableIDs, id)
|
||||
}
|
||||
return tableIDs, nil
|
||||
}
|
||||
|
||||
// parseSQL is the core recursive function that processes SQL strings.
|
||||
// It uses a state machine to find table names and recursively parse EXECUTE IMMEDIATE.
|
||||
func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visitedSQLs map[string]struct{}, inSubquery bool) (int, error) {
|
||||
// Prevent infinite recursion.
|
||||
if _, ok := visitedSQLs[sql]; ok {
|
||||
return len(sql), nil
|
||||
}
|
||||
visitedSQLs[sql] = struct{}{}
|
||||
|
||||
state := stateNormal
|
||||
expectingTable := false
|
||||
var lastTableKeyword, lastToken, statementVerb string
|
||||
runes := []rune(sql)
|
||||
|
||||
for i := 0; i < len(runes); {
|
||||
char := runes[i]
|
||||
remaining := sql[i:]
|
||||
|
||||
switch state {
|
||||
case stateNormal:
|
||||
if strings.HasPrefix(remaining, "--") {
|
||||
state = stateInSingleLineCommentDash
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(remaining, "#") {
|
||||
state = stateInSingleLineCommentHash
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(remaining, "/*") {
|
||||
state = stateInMultiLineComment
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
if char == '(' {
|
||||
if expectingTable {
|
||||
// The subquery starts after '('.
|
||||
consumed, err := parseSQL(remaining[1:], defaultProjectID, tableIDSet, visitedSQLs, true)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
// Advance i by the length of the subquery + the opening parenthesis.
|
||||
// The recursive call returns what it consumed, including the closing parenthesis.
|
||||
i += consumed + 1
|
||||
// For most keywords, we expect only one table. `from` can have multiple "tables" (subqueries).
|
||||
if lastTableKeyword != "from" {
|
||||
expectingTable = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
if char == ')' {
|
||||
if inSubquery {
|
||||
return i + 1, nil
|
||||
}
|
||||
}
|
||||
|
||||
if char == ';' {
|
||||
statementVerb = ""
|
||||
lastToken = ""
|
||||
i++
|
||||
continue
|
||||
|
||||
}
|
||||
|
||||
// Raw strings must be checked before regular strings.
|
||||
if strings.HasPrefix(remaining, "r'''") || strings.HasPrefix(remaining, "R'''") {
|
||||
state = stateInRawTripleSingleQuoteString
|
||||
i += 4
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(remaining, `r"""`) || strings.HasPrefix(remaining, `R"""`) {
|
||||
state = stateInRawTripleDoubleQuoteString
|
||||
i += 4
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(remaining, "r'") || strings.HasPrefix(remaining, "R'") {
|
||||
state = stateInRawSingleQuoteString
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(remaining, `r"`) || strings.HasPrefix(remaining, `R"`) {
|
||||
state = stateInRawDoubleQuoteString
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(remaining, "'''") {
|
||||
state = stateInTripleSingleQuoteString
|
||||
i += 3
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(remaining, `"""`) {
|
||||
state = stateInTripleDoubleQuoteString
|
||||
i += 3
|
||||
continue
|
||||
}
|
||||
if char == '\'' {
|
||||
state = stateInSingleQuoteString
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if char == '"' {
|
||||
state = stateInDoubleQuoteString
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if unicode.IsLetter(char) || char == '`' {
|
||||
parts, consumed, err := parseIdentifierSequence(remaining)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if consumed == 0 {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if len(parts) == 1 {
|
||||
keyword := strings.ToLower(parts[0])
|
||||
switch keyword {
|
||||
case "call":
|
||||
return 0, fmt.Errorf("CALL is not allowed when dataset restrictions are in place, as the called procedure's contents cannot be safely analyzed")
|
||||
case "immediate":
|
||||
if lastToken == "execute" {
|
||||
return 0, fmt.Errorf("EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place, as its contents cannot be safely analyzed")
|
||||
}
|
||||
case "procedure", "function":
|
||||
if lastToken == "create" || lastToken == "create or replace" {
|
||||
return 0, fmt.Errorf("unanalyzable statements like '%s %s' are not allowed", strings.ToUpper(lastToken), strings.ToUpper(keyword))
|
||||
}
|
||||
case verbCreate, verbAlter, verbDrop, verbSelect, verbInsert, verbUpdate, verbDelete, verbMerge:
|
||||
if statementVerb == "" {
|
||||
statementVerb = keyword
|
||||
}
|
||||
}
|
||||
|
||||
if statementVerb == verbCreate || statementVerb == verbAlter || statementVerb == verbDrop {
|
||||
if keyword == "schema" || keyword == "dataset" {
|
||||
return 0, fmt.Errorf("dataset-level operations like '%s %s' are not allowed when dataset restrictions are in place", strings.ToUpper(statementVerb), strings.ToUpper(keyword))
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := tableFollowsKeywords[keyword]; ok {
|
||||
expectingTable = true
|
||||
lastTableKeyword = keyword
|
||||
} else if _, ok := tableContextExitKeywords[keyword]; ok {
|
||||
expectingTable = false
|
||||
lastTableKeyword = ""
|
||||
}
|
||||
if lastToken == "create" && keyword == "or" {
|
||||
lastToken = "create or"
|
||||
} else if lastToken == "create or" && keyword == "replace" {
|
||||
lastToken = "create or replace"
|
||||
} else {
|
||||
lastToken = keyword
|
||||
}
|
||||
} else if len(parts) >= 2 {
|
||||
// This is a multi-part identifier. If we were expecting a table, this is it.
|
||||
if expectingTable {
|
||||
tableID, err := formatTableID(parts, defaultProjectID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if tableID != "" {
|
||||
tableIDSet[tableID] = struct{}{}
|
||||
}
|
||||
// For most keywords, we expect only one table.
|
||||
if lastTableKeyword != "from" {
|
||||
expectingTable = false
|
||||
}
|
||||
}
|
||||
lastToken = ""
|
||||
}
|
||||
|
||||
i += consumed
|
||||
continue
|
||||
}
|
||||
i++
|
||||
|
||||
case stateInSingleQuoteString:
|
||||
if char == '\\' {
|
||||
i += 2 // Skip backslash and the escaped character.
|
||||
continue
|
||||
}
|
||||
if char == '\'' {
|
||||
state = stateNormal
|
||||
}
|
||||
i++
|
||||
case stateInDoubleQuoteString:
|
||||
if char == '\\' {
|
||||
i += 2 // Skip backslash and the escaped character.
|
||||
continue
|
||||
}
|
||||
if char == '"' {
|
||||
state = stateNormal
|
||||
}
|
||||
i++
|
||||
case stateInTripleSingleQuoteString:
|
||||
if strings.HasPrefix(remaining, "'''") {
|
||||
state = stateNormal
|
||||
i += 3
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
case stateInTripleDoubleQuoteString:
|
||||
if strings.HasPrefix(remaining, `"""`) {
|
||||
state = stateNormal
|
||||
i += 3
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
case stateInSingleLineCommentDash, stateInSingleLineCommentHash:
|
||||
if char == '\n' {
|
||||
state = stateNormal
|
||||
}
|
||||
i++
|
||||
case stateInMultiLineComment:
|
||||
if strings.HasPrefix(remaining, "*/") {
|
||||
state = stateNormal
|
||||
i += 2
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
case stateInRawSingleQuoteString:
|
||||
if char == '\'' {
|
||||
state = stateNormal
|
||||
}
|
||||
i++
|
||||
case stateInRawDoubleQuoteString:
|
||||
if char == '"' {
|
||||
state = stateNormal
|
||||
}
|
||||
i++
|
||||
case stateInRawTripleSingleQuoteString:
|
||||
if strings.HasPrefix(remaining, "'''") {
|
||||
state = stateNormal
|
||||
i += 3
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
case stateInRawTripleDoubleQuoteString:
|
||||
if strings.HasPrefix(remaining, `"""`) {
|
||||
state = stateNormal
|
||||
i += 3
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if inSubquery {
|
||||
return 0, fmt.Errorf("unclosed subquery parenthesis")
|
||||
}
|
||||
return len(sql), nil
|
||||
}
|
||||
|
||||
// parseIdentifierSequence parses a sequence of dot-separated identifiers.
|
||||
// It returns the parts of the identifier, the number of characters consumed, and an error.
|
||||
func parseIdentifierSequence(s string) ([]string, int, error) {
|
||||
var parts []string
|
||||
var totalConsumed int
|
||||
|
||||
for {
|
||||
remaining := s[totalConsumed:]
|
||||
trimmed := strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||
totalConsumed += len(remaining) - len(trimmed)
|
||||
current := s[totalConsumed:]
|
||||
|
||||
if len(current) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
var part string
|
||||
var consumed int
|
||||
|
||||
if current[0] == '`' {
|
||||
end := strings.Index(current[1:], "`")
|
||||
if end == -1 {
|
||||
return nil, 0, fmt.Errorf("unclosed backtick identifier")
|
||||
}
|
||||
part = current[1 : end+1]
|
||||
consumed = end + 2
|
||||
} else if len(current) > 0 && unicode.IsLetter(rune(current[0])) {
|
||||
end := strings.IndexFunc(current, func(r rune) bool {
|
||||
return !unicode.IsLetter(r) && !unicode.IsNumber(r) && r != '_' && r != '-'
|
||||
})
|
||||
if end == -1 {
|
||||
part = current
|
||||
consumed = len(current)
|
||||
} else {
|
||||
part = current[:end]
|
||||
consumed = end
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
|
||||
if current[0] == '`' && strings.Contains(part, ".") {
|
||||
// This handles cases like `project.dataset.table` but not `project.dataset`.table.
|
||||
// If the character after the quoted identifier is not a dot, we treat it as a full name.
|
||||
if len(current) <= consumed || current[consumed] != '.' {
|
||||
parts = append(parts, strings.Split(part, ".")...)
|
||||
totalConsumed += consumed
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
parts = append(parts, strings.Split(part, ".")...)
|
||||
totalConsumed += consumed
|
||||
|
||||
if len(s) <= totalConsumed || s[totalConsumed] != '.' {
|
||||
break
|
||||
}
|
||||
totalConsumed++
|
||||
}
|
||||
return parts, totalConsumed, nil
|
||||
}
|
||||
|
||||
func formatTableID(parts []string, defaultProjectID string) (string, error) {
|
||||
if len(parts) < 2 || len(parts) > 3 {
|
||||
// Not a table identifier (could be a CTE, column, etc.).
|
||||
// Return the consumed length so the main loop can skip this identifier.
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var tableID string
|
||||
if len(parts) == 3 { // project.dataset.table
|
||||
tableID = strings.Join(parts, ".")
|
||||
} else { // dataset.table
|
||||
if defaultProjectID == "" {
|
||||
return "", fmt.Errorf("query contains table '%s' without project ID, and no default project ID is provided", strings.Join(parts, "."))
|
||||
}
|
||||
tableID = fmt.Sprintf("%s.%s", defaultProjectID, strings.Join(parts, "."))
|
||||
}
|
||||
|
||||
return tableID, nil
|
||||
}
|
||||
496
internal/tools/bigquery/bigquerycommon/table_name_parser_test.go
Normal file
496
internal/tools/bigquery/bigquerycommon/table_name_parser_test.go
Normal file
@@ -0,0 +1,496 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package bigquerycommon_test
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
||||
)
|
||||
|
||||
func TestTableParser(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
sql string
|
||||
defaultProjectID string
|
||||
want []string
|
||||
wantErr bool
|
||||
wantErrMsg string
|
||||
}{
|
||||
{
|
||||
name: "single fully qualified table",
|
||||
sql: "SELECT * FROM `my-project.my_dataset.my_table`",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"my-project.my_dataset.my_table"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "multiple statements with same table",
|
||||
sql: "select * from proj1.data1.tbl1 limit 1; select A.b from proj1.data1.tbl1 as A limit 1;",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"proj1.data1.tbl1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "multiple fully qualified tables",
|
||||
sql: "SELECT * FROM `proj1.data1`.`tbl1` JOIN proj2.`data2.tbl2` ON id",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"proj1.data1.tbl1", "proj2.data2.tbl2"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "duplicate tables",
|
||||
sql: "SELECT * FROM `proj1.data1.tbl1` JOIN proj1.data1.tbl1 ON id",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"proj1.data1.tbl1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "partial table with default project",
|
||||
sql: "SELECT * FROM `my_dataset`.my_table",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"default-proj.my_dataset.my_table"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "partial table without default project",
|
||||
sql: "SELECT * FROM `my_dataset.my_table`",
|
||||
defaultProjectID: "",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "mixed fully qualified and partial tables",
|
||||
sql: "SELECT t1.*, t2.* FROM `proj1.data1.tbl1` AS t1 JOIN `data2.tbl2` AS t2 ON t1.id = t2.id",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"proj1.data1.tbl1", "default-proj.data2.tbl2"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "no tables",
|
||||
sql: "SELECT 1+1",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "ignore single part identifiers (like CTEs)",
|
||||
sql: "WITH my_cte AS (SELECT 1) SELECT * FROM `my_cte`",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "complex CTE",
|
||||
sql: "WITH cte1 AS (SELECT * FROM `real.table.one`), cte2 AS (SELECT * FROM cte1) SELECT * FROM cte2 JOIN `real.table.two` ON true",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"real.table.one", "real.table.two"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "nested subquery should be parsed",
|
||||
sql: "SELECT * FROM (SELECT a FROM (SELECT A.b FROM `real.table.nested` AS A))",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"real.table.nested"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "from clause with unnest",
|
||||
sql: "SELECT event.name FROM `my-project.my_dataset.my_table` AS A, UNNEST(A.events) AS event",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"my-project.my_dataset.my_table"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "ignore more than 3 parts",
|
||||
sql: "SELECT * FROM `proj.data.tbl.col`",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "complex query",
|
||||
sql: "SELECT name FROM (SELECT name FROM `proj1.data1.tbl1`) UNION ALL SELECT name FROM `data2.tbl2`",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"proj1.data1.tbl1", "default-proj.data2.tbl2"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty sql",
|
||||
sql: "",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "with comments",
|
||||
sql: "SELECT * FROM `proj1.data1.tbl1`; -- comment `fake.table.one` \n SELECT * FROM `proj2.data2.tbl2`; # comment `fake.table.two`",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"proj1.data1.tbl1", "proj2.data2.tbl2"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "multi-statement with semicolon",
|
||||
sql: "SELECT * FROM `proj1.data1.tbl1`; SELECT * FROM `proj2.data2.tbl2`",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"proj1.data1.tbl1", "proj2.data2.tbl2"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "simple execute immediate",
|
||||
sql: "EXECUTE IMMEDIATE 'SELECT * FROM `exec.proj.tbl`'",
|
||||
defaultProjectID: "default-proj",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
wantErrMsg: "EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place",
|
||||
},
|
||||
{
|
||||
name: "execute immediate with multiple spaces",
|
||||
sql: "EXECUTE IMMEDIATE 'SELECT 1'",
|
||||
defaultProjectID: "default-proj",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
wantErrMsg: "EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place",
|
||||
},
|
||||
{
|
||||
name: "execute immediate with newline",
|
||||
sql: "EXECUTE\nIMMEDIATE 'SELECT 1'",
|
||||
defaultProjectID: "default-proj",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
wantErrMsg: "EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place",
|
||||
},
|
||||
{
|
||||
name: "execute immediate with comment",
|
||||
sql: "EXECUTE -- some comment\n IMMEDIATE 'SELECT * FROM `exec.proj.tbl`'",
|
||||
defaultProjectID: "default-proj",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
wantErrMsg: "EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place",
|
||||
},
|
||||
{
|
||||
name: "nested execute immediate",
|
||||
sql: "EXECUTE IMMEDIATE \"EXECUTE IMMEDIATE '''SELECT * FROM `nested.exec.tbl`'''\"",
|
||||
defaultProjectID: "default-proj",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
wantErrMsg: "EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place",
|
||||
},
|
||||
{
|
||||
name: "begin execute immediate",
|
||||
sql: "BEGIN EXECUTE IMMEDIATE 'SELECT * FROM `exec.proj.tbl`'; END;",
|
||||
defaultProjectID: "default-proj",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
wantErrMsg: "EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place",
|
||||
},
|
||||
{
|
||||
name: "table inside string literal should be ignored",
|
||||
sql: "SELECT * FROM `real.table.one` WHERE name = 'select * from `fake.table.two`'",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"real.table.one"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "string with escaped single quote",
|
||||
sql: "SELECT 'this is a string with an escaped quote \\' and a fake table `fake.table.one`' FROM `real.table.two`",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"real.table.two"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "string with escaped double quote",
|
||||
sql: `SELECT "this is a string with an escaped quote \" and a fake table ` + "`fake.table.one`" + `" FROM ` + "`real.table.two`",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"real.table.two"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "multi-line comment",
|
||||
sql: "/* `fake.table.1` */ SELECT * FROM `real.table.2`",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"real.table.2"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "raw string with backslash should be ignored",
|
||||
sql: "SELECT * FROM `real.table.one` WHERE name = r'a raw string with a \\ and a fake table `fake.table.two`'",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"real.table.one"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "capital R raw string with quotes inside should be ignored",
|
||||
sql: `SELECT * FROM ` + "`real.table.one`" + ` WHERE name = R"""a raw string with a ' and a " and a \ and a fake table ` + "`fake.table.two`" + `"""`,
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"real.table.one"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "triple quoted raw string should be ignored",
|
||||
sql: "SELECT * FROM `real.table.one` WHERE name = r'''a raw string with a ' and a \" and a \\ and a fake table `fake.table.two`'''",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"real.table.one"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "triple quoted capital R raw string should be ignored",
|
||||
sql: `SELECT * FROM ` + "`real.table.one`" + ` WHERE name = R"""a raw string with a ' and a " and a \ and a fake table ` + "`fake.table.two`" + `"""`,
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"real.table.one"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "unquoted fully qualified table",
|
||||
sql: "SELECT * FROM my-project.my_dataset.my_table",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"my-project.my_dataset.my_table"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "unquoted partial table with default project",
|
||||
sql: "SELECT * FROM my_dataset.my_table",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"default-proj.my_dataset.my_table"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "unquoted partial table without default project",
|
||||
sql: "SELECT * FROM my_dataset.my_table",
|
||||
defaultProjectID: "",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "mixed quoting style 1",
|
||||
sql: "SELECT * FROM `my-project`.my_dataset.my_table",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"my-project.my_dataset.my_table"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "mixed quoting style 2",
|
||||
sql: "SELECT * FROM `my-project`.`my_dataset`.my_table",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"my-project.my_dataset.my_table"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "mixed quoting style 3",
|
||||
sql: "SELECT * FROM `my-project`.`my_dataset`.`my_table`",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"my-project.my_dataset.my_table"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "mixed quoted and unquoted tables",
|
||||
sql: "SELECT * FROM `proj1.data1.tbl1` JOIN proj2.data2.tbl2 ON id",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"proj1.data1.tbl1", "proj2.data2.tbl2"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "create table statement",
|
||||
sql: "CREATE TABLE `my-project.my_dataset.my_table` (x INT64)",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"my-project.my_dataset.my_table"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "insert into statement",
|
||||
sql: "INSERT INTO `my-project.my_dataset.my_table` (x) VALUES (1)",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"my-project.my_dataset.my_table"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "update statement",
|
||||
sql: "UPDATE `my-project.my_dataset.my_table` SET x = 2 WHERE true",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"my-project.my_dataset.my_table"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "delete from statement",
|
||||
sql: "DELETE FROM `my-project.my_dataset.my_table` WHERE true",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"my-project.my_dataset.my_table"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "merge into statement",
|
||||
sql: "MERGE `proj.data.target` T USING `proj.data.source` S ON T.id = S.id WHEN NOT MATCHED THEN INSERT ROW",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"proj.data.source", "proj.data.target"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "create schema statement",
|
||||
sql: "CREATE SCHEMA `my-project.my_dataset`",
|
||||
defaultProjectID: "default-proj",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
wantErrMsg: "dataset-level operations like 'CREATE SCHEMA' are not allowed",
|
||||
},
|
||||
{
|
||||
name: "create dataset statement",
|
||||
sql: "CREATE DATASET `my-project.my_dataset`",
|
||||
defaultProjectID: "default-proj",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
wantErrMsg: "dataset-level operations like 'CREATE DATASET' are not allowed",
|
||||
},
|
||||
{
|
||||
name: "drop schema statement",
|
||||
sql: "DROP SCHEMA `my-project.my_dataset`",
|
||||
defaultProjectID: "default-proj",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
wantErrMsg: "dataset-level operations like 'DROP SCHEMA' are not allowed",
|
||||
},
|
||||
{
|
||||
name: "drop dataset statement",
|
||||
sql: "DROP DATASET `my-project.my_dataset`",
|
||||
defaultProjectID: "default-proj",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
wantErrMsg: "dataset-level operations like 'DROP DATASET' are not allowed",
|
||||
},
|
||||
{
|
||||
name: "alter schema statement",
|
||||
sql: "ALTER SCHEMA my_dataset SET OPTIONS(description='new description')",
|
||||
defaultProjectID: "default-proj",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
wantErrMsg: "dataset-level operations like 'ALTER SCHEMA' are not allowed",
|
||||
},
|
||||
{
|
||||
name: "alter dataset statement",
|
||||
sql: "ALTER DATASET my_dataset SET OPTIONS(description='new description')",
|
||||
defaultProjectID: "default-proj",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
wantErrMsg: "dataset-level operations like 'ALTER DATASET' are not allowed",
|
||||
},
|
||||
{
|
||||
name: "begin...end block",
|
||||
sql: "BEGIN CREATE TABLE `proj.data.tbl1` (x INT64); INSERT `proj.data.tbl2` (y) VALUES (1); END;",
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"proj.data.tbl1", "proj.data.tbl2"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "complex begin...end block with comments and different quoting",
|
||||
sql: `
|
||||
BEGIN
|
||||
-- Create a new table
|
||||
CREATE TABLE proj.data.tbl1 (x INT64);
|
||||
/* Insert some data from another table */
|
||||
INSERT INTO ` + "`proj.data.tbl2`" + ` (y) SELECT y FROM proj.data.source;
|
||||
END;`,
|
||||
defaultProjectID: "default-proj",
|
||||
want: []string{"proj.data.source", "proj.data.tbl1", "proj.data.tbl2"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "call fully qualified procedure",
|
||||
sql: "CALL my-project.my_dataset.my_procedure()",
|
||||
defaultProjectID: "default-proj",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
wantErrMsg: "CALL is not allowed when dataset restrictions are in place",
|
||||
},
|
||||
{
|
||||
name: "call partially qualified procedure",
|
||||
sql: "CALL my_dataset.my_procedure()",
|
||||
defaultProjectID: "default-proj",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
wantErrMsg: "CALL is not allowed when dataset restrictions are in place",
|
||||
},
|
||||
{
|
||||
name: "call procedure in begin...end block",
|
||||
sql: "BEGIN CALL proj.data.proc1(); SELECT * FROM proj.data.tbl1; END;",
|
||||
defaultProjectID: "default-proj",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
wantErrMsg: "CALL is not allowed when dataset restrictions are in place",
|
||||
},
|
||||
{
|
||||
name: "call procedure with newline",
|
||||
sql: "CALL\nmy_dataset.my_procedure()",
|
||||
defaultProjectID: "default-proj",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
wantErrMsg: "CALL is not allowed when dataset restrictions are in place",
|
||||
},
|
||||
{
|
||||
name: "call procedure without default project should fail",
|
||||
sql: "CALL my_dataset.my_procedure()",
|
||||
defaultProjectID: "",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
wantErrMsg: "CALL is not allowed when dataset restrictions are in place",
|
||||
},
|
||||
{
|
||||
name: "create procedure statement",
|
||||
sql: "CREATE PROCEDURE my_dataset.my_procedure() BEGIN SELECT 1; END;",
|
||||
defaultProjectID: "default-proj",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
wantErrMsg: "unanalyzable statements like 'CREATE PROCEDURE' are not allowed",
|
||||
},
|
||||
{
|
||||
name: "create or replace procedure statement",
|
||||
sql: "CREATE\n OR \nREPLACE \nPROCEDURE my_dataset.my_procedure() BEGIN SELECT 1; END;",
|
||||
defaultProjectID: "default-proj",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
wantErrMsg: "unanalyzable statements like 'CREATE OR REPLACE PROCEDURE' are not allowed",
|
||||
},
|
||||
{
|
||||
name: "create function statement",
|
||||
sql: "CREATE FUNCTION my_dataset.my_function() RETURNS INT64 AS (1);",
|
||||
defaultProjectID: "default-proj",
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
wantErrMsg: "unanalyzable statements like 'CREATE FUNCTION' are not allowed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := bigquerycommon.TableParser(tc.sql, tc.defaultProjectID)
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("TableParser() error = %v, wantErr %v", err, tc.wantErr)
|
||||
return
|
||||
}
|
||||
if tc.wantErr && tc.wantErrMsg != "" {
|
||||
if err == nil || !strings.Contains(err.Error(), tc.wantErrMsg) {
|
||||
t.Errorf("TableParser() error = %v, want err containing %q", err, tc.wantErrMsg)
|
||||
}
|
||||
}
|
||||
// Sort slices to ensure comparison is order-independent.
|
||||
sort.Strings(got)
|
||||
sort.Strings(tc.want)
|
||||
if diff := cmp.Diff(tc.want, got); diff != "" {
|
||||
t.Errorf("TableParser() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
@@ -51,6 +52,8 @@ type compatibleSource interface {
|
||||
BigQueryRestService() *bigqueryrestapi.Service
|
||||
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
||||
UseClientAuthorization() bool
|
||||
IsDatasetAllowed(projectID, datasetID string) bool
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
@@ -86,7 +89,28 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
sqlParameter := tools.NewStringParameter("sql", "The sql to execute.")
|
||||
sqlDescription := "The sql to execute."
|
||||
allowedDatasets := s.BigQueryAllowedDatasets()
|
||||
if len(allowedDatasets) > 0 {
|
||||
datasetIDs := []string{}
|
||||
for _, ds := range allowedDatasets {
|
||||
datasetIDs = append(datasetIDs, fmt.Sprintf("`%s`", ds))
|
||||
}
|
||||
|
||||
if len(datasetIDs) == 1 {
|
||||
parts := strings.Split(allowedDatasets[0], ".")
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("expected split to have 2 parts: %s", allowedDatasets[0])
|
||||
}
|
||||
datasetID := parts[1]
|
||||
sqlDescription += fmt.Sprintf(" The query must only access the %s dataset. "+
|
||||
"To query a table within this dataset (e.g., `my_table`), "+
|
||||
"qualify it with the dataset id (e.g., `%s.my_table`).", datasetIDs[0], datasetID)
|
||||
} else {
|
||||
sqlDescription += fmt.Sprintf(" The query must only access datasets from the following list: %s.", strings.Join(datasetIDs, ", "))
|
||||
}
|
||||
}
|
||||
sqlParameter := tools.NewStringParameter("sql", sqlDescription)
|
||||
dryRunParameter := tools.NewBooleanParameterWithDefault(
|
||||
"dry_run",
|
||||
false,
|
||||
@@ -103,16 +127,18 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
AllowedDatasets: allowedDatasets,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -127,11 +153,13 @@ type Tool struct {
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
AllowedDatasets []string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
@@ -165,6 +193,61 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed during dry run: %w", err)
|
||||
}
|
||||
statementType := dryRunJob.Statistics.Query.StatementType
|
||||
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
switch statementType {
|
||||
case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA":
|
||||
return nil, fmt.Errorf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType)
|
||||
case "CREATE_FUNCTION", "CREATE_TABLE_FUNCTION", "CREATE_PROCEDURE":
|
||||
return nil, fmt.Errorf("creating stored routines ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType)
|
||||
case "CALL":
|
||||
return nil, fmt.Errorf("calling stored procedures ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType)
|
||||
}
|
||||
|
||||
// Use a map to avoid duplicate table names.
|
||||
tableIDSet := make(map[string]struct{})
|
||||
|
||||
// Get all tables from the dry run result. This is the most reliable method.
|
||||
queryStats := dryRunJob.Statistics.Query
|
||||
if queryStats != nil {
|
||||
for _, tableRef := range queryStats.ReferencedTables {
|
||||
tableIDSet[fmt.Sprintf("%s.%s.%s", tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId)] = struct{}{}
|
||||
}
|
||||
if tableRef := queryStats.DdlTargetTable; tableRef != nil {
|
||||
tableIDSet[fmt.Sprintf("%s.%s.%s", tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId)] = struct{}{}
|
||||
}
|
||||
if tableRef := queryStats.DdlDestinationTable; tableRef != nil {
|
||||
tableIDSet[fmt.Sprintf("%s.%s.%s", tableRef.ProjectId, tableRef.DatasetId, tableRef.TableId)] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
var tableNames []string
|
||||
if len(tableIDSet) > 0 {
|
||||
for tableID := range tableIDSet {
|
||||
tableNames = append(tableNames, tableID)
|
||||
}
|
||||
} else if statementType != "SELECT" {
|
||||
// If dry run yields no tables, fall back to the parser for non-SELECT statements
|
||||
// to catch unsafe operations like EXECUTE IMMEDIATE.
|
||||
parsedTables, parseErr := bqutil.TableParser(sql, t.Client.Project())
|
||||
if parseErr != nil {
|
||||
// If parsing fails (e.g., EXECUTE IMMEDIATE), we cannot guarantee safety, so we must fail.
|
||||
return nil, fmt.Errorf("could not parse tables from query to validate against allowed datasets: %w", parseErr)
|
||||
}
|
||||
tableNames = parsedTables
|
||||
}
|
||||
|
||||
for _, tableID := range tableNames {
|
||||
parts := strings.Split(tableID, ".")
|
||||
if len(parts) == 3 {
|
||||
projectID, datasetID := parts[0], parts[1]
|
||||
if !t.IsDatasetAllowed(projectID, datasetID) {
|
||||
return nil, fmt.Errorf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if dryRun {
|
||||
if dryRunJob != nil {
|
||||
@@ -178,8 +261,6 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return "Dry run was requested, but no job information was returned.", nil
|
||||
}
|
||||
|
||||
statementType := dryRunJob.Statistics.Query.StatementType
|
||||
// JobStatistics.QueryStatistics.StatementType
|
||||
query := bqClient.Query(sql)
|
||||
query.Location = bqClient.Location
|
||||
|
||||
|
||||
@@ -269,6 +269,11 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
|
||||
"source": "my-instance",
|
||||
"description": "Tool to list table within a dataset",
|
||||
},
|
||||
"execute-sql-restricted": map[string]any{
|
||||
"kind": "bigquery-execute-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to execute SQL",
|
||||
},
|
||||
"conversational-analytics-restricted": map[string]any{
|
||||
"kind": "bigquery-conversational-analytics",
|
||||
"source": "my-instance",
|
||||
@@ -307,6 +312,8 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
|
||||
// Run tests
|
||||
runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1)
|
||||
runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2)
|
||||
runExecuteSqlWithRestriction(t, allowedTableNameParam1, disallowedTableNameParam)
|
||||
runExecuteSqlWithRestriction(t, allowedTableNameParam2, disallowedTableNameParam)
|
||||
runConversationalAnalyticsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, disallowedTableName)
|
||||
runConversationalAnalyticsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, disallowedTableName)
|
||||
runForecastWithRestriction(t, allowedForecastTableFullName1, disallowedForecastTableFullName)
|
||||
@@ -2156,6 +2163,94 @@ func runListTableIdsWithRestriction(t *testing.T, allowedDatasetName, disallowed
|
||||
}
|
||||
}
|
||||
|
||||
func runExecuteSqlWithRestriction(t *testing.T, allowedTableFullName, disallowedTableFullName string) {
|
||||
allowedTableParts := strings.Split(strings.Trim(allowedTableFullName, "`"), ".")
|
||||
if len(allowedTableParts) != 3 {
|
||||
t.Fatalf("invalid allowed table name format: %s", allowedTableFullName)
|
||||
}
|
||||
allowedDatasetID := allowedTableParts[1]
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
sql string
|
||||
wantStatusCode int
|
||||
wantInError string
|
||||
}{
|
||||
{
|
||||
name: "invoke on allowed table",
|
||||
sql: fmt.Sprintf("SELECT * FROM %s", allowedTableFullName),
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "invoke on disallowed table",
|
||||
sql: fmt.Sprintf("SELECT * FROM %s", disallowedTableFullName),
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantInError: fmt.Sprintf("query accesses dataset '%s', which is not in the allowed list",
|
||||
strings.Join(
|
||||
strings.Split(strings.Trim(disallowedTableFullName, "`"), ".")[0:2],
|
||||
".")),
|
||||
},
|
||||
{
|
||||
name: "disallowed create schema",
|
||||
sql: "CREATE SCHEMA another_dataset",
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantInError: "dataset-level operations like 'CREATE_SCHEMA' are not allowed",
|
||||
},
|
||||
{
|
||||
name: "disallowed alter schema",
|
||||
sql: fmt.Sprintf("ALTER SCHEMA %s SET OPTIONS(description='new one')", allowedDatasetID),
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantInError: "dataset-level operations like 'ALTER_SCHEMA' are not allowed",
|
||||
},
|
||||
{
|
||||
name: "disallowed create function",
|
||||
sql: fmt.Sprintf("CREATE FUNCTION %s.my_func() RETURNS INT64 AS (1)", allowedDatasetID),
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantInError: "creating stored routines ('CREATE_FUNCTION') is not allowed",
|
||||
},
|
||||
{
|
||||
name: "disallowed create procedure",
|
||||
sql: fmt.Sprintf("CREATE PROCEDURE %s.my_proc() BEGIN SELECT 1; END", allowedDatasetID),
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantInError: "unanalyzable statements like 'CREATE PROCEDURE' are not allowed",
|
||||
},
|
||||
{
|
||||
name: "disallowed execute immediate",
|
||||
sql: "EXECUTE IMMEDIATE 'SELECT 1'",
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantInError: "EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql":"%s"}`, tc.sql)))
|
||||
req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/execute-sql-restricted/invoke", body)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
if tc.wantInError != "" {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
if !strings.Contains(string(bodyBytes), tc.wantInError) {
|
||||
t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runConversationalAnalyticsWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName, allowedTableName, disallowedTableName string) {
|
||||
allowedTableRefsJSON := fmt.Sprintf(`[{"projectId":"%s","datasetId":"%s","tableId":"%s"}]`, BigqueryProject, allowedDatasetName, allowedTableName)
|
||||
disallowedTableRefsJSON := fmt.Sprintf(`[{"projectId":"%s","datasetId":"%s","tableId":"%s"}]`, BigqueryProject, disallowedDatasetName, disallowedTableName)
|
||||
|
||||
Reference in New Issue
Block a user