Compare commits

...

2 Commits

Author SHA1 Message Date
Yuan Teoh
9af7970322 feat: support regex in enum type 2025-09-25 14:46:35 -07:00
Yuan Teoh
d96fceacd2 feat: add new enum parameter type 2025-09-25 11:33:30 -07:00
3 changed files with 435 additions and 4 deletions

View File

@@ -153,6 +153,32 @@ will be thrown in case of value type mismatch.
valueType: integer # This enforces the value type for all entries.
```
### Enum Parameters
The `enum` type allow users to specify a set of allowed values with that
parameter. When toolbox parse the input of parameters, it will check against the
allowed values.
```yaml
parameter:
- name: airline
type: enum
description: name of airline.
enumType: string
allowedValues:
- cymbalair
- delta
```
Other than the regular fields required with the `enumType` specified, below are
the additional fields that are needed when using `enum` type.
| **field** | **type** | **required** | **description** |
|---------------|:--------:|:------------:|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| enumType | string | true | This indicates the type of the value. Must be one of the supported parameter type (e.g. `string`/ `integer` / `float` / `boolean` / `array` / `map`). |
| escape | bool | false | Indicates if the value will be escaped if used with `templateParameters`. Escaping will add double quotes (or backticks/square brackets depending on the source) depending on the database. This is defaulted to `false`. |
| allowedValues | []string | true | Input value will be checked against this field. |
### Authenticated Parameters
Authenticated parameters are automatically populated with user

View File

@@ -19,10 +19,12 @@ import (
"context"
"encoding/json"
"fmt"
"regexp"
"slices"
"strings"
"text/template"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/util"
)
@@ -33,6 +35,7 @@ const (
typeBool = "boolean"
typeArray = "array"
typeMap = "map"
typeEnum = "enum"
)
// ParamValues is an ordered list of ParamValue
@@ -179,6 +182,11 @@ func GetParams(params Parameters, paramValuesMap map[string]any) (ParamValues, e
if !ok {
return nil, fmt.Errorf("missing parameter %s", k)
}
if p.GetType() == typeEnum {
if p.(*EnumParameter).GetEscape() {
v = fmt.Sprintf(`"%s"`, v)
}
}
resultParamValues = append(resultParamValues, ParamValue{Name: k, Value: v})
}
return resultParamValues, nil
@@ -233,6 +241,7 @@ type Parameter interface {
// but this is done to differentiate it from the fields in CommonParameter.
GetName() string
GetType() string
GetDesc() string
GetDefault() any
GetRequired() bool
GetAuthServices() []ParamAuthService
@@ -278,7 +287,7 @@ func parseParamFromDelayedUnmarshaler(ctx context.Context, u *util.DelayedUnmars
t, ok := p["type"]
if !ok {
return nil, fmt.Errorf("parameter is missing 'type' field: %w", err)
return nil, fmt.Errorf("parameter is missing 'type' field")
}
dec, err := util.NewStrictDecoder(p)
@@ -356,6 +365,17 @@ func parseParamFromDelayedUnmarshaler(ctx context.Context, u *util.DelayedUnmars
a.AuthSources = nil
}
return a, nil
case typeEnum:
a := &EnumParameter{}
if err := dec.DecodeContext(ctx, a); err != nil {
return nil, fmt.Errorf("unable to parse as %q: %w", t, err)
}
if a.AuthSources != nil {
logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` for parameters instead")
a.AuthServices = append(a.AuthServices, a.AuthSources...)
a.AuthSources = nil
}
return a, nil
}
return nil, fmt.Errorf("%q is not valid type for a parameter", t)
}
@@ -427,6 +447,11 @@ func (p *CommonParameter) GetType() string {
return p.Type
}
// GetDesc returns the description specified for the Parameter.
func (p *CommonParameter) GetDesc() string {
return p.Desc
}
// GetRequired returns the type specified for the Parameter.
func (p *CommonParameter) GetRequired() bool {
// parameters are defaulted to required
@@ -1230,3 +1255,155 @@ func (p *MapParameter) McpManifest() ParameterMcpManifest {
AdditionalProperties: additionalProperties,
}
}
// NewEnumParameter is a convenience function for initializing a EnumParameter.
func NewEnumParameter(param Parameter, escape bool, allowedValues []any) *EnumParameter {
d := param.GetDefault()
r := param.GetRequired()
return &EnumParameter{
CommonParameter: CommonParameter{
Name: param.GetName(),
Type: typeEnum,
Desc: param.GetDesc(),
Required: &r,
AuthServices: param.GetAuthServices(),
},
EnumType: param.GetType(),
Escape: escape,
AllowedValues: allowedValues,
EnumItem: param,
Default: &d,
}
}
// EnumParameter is a parameter that allow users to specify
// allowedValues to provide a fixed set of values. This will
// make parameter, especially templateParameter more secure and safe.
type EnumParameter struct {
CommonParameter `yaml:",inline"`
Default *any `yaml:"default"`
EnumType string `yaml:"enumType"`
Escape bool `yaml:"escape"`
AllowedValues []any `yaml:"allowedValues"`
EnumItem Parameter
}
// Ensure EnumParameter implements the Parameter interface.
var _ Parameter = &EnumParameter{}
// UnmarshalYAML handles parsing the EnumParameter from YAML input.
func (p *EnumParameter) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
var rawItem map[string]any
if err := unmarshal(&rawItem); err != nil {
return fmt.Errorf("error parsing enum parameter: %w", err)
}
// extract enum parameter known fields
enumType, ok := rawItem["enumType"].(string)
if !ok {
return fmt.Errorf("error parsing 'enumType' field")
}
escape := false
if v, ok := rawItem["escape"]; ok {
if escape, ok = v.(bool); !ok {
return fmt.Errorf("error parsing 'escape' field")
}
}
allowedValues, ok := rawItem["allowedValues"].([]any)
if !ok {
return fmt.Errorf("error parsing 'allowedValues' field")
}
rawItem["type"] = enumType
// remove the extracted field from the map
delete(rawItem, "enumType")
delete(rawItem, "escape")
delete(rawItem, "allowedValues")
// create a util.DelayedUnmarshaler from the remaining fields
m, err := yaml.Marshal(rawItem)
if err != nil {
return fmt.Errorf("error marshaling remaining fields from enum parameter")
}
var delayedUnmarshaler util.DelayedUnmarshaler
if err = yaml.UnmarshalContext(ctx, m, &delayedUnmarshaler); err != nil {
return fmt.Errorf("error unmarhaling into DelayedUnmarshaler")
}
parameter, err := parseParamFromDelayedUnmarshaler(ctx, &delayedUnmarshaler)
if err != nil {
return err
}
d := parameter.GetDefault()
r := parameter.GetRequired()
p.Default = &d
p.CommonParameter = CommonParameter{
Name: parameter.GetName(),
Type: "enum",
Desc: parameter.GetDesc(),
Required: &r,
AuthServices: parameter.GetAuthServices(),
}
p.EnumType = enumType
p.Escape = escape
p.AllowedValues = allowedValues
p.EnumItem = parameter
return nil
}
// Parse validates and parses an incoming value for enum parameter.
func (p *EnumParameter) Parse(v any) (any, error) {
input := fmt.Sprintf("%v", v)
var exists bool
for _, av := range p.AllowedValues {
target := fmt.Sprintf("%v", av)
if MatchStringOrRegex(input, target) {
exists = true
break
}
}
if !exists {
return nil, fmt.Errorf("unable to parse enum parameter: input is not part of allowed values")
}
return p.EnumItem.Parse(v)
}
// MatchStringOrRegex checks if the input matches the target
func MatchStringOrRegex(input, target string) bool {
re, err := regexp.Compile(target)
if err != nil {
return strings.Contains(input, target)
}
return re.MatchString(input)
}
func (p *EnumParameter) GetAuthServices() []ParamAuthService {
return p.AuthServices
}
func (p *EnumParameter) GetDefault() any {
if p.Default == nil {
return nil
}
return *p.Default
}
func (p *EnumParameter) GetEnumType() string {
return p.EnumType
}
func (p *EnumParameter) GetEscape() bool {
return p.Escape
}
// Manifest returns the manifest for the EnumParameter.
func (p *EnumParameter) Manifest() ParameterManifest {
return p.EnumItem.Manifest()
}
// McpManifest returns the MCP manifest for EnumParameter.
func (p *EnumParameter) McpManifest() ParameterMcpManifest {
return p.EnumItem.McpManifest()
}

