mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 16:08:16 -05:00
feat: support requesting a single tool (#56)
Adds support for getting a ToolsManifest with a single tool when a GET `/tools/$toolname` request is sent.
This commit is contained in:
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/go-chi/render"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
// apiRouter creates a router that represents the routes under /api
|
||||
@@ -28,12 +29,14 @@ func apiRouter(s *Server) (chi.Router, error) {
|
||||
r := chi.NewRouter()
|
||||
|
||||
r.Use(middleware.AllowContentType("application/json"))
|
||||
r.Use(middleware.StripSlashes)
|
||||
r.Use(render.SetContentType(render.ContentTypeJSON))
|
||||
|
||||
r.Get("/toolset/", func(w http.ResponseWriter, r *http.Request) { toolsetHandler(s, w, r) })
|
||||
r.Get("/toolset", func(w http.ResponseWriter, r *http.Request) { toolsetHandler(s, w, r) })
|
||||
r.Get("/toolset/{toolsetName}", func(w http.ResponseWriter, r *http.Request) { toolsetHandler(s, w, r) })
|
||||
|
||||
r.Route("/tool/{toolName}", func(r chi.Router) {
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) { toolGetHandler(s, w, r) })
|
||||
r.Post("/invoke", func(w http.ResponseWriter, r *http.Request) { toolInvokeHandler(s, w, r) })
|
||||
})
|
||||
|
||||
@@ -51,6 +54,26 @@ func toolsetHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
render.JSON(w, r, toolset.Manifest)
|
||||
}
|
||||
|
||||
// toolGetHandler handles requests for a single Tool.
|
||||
func toolGetHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
toolName := chi.URLParam(r, "toolName")
|
||||
tool, ok := s.tools[toolName]
|
||||
if !ok {
|
||||
err := fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
|
||||
return
|
||||
}
|
||||
// TODO: this can be optimized later with some caching
|
||||
m := tools.ToolsetManifest{
|
||||
ServerVersion: s.conf.Version,
|
||||
ToolsManifest: map[string]tools.Manifest{
|
||||
toolName: tool.Manifest(),
|
||||
},
|
||||
}
|
||||
|
||||
render.JSON(w, r, m)
|
||||
}
|
||||
|
||||
// toolInvokeHandler handles the API request to invoke a specific Tool.
|
||||
func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
toolName := chi.URLParam(r, "toolName")
|
||||
|
||||
@@ -64,15 +64,15 @@ func TestToolsetEndpoint(t *testing.T) {
|
||||
toolsets[name] = m
|
||||
}
|
||||
|
||||
server := Server{tools: toolsMap, toolsets: toolsets}
|
||||
server := Server{conf: ServerConfig{}, tools: toolsMap, toolsets: toolsets}
|
||||
r, err := apiRouter(&server)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to initalize router: %s", err)
|
||||
t.Fatalf("unable to initialize router: %s", err)
|
||||
}
|
||||
ts := httptest.NewServer(r)
|
||||
defer ts.Close()
|
||||
|
||||
// wantRepsonse is a struct for checks against test cases
|
||||
// wantResponse is a struct for checks against test cases
|
||||
type wantResponse struct {
|
||||
statusCode int
|
||||
isErr bool
|
||||
@@ -160,6 +160,108 @@ func TestToolsetEndpoint(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
func TestToolGetEndpoint(t *testing.T) {
|
||||
// Set up resources to test against
|
||||
tool1 := MockTool{
|
||||
Name: "no_params",
|
||||
Params: []tools.Parameter{},
|
||||
}
|
||||
tool2 := MockTool{
|
||||
Name: "some_params",
|
||||
Params: tools.Parameters{
|
||||
tools.NewIntParameter("param1", "This is the first parameter."),
|
||||
tools.NewIntParameter("param2", "This is the second parameter."),
|
||||
},
|
||||
}
|
||||
toolsMap := map[string]tools.Tool{tool1.Name: tool1, tool2.Name: tool2}
|
||||
|
||||
server := Server{conf: ServerConfig{Version: "0.0.0"}, tools: toolsMap}
|
||||
r, err := apiRouter(&server)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to initialize router: %s", err)
|
||||
}
|
||||
ts := httptest.NewServer(r)
|
||||
defer ts.Close()
|
||||
|
||||
// wantResponse is a struct for checks against test cases
|
||||
type wantResponse struct {
|
||||
statusCode int
|
||||
isErr bool
|
||||
version string
|
||||
tools []string
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
toolName string
|
||||
want wantResponse
|
||||
}{
|
||||
{
|
||||
name: "tool1",
|
||||
toolName: tool1.Name,
|
||||
want: wantResponse{
|
||||
statusCode: http.StatusOK,
|
||||
version: "0.0.0",
|
||||
tools: []string{tool1.Name},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool2",
|
||||
toolName: tool2.Name,
|
||||
want: wantResponse{
|
||||
statusCode: http.StatusOK,
|
||||
version: "0.0.0",
|
||||
tools: []string{tool2.Name},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid tool",
|
||||
toolName: "some_imaginary_tool",
|
||||
want: wantResponse{
|
||||
statusCode: http.StatusNotFound,
|
||||
isErr: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, body, err := testRequest(ts, http.MethodGet, fmt.Sprintf("/tool/%s", tc.toolName), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during request: %s", err)
|
||||
}
|
||||
|
||||
if contentType := resp.Header.Get("Content-type"); contentType != "application/json" {
|
||||
t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType)
|
||||
}
|
||||
|
||||
if resp.StatusCode != tc.want.statusCode {
|
||||
t.Logf("response body: %s", body)
|
||||
t.Fatalf("unexpected status code: want %d, got %d", tc.want.statusCode, resp.StatusCode)
|
||||
}
|
||||
if tc.want.isErr {
|
||||
// skip the rest of the checks if this is an error case
|
||||
return
|
||||
}
|
||||
var m tools.ToolsetManifest
|
||||
err = json.Unmarshal(body, &m)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to parse ToolsetManifest: %s", err)
|
||||
}
|
||||
// Check the version is correct
|
||||
if m.ServerVersion != tc.want.version {
|
||||
t.Fatalf("unexpected ServerVersion: want %q, got %q", tc.want.version, m.ServerVersion)
|
||||
}
|
||||
// validate that the tools in the toolset are correct
|
||||
for _, name := range tc.want.tools {
|
||||
_, ok := m.ToolsManifest[name]
|
||||
if !ok {
|
||||
t.Errorf("%q tool not found in manfiest", name)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testRequest(ts *httptest.Server, method, path string, body io.Reader) (*http.Response, []byte, error) {
|
||||
req, err := http.NewRequest(method, ts.URL+path, body)
|
||||
|
||||
Reference in New Issue
Block a user