mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-02-08 22:25:21 -05:00
Adds logic to make the server shutdown gracefully, including better respecting cancelled contexts and providing up to 10 seconds to finish current connections.
313 lines
8.3 KiB
Go
313 lines
8.3 KiB
Go
// Copyright 2024 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package server
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"testing"
|
|
|
|
"github.com/googleapis/genai-toolbox/internal/log"
|
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
|
)
|
|
|
|
var _ tools.Tool = &MockTool{}
|
|
|
|
type MockTool struct {
|
|
Name string
|
|
Description string
|
|
Params []tools.Parameter
|
|
}
|
|
|
|
func (t MockTool) Invoke(tools.ParamValues) (string, error) {
|
|
return "", nil
|
|
}
|
|
|
|
// claims is a map of user info decoded from an auth token
|
|
func (t MockTool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) {
|
|
return tools.ParseParams(t.Params, data, claimsMap)
|
|
}
|
|
|
|
func (t MockTool) Manifest() tools.Manifest {
|
|
pMs := make([]tools.ParameterManifest, 0, len(t.Params))
|
|
for _, p := range t.Params {
|
|
pMs = append(pMs, p.Manifest())
|
|
}
|
|
return tools.Manifest{Description: t.Description, Parameters: pMs}
|
|
}
|
|
|
|
func (t MockTool) Authorized(verifiedAuthSources []string) bool {
|
|
return true
|
|
}
|
|
|
|
func TestToolsetEndpoint(t *testing.T) {
|
|
// Set up resources to test against
|
|
tool1 := MockTool{
|
|
Name: "no_params",
|
|
Params: []tools.Parameter{},
|
|
}
|
|
tool2 := MockTool{
|
|
Name: "some_params",
|
|
Params: tools.Parameters{
|
|
tools.NewIntParameter("param1", "This is the first parameter."),
|
|
tools.NewIntParameter("param2", "This is the second parameter."),
|
|
},
|
|
}
|
|
toolsMap := map[string]tools.Tool{tool1.Name: tool1, tool2.Name: tool2}
|
|
|
|
toolsets := make(map[string]tools.Toolset)
|
|
for name, l := range map[string][]string{
|
|
"": {tool1.Name, tool2.Name},
|
|
"tool1_only": {tool1.Name},
|
|
"tool2_only": {tool2.Name},
|
|
} {
|
|
tc := tools.ToolsetConfig{Name: name, ToolNames: l}
|
|
m, err := tc.Initialize("0.0.0", toolsMap)
|
|
if err != nil {
|
|
t.Fatalf("unable to initialize toolset %q: %s", name, err)
|
|
}
|
|
toolsets[name] = m
|
|
}
|
|
|
|
testLogger, err := log.NewStdLogger(os.Stdout, os.Stderr, "info")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %s", err)
|
|
}
|
|
server := Server{logger: testLogger, tools: toolsMap, toolsets: toolsets}
|
|
r, err := apiRouter(&server)
|
|
if err != nil {
|
|
t.Fatalf("unable to initialize router: %s", err)
|
|
}
|
|
ts := httptest.NewServer(r)
|
|
defer ts.Close()
|
|
|
|
// wantResponse is a struct for checks against test cases
|
|
type wantResponse struct {
|
|
statusCode int
|
|
isErr bool
|
|
version string
|
|
tools []string
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
toolsetName string
|
|
want wantResponse
|
|
}{
|
|
{
|
|
name: "'default' manifest",
|
|
toolsetName: "",
|
|
want: wantResponse{
|
|
statusCode: http.StatusOK,
|
|
version: "0.0.0",
|
|
tools: []string{tool1.Name, tool2.Name},
|
|
},
|
|
},
|
|
{
|
|
name: "invalid toolset name",
|
|
toolsetName: "some_imaginary_toolset",
|
|
want: wantResponse{
|
|
statusCode: http.StatusNotFound,
|
|
isErr: true,
|
|
},
|
|
},
|
|
{
|
|
name: "single toolset 1",
|
|
toolsetName: "tool1_only",
|
|
want: wantResponse{
|
|
statusCode: http.StatusOK,
|
|
version: "0.0.0",
|
|
tools: []string{tool1.Name},
|
|
},
|
|
},
|
|
{
|
|
name: "single toolset 2",
|
|
toolsetName: "tool2_only",
|
|
want: wantResponse{
|
|
statusCode: http.StatusOK,
|
|
version: "0.0.0",
|
|
tools: []string{tool2.Name},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
resp, body, err := testRequest(ts, http.MethodGet, fmt.Sprintf("/toolset/%s", tc.toolsetName), nil)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error during request: %s", err)
|
|
}
|
|
|
|
if contentType := resp.Header.Get("Content-type"); contentType != "application/json" {
|
|
t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType)
|
|
}
|
|
|
|
if resp.StatusCode != tc.want.statusCode {
|
|
t.Logf("response body: %s", body)
|
|
t.Fatalf("unexpected status code: want %d, got %d", tc.want.statusCode, resp.StatusCode)
|
|
}
|
|
if tc.want.isErr {
|
|
// skip the rest of the checks if this is an error case
|
|
return
|
|
}
|
|
var m tools.ToolsetManifest
|
|
err = json.Unmarshal(body, &m)
|
|
if err != nil {
|
|
t.Fatalf("unable to parse ToolsetManifest: %s", err)
|
|
}
|
|
// Check the version is correct
|
|
if m.ServerVersion != tc.want.version {
|
|
t.Fatalf("unexpected ServerVersion: want %q, got %q", tc.want.version, m.ServerVersion)
|
|
}
|
|
// validate that the tools in the toolset are correct
|
|
for _, name := range tc.want.tools {
|
|
_, ok := m.ToolsManifest[name]
|
|
if !ok {
|
|
t.Errorf("%q tool not found in manfiest", name)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
func TestToolGetEndpoint(t *testing.T) {
|
|
// Set up resources to test against
|
|
tool1 := MockTool{
|
|
Name: "no_params",
|
|
Params: []tools.Parameter{},
|
|
}
|
|
tool2 := MockTool{
|
|
Name: "some_params",
|
|
Params: tools.Parameters{
|
|
tools.NewIntParameter("param1", "This is the first parameter."),
|
|
tools.NewIntParameter("param2", "This is the second parameter."),
|
|
},
|
|
}
|
|
toolsMap := map[string]tools.Tool{tool1.Name: tool1, tool2.Name: tool2}
|
|
|
|
testLogger, err := log.NewStdLogger(os.Stdout, os.Stderr, "info")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %s", err)
|
|
}
|
|
server := Server{version: "0.0.0", logger: testLogger, tools: toolsMap}
|
|
r, err := apiRouter(&server)
|
|
if err != nil {
|
|
t.Fatalf("unable to initialize router: %s", err)
|
|
}
|
|
ts := httptest.NewServer(r)
|
|
defer ts.Close()
|
|
|
|
// wantResponse is a struct for checks against test cases
|
|
type wantResponse struct {
|
|
statusCode int
|
|
isErr bool
|
|
version string
|
|
tools []string
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
toolName string
|
|
want wantResponse
|
|
}{
|
|
{
|
|
name: "tool1",
|
|
toolName: tool1.Name,
|
|
want: wantResponse{
|
|
statusCode: http.StatusOK,
|
|
version: "0.0.0",
|
|
tools: []string{tool1.Name},
|
|
},
|
|
},
|
|
{
|
|
name: "tool2",
|
|
toolName: tool2.Name,
|
|
want: wantResponse{
|
|
statusCode: http.StatusOK,
|
|
version: "0.0.0",
|
|
tools: []string{tool2.Name},
|
|
},
|
|
},
|
|
{
|
|
name: "invalid tool",
|
|
toolName: "some_imaginary_tool",
|
|
want: wantResponse{
|
|
statusCode: http.StatusNotFound,
|
|
isErr: true,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
resp, body, err := testRequest(ts, http.MethodGet, fmt.Sprintf("/tool/%s", tc.toolName), nil)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error during request: %s", err)
|
|
}
|
|
|
|
if contentType := resp.Header.Get("Content-type"); contentType != "application/json" {
|
|
t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType)
|
|
}
|
|
|
|
if resp.StatusCode != tc.want.statusCode {
|
|
t.Logf("response body: %s", body)
|
|
t.Fatalf("unexpected status code: want %d, got %d", tc.want.statusCode, resp.StatusCode)
|
|
}
|
|
if tc.want.isErr {
|
|
// skip the rest of the checks if this is an error case
|
|
return
|
|
}
|
|
var m tools.ToolsetManifest
|
|
err = json.Unmarshal(body, &m)
|
|
if err != nil {
|
|
t.Fatalf("unable to parse ToolsetManifest: %s", err)
|
|
}
|
|
// Check the version is correct
|
|
if m.ServerVersion != tc.want.version {
|
|
t.Fatalf("unexpected ServerVersion: want %q, got %q", tc.want.version, m.ServerVersion)
|
|
}
|
|
// validate that the tools in the toolset are correct
|
|
for _, name := range tc.want.tools {
|
|
_, ok := m.ToolsManifest[name]
|
|
if !ok {
|
|
t.Errorf("%q tool not found in manfiest", name)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func testRequest(ts *httptest.Server, method, path string, body io.Reader) (*http.Response, []byte, error) {
|
|
req, err := http.NewRequest(method, ts.URL+path, body)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("unable to create request: %w", err)
|
|
}
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("unable to send request: %w", err)
|
|
}
|
|
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("unable to read request body: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
return resp, respBody, nil
|
|
}
|