View File

@@ -17,6 +17,7 @@ package tools_test
import (
"bytes"
"encoding/json"
"fmt"
"math"
"strings"
"testing"
@@ -351,6 +352,37 @@ func TestParametersMarshal(t *testing.T) {
tools.NewMapParameter("my_generic_map", "this param is a generic map", ""),
},
},
{
name: "enum string",
in: []map[string]any{
{
"name": "enum_string",
"type": "enum",
"enumType": "string",
"description": "enum string parameter",
"allowedValues": []any{"foo", "bar"},
},
},
want: tools.Parameters{
tools.NewEnumParameter(tools.NewStringParameter("enum_string", "enum string parameter"), false, []any{"foo", "bar"}),
},
},
{
name: "enum string with escape",
in: []map[string]any{
{
"name": "enum_string",
"type": "enum",
"enumType": "string",
"description": "enum string parameter",
"allowedValues": []any{"foo", "bar"},
"escape": true,
},
},
want: tools.Parameters{
tools.NewEnumParameter(tools.NewStringParameter("enum_string", "enum string parameter"), true, []any{"foo", "bar"}),
},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
@@ -669,6 +701,31 @@ func TestAuthParametersMarshal(t *testing.T) {
tools.NewMapParameterWithAuth("my_map", "this param is a map of strings", "string", authServices),
},
},
{
name: "enum",
in: []map[string]any{
{
"name": "enum_string",
"type": "enum",
"description": "enum of strings",
"enumType": "string",
"allowedValues": []any{"foo", "bar"},
"authServices": []map[string]string{
{
"name": "my-google-auth-service",
"field": "user_id",
},
{
"name": "other-auth-service",
"field": "user_id",
},
},
},
},
want: tools.Parameters{
tools.NewEnumParameter(tools.NewStringParameterWithAuth("enum_string", "enum of strings", authServices), false, []any{"foo", "bar"}),
},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
@@ -784,6 +841,35 @@ func TestParametersParse(t *testing.T) {
"my_bool": 1.5,
},
},
{
name: "enum",
params: tools.Parameters{
tools.NewEnumParameter(tools.NewStringParameter("enum_string", "enum of strings"), false, []any{"foo", "bar"}),
},
in: map[string]any{
"enum_string": "foo",
},
want: tools.ParamValues{tools.ParamValue{Name: "enum_string", Value: "foo"}},
},
{
name: "enum not allowed",
params: tools.Parameters{
tools.NewEnumParameter(tools.NewStringParameter("enum_string", "enum of strings"), false, []any{"foo", "bar"}),
},
in: map[string]any{
"enum_string": "invalid",
},
},
{
name: "enum with integer",
params: tools.Parameters{
tools.NewEnumParameter(tools.NewIntParameter("enum_int", "enum of int"), false, []any{"^[1-5]$"}),
},
in: map[string]any{
"enum_int": 4,
},
want: tools.ParamValues{tools.ParamValue{Name: "enum_int", Value: 4}},
},
{
name: "string default",
params: tools.Parameters{
@@ -824,6 +910,14 @@ func TestParametersParse(t *testing.T) {
in: map[string]any{},
want: tools.ParamValues{tools.ParamValue{Name: "my_bool", Value: true}},
},
{
name: "enum default",
params: tools.Parameters{
tools.NewEnumParameter(tools.NewStringParameterWithDefault("enum_string", "foo", "enum of strings"), false, []any{"foo", "bar"}),
},
in: map[string]any{},
want: tools.ParamValues{tools.ParamValue{Name: "enum_string", Value: "foo"}},
},
{
name: "string not required",
params: tools.Parameters{
@@ -856,6 +950,14 @@ func TestParametersParse(t *testing.T) {
in: map[string]any{},
want: tools.ParamValues{tools.ParamValue{Name: "my_bool", Value: nil}},
},
{
name: "enum not required",
params: tools.Parameters{
tools.NewEnumParameter(tools.NewStringParameterWithRequired("enum_string", "enum of strings", false), true, []any{"foo", "bar"}),
},
in: map[string]any{},
want: tools.ParamValues{tools.ParamValue{Name: "enum_string", Value: nil}},
},
{
name: "map",
params: tools.Parameters{
@@ -1197,6 +1299,17 @@ func TestParamManifest(t *testing.T) {
Items: &tools.ParameterManifest{Name: "foo-string", Type: "string", Required: true, Description: "bar", AuthServices: []string{}},
},
},
{
name: "enum with string",
in: tools.NewEnumParameter(tools.NewStringParameter("foo-enum", "enum of strings"), false, []any{"foo", "bar"}),
want: tools.ParameterManifest{
Name: "foo-enum",
Type: "string",
Required: true,
Description: "enum of strings",
AuthServices: []string{},
},
},
{
name: "string default",
in: tools.NewStringParameterWithDefault("foo-string", "foo", "bar"),
@@ -1229,6 +1342,17 @@ func TestParamManifest(t *testing.T) {
Items: &tools.ParameterManifest{Name: "foo-string", Type: "string", Required: false, Description: "bar", AuthServices: []string{}},
},
},
{
name: "enum with string default",
in: tools.NewEnumParameter(tools.NewStringParameterWithDefault("foo-enum", "foo", "enum of strings"), false, []any{"foo", "bar"}),
want: tools.ParameterManifest{
Name: "foo-enum",
Type: "string",
Required: false,
Description: "enum of strings",
AuthServices: []string{},
},
},
{
name: "string not required",
in: tools.NewStringParameterWithRequired("foo-string", "bar", false),
@@ -1261,6 +1385,17 @@ func TestParamManifest(t *testing.T) {
Items: &tools.ParameterManifest{Name: "foo-string", Type: "string", Required: false, Description: "bar", AuthServices: []string{}},
},
},
{
name: "enum with string not required",
in: tools.NewEnumParameter(tools.NewStringParameterWithRequired("foo-enum", "enum of strings", false), false, []any{"foo", "bar"}),
want: tools.ParameterManifest{
Name: "foo-enum",
Type: "string",
Required: false,
Description: "enum of strings",
AuthServices: []string{},
},
},
{
name: "map with string values",
in: tools.NewMapParameter("foo-map", "bar", "string"),
@@ -1343,7 +1478,6 @@ func TestParamMcpManifest(t *testing.T) {
Items: &tools.ParameterMcpManifest{Type: "string", Description: "bar"},
},
},
{
name: "map with string values",
in: tools.NewMapParameter("foo-map", "bar", "string"),
@@ -1362,6 +1496,11 @@ func TestParamMcpManifest(t *testing.T) {
AdditionalProperties: true,
},
},
{
name: "enum param",
in: tools.NewEnumParameter(tools.NewStringParameter("foo-enum", "enum of strings"), false, []any{"foo", "bar"}),
want: tools.ParameterMcpManifest{Type: "string", Description: "enum of strings"},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
@@ -1389,6 +1528,7 @@ func TestMcpManifest(t *testing.T) {
tools.NewArrayParameter("foo-array2", "bar", tools.NewStringParameter("foo-string", "bar")),
tools.NewMapParameter("foo-map-int", "a map of ints", "integer"),
tools.NewMapParameter("foo-map-any", "a map of any", ""),
tools.NewEnumParameter(tools.NewStringParameter("foo-enum-string", "enum of strings"), false, []any{"foo", "bar"}),
},
want: tools.McpToolsSchema{
Type: "object",
@@ -1412,8 +1552,12 @@ func TestMcpManifest(t *testing.T) {
Description: "a map of any",
AdditionalProperties: true,
},
"foo-enum-string": {
Type: "string",
Description: "enum of strings",
},
},
Required: []string{"foo-string2", "foo-int2", "foo-float", "foo-array2", "foo-map-int", "foo-map-any"},
Required: []string{"foo-string2", "foo-int2", "foo-float", "foo-array2", "foo-map-int", "foo-map-any", "foo-enum-string"},
},
},
}
@@ -1455,7 +1599,7 @@ func TestFailParametersUnmarshal(t *testing.T) {
"description": "this is a param for string",
},
},
err: "parameter is missing 'type' field: %!w(<nil>)",
err: "parameter is missing 'type' field",
},
{
name: "common parameter missing description",
@@ -1663,6 +1807,18 @@ func TestGetParams(t *testing.T) {
in: map[string]any{},
want: tools.ParamValues{},
},
{
name: "enum with escape",
params: tools.Parameters{
tools.NewEnumParameter(tools.NewStringParameter("my_string_enum", "string of enums"), true, []any{"foo", "bar"}),
},
in: map[string]any{
"my_string_enum": "foo",
},
want: tools.ParamValues{
tools.ParamValue{Name: "my_string_enum", Value: `"foo"`},
},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
@@ -1881,3 +2037,75 @@ func TestCheckParamRequired(t *testing.T) {
})
}
}
func TestMatchStringOrRegex(t *testing.T) {
tcs := []struct {
name string
input string
target string
want bool
}{
{
name: "exact string",
input: "foo",
target: "foo",
want: true,
},
{
name: "exact integer",
input: fmt.Sprintf("%v", 5),
target: fmt.Sprintf("%v", 5),
want: true,
},
{
name: "wrong integer",
input: fmt.Sprintf("%v", 4),
target: fmt.Sprintf("%v", 5),
want: false,
},
{
name: "exact boolean",
input: fmt.Sprintf("%v", true),
target: fmt.Sprintf("%v", true),
want: true,
},
{
name: "target contains input",
input: "foo",
target: "foo bar",
want: false,
},
{
name: "regex any string",
input: "foo",
target: ".*",
want: true,
},
{
name: "regex",
input: "foo6",
target: `foo\d+`,
want: true,
},
{
name: "regex of numbers",
input: "4",
target: "^[1-5]$",
want: true,
},
{
name: "regex of numbers invalid",
input: "7",
target: "^[1-5]$",
want: false,
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
got := tools.MatchStringOrRegex(tc.input, tc.target)
if got != tc.want {
t.Fatalf("got %v, want %v", got, tc.want)
}
})
}
}