Compare commits

...

2 Commits

Author SHA1 Message Date
Yuan Teoh
9af7970322 feat: support regex in enum type 2025-09-25 14:46:35 -07:00
Yuan Teoh
d96fceacd2 feat: add new enum parameter type 2025-09-25 11:33:30 -07:00
3 changed files with 435 additions and 4 deletions

View File

@@ -153,6 +153,32 @@ will be thrown in case of value type mismatch.
valueType: integer # This enforces the value type for all entries. valueType: integer # This enforces the value type for all entries.
``` ```
### Enum Parameters
The `enum` type allow users to specify a set of allowed values with that
parameter. When toolbox parse the input of parameters, it will check against the
allowed values.
```yaml
parameter:
- name: airline
type: enum
description: name of airline.
enumType: string
allowedValues:
- cymbalair
- delta
```
Other than the regular fields required with the `enumType` specified, below are
the additional fields that are needed when using `enum` type.
| **field** | **type** | **required** | **description** |
|---------------|:--------:|:------------:|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| enumType | string | true | This indicates the type of the value. Must be one of the supported parameter type (e.g. `string`/ `integer` / `float` / `boolean` / `array` / `map`). |
| escape | bool | false | Indicates if the value will be escaped if used with `templateParameters`. Escaping will add double quotes (or backticks/square brackets depending on the source) depending on the database. This is defaulted to `false`. |
| allowedValues | []string | true | Input value will be checked against this field. |
### Authenticated Parameters ### Authenticated Parameters
Authenticated parameters are automatically populated with user Authenticated parameters are automatically populated with user

View File

