feat(source/bigquery): add optional write mode config (#1157)

Summary
Adds an optional write_mode configuration to the BigQuery source,
enhancing security by controlling the types of SQL statements that can
be executed to prevent unauthorized data modification.

Key Changes
Added writeMode Configuration: A new write_mode field is added to the
BigQuery source, supporting three modes:

allowed (Default): Permits all SQL statements.

blocked: Allows only SELECT queries.

protected: Enables session-based execution, restricting write operations
(like CREATE TABLE) to the session's temporary dataset, thus protecting
permanent datasets. Note: at the moment, this won't work with
useClientOAuth, will fix this in the future.

These restrictions primarily apply to the bigquery-execute-sql tool and
the session may be used in other tools.
This commit is contained in:
Huan Chen
2025-10-08 13:40:48 -07:00
committed by GitHub
parent 2c4d73b77b
commit 63adc78bea
14 changed files with 789 additions and 106 deletions

View File

@@ -119,6 +119,7 @@ sources:
kind: "bigquery"
project: "my-project-id"
# location: "US" # Optional: Specifies the location for query jobs.
# writeMode: "allowed" # One of: allowed, blocked, protected. Defaults to "allowed".
# allowedDatasets: # Optional: Restricts tool access to a specific list of datasets.
# - "my_dataset_1"
# - "other_project.my_dataset_2"
@@ -133,6 +134,7 @@ sources:
project: "my-project-id"
useClientOAuth: true
# location: "US" # Optional: Specifies the location for query jobs.
# writeMode: "allowed" # One of: allowed, blocked, protected. Defaults to "allowed".
# allowedDatasets: # Optional: Restricts tool access to a specific list of datasets.
# - "my_dataset_1"
# - "other_project.my_dataset_2"
@@ -145,5 +147,6 @@ sources:
| kind | string | true | Must be "bigquery". |
| project | string | true | Id of the Google Cloud project to use for billing and as the default project for BigQuery resources. |
| location | string | false | Specifies the location (e.g., 'us', 'asia-northeast1') in which to run the query job. This location must match the location of any tables referenced in the query. Defaults to the table's location or 'US' if the location cannot be determined. [Learn More](https://cloud.google.com/bigquery/docs/locations) |
| writeMode | string | false | Controls the write behavior for tools. `allowed` (default): All queries are permitted. `blocked`: Only `SELECT` statements are allowed for the `bigquery-execute-sql` tool. `protected`: Enables session-based execution where all tools associated with this source instance share the same [BigQuery session](https://cloud.google.com/bigquery/docs/sessions-intro). This allows for stateful operations using temporary tables (e.g., `CREATE TEMP TABLE`). For `bigquery-execute-sql`, `SELECT` statements can be used on all tables, but write operations are restricted to the session's temporary dataset. For tools like `bigquery-sql`, `bigquery-forecast`, and `bigquery-analyze-contribution`, the `writeMode` restrictions do not apply, but they will operate within the shared session. **Note:** The `protected` mode cannot be used with `useClientOAuth: true`. It is also not recommended for multi-user server environments, as all users would share the same session. A session is terminated automatically after 24 hours of inactivity or after 7 days, whichever comes first. A new session is created on the next request, and any temporary data from the previous session will be lost. |
| allowedDatasets | []string | false | An optional list of dataset IDs that tools using this source are allowed to access. If provided, any tool operation attempting to access a dataset not in this list will be rejected. To enforce this, two types of operations are also disallowed: 1) Dataset-level operations (e.g., `CREATE SCHEMA`), and 2) operations where table access cannot be statically analyzed (e.g., `EXECUTE IMMEDIATE`, `CREATE PROCEDURE`). If a single dataset is provided, it will be treated as the default for prebuilt tools. |
| useClientOAuth | bool | false | If true, forwards the client's OAuth access token from the "Authorization" header to downstream queries. |
| useClientOAuth | bool | false | If true, forwards the client's OAuth access token from the "Authorization" header to downstream queries. **Note:** This cannot be used with `writeMode: protected`. |

View File

@@ -39,6 +39,13 @@ It's compatible with the following sources:
insights. Can be `'NO_PRUNING'` or `'PRUNE_REDUNDANT_INSIGHTS'`. Defaults to
`'PRUNE_REDUNDANT_INSIGHTS'`.
The behavior of this tool is influenced by the `writeMode` setting on its `bigquery` source:
- **`allowed` (default) and `blocked`:** These modes do not impose any special restrictions on the `bigquery-analyze-contribution` tool.
- **`protected`:** This mode enables session-based execution. The tool will operate within the same BigQuery session as other
tools using the same source. This allows the `input_data` parameter to be a query that references temporary resources (e.g.,
`TEMP` tables) created within that session.
## Example

View File

