feat: add support for array type parameters (#26)

Adds support for "array" type parameters. Uses a subet of JSONSchema for
specification, in that arrays can be specified in the following way:

```yaml
parameters:
    name: "my_array"
    type: "array"
    description: "some description"
    items:
       type: "integer"
```
This commit is contained in:
Kurtis Van Gent
2024-10-25 15:54:14 -06:00
committed by GitHub
parent 91664d2cb7
commit 3903e860bc
4 changed files with 304 additions and 74 deletions

View File

@@ -62,7 +62,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
toolName := chi.URLParam(r, "toolName")
tool, ok := s.tools[toolName]
if !ok {
err := fmt.Errorf("Invalid tool name. Tool with name %q does not exist", toolName)
err := fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
return
}
@@ -70,21 +70,21 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
var data map[string]interface{}
if err := render.DecodeJSON(r.Body, &data); err != nil {
render.Status(r, http.StatusBadRequest)
err := fmt.Errorf("Request body was invalid JSON: %w", err)
err := fmt.Errorf("request body was invalid JSON: %w", err)
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
return
}
params, err := tool.ParseParams(data)
if err != nil {
err := fmt.Errorf("Provided parameters were invalid: %w", err)
err := fmt.Errorf("provided parameters were invalid: %w", err)
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
return
}
res, err := tool.Invoke(params)
if err != nil {
err := fmt.Errorf("Error while invoking tool: %w", err)
err := fmt.Errorf("error while invoking tool: %w", err)
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
return
}

View File

@@ -46,12 +46,12 @@ func (r CloudSQLPgConfig) sourceKind() string {
func (r CloudSQLPgConfig) Initialize() (Source, error) {
pool, err := initConnectionPool(r.Project, r.Region, r.Instance, r.User, r.Password, r.Database)
if err != nil {
return nil, fmt.Errorf("Unable to create pool: %w", err)
return nil, fmt.Errorf("unable to create pool: %w", err)
}
err = pool.Ping(context.Background())
if err != nil {
return nil, fmt.Errorf("Unable to connect successfully: %w", err)
return nil, fmt.Errorf("unable to connect successfully: %w", err)
}
s := CloudSQLPgSource{
@@ -72,16 +72,16 @@ type CloudSQLPgSource struct {
func initConnectionPool(project, region, instance, user, pass, dbname string) (*pgxpool.Pool, error) {
// Configure the driver to connect to the database
dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", user, pass, dbname)
dsn := fmt.Sprintf("user=%s passwgit sord=%s dbname=%s sslmode=disable", user, pass, dbname)
config, err := pgxpool.ParseConfig(dsn)
if err != nil {
return nil, fmt.Errorf("Unable to parse connection uri: %w", err)
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
}
// Create a new dialer with any options
d, err := cloudsqlconn.NewDialer(context.Background())
if err != nil {
return nil, fmt.Errorf("Unable to parse connection uri: %w", err)
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
}
// Tell the driver to use the Cloud SQL Go Connector to create connections

View File

@@ -22,11 +22,29 @@ import (
const (
typeString = "string"
typeInt = "int"
typeInt = "integer"
typeFloat = "float"
typeBool = "bool"
typeBool = "boolean"
typeArray = "array"
)
// ParseParams is a helper function for parsing Parameters from an arbitraryJSON object.
func ParseParams(ps Parameters, data map[string]any) ([]any, error) {
params := []any{}
for _, p := range ps {
v, ok := data[p.GetName()]
if !ok {
return nil, fmt.Errorf("parameter %q is required!", p.GetName())
}
newV, err := p.Parse(v)
if err != nil {
return nil, fmt.Errorf("unable to parse value for %q: %w", p.GetName(), err)
}
params = append(params, newV)
}
return params, nil
}
type Parameter interface {
// Note: It's typically not idiomatic to include "Get" in the function name,
// but this is done to differentiate it from the fields in CommonParameter.
@@ -47,44 +65,56 @@ func (c *Parameters) UnmarshalYAML(node *yaml.Node) error {
return err
}
for _, n := range nodeList {
var p CommonParameter
err := n.Decode(&p)
p, err := parseFromYamlNode(&n)
if err != nil {
return fmt.Errorf("parameter missing required fields")
}
switch p.Type {
case typeString:
a := &StringParameter{}
if err := n.Decode(a); err != nil {
return fmt.Errorf("unable to parse as %q: %w", p.Type, err)
}
*c = append(*c, a)
case typeInt:
a := &IntParameter{}
if err := n.Decode(a); err != nil {
return fmt.Errorf("unable to parse as %q: %w", p.Type, err)
}
*c = append(*c, a)
case typeFloat:
a := &FloatParameter{}
if err := n.Decode(a); err != nil {
return fmt.Errorf("unable to parse as %q: %w", p.Type, err)
}
*c = append(*c, a)
case typeBool:
a := &BooleanParameter{}
if err := n.Decode(a); err != nil {
return fmt.Errorf("unable to parse as %q: %w", p.Type, err)
}
*c = append(*c, a)
default:
return fmt.Errorf("%q is not valid type for a parameter!", p.GetName())
return err
}
(*c) = append((*c), p)
}
return nil
}
func parseFromYamlNode(node *yaml.Node) (Parameter, error) {
var p CommonParameter
err := node.Decode(&p)
if err != nil {
return nil, fmt.Errorf("parameter missing required fields")
}
switch p.Type {
case typeString:
a := &StringParameter{}
if err := node.Decode(a); err != nil {
return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err)
}
return a, nil
case typeInt:
a := &IntParameter{}
if err := node.Decode(a); err != nil {
return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err)
}
return a, nil
case typeFloat:
a := &FloatParameter{}
if err := node.Decode(a); err != nil {
return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err)
}
return a, nil
case typeBool:
a := &BooleanParameter{}
if err := node.Decode(a); err != nil {
return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err)
}
return a, nil
case typeArray:
a := &ArrayParameter{}
if err := node.Decode(a); err != nil {
return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err)
}
return a, nil
}
return nil, fmt.Errorf("%q is not valid type for a parameter!", p.Type)
}
func generateManifests(ps []Parameter) []ParameterManifest {
rtn := make([]ParameterManifest, 0, len(ps))
for _, p := range ps {
@@ -135,7 +165,7 @@ type ParseTypeError struct {
}
func (e ParseTypeError) Error() string {
return fmt.Sprintf("Error parsing parameter %q: %q not type %q", e.Name, e.Value, e.Type)
return fmt.Sprintf("%q not type %q", e.Value, e.Type)
}
// NewStringParameter is a convenience function for initializing a StringParameter.
@@ -158,11 +188,11 @@ type StringParameter struct {
// Parse casts the value "v" as a "string".
func (p *StringParameter) Parse(v any) (any, error) {
v, ok := v.(string)
newV, ok := v.(string)
if !ok {
return nil, &ParseTypeError{p.Name, p.Type, v}
}
return v, nil
return newV, nil
}
// NewIntParameter is a convenience function for initializing a IntParameter.
@@ -184,11 +214,11 @@ type IntParameter struct {
}
func (p *IntParameter) Parse(v any) (any, error) {
v, ok := v.(int64)
newV, ok := v.(int)
if !ok {
return nil, &ParseTypeError{p.Name, p.Type, v}
}
return v, nil
return newV, nil
}
// NewFloatParameter is a convenience function for initializing a FloatParameter.
@@ -210,11 +240,11 @@ type FloatParameter struct {
}
func (p *FloatParameter) Parse(v any) (any, error) {
v, ok := v.(float64)
newV, ok := v.(float64)
if !ok {
return nil, &ParseTypeError{p.Name, p.Type, v}
}
return v, nil
return newV, nil
}
// NewBooleanParameter is a convenience function for initializing a BooleanParameter.
@@ -236,26 +266,73 @@ type BooleanParameter struct {
}
func (p *BooleanParameter) Parse(v any) (any, error) {
v, ok := v.(bool)
newV, ok := v.(bool)
if !ok {
return nil, &ParseTypeError{p.Name, p.Type, v}
}
return v, nil
return newV, nil
}
// ParseParams is a helper function for parsing Parameters from an arbitraryJSON object.
func ParseParams(ps Parameters, data map[string]any) ([]any, error) {
params := []any{}
for _, p := range ps {
v, ok := data[p.GetName()]
if !ok {
return nil, fmt.Errorf("Parameter %q is required!", p.GetName())
}
v, err := p.Parse(v)
if err != nil {
return nil, fmt.Errorf("unable to parse value for %q: %w", p.GetName(), err)
}
params = append(params, v)
// NewArrayParameter is a convenience function for initializing an ArrayParameter.
func NewArrayParameter(name, desc string, items Parameter) *ArrayParameter {
return &ArrayParameter{
CommonParameter: CommonParameter{
Name: name,
Type: typeArray,
Desc: desc,
},
Items: items,
}
return params, nil
}
var _ Parameter = &ArrayParameter{}
// ArrayParameter is a parameter representing the "array" type.
type ArrayParameter struct {
CommonParameter `yaml:",inline"`
Items Parameter `yaml:"items"`
}
func (p *ArrayParameter) UnmarshalYAML(node *yaml.Node) error {
if err := node.Decode(&p.CommonParameter); err != nil {
return err
}
// Find the node that represents the "items" field name
idx, ok := findIdxByValue(node.Content, "items")
if !ok {
return fmt.Errorf("array parameter missing 'items' field!")
}
// Parse items from the "value" of "items" field
i, err := parseFromYamlNode(node.Content[idx+1])
if err != nil {
return fmt.Errorf("unable to parse 'items' field: %w", err)
}
p.Items = i
return nil
}
// findIdxByValue returns the index of the first node where value matches
func findIdxByValue(nodes []*yaml.Node, value string) (int, bool) {
for idx, n := range nodes {
if n.Value == value {
return idx, true
}
}
return 0, false
}
func (p *ArrayParameter) Parse(v any) (any, error) {
arrVal, ok := v.([]any)
if !ok {
return nil, &ParseTypeError{p.Name, p.Type, arrVal}
}
rtn := make([]any, 0, len(arrVal))
for idx, val := range arrVal {
val, err := p.Items.Parse(val)
if err != nil {
return nil, fmt.Errorf("unable to parse element #%d: %w", idx, err)
}
rtn = append(rtn, val)
}
return rtn, nil
}

View File

@@ -15,6 +15,7 @@
package tools_test
import (
"reflect"
"testing"
"github.com/google/go-cmp/cmp"
@@ -22,15 +23,15 @@ import (
"gopkg.in/yaml.v3"
)
func TestParameters(t *testing.T) {
func TestParametersMarhsall(t *testing.T) {
tcs := []struct {
name string
in []map[string]string
in []map[string]any
want tools.Parameters
}{
{
name: "string",
in: []map[string]string{
in: []map[string]any{
{
"name": "my_string",
"type": "string",
@@ -43,10 +44,10 @@ func TestParameters(t *testing.T) {
},
{
name: "int",
in: []map[string]string{
in: []map[string]any{
{
"name": "my_integer",
"type": "int",
"type": "integer",
"description": "this param is an int",
},
},
@@ -56,7 +57,7 @@ func TestParameters(t *testing.T) {
},
{
name: "float",
in: []map[string]string{
in: []map[string]any{
{
"name": "my_float",
"type": "float",
@@ -69,10 +70,10 @@ func TestParameters(t *testing.T) {
},
{
name: "bool",
in: []map[string]string{
in: []map[string]any{
{
"name": "my_bool",
"type": "bool",
"type": "boolean",
"description": "this param is a boolean",
},
},
@@ -80,6 +81,38 @@ func TestParameters(t *testing.T) {
tools.NewBooleanParameter("my_bool", "this param is a boolean"),
},
},
{
name: "string array",
in: []map[string]any{
{
"name": "my_array",
"type": "array",
"description": "this param is an array of strings",
"items": map[string]string{
"type": "string",
},
},
},
want: tools.Parameters{
tools.NewArrayParameter("my_array", "this param is an array of strings", tools.NewStringParameter("", "")),
},
},
{
name: "float array",
in: []map[string]any{
{
"name": "my_array",
"type": "array",
"description": "this param is an array of floats",
"items": map[string]string{
"type": "float",
},
},
},
want: tools.Parameters{
tools.NewArrayParameter("my_array", "this param is an array of floats", tools.NewFloatParameter("", "")),
},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
@@ -100,3 +133,123 @@ func TestParameters(t *testing.T) {
})
}
}
func TestParametersParse(t *testing.T) {
tcs := []struct {
name string
params tools.Parameters
in map[string]any
want []any
}{
{
name: "string",
params: tools.Parameters{
tools.NewStringParameter("my_string", "this param is a string"),
},
in: map[string]any{
"my_string": "hello world",
},
want: []any{"hello world"},
},
{
name: "not string",
params: tools.Parameters{
tools.NewStringParameter("my_string", "this param is a string"),
},
in: map[string]any{
"my_string": 4,
},
},
{
name: "int",
params: tools.Parameters{
tools.NewIntParameter("my_int", "this param is an int"),
},
in: map[string]any{
"my_int": 100,
},
want: []any{100},
},
{
name: "not int",
params: tools.Parameters{
tools.NewIntParameter("my_int", "this param is an int"),
},
in: map[string]any{
"my_int": 14.5,
},
},
{
name: "float",
params: tools.Parameters{
tools.NewFloatParameter("my_float", "this param is a float"),
},
in: map[string]any{
"my_float": 1.5,
},
want: []any{1.5},
},
{
name: "not float",
params: tools.Parameters{
tools.NewFloatParameter("my_float", "this param is a float"),
},
in: map[string]any{
"my_float": true,
},
},
{
name: "bool",
params: tools.Parameters{
tools.NewBooleanParameter("my_bool", "this param is a bool"),
},
in: map[string]any{
"my_bool": true,
},
want: []any{true},
},
{
name: "bool",
params: tools.Parameters{
tools.NewBooleanParameter("my_bool", "this param is a bool"),
},
in: map[string]any{
"my_bool": "true",
},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
// parse map to bytes
data, err := yaml.Marshal(tc.in)
if err != nil {
t.Fatalf("unable to marshal input to yaml: %s", err)
}
// parse bytes to object
var m map[string]any
err = yaml.Unmarshal(data, &m)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
gotAll, err := tools.ParseParams(tc.params, m)
if err != nil {
if len(tc.want) == 0 {
// error is expected if no items in want
return
}
t.Fatalf("unexpected error from ParseParams: %s", err)
}
for i, got := range gotAll {
want := tc.want[i]
if got != want {
t.Fatalf("unexpected value: got %q, want %q", got, want)
}
gotType, wantType := reflect.TypeOf(got), reflect.TypeOf(want)
if gotType != wantType {
t.Fatalf("unexpected value: got %q, want %q", got, want)
}
}
})
}
}