mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-04-09 03:02:26 -04:00
To facilitate the transition of moving invocation implementation to Source, we will have to move parameter to `internal/util`. This approach is crucial because certain parameters may not be fully resolvable pre-implementation. Since both `internal/sources` and `internal/tools` will need access to `parameters`, it will be more relevant to move parameters implementation to utils.
250 lines
6.9 KiB
Go
250 lines
6.9 KiB
Go
// Copyright 2025 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package prompts_test
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
"testing"
|
|
|
|
yaml "github.com/goccy/go-yaml"
|
|
"github.com/google/go-cmp/cmp"
|
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
|
)
|
|
|
|
// Test type aliases for convenience.
|
|
type (
|
|
Argument = prompts.Argument
|
|
ArgMcpManifest = prompts.ArgMcpManifest
|
|
Arguments = prompts.Arguments
|
|
)
|
|
|
|
// Ptr is a helper function to create a pointer to a value.
|
|
func Ptr[T any](v T) *T {
|
|
return &v
|
|
}
|
|
|
|
func makeArrayArg(name, desc string, items parameters.Parameter) Argument {
|
|
return Argument{Parameter: parameters.NewArrayParameter(name, desc, items)}
|
|
}
|
|
|
|
func TestArgMcpManifest(t *testing.T) {
|
|
t.Parallel()
|
|
testCases := []struct {
|
|
name string
|
|
arg Argument
|
|
expected ArgMcpManifest
|
|
}{
|
|
{
|
|
name: "Required with no default",
|
|
arg: Argument{Parameter: parameters.NewStringParameterWithRequired("name1", "desc1", true)},
|
|
expected: ArgMcpManifest{
|
|
Name: "name1", Description: "desc1", Required: true,
|
|
},
|
|
},
|
|
{
|
|
name: "Not required with no default",
|
|
arg: Argument{Parameter: parameters.NewStringParameterWithRequired("name2", "desc2", false)},
|
|
expected: ArgMcpManifest{
|
|
Name: "name2", Description: "desc2", Required: false,
|
|
},
|
|
},
|
|
{
|
|
name: "Implicitly required with default",
|
|
arg: Argument{Parameter: parameters.NewStringParameterWithDefault("name3", "defaultVal", "desc3")},
|
|
expected: ArgMcpManifest{
|
|
Name: "name3", Description: "desc3", Required: false,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
got := tc.arg.McpManifest()
|
|
if diff := cmp.Diff(tc.expected, got); diff != "" {
|
|
t.Errorf("McpManifest() mismatch (-want +got):\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestArguments_UnmarshalYAML tests all unmarshaling logic for the Arguments type.
|
|
func TestArgumentsUnmarshalYAML(t *testing.T) {
|
|
t.Parallel()
|
|
// paramComparer allows cmp.Diff to intelligently compare the parsed results.
|
|
var transformFunc func(parameters.Parameter) any
|
|
transformFunc = func(p parameters.Parameter) any {
|
|
s := struct{ Name, Type, Desc string }{
|
|
Name: p.GetName(),
|
|
Type: p.GetType(),
|
|
Desc: p.Manifest().Description,
|
|
}
|
|
if arr, ok := p.(*parameters.ArrayParameter); ok {
|
|
s.Desc = fmt.Sprintf("%s items:%v", s.Desc, transformFunc(arr.GetItems()))
|
|
}
|
|
return s
|
|
}
|
|
paramComparer := cmp.Transformer("Parameter", transformFunc)
|
|
|
|
testCases := []struct {
|
|
name string
|
|
yamlInput []map[string]any
|
|
expectedArgs Arguments
|
|
wantErr string
|
|
}{
|
|
{
|
|
name: "Defaults type to string when omitted",
|
|
yamlInput: []map[string]any{
|
|
{"name": "p1", "description": "d1"},
|
|
},
|
|
expectedArgs: Arguments{
|
|
{Parameter: parameters.NewStringParameter("p1", "d1")},
|
|
},
|
|
},
|
|
{
|
|
name: "Respects type when present",
|
|
yamlInput: []map[string]any{
|
|
{"name": "p1", "description": "d1", "type": "integer"},
|
|
},
|
|
expectedArgs: Arguments{
|
|
{Parameter: parameters.NewIntParameter("p1", "d1")},
|
|
},
|
|
},
|
|
{
|
|
name: "Parses complex types like arrays correctly",
|
|
yamlInput: []map[string]any{
|
|
{
|
|
"name": "param_array",
|
|
"description": "an array",
|
|
"type": "array",
|
|
"items": map[string]any{
|
|
"name": "item_name",
|
|
"type": "string",
|
|
"description": "an item",
|
|
},
|
|
},
|
|
},
|
|
expectedArgs: Arguments{
|
|
makeArrayArg("param_array", "an array", parameters.NewStringParameter("item_name", "an item")),
|
|
},
|
|
},
|
|
{
|
|
name: "Propagates parsing error for unsupported type",
|
|
yamlInput: []map[string]any{
|
|
{"name": "p1", "description": "d1", "type": "unsupported"},
|
|
},
|
|
wantErr: `"unsupported" is not valid type for a parameter`,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
yamlBytes, err := yaml.Marshal(tc.yamlInput)
|
|
if err != nil {
|
|
t.Fatalf("Test setup failure: could not marshal test input to YAML: %v", err)
|
|
}
|
|
var got Arguments
|
|
ctx, err := testutils.ContextWithNewLogger()
|
|
if err != nil {
|
|
t.Fatalf("Failed to create logger using testutils: %v", err)
|
|
}
|
|
err = yaml.UnmarshalContext(ctx, yamlBytes, &got)
|
|
|
|
if tc.wantErr != "" {
|
|
if err == nil {
|
|
t.Fatalf("UnmarshalContext() expected error but got nil")
|
|
}
|
|
if !strings.Contains(err.Error(), tc.wantErr) {
|
|
t.Errorf("UnmarshalContext() error mismatch:\nwant to contain: %q\ngot: %q", tc.wantErr, err.Error())
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Fatalf("UnmarshalContext() returned unexpected error: %v", err)
|
|
}
|
|
if diff := cmp.Diff(tc.expectedArgs, got, paramComparer); diff != "" {
|
|
t.Errorf("UnmarshalContext() result mismatch (-want +got):\n%s", diff)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParseArguments(t *testing.T) {
|
|
t.Parallel()
|
|
testArguments := prompts.Arguments{
|
|
{Parameter: parameters.NewStringParameter("name", "A required name.")},
|
|
{Parameter: parameters.NewIntParameterWithRequired("count", "An optional count.", false)},
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
argsIn map[string]any
|
|
want parameters.ParamValues
|
|
wantErr string
|
|
}{
|
|
{
|
|
name: "Success with all parameters provided",
|
|
argsIn: map[string]any{
|
|
"name": "test-name",
|
|
"count": 42,
|
|
},
|
|
want: parameters.ParamValues{
|
|
{Name: "name", Value: "test-name"},
|
|
{Name: "count", Value: 42},
|
|
},
|
|
},
|
|
{
|
|
name: "Success with only required parameters",
|
|
argsIn: map[string]any{
|
|
"name": "another-name",
|
|
},
|
|
want: parameters.ParamValues{
|
|
{Name: "name", Value: "another-name"},
|
|
{Name: "count", Value: nil},
|
|
},
|
|
},
|
|
{
|
|
name: "Failure with missing required parameter",
|
|
argsIn: map[string]any{
|
|
"count": 123,
|
|
},
|
|
wantErr: `parameter "name" is required`,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
got, err := prompts.ParseArguments(testArguments, tc.argsIn, nil)
|
|
if tc.wantErr != "" {
|
|
if err == nil {
|
|
t.Fatalf("expected an error but got nil")
|
|
}
|
|
if !strings.Contains(err.Error(), tc.wantErr) {
|
|
t.Errorf("error mismatch:\n want to contain: %q\n got: %q", tc.wantErr, err.Error())
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if diff := cmp.Diff(tc.want, got); diff != "" {
|
|
t.Errorf("ParseArguments() result mismatch (-want +got):\n%s", diff)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|