feat(tools/postgressql): add templateParameters field (#615)

Add new tool field, templateParameters, to support non-filter parameters 
and DDL statements

Fix #535 for postgressql tool.

---------

Co-authored-by: Yuan <45984206+Yuan325@users.noreply.github.com>
This commit is contained in:
AlexTalreja
2025-06-04 20:28:37 +00:00
committed by GitHub
parent 1830702fd8
commit b76346993f
7 changed files with 721 additions and 24 deletions

View File

@@ -15,10 +15,12 @@
package tools
import (
"bytes"
"context"
"encoding/json"
"fmt"
"strings"
"text/template"
"github.com/googleapis/genai-toolbox/internal/util"
)
@@ -146,6 +148,62 @@ func ParseParams(ps Parameters, data map[string]any, claimsMap map[string]map[st
return params, nil
}
// helper function to convert a string array parameter to a comma separated string
func ConvertArrayParamToString(param any) (string, error) {
switch v := param.(type) {
case []any:
var stringValues []string
for _, item := range v {
stringVal, ok := item.(string)
if !ok {
return "", fmt.Errorf("templateParameter only supports string arrays")
}
stringValues = append(stringValues, stringVal)
}
return strings.Join(stringValues, ", "), nil
default:
return "", fmt.Errorf("invalid parameter type, expected array of type string")
}
}
// GetParams return the ParamValues that are associated with the Parameters.
func GetParams(params Parameters, paramValuesMap map[string]any) (ParamValues, error) {
resultParamValues := make(ParamValues, 0)
for _, p := range params {
k := p.GetName()
v, ok := paramValuesMap[k]
if !ok {
return nil, fmt.Errorf("missing parameter %s", k)
}
resultParamValues = append(resultParamValues, ParamValue{Name: k, Value: v})
}
return resultParamValues, nil
}
func ResolveTemplateParams(templateParams Parameters, originalStatement string, paramsMap map[string]any) (string, error) {
templateParamsValues, err := GetParams(templateParams, paramsMap)
templateParamsMap := templateParamsValues.AsMap()
if err != nil {
return "", fmt.Errorf("error getting template params %s", err)
}
funcMap := template.FuncMap{
"array": ConvertArrayParamToString,
}
t, err := template.New("statement").Funcs(funcMap).Parse(originalStatement)
if err != nil {
return "", fmt.Errorf("error creating go template %s", err)
}
var result bytes.Buffer
err = t.Execute(&result, templateParamsMap)
if err != nil {
return "", fmt.Errorf("error executing go template %s", err)
}
modifiedStatement := result.String()
return modifiedStatement, nil
}
type Parameter interface {
// Note: It's typically not idiomatic to include "Get" in the function name,
// but this is done to differentiate it from the fields in CommonParameter.

View File

@@ -980,3 +980,310 @@ func TestFailParametersUnmarshal(t *testing.T) {
})
}
}
func TestConvertArrayParamToString(t *testing.T) {
tcs := []struct {
name string
in []any
want string
}{
{
in: []any{
"id",
"name",
"location",
},
want: "id, name, location",
},
{
in: []any{
"id",
},
want: "id",
},
{
in: []any{
"id",
"5",
"false",
},
want: "id, 5, false",
},
{
in: []any{},
want: "",
},
{
in: []any{},
want: "",
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
got, _ := tools.ConvertArrayParamToString(tc.in)
if diff := cmp.Diff(tc.want, got); diff != "" {
t.Fatalf("incorrect array param conversion: diff %v", diff)
}
})
}
}
func TestFailConvertArrayParamToString(t *testing.T) {
tcs := []struct {
name string
in []any
err string
}{
{
in: []any{5, 10, 15},
err: "templateParameter only supports string arrays",
},
{
in: []any{"id", "name", 15},
err: "templateParameter only supports string arrays",
},
{
in: []any{false},
err: "templateParameter only supports string arrays",
},
{
in: []any{10, true},
err: "templateParameter only supports string arrays",
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
_, err := tools.ConvertArrayParamToString(tc.in)
errStr := err.Error()
if errStr != tc.err {
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
}
})
}
}
func TestGetParams(t *testing.T) {
tcs := []struct {
name string
in map[string]any
params tools.Parameters
want tools.ParamValues
}{
{
name: "parameters to include and exclude",
params: tools.Parameters{
tools.NewStringParameter("my_string_inc", "this should be included"),
tools.NewStringParameter("my_string_inc2", "this should be included"),
},
in: map[string]any{
"my_string_inc": "hello world A",
"my_string_inc2": "hello world B",
"my_string_exc": "hello world C",
},
want: tools.ParamValues{
tools.ParamValue{Name: "my_string_inc", Value: "hello world A"},
tools.ParamValue{Name: "my_string_inc2", Value: "hello world B"},
},
},
{
name: "include all",
params: tools.Parameters{
tools.NewStringParameter("my_string_inc", "this should be included"),
},
in: map[string]any{
"my_string_inc": "hello world A",
},
want: tools.ParamValues{
tools.ParamValue{Name: "my_string_inc", Value: "hello world A"},
},
},
{
name: "exclude all",
params: tools.Parameters{},
in: map[string]any{
"my_string_exc": "hello world A",
"my_string_exc2": "hello world B",
},
want: tools.ParamValues{},
},
{
name: "empty",
params: tools.Parameters{},
in: map[string]any{},
want: tools.ParamValues{},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
got, _ := tools.GetParams(tc.params, tc.in)
if diff := cmp.Diff(tc.want, got); diff != "" {
t.Fatalf("incorrect get params: diff %v", diff)
}
})
}
}
func TestFailGetParams(t *testing.T) {
tcs := []struct {
name string
params tools.Parameters
in map[string]any
err string
}{
{
name: "missing the only parameter",
params: tools.Parameters{tools.NewStringParameter("my_string", "this was missing")},
in: map[string]any{},
err: "missing parameter my_string",
},
{
name: "missing one parameter of multiple",
params: tools.Parameters{
tools.NewStringParameter("my_string_inc", "this should be included"),
tools.NewStringParameter("my_string_exc", "this was missing"),
},
in: map[string]any{
"my_string_inc": "hello world A",
},
err: "missing parameter my_string_exc",
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
_, err := tools.GetParams(tc.params, tc.in)
errStr := err.Error()
if errStr != tc.err {
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
}
})
}
}
func TestResolveTemplateParameters(t *testing.T) {
tcs := []struct {
name string
templateParams tools.Parameters
statement string
in map[string]any
want string
}{
{
name: "single template parameter",
templateParams: tools.Parameters{
tools.NewStringParameter("tableName", "this is a string template parameter"),
},
statement: "SELECT * FROM {{.tableName}}",
in: map[string]any{
"tableName": "hotels",
},
want: "SELECT * FROM hotels",
},
{
name: "multiple template parameters",
templateParams: tools.Parameters{
tools.NewStringParameter("tableName", "this is a string template parameter"),
tools.NewStringParameter("columnName", "this is a string template parameter"),
},
statement: "SELECT * FROM {{.tableName}} WHERE {{.columnName}} = 'Hilton'",
in: map[string]any{
"tableName": "hotels",
"columnName": "name",
},
want: "SELECT * FROM hotels WHERE name = 'Hilton'",
},
{
name: "standard and template parameter",
templateParams: tools.Parameters{
tools.NewStringParameter("tableName", "this is a string template parameter"),
tools.NewStringParameter("hotelName", "this is a string parameter"),
},
statement: "SELECT * FROM {{.tableName}} WHERE name = $1",
in: map[string]any{
"tableName": "hotels",
"hotelName": "name",
},
want: "SELECT * FROM hotels WHERE name = $1",
},
{
name: "standard parameter",
templateParams: tools.Parameters{
tools.NewStringParameter("hotelName", "this is a string parameter"),
},
statement: "SELECT * FROM hotels WHERE name = $1",
in: map[string]any{
"hotelName": "hotels",
},
want: "SELECT * FROM hotels WHERE name = $1",
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
got, _ := tools.ResolveTemplateParams(tc.templateParams, tc.statement, tc.in)
if diff := cmp.Diff(tc.want, got); diff != "" {
t.Fatalf("incorrect resolved template params: diff %v", diff)
}
})
}
}
func TestFailResolveTemplateParameters(t *testing.T) {
tcs := []struct {
name string
templateParams tools.Parameters
statement string
in map[string]any
err string
}{
{
name: "wrong param name",
templateParams: tools.Parameters{
tools.NewStringParameter("tableName", "this is a string template parameter"),
},
statement: "SELECT * FROM {{.missingParam}}",
in: map[string]any{},
err: "error getting template params missing parameter tableName",
},
{
name: "incomplete param template",
templateParams: tools.Parameters{
tools.NewStringParameter("tableName", "this is a string template parameter"),
},
statement: "SELECT * FROM {{.tableName",
in: map[string]any{
"tableName": "hotels",
},
err: "error creating go template template: statement:1: unclosed action",
},
{
name: "undefined function",
templateParams: tools.Parameters{
tools.NewStringParameter("tableName", "this is a string template parameter"),
},
statement: "SELECT * FROM {{json .tableName}}",
in: map[string]any{
"tableName": "hotels",
},
err: "error creating go template template: statement:1: function \"json\" not defined",
},
{
name: "undefined method",
templateParams: tools.Parameters{
tools.NewStringParameter("tableName", "this is a string template parameter"),
},
statement: "SELECT * FROM {{.tableName .wrong}}",
in: map[string]any{
"tableName": "hotels",
},
err: "error executing go template template: statement:1:16: executing \"statement\" at <.tableName>: tableName is not a method but has arguments",
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
_, err := tools.ResolveTemplateParams(tc.templateParams, tc.statement, tc.in)
errStr := err.Error()
if errStr != tc.err {
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
}
})
}
}

