mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 07:28:05 -05:00
To facilitate the transition of moving invocation implementation to Source, we will have to move parameter to `internal/util`. This approach is crucial because certain parameters may not be fully resolvable pre-implementation. Since both `internal/sources` and `internal/tools` will need access to `parameters`, it will be more relevant to move parameters implementation to utils.
204 lines
6.2 KiB
Go
204 lines
6.2 KiB
Go
// Copyright 2025 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package prompts_test
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"strings"
|
|
"testing"
|
|
|
|
yaml "github.com/goccy/go-yaml"
|
|
"github.com/google/go-cmp/cmp"
|
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
|
_ "github.com/googleapis/genai-toolbox/internal/prompts/custom"
|
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
|
)
|
|
|
|
type mockPromptConfig struct {
|
|
name string
|
|
kind string
|
|
}
|
|
|
|
func (m *mockPromptConfig) PromptConfigKind() string { return m.kind }
|
|
func (m *mockPromptConfig) Initialize() (prompts.Prompt, error) { return nil, nil }
|
|
|
|
var errMockFactory = errors.New("mock factory error")
|
|
|
|
func mockFactory(ctx context.Context, name string, decoder *yaml.Decoder) (prompts.PromptConfig, error) {
|
|
return &mockPromptConfig{name: name, kind: "mockKind"}, nil
|
|
}
|
|
|
|
func mockErrorFactory(ctx context.Context, name string, decoder *yaml.Decoder) (prompts.PromptConfig, error) {
|
|
return nil, errMockFactory
|
|
}
|
|
|
|
func TestRegistry(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := context.Background()
|
|
|
|
t.Run("RegisterAndDecodeSuccess", func(t *testing.T) {
|
|
kind := "testKindSuccess"
|
|
if !prompts.Register(kind, mockFactory) {
|
|
t.Fatal("expected registration to succeed")
|
|
}
|
|
// This should fail because we are registering a duplicate
|
|
if prompts.Register(kind, mockFactory) {
|
|
t.Fatal("expected duplicate registration to fail")
|
|
}
|
|
|
|
decoder := yaml.NewDecoder(strings.NewReader(""))
|
|
config, err := prompts.DecodeConfig(ctx, kind, "testPrompt", decoder)
|
|
if err != nil {
|
|
t.Fatalf("expected DecodeConfig to succeed, but got error: %v", err)
|
|
}
|
|
if config == nil {
|
|
t.Fatal("expected a non-nil config")
|
|
}
|
|
})
|
|
|
|
t.Run("DecodeUnknownKind", func(t *testing.T) {
|
|
decoder := yaml.NewDecoder(strings.NewReader(""))
|
|
_, err := prompts.DecodeConfig(ctx, "unregisteredKind", "testPrompt", decoder)
|
|
if err == nil {
|
|
t.Fatal("expected an error for unknown kind, but got nil")
|
|
}
|
|
if !strings.Contains(err.Error(), "unknown prompt kind") {
|
|
t.Errorf("expected error to contain 'unknown prompt kind', but got: %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("FactoryReturnsError", func(t *testing.T) {
|
|
kind := "testKindError"
|
|
if !prompts.Register(kind, mockErrorFactory) {
|
|
t.Fatal("expected registration to succeed")
|
|
}
|
|
|
|
decoder := yaml.NewDecoder(strings.NewReader(""))
|
|
_, err := prompts.DecodeConfig(ctx, kind, "testPrompt", decoder)
|
|
if err == nil {
|
|
t.Fatal("expected an error from the factory, but got nil")
|
|
}
|
|
if !errors.Is(err, errMockFactory) {
|
|
t.Errorf("expected error to wrap mock factory error, but it didn't")
|
|
}
|
|
})
|
|
|
|
t.Run("DecodeDefaultsToCustom", func(t *testing.T) {
|
|
decoder := yaml.NewDecoder(strings.NewReader("description: A test prompt"))
|
|
config, err := prompts.DecodeConfig(ctx, "", "testDefaultPrompt", decoder)
|
|
if err != nil {
|
|
t.Fatalf("expected DecodeConfig with empty kind to succeed, but got error: %v", err)
|
|
}
|
|
if config == nil {
|
|
t.Fatal("expected a non-nil config for default kind")
|
|
}
|
|
if config.PromptConfigKind() != "custom" {
|
|
t.Errorf("expected default kind to be 'custom', but got %q", config.PromptConfigKind())
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestGetMcpManifest(t *testing.T) {
|
|
t.Parallel()
|
|
testCases := []struct {
|
|
name string
|
|
promptName string
|
|
description string
|
|
args prompts.Arguments
|
|
want prompts.McpManifest
|
|
}{
|
|
{
|
|
name: "No arguments",
|
|
promptName: "test-prompt",
|
|
description: "A test prompt.",
|
|
args: prompts.Arguments{},
|
|
want: prompts.McpManifest{
|
|
Name: "test-prompt",
|
|
Description: "A test prompt.",
|
|
Arguments: []prompts.ArgMcpManifest{},
|
|
},
|
|
},
|
|
{
|
|
name: "With arguments",
|
|
promptName: "arg-prompt",
|
|
description: "Prompt with args.",
|
|
args: prompts.Arguments{
|
|
{Parameter: parameters.NewStringParameter("param1", "First param")},
|
|
{Parameter: parameters.NewIntParameterWithRequired("param2", "Second param", false)},
|
|
},
|
|
want: prompts.McpManifest{
|
|
Name: "arg-prompt",
|
|
Description: "Prompt with args.",
|
|
Arguments: []prompts.ArgMcpManifest{
|
|
{Name: "param1", Description: "First param", Required: true},
|
|
{Name: "param2", Description: "Second param", Required: false},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
got := prompts.GetMcpManifest(tc.promptName, tc.description, tc.args)
|
|
if diff := cmp.Diff(tc.want, got); diff != "" {
|
|
t.Errorf("GetMcpManifest() mismatch (-want +got):\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetManifest(t *testing.T) {
|
|
t.Parallel()
|
|
testCases := []struct {
|
|
name string
|
|
description string
|
|
args prompts.Arguments
|
|
want prompts.Manifest
|
|
}{
|
|
{
|
|
name: "No arguments",
|
|
description: "A simple prompt.",
|
|
args: prompts.Arguments{},
|
|
want: prompts.Manifest{
|
|
Description: "A simple prompt.",
|
|
Arguments: []parameters.ParameterManifest{},
|
|
},
|
|
},
|
|
{
|
|
name: "With arguments",
|
|
description: "Prompt with arguments.",
|
|
args: prompts.Arguments{
|
|
{Parameter: parameters.NewStringParameter("param1", "First param")},
|
|
{Parameter: parameters.NewBooleanParameterWithRequired("param2", "Second param", false)},
|
|
},
|
|
want: prompts.Manifest{
|
|
Description: "Prompt with arguments.",
|
|
Arguments: []parameters.ParameterManifest{
|
|
{Name: "param1", Type: "string", Required: true, Description: "First param", AuthServices: []string{}},
|
|
{Name: "param2", Type: "boolean", Required: false, Description: "Second param", AuthServices: []string{}},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
got := prompts.GetManifest(tc.description, tc.args)
|
|
if diff := cmp.Diff(tc.want, got); diff != "" {
|
|
t.Errorf("GetManifest() mismatch (-want +got):\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|