mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-02-05 04:35:14 -05:00
Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> Co-authored-by: Wenxin Du <117315983+duwenxin99@users.noreply.github.com>
355 lines
9.5 KiB
Go
355 lines
9.5 KiB
Go
// Copyright 2024 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package server
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"testing"
|
|
|
|
"github.com/googleapis/genai-toolbox/internal/log"
|
|
"github.com/googleapis/genai-toolbox/internal/telemetry"
|
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
|
)
|
|
|
|
var _ tools.Tool = &MockTool{}
|
|
|
|
const fakeVersionString = "0.0.0"
|
|
|
|
type MockTool struct {
|
|
Name string
|
|
Description string
|
|
Params []tools.Parameter
|
|
}
|
|
|
|
func (t MockTool) Invoke(tools.ParamValues) (string, error) {
|
|
return "", nil
|
|
}
|
|
|
|
// claims is a map of user info decoded from an auth token
|
|
func (t MockTool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) {
|
|
return tools.ParseParams(t.Params, data, claimsMap)
|
|
}
|
|
|
|
func (t MockTool) Manifest() tools.Manifest {
|
|
pMs := make([]tools.ParameterManifest, 0, len(t.Params))
|
|
for _, p := range t.Params {
|
|
pMs = append(pMs, p.Manifest())
|
|
}
|
|
return tools.Manifest{Description: t.Description, Parameters: pMs}
|
|
}
|
|
|
|
func (t MockTool) Authorized(verifiedAuthSources []string) bool {
|
|
return true
|
|
}
|
|
|
|
func TestToolsetEndpoint(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
// Set up resources to test against
|
|
tool1 := MockTool{
|
|
Name: "no_params",
|
|
Params: []tools.Parameter{},
|
|
}
|
|
tool2 := MockTool{
|
|
Name: "some_params",
|
|
Params: tools.Parameters{
|
|
tools.NewIntParameter("param1", "This is the first parameter."),
|
|
tools.NewIntParameter("param2", "This is the second parameter."),
|
|
},
|
|
}
|
|
toolsMap := map[string]tools.Tool{tool1.Name: tool1, tool2.Name: tool2}
|
|
|
|
toolsets := make(map[string]tools.Toolset)
|
|
for name, l := range map[string][]string{
|
|
"": {tool1.Name, tool2.Name},
|
|
"tool1_only": {tool1.Name},
|
|
"tool2_only": {tool2.Name},
|
|
} {
|
|
tc := tools.ToolsetConfig{Name: name, ToolNames: l}
|
|
m, err := tc.Initialize(fakeVersionString, toolsMap)
|
|
if err != nil {
|
|
t.Fatalf("unable to initialize toolset %q: %s", name, err)
|
|
}
|
|
toolsets[name] = m
|
|
}
|
|
|
|
testLogger, err := log.NewStdLogger(os.Stdout, os.Stderr, "info")
|
|
if err != nil {
|
|
t.Fatalf("unable to initialize logger: %s", err)
|
|
}
|
|
|
|
otelShutdown, err := telemetry.SetupOTel(ctx, fakeVersionString, "", false, "toolbox")
|
|
if err != nil {
|
|
t.Fatalf("unable to setup otel: %s", err)
|
|
}
|
|
defer func() {
|
|
err := otelShutdown(ctx)
|
|
if err != nil {
|
|
t.Fatalf("error shutting down OpenTelemetry: %s", err)
|
|
}
|
|
}()
|
|
instrumentation, err := CreateTelemetryInstrumentation(fakeVersionString)
|
|
if err != nil {
|
|
t.Fatalf("unable to create custom metrics: %s", err)
|
|
}
|
|
|
|
server := Server{logger: testLogger, instrumentation: instrumentation, tools: toolsMap, toolsets: toolsets}
|
|
r, err := apiRouter(&server)
|
|
if err != nil {
|
|
t.Fatalf("unable to initialize router: %s", err)
|
|
}
|
|
ts := httptest.NewServer(r)
|
|
defer ts.Close()
|
|
|
|
// wantResponse is a struct for checks against test cases
|
|
type wantResponse struct {
|
|
statusCode int
|
|
isErr bool
|
|
version string
|
|
tools []string
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
toolsetName string
|
|
want wantResponse
|
|
}{
|
|
{
|
|
name: "'default' manifest",
|
|
toolsetName: "",
|
|
want: wantResponse{
|
|
statusCode: http.StatusOK,
|
|
version: fakeVersionString,
|
|
tools: []string{tool1.Name, tool2.Name},
|
|
},
|
|
},
|
|
{
|
|
name: "invalid toolset name",
|
|
toolsetName: "some_imaginary_toolset",
|
|
want: wantResponse{
|
|
statusCode: http.StatusNotFound,
|
|
isErr: true,
|
|
},
|
|
},
|
|
{
|
|
name: "single toolset 1",
|
|
toolsetName: "tool1_only",
|
|
want: wantResponse{
|
|
statusCode: http.StatusOK,
|
|
version: fakeVersionString,
|
|
tools: []string{tool1.Name},
|
|
},
|
|
},
|
|
{
|
|
name: "single toolset 2",
|
|
toolsetName: "tool2_only",
|
|
want: wantResponse{
|
|
statusCode: http.StatusOK,
|
|
version: fakeVersionString,
|
|
tools: []string{tool2.Name},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
resp, body, err := testRequest(ts, http.MethodGet, fmt.Sprintf("/toolset/%s", tc.toolsetName), nil)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error during request: %s", err)
|
|
}
|
|
|
|
if contentType := resp.Header.Get("Content-type"); contentType != "application/json" {
|
|
t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType)
|
|
}
|
|
|
|
if resp.StatusCode != tc.want.statusCode {
|
|
t.Logf("response body: %s", body)
|
|
t.Fatalf("unexpected status code: want %d, got %d", tc.want.statusCode, resp.StatusCode)
|
|
}
|
|
if tc.want.isErr {
|
|
// skip the rest of the checks if this is an error case
|
|
return
|
|
}
|
|
var m tools.ToolsetManifest
|
|
err = json.Unmarshal(body, &m)
|
|
if err != nil {
|
|
t.Fatalf("unable to parse ToolsetManifest: %s", err)
|
|
}
|
|
// Check the version is correct
|
|
if m.ServerVersion != tc.want.version {
|
|
t.Fatalf("unexpected ServerVersion: want %q, got %q", tc.want.version, m.ServerVersion)
|
|
}
|
|
// validate that the tools in the toolset are correct
|
|
for _, name := range tc.want.tools {
|
|
_, ok := m.ToolsManifest[name]
|
|
if !ok {
|
|
t.Errorf("%q tool not found in manfiest", name)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
func TestToolGetEndpoint(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
// Set up resources to test against
|
|
tool1 := MockTool{
|
|
Name: "no_params",
|
|
Params: []tools.Parameter{},
|
|
}
|
|
tool2 := MockTool{
|
|
Name: "some_params",
|
|
Params: tools.Parameters{
|
|
tools.NewIntParameter("param1", "This is the first parameter."),
|
|
tools.NewIntParameter("param2", "This is the second parameter."),
|
|
},
|
|
}
|
|
toolsMap := map[string]tools.Tool{tool1.Name: tool1, tool2.Name: tool2}
|
|
|
|
testLogger, err := log.NewStdLogger(os.Stdout, os.Stderr, "info")
|
|
if err != nil {
|
|
t.Fatalf("unable to initialize logger: %s", err)
|
|
}
|
|
|
|
otelShutdown, err := telemetry.SetupOTel(ctx, fakeVersionString, "", false, "toolbox")
|
|
if err != nil {
|
|
t.Fatalf("unable to setup otel: %s", err)
|
|
}
|
|
defer func() {
|
|
err := otelShutdown(ctx)
|
|
if err != nil {
|
|
t.Fatalf("error shutting down OpenTelemetry: %s", err)
|
|
}
|
|
}()
|
|
instrumentation, err := CreateTelemetryInstrumentation(fakeVersionString)
|
|
if err != nil {
|
|
t.Fatalf("unable to create custom metrics: %s", err)
|
|
}
|
|
|
|
server := Server{version: fakeVersionString, logger: testLogger, instrumentation: instrumentation, tools: toolsMap}
|
|
r, err := apiRouter(&server)
|
|
if err != nil {
|
|
t.Fatalf("unable to initialize router: %s", err)
|
|
}
|
|
ts := httptest.NewServer(r)
|
|
defer ts.Close()
|
|
|
|
// wantResponse is a struct for checks against test cases
|
|
type wantResponse struct {
|
|
statusCode int
|
|
isErr bool
|
|
version string
|
|
tools []string
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
toolName string
|
|
want wantResponse
|
|
}{
|
|
{
|
|
name: "tool1",
|
|
toolName: tool1.Name,
|
|
want: wantResponse{
|
|
statusCode: http.StatusOK,
|
|
version: fakeVersionString,
|
|
tools: []string{tool1.Name},
|
|
},
|
|
},
|
|
{
|
|
name: "tool2",
|
|
toolName: tool2.Name,
|
|
want: wantResponse{
|
|
statusCode: http.StatusOK,
|
|
version: fakeVersionString,
|
|
tools: []string{tool2.Name},
|
|
},
|
|
},
|
|
{
|
|
name: "invalid tool",
|
|
toolName: "some_imaginary_tool",
|
|
want: wantResponse{
|
|
statusCode: http.StatusNotFound,
|
|
isErr: true,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
resp, body, err := testRequest(ts, http.MethodGet, fmt.Sprintf("/tool/%s", tc.toolName), nil)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error during request: %s", err)
|
|
}
|
|
|
|
if contentType := resp.Header.Get("Content-type"); contentType != "application/json" {
|
|
t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType)
|
|
}
|
|
|
|
if resp.StatusCode != tc.want.statusCode {
|
|
t.Logf("response body: %s", body)
|
|
t.Fatalf("unexpected status code: want %d, got %d", tc.want.statusCode, resp.StatusCode)
|
|
}
|
|
if tc.want.isErr {
|
|
// skip the rest of the checks if this is an error case
|
|
return
|
|
}
|
|
var m tools.ToolsetManifest
|
|
err = json.Unmarshal(body, &m)
|
|
if err != nil {
|
|
t.Fatalf("unable to parse ToolsetManifest: %s", err)
|
|
}
|
|
// Check the version is correct
|
|
if m.ServerVersion != tc.want.version {
|
|
t.Fatalf("unexpected ServerVersion: want %q, got %q", tc.want.version, m.ServerVersion)
|
|
}
|
|
// validate that the tools in the toolset are correct
|
|
for _, name := range tc.want.tools {
|
|
_, ok := m.ToolsManifest[name]
|
|
if !ok {
|
|
t.Errorf("%q tool not found in manfiest", name)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func testRequest(ts *httptest.Server, method, path string, body io.Reader) (*http.Response, []byte, error) {
|
|
req, err := http.NewRequest(method, ts.URL+path, body)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("unable to create request: %w", err)
|
|
}
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("unable to send request: %w", err)
|
|
}
|
|
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("unable to read request body: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
return resp, respBody, nil
|
|
}
|