diff --git a/internal/tools/sqlitesql/sqlitesql.go b/internal/tools/sqlitesql/sqlitesql.go index 254b303ea5..6cb35c4e98 100644 --- a/internal/tools/sqlitesql/sqlitesql.go +++ b/internal/tools/sqlitesql/sqlitesql.go @@ -51,13 +51,14 @@ var _ compatibleSource = &sqlite.Source{} var compatibleSources = [...]string{sqlite.SourceKind} type Config struct { - Name string `yaml:"name" validate:"required"` - Kind string `yaml:"kind" validate:"required"` - Source string `yaml:"source" validate:"required"` - Description string `yaml:"description" validate:"required"` - Statement string `yaml:"statement" validate:"required"` - AuthRequired []string `yaml:"authRequired"` - Parameters tools.Parameters `yaml:"parameters"` + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + Statement string `yaml:"statement" validate:"required"` + AuthRequired []string `yaml:"authRequired"` + Parameters tools.Parameters `yaml:"parameters"` + TemplateParameters tools.Parameters `yaml:"templateParameters"` } // validate interface @@ -80,22 +81,26 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) } + allParameters, paramManifest, paramMcpManifest := tools.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) + mcpManifest := tools.McpManifest{ Name: cfg.Name, Description: cfg.Description, - InputSchema: cfg.Parameters.McpManifest(), + InputSchema: paramMcpManifest, } // finish tool setup t := Tool{ - Name: cfg.Name, - Kind: kind, - Parameters: cfg.Parameters, - Statement: cfg.Statement, - AuthRequired: cfg.AuthRequired, - Db: s.SQLiteDB(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Name: cfg.Name, + Kind: kind, + Parameters: cfg.Parameters, + TemplateParameters: cfg.TemplateParameters, + AllParams: allParameters, + Statement: cfg.Statement, + AuthRequired: cfg.AuthRequired, + Db: s.SQLiteDB(), + manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -104,10 +109,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) var _ tools.Tool = Tool{} type Tool struct { - Name string `yaml:"name"` - Kind string `yaml:"kind"` - AuthRequired []string `yaml:"authRequired"` - Parameters tools.Parameters `yaml:"parameters"` + Name string `yaml:"name"` + Kind string `yaml:"kind"` + AuthRequired []string `yaml:"authRequired"` + Parameters tools.Parameters `yaml:"parameters"` + TemplateParameters tools.Parameters `yaml:"templateParameters"` + AllParams tools.Parameters `yaml:"allParams"` Db *sql.DB Statement string `yaml:"statement"` @@ -116,8 +123,19 @@ type Tool struct { } func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) { + paramsMap := params.AsMap() + newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) + if err != nil { + return nil, fmt.Errorf("unable to extract template params %w", err) + } + + newParams, err := tools.GetParams(t.Parameters, paramsMap) + if err != nil { + return nil, fmt.Errorf("unable to extract standard params %w", err) + } + // Execute the SQL query with parameters - rows, err := t.Db.QueryContext(ctx, t.Statement, params.AsSlice()...) + rows, err := t.Db.QueryContext(ctx, newStatement, newParams.AsSlice()...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -171,7 +189,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, erro } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) { - return tools.ParseParams(t.Parameters, data, claims) + return tools.ParseParams(t.AllParams, data, claims) } func (t Tool) Manifest() tools.Manifest { diff --git a/internal/tools/sqlitesql/sqlitesql_test.go b/internal/tools/sqlitesql/sqlitesql_test.go index 7da958e1a1..4f4fcfec21 100644 --- a/internal/tools/sqlitesql/sqlitesql_test.go +++ b/internal/tools/sqlitesql/sqlitesql_test.go @@ -92,3 +92,85 @@ func TestParseFromYamlSQLite(t *testing.T) { } } + +func TestParseFromYamlWithTemplateSqlite(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: sqlite-sql + source: my-sqlite-db + description: some description + statement: | + SELECT * FROM SQL_STATEMENT; + authRequired: + - my-google-auth-service + - other-auth-service + parameters: + - name: country + type: string + description: some description + authServices: + - name: my-google-auth-service + field: user_id + - name: other-auth-service + field: user_id + templateParameters: + - name: tableName + type: string + description: The table to select hotels from. + - name: fieldArray + type: array + description: The columns to return for the query. + items: + name: column + type: string + description: A column name that will be returned from the query. + `, + want: server.ToolConfigs{ + "example_tool": sqlitesql.Config{ + Name: "example_tool", + Kind: "sqlite-sql", + Source: "my-sqlite-db", + Description: "some description", + Statement: "SELECT * FROM SQL_STATEMENT;\n", + AuthRequired: []string{"my-google-auth-service", "other-auth-service"}, + Parameters: []tools.Parameter{ + tools.NewStringParameterWithAuth("country", "some description", + []tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"}, + {Name: "other-auth-service", Field: "user_id"}}), + }, + TemplateParameters: []tools.Parameter{ + tools.NewStringParameter("tableName", "The table to select hotels from."), + tools.NewArrayParameter("fieldArray", "The columns to return for the query.", tools.NewStringParameter("column", "A column name that will be returned from the query.")), + }, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/tests/sqlite/sqlite_integration_test.go b/tests/sqlite/sqlite_integration_test.go index db1e80d5d5..c9a7a12309 100644 --- a/tests/sqlite/sqlite_integration_test.go +++ b/tests/sqlite/sqlite_integration_test.go @@ -96,6 +96,12 @@ func getSQLiteAuthToolInfo(tableName string) (string, string, string, []any) { return create_statement, insert_statement, tool_statement, params } +func getSQLiteTmplToolStatement() (string, string) { + tmplSelectCombined := "SELECT * FROM {{.tableName}} WHERE id = ?" + tmplSelectFilterCombined := "SELECT * FROM {{.tableName}} WHERE {{.columnFilter}} = ?" + return tmplSelectCombined, tmplSelectFilterCombined +} + func TestSQLiteToolEndpoint(t *testing.T) { db, teardownDb, sqliteDb, err := initSQLiteDb(t, SQLITE_DATABASE) if err != nil { @@ -114,6 +120,7 @@ func TestSQLiteToolEndpoint(t *testing.T) { // create table name with UUID tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") // set up data for param tool create_statement1, insert_statement1, tool_statement1, params1 := getSQLiteParamToolInfo(tableNameParam) @@ -125,6 +132,8 @@ func TestSQLiteToolEndpoint(t *testing.T) { // Write config into a file and pass it to command toolsFile := tests.GetToolsConfig(sourceConfig, SQLITE_TOOL_KIND, tool_statement1, tool_statement2) + tmplSelectCombined, tmplSelectFilterCombined := getSQLiteTmplToolStatement() + toolsFile = tests.AddTemplateParamConfig(t, toolsFile, SQLITE_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined) cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) if err != nil { @@ -147,4 +156,5 @@ func TestSQLiteToolEndpoint(t *testing.T) { invokeParamWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant() tests.RunToolInvokeTest(t, select1Want, invokeParamWant) tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant) + tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam) }