mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 15:38:08 -05:00
feat: add support for array type parameters (#26)
Adds support for "array" type parameters. Uses a subet of JSONSchema for
specification, in that arrays can be specified in the following way:
```yaml
parameters:
name: "my_array"
type: "array"
description: "some description"
items:
type: "integer"
```
This commit is contained in:
@@ -62,7 +62,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
toolName := chi.URLParam(r, "toolName")
|
||||
tool, ok := s.tools[toolName]
|
||||
if !ok {
|
||||
err := fmt.Errorf("Invalid tool name. Tool with name %q does not exist", toolName)
|
||||
err := fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
|
||||
return
|
||||
}
|
||||
@@ -70,21 +70,21 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
var data map[string]interface{}
|
||||
if err := render.DecodeJSON(r.Body, &data); err != nil {
|
||||
render.Status(r, http.StatusBadRequest)
|
||||
err := fmt.Errorf("Request body was invalid JSON: %w", err)
|
||||
err := fmt.Errorf("request body was invalid JSON: %w", err)
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
|
||||
return
|
||||
}
|
||||
|
||||
params, err := tool.ParseParams(data)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("Provided parameters were invalid: %w", err)
|
||||
err := fmt.Errorf("provided parameters were invalid: %w", err)
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
|
||||
return
|
||||
}
|
||||
|
||||
res, err := tool.Invoke(params)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("Error while invoking tool: %w", err)
|
||||
err := fmt.Errorf("error while invoking tool: %w", err)
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -46,12 +46,12 @@ func (r CloudSQLPgConfig) sourceKind() string {
|
||||
func (r CloudSQLPgConfig) Initialize() (Source, error) {
|
||||
pool, err := initConnectionPool(r.Project, r.Region, r.Instance, r.User, r.Password, r.Database)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to create pool: %w", err)
|
||||
return nil, fmt.Errorf("unable to create pool: %w", err)
|
||||
}
|
||||
|
||||
err = pool.Ping(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to connect successfully: %w", err)
|
||||
return nil, fmt.Errorf("unable to connect successfully: %w", err)
|
||||
}
|
||||
|
||||
s := CloudSQLPgSource{
|
||||
@@ -72,16 +72,16 @@ type CloudSQLPgSource struct {
|
||||
|
||||
func initConnectionPool(project, region, instance, user, pass, dbname string) (*pgxpool.Pool, error) {
|
||||
// Configure the driver to connect to the database
|
||||
dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", user, pass, dbname)
|
||||
dsn := fmt.Sprintf("user=%s passwgit sord=%s dbname=%s sslmode=disable", user, pass, dbname)
|
||||
config, err := pgxpool.ParseConfig(dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to parse connection uri: %w", err)
|
||||
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
|
||||
}
|
||||
|
||||
// Create a new dialer with any options
|
||||
d, err := cloudsqlconn.NewDialer(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to parse connection uri: %w", err)
|
||||
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
|
||||
}
|
||||
|
||||
// Tell the driver to use the Cloud SQL Go Connector to create connections
|
||||
|
||||
@@ -22,11 +22,29 @@ import (
|
||||
|
||||
const (
|
||||
typeString = "string"
|
||||
typeInt = "int"
|
||||
typeInt = "integer"
|
||||
typeFloat = "float"
|
||||
typeBool = "bool"
|
||||
typeBool = "boolean"
|
||||
typeArray = "array"
|
||||
)
|
||||
|
||||
// ParseParams is a helper function for parsing Parameters from an arbitraryJSON object.
|
||||
func ParseParams(ps Parameters, data map[string]any) ([]any, error) {
|
||||
params := []any{}
|
||||
for _, p := range ps {
|
||||
v, ok := data[p.GetName()]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("parameter %q is required!", p.GetName())
|
||||
}
|
||||
newV, err := p.Parse(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse value for %q: %w", p.GetName(), err)
|
||||
}
|
||||
params = append(params, newV)
|
||||
}
|
||||
return params, nil
|
||||
}
|
||||
|
||||
type Parameter interface {
|
||||
// Note: It's typically not idiomatic to include "Get" in the function name,
|
||||
// but this is done to differentiate it from the fields in CommonParameter.
|
||||
@@ -47,44 +65,56 @@ func (c *Parameters) UnmarshalYAML(node *yaml.Node) error {
|
||||
return err
|
||||
}
|
||||
for _, n := range nodeList {
|
||||
var p CommonParameter
|
||||
err := n.Decode(&p)
|
||||
p, err := parseFromYamlNode(&n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parameter missing required fields")
|
||||
}
|
||||
switch p.Type {
|
||||
case typeString:
|
||||
a := &StringParameter{}
|
||||
if err := n.Decode(a); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", p.Type, err)
|
||||
}
|
||||
*c = append(*c, a)
|
||||
case typeInt:
|
||||
a := &IntParameter{}
|
||||
if err := n.Decode(a); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", p.Type, err)
|
||||
}
|
||||
*c = append(*c, a)
|
||||
case typeFloat:
|
||||
a := &FloatParameter{}
|
||||
if err := n.Decode(a); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", p.Type, err)
|
||||
}
|
||||
*c = append(*c, a)
|
||||
case typeBool:
|
||||
a := &BooleanParameter{}
|
||||
if err := n.Decode(a); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", p.Type, err)
|
||||
}
|
||||
*c = append(*c, a)
|
||||
default:
|
||||
return fmt.Errorf("%q is not valid type for a parameter!", p.GetName())
|
||||
return err
|
||||
}
|
||||
(*c) = append((*c), p)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseFromYamlNode(node *yaml.Node) (Parameter, error) {
|
||||
var p CommonParameter
|
||||
err := node.Decode(&p)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parameter missing required fields")
|
||||
}
|
||||
switch p.Type {
|
||||
case typeString:
|
||||
a := &StringParameter{}
|
||||
if err := node.Decode(a); err != nil {
|
||||
return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err)
|
||||
}
|
||||
return a, nil
|
||||
case typeInt:
|
||||
a := &IntParameter{}
|
||||
if err := node.Decode(a); err != nil {
|
||||
return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err)
|
||||
}
|
||||
return a, nil
|
||||
case typeFloat:
|
||||
a := &FloatParameter{}
|
||||
if err := node.Decode(a); err != nil {
|
||||
return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err)
|
||||
}
|
||||
return a, nil
|
||||
case typeBool:
|
||||
a := &BooleanParameter{}
|
||||
if err := node.Decode(a); err != nil {
|
||||
return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err)
|
||||
}
|
||||
return a, nil
|
||||
case typeArray:
|
||||
a := &ArrayParameter{}
|
||||
if err := node.Decode(a); err != nil {
|
||||
return nil, fmt.Errorf("unable to parse as %q: %w", p.Type, err)
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
return nil, fmt.Errorf("%q is not valid type for a parameter!", p.Type)
|
||||
}
|
||||
|
||||
func generateManifests(ps []Parameter) []ParameterManifest {
|
||||
rtn := make([]ParameterManifest, 0, len(ps))
|
||||
for _, p := range ps {
|
||||
@@ -135,7 +165,7 @@ type ParseTypeError struct {
|
||||
}
|
||||
|
||||
func (e ParseTypeError) Error() string {
|
||||
return fmt.Sprintf("Error parsing parameter %q: %q not type %q", e.Name, e.Value, e.Type)
|
||||
return fmt.Sprintf("%q not type %q", e.Value, e.Type)
|
||||
}
|
||||
|
||||
// NewStringParameter is a convenience function for initializing a StringParameter.
|
||||
@@ -158,11 +188,11 @@ type StringParameter struct {
|
||||
|
||||
// Parse casts the value "v" as a "string".
|
||||
func (p *StringParameter) Parse(v any) (any, error) {
|
||||
v, ok := v.(string)
|
||||
newV, ok := v.(string)
|
||||
if !ok {
|
||||
return nil, &ParseTypeError{p.Name, p.Type, v}
|
||||
}
|
||||
return v, nil
|
||||
return newV, nil
|
||||
}
|
||||
|
||||
// NewIntParameter is a convenience function for initializing a IntParameter.
|
||||
@@ -184,11 +214,11 @@ type IntParameter struct {
|
||||
}
|
||||
|
||||
func (p *IntParameter) Parse(v any) (any, error) {
|
||||
v, ok := v.(int64)
|
||||
newV, ok := v.(int)
|
||||
if !ok {
|
||||
return nil, &ParseTypeError{p.Name, p.Type, v}
|
||||
}
|
||||
return v, nil
|
||||
return newV, nil
|
||||
}
|
||||
|
||||
// NewFloatParameter is a convenience function for initializing a FloatParameter.
|
||||
@@ -210,11 +240,11 @@ type FloatParameter struct {
|
||||
}
|
||||
|
||||
func (p *FloatParameter) Parse(v any) (any, error) {
|
||||
v, ok := v.(float64)
|
||||
newV, ok := v.(float64)
|
||||
if !ok {
|
||||
return nil, &ParseTypeError{p.Name, p.Type, v}
|
||||
}
|
||||
return v, nil
|
||||
return newV, nil
|
||||
}
|
||||
|
||||
// NewBooleanParameter is a convenience function for initializing a BooleanParameter.
|
||||
@@ -236,26 +266,73 @@ type BooleanParameter struct {
|
||||
}
|
||||
|
||||
func (p *BooleanParameter) Parse(v any) (any, error) {
|
||||
v, ok := v.(bool)
|
||||
newV, ok := v.(bool)
|
||||
if !ok {
|
||||
return nil, &ParseTypeError{p.Name, p.Type, v}
|
||||
}
|
||||
return v, nil
|
||||
return newV, nil
|
||||
}
|
||||
|
||||
// ParseParams is a helper function for parsing Parameters from an arbitraryJSON object.
|
||||
func ParseParams(ps Parameters, data map[string]any) ([]any, error) {
|
||||
params := []any{}
|
||||
for _, p := range ps {
|
||||
v, ok := data[p.GetName()]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Parameter %q is required!", p.GetName())
|
||||
}
|
||||
v, err := p.Parse(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse value for %q: %w", p.GetName(), err)
|
||||
}
|
||||
params = append(params, v)
|
||||
// NewArrayParameter is a convenience function for initializing an ArrayParameter.
|
||||
func NewArrayParameter(name, desc string, items Parameter) *ArrayParameter {
|
||||
return &ArrayParameter{
|
||||
CommonParameter: CommonParameter{
|
||||
Name: name,
|
||||
Type: typeArray,
|
||||
Desc: desc,
|
||||
},
|
||||
Items: items,
|
||||
}
|
||||
return params, nil
|
||||
}
|
||||
|
||||
var _ Parameter = &ArrayParameter{}
|
||||
|
||||
// ArrayParameter is a parameter representing the "array" type.
|
||||
type ArrayParameter struct {
|
||||
CommonParameter `yaml:",inline"`
|
||||
Items Parameter `yaml:"items"`
|
||||
}
|
||||
|
||||
func (p *ArrayParameter) UnmarshalYAML(node *yaml.Node) error {
|
||||
if err := node.Decode(&p.CommonParameter); err != nil {
|
||||
return err
|
||||
}
|
||||
// Find the node that represents the "items" field name
|
||||
idx, ok := findIdxByValue(node.Content, "items")
|
||||
if !ok {
|
||||
return fmt.Errorf("array parameter missing 'items' field!")
|
||||
}
|
||||
// Parse items from the "value" of "items" field
|
||||
i, err := parseFromYamlNode(node.Content[idx+1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to parse 'items' field: %w", err)
|
||||
}
|
||||
p.Items = i
|
||||
return nil
|
||||
}
|
||||
|
||||
// findIdxByValue returns the index of the first node where value matches
|
||||
func findIdxByValue(nodes []*yaml.Node, value string) (int, bool) {
|
||||
for idx, n := range nodes {
|
||||
if n.Value == value {
|
||||
return idx, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (p *ArrayParameter) Parse(v any) (any, error) {
|
||||
arrVal, ok := v.([]any)
|
||||
if !ok {
|
||||
return nil, &ParseTypeError{p.Name, p.Type, arrVal}
|
||||
}
|
||||
rtn := make([]any, 0, len(arrVal))
|
||||
for idx, val := range arrVal {
|
||||
val, err := p.Items.Parse(val)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse element #%d: %w", idx, err)
|
||||
}
|
||||
rtn = append(rtn, val)
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
package tools_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -22,15 +23,15 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestParameters(t *testing.T) {
|
||||
func TestParametersMarhsall(t *testing.T) {
|
||||
tcs := []struct {
|
||||
name string
|
||||
in []map[string]string
|
||||
in []map[string]any
|
||||
want tools.Parameters
|
||||
}{
|
||||
{
|
||||
name: "string",
|
||||
in: []map[string]string{
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_string",
|
||||
"type": "string",
|
||||
@@ -43,10 +44,10 @@ func TestParameters(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "int",
|
||||
in: []map[string]string{
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_integer",
|
||||
"type": "int",
|
||||
"type": "integer",
|
||||
"description": "this param is an int",
|
||||
},
|
||||
},
|
||||
@@ -56,7 +57,7 @@ func TestParameters(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "float",
|
||||
in: []map[string]string{
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_float",
|
||||
"type": "float",
|
||||
@@ -69,10 +70,10 @@ func TestParameters(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "bool",
|
||||
in: []map[string]string{
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_bool",
|
||||
"type": "bool",
|
||||
"type": "boolean",
|
||||
"description": "this param is a boolean",
|
||||
},
|
||||
},
|
||||
@@ -80,6 +81,38 @@ func TestParameters(t *testing.T) {
|
||||
tools.NewBooleanParameter("my_bool", "this param is a boolean"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "string array",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_array",
|
||||
"type": "array",
|
||||
"description": "this param is an array of strings",
|
||||
"items": map[string]string{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewArrayParameter("my_array", "this param is an array of strings", tools.NewStringParameter("", "")),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "float array",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_array",
|
||||
"type": "array",
|
||||
"description": "this param is an array of floats",
|
||||
"items": map[string]string{
|
||||
"type": "float",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewArrayParameter("my_array", "this param is an array of floats", tools.NewFloatParameter("", "")),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -100,3 +133,123 @@ func TestParameters(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParametersParse(t *testing.T) {
|
||||
tcs := []struct {
|
||||
name string
|
||||
params tools.Parameters
|
||||
in map[string]any
|
||||
want []any
|
||||
}{
|
||||
{
|
||||
name: "string",
|
||||
params: tools.Parameters{
|
||||
tools.NewStringParameter("my_string", "this param is a string"),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_string": "hello world",
|
||||
},
|
||||
want: []any{"hello world"},
|
||||
},
|
||||
{
|
||||
name: "not string",
|
||||
params: tools.Parameters{
|
||||
tools.NewStringParameter("my_string", "this param is a string"),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_string": 4,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "int",
|
||||
params: tools.Parameters{
|
||||
tools.NewIntParameter("my_int", "this param is an int"),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_int": 100,
|
||||
},
|
||||
want: []any{100},
|
||||
},
|
||||
{
|
||||
name: "not int",
|
||||
params: tools.Parameters{
|
||||
tools.NewIntParameter("my_int", "this param is an int"),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_int": 14.5,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "float",
|
||||
params: tools.Parameters{
|
||||
tools.NewFloatParameter("my_float", "this param is a float"),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_float": 1.5,
|
||||
},
|
||||
want: []any{1.5},
|
||||
},
|
||||
{
|
||||
name: "not float",
|
||||
params: tools.Parameters{
|
||||
tools.NewFloatParameter("my_float", "this param is a float"),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_float": true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bool",
|
||||
params: tools.Parameters{
|
||||
tools.NewBooleanParameter("my_bool", "this param is a bool"),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_bool": true,
|
||||
},
|
||||
want: []any{true},
|
||||
},
|
||||
{
|
||||
name: "bool",
|
||||
params: tools.Parameters{
|
||||
tools.NewBooleanParameter("my_bool", "this param is a bool"),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_bool": "true",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// parse map to bytes
|
||||
data, err := yaml.Marshal(tc.in)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to marshal input to yaml: %s", err)
|
||||
}
|
||||
// parse bytes to object
|
||||
var m map[string]any
|
||||
err = yaml.Unmarshal(data, &m)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
|
||||
gotAll, err := tools.ParseParams(tc.params, m)
|
||||
if err != nil {
|
||||
if len(tc.want) == 0 {
|
||||
// error is expected if no items in want
|
||||
return
|
||||
}
|
||||
t.Fatalf("unexpected error from ParseParams: %s", err)
|
||||
}
|
||||
for i, got := range gotAll {
|
||||
want := tc.want[i]
|
||||
if got != want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, want)
|
||||
}
|
||||
gotType, wantType := reflect.TypeOf(got), reflect.TypeOf(want)
|
||||
if gotType != wantType {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user