diff --git a/internal/server/api.go b/internal/server/api.go index 0be594e314..7b2a447501 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -62,7 +62,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) { toolName := chi.URLParam(r, "toolName") tool, ok := s.tools[toolName] if !ok { - err := fmt.Errorf("Invalid tool name. Tool with name %q does not exist", toolName) + err := fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) _ = render.Render(w, r, newErrResponse(err, http.StatusNotFound)) return } @@ -70,21 +70,21 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) { var data map[string]interface{} if err := render.DecodeJSON(r.Body, &data); err != nil { render.Status(r, http.StatusBadRequest) - err := fmt.Errorf("Request body was invalid JSON: %w", err) + err := fmt.Errorf("request body was invalid JSON: %w", err) _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) return } params, err := tool.ParseParams(data) if err != nil { - err := fmt.Errorf("Provided parameters were invalid: %w", err) + err := fmt.Errorf("provided parameters were invalid: %w", err) _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) return } res, err := tool.Invoke(params) if err != nil { - err := fmt.Errorf("Error while invoking tool: %w", err) + err := fmt.Errorf("error while invoking tool: %w", err) _ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError)) return } diff --git a/internal/sources/cloud_sql_pg.go b/internal/sources/cloud_sql_pg.go index 8fc525ebf0..c127c72d97 100644 --- a/internal/sources/cloud_sql_pg.go +++ b/internal/sources/cloud_sql_pg.go @@ -46,12 +46,12 @@ func (r CloudSQLPgConfig) sourceKind() string { func (r CloudSQLPgConfig) Initialize() (Source, error) { pool, err := initConnectionPool(r.Project, r.Region, r.Instance, r.User, r.Password, r.Database) if err != nil { - return nil, fmt.Errorf("Unable to create pool: %w", err) + return nil, fmt.Errorf("unable to create pool: %w", err) } err = pool.Ping(context.Background()) if err != nil { - return nil, fmt.Errorf("Unable to connect successfully: %w", err) + return nil, fmt.Errorf("unable to connect successfully: %w", err) } s := CloudSQLPgSource{ @@ -72,16 +72,16 @@ type CloudSQLPgSource struct { func initConnectionPool(project, region, instance, user, pass, dbname string) (*pgxpool.Pool, error) { // Configure the driver to connect to the database - dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", user, pass, dbname) + dsn := fmt.Sprintf("user=%s passwgit sord=%s dbname=%s sslmode=disable", user, pass, dbname) config, err := pgxpool.ParseConfig(dsn) if err != nil { - return nil, fmt.Errorf("Unable to parse connection uri: %w", err) + return nil, fmt.Errorf("unable to parse connection uri: %w", err) } // Create a new dialer with any options d, err := cloudsqlconn.NewDialer(context.Background()) if err != nil { - return nil, fmt.Errorf("Unable to parse connection uri: %w", err) + return nil, fmt.Errorf("unable to parse connection uri: %w", err) } // Tell the driver to use the Cloud SQL Go Connector to create connections diff --git a/internal/tools/parameters.go b/internal/tools/parameters.go index 82a1e2b0e9..3a5f304ed2 100644 --- a/internal/tools/parameters.go +++ b/internal/tools/parameters.go @@ -22,11 +22,29 @@ import ( const ( typeString = "string" - typeInt = "int" + typeInt = "integer" typeFloat = "float" - typeBool = "bool" + typeBool = "boolean" + typeArray = "array" ) +// ParseParams is a helper function for parsing Parameters from an arbitraryJSON object. +func ParseParams(ps Parameters, data map[string]any) ([]any, error) { + params := []any{} + for _, p := range ps { + v, ok := data[p.GetName()] + if !ok { + return nil, fmt.Errorf("parameter %q is required!", p.GetName()) + } + newV, err := p.Parse(v) + if err != nil { + return nil, fmt.Errorf("unable to parse value for %q: %w", p.GetName(), err) + } + params = append(params, newV) + } + return params, nil +} + type Parameter interface { // Note: It's typically not idiomatic to include "Get" in the function name, // but this is done to differentiate it from the fields in CommonParameter. @@ -47,44 +65,56 @@ func (c *Parameters) UnmarshalYAML(node *yaml.Node) error { return err } for _, n := range nodeList { - var p CommonParameter - err := n.Decode(&p) + p, err := parseFromYamlNode(&n) if err != nil { - return fmt.Errorf("parameter missing required fields") - } - switch p.Type { - case typeString: - a := &StringParameter{} - if err := n.Decode(a); err != nil { - return fmt.Errorf("unable to parse as %q: %w", p.Type, err) - } - *c = append(*c, a) - case typeInt: - a := &IntParameter{} - if err := n.Decode(a); err != nil { - return fmt.Errorf("unable to parse as %q: %w", p.Type, err) - } - *c = append(*c, a) - case typeFloat: - a := &FloatParameter{} - if err := n.Decode(a); err != nil { - return fmt.Errorf("unable to parse as %q: %w", p.Type, err) - } - *c = append(*c, a) - case typeBool: - a := &BooleanParameter{} - if err := n.Decode(a); err != nil { - return fmt.Errorf("unable to parse as %q: %w", p.Type, err) - } - *c = append(*c, a) - default: - return fmt.Errorf("%q is not valid type for a parameter!", p.GetName()) + return err } + (*c) = append((*c), p) } - return nil } +func parseFromYamlNode(node *yaml.Node) (Parameter, error) { + var p CommonParameter + err := node.Decode(&p) + if err != nil { + return nil, fmt.Errorf("parameter missing required fields") + } + switch p.Type { + case typeString: + a := &StringParameter{} + if err := node.Decode(a); err != nil { + return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err) + } + return a, nil + case typeInt: + a := &IntParameter{} + if err := node.Decode(a); err != nil { + return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err) + } + return a, nil + case typeFloat: + a := &FloatParameter{} + if err := node.Decode(a); err != nil { + return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err) + } + return a, nil + case typeBool: + a := &BooleanParameter{} + if err := node.Decode(a); err != nil { + return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err) + } + return a, nil + case typeArray: + a := &ArrayParameter{} + if err := node.Decode(a); err != nil { + return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err) + } + return a, nil + } + return nil, fmt.Errorf("%q is not valid type for a parameter!", p.Type) +} + func generateManifests(ps []Parameter) []ParameterManifest { rtn := make([]ParameterManifest, 0, len(ps)) for _, p := range ps { @@ -135,7 +165,7 @@ type ParseTypeError struct { } func (e ParseTypeError) Error() string { - return fmt.Sprintf("Error parsing parameter %q: %q not type %q", e.Name, e.Value, e.Type) + return fmt.Sprintf("%q not type %q", e.Value, e.Type) } // NewStringParameter is a convenience function for initializing a StringParameter. @@ -158,11 +188,11 @@ type StringParameter struct { // Parse casts the value "v" as a "string". func (p *StringParameter) Parse(v any) (any, error) { - v, ok := v.(string) + newV, ok := v.(string) if !ok { return nil, &ParseTypeError{p.Name, p.Type, v} } - return v, nil + return newV, nil } // NewIntParameter is a convenience function for initializing a IntParameter. @@ -184,11 +214,11 @@ type IntParameter struct { } func (p *IntParameter) Parse(v any) (any, error) { - v, ok := v.(int64) + newV, ok := v.(int) if !ok { return nil, &ParseTypeError{p.Name, p.Type, v} } - return v, nil + return newV, nil } // NewFloatParameter is a convenience function for initializing a FloatParameter. @@ -210,11 +240,11 @@ type FloatParameter struct { } func (p *FloatParameter) Parse(v any) (any, error) { - v, ok := v.(float64) + newV, ok := v.(float64) if !ok { return nil, &ParseTypeError{p.Name, p.Type, v} } - return v, nil + return newV, nil } // NewBooleanParameter is a convenience function for initializing a BooleanParameter. @@ -236,26 +266,73 @@ type BooleanParameter struct { } func (p *BooleanParameter) Parse(v any) (any, error) { - v, ok := v.(bool) + newV, ok := v.(bool) if !ok { return nil, &ParseTypeError{p.Name, p.Type, v} } - return v, nil + return newV, nil } -// ParseParams is a helper function for parsing Parameters from an arbitraryJSON object. -func ParseParams(ps Parameters, data map[string]any) ([]any, error) { - params := []any{} - for _, p := range ps { - v, ok := data[p.GetName()] - if !ok { - return nil, fmt.Errorf("Parameter %q is required!", p.GetName()) - } - v, err := p.Parse(v) - if err != nil { - return nil, fmt.Errorf("unable to parse value for %q: %w", p.GetName(), err) - } - params = append(params, v) +// NewArrayParameter is a convenience function for initializing an ArrayParameter. +func NewArrayParameter(name, desc string, items Parameter) *ArrayParameter { + return &ArrayParameter{ + CommonParameter: CommonParameter{ + Name: name, + Type: typeArray, + Desc: desc, + }, + Items: items, } - return params, nil +} + +var _ Parameter = &ArrayParameter{} + +// ArrayParameter is a parameter representing the "array" type. +type ArrayParameter struct { + CommonParameter `yaml:",inline"` + Items Parameter `yaml:"items"` +} + +func (p *ArrayParameter) UnmarshalYAML(node *yaml.Node) error { + if err := node.Decode(&p.CommonParameter); err != nil { + return err + } + // Find the node that represents the "items" field name + idx, ok := findIdxByValue(node.Content, "items") + if !ok { + return fmt.Errorf("array parameter missing 'items' field!") + } + // Parse items from the "value" of "items" field + i, err := parseFromYamlNode(node.Content[idx+1]) + if err != nil { + return fmt.Errorf("unable to parse 'items' field: %w", err) + } + p.Items = i + return nil +} + +// findIdxByValue returns the index of the first node where value matches +func findIdxByValue(nodes []*yaml.Node, value string) (int, bool) { + for idx, n := range nodes { + if n.Value == value { + return idx, true + } + } + return 0, false +} + +func (p *ArrayParameter) Parse(v any) (any, error) { + arrVal, ok := v.([]any) + if !ok { + return nil, &ParseTypeError{p.Name, p.Type, arrVal} + } + rtn := make([]any, 0, len(arrVal)) + for idx, val := range arrVal { + val, err := p.Items.Parse(val) + if err != nil { + return nil, fmt.Errorf("unable to parse element #%d: %w", idx, err) + } + rtn = append(rtn, val) + } + return rtn, nil } diff --git a/internal/tools/parameters_test.go b/internal/tools/parameters_test.go index 4a939e45c5..68788c5474 100644 --- a/internal/tools/parameters_test.go +++ b/internal/tools/parameters_test.go @@ -15,6 +15,7 @@ package tools_test import ( + "reflect" "testing" "github.com/google/go-cmp/cmp" @@ -22,15 +23,15 @@ import ( "gopkg.in/yaml.v3" ) -func TestParameters(t *testing.T) { +func TestParametersMarhsall(t *testing.T) { tcs := []struct { name string - in []map[string]string + in []map[string]any want tools.Parameters }{ { name: "string", - in: []map[string]string{ + in: []map[string]any{ { "name": "my_string", "type": "string", @@ -43,10 +44,10 @@ func TestParameters(t *testing.T) { }, { name: "int", - in: []map[string]string{ + in: []map[string]any{ { "name": "my_integer", - "type": "int", + "type": "integer", "description": "this param is an int", }, }, @@ -56,7 +57,7 @@ func TestParameters(t *testing.T) { }, { name: "float", - in: []map[string]string{ + in: []map[string]any{ { "name": "my_float", "type": "float", @@ -69,10 +70,10 @@ func TestParameters(t *testing.T) { }, { name: "bool", - in: []map[string]string{ + in: []map[string]any{ { "name": "my_bool", - "type": "bool", + "type": "boolean", "description": "this param is a boolean", }, }, @@ -80,6 +81,38 @@ func TestParameters(t *testing.T) { tools.NewBooleanParameter("my_bool", "this param is a boolean"), }, }, + { + name: "string array", + in: []map[string]any{ + { + "name": "my_array", + "type": "array", + "description": "this param is an array of strings", + "items": map[string]string{ + "type": "string", + }, + }, + }, + want: tools.Parameters{ + tools.NewArrayParameter("my_array", "this param is an array of strings", tools.NewStringParameter("", "")), + }, + }, + { + name: "float array", + in: []map[string]any{ + { + "name": "my_array", + "type": "array", + "description": "this param is an array of floats", + "items": map[string]string{ + "type": "float", + }, + }, + }, + want: tools.Parameters{ + tools.NewArrayParameter("my_array", "this param is an array of floats", tools.NewFloatParameter("", "")), + }, + }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { @@ -100,3 +133,123 @@ func TestParameters(t *testing.T) { }) } } + +func TestParametersParse(t *testing.T) { + tcs := []struct { + name string + params tools.Parameters + in map[string]any + want []any + }{ + { + name: "string", + params: tools.Parameters{ + tools.NewStringParameter("my_string", "this param is a string"), + }, + in: map[string]any{ + "my_string": "hello world", + }, + want: []any{"hello world"}, + }, + { + name: "not string", + params: tools.Parameters{ + tools.NewStringParameter("my_string", "this param is a string"), + }, + in: map[string]any{ + "my_string": 4, + }, + }, + { + name: "int", + params: tools.Parameters{ + tools.NewIntParameter("my_int", "this param is an int"), + }, + in: map[string]any{ + "my_int": 100, + }, + want: []any{100}, + }, + { + name: "not int", + params: tools.Parameters{ + tools.NewIntParameter("my_int", "this param is an int"), + }, + in: map[string]any{ + "my_int": 14.5, + }, + }, + { + name: "float", + params: tools.Parameters{ + tools.NewFloatParameter("my_float", "this param is a float"), + }, + in: map[string]any{ + "my_float": 1.5, + }, + want: []any{1.5}, + }, + { + name: "not float", + params: tools.Parameters{ + tools.NewFloatParameter("my_float", "this param is a float"), + }, + in: map[string]any{ + "my_float": true, + }, + }, + { + name: "bool", + params: tools.Parameters{ + tools.NewBooleanParameter("my_bool", "this param is a bool"), + }, + in: map[string]any{ + "my_bool": true, + }, + want: []any{true}, + }, + { + name: "bool", + params: tools.Parameters{ + tools.NewBooleanParameter("my_bool", "this param is a bool"), + }, + in: map[string]any{ + "my_bool": "true", + }, + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + // parse map to bytes + data, err := yaml.Marshal(tc.in) + if err != nil { + t.Fatalf("unable to marshal input to yaml: %s", err) + } + // parse bytes to object + var m map[string]any + err = yaml.Unmarshal(data, &m) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + + gotAll, err := tools.ParseParams(tc.params, m) + if err != nil { + if len(tc.want) == 0 { + // error is expected if no items in want + return + } + t.Fatalf("unexpected error from ParseParams: %s", err) + } + for i, got := range gotAll { + want := tc.want[i] + if got != want { + t.Fatalf("unexpected value: got %q, want %q", got, want) + } + gotType, wantType := reflect.TypeOf(got), reflect.TypeOf(want) + if gotType != wantType { + t.Fatalf("unexpected value: got %q, want %q", got, want) + } + } + }) + } +}