mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 15:38:08 -05:00
feat(tools/bigquery): add templateParameters field for bigquery (#699)
Add templateParameters to support non-filter parameters and DDL statements. Part of #535
This commit is contained in:
@@ -53,13 +53,14 @@ var _ compatibleSource = &bigqueryds.Source{}
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Statement string `yaml:"statement" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Statement string `yaml:"statement" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
@@ -82,22 +83,26 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, paramMcpManifest := tools.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: cfg.Parameters.McpManifest(),
|
||||
InputSchema: paramMcpManifest,
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: cfg.Parameters,
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Client: s.BigQueryClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: cfg.Parameters,
|
||||
TemplateParameters: cfg.TemplateParameters,
|
||||
AllParams: allParameters,
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Client: s.BigQueryClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -106,10 +111,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
Statement string
|
||||
@@ -118,11 +125,22 @@ type Tool struct {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
namedArgs := make([]bigqueryapi.QueryParameter, 0, len(params))
|
||||
paramsMap := params.AsReversedMap()
|
||||
for _, v := range params.AsSlice() {
|
||||
paramName := paramsMap[v]
|
||||
if strings.Contains(t.Statement, "@"+paramName) {
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract template params %w", err)
|
||||
}
|
||||
|
||||
newParams, err := tools.GetParams(t.Parameters, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
}
|
||||
|
||||
namedArgs := make([]bigqueryapi.QueryParameter, 0, len(newParams))
|
||||
newParamsMap := newParams.AsReversedMap()
|
||||
for _, v := range newParams.AsSlice() {
|
||||
paramName := newParamsMap[v]
|
||||
if strings.Contains(newStatement, "@"+paramName) {
|
||||
namedArgs = append(namedArgs, bigqueryapi.QueryParameter{
|
||||
Name: paramName,
|
||||
Value: v,
|
||||
@@ -134,7 +152,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, erro
|
||||
}
|
||||
}
|
||||
|
||||
query := t.Client.Query(t.Statement)
|
||||
query := t.Client.Query(newStatement)
|
||||
query.Parameters = namedArgs
|
||||
query.Location = t.Client.Location
|
||||
|
||||
@@ -164,7 +182,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, erro
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claims)
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
|
||||
@@ -82,3 +82,76 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestParseFromYamlWithTemplateBigQuery(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: bigquery-sql
|
||||
source: my-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
templateParameters:
|
||||
- name: tableName
|
||||
type: string
|
||||
description: The table to select hotels from.
|
||||
- name: fieldArray
|
||||
type: array
|
||||
description: The columns to return for the query.
|
||||
items:
|
||||
name: column
|
||||
type: string
|
||||
description: A column name that will be returned from the query.
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": bigquery.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "bigquery-sql",
|
||||
Source: "my-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
AuthRequired: []string{},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameter("country", "some description"),
|
||||
},
|
||||
TemplateParameters: []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "The table to select hotels from."),
|
||||
tools.NewArrayParameter("fieldArray", "The columns to return for the query.", tools.NewStringParameter("column", "A column name that will be returned from the query.")),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -94,6 +94,12 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
datasetName,
|
||||
strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
)
|
||||
tableNameTemplateParam := fmt.Sprintf("`%s.%s.template_param_table_%s`",
|
||||
BIGQUERY_PROJECT,
|
||||
datasetName,
|
||||
strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
)
|
||||
|
||||
// set up data for param tool
|
||||
create_statement1, insert_statement1, tool_statement1, params1 := getBigQueryParamToolInfo(tableNameParam)
|
||||
teardownTable1 := setupBigQueryTable(t, ctx, client, create_statement1, insert_statement1, datasetName, tableNameParam, params1)
|
||||
@@ -107,6 +113,8 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, BIGQUERY_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
toolsFile = addBigQueryPrebuiltToolsConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := getBigQueryTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, BIGQUERY_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined)
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
@@ -132,6 +140,11 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
invokeParamWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
templateParamTestConfig := tests.NewTemplateParameterTestConfig(
|
||||
tests.WithCreateColArray(`["id INT64", "name STRING", "age INT64"]`),
|
||||
)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, templateParamTestConfig)
|
||||
|
||||
runBigQueryExecuteSqlToolInvokeTest(t, select1Want, invokeParamWant, tableNameParam)
|
||||
runBigQueryListDatasetToolInvokeTest(t, datasetName)
|
||||
runBigQueryGetDatasetInfoToolInvokeTest(t, datasetName, datasetInfoWant)
|
||||
@@ -169,6 +182,13 @@ func getBigQueryAuthToolInfo(tableName string) (string, string, string, []bigque
|
||||
return createStatement, insertStatement, toolStatement, params
|
||||
}
|
||||
|
||||
// getBigQueryTmplToolStatement returns statements for template parameter test cases for bigquery kind
|
||||
func getBigQueryTmplToolStatement() (string, string) {
|
||||
tmplSelectCombined := "SELECT * FROM {{.tableName}} WHERE id = ? ORDER BY id"
|
||||
tmplSelectFilterCombined := "SELECT * FROM {{.tableName}} WHERE {{.columnFilter}} = ? ORDER BY id"
|
||||
return tmplSelectCombined, tmplSelectFilterCombined
|
||||
}
|
||||
|
||||
func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.Client, create_statement, insert_statement, datasetName string, tableName string, params []bigqueryapi.QueryParameter) func(*testing.T) {
|
||||
// Create dataset
|
||||
dataset := client.Dataset(datasetName)
|
||||
|
||||
@@ -159,7 +159,7 @@ func AddTemplateParamConfig(t *testing.T, config map[string]any, toolKind, tmplS
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Create table tool with template parameters",
|
||||
"statement": "SELECT * FROM {{.tableName}}",
|
||||
"statement": "SELECT * FROM {{.tableName}} ORDER BY id",
|
||||
"templateParameters": []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "some description"),
|
||||
},
|
||||
@@ -178,7 +178,7 @@ func AddTemplateParamConfig(t *testing.T, config map[string]any, toolKind, tmplS
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Create table tool with template parameters",
|
||||
"statement": "SELECT {{array .fields}} FROM {{.tableName}}",
|
||||
"statement": "SELECT {{array .fields}} FROM {{.tableName}} ORDER BY id",
|
||||
"templateParameters": []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "some description"),
|
||||
tools.NewArrayParameter("fields", "The fields to select from", tools.NewStringParameter("field", "A field that will be returned from the query.")),
|
||||
|
||||
@@ -219,6 +219,7 @@ type TemplateParameterTestConfig struct {
|
||||
select1Want string
|
||||
nameFieldArray string
|
||||
nameColFilter string
|
||||
createColArray string
|
||||
}
|
||||
|
||||
type Option func(*TemplateParameterTestConfig)
|
||||
@@ -265,6 +266,13 @@ func WithReplaceNameColFilter(s string) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithCreateColArray is the option function to configure replaceNameColFilter.
|
||||
func WithCreateColArray(s string) Option {
|
||||
return func(c *TemplateParameterTestConfig) {
|
||||
c.createColArray = s
|
||||
}
|
||||
}
|
||||
|
||||
// NewTemplateParameterTestConfig creates a new TemplateParameterTestConfig instances with options.
|
||||
func NewTemplateParameterTestConfig(options ...Option) *TemplateParameterTestConfig {
|
||||
templateParamTestOption := &TemplateParameterTestConfig{
|
||||
@@ -274,6 +282,7 @@ func NewTemplateParameterTestConfig(options ...Option) *TemplateParameterTestCon
|
||||
select1Want: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]",
|
||||
nameFieldArray: `["name"]`,
|
||||
nameColFilter: "name",
|
||||
createColArray: `["id INT","name VARCHAR(20)","age INT"]`,
|
||||
}
|
||||
|
||||
// Apply provided options
|
||||
@@ -304,7 +313,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, config
|
||||
ddl: true,
|
||||
api: "http://127.0.0.1:5000/api/tool/create-table-templateParams-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "columns":["id INT","name VARCHAR(20)","age INT"]}`, tableName))),
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "columns":%s}`, tableName, config.createColArray))),
|
||||
want: "null",
|
||||
isErr: false,
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user