mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 15:38:08 -05:00
feat(tools/couchbase): add templateParameters field for couchbase (#723)
Add templateParameters to support non-filter parameters and DDL statements. Part of https://github.com/googleapis/genai-toolbox/issues/535
This commit is contained in:
@@ -53,13 +53,14 @@ var _ compatibleSource = &couchbase.Source{}
|
||||
var compatibleSources = [...]string{couchbase.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Statement string `yaml:"statement" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Statement string `yaml:"statement" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
@@ -82,21 +83,25 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, paramMcpManifest := tools.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: cfg.Parameters.McpManifest(),
|
||||
InputSchema: paramMcpManifest,
|
||||
}
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: cfg.Parameters,
|
||||
TemplateParameters: cfg.TemplateParameters,
|
||||
AllParams: allParameters,
|
||||
Statement: cfg.Statement,
|
||||
Scope: s.CouchbaseScope(),
|
||||
QueryScanConsistency: s.CouchbaseQueryScanConsistency(),
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
@@ -106,10 +111,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
|
||||
Scope *gocb.Scope
|
||||
QueryScanConsistency uint
|
||||
@@ -119,10 +126,19 @@ type Tool struct {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
namedParams := params.AsMap()
|
||||
results, err := t.Scope.Query(t.Statement, &gocb.QueryOptions{
|
||||
namedParamsMap := params.AsMap()
|
||||
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, namedParamsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract template params %w", err)
|
||||
}
|
||||
|
||||
newParams, err := tools.GetParams(t.Parameters, namedParamsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
}
|
||||
results, err := t.Scope.Query(newStatement, &gocb.QueryOptions{
|
||||
ScanConsistency: gocb.QueryScanConsistency(t.QueryScanConsistency),
|
||||
NamedParameters: namedParams,
|
||||
NamedParameters: newParams.AsMap(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
@@ -141,7 +157,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, erro
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claimsMap)
|
||||
return tools.ParseParams(t.AllParams, data, claimsMap)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
|
||||
@@ -85,3 +85,67 @@ func TestParseFromYamlCouchbase(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFromYamlWithTemplateMssql(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: couchbase-sql
|
||||
source: my-couchbase-instance
|
||||
description: some tool description
|
||||
statement: |
|
||||
select * from {{.tableName}} WHERE name = $hotel;
|
||||
parameters:
|
||||
- name: hotel
|
||||
type: string
|
||||
description: hotel parameter description
|
||||
templateParameters:
|
||||
- name: tableName
|
||||
type: string
|
||||
description: The table to select hotels from.
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": couchbase.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "couchbase-sql",
|
||||
AuthRequired: []string{},
|
||||
Source: "my-couchbase-instance",
|
||||
Description: "some tool description",
|
||||
Statement: "select * from {{.tableName}} WHERE name = $hotel;\n",
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameter("hotel", "hotel parameter description"),
|
||||
},
|
||||
TemplateParameters: []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "The table to select hotels from."),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,7 +147,7 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, ALLOYDB_POSTGRES_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
toolsFile = tests.AddPgExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, ALLOYDB_POSTGRES_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined)
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, ALLOYDB_POSTGRES_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
|
||||
@@ -114,7 +114,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, BIGQUERY_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
toolsFile = addBigQueryPrebuiltToolsConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := getBigQueryTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, BIGQUERY_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined)
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, BIGQUERY_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
|
||||
@@ -134,7 +134,7 @@ func TestCloudSQLMssqlToolEndpoints(t *testing.T) {
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, CLOUD_SQL_MSSQL_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
toolsFile = tests.AddMssqlExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMssqlTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CLOUD_SQL_MSSQL_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined)
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CLOUD_SQL_MSSQL_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
|
||||
@@ -128,7 +128,7 @@ func TestCloudSQLMysqlToolEndpoints(t *testing.T) {
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, CLOUD_SQL_MYSQL_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
toolsFile = tests.AddMySqlExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMysqlTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CLOUD_SQL_MYSQL_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined)
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CLOUD_SQL_MYSQL_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
|
||||
@@ -132,7 +132,7 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, CLOUD_SQL_POSTGRES_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
toolsFile = tests.AddPgExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CLOUD_SQL_POSTGRES_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined)
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CLOUD_SQL_POSTGRES_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
|
||||
@@ -129,11 +129,17 @@ func AddPgExecuteSqlConfig(t *testing.T, config map[string]any) map[string]any {
|
||||
return config
|
||||
}
|
||||
|
||||
func AddTemplateParamConfig(t *testing.T, config map[string]any, toolKind, tmplSelectCombined, tmplSelectFilterCombined string) map[string]any {
|
||||
func AddTemplateParamConfig(t *testing.T, config map[string]any, toolKind, tmplSelectCombined, tmplSelectFilterCombined string, tmplSelectAll string) map[string]any {
|
||||
toolsMap, ok := config["tools"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("unable to get tools from config")
|
||||
}
|
||||
|
||||
selectAll := "SELECT * FROM {{.tableName}} ORDER BY id"
|
||||
if tmplSelectAll != "" {
|
||||
selectAll = tmplSelectAll
|
||||
}
|
||||
|
||||
toolsMap["create-table-templateParams-tool"] = map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
@@ -159,7 +165,7 @@ func AddTemplateParamConfig(t *testing.T, config map[string]any, toolKind, tmplS
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Create table tool with template parameters",
|
||||
"statement": "SELECT * FROM {{.tableName}} ORDER BY id",
|
||||
"statement": selectAll,
|
||||
"templateParameters": []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "some description"),
|
||||
},
|
||||
|
||||
@@ -100,6 +100,7 @@ func TestCouchbaseToolEndpoints(t *testing.T) {
|
||||
// Create collection names with UUID
|
||||
collectionNameParam := "param_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
collectionNameAuth := "auth_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
collectionNameTemplateParam := "template_param_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// Set up data for param tool
|
||||
paramToolStatement, params1 := getCouchbaseParamToolInfo(collectionNameParam)
|
||||
@@ -111,8 +112,14 @@ func TestCouchbaseToolEndpoints(t *testing.T) {
|
||||
teardownCollection2 := setupCouchbaseCollection(t, ctx, cluster, couchbaseBucket, couchbaseScope, collectionNameAuth, params2)
|
||||
defer teardownCollection2(t)
|
||||
|
||||
// Setup up table for template param tool
|
||||
tmplSelectCombined, tmplSelectFilterCombined, tmplSelectAll, params3 := getCouchbaseTemplateParamToolInfo()
|
||||
teardownCollection3 := setupCouchbaseCollection(t, ctx, cluster, couchbaseBucket, couchbaseScope, collectionNameTemplateParam, params3)
|
||||
defer teardownCollection3(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, couchbaseToolKind, paramToolStatement, authToolStatement)
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, couchbaseToolKind, tmplSelectCombined, tmplSelectFilterCombined, tmplSelectAll)
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
@@ -136,6 +143,14 @@ func TestCouchbaseToolEndpoints(t *testing.T) {
|
||||
invokeParamWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failMcpInvocationWant)
|
||||
|
||||
templateParamTestConfig := tests.NewTemplateParameterTestConfig(
|
||||
tests.WithIgnoreDdl(),
|
||||
tests.WithIgnoreInsert(),
|
||||
tests.WithSelect1Want("[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]"),
|
||||
tests.WithSelectAllWant("[{\"age\":21,\"id\":1,\"name\":\"Alex\"},{\"age\":100,\"id\":2,\"name\":\"Alice\"}]"),
|
||||
)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, collectionNameTemplateParam, templateParamTestConfig)
|
||||
}
|
||||
|
||||
// setupCouchbaseCollection creates a scope and collection and inserts test data
|
||||
@@ -240,3 +255,15 @@ func getCouchbaseAuthToolInfo(collectionName string) (string, []map[string]any)
|
||||
}
|
||||
return toolStatement, params
|
||||
}
|
||||
|
||||
func getCouchbaseTemplateParamToolInfo() (string, string, string, []map[string]any) {
|
||||
tmplSelectCombined := "SELECT {{.tableName}}.* FROM {{.tableName}} WHERE id = $id"
|
||||
tmplSelectFilterCombined := "SELECT {{.tableName}}.* FROM {{.tableName}} WHERE {{.columnFilter}} = $name"
|
||||
tmplSelectAll := "SELECT {{.tableName}}.* FROM {{.tableName}}"
|
||||
|
||||
params := []map[string]any{
|
||||
{"name": "Alex", "id": 1, "age": 21},
|
||||
{"name": "Alice", "id": 2, "age": 100},
|
||||
}
|
||||
return tmplSelectCombined, tmplSelectFilterCombined, tmplSelectAll, params
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ func TestMssqlToolEndpoints(t *testing.T) {
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, MSSQL_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
toolsFile = tests.AddMssqlExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMssqlTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, MSSQL_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined)
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, MSSQL_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
|
||||
@@ -105,7 +105,7 @@ func TestMysqlToolEndpoints(t *testing.T) {
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, MYSQL_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
toolsFile = tests.AddMySqlExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMysqlTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, MYSQL_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined)
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, MYSQL_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
|
||||
@@ -105,7 +105,7 @@ func TestPostgres(t *testing.T) {
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, POSTGRES_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
toolsFile = tests.AddPgExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, POSTGRES_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined)
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, POSTGRES_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
|
||||
@@ -133,7 +133,7 @@ func TestSQLiteToolEndpoint(t *testing.T) {
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, SQLITE_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := getSQLiteTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, SQLITE_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined)
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, SQLITE_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
|
||||
@@ -393,6 +393,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, config
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
|
||||
Reference in New Issue
Block a user