View File

@@ -17,6 +17,7 @@ package postgressql
import (
"context"
"fmt"
"slices"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -55,13 +56,14 @@ var _ compatibleSource = &postgres.Source{}
var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description" validate:"required"`
Statement string `yaml:"statement" validate:"required"`
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description" validate:"required"`
Statement string `yaml:"statement" validate:"required"`
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
TemplateParameters tools.Parameters `yaml:"templateParameters"`
}
// validate interface
@@ -84,22 +86,62 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
allParameters := slices.Concat(cfg.Parameters, cfg.TemplateParameters)
paramManifest := slices.Concat(
cfg.Parameters.Manifest(),
cfg.TemplateParameters.Manifest(),
)
if paramManifest == nil {
paramManifest = make([]tools.ParameterManifest, 0)
}
parametersMcpManifest := cfg.Parameters.McpManifest()
templateParametersMcpManifest := cfg.TemplateParameters.McpManifest()
// Concatenate parameters for MCP `required` field
concatRequiredManifest := slices.Concat(
parametersMcpManifest.Required,
templateParametersMcpManifest.Required,
)
if concatRequiredManifest == nil {
concatRequiredManifest = []string{}
}
// Concatenate parameters for MCP `properties` field
concatPropertiesManifest := make(map[string]tools.ParameterMcpManifest)
for name, p := range parametersMcpManifest.Properties {
concatPropertiesManifest[name] = p
}
for name, p := range templateParametersMcpManifest.Properties {
concatPropertiesManifest[name] = p
}
// Create a new McpToolsSchema with all parameters
paramMcpManifest := tools.McpToolsSchema{
Type: "object",
Properties: concatPropertiesManifest,
Required: concatRequiredManifest,
}
mcpManifest := tools.McpManifest{
Name: cfg.Name,
Description: cfg.Description,
InputSchema: cfg.Parameters.McpManifest(),
InputSchema: paramMcpManifest,
}
// finish tool setup
t := Tool{
Name: cfg.Name,
Kind: kind,
Parameters: cfg.Parameters,
Statement: cfg.Statement,
AuthRequired: cfg.AuthRequired,
Pool: s.PostgresPool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Name: cfg.Name,
Kind: kind,
Parameters: cfg.Parameters,
TemplateParameters: cfg.TemplateParameters,
AllParams: allParameters,
Statement: cfg.Statement,
AuthRequired: cfg.AuthRequired,
Pool: s.PostgresPool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -108,10 +150,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
var _ tools.Tool = Tool{}
type Tool struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
TemplateParameters tools.Parameters `yaml:"templateParameters"`
AllParams tools.Parameters `yaml:"allParams"`
Pool *pgxpool.Pool
Statement string
@@ -120,8 +164,18 @@ type Tool struct {
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
sliceParams := params.AsSlice()
results, err := t.Pool.Query(ctx, t.Statement, sliceParams...)
paramsMap := params.AsMap()
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {
return nil, fmt.Errorf("unable to extract template params %w", err)
}
newParams, err := tools.GetParams(t.Parameters, paramsMap)
if err != nil {
return nil, fmt.Errorf("unable to extract standard params %w", err)
}
sliceParams := newParams.AsSlice()
results, err := t.Pool.Query(ctx, newStatement, sliceParams...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
@@ -145,7 +199,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, erro
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
return tools.ParseParams(t.Parameters, data, claims)
return tools.ParseParams(t.AllParams, data, claims)
}
func (t Tool) Manifest() tools.Manifest {

View File

@@ -92,3 +92,76 @@ func TestParseFromYamlPostgres(t *testing.T) {
}
}
func TestParseFromYamlWithTemplateParamsPostgres(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
desc string
in string
want server.ToolConfigs
}{
{
desc: "basic example",
in: `
tools:
example_tool:
kind: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: name
type: string
description: some description
templateParameters:
- name: tableName
type: string
description: The table to select hotels from.
- name: fieldArray
type: array
description: The columns to return for the query.
items:
name: column
type: string
description: A column name that will be returned from the query.
`,
want: server.ToolConfigs{
"example_tool": postgressql.Config{
Name: "example_tool",
Kind: "postgres-sql",
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
AuthRequired: []string{},
Parameters: []tools.Parameter{
tools.NewStringParameter("name", "some description"),
},
TemplateParameters: []tools.Parameter{
tools.NewStringParameter("tableName", "The table to select hotels from."),
tools.NewArrayParameter("fieldArray", "The columns to return for the query.", tools.NewStringParameter("column", "A column name that will be returned from the query.")),
},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})
}
}

View File

@@ -131,6 +131,7 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
// create table name with UUID
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
// set up data for param tool
create_statement1, insert_statement1, tool_statement1, params1 := tests.GetPostgresSQLParamToolInfo(tableNameParam)
@@ -145,7 +146,7 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, ALLOYDB_POSTGRES_TOOL_KIND, tool_statement1, tool_statement2)
toolsFile = tests.AddPgExecuteSqlConfig(t, toolsFile)
toolsFile = tests.AddTemplateParamConfig(t, toolsFile)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
t.Fatalf("command initialization returned an error: %s", err)
@@ -167,6 +168,7 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
}
// Test connection with different IP type

