chore: update params type (#98)

Different databases require different types for `Params` field when
adding parameters to their statement. e.g. alloydb, cloudsql, and
postgres uses `pgxpool` to query and build sql statement, whereas
spanner uses `Spanner` library.

Added a new `ParamValue` struct. `ParseParams` helper function parses
arbitraryJSON object into `[]ParamValue`, and the tool's invoke will
convert `[]ParamValue` into it's required type.

---------

Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>
This commit is contained in:
Yuan
2024-12-05 16:05:09 -08:00
committed by GitHub
parent e88ec409d1
commit e815dc49f4
5 changed files with 83 additions and 18 deletions

View File

@@ -35,11 +35,11 @@ type MockTool struct {
Params []tools.Parameter
}
func (t MockTool) Invoke([]any) (string, error) {
func (t MockTool) Invoke(tools.ParamValues) (string, error) {
return "", nil
}
func (t MockTool) ParseParams(data map[string]any) ([]any, error) {
func (t MockTool) ParseParams(data map[string]any) (tools.ParamValues, error) {
return tools.ParseParams(t.Params, data)
}

View File

@@ -28,11 +28,40 @@ const (
typeArray = "array"
)
// ParseParams is a helper function for parsing Parameters from an arbitraryJSON object.
func ParseParams(ps Parameters, data map[string]any) ([]any, error) {
// ParamValues is an ordered list of ParamValue
type ParamValues []ParamValue
// ParamValue represents the parameter's name and value.
type ParamValue struct {
Name string
Value any
}
// AsSlice returns a slice of the Param's values (in order).
func (p ParamValues) AsSlice() []any {
params := []any{}
for _, p := range p {
params = append(params, p.Value)
}
return params
}
// AsMap returns a map of ParamValue's names to values.
func (p ParamValues) AsMap() map[string]interface{} {
params := make(map[string]interface{})
for _, p := range p {
params[p.Name] = p.Value
}
return params
}
// ParseParams parses specified Parameters from data and returns them as ParamValues.
func ParseParams(ps Parameters, data map[string]any) (ParamValues, error) {
params := make([]ParamValue, 0, len(ps))
for _, p := range ps {
v, ok := data[p.GetName()]
name := p.GetName()
v, ok := data[name]
if !ok {
return nil, fmt.Errorf("parameter %q is required!", p.GetName())
}
@@ -40,7 +69,7 @@ func ParseParams(ps Parameters, data map[string]any) ([]any, error) {
if err != nil {
return nil, fmt.Errorf("unable to parse value for %q: %w", p.GetName(), err)
}
params = append(params, newV)
params = append(params, ParamValue{Name: name, Value: newV})
}
return params, nil
}

View File

@@ -139,7 +139,7 @@ func TestParametersParse(t *testing.T) {
name string
params tools.Parameters
in map[string]any
want []any
want tools.ParamValues
}{
{
name: "string",
@@ -149,7 +149,7 @@ func TestParametersParse(t *testing.T) {
in: map[string]any{
"my_string": "hello world",
},
want: []any{"hello world"},
want: tools.ParamValues{tools.ParamValue{Name: "my_string", Value: "hello world"}},
},
{
name: "not string",
@@ -168,7 +168,7 @@ func TestParametersParse(t *testing.T) {
in: map[string]any{
"my_int": 100,
},
want: []any{100},
want: tools.ParamValues{tools.ParamValue{Name: "my_int", Value: 100}},
},
{
name: "not int",
@@ -187,7 +187,7 @@ func TestParametersParse(t *testing.T) {
in: map[string]any{
"my_float": 1.5,
},
want: []any{1.5},
want: tools.ParamValues{tools.ParamValue{Name: "my_float", Value: 1.5}},
},
{
name: "not float",
@@ -206,15 +206,15 @@ func TestParametersParse(t *testing.T) {
in: map[string]any{
"my_bool": true,
},
want: []any{true},
want: tools.ParamValues{tools.ParamValue{Name: "my_bool", Value: true}},
},
{
name: "bool",
name: "not bool",
params: tools.Parameters{
tools.NewBooleanParameter("my_bool", "this param is a bool"),
},
in: map[string]any{
"my_bool": "true",
"my_bool": 1.5,
},
},
}
@@ -253,3 +253,38 @@ func TestParametersParse(t *testing.T) {
})
}
}
func TestParamValues(t *testing.T) {
tcs := []struct {
name string
in tools.ParamValues
wantSlice []any
wantMap map[string]interface{}
}{
{
name: "string",
in: tools.ParamValues{tools.ParamValue{Name: "my_bool", Value: true}, tools.ParamValue{Name: "my_string", Value: "hello world"}},
wantSlice: []any{true, "hello world"},
wantMap: map[string]interface{}{"my_bool": true, "my_string": "hello world"},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
gotSlice := tc.in.AsSlice()
gotMap := tc.in.AsMap()
for i, got := range gotSlice {
want := tc.wantSlice[i]
if got != want {
t.Fatalf("unexpected value: got %q, want %q", got, want)
}
}
for i, got := range gotMap {
want := tc.wantMap[i]
if got != want {
t.Fatalf("unexpected value: got %q, want %q", got, want)
}
}
})
}
}

View File

@@ -105,9 +105,10 @@ type Tool struct {
manifest tools.Manifest
}
func (t Tool) Invoke(params []any) (string, error) {
func (t Tool) Invoke(params tools.ParamValues) (string, error) {
sliceParams := params.AsSlice()
fmt.Printf("Invoked tool %s\n", t.Name)
results, err := t.Pool.Query(context.Background(), t.Statement, params...)
results, err := t.Pool.Query(context.Background(), t.Statement, sliceParams...)
if err != nil {
return "", fmt.Errorf("unable to execute query: %w", err)
}
@@ -124,7 +125,7 @@ func (t Tool) Invoke(params []any) (string, error) {
return fmt.Sprintf("Stub tool call for %q! Parameters parsed: %q \n Output: %s", t.Name, params, out.String()), nil
}
func (t Tool) ParseParams(data map[string]any) ([]any, error) {
func (t Tool) ParseParams(data map[string]any) (tools.ParamValues, error) {
return tools.ParseParams(t.Parameters, data)
}

View File

@@ -24,8 +24,8 @@ type ToolConfig interface {
}
type Tool interface {
Invoke([]any) (string, error)
ParseParams(data map[string]any) ([]any, error)
Invoke(ParamValues) (string, error)
ParseParams(data map[string]any) (ParamValues, error)
Manifest() Manifest
}