mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-11 08:28:11 -05:00
feat(source/bigquery): add optional write mode config (#1157)
Summary Adds an optional write_mode configuration to the BigQuery source, enhancing security by controlling the types of SQL statements that can be executed to prevent unauthorized data modification. Key Changes Added writeMode Configuration: A new write_mode field is added to the BigQuery source, supporting three modes: allowed (Default): Permits all SQL statements. blocked: Allows only SELECT queries. protected: Enables session-based execution, restricting write operations (like CREATE TABLE) to the session's temporary dataset, thus protecting permanent datasets. Note: at the moment, this won't work with useClientOAuth, will fix this in the future. These restrictions primarily apply to the bigquery-execute-sql tool and the session may be used in other tools.
This commit is contained in:
@@ -119,6 +119,7 @@ sources:
|
||||
kind: "bigquery"
|
||||
project: "my-project-id"
|
||||
# location: "US" # Optional: Specifies the location for query jobs.
|
||||
# writeMode: "allowed" # One of: allowed, blocked, protected. Defaults to "allowed".
|
||||
# allowedDatasets: # Optional: Restricts tool access to a specific list of datasets.
|
||||
# - "my_dataset_1"
|
||||
# - "other_project.my_dataset_2"
|
||||
@@ -133,6 +134,7 @@ sources:
|
||||
project: "my-project-id"
|
||||
useClientOAuth: true
|
||||
# location: "US" # Optional: Specifies the location for query jobs.
|
||||
# writeMode: "allowed" # One of: allowed, blocked, protected. Defaults to "allowed".
|
||||
# allowedDatasets: # Optional: Restricts tool access to a specific list of datasets.
|
||||
# - "my_dataset_1"
|
||||
# - "other_project.my_dataset_2"
|
||||
@@ -145,5 +147,6 @@ sources:
|
||||
| kind | string | true | Must be "bigquery". |
|
||||
| project | string | true | Id of the Google Cloud project to use for billing and as the default project for BigQuery resources. |
|
||||
| location | string | false | Specifies the location (e.g., 'us', 'asia-northeast1') in which to run the query job. This location must match the location of any tables referenced in the query. Defaults to the table's location or 'US' if the location cannot be determined. [Learn More](https://cloud.google.com/bigquery/docs/locations) |
|
||||
| writeMode | string | false | Controls the write behavior for tools. `allowed` (default): All queries are permitted. `blocked`: Only `SELECT` statements are allowed for the `bigquery-execute-sql` tool. `protected`: Enables session-based execution where all tools associated with this source instance share the same [BigQuery session](https://cloud.google.com/bigquery/docs/sessions-intro). This allows for stateful operations using temporary tables (e.g., `CREATE TEMP TABLE`). For `bigquery-execute-sql`, `SELECT` statements can be used on all tables, but write operations are restricted to the session's temporary dataset. For tools like `bigquery-sql`, `bigquery-forecast`, and `bigquery-analyze-contribution`, the `writeMode` restrictions do not apply, but they will operate within the shared session. **Note:** The `protected` mode cannot be used with `useClientOAuth: true`. It is also not recommended for multi-user server environments, as all users would share the same session. A session is terminated automatically after 24 hours of inactivity or after 7 days, whichever comes first. A new session is created on the next request, and any temporary data from the previous session will be lost. |
|
||||
| allowedDatasets | []string | false | An optional list of dataset IDs that tools using this source are allowed to access. If provided, any tool operation attempting to access a dataset not in this list will be rejected. To enforce this, two types of operations are also disallowed: 1) Dataset-level operations (e.g., `CREATE SCHEMA`), and 2) operations where table access cannot be statically analyzed (e.g., `EXECUTE IMMEDIATE`, `CREATE PROCEDURE`). If a single dataset is provided, it will be treated as the default for prebuilt tools. |
|
||||
| useClientOAuth | bool | false | If true, forwards the client's OAuth access token from the "Authorization" header to downstream queries. |
|
||||
| useClientOAuth | bool | false | If true, forwards the client's OAuth access token from the "Authorization" header to downstream queries. **Note:** This cannot be used with `writeMode: protected`. |
|
||||
|
||||
@@ -39,6 +39,13 @@ It's compatible with the following sources:
|
||||
insights. Can be `'NO_PRUNING'` or `'PRUNE_REDUNDANT_INSIGHTS'`. Defaults to
|
||||
`'PRUNE_REDUNDANT_INSIGHTS'`.
|
||||
|
||||
The behavior of this tool is influenced by the `writeMode` setting on its `bigquery` source:
|
||||
|
||||
- **`allowed` (default) and `blocked`:** These modes do not impose any special restrictions on the `bigquery-analyze-contribution` tool.
|
||||
- **`protected`:** This mode enables session-based execution. The tool will operate within the same BigQuery session as other
|
||||
tools using the same source. This allows the `input_data` parameter to be a query that references temporary resources (e.g.,
|
||||
`TEMP` tables) created within that session.
|
||||
|
||||
|
||||
## Example
|
||||
|
||||
|
||||
@@ -20,8 +20,15 @@ It's compatible with the following sources:
|
||||
- **`dry_run`** (optional): If set to `true`, the query is validated but not run,
|
||||
returning information about the execution instead. Defaults to `false`.
|
||||
|
||||
The behavior of this tool is influenced by the `writeMode` setting on its `bigquery` source:
|
||||
|
||||
- **`allowed` (default):** All SQL statements are permitted.
|
||||
- **`blocked`:** Only `SELECT` statements are allowed. Any other type of statement (e.g., `INSERT`, `UPDATE`, `CREATE`) will be rejected.
|
||||
- **`protected`:** This mode enables session-based execution. `SELECT` statements can be used on all tables, while write operations are allowed only for the session's temporary dataset (e.g., `CREATE TEMP TABLE ...`). This prevents modifications to permanent datasets while allowing stateful, multi-step operations within a secure session.
|
||||
|
||||
The tool's behavior is influenced by the `allowedDatasets` restriction on the
|
||||
`bigquery` source:
|
||||
`bigquery` source. Similar to `writeMode`, this setting provides an additional layer of security by controlling which datasets can be accessed:
|
||||
|
||||
- **Without `allowedDatasets` restriction:** The tool can execute any valid GoogleSQL
|
||||
query.
|
||||
- **With `allowedDatasets` restriction:** Before execution, the tool performs a dry run
|
||||
@@ -33,6 +40,8 @@ The tool's behavior is influenced by the `allowedDatasets` restriction on the
|
||||
- **Unanalyzable operations** where the accessed tables cannot be determined
|
||||
statically (e.g., `EXECUTE IMMEDIATE`, `CREATE PROCEDURE`, `CALL`).
|
||||
|
||||
> **Note:** This tool is intended for developer assistant workflows with human-in-the-loop and shouldn't be used for production agents.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
|
||||
@@ -33,12 +33,19 @@ query based on the provided parameters:
|
||||
- **horizon** (integer, optional): The number of future time steps you want to
|
||||
predict. It defaults to 10 if not specified.
|
||||
|
||||
The tool's behavior regarding these parameters is influenced by the `allowedDatasets` restriction on the `bigquery` source:
|
||||
The behavior of this tool is influenced by the `writeMode` setting on its `bigquery` source:
|
||||
|
||||
- **`allowed` (default) and `blocked`:** These modes do not impose any special restrictions on the `bigquery-forecast` tool.
|
||||
- **`protected`:** This mode enables session-based execution. The tool will operate within the same BigQuery session as other
|
||||
tools using the same source. This allows the `history_data` parameter to be a query that references temporary resources (e.g.,
|
||||
`TEMP` tables) created within that session.
|
||||
|
||||
The tool's behavior is also influenced by the `allowedDatasets` restriction on the `bigquery` source:
|
||||
|
||||
- **Without `allowedDatasets` restriction:** The tool can use any table or query for the `history_data` parameter.
|
||||
- **With `allowedDatasets` restriction:** The tool verifies that the `history_data` parameter only accesses tables
|
||||
within the allowed datasets. If `history_data` is a table ID, the tool checks if the table's dataset is in the
|
||||
allowed list. If `history_data` is a query, the tool performs a dry run to analyze the query and rejects it
|
||||
if it accesses any table outside the allowed list.
|
||||
- **With `allowedDatasets` restriction:** The tool verifies that the `history_data` parameter only accesses tables within the allowed datasets.
|
||||
- If `history_data` is a table ID, the tool checks if the table's dataset is in the allowed list.
|
||||
- If `history_data` is a query, the tool performs a dry run to analyze the query and rejects it if it accesses any table outside the allowed list.
|
||||
|
||||
## Example
|
||||
|
||||
|
||||
@@ -15,6 +15,11 @@ the following sources:
|
||||
|
||||
- [bigquery](../../sources/bigquery.md)
|
||||
|
||||
The behavior of this tool is influenced by the `writeMode` setting on its `bigquery` source:
|
||||
|
||||
- **`allowed` (default) and `blocked`:** These modes do not impose any restrictions on the `bigquery-sql` tool. The pre-defined SQL statement will be executed as-is.
|
||||
- **`protected`:** This mode enables session-based execution. The tool will operate within the same BigQuery session as other tools using the same source, allowing it to interact with temporary resources like `TEMP` tables created within that session.
|
||||
|
||||
### GoogleSQL
|
||||
|
||||
BigQuery uses [GoogleSQL][bigquery-googlesql] for querying data. The integration
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
||||
@@ -36,11 +37,22 @@ import (
|
||||
|
||||
const SourceKind string = "bigquery"
|
||||
|
||||
const (
|
||||
// No write operations are allowed.
|
||||
WriteModeBlocked string = "blocked"
|
||||
// Only protected write operations are allowed in a BigQuery session.
|
||||
WriteModeProtected string = "protected"
|
||||
// All write operations are allowed.
|
||||
WriteModeAllowed string = "allowed"
|
||||
)
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
type BigqueryClientCreator func(tokenString string, wantRestService bool) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
|
||||
|
||||
type BigQuerySessionProvider func(ctx context.Context) (*Session, error)
|
||||
|
||||
type DataplexClientCreator func(tokenString string) (*dataplexapi.CatalogClient, error)
|
||||
|
||||
func init() {
|
||||
@@ -63,6 +75,7 @@ type Config struct {
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Location string `yaml:"location"`
|
||||
WriteMode string `yaml:"writeMode"`
|
||||
AllowedDatasets []string `yaml:"allowedDatasets"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
}
|
||||
@@ -73,6 +86,14 @@ func (r Config) SourceConfigKind() string {
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
if r.WriteMode == "" {
|
||||
r.WriteMode = WriteModeAllowed
|
||||
}
|
||||
|
||||
if r.WriteMode == WriteModeProtected && r.UseClientOAuth {
|
||||
return nil, fmt.Errorf("writeMode 'protected' cannot be used with useClientOAuth 'true'")
|
||||
}
|
||||
|
||||
var client *bigqueryapi.Client
|
||||
var restService *bigqueryrestapi.Service
|
||||
var tokenSource oauth2.TokenSource
|
||||
@@ -133,9 +154,15 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
TokenSource: tokenSource,
|
||||
MaxQueryResultRows: 50,
|
||||
ClientCreator: clientCreator,
|
||||
WriteMode: r.WriteMode,
|
||||
AllowedDatasets: allowedDatasets,
|
||||
UseClientOAuth: r.UseClientOAuth,
|
||||
}
|
||||
s.SessionProvider = s.newBigQuerySessionProvider()
|
||||
|
||||
if r.WriteMode != WriteModeAllowed && r.WriteMode != WriteModeBlocked && r.WriteMode != WriteModeProtected {
|
||||
return nil, fmt.Errorf("invalid writeMode %q: must be one of %q, %q, or %q", r.WriteMode, WriteModeAllowed, WriteModeProtected, WriteModeBlocked)
|
||||
}
|
||||
s.makeDataplexCatalogClient = s.lazyInitDataplexClient(ctx, tracer)
|
||||
return s, nil
|
||||
|
||||
@@ -156,7 +183,19 @@ type Source struct {
|
||||
ClientCreator BigqueryClientCreator
|
||||
AllowedDatasets map[string]struct{}
|
||||
UseClientOAuth bool
|
||||
WriteMode string
|
||||
sessionMutex sync.Mutex
|
||||
makeDataplexCatalogClient func() (*dataplexapi.CatalogClient, DataplexClientCreator, error)
|
||||
SessionProvider BigQuerySessionProvider
|
||||
Session *Session
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
ID string
|
||||
ProjectID string
|
||||
DatasetID string
|
||||
CreationTime time.Time
|
||||
LastUsed time.Time
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
@@ -172,6 +211,103 @@ func (s *Source) BigQueryRestService() *bigqueryrestapi.Service {
|
||||
return s.RestService
|
||||
}
|
||||
|
||||
func (s *Source) BigQueryWriteMode() string {
|
||||
return s.WriteMode
|
||||
}
|
||||
|
||||
func (s *Source) BigQuerySession() BigQuerySessionProvider {
|
||||
return s.SessionProvider
|
||||
}
|
||||
|
||||
func (s *Source) newBigQuerySessionProvider() BigQuerySessionProvider {
|
||||
return func(ctx context.Context) (*Session, error) {
|
||||
if s.WriteMode != WriteModeProtected {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
s.sessionMutex.Lock()
|
||||
defer s.sessionMutex.Unlock()
|
||||
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get logger from context: %w", err)
|
||||
}
|
||||
|
||||
if s.Session != nil {
|
||||
// Absolute 7-day lifetime check.
|
||||
const sessionMaxLifetime = 7 * 24 * time.Hour
|
||||
// This assumes a single task will not exceed 30 minutes, preventing it from failing mid-execution.
|
||||
const refreshThreshold = 30 * time.Minute
|
||||
if time.Since(s.Session.CreationTime) > (sessionMaxLifetime - refreshThreshold) {
|
||||
logger.DebugContext(ctx, "Session is approaching its 7-day maximum lifetime. Creating a new one.")
|
||||
} else {
|
||||
job := &bigqueryrestapi.Job{
|
||||
Configuration: &bigqueryrestapi.JobConfiguration{
|
||||
DryRun: true,
|
||||
Query: &bigqueryrestapi.JobConfigurationQuery{
|
||||
Query: "SELECT 1",
|
||||
UseLegacySql: new(bool),
|
||||
ConnectionProperties: []*bigqueryrestapi.ConnectionProperty{{Key: "session_id", Value: s.Session.ID}},
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err := s.RestService.Jobs.Insert(s.Project, job).Do()
|
||||
if err == nil {
|
||||
s.Session.LastUsed = time.Now()
|
||||
return s.Session, nil
|
||||
}
|
||||
logger.DebugContext(ctx, "Session validation failed (likely expired), creating a new one.", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new session if one doesn't exist, it has passed its 7-day lifetime,
|
||||
// or it failed the validation dry run.
|
||||
|
||||
creationTime := time.Now()
|
||||
job := &bigqueryrestapi.Job{
|
||||
JobReference: &bigqueryrestapi.JobReference{
|
||||
ProjectId: s.Project,
|
||||
Location: s.Location,
|
||||
},
|
||||
Configuration: &bigqueryrestapi.JobConfiguration{
|
||||
DryRun: true,
|
||||
Query: &bigqueryrestapi.JobConfigurationQuery{
|
||||
Query: "SELECT 1",
|
||||
CreateSession: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
createdJob, err := s.RestService.Jobs.Insert(s.Project, job).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create new session: %w", err)
|
||||
}
|
||||
|
||||
var sessionID, sessionDatasetID, projectID string
|
||||
if createdJob.Status != nil && createdJob.Statistics.SessionInfo != nil {
|
||||
sessionID = createdJob.Statistics.SessionInfo.SessionId
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to get session ID from new session job")
|
||||
}
|
||||
|
||||
if createdJob.Configuration != nil && createdJob.Configuration.Query != nil && createdJob.Configuration.Query.DestinationTable != nil {
|
||||
sessionDatasetID = createdJob.Configuration.Query.DestinationTable.DatasetId
|
||||
projectID = createdJob.Configuration.Query.DestinationTable.ProjectId
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to get session dataset ID from new session job")
|
||||
}
|
||||
|
||||
s.Session = &Session{
|
||||
ID: sessionID,
|
||||
ProjectID: projectID,
|
||||
DatasetID: sessionDatasetID,
|
||||
CreationTime: creationTime,
|
||||
LastUsed: creationTime,
|
||||
}
|
||||
return s.Session, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Source) UseClientAuthorization() bool {
|
||||
return s.UseClientOAuth
|
||||
}
|
||||
@@ -257,7 +393,7 @@ func initBigQueryConnection(
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
cred, err := google.FindDefaultCredentials(ctx, bigqueryapi.Scope)
|
||||
cred, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform")
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to find default Google Cloud credentials with scope %q: %w", bigqueryapi.Scope, err)
|
||||
}
|
||||
|
||||
@@ -37,14 +37,34 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "",
|
||||
WriteMode: "",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "all fields specified",
|
||||
in: `
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
location: asia
|
||||
writeMode: blocked
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
Location: "asia",
|
||||
WriteMode: "blocked",
|
||||
UseClientOAuth: false,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -50,6 +50,7 @@ type compatibleSource interface {
|
||||
BigQueryRestService() *bigqueryrestapi.Service
|
||||
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
||||
UseClientAuthorization() bool
|
||||
BigQuerySession() bigqueryds.BigQuerySessionProvider
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
@@ -122,16 +123,17 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
SessionProvider: s.BigQuerySession(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -146,11 +148,12 @@ type Tool struct {
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
SessionProvider bigqueryds.BigQuerySessionProvider
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
// Invoke runs the contribution analysis.
|
||||
@@ -222,7 +225,22 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
}
|
||||
|
||||
createModelQuery := bqClient.Query(createModelSQL)
|
||||
createModelQuery.CreateSession = true
|
||||
|
||||
// Get session from provider if in protected mode.
|
||||
// Otherwise, a new session will be created by the first query.
|
||||
session, err := t.SessionProvider(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
|
||||
if session != nil {
|
||||
createModelQuery.ConnectionProperties = []*bigqueryapi.ConnectionProperty{
|
||||
{Key: "session_id", Value: session.ID},
|
||||
}
|
||||
} else {
|
||||
// If not in protected mode, create a session for this invocation.
|
||||
createModelQuery.CreateSession = true
|
||||
}
|
||||
createModelJob, err := createModelQuery.Run(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start create model job: %w", err)
|
||||
@@ -236,16 +254,21 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return nil, fmt.Errorf("create model job failed: %w", err)
|
||||
}
|
||||
|
||||
if status.Statistics == nil || status.Statistics.SessionInfo == nil || status.Statistics.SessionInfo.SessionID == "" {
|
||||
return nil, fmt.Errorf("failed to create a BigQuery session")
|
||||
// Determine the session ID to use for subsequent queries.
|
||||
// It's either from the pre-existing session (protected mode) or the one just created.
|
||||
var sessionID string
|
||||
if session != nil {
|
||||
sessionID = session.ID
|
||||
} else if status.Statistics != nil && status.Statistics.SessionInfo != nil {
|
||||
sessionID = status.Statistics.SessionInfo.SessionID
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to get or create a BigQuery session ID")
|
||||
}
|
||||
sessionID := status.Statistics.SessionInfo.SessionID
|
||||
|
||||
getInsightsSQL := fmt.Sprintf("SELECT * FROM ML.GET_INSIGHTS(MODEL %s)", modelID)
|
||||
|
||||
getInsightsQuery := bqClient.Query(getInsightsSQL)
|
||||
getInsightsQuery.ConnectionProperties = []*bigqueryapi.ConnectionProperty{
|
||||
{Key: "session_id", Value: sessionID},
|
||||
}
|
||||
getInsightsQuery.ConnectionProperties = []*bigqueryapi.ConnectionProperty{{Key: "session_id", Value: sessionID}}
|
||||
|
||||
job, err := getInsightsQuery.Run(ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -53,3 +53,19 @@ func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, proj
|
||||
}
|
||||
return insertResponse, nil
|
||||
}
|
||||
|
||||
// BQTypeStringFromToolType converts a tool parameter type string to a BigQuery standard SQL type string.
|
||||
func BQTypeStringFromToolType(toolType string) (string, error) {
|
||||
switch toolType {
|
||||
case "string":
|
||||
return "STRING", nil
|
||||
case "integer":
|
||||
return "INT64", nil
|
||||
case "float":
|
||||
return "FLOAT64", nil
|
||||
case "boolean":
|
||||
return "BOOL", nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported tool parameter type for BigQuery: %s", toolType)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,6 +49,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
BigQueryClient() *bigqueryapi.Client
|
||||
BigQuerySession() bigqueryds.BigQuerySessionProvider
|
||||
BigQueryWriteMode() string
|
||||
BigQueryRestService() *bigqueryrestapi.Service
|
||||
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
||||
UseClientAuthorization() bool
|
||||
@@ -89,33 +91,43 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
sqlDescription := "The sql to execute."
|
||||
var sqlDescriptionBuilder strings.Builder
|
||||
switch s.BigQueryWriteMode() {
|
||||
case bigqueryds.WriteModeBlocked:
|
||||
sqlDescriptionBuilder.WriteString("The SQL to execute. In 'blocked' mode, only SELECT statements are allowed; other statement types will fail.")
|
||||
case bigqueryds.WriteModeProtected:
|
||||
sqlDescriptionBuilder.WriteString("The SQL to execute. Only SELECT statements and writes to the session's temporary dataset are allowed (e.g., `CREATE TEMP TABLE ...`).")
|
||||
default: // WriteModeAllowed
|
||||
sqlDescriptionBuilder.WriteString("The SQL to execute.")
|
||||
}
|
||||
|
||||
allowedDatasets := s.BigQueryAllowedDatasets()
|
||||
if len(allowedDatasets) > 0 {
|
||||
datasetIDs := []string{}
|
||||
for _, ds := range allowedDatasets {
|
||||
datasetIDs = append(datasetIDs, fmt.Sprintf("`%s`", ds))
|
||||
}
|
||||
|
||||
if len(datasetIDs) == 1 {
|
||||
parts := strings.Split(allowedDatasets[0], ".")
|
||||
if len(allowedDatasets) == 1 {
|
||||
datasetFQN := allowedDatasets[0]
|
||||
parts := strings.Split(datasetFQN, ".")
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("expected split to have 2 parts: %s", allowedDatasets[0])
|
||||
return nil, fmt.Errorf("expected allowedDataset to have at least 2 parts (project.dataset): %s", datasetFQN)
|
||||
}
|
||||
datasetID := parts[1]
|
||||
sqlDescription += fmt.Sprintf(" The query must only access the %s dataset. "+
|
||||
sqlDescriptionBuilder.WriteString(fmt.Sprintf(" The query must only access the `%s` dataset. "+
|
||||
"To query a table within this dataset (e.g., `my_table`), "+
|
||||
"qualify it with the dataset id (e.g., `%s.my_table`).", datasetIDs[0], datasetID)
|
||||
"qualify it with the dataset id (e.g., `%s.my_table`).", datasetFQN, datasetID))
|
||||
} else {
|
||||
sqlDescription += fmt.Sprintf(" The query must only access datasets from the following list: %s.", strings.Join(datasetIDs, ", "))
|
||||
datasetIDs := []string{}
|
||||
for _, ds := range allowedDatasets {
|
||||
datasetIDs = append(datasetIDs, fmt.Sprintf("`%s`", ds))
|
||||
}
|
||||
sqlDescriptionBuilder.WriteString(fmt.Sprintf(" The query must only access datasets from the following list: %s.", strings.Join(datasetIDs, ", ")))
|
||||
}
|
||||
}
|
||||
sqlParameter := tools.NewStringParameter("sql", sqlDescription)
|
||||
|
||||
sqlParameter := tools.NewStringParameter("sql", sqlDescriptionBuilder.String())
|
||||
dryRunParameter := tools.NewBooleanParameterWithDefault(
|
||||
"dry_run",
|
||||
false,
|
||||
"If set to true, the query will be validated and information about the execution "+
|
||||
"will be returned without running the query. Defaults to false.",
|
||||
"If set to true, the query will be validated and information about the execution will be returned "+
|
||||
"without running the query. Defaults to false.",
|
||||
)
|
||||
parameters := tools.Parameters{sqlParameter, dryRunParameter}
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters)
|
||||
@@ -130,6 +142,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
WriteMode: s.BigQueryWriteMode(),
|
||||
SessionProvider: s.BigQuerySession(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
AllowedDatasets: allowedDatasets,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
@@ -150,6 +164,8 @@ type Tool struct {
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
WriteMode string
|
||||
SessionProvider bigqueryds.BigQuerySessionProvider
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
AllowedDatasets []string
|
||||
@@ -184,12 +200,39 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
}
|
||||
}
|
||||
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed during dry run: %w", err)
|
||||
var connProps []*bigqueryapi.ConnectionProperty
|
||||
var session *bigqueryds.Session
|
||||
if t.WriteMode == bigqueryds.WriteModeProtected {
|
||||
session, err = t.SessionProvider(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session for protected mode: %w", err)
|
||||
}
|
||||
connProps = []*bigqueryapi.ConnectionProperty{
|
||||
{Key: "session_id", Value: session.ID},
|
||||
}
|
||||
}
|
||||
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, connProps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||
}
|
||||
|
||||
statementType := dryRunJob.Statistics.Query.StatementType
|
||||
|
||||
switch t.WriteMode {
|
||||
case bigqueryds.WriteModeBlocked:
|
||||
if statementType != "SELECT" {
|
||||
return nil, fmt.Errorf("write mode is 'blocked', only SELECT statements are allowed")
|
||||
}
|
||||
case bigqueryds.WriteModeProtected:
|
||||
if dryRunJob.Configuration != nil && dryRunJob.Configuration.Query != nil {
|
||||
if dest := dryRunJob.Configuration.Query.DestinationTable; dest != nil && dest.DatasetId != session.DatasetID {
|
||||
return nil, fmt.Errorf("protected write mode only supports SELECT statements, or write operations in the anonymous "+
|
||||
"dataset of a BigQuery session, but destination was %q", dest.DatasetId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
switch statementType {
|
||||
case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA":
|
||||
@@ -259,6 +302,8 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
query := bqClient.Query(sql)
|
||||
query.Location = bqClient.Location
|
||||
|
||||
query.ConnectionProperties = connProps
|
||||
|
||||
// Log the query executed for debugging.
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
@@ -270,10 +315,14 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
// We iterate through the results, convert each row into a map of
|
||||
// column names to values, and return the collection of rows.
|
||||
var out []any
|
||||
it, err := query.Read(ctx)
|
||||
job, err := query.Run(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
it, err := job.Read(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read query results: %w", err)
|
||||
}
|
||||
for {
|
||||
var row map[string]bigqueryapi.Value
|
||||
err = it.Next(&row)
|
||||
|
||||
@@ -53,6 +53,7 @@ type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
IsDatasetAllowed(projectID, datasetID string) bool
|
||||
BigQueryAllowedDatasets() []string
|
||||
BigQuerySession() bigqueryds.BigQuerySessionProvider
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
@@ -123,6 +124,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
SessionProvider: s.BigQuerySession(),
|
||||
AllowedDatasets: allowedDatasets,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
@@ -145,6 +147,7 @@ type Tool struct {
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
AllowedDatasets []string
|
||||
SessionProvider bigqueryds.BigQuerySessionProvider
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
@@ -184,13 +187,39 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
}
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
restService := t.RestService
|
||||
var err error
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, restService, err = t.ClientCreator(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var historyDataSource string
|
||||
trimmedUpperHistoryData := strings.TrimSpace(strings.ToUpper(historyData))
|
||||
if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") {
|
||||
if len(t.AllowedDatasets) > 0 {
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, t.RestService, t.Client.Project(), t.Client.Location, historyData, nil, nil)
|
||||
var connProps []*bigqueryapi.ConnectionProperty
|
||||
session, err := t.SessionProvider(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed during dry run: %w", err)
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
if session != nil {
|
||||
connProps = []*bigqueryapi.ConnectionProperty{
|
||||
{Key: "session_id", Value: session.ID},
|
||||
}
|
||||
}
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, historyData, nil, connProps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||
}
|
||||
statementType := dryRunJob.Statistics.Query.StatementType
|
||||
if statementType != "SELECT" {
|
||||
@@ -246,24 +275,19 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
horizon => %d%s)`,
|
||||
historyDataSource, dataCol, timestampCol, horizon, idColsArg)
|
||||
|
||||
bqClient := t.Client
|
||||
var err error
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// JobStatistics.QueryStatistics.StatementType
|
||||
query := bqClient.Query(sql)
|
||||
query.Location = bqClient.Location
|
||||
session, err := t.SessionProvider(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
if session != nil {
|
||||
// Add session ID to the connection properties for subsequent calls.
|
||||
query.ConnectionProperties = []*bigqueryapi.ConnectionProperty{
|
||||
{Key: "session_id", Value: session.ID},
|
||||
}
|
||||
}
|
||||
|
||||
// Log the query executed for debugging.
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
@@ -276,10 +300,14 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
// We iterate through the results, convert each row into a map of
|
||||
// column names to values, and return the collection of rows.
|
||||
var out []any
|
||||
it, err := query.Read(ctx)
|
||||
job, err := query.Run(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
it, err := job.Read(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read query results: %w", err)
|
||||
}
|
||||
for {
|
||||
var row map[string]bigqueryapi.Value
|
||||
err = it.Next(&row)
|
||||
|
||||
@@ -49,6 +49,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
BigQueryClient() *bigqueryapi.Client
|
||||
BigQuerySession() bigqueryds.BigQuerySessionProvider
|
||||
BigQueryWriteMode() string
|
||||
BigQueryRestService() *bigqueryrestapi.Service
|
||||
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
||||
UseClientAuthorization() bool
|
||||
@@ -106,13 +108,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
TemplateParameters: cfg.TemplateParameters,
|
||||
AllParams: allParameters,
|
||||
|
||||
Statement: cfg.Statement,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Statement: cfg.Statement,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
SessionProvider: s.BigQuerySession(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -129,12 +132,13 @@ type Tool struct {
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
Statement string
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Statement string
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
SessionProvider bigqueryds.BigQuerySessionProvider
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
@@ -187,7 +191,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
if arrayParam, ok := p.(*tools.ArrayParameter); ok {
|
||||
// Handle array types based on their defined item type.
|
||||
lowLevelParam.ParameterType.Type = "ARRAY"
|
||||
itemType, err := BQTypeStringFromToolType(arrayParam.GetItems().GetType())
|
||||
itemType, err := bqutil.BQTypeStringFromToolType(arrayParam.GetItems().GetType())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -204,7 +208,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
lowLevelParam.ParameterValue.ArrayValues = arrayValues
|
||||
} else {
|
||||
// Handle scalar types based on their defined type.
|
||||
bqType, err := BQTypeStringFromToolType(p.GetType())
|
||||
bqType, err := bqutil.BQTypeStringFromToolType(p.GetType())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -233,20 +237,36 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
query.Parameters = highLevelParams
|
||||
query.Location = bqClient.Location
|
||||
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, query.ConnectionProperties)
|
||||
if err != nil {
|
||||
// This is a fallback check in case the switch logic was bypassed.
|
||||
return nil, fmt.Errorf("final query validation failed: %w", err)
|
||||
connProps := []*bigqueryapi.ConnectionProperty{}
|
||||
if t.SessionProvider != nil {
|
||||
session, err := t.SessionProvider(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
}
|
||||
if session != nil {
|
||||
// Add session ID to the connection properties for subsequent calls.
|
||||
connProps = append(connProps, &bigqueryapi.ConnectionProperty{Key: "session_id", Value: session.ID})
|
||||
}
|
||||
}
|
||||
query.ConnectionProperties = connProps
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), query.Location, newStatement, lowLevelParams, connProps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||
}
|
||||
|
||||
statementType := dryRunJob.Statistics.Query.StatementType
|
||||
|
||||
// This block handles SELECT statements, which return a row set.
|
||||
// We iterate through the results, convert each row into a map of
|
||||
// column names to values, and return the collection of rows.
|
||||
it, err := query.Read(ctx)
|
||||
job, err := query.Run(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
it, err := job.Read(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read query results: %w", err)
|
||||
}
|
||||
|
||||
var out []any
|
||||
for {
|
||||
@@ -300,19 +320,3 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return t.UseClientOAuth
|
||||
}
|
||||
|
||||
func BQTypeStringFromToolType(toolType string) (string, error) {
|
||||
switch toolType {
|
||||
case "string":
|
||||
return "STRING", nil
|
||||
case "integer":
|
||||
return "INT64", nil
|
||||
case "float":
|
||||
return "FLOAT64", nil
|
||||
case "boolean":
|
||||
return "BOOL", nil
|
||||
// Note: 'array' is handled separately as it has a nested item type.
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported tool parameter type for BigQuery: %s", toolType)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,7 +174,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
ddlWant := `"Query executed successfully and returned no content."`
|
||||
dataInsightsWant := `(?s)Schema Resolved.*Retrieval Query.*SQL Generated.*Answer`
|
||||
// Partial message; the full error message is too long.
|
||||
mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"final query validation failed: failed to insert dry run job: googleapi: Error 400: Syntax error: Unexpected identifier \"SELEC\" at [1:1]`
|
||||
mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"query validation failed: failed to insert dry run job: googleapi: Error 400: Syntax error: Unexpected identifier \"SELEC\" at [1:1]`
|
||||
mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"f0_\":1}"}]}}`
|
||||
createColArray := `["id INT64", "name STRING", "age INT64"]`
|
||||
selectEmptyWant := `"The query returned 0 rows."`
|
||||
@@ -326,6 +326,165 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
|
||||
runForecastWithRestriction(t, allowedForecastTableFullName2, disallowedForecastTableFullName)
|
||||
}
|
||||
|
||||
func TestBigQueryWriteModeAllowed(t *testing.T) {
|
||||
sourceConfig := getBigQueryVars(t)
|
||||
sourceConfig["writeMode"] = "allowed"
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
datasetName := fmt.Sprintf("temp_toolbox_test_allowed_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
|
||||
|
||||
client, err := initBigQueryConnection(BigqueryProject)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create BigQuery connection: %s", err)
|
||||
}
|
||||
|
||||
dataset := client.Dataset(datasetName)
|
||||
if err := dataset.Create(ctx, &bigqueryapi.DatasetMetadata{Name: datasetName}); err != nil {
|
||||
t.Fatalf("Failed to create dataset %q: %v", datasetName, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := dataset.DeleteWithContents(ctx); err != nil {
|
||||
t.Logf("failed to cleanup dataset %s: %v", datasetName, err)
|
||||
}
|
||||
}()
|
||||
|
||||
toolsFile := map[string]any{
|
||||
"sources": map[string]any{
|
||||
"my-instance": sourceConfig,
|
||||
},
|
||||
"tools": map[string]any{
|
||||
"my-exec-sql-tool": map[string]any{
|
||||
"kind": "bigquery-execute-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to execute sql",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
runBigQueryWriteModeAllowedTest(t, datasetName)
|
||||
}
|
||||
|
||||
func TestBigQueryWriteModeBlocked(t *testing.T) {
|
||||
sourceConfig := getBigQueryVars(t)
|
||||
sourceConfig["writeMode"] = "blocked"
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
datasetName := fmt.Sprintf("temp_toolbox_test_blocked_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
|
||||
tableName := fmt.Sprintf("param_table_blocked_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
|
||||
tableNameParam := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, datasetName, tableName)
|
||||
|
||||
client, err := initBigQueryConnection(BigqueryProject)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create BigQuery connection: %s", err)
|
||||
}
|
||||
createParamTableStmt, insertParamTableStmt, _, _, _, _, paramTestParams := getBigQueryParamToolInfo(tableNameParam)
|
||||
teardownTable := setupBigQueryTable(t, ctx, client, createParamTableStmt, insertParamTableStmt, datasetName, tableNameParam, paramTestParams)
|
||||
defer teardownTable(t)
|
||||
|
||||
toolsFile := map[string]any{
|
||||
"sources": map[string]any{"my-instance": sourceConfig},
|
||||
"tools": map[string]any{
|
||||
"my-exec-sql-tool": map[string]any{"kind": "bigquery-execute-sql", "source": "my-instance", "description": "Tool to execute sql"},
|
||||
},
|
||||
}
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
runBigQueryWriteModeBlockedTest(t, tableNameParam, datasetName)
|
||||
}
|
||||
|
||||
func TestBigQueryWriteModeProtected(t *testing.T) {
|
||||
sourceConfig := getBigQueryVars(t)
|
||||
sourceConfig["writeMode"] = "protected"
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
permanentDatasetName := fmt.Sprintf("perm_dataset_protected_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
|
||||
client, err := initBigQueryConnection(BigqueryProject)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create BigQuery connection: %s", err)
|
||||
}
|
||||
dataset := client.Dataset(permanentDatasetName)
|
||||
if err := dataset.Create(ctx, &bigqueryapi.DatasetMetadata{Name: permanentDatasetName}); err != nil {
|
||||
t.Fatalf("Failed to create dataset %q: %v", permanentDatasetName, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := dataset.DeleteWithContents(ctx); err != nil {
|
||||
t.Logf("failed to cleanup dataset %s: %v", permanentDatasetName, err)
|
||||
}
|
||||
}()
|
||||
|
||||
toolsFile := map[string]any{
|
||||
"sources": map[string]any{"my-instance": sourceConfig},
|
||||
"tools": map[string]any{
|
||||
"my-exec-sql-tool": map[string]any{"kind": "bigquery-execute-sql", "source": "my-instance", "description": "Tool to execute sql"},
|
||||
"my-sql-tool-protected": map[string]any{
|
||||
"kind": "bigquery-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to query from the session",
|
||||
"statement": "SELECT * FROM my_shared_temp_table",
|
||||
},
|
||||
"my-forecast-tool-protected": map[string]any{
|
||||
"kind": "bigquery-forecast",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to forecast from session temp table",
|
||||
},
|
||||
"my-analyze-contribution-tool-protected": map[string]any{
|
||||
"kind": "bigquery-analyze-contribution",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to analyze contribution from session temp table",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
runBigQueryWriteModeProtectedTest(t, permanentDatasetName)
|
||||
}
|
||||
|
||||
// getBigQueryParamToolInfo returns statements and param for my-tool for bigquery kind
|
||||
func getBigQueryParamToolInfo(tableName string) (string, string, string, string, string, string, []bigqueryapi.QueryParameter) {
|
||||
createStatement := fmt.Sprintf(`
|
||||
@@ -900,6 +1059,223 @@ func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamW
|
||||
}
|
||||
}
|
||||
|
||||
// runInvokeRequest sends a POST request to the given API endpoint and returns the response and parsed JSON body.
|
||||
func runInvokeRequest(t *testing.T, api, body string, headers map[string]string) (*http.Response, map[string]interface{}) {
|
||||
t.Helper()
|
||||
req, err := http.NewRequest(http.MethodPost, api, bytes.NewBufferString(body))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
for k, v := range headers {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to send request: %v", err)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
// Use a TeeReader to be able to read the body multiple times (for logging on failure)
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read response body: %v", err)
|
||||
}
|
||||
resp.Body.Close() // Close original body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Replace with a new reader
|
||||
|
||||
if err := json.Unmarshal(bodyBytes, &result); err != nil {
|
||||
t.Logf("Failed to decode response body: %s", string(bodyBytes))
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
return resp, result
|
||||
}
|
||||
|
||||
func runBigQueryWriteModeAllowedTest(t *testing.T, datasetName string) {
|
||||
t.Run("CREATE TABLE should succeed", func(t *testing.T) {
|
||||
sql := fmt.Sprintf("CREATE TABLE %s.new_table (x INT64)", datasetName)
|
||||
body := fmt.Sprintf(`{"sql": "%s"}`, sql)
|
||||
resp, result := runInvokeRequest(t, "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", body, nil)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, http.StatusOK, string(bodyBytes))
|
||||
}
|
||||
|
||||
resStr, ok := result["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("expected 'result' field in response, got %v", result)
|
||||
}
|
||||
if resStr != `"Query executed successfully and returned no content."` {
|
||||
t.Errorf("unexpected result: got %q, want %q", resStr, `"Query executed successfully and returned no content."`)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func runBigQueryWriteModeBlockedTest(t *testing.T, tableNameParam, datasetName string) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
sql string
|
||||
wantStatusCode int
|
||||
wantInError string
|
||||
wantResult string
|
||||
}{
|
||||
{"SELECT statement should succeed", fmt.Sprintf("SELECT * FROM %s WHERE id = 1", tableNameParam), http.StatusOK, "", `[{"id":1,"name":"Alice"}]`},
|
||||
{"INSERT statement should fail", fmt.Sprintf("INSERT INTO %s (id, name) VALUES (10, 'test')", tableNameParam), http.StatusBadRequest, "write mode is 'blocked', only SELECT statements are allowed", ""},
|
||||
{"CREATE TABLE statement should fail", fmt.Sprintf("CREATE TABLE %s.new_table (x INT64)", datasetName), http.StatusBadRequest, "write mode is 'blocked', only SELECT statements are allowed", ""},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
body := fmt.Sprintf(`{"sql": "%s"}`, tc.sql)
|
||||
resp, result := runInvokeRequest(t, "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", body, nil)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
if tc.wantInError != "" {
|
||||
errStr, ok := result["error"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("expected 'error' field in response, got %v", result)
|
||||
}
|
||||
if !strings.Contains(errStr, tc.wantInError) {
|
||||
t.Fatalf("expected error message to contain %q, but got %q", tc.wantInError, errStr)
|
||||
}
|
||||
}
|
||||
if tc.wantResult != "" {
|
||||
resStr, ok := result["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("expected 'result' field in response, got %v", result)
|
||||
}
|
||||
if resStr != tc.wantResult {
|
||||
t.Fatalf("unexpected result: got %q, want %q", resStr, tc.wantResult)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runBigQueryWriteModeProtectedTest(t *testing.T, permanentDatasetName string) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
toolName string
|
||||
requestBody string
|
||||
wantStatusCode int
|
||||
wantInError string
|
||||
wantResult string
|
||||
}{
|
||||
{
|
||||
name: "CREATE TABLE to permanent dataset should fail",
|
||||
toolName: "my-exec-sql-tool",
|
||||
requestBody: fmt.Sprintf(`{"sql": "CREATE TABLE %s.new_table (x INT64)"}`, permanentDatasetName),
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantInError: "protected write mode only supports SELECT statements, or write operations in the anonymous dataset",
|
||||
wantResult: "",
|
||||
},
|
||||
{
|
||||
name: "CREATE TEMP TABLE should succeed",
|
||||
toolName: "my-exec-sql-tool",
|
||||
requestBody: `{"sql": "CREATE TEMP TABLE my_shared_temp_table (x INT64)"}`,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantInError: "",
|
||||
wantResult: `"Query executed successfully and returned no content."`,
|
||||
},
|
||||
{
|
||||
name: "INSERT into TEMP TABLE should succeed",
|
||||
toolName: "my-exec-sql-tool",
|
||||
requestBody: `{"sql": "INSERT INTO my_shared_temp_table (x) VALUES (42)"}`,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantInError: "",
|
||||
wantResult: `"Query executed successfully and returned no content."`,
|
||||
},
|
||||
{
|
||||
name: "SELECT from TEMP TABLE with exec-sql should succeed",
|
||||
toolName: "my-exec-sql-tool",
|
||||
requestBody: `{"sql": "SELECT * FROM my_shared_temp_table"}`,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantInError: "",
|
||||
wantResult: `[{"x":42}]`,
|
||||
},
|
||||
{
|
||||
name: "SELECT from TEMP TABLE with sql-tool should succeed",
|
||||
toolName: "my-sql-tool-protected",
|
||||
requestBody: `{}`,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantInError: "",
|
||||
wantResult: `[{"x":42}]`,
|
||||
},
|
||||
{
|
||||
name: "CREATE TEMP TABLE for forecast should succeed",
|
||||
toolName: "my-exec-sql-tool",
|
||||
requestBody: `{"sql": "CREATE TEMP TABLE forecast_temp_table (ts TIMESTAMP, data FLOAT64) AS SELECT TIMESTAMP('2025-01-01T00:00:00Z') AS ts, 10.0 AS data UNION ALL SELECT TIMESTAMP('2025-01-01T01:00:00Z'), 11.0 UNION ALL SELECT TIMESTAMP('2025-01-01T02:00:00Z'), 12.0 UNION ALL SELECT TIMESTAMP('2025-01-01T03:00:00Z'), 13.0"}`,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantInError: "",
|
||||
wantResult: `"Query executed successfully and returned no content."`,
|
||||
},
|
||||
{
|
||||
name: "Forecast from TEMP TABLE should succeed",
|
||||
toolName: "my-forecast-tool-protected",
|
||||
requestBody: `{"history_data": "SELECT * FROM forecast_temp_table", "timestamp_col": "ts", "data_col": "data", "horizon": 1}`,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantInError: "",
|
||||
wantResult: `"forecast_timestamp"`,
|
||||
},
|
||||
{
|
||||
name: "CREATE TEMP TABLE for contribution analysis should succeed",
|
||||
toolName: "my-exec-sql-tool",
|
||||
requestBody: `{"sql": "CREATE TEMP TABLE contribution_temp_table (dim1 STRING, is_test BOOL, metric FLOAT64) AS SELECT 'a' as dim1, true as is_test, 100.0 as metric UNION ALL SELECT 'b', false, 120.0"}`,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantInError: "",
|
||||
wantResult: `"Query executed successfully and returned no content."`,
|
||||
},
|
||||
{
|
||||
name: "Analyze contribution from TEMP TABLE should succeed",
|
||||
toolName: "my-analyze-contribution-tool-protected",
|
||||
requestBody: `{"input_data": "SELECT * FROM contribution_temp_table", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": ["dim1"]}`,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantInError: "",
|
||||
wantResult: `"relative_difference"`,
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
api := fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", tc.toolName)
|
||||
resp, result := runInvokeRequest(t, api, tc.requestBody, nil)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
if tc.wantInError != "" {
|
||||
errStr, ok := result["error"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("expected 'error' field in response, got %v", result)
|
||||
}
|
||||
if !strings.Contains(errStr, tc.wantInError) {
|
||||
t.Fatalf("expected error message to contain %q, but got %q", tc.wantInError, errStr)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.wantResult != "" {
|
||||
resStr, ok := result["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("expected 'result' field in response, got %v", result)
|
||||
}
|
||||
if !strings.Contains(resStr, tc.wantResult) {
|
||||
t.Fatalf("expected %q to contain %q, but it did not", resStr, tc.wantResult)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runBigQueryExecuteSqlToolInvokeDryRunTest(t *testing.T, datasetName string) {
|
||||
// Get ID token
|
||||
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
|
||||
|
||||
Reference in New Issue
Block a user