feat: add templateParameters field for mysqlsql (#663)

Add `templateParameters to support non-filter parameters and DDL
statements.

Part of #535
This commit is contained in:
Yuan
2025-06-05 14:37:30 -07:00
committed by GitHub
parent 71250e1ced
commit 0a08d2c15d
5 changed files with 139 additions and 26 deletions

View File

@@ -53,13 +53,14 @@ var _ compatibleSource = &mysql.Source{}
var compatibleSources = [...]string{cloudsqlmysql.SourceKind, mysql.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description" validate:"required"`
Statement string `yaml:"statement" validate:"required"`
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description" validate:"required"`
Statement string `yaml:"statement" validate:"required"`
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
TemplateParameters tools.Parameters `yaml:"templateParameters"`
}
// validate interface
@@ -82,22 +83,26 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
allParameters, paramManifest, paramMcpManifest := tools.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
mcpManifest := tools.McpManifest{
Name: cfg.Name,
Description: cfg.Description,
InputSchema: cfg.Parameters.McpManifest(),
InputSchema: paramMcpManifest,
}
// finish tool setup
t := Tool{
Name: cfg.Name,
Kind: kind,
Parameters: cfg.Parameters,
Statement: cfg.Statement,
AuthRequired: cfg.AuthRequired,
Pool: s.MySQLPool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Name: cfg.Name,
Kind: kind,
Parameters: cfg.Parameters,
TemplateParameters: cfg.TemplateParameters,
AllParams: allParameters,
Statement: cfg.Statement,
AuthRequired: cfg.AuthRequired,
Pool: s.MySQLPool(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -106,10 +111,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
var _ tools.Tool = Tool{}
type Tool struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
TemplateParameters tools.Parameters `yaml:"templateParameters"`
AllParams tools.Parameters `yaml:"allParams"`
Pool *sql.DB
Statement string
@@ -118,9 +125,19 @@ type Tool struct {
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
sliceParams := params.AsSlice()
paramsMap := params.AsMap()
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {
return nil, fmt.Errorf("unable to extract template params %w", err)
}
results, err := t.Pool.QueryContext(ctx, t.Statement, sliceParams...)
newParams, err := tools.GetParams(t.Parameters, paramsMap)
if err != nil {
return nil, fmt.Errorf("unable to extract standard params %w", err)
}
sliceParams := newParams.AsSlice()
results, err := t.Pool.QueryContext(ctx, newStatement, sliceParams...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
@@ -179,7 +196,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, erro
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
return tools.ParseParams(t.Parameters, data, claims)
return tools.ParseParams(t.AllParams, data, claims)
}
func (t Tool) Manifest() tools.Manifest {

View File

@@ -90,5 +90,86 @@ func TestParseFromYamlMySQL(t *testing.T) {
}
})
}
}
func TestParseFromYamlWithTemplateParamsMySQL(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
desc string
in string
want server.ToolConfigs
}{
{
desc: "basic example",
in: `
tools:
example_tool:
kind: mysql-sql
source: my-mysql-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
authRequired:
- my-google-auth-service
- other-auth-service
parameters:
- name: country
type: string
description: some description
authServices:
- name: my-google-auth-service
field: user_id
- name: other-auth-service
field: user_id
templateParameters:
- name: tableName
type: string
description: The table to select hotels from.
- name: fieldArray
type: array
description: The columns to return for the query.
items:
name: column
type: string
description: A column name that will be returned from the query.
`,
want: server.ToolConfigs{
"example_tool": mysqlsql.Config{
Name: "example_tool",
Kind: "mysql-sql",
Source: "my-mysql-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
Parameters: []tools.Parameter{
tools.NewStringParameterWithAuth("country", "some description",
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"},
{Name: "other-auth-service", Field: "user_id"}}),
},
TemplateParameters: []tools.Parameter{
tools.NewStringParameter("tableName", "The table to select hotels from."),
tools.NewArrayParameter("fieldArray", "The columns to return for the query.", tools.NewStringParameter("column", "A column name that will be returned from the query.")),
},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})
}
}

