mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-11 00:18:17 -05:00
feat(tools/bigtable): add templateParameters field for bigtable (#692)
Add templateParameters to support non-filter parameters and DDL statements. Added a new argument `ignoreInsert` at integration test. Bigtable only allow `SELECT` statement. This is used to filter insert statement for bigtable. Part of #535
This commit is contained in:
@@ -51,13 +51,14 @@ var _ compatibleSource = &bigtabledb.Source{}
|
||||
var compatibleSources = [...]string{bigtabledb.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Statement string `yaml:"statement" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Statement string `yaml:"statement" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
@@ -80,22 +81,26 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, paramMcpManifest := tools.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: cfg.Parameters.McpManifest(),
|
||||
InputSchema: paramMcpManifest,
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: cfg.Parameters,
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Client: s.BigtableClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: cfg.Parameters,
|
||||
TemplateParameters: cfg.TemplateParameters,
|
||||
AllParams: allParameters,
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Client: s.BigtableClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -104,10 +109,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
Client *bigtable.Client
|
||||
Statement string
|
||||
@@ -141,21 +148,32 @@ func getMapParamsType(tparams tools.Parameters, params tools.ParamValues) (map[s
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
mapParamsType, err := getMapParamsType(t.Parameters, params)
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract template params %w", err)
|
||||
}
|
||||
|
||||
newParams, err := tools.GetParams(t.Parameters, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
}
|
||||
|
||||
mapParamsType, err := getMapParamsType(t.Parameters, newParams)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fail to get map params: %w", err)
|
||||
}
|
||||
|
||||
ps, err := t.Client.PrepareStatement(
|
||||
ctx,
|
||||
t.Statement,
|
||||
newStatement,
|
||||
mapParamsType,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to prepare statement: %w", err)
|
||||
}
|
||||
|
||||
bs, err := ps.Bind(params.AsMap())
|
||||
bs, err := ps.Bind(newParams.AsMap())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to bind: %w", err)
|
||||
}
|
||||
@@ -183,7 +201,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, erro
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claims)
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
|
||||
@@ -82,3 +82,76 @@ func TestParseFromYamlBigtable(t *testing.T) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestParseFromYamlWithTemplateBigtable(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: bigtable-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
templateParameters:
|
||||
- name: tableName
|
||||
type: string
|
||||
description: The table to select hotels from.
|
||||
- name: fieldArray
|
||||
type: array
|
||||
description: The columns to return for the query.
|
||||
items:
|
||||
name: column
|
||||
type: string
|
||||
description: A column name that will be returned from the query.
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": bigtable.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "bigtable-sql",
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
AuthRequired: []string{},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameter("country", "some description"),
|
||||
},
|
||||
TemplateParameters: []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "The table to select hotels from."),
|
||||
tools.NewArrayParameter("fieldArray", "The columns to return for the query.", tools.NewStringParameter("column", "A column name that will be returned from the query.")),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -170,7 +170,7 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tmplSelectAllWant, tmplSelect1Want, false)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tmplSelectAllWant, tmplSelect1Want, "", "", false, false)
|
||||
}
|
||||
|
||||
// Test connection with different IP type
|
||||
|
||||
@@ -29,6 +29,7 @@ import (
|
||||
|
||||
"cloud.google.com/go/bigtable"
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
@@ -69,6 +70,7 @@ func TestBigtableToolEndpoints(t *testing.T) {
|
||||
|
||||
tableName := "param_table" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameTemplateParam := "tmpl_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
columnFamilyName := "cf"
|
||||
muts, rowKeys := getTestData(columnFamilyName)
|
||||
@@ -85,8 +87,14 @@ func TestBigtableToolEndpoints(t *testing.T) {
|
||||
teardownTable2 := setupBtTable(t, ctx, sourceConfig["project"].(string), sourceConfig["instance"].(string), tableNameAuth, columnFamilyName, muts, rowKeys)
|
||||
defer teardownTable2(t)
|
||||
|
||||
mutsTmpl, rowKeysTmpl := getTestDataTemplateParam(columnFamilyName)
|
||||
teardownTableTmpl := setupBtTable(t, ctx, sourceConfig["project"].(string), sourceConfig["instance"].(string), tableNameTemplateParam, columnFamilyName, mutsTmpl, rowKeysTmpl)
|
||||
defer teardownTableTmpl(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, BIGTABLE_TOOL_KIND, param_test_statement, auth_tool_statement)
|
||||
toolsFile = addTemplateParamConfig(t, toolsFile)
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
@@ -106,9 +114,20 @@ func TestBigtableToolEndpoints(t *testing.T) {
|
||||
// Actual test parameters are set in https://github.com/googleapis/genai-toolbox/blob/52b09a67cb40ac0c5f461598b4673136699a3089/tests/tool_test.go#L250
|
||||
select1Want := "[{\"$col1\":1}]"
|
||||
failInvocationWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to prepare statement: rpc error: code = InvalidArgument desc = Syntax error: Unexpected identifier \"SELEC\" [at 1:1]"}],"isError":true}}`
|
||||
invokeParamWant, mcpInvokeParamWant, _, _ := tests.GetNonSpannerInvokeParamWant()
|
||||
invokeParamWant, mcpInvokeParamWant, tmplSelectAllWant, tmplSelect1Want := tests.GetNonSpannerInvokeParamWant()
|
||||
replaceNameFieldArray := `["CAST(cf['name'] AS string) as name"]`
|
||||
replaceNameColFilter := "CAST(cf['name'] AS string)"
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tmplSelectAllWant, tmplSelect1Want, replaceNameFieldArray, replaceNameColFilter, true, true)
|
||||
}
|
||||
|
||||
func convertToBytes(v int) []byte {
|
||||
binary1 := new(bytes.Buffer)
|
||||
if err := binary.Write(binary1, binary.BigEndian, int64(v)); err != nil {
|
||||
log.Fatalf("Unable to encode id: %v", err)
|
||||
}
|
||||
return binary1.Bytes()
|
||||
}
|
||||
|
||||
func getTestData(columnFamilyName string) ([]*bigtable.Mutation, []string) {
|
||||
@@ -117,11 +136,7 @@ func getTestData(columnFamilyName string) ([]*bigtable.Mutation, []string) {
|
||||
|
||||
var ids [3][]byte
|
||||
for i := range ids {
|
||||
binary1 := new(bytes.Buffer)
|
||||
if err := binary.Write(binary1, binary.BigEndian, int64(i+1)); err != nil {
|
||||
log.Fatalf("Unable to encode id: %v", err)
|
||||
}
|
||||
ids[i] = binary1.Bytes()
|
||||
ids[i] = convertToBytes(i + 1)
|
||||
}
|
||||
|
||||
now := bigtable.Time(time.Now())
|
||||
@@ -154,6 +169,41 @@ func getTestData(columnFamilyName string) ([]*bigtable.Mutation, []string) {
|
||||
return muts, rowKeys
|
||||
}
|
||||
|
||||
func getTestDataTemplateParam(columnFamilyName string) ([]*bigtable.Mutation, []string) {
|
||||
muts := []*bigtable.Mutation{}
|
||||
rowKeys := []string{}
|
||||
|
||||
var ids [2][]byte
|
||||
for i := range ids {
|
||||
ids[i] = convertToBytes(i + 1)
|
||||
}
|
||||
|
||||
now := bigtable.Time(time.Now())
|
||||
for rowKey, mutData := range map[string]map[string][]byte{
|
||||
// Do not change the test data without checking tests/common_test.go.
|
||||
// The structure and value of seed data has to match https://github.com/googleapis/genai-toolbox/blob/4dba0df12dc438eca3cb476ef52aa17cdf232c12/tests/common_test.go#L200-L251
|
||||
// Expected values are defined in https://github.com/googleapis/genai-toolbox/blob/52b09a67cb40ac0c5f461598b4673136699a3089/tests/tool_test.go#L229-L310
|
||||
"row-01": {
|
||||
"name": []byte("Alex"),
|
||||
"age": convertToBytes(21),
|
||||
"id": ids[0],
|
||||
},
|
||||
"row-02": {
|
||||
"name": []byte("Alice"),
|
||||
"age": convertToBytes(100),
|
||||
"id": ids[1],
|
||||
},
|
||||
} {
|
||||
mut := bigtable.NewMutation()
|
||||
for col, v := range mutData {
|
||||
mut.Set(columnFamilyName, col, now, v)
|
||||
}
|
||||
muts = append(muts, mut)
|
||||
rowKeys = append(rowKeys, rowKey)
|
||||
}
|
||||
return muts, rowKeys
|
||||
}
|
||||
|
||||
func setupBtTable(t *testing.T, ctx context.Context, projectId string, instance string, tableName string, columnFamilyName string, muts []*bigtable.Mutation, rowKeys []string) func(*testing.T) {
|
||||
// Creating clients
|
||||
adminClient, err := bigtable.NewAdminClient(ctx, projectId, instance)
|
||||
@@ -213,3 +263,52 @@ func setupBtTable(t *testing.T, ctx context.Context, projectId string, instance
|
||||
defer adminClient.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func addTemplateParamConfig(t *testing.T, config map[string]any) map[string]any {
|
||||
toolsMap, ok := config["tools"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("unable to get tools from config")
|
||||
}
|
||||
toolsMap["select-templateParams-tool"] = map[string]any{
|
||||
"kind": "bigtable-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Create table tool with template parameters",
|
||||
"statement": "SELECT TO_INT64(cf['age']) as age, TO_INT64(cf['id']) as id, CAST(cf['name'] AS string) as name, FROM {{.tableName}};",
|
||||
"templateParameters": []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "some description"),
|
||||
},
|
||||
}
|
||||
toolsMap["select-templateParams-combined-tool"] = map[string]any{
|
||||
"kind": "bigtable-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Create table tool with template parameters",
|
||||
"statement": "SELECT TO_INT64(cf['age']) as age, TO_INT64(cf['id']) as id, CAST(cf['name'] AS string) as name, FROM {{.tableName}} WHERE TO_INT64(cf['id']) = @id;",
|
||||
"parameters": []tools.Parameter{tools.NewIntParameter("id", "the id of the user")},
|
||||
"templateParameters": []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "some description"),
|
||||
},
|
||||
}
|
||||
toolsMap["select-fields-templateParams-tool"] = map[string]any{
|
||||
"kind": "bigtable-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Create table tool with template parameters",
|
||||
"statement": "SELECT {{array .fields}}, FROM {{.tableName}};",
|
||||
"templateParameters": []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "some description"),
|
||||
tools.NewArrayParameter("fields", "The fields to select from", tools.NewStringParameter("field", "A field that will be returned from the query.")),
|
||||
},
|
||||
}
|
||||
toolsMap["select-filter-templateParams-combined-tool"] = map[string]any{
|
||||
"kind": "bigtable-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Create table tool with template parameters",
|
||||
"statement": "SELECT TO_INT64(cf['age']) as age, TO_INT64(cf['id']) as id, CAST(cf['name'] AS string) as name, FROM {{.tableName}} WHERE {{.columnFilter}} = @name;",
|
||||
"parameters": []tools.Parameter{tools.NewStringParameter("name", "the name of the user")},
|
||||
"templateParameters": []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "some description"),
|
||||
tools.NewStringParameter("columnFilter", "some description"),
|
||||
},
|
||||
}
|
||||
config["tools"] = toolsMap
|
||||
return config
|
||||
}
|
||||
|
||||
@@ -157,7 +157,7 @@ func TestCloudSQLMssqlToolEndpoints(t *testing.T) {
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tmplSelectAllWant, tmplSelect1Want, false)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tmplSelectAllWant, tmplSelect1Want, "", "", false, false)
|
||||
}
|
||||
|
||||
// Test connection with different IP type
|
||||
|
||||
@@ -151,7 +151,7 @@ func TestCloudSQLMysqlToolEndpoints(t *testing.T) {
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tmplSelectAllWant, tmplSelect1Want, false)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tmplSelectAllWant, tmplSelect1Want, "", "", false, false)
|
||||
}
|
||||
|
||||
// Test connection with different IP type
|
||||
|
||||
@@ -155,7 +155,7 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tmplSelectAllWant, tmplSelect1Want, false)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tmplSelectAllWant, tmplSelect1Want, "", "", false, false)
|
||||
}
|
||||
|
||||
// Test connection with different IP type
|
||||
|
||||
@@ -129,5 +129,5 @@ func TestMssqlToolEndpoints(t *testing.T) {
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tmplSelectAllWant, tmplSelect1Want, false)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tmplSelectAllWant, tmplSelect1Want, "", "", false, false)
|
||||
}
|
||||
|
||||
@@ -128,5 +128,5 @@ func TestMysqlToolEndpoints(t *testing.T) {
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tmplSelectAllWant, tmplSelect1Want, false)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tmplSelectAllWant, tmplSelect1Want, "", "", false, false)
|
||||
}
|
||||
|
||||
@@ -128,5 +128,5 @@ func TestPostgres(t *testing.T) {
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tmplSelectAllWant, tmplSelect1Want, false)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tmplSelectAllWant, tmplSelect1Want, "", "", false, false)
|
||||
}
|
||||
|
||||
@@ -161,7 +161,7 @@ func TestSpannerToolEndpoints(t *testing.T) {
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
runSpannerSchemaToolInvokeTest(t, accessSchemaWant)
|
||||
runSpannerExecuteSqlToolInvokeTest(t, select1Want, invokeParamWant, tableNameParam, tableNameAuth)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tmplSelectAllWant, tmplSelect1Want, true)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tmplSelectAllWant, tmplSelect1Want, "", "", true, false)
|
||||
}
|
||||
|
||||
// getSpannerToolInfo returns statements and param for my-param-tool for spanner-sql kind
|
||||
|
||||
@@ -156,5 +156,5 @@ func TestSQLiteToolEndpoint(t *testing.T) {
|
||||
invokeParamWant, mcpInvokeParamWant, tmplSelectAllWant, tmplSelect1Want := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tmplSelectAllWant, tmplSelect1Want, false)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tmplSelectAllWant, tmplSelect1Want, "", "", false, false)
|
||||
}
|
||||
|
||||
@@ -211,13 +211,26 @@ func RunToolInvokeTest(t *testing.T, select1Want, invokeParamWant string) {
|
||||
}
|
||||
}
|
||||
|
||||
func RunToolInvokeWithTemplateParameters(t *testing.T, tableName, select_all_want, select_only_1_want string, ignoreDdl bool) {
|
||||
// RunToolInvokeWithTemplateParameters runs tool invoke test cases with template parameters.
|
||||
// ignoreDdl is used for sources that does not support DDL statement.
|
||||
// replaceNameFieldArray and replaceNameColFilter is used for bigtable since it have a different formatting for sql statement.
|
||||
// ignoreInsert is used for bigtable since it does not support other DML statement other than `SELECT`.
|
||||
func RunToolInvokeWithTemplateParameters(t *testing.T, tableName, select_all_want, select_only_1_want, replaceNameFieldArray, replaceNameColFilter string, ignoreDdl, ignoreInsert bool) {
|
||||
select_only_names_want := "[{\"name\":\"Alex\"},{\"name\":\"Alice\"}]"
|
||||
nameFieldArray := `["name"]`
|
||||
nameColFilter := "name"
|
||||
if replaceNameFieldArray != "" {
|
||||
nameFieldArray = replaceNameFieldArray
|
||||
}
|
||||
if replaceNameColFilter != "" {
|
||||
nameColFilter = replaceNameColFilter
|
||||
}
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
ddl bool
|
||||
insert bool
|
||||
api string
|
||||
requestHeader map[string]string
|
||||
requestBody io.Reader
|
||||
@@ -235,6 +248,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName, select_all_wan
|
||||
},
|
||||
{
|
||||
name: "invoke insert-table-templateParams-tool",
|
||||
insert: true,
|
||||
api: "http://127.0.0.1:5000/api/tool/insert-table-templateParams-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "columns":["id","name","age"], "values":"1, 'Alex', 21"}`, tableName))),
|
||||
@@ -243,6 +257,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName, select_all_wan
|
||||
},
|
||||
{
|
||||
name: "invoke insert-table-templateParams-tool",
|
||||
insert: true,
|
||||
api: "http://127.0.0.1:5000/api/tool/insert-table-templateParams-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "columns":["id","name","age"], "values":"2, 'Alice', 100"}`, tableName))),
|
||||
@@ -269,7 +284,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName, select_all_wan
|
||||
name: "invoke select-fields-templateParams-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/select-fields-templateParams-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "fields":["name"]}`, tableName))),
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "fields":%s}`, tableName, nameFieldArray))),
|
||||
want: select_only_names_want,
|
||||
isErr: false,
|
||||
},
|
||||
@@ -277,7 +292,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName, select_all_wan
|
||||
name: "invoke select-filter-templateParams-combined-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/select-filter-templateParams-combined-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"name": "Alex", "tableName": "%s", "columnFilter": "name"}`, tableName))),
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"name": "Alex", "tableName": "%s", "columnFilter": "%s"}`, tableName, nameColFilter))),
|
||||
want: select_only_1_want,
|
||||
isErr: false,
|
||||
},
|
||||
@@ -293,7 +308,11 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName, select_all_wan
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if !tc.ddl || (tc.ddl && !ignoreDdl) {
|
||||
// if test case is DDL and source does not ignore ddl test cases
|
||||
ddlAllow := !tc.ddl || (tc.ddl && !ignoreDdl)
|
||||
// if test case is insert statement and source does not ignore insert test cases
|
||||
insertAllow := !tc.insert || (tc.insert && !ignoreInsert)
|
||||
if ddlAllow && insertAllow {
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user