Files
genai-toolbox/internal/server/api_test.go
2025-01-13 15:38:38 -08:00

355 lines
9.5 KiB
Go

// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/googleapis/genai-toolbox/internal/log"
"github.com/googleapis/genai-toolbox/internal/telemetry"
"github.com/googleapis/genai-toolbox/internal/tools"
)
var _ tools.Tool = &MockTool{}
const fakeVersionString = "0.0.0"
type MockTool struct {
Name string
Description string
Params []tools.Parameter
}
func (t MockTool) Invoke(tools.ParamValues) (string, error) {
return "", nil
}
// claims is a map of user info decoded from an auth token
func (t MockTool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) {
return tools.ParseParams(t.Params, data, claimsMap)
}
func (t MockTool) Manifest() tools.Manifest {
pMs := make([]tools.ParameterManifest, 0, len(t.Params))
for _, p := range t.Params {
pMs = append(pMs, p.Manifest())
}
return tools.Manifest{Description: t.Description, Parameters: pMs}
}
func (t MockTool) Authorized(verifiedAuthSources []string) bool {
return true
}
func TestToolsetEndpoint(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Set up resources to test against
tool1 := MockTool{
Name: "no_params",
Params: []tools.Parameter{},
}
tool2 := MockTool{
Name: "some_params",
Params: tools.Parameters{
tools.NewIntParameter("param1", "This is the first parameter."),
tools.NewIntParameter("param2", "This is the second parameter."),
},
}
toolsMap := map[string]tools.Tool{tool1.Name: tool1, tool2.Name: tool2}
toolsets := make(map[string]tools.Toolset)
for name, l := range map[string][]string{
"": {tool1.Name, tool2.Name},
"tool1_only": {tool1.Name},
"tool2_only": {tool2.Name},
} {
tc := tools.ToolsetConfig{Name: name, ToolNames: l}
m, err := tc.Initialize(fakeVersionString, toolsMap)
if err != nil {
t.Fatalf("unable to initialize toolset %q: %s", name, err)
}
toolsets[name] = m
}
testLogger, err := log.NewStdLogger(os.Stdout, os.Stderr, "info")
if err != nil {
t.Fatalf("unable to initialize logger: %s", err)
}
otelShutdown, err := telemetry.SetupOTel(ctx, fakeVersionString, "", false, "toolbox")
if err != nil {
t.Fatalf("unable to setup otel: %s", err)
}
defer func() {
err := otelShutdown(ctx)
if err != nil {
t.Fatalf("error shutting down OpenTelemetry: %s", err)
}
}()
instrumentation, err := CreateTelemetryInstrumentation(fakeVersionString)
if err != nil {
t.Fatalf("unable to create custom metrics: %s", err)
}
server := Server{logger: testLogger, instrumentation: instrumentation, tools: toolsMap, toolsets: toolsets}
r, err := apiRouter(&server)
if err != nil {
t.Fatalf("unable to initialize router: %s", err)
}
ts := httptest.NewServer(r)
defer ts.Close()
// wantResponse is a struct for checks against test cases
type wantResponse struct {
statusCode int
isErr bool
version string
tools []string
}
testCases := []struct {
name string
toolsetName string
want wantResponse
}{
{
name: "'default' manifest",
toolsetName: "",
want: wantResponse{
statusCode: http.StatusOK,
version: fakeVersionString,
tools: []string{tool1.Name, tool2.Name},
},
},
{
name: "invalid toolset name",
toolsetName: "some_imaginary_toolset",
want: wantResponse{
statusCode: http.StatusNotFound,
isErr: true,
},
},
{
name: "single toolset 1",
toolsetName: "tool1_only",
want: wantResponse{
statusCode: http.StatusOK,
version: fakeVersionString,
tools: []string{tool1.Name},
},
},
{
name: "single toolset 2",
toolsetName: "tool2_only",
want: wantResponse{
statusCode: http.StatusOK,
version: fakeVersionString,
tools: []string{tool2.Name},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
resp, body, err := testRequest(ts, http.MethodGet, fmt.Sprintf("/toolset/%s", tc.toolsetName), nil)
if err != nil {
t.Fatalf("unexpected error during request: %s", err)
}
if contentType := resp.Header.Get("Content-type"); contentType != "application/json" {
t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType)
}
if resp.StatusCode != tc.want.statusCode {
t.Logf("response body: %s", body)
t.Fatalf("unexpected status code: want %d, got %d", tc.want.statusCode, resp.StatusCode)
}
if tc.want.isErr {
// skip the rest of the checks if this is an error case
return
}
var m tools.ToolsetManifest
err = json.Unmarshal(body, &m)
if err != nil {
t.Fatalf("unable to parse ToolsetManifest: %s", err)
}
// Check the version is correct
if m.ServerVersion != tc.want.version {
t.Fatalf("unexpected ServerVersion: want %q, got %q", tc.want.version, m.ServerVersion)
}
// validate that the tools in the toolset are correct
for _, name := range tc.want.tools {
_, ok := m.ToolsManifest[name]
if !ok {
t.Errorf("%q tool not found in manfiest", name)
}
}
})
}
}
func TestToolGetEndpoint(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Set up resources to test against
tool1 := MockTool{
Name: "no_params",
Params: []tools.Parameter{},
}
tool2 := MockTool{
Name: "some_params",
Params: tools.Parameters{
tools.NewIntParameter("param1", "This is the first parameter."),
tools.NewIntParameter("param2", "This is the second parameter."),
},
}
toolsMap := map[string]tools.Tool{tool1.Name: tool1, tool2.Name: tool2}
testLogger, err := log.NewStdLogger(os.Stdout, os.Stderr, "info")
if err != nil {
t.Fatalf("unable to initialize logger: %s", err)
}
otelShutdown, err := telemetry.SetupOTel(ctx, fakeVersionString, "", false, "toolbox")
if err != nil {
t.Fatalf("unable to setup otel: %s", err)
}
defer func() {
err := otelShutdown(ctx)
if err != nil {
t.Fatalf("error shutting down OpenTelemetry: %s", err)
}
}()
instrumentation, err := CreateTelemetryInstrumentation(fakeVersionString)
if err != nil {
t.Fatalf("unable to create custom metrics: %s", err)
}
server := Server{version: fakeVersionString, logger: testLogger, instrumentation: instrumentation, tools: toolsMap}
r, err := apiRouter(&server)
if err != nil {
t.Fatalf("unable to initialize router: %s", err)
}
ts := httptest.NewServer(r)
defer ts.Close()
// wantResponse is a struct for checks against test cases
type wantResponse struct {
statusCode int
isErr bool
version string
tools []string
}
testCases := []struct {
name string
toolName string
want wantResponse
}{
{
name: "tool1",
toolName: tool1.Name,
want: wantResponse{
statusCode: http.StatusOK,
version: fakeVersionString,
tools: []string{tool1.Name},
},
},
{
name: "tool2",
toolName: tool2.Name,
want: wantResponse{
statusCode: http.StatusOK,
version: fakeVersionString,
tools: []string{tool2.Name},
},
},
{
name: "invalid tool",
toolName: "some_imaginary_tool",
want: wantResponse{
statusCode: http.StatusNotFound,
isErr: true,
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
resp, body, err := testRequest(ts, http.MethodGet, fmt.Sprintf("/tool/%s", tc.toolName), nil)
if err != nil {
t.Fatalf("unexpected error during request: %s", err)
}
if contentType := resp.Header.Get("Content-type"); contentType != "application/json" {
t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType)
}
if resp.StatusCode != tc.want.statusCode {
t.Logf("response body: %s", body)
t.Fatalf("unexpected status code: want %d, got %d", tc.want.statusCode, resp.StatusCode)
}
if tc.want.isErr {
// skip the rest of the checks if this is an error case
return
}
var m tools.ToolsetManifest
err = json.Unmarshal(body, &m)
if err != nil {
t.Fatalf("unable to parse ToolsetManifest: %s", err)
}
// Check the version is correct
if m.ServerVersion != tc.want.version {
t.Fatalf("unexpected ServerVersion: want %q, got %q", tc.want.version, m.ServerVersion)
}
// validate that the tools in the toolset are correct
for _, name := range tc.want.tools {
_, ok := m.ToolsManifest[name]
if !ok {
t.Errorf("%q tool not found in manfiest", name)
}
}
})
}
}
func testRequest(ts *httptest.Server, method, path string, body io.Reader) (*http.Response, []byte, error) {
req, err := http.NewRequest(method, ts.URL+path, body)
if err != nil {
return nil, nil, fmt.Errorf("unable to create request: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("unable to send request: %w", err)
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nil, fmt.Errorf("unable to read request body: %w", err)
}
defer resp.Body.Close()
return resp, respBody, nil
}