mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-04-09 03:02:26 -04:00
chore: update params type (#98)
Different databases require different types for `Params` field when adding parameters to their statement. e.g. alloydb, cloudsql, and postgres uses `pgxpool` to query and build sql statement, whereas spanner uses `Spanner` library. Added a new `ParamValue` struct. `ParseParams` helper function parses arbitraryJSON object into `[]ParamValue`, and the tool's invoke will convert `[]ParamValue` into it's required type. --------- Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>
This commit is contained in:
@@ -35,11 +35,11 @@ type MockTool struct {
|
||||
Params []tools.Parameter
|
||||
}
|
||||
|
||||
func (t MockTool) Invoke([]any) (string, error) {
|
||||
func (t MockTool) Invoke(tools.ParamValues) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (t MockTool) ParseParams(data map[string]any) ([]any, error) {
|
||||
func (t MockTool) ParseParams(data map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Params, data)
|
||||
}
|
||||
|
||||
|
||||
@@ -28,11 +28,40 @@ const (
|
||||
typeArray = "array"
|
||||
)
|
||||
|
||||
// ParseParams is a helper function for parsing Parameters from an arbitraryJSON object.
|
||||
func ParseParams(ps Parameters, data map[string]any) ([]any, error) {
|
||||
// ParamValues is an ordered list of ParamValue
|
||||
type ParamValues []ParamValue
|
||||
|
||||
// ParamValue represents the parameter's name and value.
|
||||
type ParamValue struct {
|
||||
Name string
|
||||
Value any
|
||||
}
|
||||
|
||||
// AsSlice returns a slice of the Param's values (in order).
|
||||
func (p ParamValues) AsSlice() []any {
|
||||
params := []any{}
|
||||
|
||||
for _, p := range p {
|
||||
params = append(params, p.Value)
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
// AsMap returns a map of ParamValue's names to values.
|
||||
func (p ParamValues) AsMap() map[string]interface{} {
|
||||
params := make(map[string]interface{})
|
||||
for _, p := range p {
|
||||
params[p.Name] = p.Value
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
// ParseParams parses specified Parameters from data and returns them as ParamValues.
|
||||
func ParseParams(ps Parameters, data map[string]any) (ParamValues, error) {
|
||||
params := make([]ParamValue, 0, len(ps))
|
||||
for _, p := range ps {
|
||||
v, ok := data[p.GetName()]
|
||||
name := p.GetName()
|
||||
v, ok := data[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("parameter %q is required!", p.GetName())
|
||||
}
|
||||
@@ -40,7 +69,7 @@ func ParseParams(ps Parameters, data map[string]any) ([]any, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse value for %q: %w", p.GetName(), err)
|
||||
}
|
||||
params = append(params, newV)
|
||||
params = append(params, ParamValue{Name: name, Value: newV})
|
||||
}
|
||||
return params, nil
|
||||
}
|
||||
|
||||
@@ -139,7 +139,7 @@ func TestParametersParse(t *testing.T) {
|
||||
name string
|
||||
params tools.Parameters
|
||||
in map[string]any
|
||||
want []any
|
||||
want tools.ParamValues
|
||||
}{
|
||||
{
|
||||
name: "string",
|
||||
@@ -149,7 +149,7 @@ func TestParametersParse(t *testing.T) {
|
||||
in: map[string]any{
|
||||
"my_string": "hello world",
|
||||
},
|
||||
want: []any{"hello world"},
|
||||
want: tools.ParamValues{tools.ParamValue{Name: "my_string", Value: "hello world"}},
|
||||
},
|
||||
{
|
||||
name: "not string",
|
||||
@@ -168,7 +168,7 @@ func TestParametersParse(t *testing.T) {
|
||||
in: map[string]any{
|
||||
"my_int": 100,
|
||||
},
|
||||
want: []any{100},
|
||||
want: tools.ParamValues{tools.ParamValue{Name: "my_int", Value: 100}},
|
||||
},
|
||||
{
|
||||
name: "not int",
|
||||
@@ -187,7 +187,7 @@ func TestParametersParse(t *testing.T) {
|
||||
in: map[string]any{
|
||||
"my_float": 1.5,
|
||||
},
|
||||
want: []any{1.5},
|
||||
want: tools.ParamValues{tools.ParamValue{Name: "my_float", Value: 1.5}},
|
||||
},
|
||||
{
|
||||
name: "not float",
|
||||
@@ -206,15 +206,15 @@ func TestParametersParse(t *testing.T) {
|
||||
in: map[string]any{
|
||||
"my_bool": true,
|
||||
},
|
||||
want: []any{true},
|
||||
want: tools.ParamValues{tools.ParamValue{Name: "my_bool", Value: true}},
|
||||
},
|
||||
{
|
||||
name: "bool",
|
||||
name: "not bool",
|
||||
params: tools.Parameters{
|
||||
tools.NewBooleanParameter("my_bool", "this param is a bool"),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_bool": "true",
|
||||
"my_bool": 1.5,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -253,3 +253,38 @@ func TestParametersParse(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParamValues(t *testing.T) {
|
||||
tcs := []struct {
|
||||
name string
|
||||
in tools.ParamValues
|
||||
wantSlice []any
|
||||
wantMap map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "string",
|
||||
in: tools.ParamValues{tools.ParamValue{Name: "my_bool", Value: true}, tools.ParamValue{Name: "my_string", Value: "hello world"}},
|
||||
wantSlice: []any{true, "hello world"},
|
||||
wantMap: map[string]interface{}{"my_bool": true, "my_string": "hello world"},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gotSlice := tc.in.AsSlice()
|
||||
gotMap := tc.in.AsMap()
|
||||
|
||||
for i, got := range gotSlice {
|
||||
want := tc.wantSlice[i]
|
||||
if got != want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
for i, got := range gotMap {
|
||||
want := tc.wantMap[i]
|
||||
if got != want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,9 +105,10 @@ type Tool struct {
|
||||
manifest tools.Manifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(params []any) (string, error) {
|
||||
func (t Tool) Invoke(params tools.ParamValues) (string, error) {
|
||||
sliceParams := params.AsSlice()
|
||||
fmt.Printf("Invoked tool %s\n", t.Name)
|
||||
results, err := t.Pool.Query(context.Background(), t.Statement, params...)
|
||||
results, err := t.Pool.Query(context.Background(), t.Statement, sliceParams...)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
@@ -124,7 +125,7 @@ func (t Tool) Invoke(params []any) (string, error) {
|
||||
return fmt.Sprintf("Stub tool call for %q! Parameters parsed: %q \n Output: %s", t.Name, params, out.String()), nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any) ([]any, error) {
|
||||
func (t Tool) ParseParams(data map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data)
|
||||
}
|
||||
|
||||
|
||||
@@ -24,8 +24,8 @@ type ToolConfig interface {
|
||||
}
|
||||
|
||||
type Tool interface {
|
||||
Invoke([]any) (string, error)
|
||||
ParseParams(data map[string]any) ([]any, error)
|
||||
Invoke(ParamValues) (string, error)
|
||||
ParseParams(data map[string]any) (ParamValues, error)
|
||||
Manifest() Manifest
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user