mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 07:58:12 -05:00
feat: support regex in enum type
This commit is contained in:
@@ -19,6 +19,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"text/template"
|
||||
@@ -1353,9 +1354,11 @@ func (p *EnumParameter) UnmarshalYAML(ctx context.Context, unmarshal func(interf
|
||||
|
||||
// Parse validates and parses an incoming value for enum parameter.
|
||||
func (p *EnumParameter) Parse(v any) (any, error) {
|
||||
input := fmt.Sprintf("%v", v)
|
||||
var exists bool
|
||||
for _, av := range p.AllowedValues {
|
||||
if av == v {
|
||||
target := fmt.Sprintf("%v", av)
|
||||
if MatchStringOrRegex(input, target) {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
@@ -1367,6 +1370,15 @@ func (p *EnumParameter) Parse(v any) (any, error) {
|
||||
return p.EnumItem.Parse(v)
|
||||
}
|
||||
|
||||
// MatchStringOrRegex checks if the input matches the target
|
||||
func MatchStringOrRegex(input, target string) bool {
|
||||
re, err := regexp.Compile(target)
|
||||
if err != nil {
|
||||
return strings.Contains(input, target)
|
||||
}
|
||||
return re.MatchString(input)
|
||||
}
|
||||
|
||||
func (p *EnumParameter) GetAuthServices() []ParamAuthService {
|
||||
return p.AuthServices
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ package tools_test
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -859,6 +860,16 @@ func TestParametersParse(t *testing.T) {
|
||||
"enum_string": "invalid",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enum with integer",
|
||||
params: tools.Parameters{
|
||||
tools.NewEnumParameter(tools.NewIntParameter("enum_int", "enum of int"), false, []any{"^[1-5]$"}),
|
||||
},
|
||||
in: map[string]any{
|
||||
"enum_int": 4,
|
||||
},
|
||||
want: tools.ParamValues{tools.ParamValue{Name: "enum_int", Value: 4}},
|
||||
},
|
||||
{
|
||||
name: "string default",
|
||||
params: tools.Parameters{
|
||||
@@ -2026,3 +2037,75 @@ func TestCheckParamRequired(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchStringOrRegex(t *testing.T) {
|
||||
tcs := []struct {
|
||||
name string
|
||||
input string
|
||||
target string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "exact string",
|
||||
input: "foo",
|
||||
target: "foo",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "exact integer",
|
||||
input: fmt.Sprintf("%v", 5),
|
||||
target: fmt.Sprintf("%v", 5),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wrong integer",
|
||||
input: fmt.Sprintf("%v", 4),
|
||||
target: fmt.Sprintf("%v", 5),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "exact boolean",
|
||||
input: fmt.Sprintf("%v", true),
|
||||
target: fmt.Sprintf("%v", true),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "target contains input",
|
||||
input: "foo",
|
||||
target: "foo bar",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "regex any string",
|
||||
input: "foo",
|
||||
target: ".*",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "regex",
|
||||
input: "foo6",
|
||||
target: `foo\d+`,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "regex of numbers",
|
||||
input: "4",
|
||||
target: "^[1-5]$",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "regex of numbers invalid",
|
||||
input: "7",
|
||||
target: "^[1-5]$",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := tools.MatchStringOrRegex(tc.input, tc.target)
|
||||
if got != tc.want {
|
||||
t.Fatalf("got %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user