@@ -20,8 +20,15 @@ It's compatible with the following sources:
- **`dry_run`** (optional): If set to `true`, the query is validated but not run,
returning information about the execution instead. Defaults to `false`.
The behavior of this tool is influenced by the `writeMode` setting on its `bigquery` source:
- **`allowed` (default):** All SQL statements are permitted.
- **`blocked`:** Only `SELECT` statements are allowed. Any other type of statement (e.g., `INSERT`, `UPDATE`, `CREATE`) will be rejected.
- **`protected`:** This mode enables session-based execution. `SELECT` statements can be used on all tables, while write operations are allowed only for the session's temporary dataset (e.g., `CREATE TEMP TABLE ...`). This prevents modifications to permanent datasets while allowing stateful, multi-step operations within a secure session.
The tool's behavior is influenced by the `allowedDatasets` restriction on the
`bigquery` source:
`bigquery` source. Similar to `writeMode`, this setting provides an additional layer of security by controlling which datasets can be accessed:
- **Without `allowedDatasets` restriction:** The tool can execute any valid GoogleSQL
query.
- **With `allowedDatasets` restriction:** Before execution, the tool performs a dry run
@@ -33,6 +40,8 @@ The tool's behavior is influenced by the `allowedDatasets` restriction on the
- **Unanalyzable operations** where the accessed tables cannot be determined
statically (e.g., `EXECUTE IMMEDIATE`, `CREATE PROCEDURE`, `CALL`).
> **Note:** This tool is intended for developer assistant workflows with human-in-the-loop and shouldn't be used for production agents.
## Example
```yaml

View File

@@ -33,12 +33,19 @@ query based on the provided parameters:
- **horizon** (integer, optional): The number of future time steps you want to
predict. It defaults to 10 if not specified.
The tool's behavior regarding these parameters is influenced by the `allowedDatasets` restriction on the `bigquery` source:
The behavior of this tool is influenced by the `writeMode` setting on its `bigquery` source:
- **`allowed` (default) and `blocked`:** These modes do not impose any special restrictions on the `bigquery-forecast` tool.
- **`protected`:** This mode enables session-based execution. The tool will operate within the same BigQuery session as other
tools using the same source. This allows the `history_data` parameter to be a query that references temporary resources (e.g.,
`TEMP` tables) created within that session.
The tool's behavior is also influenced by the `allowedDatasets` restriction on the `bigquery` source:
- **Without `allowedDatasets` restriction:** The tool can use any table or query for the `history_data` parameter.
- **With `allowedDatasets` restriction:** The tool verifies that the `history_data` parameter only accesses tables
within the allowed datasets. If `history_data` is a table ID, the tool checks if the table's dataset is in the
allowed list. If `history_data` is a query, the tool performs a dry run to analyze the query and rejects it
if it accesses any table outside the allowed list.
- **With `allowedDatasets` restriction:** The tool verifies that the `history_data` parameter only accesses tables within the allowed datasets.
- If `history_data` is a table ID, the tool checks if the table's dataset is in the allowed list.
- If `history_data` is a query, the tool performs a dry run to analyze the query and rejects it if it accesses any table outside the allowed list.
## Example

View File

@@ -15,6 +15,11 @@ the following sources:
- [bigquery](../../sources/bigquery.md)
The behavior of this tool is influenced by the `writeMode` setting on its `bigquery` source:
- **`allowed` (default) and `blocked`:** These modes do not impose any restrictions on the `bigquery-sql` tool. The pre-defined SQL statement will be executed as-is.
- **`protected`:** This mode enables session-based execution. The tool will operate within the same BigQuery session as other tools using the same source, allowing it to interact with temporary resources like `TEMP` tables created within that session.
### GoogleSQL
BigQuery uses [GoogleSQL][bigquery-googlesql] for querying data. The integration

View File

@@ -20,6 +20,7 @@ import (
"net/http"
"strings"
"sync"
"time"
bigqueryapi "cloud.google.com/go/bigquery"
dataplexapi "cloud.google.com/go/dataplex/apiv1"
@@ -36,11 +37,22 @@ import (
const SourceKind string = "bigquery"
const (
// No write operations are allowed.
WriteModeBlocked string = "blocked"
// Only protected write operations are allowed in a BigQuery session.
WriteModeProtected string = "protected"
// All write operations are allowed.
WriteModeAllowed string = "allowed"
)
// validate interface
var _ sources.SourceConfig = Config{}
type BigqueryClientCreator func(tokenString string, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
type BigQuerySessionProvider func(ctx context.Context) (*Session, error)
type DataplexClientCreator func(tokenString string) (*dataplexapi.CatalogClient, error)
func init() {
@@ -63,6 +75,7 @@ type Config struct {
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Location string `yaml:"location"`
WriteMode string `yaml:"writeMode"`
AllowedDatasets []string `yaml:"allowedDatasets"`
UseClientOAuth bool `yaml:"useClientOAuth"`
}
@@ -73,6 +86,14 @@ func (r Config) SourceConfigKind() string {
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
if r.WriteMode == "" {
r.WriteMode = WriteModeAllowed
}
if r.WriteMode == WriteModeProtected && r.UseClientOAuth {
return nil, fmt.Errorf("writeMode 'protected' cannot be used with useClientOAuth 'true'")
}
var client *bigqueryapi.Client
var restService *bigqueryrestapi.Service
var tokenSource oauth2.TokenSource
@@ -133,9 +154,15 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
TokenSource: tokenSource,
MaxQueryResultRows: 50,
ClientCreator: clientCreator,
WriteMode: r.WriteMode,
AllowedDatasets: allowedDatasets,
UseClientOAuth: r.UseClientOAuth,
}
s.SessionProvider = s.newBigQuerySessionProvider()
if r.WriteMode != WriteModeAllowed && r.WriteMode != WriteModeBlocked && r.WriteMode != WriteModeProtected {
return nil, fmt.Errorf("invalid writeMode %q: must be one of %q, %q, or %q", r.WriteMode, WriteModeAllowed, WriteModeProtected, WriteModeBlocked)
}
s.makeDataplexCatalogClient = s.lazyInitDataplexClient(ctx, tracer)
return s, nil
@@ -156,7 +183,19 @@ type Source struct {
ClientCreator BigqueryClientCreator
AllowedDatasets map[string]struct{}
UseClientOAuth bool
WriteMode string
sessionMutex sync.Mutex
makeDataplexCatalogClient func() (*dataplexapi.CatalogClient, DataplexClientCreator, error)
SessionProvider BigQuerySessionProvider
Session *Session
}
type Session struct {
ID string
ProjectID string
DatasetID string
CreationTime time.Time
LastUsed time.Time
}
func (s *Source) SourceKind() string {
@@ -172,6 +211,103 @@ func (s *Source) BigQueryRestService() *bigqueryrestapi.Service {
return s.RestService
}
func (s *Source) BigQueryWriteMode() string {
return s.WriteMode
}
func (s *Source) BigQuerySession() BigQuerySessionProvider {
return s.SessionProvider
}
func (s *Source) newBigQuerySessionProvider() BigQuerySessionProvider {
return func(ctx context.Context) (*Session, error) {
if s.WriteMode != WriteModeProtected {
return nil, nil
}
s.sessionMutex.Lock()
defer s.sessionMutex.Unlock()
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get logger from context: %w", err)
}
if s.Session != nil {
// Absolute 7-day lifetime check.
const sessionMaxLifetime = 7 * 24 * time.Hour
// This assumes a single task will not exceed 30 minutes, preventing it from failing mid-execution.
const refreshThreshold = 30 * time.Minute
if time.Since(s.Session.CreationTime) > (sessionMaxLifetime - refreshThreshold) {
logger.DebugContext(ctx, "Session is approaching its 7-day maximum lifetime. Creating a new one.")
} else {
job := &bigqueryrestapi.Job{
Configuration: &bigqueryrestapi.JobConfiguration{
DryRun: true,
Query: &bigqueryrestapi.JobConfigurationQuery{
Query: "SELECT 1",
UseLegacySql: new(bool),
ConnectionProperties: []*bigqueryrestapi.ConnectionProperty{{Key: "session_id", Value: s.Session.ID}},
},
},
}
_, err := s.RestService.Jobs.Insert(s.Project, job).Do()
if err == nil {
s.Session.LastUsed = time.Now()
return s.Session, nil
}
logger.DebugContext(ctx, "Session validation failed (likely expired), creating a new one.", "error", err)
}
}
// Create a new session if one doesn't exist, it has passed its 7-day lifetime,
// or it failed the validation dry run.
creationTime := time.Now()
job := &bigqueryrestapi.Job{
JobReference: &bigqueryrestapi.JobReference{
ProjectId: s.Project,
Location: s.Location,
},
Configuration: &bigqueryrestapi.JobConfiguration{
DryRun: true,
Query: &bigqueryrestapi.JobConfigurationQuery{
Query: "SELECT 1",
CreateSession: true,
},
},
}
createdJob, err := s.RestService.Jobs.Insert(s.Project, job).Do()
if err != nil {
return nil, fmt.Errorf("failed to create new session: %w", err)
}
var sessionID, sessionDatasetID, projectID string
if createdJob.Status != nil && createdJob.Statistics.SessionInfo != nil {
sessionID = createdJob.Statistics.SessionInfo.SessionId
} else {
return nil, fmt.Errorf("failed to get session ID from new session job")
}
if createdJob.Configuration != nil && createdJob.Configuration.Query != nil && createdJob.Configuration.Query.DestinationTable != nil {
sessionDatasetID = createdJob.Configuration.Query.DestinationTable.DatasetId
projectID = createdJob.Configuration.Query.DestinationTable.ProjectId
} else {
return nil, fmt.Errorf("failed to get session dataset ID from new session job")
}
s.Session = &Session{
ID: sessionID,
ProjectID: projectID,
DatasetID: sessionDatasetID,
CreationTime: creationTime,
LastUsed: creationTime,
}
return s.Session, nil
}
}
func (s *Source) UseClientAuthorization() bool {
return s.UseClientOAuth
}
@@ -257,7 +393,7 @@ func initBigQueryConnection(
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
cred, err := google.FindDefaultCredentials(ctx, bigqueryapi.Scope)
cred, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to find default Google Cloud credentials with scope %q: %w", bigqueryapi.Scope, err)
}

View File

@@ -37,14 +37,34 @@ func TestParseFromYamlBigQuery(t *testing.T) {
my-instance:
kind: bigquery
project: my-project
location: us
`,
want: server.SourceConfigs{
"my-instance": bigquery.Config{
Name: "my-instance",
Kind: bigquery.SourceKind,
Project: "my-project",
Location: "",
WriteMode: "",
},
},
},
{
desc: "all fields specified",
in: `
sources:
my-instance:
kind: bigquery
project: my-project
location: asia
writeMode: blocked
`,
want: server.SourceConfigs{
"my-instance": bigquery.Config{
Name: "my-instance",
Kind: bigquery.SourceKind,
Project: "my-project",
Location: "us",
Location: "asia",
WriteMode: "blocked",
UseClientOAuth: false,
},
},

View File

@@ -50,6 +50,7 @@ type compatibleSource interface {
BigQueryRestService() *bigqueryrestapi.Service
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
BigQuerySession() bigqueryds.BigQuerySessionProvider
}
// validate compatible sources are still compatible
@@ -122,16 +123,17 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
SessionProvider: s.BigQuerySession(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -146,11 +148,12 @@ type Tool struct {
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
SessionProvider bigqueryds.BigQuerySessionProvider
manifest tools.Manifest
mcpManifest tools.McpManifest
}
// Invoke runs the contribution analysis.
@@ -222,7 +225,22 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
}
createModelQuery := bqClient.Query(createModelSQL)
createModelQuery.CreateSession = true
// Get session from provider if in protected mode.
// Otherwise, a new session will be created by the first query.
session, err := t.SessionProvider(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
if session != nil {
createModelQuery.ConnectionProperties = []*bigqueryapi.ConnectionProperty{
{Key: "session_id", Value: session.ID},
}
} else {
// If not in protected mode, create a session for this invocation.
createModelQuery.CreateSession = true
}
createModelJob, err := createModelQuery.Run(ctx)
if err != nil {
return nil, fmt.Errorf("failed to start create model job: %w", err)
@@ -236,16 +254,21 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("create model job failed: %w", err)
}
if status.Statistics == nil || status.Statistics.SessionInfo == nil || status.Statistics.SessionInfo.SessionID == "" {
return nil, fmt.Errorf("failed to create a BigQuery session")
// Determine the session ID to use for subsequent queries.
// It's either from the pre-existing session (protected mode) or the one just created.
var sessionID string
if session != nil {
sessionID = session.ID
} else if status.Statistics != nil && status.Statistics.SessionInfo != nil {
sessionID = status.Statistics.SessionInfo.SessionID
} else {
return nil, fmt.Errorf("failed to get or create a BigQuery session ID")
}
sessionID := status.Statistics.SessionInfo.SessionID
getInsightsSQL := fmt.Sprintf("SELECT * FROM ML.GET_INSIGHTS(MODEL %s)", modelID)
getInsightsQuery := bqClient.Query(getInsightsSQL)
getInsightsQuery.ConnectionProperties = []*bigqueryapi.ConnectionProperty{
{Key: "session_id", Value: sessionID},
}
getInsightsQuery.ConnectionProperties = []*bigqueryapi.ConnectionProperty{{Key: "session_id", Value: sessionID}}
job, err := getInsightsQuery.Run(ctx)
if err != nil {

View File

@@ -53,3 +53,19 @@ func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, proj
}
return insertResponse, nil
}
// BQTypeStringFromToolType converts a tool parameter type string to a BigQuery standard SQL type string.
func BQTypeStringFromToolType(toolType string) (string, error) {
switch toolType {
case "string":
return "STRING", nil
case "integer":
return "INT64", nil
case "float":
return "FLOAT64", nil
case "boolean":
return "BOOL", nil
default:
return "", fmt.Errorf("unsupported tool parameter type for BigQuery: %s", toolType)
}
}

View File

@@ -49,6 +49,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
BigQueryClient() *bigqueryapi.Client
BigQuerySession() bigqueryds.BigQuerySessionProvider
BigQueryWriteMode() string
BigQueryRestService() *bigqueryrestapi.Service
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
@@ -89,33 +91,43 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
sqlDescription := "The sql to execute."
var sqlDescriptionBuilder strings.Builder
switch s.BigQueryWriteMode() {
case bigqueryds.WriteModeBlocked:
sqlDescriptionBuilder.WriteString("The SQL to execute. In 'blocked' mode, only SELECT statements are allowed; other statement types will fail.")
case bigqueryds.WriteModeProtected:
sqlDescriptionBuilder.WriteString("The SQL to execute. Only SELECT statements and writes to the session's temporary dataset are allowed (e.g., `CREATE TEMP TABLE ...`).")
default: // WriteModeAllowed
sqlDescriptionBuilder.WriteString("The SQL to execute.")
}
allowedDatasets := s.BigQueryAllowedDatasets()
if len(allowedDatasets) > 0 {
datasetIDs := []string{}
for _, ds := range allowedDatasets {
datasetIDs = append(datasetIDs, fmt.Sprintf("`%s`", ds))
}
if len(datasetIDs) == 1 {
parts := strings.Split(allowedDatasets[0], ".")
if len(allowedDatasets) == 1 {
datasetFQN := allowedDatasets[0]
parts := strings.Split(datasetFQN, ".")
if len(parts) < 2 {
return nil, fmt.Errorf("expected split to have 2 parts: %s", allowedDatasets[0])
return nil, fmt.Errorf("expected allowedDataset to have at least 2 parts (project.dataset): %s", datasetFQN)
}
datasetID := parts[1]
sqlDescription += fmt.Sprintf(" The query must only access the %s dataset. "+
sqlDescriptionBuilder.WriteString(fmt.Sprintf(" The query must only access the `%s` dataset. "+
"To query a table within this dataset (e.g., `my_table`), "+
"qualify it with the dataset id (e.g., `%s.my_table`).", datasetIDs[0], datasetID)
"qualify it with the dataset id (e.g., `%s.my_table`).", datasetFQN, datasetID))
} else {
sqlDescription += fmt.Sprintf(" The query must only access datasets from the following list: %s.", strings.Join(datasetIDs, ", "))
datasetIDs := []string{}
for _, ds := range allowedDatasets {
datasetIDs = append(datasetIDs, fmt.Sprintf("`%s`", ds))
}
sqlDescriptionBuilder.WriteString(fmt.Sprintf(" The query must only access datasets from the following list: %s.", strings.Join(datasetIDs, ", ")))
}
}
sqlParameter := tools.NewStringParameter("sql", sqlDescription)
sqlParameter := tools.NewStringParameter("sql", sqlDescriptionBuilder.String())
dryRunParameter := tools.NewBooleanParameterWithDefault(
"dry_run",
false,
"If set to true, the query will be validated and information about the execution "+
"will be returned without running the query. Defaults to false.",
"If set to true, the query will be validated and information about the execution will be returned "+
"without running the query. Defaults to false.",
)
parameters := tools.Parameters{sqlParameter, dryRunParameter}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters)
@@ -130,6 +142,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
WriteMode: s.BigQueryWriteMode(),
SessionProvider: s.BigQuerySession(),
IsDatasetAllowed: s.IsDatasetAllowed,
AllowedDatasets: allowedDatasets,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
@@ -150,6 +164,8 @@ type Tool struct {
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
WriteMode string
SessionProvider bigqueryds.BigQuerySessionProvider
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
@@ -184,12 +200,39 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
}
}
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, nil)
if err != nil {
return nil, fmt.Errorf("query validation failed during dry run: %w", err)
var connProps []*bigqueryapi.ConnectionProperty
var session *bigqueryds.Session
if t.WriteMode == bigqueryds.WriteModeProtected {
session, err = t.SessionProvider(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session for protected mode: %w", err)
}
connProps = []*bigqueryapi.ConnectionProperty{
{Key: "session_id", Value: session.ID},
}
}
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, connProps)
if err != nil {
return nil, fmt.Errorf("query validation failed: %w", err)
}
statementType := dryRunJob.Statistics.Query.StatementType
switch t.WriteMode {
case bigqueryds.WriteModeBlocked:
if statementType != "SELECT" {
return nil, fmt.Errorf("write mode is 'blocked', only SELECT statements are allowed")
}
case bigqueryds.WriteModeProtected:
if dryRunJob.Configuration != nil && dryRunJob.Configuration.Query != nil {
if dest := dryRunJob.Configuration.Query.DestinationTable; dest != nil && dest.DatasetId != session.DatasetID {
return nil, fmt.Errorf("protected write mode only supports SELECT statements, or write operations in the anonymous "+
"dataset of a BigQuery session, but destination was %q", dest.DatasetId)
}
}
}
if len(t.AllowedDatasets) > 0 {
switch statementType {
case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA":
@@ -259,6 +302,8 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
query := bqClient.Query(sql)
query.Location = bqClient.Location
query.ConnectionProperties = connProps
// Log the query executed for debugging.
logger, err := util.LoggerFromContext(ctx)
if err != nil {
@@ -270,10 +315,14 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
// We iterate through the results, convert each row into a map of
// column names to values, and return the collection of rows.
var out []any
it, err := query.Read(ctx)
job, err := query.Run(ctx)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
it, err := job.Read(ctx)
if err != nil {
return nil, fmt.Errorf("unable to read query results: %w", err)
}
for {
var row map[string]bigqueryapi.Value
err = it.Next(&row)

View File

@@ -53,6 +53,7 @@ type compatibleSource interface {
UseClientAuthorization() bool
IsDatasetAllowed(projectID, datasetID string) bool
BigQueryAllowedDatasets() []string
BigQuerySession() bigqueryds.BigQuerySessionProvider
}
// validate compatible sources are still compatible
@@ -123,6 +124,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
IsDatasetAllowed: s.IsDatasetAllowed,
SessionProvider: s.BigQuerySession(),
AllowedDatasets: allowedDatasets,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -145,6 +147,7 @@ type Tool struct {
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
SessionProvider bigqueryds.BigQuerySessionProvider
manifest tools.Manifest
mcpManifest tools.McpManifest
}
@@ -184,13 +187,39 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
}
}
bqClient := t.Client
restService := t.RestService
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = t.ClientCreator(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}
var historyDataSource string
trimmedUpperHistoryData := strings.TrimSpace(strings.ToUpper(historyData))
if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") {
if len(t.AllowedDatasets) > 0 {
dryRunJob, err := bqutil.DryRunQuery(ctx, t.RestService, t.Client.Project(), t.Client.Location, historyData, nil, nil)
var connProps []*bigqueryapi.ConnectionProperty
session, err := t.SessionProvider(ctx)
if err != nil {
return nil, fmt.Errorf("query validation failed during dry run: %w", err)
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
if session != nil {
connProps = []*bigqueryapi.ConnectionProperty{
{Key: "session_id", Value: session.ID},
}
}
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, historyData, nil, connProps)
if err != nil {
return nil, fmt.Errorf("query validation failed: %w", err)
}
statementType := dryRunJob.Statistics.Query.StatementType
if statementType != "SELECT" {
@@ -246,24 +275,19 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
horizon => %d%s)`,
historyDataSource, dataCol, timestampCol, horizon, idColsArg)
bqClient := t.Client
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}
// JobStatistics.QueryStatistics.StatementType
query := bqClient.Query(sql)
query.Location = bqClient.Location
session, err := t.SessionProvider(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
if session != nil {
// Add session ID to the connection properties for subsequent calls.
query.ConnectionProperties = []*bigqueryapi.ConnectionProperty{
{Key: "session_id", Value: session.ID},
}
}
// Log the query executed for debugging.
logger, err := util.LoggerFromContext(ctx)
@@ -276,10 +300,14 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
// We iterate through the results, convert each row into a map of
// column names to values, and return the collection of rows.
var out []any
it, err := query.Read(ctx)
job, err := query.Run(ctx)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
it, err := job.Read(ctx)
if err != nil {
return nil, fmt.Errorf("unable to read query results: %w", err)
}
for {
var row map[string]bigqueryapi.Value
err = it.Next(&row)

View File

@@ -49,6 +49,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
BigQueryClient() *bigqueryapi.Client
BigQuerySession() bigqueryds.BigQuerySessionProvider
BigQueryWriteMode() string
BigQueryRestService() *bigqueryrestapi.Service
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
@@ -106,13 +108,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
TemplateParameters: cfg.TemplateParameters,
AllParams: allParameters,
Statement: cfg.Statement,
UseClientOAuth: s.UseClientAuthorization(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
ClientCreator: s.BigQueryClientCreator(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Statement: cfg.Statement,
UseClientOAuth: s.UseClientAuthorization(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
SessionProvider: s.BigQuerySession(),
ClientCreator: s.BigQueryClientCreator(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -129,12 +132,13 @@ type Tool struct {
TemplateParameters tools.Parameters `yaml:"templateParameters"`
AllParams tools.Parameters `yaml:"allParams"`
Statement string
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
Statement string
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
SessionProvider bigqueryds.BigQuerySessionProvider
ClientCreator bigqueryds.BigqueryClientCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
@@ -187,7 +191,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
if arrayParam, ok := p.(*tools.ArrayParameter); ok {
// Handle array types based on their defined item type.
lowLevelParam.ParameterType.Type = "ARRAY"
itemType, err := BQTypeStringFromToolType(arrayParam.GetItems().GetType())
itemType, err := bqutil.BQTypeStringFromToolType(arrayParam.GetItems().GetType())
if err != nil {
return nil, err
}
@@ -204,7 +208,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
lowLevelParam.ParameterValue.ArrayValues = arrayValues
} else {
// Handle scalar types based on their defined type.
bqType, err := BQTypeStringFromToolType(p.GetType())
bqType, err := bqutil.BQTypeStringFromToolType(p.GetType())
if err != nil {
return nil, err
}
@@ -233,20 +237,36 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
query.Parameters = highLevelParams
query.Location = bqClient.Location
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, query.ConnectionProperties)
if err != nil {
// This is a fallback check in case the switch logic was bypassed.
return nil, fmt.Errorf("final query validation failed: %w", err)
connProps := []*bigqueryapi.ConnectionProperty{}
if t.SessionProvider != nil {
session, err := t.SessionProvider(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
}
if session != nil {
// Add session ID to the connection properties for subsequent calls.
connProps = append(connProps, &bigqueryapi.ConnectionProperty{Key: "session_id", Value: session.ID})
}
}
query.ConnectionProperties = connProps
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), query.Location, newStatement, lowLevelParams, connProps)
if err != nil {
return nil, fmt.Errorf("query validation failed: %w", err)
}
statementType := dryRunJob.Statistics.Query.StatementType
// This block handles SELECT statements, which return a row set.
// We iterate through the results, convert each row into a map of
// column names to values, and return the collection of rows.
it, err := query.Read(ctx)
job, err := query.Run(ctx)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
it, err := job.Read(ctx)
if err != nil {
return nil, fmt.Errorf("unable to read query results: %w", err)
}
var out []any
for {
@@ -300,19 +320,3 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
func (t Tool) RequiresClientAuthorization() bool {
return t.UseClientOAuth
}
func BQTypeStringFromToolType(toolType string) (string, error) {
switch toolType {
case "string":
return "STRING", nil
case "integer":
return "INT64", nil
case "float":
return "FLOAT64", nil
case "boolean":
return "BOOL", nil
// Note: 'array' is handled separately as it has a nested item type.
default:
return "", fmt.Errorf("unsupported tool parameter type for BigQuery: %s", toolType)
}
}

View File

@@ -174,7 +174,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
ddlWant := `"Query executed successfully and returned no content."`
dataInsightsWant := `(?s)Schema Resolved.*Retrieval Query.*SQL Generated.*Answer`
// Partial message; the full error message is too long.
mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"final query validation failed: failed to insert dry run job: googleapi: Error 400: Syntax error: Unexpected identifier \"SELEC\" at [1:1]`
mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"query validation failed: failed to insert dry run job: googleapi: Error 400: Syntax error: Unexpected identifier \"SELEC\" at [1:1]`
mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"f0_\":1}"}]}}`
createColArray := `["id INT64", "name STRING", "age INT64"]`
selectEmptyWant := `"The query returned 0 rows."`
@@ -326,6 +326,165 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
runForecastWithRestriction(t, allowedForecastTableFullName2, disallowedForecastTableFullName)
}
func TestBigQueryWriteModeAllowed(t *testing.T) {
sourceConfig := getBigQueryVars(t)
sourceConfig["writeMode"] = "allowed"
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
datasetName := fmt.Sprintf("temp_toolbox_test_allowed_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
client, err := initBigQueryConnection(BigqueryProject)
if err != nil {
t.Fatalf("unable to create BigQuery connection: %s", err)
}
dataset := client.Dataset(datasetName)
if err := dataset.Create(ctx, &bigqueryapi.DatasetMetadata{Name: datasetName}); err != nil {
t.Fatalf("Failed to create dataset %q: %v", datasetName, err)
}
defer func() {
if err := dataset.DeleteWithContents(ctx); err != nil {
t.Logf("failed to cleanup dataset %s: %v", datasetName, err)
}
}()
toolsFile := map[string]any{
"sources": map[string]any{
"my-instance": sourceConfig,
},
"tools": map[string]any{
"my-exec-sql-tool": map[string]any{
"kind": "bigquery-execute-sql",
"source": "my-instance",
"description": "Tool to execute sql",
},
},
}
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile)
if err != nil {
t.Fatalf("command initialization returned an error: %s", err)
}
defer cleanup()
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
if err != nil {
t.Logf("toolbox command logs: \n%s", out)
t.Fatalf("toolbox didn't start successfully: %s", err)
}
runBigQueryWriteModeAllowedTest(t, datasetName)
}
func TestBigQueryWriteModeBlocked(t *testing.T) {
sourceConfig := getBigQueryVars(t)
sourceConfig["writeMode"] = "blocked"
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
datasetName := fmt.Sprintf("temp_toolbox_test_blocked_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
tableName := fmt.Sprintf("param_table_blocked_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
tableNameParam := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, datasetName, tableName)
client, err := initBigQueryConnection(BigqueryProject)
if err != nil {
t.Fatalf("unable to create BigQuery connection: %s", err)
}
createParamTableStmt, insertParamTableStmt, _, _, _, _, paramTestParams := getBigQueryParamToolInfo(tableNameParam)
teardownTable := setupBigQueryTable(t, ctx, client, createParamTableStmt, insertParamTableStmt, datasetName, tableNameParam, paramTestParams)
defer teardownTable(t)
toolsFile := map[string]any{
"sources": map[string]any{"my-instance": sourceConfig},
"tools": map[string]any{
"my-exec-sql-tool": map[string]any{"kind": "bigquery-execute-sql", "source": "my-instance", "description": "Tool to execute sql"},
},
}
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile)
if err != nil {
t.Fatalf("command initialization returned an error: %s", err)
}
defer cleanup()
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
if err != nil {
t.Logf("toolbox command logs: \n%s", out)
t.Fatalf("toolbox didn't start successfully: %s", err)
}
runBigQueryWriteModeBlockedTest(t, tableNameParam, datasetName)
}
func TestBigQueryWriteModeProtected(t *testing.T) {
sourceConfig := getBigQueryVars(t)
sourceConfig["writeMode"] = "protected"
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
permanentDatasetName := fmt.Sprintf("perm_dataset_protected_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
client, err := initBigQueryConnection(BigqueryProject)
if err != nil {
t.Fatalf("unable to create BigQuery connection: %s", err)
}
dataset := client.Dataset(permanentDatasetName)
if err := dataset.Create(ctx, &bigqueryapi.DatasetMetadata{Name: permanentDatasetName}); err != nil {
t.Fatalf("Failed to create dataset %q: %v", permanentDatasetName, err)
}
defer func() {
if err := dataset.DeleteWithContents(ctx); err != nil {
t.Logf("failed to cleanup dataset %s: %v", permanentDatasetName, err)
}
}()
toolsFile := map[string]any{
"sources": map[string]any{"my-instance": sourceConfig},
"tools": map[string]any{
"my-exec-sql-tool": map[string]any{"kind": "bigquery-execute-sql", "source": "my-instance", "description": "Tool to execute sql"},
"my-sql-tool-protected": map[string]any{
"kind": "bigquery-sql",
"source": "my-instance",
"description": "Tool to query from the session",
"statement": "SELECT * FROM my_shared_temp_table",
},
"my-forecast-tool-protected": map[string]any{
"kind": "bigquery-forecast",
"source": "my-instance",
"description": "Tool to forecast from session temp table",
},
"my-analyze-contribution-tool-protected": map[string]any{
"kind": "bigquery-analyze-contribution",
"source": "my-instance",
"description": "Tool to analyze contribution from session temp table",
},
},
}
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile)
if err != nil {
t.Fatalf("command initialization returned an error: %s", err)
}
defer cleanup()
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
if err != nil {
t.Logf("toolbox command logs: \n%s", out)
t.Fatalf("toolbox didn't start successfully: %s", err)
}
runBigQueryWriteModeProtectedTest(t, permanentDatasetName)
}
// getBigQueryParamToolInfo returns statements and param for my-tool for bigquery kind
func getBigQueryParamToolInfo(tableName string) (string, string, string, string, string, string, []bigqueryapi.QueryParameter) {
createStatement := fmt.Sprintf(`
@@ -900,6 +1059,223 @@ func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamW
}
}
// runInvokeRequest sends a POST request to the given API endpoint and returns the response and parsed JSON body.
func runInvokeRequest(t *testing.T, api, body string, headers map[string]string) (*http.Response, map[string]interface{}) {
t.Helper()
req, err := http.NewRequest(http.MethodPost, api, bytes.NewBufferString(body))
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
for k, v := range headers {
req.Header.Add(k, v)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("failed to send request: %v", err)
}
var result map[string]interface{}
// Use a TeeReader to be able to read the body multiple times (for logging on failure)
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
resp.Body.Close() // Close original body
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Replace with a new reader
if err := json.Unmarshal(bodyBytes, &result); err != nil {
t.Logf("Failed to decode response body: %s", string(bodyBytes))
t.Fatalf("failed to decode response: %v", err)
}
return resp, result
}
func runBigQueryWriteModeAllowedTest(t *testing.T, datasetName string) {
t.Run("CREATE TABLE should succeed", func(t *testing.T) {
sql := fmt.Sprintf("CREATE TABLE %s.new_table (x INT64)", datasetName)
body := fmt.Sprintf(`{"sql": "%s"}`, sql)
resp, result := runInvokeRequest(t, "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", body, nil)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, http.StatusOK, string(bodyBytes))
}
resStr, ok := result["result"].(string)
if !ok {
t.Fatalf("expected 'result' field in response, got %v", result)
}
if resStr != `"Query executed successfully and returned no content."` {
t.Errorf("unexpected result: got %q, want %q", resStr, `"Query executed successfully and returned no content."`)
}
})
}
func runBigQueryWriteModeBlockedTest(t *testing.T, tableNameParam, datasetName string) {
testCases := []struct {
name string
sql string
wantStatusCode int
wantInError string
wantResult string
}{
{"SELECT statement should succeed", fmt.Sprintf("SELECT * FROM %s WHERE id = 1", tableNameParam), http.StatusOK, "", `[{"id":1,"name":"Alice"}]`},
{"INSERT statement should fail", fmt.Sprintf("INSERT INTO %s (id, name) VALUES (10, 'test')", tableNameParam), http.StatusBadRequest, "write mode is 'blocked', only SELECT statements are allowed", ""},
{"CREATE TABLE statement should fail", fmt.Sprintf("CREATE TABLE %s.new_table (x INT64)", datasetName), http.StatusBadRequest, "write mode is 'blocked', only SELECT statements are allowed", ""},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
body := fmt.Sprintf(`{"sql": "%s"}`, tc.sql)
resp, result := runInvokeRequest(t, "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", body, nil)
defer resp.Body.Close()
if resp.StatusCode != tc.wantStatusCode {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
}
if tc.wantInError != "" {
errStr, ok := result["error"].(string)
if !ok {
t.Fatalf("expected 'error' field in response, got %v", result)
}
if !strings.Contains(errStr, tc.wantInError) {
t.Fatalf("expected error message to contain %q, but got %q", tc.wantInError, errStr)
}
}
if tc.wantResult != "" {
resStr, ok := result["result"].(string)
if !ok {
t.Fatalf("expected 'result' field in response, got %v", result)
}
if resStr != tc.wantResult {
t.Fatalf("unexpected result: got %q, want %q", resStr, tc.wantResult)
}
}
})
}
}
func runBigQueryWriteModeProtectedTest(t *testing.T, permanentDatasetName string) {
testCases := []struct {
name string
toolName string
requestBody string
wantStatusCode int
wantInError string
wantResult string
}{
{
name: "CREATE TABLE to permanent dataset should fail",
toolName: "my-exec-sql-tool",
requestBody: fmt.Sprintf(`{"sql": "CREATE TABLE %s.new_table (x INT64)"}`, permanentDatasetName),
wantStatusCode: http.StatusBadRequest,
wantInError: "protected write mode only supports SELECT statements, or write operations in the anonymous dataset",
wantResult: "",
},
{
name: "CREATE TEMP TABLE should succeed",
toolName: "my-exec-sql-tool",
requestBody: `{"sql": "CREATE TEMP TABLE my_shared_temp_table (x INT64)"}`,
wantStatusCode: http.StatusOK,
wantInError: "",
wantResult: `"Query executed successfully and returned no content."`,
},
{
name: "INSERT into TEMP TABLE should succeed",
toolName: "my-exec-sql-tool",
requestBody: `{"sql": "INSERT INTO my_shared_temp_table (x) VALUES (42)"}`,
wantStatusCode: http.StatusOK,
wantInError: "",
wantResult: `"Query executed successfully and returned no content."`,
},
{
name: "SELECT from TEMP TABLE with exec-sql should succeed",
toolName: "my-exec-sql-tool",
requestBody: `{"sql": "SELECT * FROM my_shared_temp_table"}`,
wantStatusCode: http.StatusOK,
wantInError: "",
wantResult: `[{"x":42}]`,
},
{
name: "SELECT from TEMP TABLE with sql-tool should succeed",
toolName: "my-sql-tool-protected",
requestBody: `{}`,
wantStatusCode: http.StatusOK,
wantInError: "",
wantResult: `[{"x":42}]`,
},
{
name: "CREATE TEMP TABLE for forecast should succeed",
toolName: "my-exec-sql-tool",
requestBody: `{"sql": "CREATE TEMP TABLE forecast_temp_table (ts TIMESTAMP, data FLOAT64) AS SELECT TIMESTAMP('2025-01-01T00:00:00Z') AS ts, 10.0 AS data UNION ALL SELECT TIMESTAMP('2025-01-01T01:00:00Z'), 11.0 UNION ALL SELECT TIMESTAMP('2025-01-01T02:00:00Z'), 12.0 UNION ALL SELECT TIMESTAMP('2025-01-01T03:00:00Z'), 13.0"}`,
wantStatusCode: http.StatusOK,
wantInError: "",
wantResult: `"Query executed successfully and returned no content."`,
},
{
name: "Forecast from TEMP TABLE should succeed",
toolName: "my-forecast-tool-protected",
requestBody: `{"history_data": "SELECT * FROM forecast_temp_table", "timestamp_col": "ts", "data_col": "data", "horizon": 1}`,
wantStatusCode: http.StatusOK,
wantInError: "",
wantResult: `"forecast_timestamp"`,
},
{
name: "CREATE TEMP TABLE for contribution analysis should succeed",
toolName: "my-exec-sql-tool",
requestBody: `{"sql": "CREATE TEMP TABLE contribution_temp_table (dim1 STRING, is_test BOOL, metric FLOAT64) AS SELECT 'a' as dim1, true as is_test, 100.0 as metric UNION ALL SELECT 'b', false, 120.0"}`,
wantStatusCode: http.StatusOK,
wantInError: "",
wantResult: `"Query executed successfully and returned no content."`,
},
{
name: "Analyze contribution from TEMP TABLE should succeed",
toolName: "my-analyze-contribution-tool-protected",
requestBody: `{"input_data": "SELECT * FROM contribution_temp_table", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": ["dim1"]}`,
wantStatusCode: http.StatusOK,
wantInError: "",
wantResult: `"relative_difference"`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
api := fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", tc.toolName)
resp, result := runInvokeRequest(t, api, tc.requestBody, nil)
defer resp.Body.Close()
if resp.StatusCode != tc.wantStatusCode {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
}
if tc.wantInError != "" {
errStr, ok := result["error"].(string)
if !ok {
t.Fatalf("expected 'error' field in response, got %v", result)
}
if !strings.Contains(errStr, tc.wantInError) {
t.Fatalf("expected error message to contain %q, but got %q", tc.wantInError, errStr)
}
}
if tc.wantResult != "" {
resStr, ok := result["result"].(string)
if !ok {
t.Fatalf("expected 'result' field in response, got %v", result)
}
if !strings.Contains(resStr, tc.wantResult) {
t.Fatalf("expected %q to contain %q, but it did not", resStr, tc.wantResult)
}
}
})
}
}
func runBigQueryExecuteSqlToolInvokeDryRunTest(t *testing.T, datasetName string) {
// Get ID token
idToken, err := tests.GetGoogleIdToken(tests.ClientId)