mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-11 00:18:17 -05:00
fix(tools/bigquery-sql): ensure invoke always returns a non-null value (#1020)
This is to make bigquery-sql consistent with bigquery-execute-sql. May not be necessary to have. - Added a dry run step to identify the query type (e.g., SELECT, DML), which allows the tool to correctly handle the query's output. - The recommended high-level client, cloud.google.com/go/bigquery, does not expose the statement type from a dry run. To circumvent this limitation, the low-level BigQuery REST API client (google.golang.org/api/bigquery/v2) was added to gain access to these necessary details. --------- Co-authored-by: Averi Kitsch <akitsch@google.com>
This commit is contained in:
@@ -17,6 +17,7 @@ package bigquerysql
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
@@ -24,6 +25,7 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||
"google.golang.org/api/iterator"
|
||||
)
|
||||
|
||||
@@ -45,6 +47,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
BigQueryClient() *bigqueryapi.Client
|
||||
BigQueryRestService() *bigqueryrestapi.Service
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
@@ -101,6 +104,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -117,15 +121,17 @@ type Tool struct {
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Statement string
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
namedArgs := make([]bigqueryapi.QueryParameter, 0, len(params))
|
||||
highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters))
|
||||
lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(t.Parameters))
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
@@ -136,14 +142,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error)
|
||||
name := p.GetName()
|
||||
value := paramsMap[name]
|
||||
|
||||
// BigQuery's QueryParameter only accepts typed slices as input
|
||||
// This checks if the param is an array.
|
||||
// If yes, convert []any to typed slice (e.g []string, []int)
|
||||
switch arrayParam := p.(type) {
|
||||
case *tools.ArrayParameter:
|
||||
// This block for converting []any to typed slices is still necessary and correct.
|
||||
if arrayParam, ok := p.(*tools.ArrayParameter); ok {
|
||||
arrayParamValue, ok := value.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to convert parameter `%s` to []any %w", name, err)
|
||||
return nil, fmt.Errorf("unable to convert parameter `%s` to []any", name)
|
||||
}
|
||||
itemType := arrayParam.GetItems().GetType()
|
||||
var err error
|
||||
@@ -153,22 +156,69 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(t.Statement, "@"+name) {
|
||||
namedArgs = append(namedArgs, bigqueryapi.QueryParameter{
|
||||
Name: name,
|
||||
Value: value,
|
||||
})
|
||||
} else {
|
||||
namedArgs = append(namedArgs, bigqueryapi.QueryParameter{
|
||||
Value: value,
|
||||
})
|
||||
// Determine if the parameter is named or positional for the high-level client.
|
||||
var paramNameForHighLevel string
|
||||
if strings.Contains(newStatement, "@"+name) {
|
||||
paramNameForHighLevel = name
|
||||
}
|
||||
|
||||
// 1. Create the high-level parameter for the final query execution.
|
||||
highLevelParams = append(highLevelParams, bigqueryapi.QueryParameter{
|
||||
Name: paramNameForHighLevel,
|
||||
Value: value,
|
||||
})
|
||||
|
||||
// 2. Create the low-level parameter for the dry run, using the defined type from `p`.
|
||||
lowLevelParam := &bigqueryrestapi.QueryParameter{
|
||||
Name: paramNameForHighLevel,
|
||||
ParameterType: &bigqueryrestapi.QueryParameterType{},
|
||||
ParameterValue: &bigqueryrestapi.QueryParameterValue{},
|
||||
}
|
||||
|
||||
if arrayParam, ok := p.(*tools.ArrayParameter); ok {
|
||||
// Handle array types based on their defined item type.
|
||||
lowLevelParam.ParameterType.Type = "ARRAY"
|
||||
itemType, err := BQTypeStringFromToolType(arrayParam.GetItems().GetType())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lowLevelParam.ParameterType.ArrayType = &bigqueryrestapi.QueryParameterType{Type: itemType}
|
||||
|
||||
// Build the array values.
|
||||
sliceVal := reflect.ValueOf(value)
|
||||
arrayValues := make([]*bigqueryrestapi.QueryParameterValue, sliceVal.Len())
|
||||
for i := 0; i < sliceVal.Len(); i++ {
|
||||
arrayValues[i] = &bigqueryrestapi.QueryParameterValue{
|
||||
Value: fmt.Sprintf("%v", sliceVal.Index(i).Interface()),
|
||||
}
|
||||
}
|
||||
lowLevelParam.ParameterValue.ArrayValues = arrayValues
|
||||
} else {
|
||||
// Handle scalar types based on their defined type.
|
||||
bqType, err := BQTypeStringFromToolType(p.GetType())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lowLevelParam.ParameterType.Type = bqType
|
||||
lowLevelParam.ParameterValue.Value = fmt.Sprintf("%v", value)
|
||||
}
|
||||
lowLevelParams = append(lowLevelParams, lowLevelParam)
|
||||
}
|
||||
|
||||
query := t.Client.Query(newStatement)
|
||||
query.Parameters = namedArgs
|
||||
query.Parameters = highLevelParams
|
||||
query.Location = t.Client.Location
|
||||
|
||||
dryRunJob, err := dryRunQuery(ctx, t.RestService, t.Client.Project(), t.Client.Location, newStatement, lowLevelParams, query.ConnectionProperties)
|
||||
if err != nil {
|
||||
// This is a fallback check in case the switch logic was bypassed.
|
||||
return nil, fmt.Errorf("final query validation failed: %w", err)
|
||||
}
|
||||
statementType := dryRunJob.Statistics.Query.StatementType
|
||||
|
||||
// This block handles SELECT statements, which return a row set.
|
||||
// We iterate through the results, convert each row into a map of
|
||||
// column names to values, and return the collection of rows.
|
||||
it, err := query.Read(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
@@ -177,7 +227,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error)
|
||||
var out []any
|
||||
for {
|
||||
var row map[string]bigqueryapi.Value
|
||||
err := it.Next(&row)
|
||||
err = it.Next(&row)
|
||||
if err == iterator.Done {
|
||||
break
|
||||
}
|
||||
@@ -190,8 +240,21 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error)
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
// If the query returned any rows, return them directly.
|
||||
if len(out) > 0 {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
return out, nil
|
||||
// This handles the standard case for a SELECT query that successfully
|
||||
// executes but returns zero rows.
|
||||
if statementType == "SELECT" {
|
||||
return "The query returned 0 rows.", nil
|
||||
}
|
||||
// This is the fallback for a successful query that doesn't return content.
|
||||
// In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc.
|
||||
// However, it is also possible that this was a query that was expected to return rows
|
||||
// but returned none, a case that we cannot distinguish here.
|
||||
return "Query executed successfully and returned no content.", nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
@@ -209,3 +272,58 @@ func (t Tool) McpManifest() tools.McpManifest {
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func BQTypeStringFromToolType(toolType string) (string, error) {
|
||||
switch toolType {
|
||||
case "string":
|
||||
return "STRING", nil
|
||||
case "integer":
|
||||
return "INT64", nil
|
||||
case "float":
|
||||
return "FLOAT64", nil
|
||||
case "boolean":
|
||||
return "BOOL", nil
|
||||
// Note: 'array' is handled separately as it has a nested item type.
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported tool parameter type for BigQuery: %s", toolType)
|
||||
}
|
||||
}
|
||||
|
||||
func dryRunQuery(
|
||||
ctx context.Context,
|
||||
restService *bigqueryrestapi.Service,
|
||||
projectID string,
|
||||
location string,
|
||||
sql string,
|
||||
params []*bigqueryrestapi.QueryParameter,
|
||||
connProps []*bigqueryapi.ConnectionProperty,
|
||||
) (*bigqueryrestapi.Job, error) {
|
||||
useLegacySql := false
|
||||
|
||||
restConnProps := make([]*bigqueryrestapi.ConnectionProperty, len(connProps))
|
||||
for i, prop := range connProps {
|
||||
restConnProps[i] = &bigqueryrestapi.ConnectionProperty{Key: prop.Key, Value: prop.Value}
|
||||
}
|
||||
|
||||
jobToInsert := &bigqueryrestapi.Job{
|
||||
JobReference: &bigqueryrestapi.JobReference{
|
||||
ProjectId: projectID,
|
||||
Location: location,
|
||||
},
|
||||
Configuration: &bigqueryrestapi.JobConfiguration{
|
||||
DryRun: true,
|
||||
Query: &bigqueryrestapi.JobConfigurationQuery{
|
||||
Query: sql,
|
||||
UseLegacySql: &useLegacySql,
|
||||
ConnectionProperties: restConnProps,
|
||||
QueryParameters: params,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
insertResponse, err := restService.Jobs.Insert(projectID, jobToInsert).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to insert dry run job: %w", err)
|
||||
}
|
||||
return insertResponse, nil
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ func initBigQueryConnection(project string) (*bigqueryapi.Client, error) {
|
||||
|
||||
func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
sourceConfig := getBigQueryVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
@@ -100,6 +100,11 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
datasetName,
|
||||
strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
)
|
||||
tableNameDataType := fmt.Sprintf("`%s.%s.datatype_table_%s`",
|
||||
BigqueryProject,
|
||||
datasetName,
|
||||
strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
)
|
||||
|
||||
// set up data for param tool
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getBigQueryParamToolInfo(tableNameParam)
|
||||
@@ -111,8 +116,14 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
teardownTable2 := setupBigQueryTable(t, ctx, client, createAuthTableStmt, insertAuthTableStmt, datasetName, tableNameAuth, authTestParams)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// set up data for data type test tool
|
||||
createDataTypeTableStmt, insertDataTypeTableStmt, dataTypeToolStmt, arrayDataTypeToolStmt, dataTypeTestParams := getBigQueryDataTypeTestInfo(tableNameDataType)
|
||||
teardownTable3 := setupBigQueryTable(t, ctx, client, createDataTypeTableStmt, insertDataTypeTableStmt, datasetName, tableNameDataType, dataTypeTestParams)
|
||||
defer teardownTable3(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, BigqueryToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = addBigQuerySqlToolConfig(t, toolsFile, dataTypeToolStmt, arrayDataTypeToolStmt)
|
||||
toolsFile = addBigQueryPrebuiltToolsConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := getBigQueryTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, BigqueryToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
@@ -135,18 +146,23 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
|
||||
select1Want := "[{\"f0_\":1}]"
|
||||
// Partial message; the full error message is too long.
|
||||
failInvocationWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: googleapi: Error 400: Syntax error: Unexpected identifier \"SELEC\" at [1:1]`
|
||||
failInvocationWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"final query validation failed: failed to insert dry run job: googleapi: Error 400: Syntax error: Unexpected identifier \"SELEC\" at [1:1]`
|
||||
datasetInfoWant := "\"Location\":\"US\",\"DefaultTableExpiration\":0,\"Labels\":null,\"Access\":"
|
||||
tableInfoWant := "{\"Name\":\"\",\"Location\":\"US\",\"Description\":\"\",\"Schema\":[{\"Name\":\"id\""
|
||||
ddlWant := `"Query executed successfully and returned no content."`
|
||||
invokeParamWant, invokeIdNullWant, nullWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeIdNullWant, nullWant, false, true)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
templateParamTestConfig := tests.NewTemplateParameterTestConfig(
|
||||
tests.WithCreateColArray(`["id INT64", "name STRING", "age INT64"]`),
|
||||
tests.WithDdlWant(ddlWant),
|
||||
tests.WithSelectEmptyWant(`"The query returned 0 rows."`),
|
||||
tests.WithInsert1Want(ddlWant),
|
||||
)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, templateParamTestConfig)
|
||||
|
||||
runBigQueryExecuteSqlToolInvokeTest(t, select1Want, invokeParamWant, tableNameParam)
|
||||
runBigQueryDataTypeTests(t)
|
||||
runBigQueryListDatasetToolInvokeTest(t, datasetName)
|
||||
runBigQueryGetDatasetInfoToolInvokeTest(t, datasetName, datasetInfoWant)
|
||||
runBigQueryListTableIdsToolInvokeTest(t, datasetName, tableName)
|
||||
@@ -187,6 +203,22 @@ func getBigQueryAuthToolInfo(tableName string) (string, string, string, []bigque
|
||||
return createStatement, insertStatement, toolStatement, params
|
||||
}
|
||||
|
||||
// getBigQueryDataTypeTestInfo returns statements and params for data type tests.
|
||||
func getBigQueryDataTypeTestInfo(tableName string) (string, string, string, string, []bigqueryapi.QueryParameter) {
|
||||
createStatement := fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (id INT64, int_val INT64, string_val STRING, float_val FLOAT64, bool_val BOOL);`, tableName)
|
||||
insertStatement := fmt.Sprintf(`
|
||||
INSERT INTO %s (id, int_val, string_val, float_val, bool_val) VALUES (?, ?, ?, ?, ?), (?, ?, ?, ?, ?), (?, ?, ?, ?, ?);`, tableName)
|
||||
toolStatement := fmt.Sprintf(`SELECT * FROM %s WHERE int_val = ? AND string_val = ? AND float_val = ? AND bool_val = ?;`, tableName)
|
||||
arrayToolStatement := fmt.Sprintf(`SELECT * FROM %s WHERE int_val IN UNNEST(@int_array) AND string_val IN UNNEST(@string_array) AND float_val IN UNNEST(@float_array) AND bool_val IN UNNEST(@bool_array) ORDER BY id;`, tableName)
|
||||
params := []bigqueryapi.QueryParameter{
|
||||
{Value: int64(1)}, {Value: int64(123)}, {Value: "hello"}, {Value: 3.14}, {Value: true},
|
||||
{Value: int64(2)}, {Value: int64(-456)}, {Value: "world"}, {Value: -0.55}, {Value: false},
|
||||
{Value: int64(3)}, {Value: int64(789)}, {Value: "test"}, {Value: 100.1}, {Value: true},
|
||||
}
|
||||
return createStatement, insertStatement, toolStatement, arrayToolStatement, params
|
||||
}
|
||||
|
||||
// getBigQueryTmplToolStatement returns statements for template parameter test cases for bigquery kind
|
||||
func getBigQueryTmplToolStatement() (string, string) {
|
||||
tmplSelectCombined := "SELECT * FROM {{.tableName}} WHERE id = ? ORDER BY id"
|
||||
@@ -345,6 +377,40 @@ func addBigQueryPrebuiltToolsConfig(t *testing.T, config map[string]any) map[str
|
||||
return config
|
||||
}
|
||||
|
||||
func addBigQuerySqlToolConfig(t *testing.T, config map[string]any, toolStatement, arrayToolStatement string) map[string]any {
|
||||
tools, ok := config["tools"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("unable to get tools from config")
|
||||
}
|
||||
tools["my-scalar-datatype-tool"] = map[string]any{
|
||||
"kind": "bigquery-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test various scalar data types.",
|
||||
"statement": toolStatement,
|
||||
"parameters": []any{
|
||||
map[string]any{"name": "int_val", "type": "integer", "description": "an integer value"},
|
||||
map[string]any{"name": "string_val", "type": "string", "description": "a string value"},
|
||||
map[string]any{"name": "float_val", "type": "float", "description": "a float value"},
|
||||
map[string]any{"name": "bool_val", "type": "boolean", "description": "a boolean value"},
|
||||
},
|
||||
}
|
||||
tools["my-array-datatype-tool"] = map[string]any{
|
||||
"kind": "bigquery-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test various array data types.",
|
||||
"statement": arrayToolStatement,
|
||||
"parameters": []any{
|
||||
map[string]any{"name": "int_array", "type": "array", "description": "an array of integer values", "items": map[string]any{"name": "item", "type": "integer", "description": "desc"}},
|
||||
map[string]any{"name": "string_array", "type": "array", "description": "an array of string values", "items": map[string]any{"name": "item", "type": "string", "description": "desc"}},
|
||||
map[string]any{"name": "float_array", "type": "array", "description": "an array of float values", "items": map[string]any{"name": "item", "type": "float", "description": "desc"}},
|
||||
map[string]any{"name": "bool_array", "type": "array", "description": "an array of boolean values", "items": map[string]any{"name": "item", "type": "boolean", "description": "desc"}},
|
||||
},
|
||||
}
|
||||
|
||||
config["tools"] = tools
|
||||
return config
|
||||
}
|
||||
|
||||
func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamWant, tableNameParam string) {
|
||||
// Get ID token
|
||||
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
|
||||
@@ -490,6 +556,84 @@ func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamW
|
||||
}
|
||||
}
|
||||
|
||||
func runBigQueryDataTypeTests(t *testing.T) {
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestHeader map[string]string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
isErr bool
|
||||
}{
|
||||
{
|
||||
name: "invoke my-scalar-datatype-tool with values",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-scalar-datatype-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"int_val": 123, "string_val": "hello", "float_val": 3.14, "bool_val": true}`)),
|
||||
want: `[{"bool_val":true,"float_val":3.14,"id":1,"int_val":123,"string_val":"hello"}]`,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-scalar-datatype-tool with missing params",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-scalar-datatype-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"int_val": 123}`)),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "invoke my-array-datatype-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-array-datatype-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"int_array": [123, 789], "string_array": ["hello", "test"], "float_array": [3.14, 100.1], "bool_array": [true]}`)),
|
||||
want: `[{"bool_val":true,"float_val":3.14,"id":1,"int_val":123,"string_val":"hello"},{"bool_val":true,"float_val":100.1,"id":3,"int_val":789,"string_val":"test"}]`,
|
||||
isErr: false,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runBigQueryListDatasetToolInvokeTest(t *testing.T, datasetWant string) {
|
||||
// Get ID token
|
||||
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
|
||||
|
||||
@@ -405,14 +405,16 @@ func RunToolInvokeTest(t *testing.T, select1Want, invokeParamWant, invokeIdNullW
|
||||
|
||||
// TemplateParameterTestConfig represents the various configuration options for template parameter tests.
|
||||
type TemplateParameterTestConfig struct {
|
||||
ignoreDdl bool
|
||||
ignoreInsert bool
|
||||
selectAllWant string
|
||||
select1Want string
|
||||
nameFieldArray string
|
||||
nameColFilter string
|
||||
createColArray string
|
||||
insert1Want string
|
||||
ignoreDdl bool
|
||||
ignoreInsert bool
|
||||
ddlWant string
|
||||
selectAllWant string
|
||||
select1Want string
|
||||
selectEmptyWant string
|
||||
nameFieldArray string
|
||||
nameColFilter string
|
||||
createColArray string
|
||||
insert1Want string
|
||||
}
|
||||
|
||||
type Option func(*TemplateParameterTestConfig)
|
||||
@@ -431,6 +433,13 @@ func WithIgnoreInsert() Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithDdlWant is the option function to configure ddlWant.
|
||||
func WithDdlWant(s string) Option {
|
||||
return func(c *TemplateParameterTestConfig) {
|
||||
c.ddlWant = s
|
||||
}
|
||||
}
|
||||
|
||||
// WithSelectAllWant is the option function to configure selectAllWant.
|
||||
func WithSelectAllWant(s string) Option {
|
||||
return func(c *TemplateParameterTestConfig) {
|
||||
@@ -445,6 +454,13 @@ func WithSelect1Want(s string) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithSelectEmptyWant is the option function to configure selectEmptyWant.
|
||||
func WithSelectEmptyWant(s string) Option {
|
||||
return func(c *TemplateParameterTestConfig) {
|
||||
c.selectEmptyWant = s
|
||||
}
|
||||
}
|
||||
|
||||
// WithReplaceNameFieldArray is the option function to configure replaceNameFieldArray.
|
||||
func WithReplaceNameFieldArray(s string) Option {
|
||||
return func(c *TemplateParameterTestConfig) {
|
||||
@@ -475,14 +491,16 @@ func WithInsert1Want(s string) Option {
|
||||
// NewTemplateParameterTestConfig creates a new TemplateParameterTestConfig instances with options.
|
||||
func NewTemplateParameterTestConfig(options ...Option) *TemplateParameterTestConfig {
|
||||
templateParamTestOption := &TemplateParameterTestConfig{
|
||||
ignoreDdl: false,
|
||||
ignoreInsert: false,
|
||||
selectAllWant: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"},{\"age\":100,\"id\":2,\"name\":\"Alice\"}]",
|
||||
select1Want: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]",
|
||||
nameFieldArray: `["name"]`,
|
||||
nameColFilter: "name",
|
||||
createColArray: `["id INT","name VARCHAR(20)","age INT"]`,
|
||||
insert1Want: "null",
|
||||
ignoreDdl: false,
|
||||
ignoreInsert: false,
|
||||
ddlWant: "null",
|
||||
selectAllWant: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"},{\"age\":100,\"id\":2,\"name\":\"Alice\"}]",
|
||||
select1Want: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]",
|
||||
selectEmptyWant: "null",
|
||||
nameFieldArray: `["name"]`,
|
||||
nameColFilter: "name",
|
||||
createColArray: `["id INT","name VARCHAR(20)","age INT"]`,
|
||||
insert1Want: "null",
|
||||
}
|
||||
|
||||
// Apply provided options
|
||||
@@ -514,7 +532,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, config
|
||||
api: "http://127.0.0.1:5000/api/tool/create-table-templateParams-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "columns":%s}`, tableName, config.createColArray))),
|
||||
want: "null",
|
||||
want: config.ddlWant,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
@@ -551,6 +569,14 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, config
|
||||
want: config.select1Want,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke select-templateParams-combined-tool with no results",
|
||||
api: "http://127.0.0.1:5000/api/tool/select-templateParams-combined-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"id": 999, "tableName": "%s"}`, tableName))),
|
||||
want: config.selectEmptyWant,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke select-fields-templateParams-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/select-fields-templateParams-tool/invoke",
|
||||
@@ -573,7 +599,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, config
|
||||
api: "http://127.0.0.1:5000/api/tool/drop-table-templateParams-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s"}`, tableName))),
|
||||
want: "null",
|
||||
want: config.ddlWant,
|
||||
isErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user