View File

@@ -112,6 +112,7 @@ func TestCloudSQLMysqlToolEndpoints(t *testing.T) {
// create table name with UUID
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
// set up data for param tool
create_statement1, insert_statement1, tool_statement1, params1 := tests.GetMysqlParamToolInfo(tableNameParam)
@@ -126,6 +127,8 @@ func TestCloudSQLMysqlToolEndpoints(t *testing.T) {
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, CLOUD_SQL_MYSQL_TOOL_KIND, tool_statement1, tool_statement2)
toolsFile = tests.AddMySqlExecuteSqlConfig(t, toolsFile)
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMysqlSQLTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CLOUD_SQL_MYSQL_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
@@ -148,6 +151,7 @@ func TestCloudSQLMysqlToolEndpoints(t *testing.T) {
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
}
// Test connection with different IP type

View File

@@ -297,7 +297,7 @@ func GetMssqlLAuthToolInfo(tableName string) (string, string, string, []any) {
return create_statement, insert_statement, tool_statement, params
}
// GetMysqlParamToolInfo returns statements and param for my-param-tool mssql-sql kind
// GetMysqlParamToolInfo returns statements and param for my-param-tool mysql-sql kind
func GetMysqlParamToolInfo(tableName string) (string, string, string, []any) {
create_statement := fmt.Sprintf("CREATE TABLE %s (id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255));", tableName)
insert_statement := fmt.Sprintf("INSERT INTO %s (name) VALUES (?), (?), (?);", tableName)
@@ -306,7 +306,7 @@ func GetMysqlParamToolInfo(tableName string) (string, string, string, []any) {
return create_statement, insert_statement, tool_statement, params
}
// GetMysqlLAuthToolInfo returns statements and param of my-auth-tool for mssql-sql kind
// GetMysqlLAuthToolInfo returns statements and param of my-auth-tool for mysql-sql kind
func GetMysqlLAuthToolInfo(tableName string) (string, string, string, []any) {
create_statement := fmt.Sprintf("CREATE TABLE %s (id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255), email VARCHAR(255));", tableName)
insert_statement := fmt.Sprintf("INSERT INTO %s (name, email) VALUES (?, ?), (?, ?)", tableName)
@@ -315,6 +315,13 @@ func GetMysqlLAuthToolInfo(tableName string) (string, string, string, []any) {
return create_statement, insert_statement, tool_statement, params
}
// GetMysqlSQLTmplToolStatement returns statements and param for template parameter test cases for mysql-sql kind
func GetMysqlSQLTmplToolStatement() (string, string) {
tmplSelectCombined := "SELECT * FROM {{.tableName}} WHERE id = ?"
tmplSelectFilterCombined := "SELECT * FROM {{.tableName}} WHERE {{.columnFilter}} = ?"
return tmplSelectCombined, tmplSelectFilterCombined
}
func GetNonSpannerInvokeParamWant() (string, string) {
invokeParamWant := "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]"
mcpInvokeParamWant := `{"jsonrpc":"2.0","id":"my-param-tool","result":{"content":[{"type":"text","text":"{\"id\":1,\"name\":\"Alice\"}"},{"type":"text","text":"{\"id\":3,\"name\":\"Sid\"}"}]}}`

View File

@@ -89,6 +89,7 @@ func TestMysqlToolEndpoints(t *testing.T) {
// create table name with UUID
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
// set up data for param tool
create_statement1, insert_statement1, tool_statement1, params1 := tests.GetMysqlParamToolInfo(tableNameParam)
@@ -103,6 +104,8 @@ func TestMysqlToolEndpoints(t *testing.T) {
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, MYSQL_TOOL_KIND, tool_statement1, tool_statement2)
toolsFile = tests.AddMySqlExecuteSqlConfig(t, toolsFile)
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMysqlSQLTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, MYSQL_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
@@ -125,4 +128,5 @@ func TestMysqlToolEndpoints(t *testing.T) {
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
}