View File

@@ -23,6 +23,7 @@ import (
"fmt"
"testing"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/jackc/pgx/v5/pgxpool"
)
@@ -128,6 +129,85 @@ func AddPgExecuteSqlConfig(t *testing.T, config map[string]any) map[string]any {
return config
}
func AddTemplateParamConfig(t *testing.T, config map[string]any) map[string]any {
toolsMap, ok := config["tools"].(map[string]any)
if !ok {
t.Fatalf("unable to get tools from config")
}
toolsMap["create-table-templateParams-tool"] = map[string]any{
"kind": "postgres-sql",
"source": "my-instance",
"description": "Create table tool with template parameters",
"statement": "CREATE TABLE {{.tableName}} ({{array .columns}})",
"templateParameters": []tools.Parameter{
tools.NewStringParameter("tableName", "some description"),
tools.NewArrayParameter("columns", "The columns to create", tools.NewStringParameter("column", "A column name that will be created")),
},
}
toolsMap["insert-table-templateParams-tool"] = map[string]any{
"kind": "postgres-sql",
"source": "my-instance",
"description": "Insert tool with template parameters",
"statement": "INSERT INTO {{.tableName}} ({{array .columns}}) VALUES ({{.values}})",
"templateParameters": []tools.Parameter{
tools.NewStringParameter("tableName", "some description"),
tools.NewArrayParameter("columns", "The columns to insert into", tools.NewStringParameter("column", "A column name that will be returned from the query.")),
tools.NewStringParameter("values", "The values to insert as a comma separated string"),
},
}
toolsMap["select-templateParams-tool"] = map[string]any{
"kind": "postgres-sql",
"source": "my-instance",
"description": "Create table tool with template parameters",
"statement": "SELECT * FROM {{.tableName}}",
"templateParameters": []tools.Parameter{
tools.NewStringParameter("tableName", "some description"),
},
}
toolsMap["select-templateParams-combined-tool"] = map[string]any{
"kind": "postgres-sql",
"source": "my-instance",
"description": "Create table tool with template parameters",
"statement": "SELECT * FROM {{.tableName}} WHERE id = $1",
"parameters": []tools.Parameter{tools.NewStringParameter("id", "the id of the user")},
"templateParameters": []tools.Parameter{
tools.NewStringParameter("tableName", "some description"),
},
}
toolsMap["select-fields-templateParams-tool"] = map[string]any{
"kind": "postgres-sql",
"source": "my-instance",
"description": "Create table tool with template parameters",
"statement": "SELECT {{array .fields}} FROM {{.tableName}}",
"templateParameters": []tools.Parameter{
tools.NewStringParameter("tableName", "some description"),
tools.NewArrayParameter("fields", "The fields to select from", tools.NewStringParameter("field", "A field that will be returned from the query.")),
},
}
toolsMap["select-filter-templateParams-combined-tool"] = map[string]any{
"kind": "postgres-sql",
"source": "my-instance",
"description": "Create table tool with template parameters",
"statement": "SELECT * FROM {{.tableName}} WHERE {{.columnFilter}} = $1",
"parameters": []tools.Parameter{tools.NewStringParameter("name", "the name of the user")},
"templateParameters": []tools.Parameter{
tools.NewStringParameter("tableName", "some description"),
tools.NewStringParameter("columnFilter", "some description"),
},
}
toolsMap["drop-table-templateParams-tool"] = map[string]any{
"kind": "postgres-sql",
"source": "my-instance",
"description": "Drop table tool with template parameters",
"statement": "DROP TABLE IF EXISTS {{.tableName}}",
"templateParameters": []tools.Parameter{
tools.NewStringParameter("tableName", "some description"),
},
}
config["tools"] = toolsMap
return config
}
// AddMySqlExecuteSqlConfig gets the tools config for `mysql-execute-sql`
func AddMySqlExecuteSqlConfig(t *testing.T, config map[string]any) map[string]any {
tools, ok := config["tools"].(map[string]any)

View File

@@ -211,6 +211,129 @@ func RunToolInvokeTest(t *testing.T, select_1_want, invoke_param_want string) {
}
}
func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string) {
select_all_want := "[{\"age\":21,\"id\":1,\"name\":\"Alex\"},{\"age\":100,\"id\":2,\"name\":\"Alice\"}]"
select_only_1_want := "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]"
select_only_names_want := "[{\"name\":\"Alex\"},{\"name\":\"Alice\"}]"
// Test tool invoke endpoint
invokeTcs := []struct {
name string
api string
requestHeader map[string]string
requestBody io.Reader
want string
isErr bool
}{
{
name: "invoke create-table-templateParams-tool",
api: "http://127.0.0.1:5000/api/tool/create-table-templateParams-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "columns":["id INT","name VARCHAR(20)","age INT"]}`, tableName))),
want: "null",
isErr: false,
},
{
name: "invoke insert-table-templateParams-tool",
api: "http://127.0.0.1:5000/api/tool/insert-table-templateParams-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "columns":["id","name","age"], "values":"1, 'Alex', 21"}`, tableName))),
want: "null",
isErr: false,
},
{
name: "invoke insert-table-templateParams-tool",
api: "http://127.0.0.1:5000/api/tool/insert-table-templateParams-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "columns":["id","name","age"], "values":"2, 'Alice', 100"}`, tableName))),
want: "null",
isErr: false,
},
{
name: "invoke select-templateParams-tool",
api: "http://127.0.0.1:5000/api/tool/select-templateParams-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s"}`, tableName))),
want: select_all_want,
isErr: false,
},
{
name: "invoke select-templateParams-combined-tool",
api: "http://127.0.0.1:5000/api/tool/select-templateParams-combined-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"id": "1", "tableName": "%s"}`, tableName))),
want: select_only_1_want,
isErr: false,
},
{
name: "invoke select-fields-templateParams-tool",
api: "http://127.0.0.1:5000/api/tool/select-fields-templateParams-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "fields":["name"]}`, tableName))),
want: select_only_names_want,
isErr: false,
},
{
name: "invoke select-filter-templateParams-combined-tool",
api: "http://127.0.0.1:5000/api/tool/select-filter-templateParams-combined-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"name": "Alex", "tableName": "%s", "columnFilter": "name"}`, tableName))),
want: select_only_1_want,
isErr: false,
},
{
name: "invoke drop-table-templateParams-tool",
api: "http://127.0.0.1:5000/api/tool/drop-table-templateParams-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s"}`, tableName))),
want: "null",
isErr: false,
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
// Send Tool invocation request
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
for k, v := range tc.requestHeader {
req.Header.Add(k, v)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if tc.isErr {
return
}
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
}
// Check response body
var body map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&body)
if err != nil {
t.Fatalf("error parsing response body")
}
got, ok := body["result"].(string)
if !ok {
t.Fatalf("unable to find result in response body")
}
if got != tc.want {
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
}
})
}
}
func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement string, select_1_want string) {
// Get ID token
idToken, err := GetGoogleIdToken(ClientId)