mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 15:38:08 -05:00
feat(bigquery-execute-sql): add dry run support (#1057)
Add optional `dry_run` parameter to bigquery-execute-sql, which defaults to `false`. When the `dry_run` parameter is set to `true`, the tool returns the metadata from the dry run instead of executing the query. Fixes #703 --------- Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
This commit is contained in:
@@ -15,8 +15,9 @@ It's compatible with the following sources:
|
||||
|
||||
- [bigquery](../../sources/bigquery.md)
|
||||
|
||||
`bigquery-execute-sql` takes one input parameter `sql` and runs the sql
|
||||
statement against the `source`.
|
||||
`bigquery-execute-sql` takes a required `sql` input parameter and runs the SQL
|
||||
statement against the configured `source`. It also supports an optional `dry_run`
|
||||
parameter to validate a query without executing it.
|
||||
|
||||
## Example
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ package bigqueryexecutesql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
@@ -83,7 +84,13 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
}
|
||||
|
||||
sqlParameter := tools.NewStringParameter("sql", "The sql to execute.")
|
||||
parameters := tools.Parameters{sqlParameter}
|
||||
dryRunParameter := tools.NewBooleanParameterWithDefault(
|
||||
"dry_run",
|
||||
false,
|
||||
"If set to true, the query will be validated and information about the execution "+
|
||||
"will be returned without running the query. Defaults to false.",
|
||||
)
|
||||
parameters := tools.Parameters{sqlParameter, dryRunParameter}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
@@ -120,10 +127,14 @@ type Tool struct {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
sliceParams := params.AsSlice()
|
||||
sql, ok := sliceParams[0].(string)
|
||||
paramsMap := params.AsMap()
|
||||
sql, ok := paramsMap["sql"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to get cast %s", sliceParams[0])
|
||||
return nil, fmt.Errorf("unable to cast sql parameter %s", paramsMap["sql"])
|
||||
}
|
||||
dryRun, ok := paramsMap["dry_run"].(bool)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"])
|
||||
}
|
||||
|
||||
dryRunJob, err := dryRunQuery(ctx, t.RestService, t.Client.Project(), t.Client.Location, sql)
|
||||
@@ -131,6 +142,18 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error)
|
||||
return nil, fmt.Errorf("query validation failed during dry run: %w", err)
|
||||
}
|
||||
|
||||
if dryRun {
|
||||
if dryRunJob != nil {
|
||||
jobJSON, err := json.MarshalIndent(dryRunJob, "", " ")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal dry run job to JSON: %w", err)
|
||||
}
|
||||
return string(jobJSON), nil
|
||||
}
|
||||
// This case should not be reached, but as a fallback, we return a message.
|
||||
return "Dry run was requested, but no job information was returned.", nil
|
||||
}
|
||||
|
||||
statementType := dryRunJob.Statistics.Query.StatementType
|
||||
// JobStatistics.QueryStatistics.StatementType
|
||||
query := t.Client.Query(sql)
|
||||
|
||||
@@ -162,6 +162,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, templateParamTestConfig)
|
||||
|
||||
runBigQueryExecuteSqlToolInvokeTest(t, select1Want, invokeParamWant, tableNameParam, ddlWant)
|
||||
runBigQueryExecuteSqlToolInvokeDryRunTest(t, datasetName)
|
||||
runBigQueryDataTypeTests(t)
|
||||
runBigQueryListDatasetToolInvokeTest(t, datasetName)
|
||||
runBigQueryGetDatasetInfoToolInvokeTest(t, datasetName, datasetInfoWant)
|
||||
@@ -556,6 +557,115 @@ func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamW
|
||||
}
|
||||
}
|
||||
|
||||
func runBigQueryExecuteSqlToolInvokeDryRunTest(t *testing.T, datasetName string) {
|
||||
// Get ID token
|
||||
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting Google ID token: %s", err)
|
||||
}
|
||||
|
||||
newTableName := fmt.Sprintf("%s.new_dry_run_table_%s", datasetName, strings.ReplaceAll(uuid.New().String(), "-", ""))
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestHeader map[string]string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
isErr bool
|
||||
}{
|
||||
{
|
||||
name: "invoke my-exec-sql-tool with dryRun",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1", "dry_run": true}`)),
|
||||
want: `\"statementType\": \"SELECT\"`,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-exec-sql-tool with dryRun create table",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql":"CREATE TABLE %s (id INT64, name STRING)", "dry_run": true}`, newTableName))),
|
||||
want: `\"statementType\": \"CREATE_TABLE\"`,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-exec-sql-tool with dryRun execute immediate",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql":"EXECUTE IMMEDIATE \"CREATE TABLE %s (id INT64, name STRING)\"", "dry_run": true}`, newTableName))),
|
||||
want: `\"statementType\": \"SCRIPT\"`,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-exec-sql-tool with dryRun and auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-exec-sql-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1", "dry_run": true}`)),
|
||||
isErr: false,
|
||||
want: `\"statementType\": \"SELECT\"`,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-exec-sql-tool with dryRun and invalid auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-exec-sql-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1","dry_run": true}`)),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-exec-sql-tool with dryRun and without auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-exec-sql-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1", "dry_run": true}`)),
|
||||
isErr: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if !strings.Contains(got, tc.want) {
|
||||
t.Fatalf("expected %q to contain %q, but it did not", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runBigQueryDataTypeTests(t *testing.T) {
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
|
||||
Reference in New Issue
Block a user