// Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package server import ( "context" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "os" "testing" "github.com/googleapis/genai-toolbox/internal/log" "github.com/googleapis/genai-toolbox/internal/telemetry" "github.com/googleapis/genai-toolbox/internal/tools" ) var _ tools.Tool = &MockTool{} const fakeVersionString = "0.0.0" type MockTool struct { Name string Description string Params []tools.Parameter } func (t MockTool) Invoke(tools.ParamValues) (string, error) { return "", nil } // claims is a map of user info decoded from an auth token func (t MockTool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) { return tools.ParseParams(t.Params, data, claimsMap) } func (t MockTool) Manifest() tools.Manifest { pMs := make([]tools.ParameterManifest, 0, len(t.Params)) for _, p := range t.Params { pMs = append(pMs, p.Manifest()) } return tools.Manifest{Description: t.Description, Parameters: pMs} } func (t MockTool) Authorized(verifiedAuthSources []string) bool { return true } func TestToolsetEndpoint(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Set up resources to test against tool1 := MockTool{ Name: "no_params", Params: []tools.Parameter{}, } tool2 := MockTool{ Name: "some_params", Params: tools.Parameters{ tools.NewIntParameter("param1", "This is the first parameter."), tools.NewIntParameter("param2", "This is the second parameter."), }, } toolsMap := map[string]tools.Tool{tool1.Name: tool1, tool2.Name: tool2} toolsets := make(map[string]tools.Toolset) for name, l := range map[string][]string{ "": {tool1.Name, tool2.Name}, "tool1_only": {tool1.Name}, "tool2_only": {tool2.Name}, } { tc := tools.ToolsetConfig{Name: name, ToolNames: l} m, err := tc.Initialize(fakeVersionString, toolsMap) if err != nil { t.Fatalf("unable to initialize toolset %q: %s", name, err) } toolsets[name] = m } testLogger, err := log.NewStdLogger(os.Stdout, os.Stderr, "info") if err != nil { t.Fatalf("unable to initialize logger: %s", err) } otelShutdown, err := telemetry.SetupOTel(ctx, fakeVersionString, "", false, "toolbox") if err != nil { t.Fatalf("unable to setup otel: %s", err) } defer func() { err := otelShutdown(ctx) if err != nil { t.Fatalf("error shutting down OpenTelemetry: %s", err) } }() instrumentation, err := CreateTelemetryInstrumentation(fakeVersionString) if err != nil { t.Fatalf("unable to create custom metrics: %s", err) } server := Server{logger: testLogger, instrumentation: instrumentation, tools: toolsMap, toolsets: toolsets} r, err := apiRouter(&server) if err != nil { t.Fatalf("unable to initialize router: %s", err) } ts := httptest.NewServer(r) defer ts.Close() // wantResponse is a struct for checks against test cases type wantResponse struct { statusCode int isErr bool version string tools []string } testCases := []struct { name string toolsetName string want wantResponse }{ { name: "'default' manifest", toolsetName: "", want: wantResponse{ statusCode: http.StatusOK, version: fakeVersionString, tools: []string{tool1.Name, tool2.Name}, }, }, { name: "invalid toolset name", toolsetName: "some_imaginary_toolset", want: wantResponse{ statusCode: http.StatusNotFound, isErr: true, }, }, { name: "single toolset 1", toolsetName: "tool1_only", want: wantResponse{ statusCode: http.StatusOK, version: fakeVersionString, tools: []string{tool1.Name}, }, }, { name: "single toolset 2", toolsetName: "tool2_only", want: wantResponse{ statusCode: http.StatusOK, version: fakeVersionString, tools: []string{tool2.Name}, }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { resp, body, err := testRequest(ts, http.MethodGet, fmt.Sprintf("/toolset/%s", tc.toolsetName), nil) if err != nil { t.Fatalf("unexpected error during request: %s", err) } if contentType := resp.Header.Get("Content-type"); contentType != "application/json" { t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType) } if resp.StatusCode != tc.want.statusCode { t.Logf("response body: %s", body) t.Fatalf("unexpected status code: want %d, got %d", tc.want.statusCode, resp.StatusCode) } if tc.want.isErr { // skip the rest of the checks if this is an error case return } var m tools.ToolsetManifest err = json.Unmarshal(body, &m) if err != nil { t.Fatalf("unable to parse ToolsetManifest: %s", err) } // Check the version is correct if m.ServerVersion != tc.want.version { t.Fatalf("unexpected ServerVersion: want %q, got %q", tc.want.version, m.ServerVersion) } // validate that the tools in the toolset are correct for _, name := range tc.want.tools { _, ok := m.ToolsManifest[name] if !ok { t.Errorf("%q tool not found in manfiest", name) } } }) } } func TestToolGetEndpoint(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Set up resources to test against tool1 := MockTool{ Name: "no_params", Params: []tools.Parameter{}, } tool2 := MockTool{ Name: "some_params", Params: tools.Parameters{ tools.NewIntParameter("param1", "This is the first parameter."), tools.NewIntParameter("param2", "This is the second parameter."), }, } toolsMap := map[string]tools.Tool{tool1.Name: tool1, tool2.Name: tool2} testLogger, err := log.NewStdLogger(os.Stdout, os.Stderr, "info") if err != nil { t.Fatalf("unable to initialize logger: %s", err) } otelShutdown, err := telemetry.SetupOTel(ctx, fakeVersionString, "", false, "toolbox") if err != nil { t.Fatalf("unable to setup otel: %s", err) } defer func() { err := otelShutdown(ctx) if err != nil { t.Fatalf("error shutting down OpenTelemetry: %s", err) } }() instrumentation, err := CreateTelemetryInstrumentation(fakeVersionString) if err != nil { t.Fatalf("unable to create custom metrics: %s", err) } server := Server{version: fakeVersionString, logger: testLogger, instrumentation: instrumentation, tools: toolsMap} r, err := apiRouter(&server) if err != nil { t.Fatalf("unable to initialize router: %s", err) } ts := httptest.NewServer(r) defer ts.Close() // wantResponse is a struct for checks against test cases type wantResponse struct { statusCode int isErr bool version string tools []string } testCases := []struct { name string toolName string want wantResponse }{ { name: "tool1", toolName: tool1.Name, want: wantResponse{ statusCode: http.StatusOK, version: fakeVersionString, tools: []string{tool1.Name}, }, }, { name: "tool2", toolName: tool2.Name, want: wantResponse{ statusCode: http.StatusOK, version: fakeVersionString, tools: []string{tool2.Name}, }, }, { name: "invalid tool", toolName: "some_imaginary_tool", want: wantResponse{ statusCode: http.StatusNotFound, isErr: true, }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { resp, body, err := testRequest(ts, http.MethodGet, fmt.Sprintf("/tool/%s", tc.toolName), nil) if err != nil { t.Fatalf("unexpected error during request: %s", err) } if contentType := resp.Header.Get("Content-type"); contentType != "application/json" { t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType) } if resp.StatusCode != tc.want.statusCode { t.Logf("response body: %s", body) t.Fatalf("unexpected status code: want %d, got %d", tc.want.statusCode, resp.StatusCode) } if tc.want.isErr { // skip the rest of the checks if this is an error case return } var m tools.ToolsetManifest err = json.Unmarshal(body, &m) if err != nil { t.Fatalf("unable to parse ToolsetManifest: %s", err) } // Check the version is correct if m.ServerVersion != tc.want.version { t.Fatalf("unexpected ServerVersion: want %q, got %q", tc.want.version, m.ServerVersion) } // validate that the tools in the toolset are correct for _, name := range tc.want.tools { _, ok := m.ToolsManifest[name] if !ok { t.Errorf("%q tool not found in manfiest", name) } } }) } } func testRequest(ts *httptest.Server, method, path string, body io.Reader) (*http.Response, []byte, error) { req, err := http.NewRequest(method, ts.URL+path, body) if err != nil { return nil, nil, fmt.Errorf("unable to create request: %w", err) } resp, err := http.DefaultClient.Do(req) if err != nil { return nil, nil, fmt.Errorf("unable to send request: %w", err) } respBody, err := io.ReadAll(resp.Body) if err != nil { return nil, nil, fmt.Errorf("unable to read request body: %w", err) } defer resp.Body.Close() return resp, respBody, nil }