mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 16:08:16 -05:00
feat(tools/postgressql): add templateParameters field (#615)
Add new tool field, templateParameters, to support non-filter parameters and DDL statements Fix #535 for postgressql tool. --------- Co-authored-by: Yuan <45984206+Yuan325@users.noreply.github.com>
This commit is contained in:
@@ -15,10 +15,12 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
)
|
||||
@@ -146,6 +148,62 @@ func ParseParams(ps Parameters, data map[string]any, claimsMap map[string]map[st
|
||||
return params, nil
|
||||
}
|
||||
|
||||
// helper function to convert a string array parameter to a comma separated string
|
||||
func ConvertArrayParamToString(param any) (string, error) {
|
||||
switch v := param.(type) {
|
||||
case []any:
|
||||
var stringValues []string
|
||||
for _, item := range v {
|
||||
stringVal, ok := item.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("templateParameter only supports string arrays")
|
||||
}
|
||||
stringValues = append(stringValues, stringVal)
|
||||
}
|
||||
return strings.Join(stringValues, ", "), nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid parameter type, expected array of type string")
|
||||
}
|
||||
}
|
||||
|
||||
// GetParams return the ParamValues that are associated with the Parameters.
|
||||
func GetParams(params Parameters, paramValuesMap map[string]any) (ParamValues, error) {
|
||||
resultParamValues := make(ParamValues, 0)
|
||||
for _, p := range params {
|
||||
k := p.GetName()
|
||||
v, ok := paramValuesMap[k]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing parameter %s", k)
|
||||
}
|
||||
resultParamValues = append(resultParamValues, ParamValue{Name: k, Value: v})
|
||||
}
|
||||
return resultParamValues, nil
|
||||
}
|
||||
|
||||
func ResolveTemplateParams(templateParams Parameters, originalStatement string, paramsMap map[string]any) (string, error) {
|
||||
templateParamsValues, err := GetParams(templateParams, paramsMap)
|
||||
templateParamsMap := templateParamsValues.AsMap()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting template params %s", err)
|
||||
}
|
||||
|
||||
funcMap := template.FuncMap{
|
||||
"array": ConvertArrayParamToString,
|
||||
}
|
||||
t, err := template.New("statement").Funcs(funcMap).Parse(originalStatement)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating go template %s", err)
|
||||
}
|
||||
var result bytes.Buffer
|
||||
err = t.Execute(&result, templateParamsMap)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error executing go template %s", err)
|
||||
}
|
||||
|
||||
modifiedStatement := result.String()
|
||||
return modifiedStatement, nil
|
||||
}
|
||||
|
||||
type Parameter interface {
|
||||
// Note: It's typically not idiomatic to include "Get" in the function name,
|
||||
// but this is done to differentiate it from the fields in CommonParameter.
|
||||
|
||||
@@ -980,3 +980,310 @@ func TestFailParametersUnmarshal(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertArrayParamToString(t *testing.T) {
|
||||
tcs := []struct {
|
||||
name string
|
||||
in []any
|
||||
want string
|
||||
}{
|
||||
{
|
||||
in: []any{
|
||||
"id",
|
||||
"name",
|
||||
"location",
|
||||
},
|
||||
want: "id, name, location",
|
||||
},
|
||||
{
|
||||
in: []any{
|
||||
"id",
|
||||
},
|
||||
want: "id",
|
||||
},
|
||||
{
|
||||
in: []any{
|
||||
"id",
|
||||
"5",
|
||||
"false",
|
||||
},
|
||||
want: "id, 5, false",
|
||||
},
|
||||
{
|
||||
in: []any{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
in: []any{},
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, _ := tools.ConvertArrayParamToString(tc.in)
|
||||
if diff := cmp.Diff(tc.want, got); diff != "" {
|
||||
t.Fatalf("incorrect array param conversion: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailConvertArrayParamToString(t *testing.T) {
|
||||
tcs := []struct {
|
||||
name string
|
||||
in []any
|
||||
err string
|
||||
}{
|
||||
{
|
||||
in: []any{5, 10, 15},
|
||||
err: "templateParameter only supports string arrays",
|
||||
},
|
||||
{
|
||||
in: []any{"id", "name", 15},
|
||||
err: "templateParameter only supports string arrays",
|
||||
},
|
||||
{
|
||||
in: []any{false},
|
||||
err: "templateParameter only supports string arrays",
|
||||
},
|
||||
{
|
||||
in: []any{10, true},
|
||||
err: "templateParameter only supports string arrays",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := tools.ConvertArrayParamToString(tc.in)
|
||||
errStr := err.Error()
|
||||
if errStr != tc.err {
|
||||
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetParams(t *testing.T) {
|
||||
tcs := []struct {
|
||||
name string
|
||||
in map[string]any
|
||||
params tools.Parameters
|
||||
want tools.ParamValues
|
||||
}{
|
||||
{
|
||||
name: "parameters to include and exclude",
|
||||
params: tools.Parameters{
|
||||
tools.NewStringParameter("my_string_inc", "this should be included"),
|
||||
tools.NewStringParameter("my_string_inc2", "this should be included"),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_string_inc": "hello world A",
|
||||
"my_string_inc2": "hello world B",
|
||||
"my_string_exc": "hello world C",
|
||||
},
|
||||
want: tools.ParamValues{
|
||||
tools.ParamValue{Name: "my_string_inc", Value: "hello world A"},
|
||||
tools.ParamValue{Name: "my_string_inc2", Value: "hello world B"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "include all",
|
||||
params: tools.Parameters{
|
||||
tools.NewStringParameter("my_string_inc", "this should be included"),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_string_inc": "hello world A",
|
||||
},
|
||||
want: tools.ParamValues{
|
||||
tools.ParamValue{Name: "my_string_inc", Value: "hello world A"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "exclude all",
|
||||
params: tools.Parameters{},
|
||||
in: map[string]any{
|
||||
"my_string_exc": "hello world A",
|
||||
"my_string_exc2": "hello world B",
|
||||
},
|
||||
want: tools.ParamValues{},
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
params: tools.Parameters{},
|
||||
in: map[string]any{},
|
||||
want: tools.ParamValues{},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, _ := tools.GetParams(tc.params, tc.in)
|
||||
if diff := cmp.Diff(tc.want, got); diff != "" {
|
||||
t.Fatalf("incorrect get params: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailGetParams(t *testing.T) {
|
||||
tcs := []struct {
|
||||
name string
|
||||
params tools.Parameters
|
||||
in map[string]any
|
||||
err string
|
||||
}{
|
||||
{
|
||||
name: "missing the only parameter",
|
||||
params: tools.Parameters{tools.NewStringParameter("my_string", "this was missing")},
|
||||
in: map[string]any{},
|
||||
err: "missing parameter my_string",
|
||||
},
|
||||
{
|
||||
name: "missing one parameter of multiple",
|
||||
params: tools.Parameters{
|
||||
tools.NewStringParameter("my_string_inc", "this should be included"),
|
||||
tools.NewStringParameter("my_string_exc", "this was missing"),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_string_inc": "hello world A",
|
||||
},
|
||||
err: "missing parameter my_string_exc",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := tools.GetParams(tc.params, tc.in)
|
||||
errStr := err.Error()
|
||||
if errStr != tc.err {
|
||||
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveTemplateParameters(t *testing.T) {
|
||||
tcs := []struct {
|
||||
name string
|
||||
templateParams tools.Parameters
|
||||
statement string
|
||||
in map[string]any
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "single template parameter",
|
||||
templateParams: tools.Parameters{
|
||||
tools.NewStringParameter("tableName", "this is a string template parameter"),
|
||||
},
|
||||
statement: "SELECT * FROM {{.tableName}}",
|
||||
in: map[string]any{
|
||||
"tableName": "hotels",
|
||||
},
|
||||
want: "SELECT * FROM hotels",
|
||||
},
|
||||
{
|
||||
name: "multiple template parameters",
|
||||
templateParams: tools.Parameters{
|
||||
tools.NewStringParameter("tableName", "this is a string template parameter"),
|
||||
tools.NewStringParameter("columnName", "this is a string template parameter"),
|
||||
},
|
||||
statement: "SELECT * FROM {{.tableName}} WHERE {{.columnName}} = 'Hilton'",
|
||||
in: map[string]any{
|
||||
"tableName": "hotels",
|
||||
"columnName": "name",
|
||||
},
|
||||
want: "SELECT * FROM hotels WHERE name = 'Hilton'",
|
||||
},
|
||||
{
|
||||
name: "standard and template parameter",
|
||||
templateParams: tools.Parameters{
|
||||
tools.NewStringParameter("tableName", "this is a string template parameter"),
|
||||
tools.NewStringParameter("hotelName", "this is a string parameter"),
|
||||
},
|
||||
statement: "SELECT * FROM {{.tableName}} WHERE name = $1",
|
||||
in: map[string]any{
|
||||
"tableName": "hotels",
|
||||
"hotelName": "name",
|
||||
},
|
||||
want: "SELECT * FROM hotels WHERE name = $1",
|
||||
},
|
||||
{
|
||||
name: "standard parameter",
|
||||
templateParams: tools.Parameters{
|
||||
tools.NewStringParameter("hotelName", "this is a string parameter"),
|
||||
},
|
||||
statement: "SELECT * FROM hotels WHERE name = $1",
|
||||
in: map[string]any{
|
||||
"hotelName": "hotels",
|
||||
},
|
||||
want: "SELECT * FROM hotels WHERE name = $1",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, _ := tools.ResolveTemplateParams(tc.templateParams, tc.statement, tc.in)
|
||||
if diff := cmp.Diff(tc.want, got); diff != "" {
|
||||
t.Fatalf("incorrect resolved template params: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailResolveTemplateParameters(t *testing.T) {
|
||||
tcs := []struct {
|
||||
name string
|
||||
templateParams tools.Parameters
|
||||
statement string
|
||||
in map[string]any
|
||||
err string
|
||||
}{
|
||||
{
|
||||
name: "wrong param name",
|
||||
templateParams: tools.Parameters{
|
||||
tools.NewStringParameter("tableName", "this is a string template parameter"),
|
||||
},
|
||||
statement: "SELECT * FROM {{.missingParam}}",
|
||||
in: map[string]any{},
|
||||
err: "error getting template params missing parameter tableName",
|
||||
},
|
||||
{
|
||||
name: "incomplete param template",
|
||||
templateParams: tools.Parameters{
|
||||
tools.NewStringParameter("tableName", "this is a string template parameter"),
|
||||
},
|
||||
statement: "SELECT * FROM {{.tableName",
|
||||
in: map[string]any{
|
||||
"tableName": "hotels",
|
||||
},
|
||||
err: "error creating go template template: statement:1: unclosed action",
|
||||
},
|
||||
{
|
||||
name: "undefined function",
|
||||
templateParams: tools.Parameters{
|
||||
tools.NewStringParameter("tableName", "this is a string template parameter"),
|
||||
},
|
||||
statement: "SELECT * FROM {{json .tableName}}",
|
||||
in: map[string]any{
|
||||
"tableName": "hotels",
|
||||
},
|
||||
err: "error creating go template template: statement:1: function \"json\" not defined",
|
||||
},
|
||||
{
|
||||
name: "undefined method",
|
||||
templateParams: tools.Parameters{
|
||||
tools.NewStringParameter("tableName", "this is a string template parameter"),
|
||||
},
|
||||
statement: "SELECT * FROM {{.tableName .wrong}}",
|
||||
in: map[string]any{
|
||||
"tableName": "hotels",
|
||||
},
|
||||
err: "error executing go template template: statement:1:16: executing \"statement\" at <.tableName>: tableName is not a method but has arguments",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := tools.ResolveTemplateParams(tc.templateParams, tc.statement, tc.in)
|
||||
errStr := err.Error()
|
||||
if errStr != tc.err {
|
||||
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ package postgressql
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -55,13 +56,14 @@ var _ compatibleSource = &postgres.Source{}
|
||||
var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Statement string `yaml:"statement" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Statement string `yaml:"statement" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
@@ -84,22 +86,62 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters := slices.Concat(cfg.Parameters, cfg.TemplateParameters)
|
||||
|
||||
paramManifest := slices.Concat(
|
||||
cfg.Parameters.Manifest(),
|
||||
cfg.TemplateParameters.Manifest(),
|
||||
)
|
||||
if paramManifest == nil {
|
||||
paramManifest = make([]tools.ParameterManifest, 0)
|
||||
}
|
||||
|
||||
parametersMcpManifest := cfg.Parameters.McpManifest()
|
||||
templateParametersMcpManifest := cfg.TemplateParameters.McpManifest()
|
||||
|
||||
// Concatenate parameters for MCP `required` field
|
||||
concatRequiredManifest := slices.Concat(
|
||||
parametersMcpManifest.Required,
|
||||
templateParametersMcpManifest.Required,
|
||||
)
|
||||
if concatRequiredManifest == nil {
|
||||
concatRequiredManifest = []string{}
|
||||
}
|
||||
|
||||
// Concatenate parameters for MCP `properties` field
|
||||
concatPropertiesManifest := make(map[string]tools.ParameterMcpManifest)
|
||||
for name, p := range parametersMcpManifest.Properties {
|
||||
concatPropertiesManifest[name] = p
|
||||
}
|
||||
for name, p := range templateParametersMcpManifest.Properties {
|
||||
concatPropertiesManifest[name] = p
|
||||
}
|
||||
|
||||
// Create a new McpToolsSchema with all parameters
|
||||
paramMcpManifest := tools.McpToolsSchema{
|
||||
Type: "object",
|
||||
Properties: concatPropertiesManifest,
|
||||
Required: concatRequiredManifest,
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: cfg.Parameters.McpManifest(),
|
||||
InputSchema: paramMcpManifest,
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: cfg.Parameters,
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Pool: s.PostgresPool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: cfg.Parameters,
|
||||
TemplateParameters: cfg.TemplateParameters,
|
||||
AllParams: allParameters,
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Pool: s.PostgresPool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -108,10 +150,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
Pool *pgxpool.Pool
|
||||
Statement string
|
||||
@@ -120,8 +164,18 @@ type Tool struct {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
sliceParams := params.AsSlice()
|
||||
results, err := t.Pool.Query(ctx, t.Statement, sliceParams...)
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract template params %w", err)
|
||||
}
|
||||
|
||||
newParams, err := tools.GetParams(t.Parameters, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
}
|
||||
sliceParams := newParams.AsSlice()
|
||||
results, err := t.Pool.Query(ctx, newStatement, sliceParams...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
@@ -145,7 +199,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, erro
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claims)
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
|
||||
@@ -92,3 +92,76 @@ func TestParseFromYamlPostgres(t *testing.T) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestParseFromYamlWithTemplateParamsPostgres(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: name
|
||||
type: string
|
||||
description: some description
|
||||
templateParameters:
|
||||
- name: tableName
|
||||
type: string
|
||||
description: The table to select hotels from.
|
||||
- name: fieldArray
|
||||
type: array
|
||||
description: The columns to return for the query.
|
||||
items:
|
||||
name: column
|
||||
type: string
|
||||
description: A column name that will be returned from the query.
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": postgressql.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "postgres-sql",
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
AuthRequired: []string{},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameter("name", "some description"),
|
||||
},
|
||||
TemplateParameters: []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "The table to select hotels from."),
|
||||
tools.NewArrayParameter("fieldArray", "The columns to return for the query.", tools.NewStringParameter("column", "A column name that will be returned from the query.")),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -131,6 +131,7 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
||||
// create table name with UUID
|
||||
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
create_statement1, insert_statement1, tool_statement1, params1 := tests.GetPostgresSQLParamToolInfo(tableNameParam)
|
||||
@@ -145,7 +146,7 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, ALLOYDB_POSTGRES_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
toolsFile = tests.AddPgExecuteSqlConfig(t, toolsFile)
|
||||
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile)
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
@@ -167,6 +168,7 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
|
||||
}
|
||||
|
||||
// Test connection with different IP type
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
@@ -128,6 +129,85 @@ func AddPgExecuteSqlConfig(t *testing.T, config map[string]any) map[string]any {
|
||||
return config
|
||||
}
|
||||
|
||||
func AddTemplateParamConfig(t *testing.T, config map[string]any) map[string]any {
|
||||
toolsMap, ok := config["tools"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("unable to get tools from config")
|
||||
}
|
||||
toolsMap["create-table-templateParams-tool"] = map[string]any{
|
||||
"kind": "postgres-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Create table tool with template parameters",
|
||||
"statement": "CREATE TABLE {{.tableName}} ({{array .columns}})",
|
||||
"templateParameters": []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "some description"),
|
||||
tools.NewArrayParameter("columns", "The columns to create", tools.NewStringParameter("column", "A column name that will be created")),
|
||||
},
|
||||
}
|
||||
toolsMap["insert-table-templateParams-tool"] = map[string]any{
|
||||
"kind": "postgres-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Insert tool with template parameters",
|
||||
"statement": "INSERT INTO {{.tableName}} ({{array .columns}}) VALUES ({{.values}})",
|
||||
"templateParameters": []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "some description"),
|
||||
tools.NewArrayParameter("columns", "The columns to insert into", tools.NewStringParameter("column", "A column name that will be returned from the query.")),
|
||||
tools.NewStringParameter("values", "The values to insert as a comma separated string"),
|
||||
},
|
||||
}
|
||||
toolsMap["select-templateParams-tool"] = map[string]any{
|
||||
"kind": "postgres-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Create table tool with template parameters",
|
||||
"statement": "SELECT * FROM {{.tableName}}",
|
||||
"templateParameters": []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "some description"),
|
||||
},
|
||||
}
|
||||
toolsMap["select-templateParams-combined-tool"] = map[string]any{
|
||||
"kind": "postgres-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Create table tool with template parameters",
|
||||
"statement": "SELECT * FROM {{.tableName}} WHERE id = $1",
|
||||
"parameters": []tools.Parameter{tools.NewStringParameter("id", "the id of the user")},
|
||||
"templateParameters": []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "some description"),
|
||||
},
|
||||
}
|
||||
toolsMap["select-fields-templateParams-tool"] = map[string]any{
|
||||
"kind": "postgres-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Create table tool with template parameters",
|
||||
"statement": "SELECT {{array .fields}} FROM {{.tableName}}",
|
||||
"templateParameters": []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "some description"),
|
||||
tools.NewArrayParameter("fields", "The fields to select from", tools.NewStringParameter("field", "A field that will be returned from the query.")),
|
||||
},
|
||||
}
|
||||
toolsMap["select-filter-templateParams-combined-tool"] = map[string]any{
|
||||
"kind": "postgres-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Create table tool with template parameters",
|
||||
"statement": "SELECT * FROM {{.tableName}} WHERE {{.columnFilter}} = $1",
|
||||
"parameters": []tools.Parameter{tools.NewStringParameter("name", "the name of the user")},
|
||||
"templateParameters": []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "some description"),
|
||||
tools.NewStringParameter("columnFilter", "some description"),
|
||||
},
|
||||
}
|
||||
toolsMap["drop-table-templateParams-tool"] = map[string]any{
|
||||
"kind": "postgres-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Drop table tool with template parameters",
|
||||
"statement": "DROP TABLE IF EXISTS {{.tableName}}",
|
||||
"templateParameters": []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "some description"),
|
||||
},
|
||||
}
|
||||
config["tools"] = toolsMap
|
||||
return config
|
||||
}
|
||||
|
||||
// AddMySqlExecuteSqlConfig gets the tools config for `mysql-execute-sql`
|
||||
func AddMySqlExecuteSqlConfig(t *testing.T, config map[string]any) map[string]any {
|
||||
tools, ok := config["tools"].(map[string]any)
|
||||
|
||||
123
tests/tool.go
123
tests/tool.go
@@ -211,6 +211,129 @@ func RunToolInvokeTest(t *testing.T, select_1_want, invoke_param_want string) {
|
||||
}
|
||||
}
|
||||
|
||||
func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string) {
|
||||
select_all_want := "[{\"age\":21,\"id\":1,\"name\":\"Alex\"},{\"age\":100,\"id\":2,\"name\":\"Alice\"}]"
|
||||
select_only_1_want := "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]"
|
||||
select_only_names_want := "[{\"name\":\"Alex\"},{\"name\":\"Alice\"}]"
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestHeader map[string]string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
isErr bool
|
||||
}{
|
||||
{
|
||||
name: "invoke create-table-templateParams-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/create-table-templateParams-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "columns":["id INT","name VARCHAR(20)","age INT"]}`, tableName))),
|
||||
want: "null",
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke insert-table-templateParams-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/insert-table-templateParams-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "columns":["id","name","age"], "values":"1, 'Alex', 21"}`, tableName))),
|
||||
want: "null",
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke insert-table-templateParams-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/insert-table-templateParams-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "columns":["id","name","age"], "values":"2, 'Alice', 100"}`, tableName))),
|
||||
want: "null",
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke select-templateParams-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/select-templateParams-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s"}`, tableName))),
|
||||
want: select_all_want,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke select-templateParams-combined-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/select-templateParams-combined-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"id": "1", "tableName": "%s"}`, tableName))),
|
||||
want: select_only_1_want,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke select-fields-templateParams-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/select-fields-templateParams-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "fields":["name"]}`, tableName))),
|
||||
want: select_only_names_want,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke select-filter-templateParams-combined-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/select-filter-templateParams-combined-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"name": "Alex", "tableName": "%s", "columnFilter": "name"}`, tableName))),
|
||||
want: select_only_1_want,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke drop-table-templateParams-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/drop-table-templateParams-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s"}`, tableName))),
|
||||
want: "null",
|
||||
isErr: false,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement string, select_1_want string) {
|
||||
// Get ID token
|
||||
idToken, err := GetGoogleIdToken(ClientId)
|
||||
|
||||
Reference in New Issue
Block a user