mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-08 15:14:00 -05:00
First part of the implementation to support semantic search in tools. Second part: https://github.com/googleapis/genai-toolbox/pull/2151
351 lines
9.4 KiB
Go
351 lines
9.4 KiB
Go
// Copyright 2025 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package server
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"testing"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/googleapis/genai-toolbox/internal/log"
|
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
|
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
|
"github.com/googleapis/genai-toolbox/internal/telemetry"
|
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
|
)
|
|
|
|
// fakeVersionString is used as a temporary version string in tests
|
|
const fakeVersionString = "0.0.0"
|
|
|
|
var (
|
|
_ tools.Tool = MockTool{}
|
|
_ prompts.Prompt = MockPrompt{}
|
|
)
|
|
|
|
// MockTool is used to mock tools in tests
|
|
type MockTool struct {
|
|
Name string
|
|
Description string
|
|
Params []parameters.Parameter
|
|
manifest tools.Manifest
|
|
unauthorized bool
|
|
requiresClientAuthrorization bool
|
|
}
|
|
|
|
func (t MockTool) Invoke(context.Context, tools.SourceProvider, parameters.ParamValues, tools.AccessToken) (any, error) {
|
|
mock := []any{t.Name}
|
|
return mock, nil
|
|
}
|
|
|
|
func (t MockTool) ToConfig() tools.ToolConfig {
|
|
return nil
|
|
}
|
|
|
|
// claims is a map of user info decoded from an auth token
|
|
func (t MockTool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) {
|
|
return parameters.ParseParams(t.Params, data, claimsMap)
|
|
}
|
|
|
|
func (t MockTool) Manifest() tools.Manifest {
|
|
pMs := make([]parameters.ParameterManifest, 0, len(t.Params))
|
|
for _, p := range t.Params {
|
|
pMs = append(pMs, p.Manifest())
|
|
}
|
|
return tools.Manifest{Description: t.Description, Parameters: pMs}
|
|
}
|
|
|
|
func (t MockTool) Authorized(verifiedAuthServices []string) bool {
|
|
// defaulted to true
|
|
return !t.unauthorized
|
|
}
|
|
|
|
func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) (bool, error) {
|
|
// defaulted to false
|
|
return t.requiresClientAuthrorization, nil
|
|
}
|
|
|
|
func (t MockTool) McpManifest() tools.McpManifest {
|
|
properties := make(map[string]parameters.ParameterMcpManifest)
|
|
required := make([]string, 0)
|
|
authParams := make(map[string][]string)
|
|
|
|
for _, p := range t.Params {
|
|
name := p.GetName()
|
|
paramManifest, authParamList := p.McpManifest()
|
|
properties[name] = paramManifest
|
|
required = append(required, name)
|
|
|
|
if len(authParamList) > 0 {
|
|
authParams[name] = authParamList
|
|
}
|
|
}
|
|
|
|
toolsSchema := parameters.McpToolsSchema{
|
|
Type: "object",
|
|
Properties: properties,
|
|
Required: required,
|
|
}
|
|
|
|
mcpManifest := tools.McpManifest{
|
|
Name: t.Name,
|
|
Description: t.Description,
|
|
InputSchema: toolsSchema,
|
|
}
|
|
|
|
if len(authParams) > 0 {
|
|
mcpManifest.Metadata = map[string]any{
|
|
"toolbox/authParams": authParams,
|
|
}
|
|
}
|
|
|
|
return mcpManifest
|
|
}
|
|
|
|
func (t MockTool) GetAuthTokenHeaderName(tools.SourceProvider) (string, error) {
|
|
return "Authorization", nil
|
|
}
|
|
|
|
// MockPrompt is used to mock prompts in tests
|
|
type MockPrompt struct {
|
|
Name string
|
|
Description string
|
|
Args prompts.Arguments
|
|
}
|
|
|
|
func (p MockPrompt) SubstituteParams(vals parameters.ParamValues) (any, error) {
|
|
return []prompts.Message{
|
|
{
|
|
Role: "user",
|
|
Content: fmt.Sprintf("substituted %s", p.Name),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (p MockPrompt) ParseArgs(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) {
|
|
var params parameters.Parameters
|
|
for _, arg := range p.Args {
|
|
params = append(params, arg.Parameter)
|
|
}
|
|
return parameters.ParseParams(params, data, claimsMap)
|
|
}
|
|
|
|
func (p MockPrompt) Manifest() prompts.Manifest {
|
|
var argManifests []parameters.ParameterManifest
|
|
for _, arg := range p.Args {
|
|
argManifests = append(argManifests, arg.Manifest())
|
|
}
|
|
return prompts.Manifest{
|
|
Description: p.Description,
|
|
Arguments: argManifests,
|
|
}
|
|
}
|
|
|
|
func (p MockPrompt) McpManifest() prompts.McpManifest {
|
|
return prompts.GetMcpManifest(p.Name, p.Description, p.Args)
|
|
}
|
|
|
|
func (p MockPrompt) ToConfig() prompts.PromptConfig {
|
|
return nil
|
|
}
|
|
|
|
var tool1 = MockTool{
|
|
Name: "no_params",
|
|
Params: []parameters.Parameter{},
|
|
}
|
|
|
|
var tool2 = MockTool{
|
|
Name: "some_params",
|
|
Params: parameters.Parameters{
|
|
parameters.NewIntParameter("param1", "This is the first parameter."),
|
|
parameters.NewIntParameter("param2", "This is the second parameter."),
|
|
},
|
|
}
|
|
|
|
var tool3 = MockTool{
|
|
Name: "array_param",
|
|
Description: "some description",
|
|
Params: parameters.Parameters{
|
|
parameters.NewArrayParameter("my_array", "this param is an array of strings", parameters.NewStringParameter("my_string", "string item")),
|
|
},
|
|
}
|
|
|
|
var tool4 = MockTool{
|
|
Name: "unauthorized_tool",
|
|
Params: []parameters.Parameter{},
|
|
unauthorized: true,
|
|
}
|
|
|
|
var tool5 = MockTool{
|
|
Name: "require_client_auth_tool",
|
|
Params: []parameters.Parameter{},
|
|
requiresClientAuthrorization: true,
|
|
}
|
|
|
|
var prompt1 = MockPrompt{
|
|
Name: "prompt1",
|
|
Args: prompts.Arguments{},
|
|
}
|
|
|
|
var prompt2 = MockPrompt{
|
|
Name: "prompt2",
|
|
Args: prompts.Arguments{
|
|
{Parameter: parameters.NewStringParameter("arg1", "This is the first argument.")},
|
|
},
|
|
}
|
|
|
|
// setUpResources setups resources to test against
|
|
func setUpResources(t *testing.T, mockTools []MockTool, mockPrompts []MockPrompt) (map[string]tools.Tool, map[string]tools.Toolset, map[string]prompts.Prompt, map[string]prompts.Promptset) {
|
|
toolsMap := make(map[string]tools.Tool)
|
|
var allTools []string
|
|
for _, tool := range mockTools {
|
|
tool.manifest = tool.Manifest()
|
|
toolsMap[tool.Name] = tool
|
|
allTools = append(allTools, tool.Name)
|
|
}
|
|
|
|
toolsets := make(map[string]tools.Toolset)
|
|
for name, l := range map[string][]string{
|
|
"": allTools,
|
|
"tool1_only": {allTools[0]},
|
|
"tool2_only": {allTools[1]},
|
|
} {
|
|
tc := tools.ToolsetConfig{Name: name, ToolNames: l}
|
|
m, err := tc.Initialize(fakeVersionString, toolsMap)
|
|
if err != nil {
|
|
t.Fatalf("unable to initialize toolset %q: %s", name, err)
|
|
}
|
|
toolsets[name] = m
|
|
}
|
|
|
|
promptsMap := make(map[string]prompts.Prompt)
|
|
var allPrompts []string
|
|
for _, prompt := range mockPrompts {
|
|
promptsMap[prompt.Name] = prompt
|
|
allPrompts = append(allPrompts, prompt.Name)
|
|
}
|
|
|
|
promptsets := make(map[string]prompts.Promptset)
|
|
if len(allPrompts) > 0 {
|
|
psc := prompts.PromptsetConfig{Name: "", PromptNames: allPrompts}
|
|
ps, err := psc.Initialize(fakeVersionString, promptsMap)
|
|
if err != nil {
|
|
t.Fatalf("unable to initialize default promptset: %s", err)
|
|
}
|
|
promptsets[""] = ps
|
|
}
|
|
|
|
return toolsMap, toolsets, promptsMap, promptsets
|
|
}
|
|
|
|
// setUpServer create a new server with tools, toolsets, prompts, and promptsets.
|
|
func setUpServer(t *testing.T, router string, tools map[string]tools.Tool, toolsets map[string]tools.Toolset, prompts map[string]prompts.Prompt, promptsets map[string]prompts.Promptset) (chi.Router, func()) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
testLogger, err := log.NewStdLogger(os.Stdout, os.Stderr, "info")
|
|
if err != nil {
|
|
t.Fatalf("unable to initialize logger: %s", err)
|
|
}
|
|
|
|
otelShutdown, err := telemetry.SetupOTel(ctx, fakeVersionString, "", false, "toolbox")
|
|
if err != nil {
|
|
t.Fatalf("unable to setup otel: %s", err)
|
|
}
|
|
|
|
instrumentation, err := telemetry.CreateTelemetryInstrumentation(fakeVersionString)
|
|
if err != nil {
|
|
t.Fatalf("unable to create custom metrics: %s", err)
|
|
}
|
|
|
|
sseManager := newSseManager(ctx)
|
|
|
|
resourceManager := resources.NewResourceManager(nil, nil, nil, tools, toolsets, prompts, promptsets)
|
|
|
|
server := Server{
|
|
version: fakeVersionString,
|
|
logger: testLogger,
|
|
instrumentation: instrumentation,
|
|
sseManager: sseManager,
|
|
ResourceMgr: resourceManager,
|
|
}
|
|
|
|
var r chi.Router
|
|
switch router {
|
|
case "api":
|
|
r, err = apiRouter(&server)
|
|
if err != nil {
|
|
t.Fatalf("unable to initialize api router: %s", err)
|
|
}
|
|
case "mcp":
|
|
r, err = mcpRouter(&server)
|
|
if err != nil {
|
|
t.Fatalf("unable to initialize mcp router: %s", err)
|
|
}
|
|
default:
|
|
t.Fatalf("unknown router")
|
|
}
|
|
shutdown := func() {
|
|
// cancel context
|
|
cancel()
|
|
// shutdown otel
|
|
err := otelShutdown(ctx)
|
|
if err != nil {
|
|
t.Fatalf("error shutting down OpenTelemetry: %s", err)
|
|
}
|
|
}
|
|
|
|
return r, shutdown
|
|
}
|
|
|
|
func runServer(r chi.Router, tls bool) *httptest.Server {
|
|
var ts *httptest.Server
|
|
if tls {
|
|
ts = httptest.NewTLSServer(r)
|
|
} else {
|
|
ts = httptest.NewServer(r)
|
|
}
|
|
return ts
|
|
}
|
|
|
|
func runRequest(ts *httptest.Server, method, path string, body io.Reader, header map[string]string) (*http.Response, []byte, error) {
|
|
req, err := http.NewRequest(method, ts.URL+path, body)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("unable to create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
for k, v := range header {
|
|
req.Header.Set(k, v)
|
|
}
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("unable to send request: %w", err)
|
|
}
|
|
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("unable to read request body: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
return resp, respBody, nil
|
|
}
|