@@ -19,10 +19,12 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"regexp"
"slices" "slices"
"strings" "strings"
"text/template" "text/template"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util"
) )
@@ -33,6 +35,7 @@ const (
typeBool = "boolean" typeBool = "boolean"
typeArray = "array" typeArray = "array"
typeMap = "map" typeMap = "map"
typeEnum = "enum"
) )
// ParamValues is an ordered list of ParamValue // ParamValues is an ordered list of ParamValue
@@ -179,6 +182,11 @@ func GetParams(params Parameters, paramValuesMap map[string]any) (ParamValues, e
if !ok { if !ok {
return nil, fmt.Errorf("missing parameter %s", k) return nil, fmt.Errorf("missing parameter %s", k)
} }
if p.GetType() == typeEnum {
if p.(*EnumParameter).GetEscape() {
v = fmt.Sprintf(`"%s"`, v)
}
}
resultParamValues = append(resultParamValues, ParamValue{Name: k, Value: v}) resultParamValues = append(resultParamValues, ParamValue{Name: k, Value: v})
} }
return resultParamValues, nil return resultParamValues, nil
@@ -233,6 +241,7 @@ type Parameter interface {
// but this is done to differentiate it from the fields in CommonParameter. // but this is done to differentiate it from the fields in CommonParameter.
GetName() string GetName() string
GetType() string GetType() string
GetDesc() string
GetDefault() any GetDefault() any
GetRequired() bool GetRequired() bool
GetAuthServices() []ParamAuthService GetAuthServices() []ParamAuthService
@@ -278,7 +287,7 @@ func parseParamFromDelayedUnmarshaler(ctx context.Context, u *util.DelayedUnmars
t, ok := p["type"] t, ok := p["type"]
if !ok { if !ok {
return nil, fmt.Errorf("parameter is missing 'type' field: %w", err) return nil, fmt.Errorf("parameter is missing 'type' field")
} }
dec, err := util.NewStrictDecoder(p) dec, err := util.NewStrictDecoder(p)
@@ -356,6 +365,17 @@ func parseParamFromDelayedUnmarshaler(ctx context.Context, u *util.DelayedUnmars
a.AuthSources = nil a.AuthSources = nil
} }
return a, nil return a, nil
case typeEnum:
a := &EnumParameter{}
if err := dec.DecodeContext(ctx, a); err != nil {
return nil, fmt.Errorf("unable to parse as %q: %w", t, err)
}
if a.AuthSources != nil {
logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` for parameters instead")
a.AuthServices = append(a.AuthServices, a.AuthSources...)
a.AuthSources = nil
}
return a, nil
} }
return nil, fmt.Errorf("%q is not valid type for a parameter", t) return nil, fmt.Errorf("%q is not valid type for a parameter", t)
} }
@@ -427,6 +447,11 @@ func (p *CommonParameter) GetType() string {
return p.Type return p.Type
} }
// GetDesc returns the description specified for the Parameter.
func (p *CommonParameter) GetDesc() string {
return p.Desc
}
// GetRequired returns the type specified for the Parameter. // GetRequired returns the type specified for the Parameter.
func (p *CommonParameter) GetRequired() bool { func (p *CommonParameter) GetRequired() bool {
// parameters are defaulted to required // parameters are defaulted to required
@@ -1230,3 +1255,155 @@ func (p *MapParameter) McpManifest() ParameterMcpManifest {
AdditionalProperties: additionalProperties, AdditionalProperties: additionalProperties,
} }
} }
// NewEnumParameter is a convenience function for initializing a EnumParameter.
func NewEnumParameter(param Parameter, escape bool, allowedValues []any) *EnumParameter {
d := param.GetDefault()
r := param.GetRequired()
return &EnumParameter{
CommonParameter: CommonParameter{
Name: param.GetName(),
Type: typeEnum,
Desc: param.GetDesc(),
Required: &r,
AuthServices: param.GetAuthServices(),
},
EnumType: param.GetType(),
Escape: escape,
AllowedValues: allowedValues,
EnumItem: param,
Default: &d,
}
}
// EnumParameter is a parameter that allow users to specify
// allowedValues to provide a fixed set of values. This will
// make parameter, especially templateParameter more secure and safe.
type EnumParameter struct {
CommonParameter `yaml:",inline"`
Default *any `yaml:"default"`
EnumType string `yaml:"enumType"`
Escape bool `yaml:"escape"`
AllowedValues []any `yaml:"allowedValues"`
EnumItem Parameter
}
// Ensure EnumParameter implements the Parameter interface.
var _ Parameter = &EnumParameter{}
// UnmarshalYAML handles parsing the EnumParameter from YAML input.
func (p *EnumParameter) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
var rawItem map[string]any
if err := unmarshal(&rawItem); err != nil {
return fmt.Errorf("error parsing enum parameter: %w", err)
}
// extract enum parameter known fields
enumType, ok := rawItem["enumType"].(string)
if !ok {
return fmt.Errorf("error parsing 'enumType' field")
}
escape := false
if v, ok := rawItem["escape"]; ok {
if escape, ok = v.(bool); !ok {
return fmt.Errorf("error parsing 'escape' field")
}
}
allowedValues, ok := rawItem["allowedValues"].([]any)
if !ok {
return fmt.Errorf("error parsing 'allowedValues' field")
}
rawItem["type"] = enumType
// remove the extracted field from the map
delete(rawItem, "enumType")
delete(rawItem, "escape")
delete(rawItem, "allowedValues")
// create a util.DelayedUnmarshaler from the remaining fields
m, err := yaml.Marshal(rawItem)
if err != nil {
return fmt.Errorf("error marshaling remaining fields from enum parameter")
}
var delayedUnmarshaler util.DelayedUnmarshaler
if err = yaml.UnmarshalContext(ctx, m, &delayedUnmarshaler); err != nil {
return fmt.Errorf("error unmarhaling into DelayedUnmarshaler")
}
parameter, err := parseParamFromDelayedUnmarshaler(ctx, &delayedUnmarshaler)
if err != nil {
return err
}
d := parameter.GetDefault()
r := parameter.GetRequired()
p.Default = &d
p.CommonParameter = CommonParameter{
Name: parameter.GetName(),
Type: "enum",
Desc: parameter.GetDesc(),
Required: &r,
AuthServices: parameter.GetAuthServices(),
}
p.EnumType = enumType
p.Escape = escape
p.AllowedValues = allowedValues
p.EnumItem = parameter
return nil
}
// Parse validates and parses an incoming value for enum parameter.
func (p *EnumParameter) Parse(v any) (any, error) {
input := fmt.Sprintf("%v", v)
var exists bool
for _, av := range p.AllowedValues {
target := fmt.Sprintf("%v", av)
if MatchStringOrRegex(input, target) {
exists = true
break
}
}
if !exists {
return nil, fmt.Errorf("unable to parse enum parameter: input is not part of allowed values")
}
return p.EnumItem.Parse(v)
}
// MatchStringOrRegex checks if the input matches the target
func MatchStringOrRegex(input, target string) bool {
re, err := regexp.Compile(target)
if err != nil {
return strings.Contains(input, target)
}
return re.MatchString(input)
}
func (p *EnumParameter) GetAuthServices() []ParamAuthService {
return p.AuthServices
}
func (p *EnumParameter) GetDefault() any {
if p.Default == nil {
return nil
}
return *p.Default
}
func (p *EnumParameter) GetEnumType() string {
return p.EnumType
}
func (p *EnumParameter) GetEscape() bool {
return p.Escape
}
// Manifest returns the manifest for the EnumParameter.
func (p *EnumParameter) Manifest() ParameterManifest {
return p.EnumItem.Manifest()
}
// McpManifest returns the MCP manifest for EnumParameter.
func (p *EnumParameter) McpManifest() ParameterMcpManifest {
return p.EnumItem.McpManifest()
}

View File

@@ -17,6 +17,7 @@ package tools_test
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt"
"math" "math"
"strings" "strings"
"testing" "testing"
@@ -351,6 +352,37 @@ func TestParametersMarshal(t *testing.T) {
tools.NewMapParameter("my_generic_map", "this param is a generic map", ""), tools.NewMapParameter("my_generic_map", "this param is a generic map", ""),
}, },
}, },
{
name: "enum string",
in: []map[string]any{
{
"name": "enum_string",
"type": "enum",
"enumType": "string",
"description": "enum string parameter",
"allowedValues": []any{"foo", "bar"},
},
},
want: tools.Parameters{
tools.NewEnumParameter(tools.NewStringParameter("enum_string", "enum string parameter"), false, []any{"foo", "bar"}),
},
},
{
name: "enum string with escape",
in: []map[string]any{
{
"name": "enum_string",
"type": "enum",
"enumType": "string",
"description": "enum string parameter",
"allowedValues": []any{"foo", "bar"},
"escape": true,
},
},
want: tools.Parameters{
tools.NewEnumParameter(tools.NewStringParameter("enum_string", "enum string parameter"), true, []any{"foo", "bar"}),
},
},
} }
for _, tc := range tcs { for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@@ -669,6 +701,31 @@ func TestAuthParametersMarshal(t *testing.T) {
tools.NewMapParameterWithAuth("my_map", "this param is a map of strings", "string", authServices), tools.NewMapParameterWithAuth("my_map", "this param is a map of strings", "string", authServices),
}, },
}, },
{
name: "enum",
in: []map[string]any{
{
"name": "enum_string",
"type": "enum",
"description": "enum of strings",
"enumType": "string",
"allowedValues": []any{"foo", "bar"},
"authServices": []map[string]string{
{
"name": "my-google-auth-service",
"field": "user_id",
},
{
"name": "other-auth-service",
"field": "user_id",
},
},
},
},
want: tools.Parameters{
tools.NewEnumParameter(tools.NewStringParameterWithAuth("enum_string", "enum of strings", authServices), false, []any{"foo", "bar"}),
},
},
} }
for _, tc := range tcs { for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@@ -784,6 +841,35 @@ func TestParametersParse(t *testing.T) {
"my_bool": 1.5, "my_bool": 1.5,
}, },
}, },
{
name: "enum",
params: tools.Parameters{
tools.NewEnumParameter(tools.NewStringParameter("enum_string", "enum of strings"), false, []any{"foo", "bar"}),
},
in: map[string]any{
"enum_string": "foo",
},
want: tools.ParamValues{tools.ParamValue{Name: "enum_string", Value: "foo"}},
},
{
name: "enum not allowed",
params: tools.Parameters{
tools.NewEnumParameter(tools.NewStringParameter("enum_string", "enum of strings"), false, []any{"foo", "bar"}),
},
in: map[string]any{
"enum_string": "invalid",
},
},
{
name: "enum with integer",
params: tools.Parameters{
tools.NewEnumParameter(tools.NewIntParameter("enum_int", "enum of int"), false, []any{"^[1-5]$"}),
},
in: map[string]any{
"enum_int": 4,
},
want: tools.ParamValues{tools.ParamValue{Name: "enum_int", Value: 4}},
},
{ {
name: "string default", name: "string default",
params: tools.Parameters{ params: tools.Parameters{
@@ -824,6 +910,14 @@ func TestParametersParse(t *testing.T) {
in: map[string]any{}, in: map[string]any{},
want: tools.ParamValues{tools.ParamValue{Name: "my_bool", Value: true}}, want: tools.ParamValues{tools.ParamValue{Name: "my_bool", Value: true}},
}, },
{
name: "enum default",
params: tools.Parameters{
tools.NewEnumParameter(tools.NewStringParameterWithDefault("enum_string", "foo", "enum of strings"), false, []any{"foo", "bar"}),
},
in: map[string]any{},
want: tools.ParamValues{tools.ParamValue{Name: "enum_string", Value: "foo"}},
},
{ {
name: "string not required", name: "string not required",
params: tools.Parameters{ params: tools.Parameters{
@@ -856,6 +950,14 @@ func TestParametersParse(t *testing.T) {
in: map[string]any{}, in: map[string]any{},
want: tools.ParamValues{tools.ParamValue{Name: "my_bool", Value: nil}}, want: tools.ParamValues{tools.ParamValue{Name: "my_bool", Value: nil}},
}, },
{
name: "enum not required",
params: tools.Parameters{
tools.NewEnumParameter(tools.NewStringParameterWithRequired("enum_string", "enum of strings", false), true, []any{"foo", "bar"}),
},
in: map[string]any{},
want: tools.ParamValues{tools.ParamValue{Name: "enum_string", Value: nil}},
},
{ {
name: "map", name: "map",
params: tools.Parameters{ params: tools.Parameters{
@@ -1197,6 +1299,17 @@ func TestParamManifest(t *testing.T) {
Items: &tools.ParameterManifest{Name: "foo-string", Type: "string", Required: true, Description: "bar", AuthServices: []string{}}, Items: &tools.ParameterManifest{Name: "foo-string", Type: "string", Required: true, Description: "bar", AuthServices: []string{}},
}, },
}, },
{
name: "enum with string",
in: tools.NewEnumParameter(tools.NewStringParameter("foo-enum", "enum of strings"), false, []any{"foo", "bar"}),
want: tools.ParameterManifest{
Name: "foo-enum",
Type: "string",
Required: true,
Description: "enum of strings",
AuthServices: []string{},
},
},
{ {
name: "string default", name: "string default",
in: tools.NewStringParameterWithDefault("foo-string", "foo", "bar"), in: tools.NewStringParameterWithDefault("foo-string", "foo", "bar"),
@@ -1229,6 +1342,17 @@ func TestParamManifest(t *testing.T) {
Items: &tools.ParameterManifest{Name: "foo-string", Type: "string", Required: false, Description: "bar", AuthServices: []string{}}, Items: &tools.ParameterManifest{Name: "foo-string", Type: "string", Required: false, Description: "bar", AuthServices: []string{}},
}, },
}, },
{
name: "enum with string default",
in: tools.NewEnumParameter(tools.NewStringParameterWithDefault("foo-enum", "foo", "enum of strings"), false, []any{"foo", "bar"}),
want: tools.ParameterManifest{
Name: "foo-enum",
Type: "string",
Required: false,
Description: "enum of strings",
AuthServices: []string{},
},
},
{ {
name: "string not required", name: "string not required",
in: tools.NewStringParameterWithRequired("foo-string", "bar", false), in: tools.NewStringParameterWithRequired("foo-string", "bar", false),
@@ -1261,6 +1385,17 @@ func TestParamManifest(t *testing.T) {
Items: &tools.ParameterManifest{Name: "foo-string", Type: "string", Required: false, Description: "bar", AuthServices: []string{}}, Items: &tools.ParameterManifest{Name: "foo-string", Type: "string", Required: false, Description: "bar", AuthServices: []string{}},
}, },
}, },
{
name: "enum with string not required",
in: tools.NewEnumParameter(tools.NewStringParameterWithRequired("foo-enum", "enum of strings", false), false, []any{"foo", "bar"}),
want: tools.ParameterManifest{
Name: "foo-enum",
Type: "string",
Required: false,
Description: "enum of strings",
AuthServices: []string{},
},
},
{ {
name: "map with string values", name: "map with string values",
in: tools.NewMapParameter("foo-map", "bar", "string"), in: tools.NewMapParameter("foo-map", "bar", "string"),
@@ -1343,7 +1478,6 @@ func TestParamMcpManifest(t *testing.T) {
Items: &tools.ParameterMcpManifest{Type: "string", Description: "bar"}, Items: &tools.ParameterMcpManifest{Type: "string", Description: "bar"},
}, },
}, },
{ {
name: "map with string values", name: "map with string values",
in: tools.NewMapParameter("foo-map", "bar", "string"), in: tools.NewMapParameter("foo-map", "bar", "string"),
@@ -1362,6 +1496,11 @@ func TestParamMcpManifest(t *testing.T) {
AdditionalProperties: true, AdditionalProperties: true,
}, },
}, },
{
name: "enum param",
in: tools.NewEnumParameter(tools.NewStringParameter("foo-enum", "enum of strings"), false, []any{"foo", "bar"}),
want: tools.ParameterMcpManifest{Type: "string", Description: "enum of strings"},
},
} }
for _, tc := range tcs { for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@@ -1389,6 +1528,7 @@ func TestMcpManifest(t *testing.T) {
tools.NewArrayParameter("foo-array2", "bar", tools.NewStringParameter("foo-string", "bar")), tools.NewArrayParameter("foo-array2", "bar", tools.NewStringParameter("foo-string", "bar")),
tools.NewMapParameter("foo-map-int", "a map of ints", "integer"), tools.NewMapParameter("foo-map-int", "a map of ints", "integer"),
tools.NewMapParameter("foo-map-any", "a map of any", ""), tools.NewMapParameter("foo-map-any", "a map of any", ""),
tools.NewEnumParameter(tools.NewStringParameter("foo-enum-string", "enum of strings"), false, []any{"foo", "bar"}),
}, },
want: tools.McpToolsSchema{ want: tools.McpToolsSchema{
Type: "object", Type: "object",
@@ -1412,8 +1552,12 @@ func TestMcpManifest(t *testing.T) {
Description: "a map of any", Description: "a map of any",
AdditionalProperties: true, AdditionalProperties: true,
}, },
"foo-enum-string": {
Type: "string",
Description: "enum of strings",
},
}, },
Required: []string{"foo-string2", "foo-int2", "foo-float", "foo-array2", "foo-map-int", "foo-map-any"}, Required: []string{"foo-string2", "foo-int2", "foo-float", "foo-array2", "foo-map-int", "foo-map-any", "foo-enum-string"},
}, },
}, },
} }
@@ -1455,7 +1599,7 @@ func TestFailParametersUnmarshal(t *testing.T) {
"description": "this is a param for string", "description": "this is a param for string",
}, },
}, },
err: "parameter is missing 'type' field: %!w(<nil>)", err: "parameter is missing 'type' field",
}, },
{ {
name: "common parameter missing description", name: "common parameter missing description",
@@ -1663,6 +1807,18 @@ func TestGetParams(t *testing.T) {
in: map[string]any{}, in: map[string]any{},
want: tools.ParamValues{}, want: tools.ParamValues{},
}, },
{
name: "enum with escape",
params: tools.Parameters{
tools.NewEnumParameter(tools.NewStringParameter("my_string_enum", "string of enums"), true, []any{"foo", "bar"}),
},
in: map[string]any{
"my_string_enum": "foo",
},
want: tools.ParamValues{
tools.ParamValue{Name: "my_string_enum", Value: `"foo"`},
},
},
} }
for _, tc := range tcs { for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@@ -1881,3 +2037,75 @@ func TestCheckParamRequired(t *testing.T) {
}) })
} }
} }
func TestMatchStringOrRegex(t *testing.T) {
tcs := []struct {
name string
input string
target string
want bool
}{
{
name: "exact string",
input: "foo",
target: "foo",
want: true,
},
{
name: "exact integer",
input: fmt.Sprintf("%v", 5),
target: fmt.Sprintf("%v", 5),
want: true,
},
{
name: "wrong integer",
input: fmt.Sprintf("%v", 4),
target: fmt.Sprintf("%v", 5),
want: false,
},
{
name: "exact boolean",
input: fmt.Sprintf("%v", true),
target: fmt.Sprintf("%v", true),
want: true,
},
{
name: "target contains input",
input: "foo",
target: "foo bar",
want: false,
},
{
name: "regex any string",
input: "foo",
target: ".*",
want: true,
},
{
name: "regex",
input: "foo6",
target: `foo\d+`,
want: true,
},
{
name: "regex of numbers",
input: "4",
target: "^[1-5]$",
want: true,
},
{
name: "regex of numbers invalid",
input: "7",
target: "^[1-5]$",
want: false,
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
got := tools.MatchStringOrRegex(tc.input, tc.target)
if got != tc.want {
t.Fatalf("got %v, want %v", got, tc.want)
}
})
}
}