feat: Add Toolset manifest endpoint (#11)

1. Calculate tool manifests when server starts.
2. Add toolset manifest endpoints.

---------

Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>
This commit is contained in:
Wenxin Du
2024-10-17 18:43:58 -04:00
committed by GitHub
parent 26ed13561a
commit 61e7b78ad8
11 changed files with 323 additions and 77 deletions

View File

@@ -77,3 +77,11 @@ jobs:
- name: Run tests
run: go test -race -v ./...
env:
PROJECT: ${{ secrets.PROJECT }}
CLOUD_SQL_PG_INSTANCE: ${{ secrets.CLOUD_SQL_PG_INSTANCE }}
REGION: ${{ secrets.REGION }}
DATABASE: ${{ secrets.DATABASE }}
USER: ${{ secrets.USER }}
PASSWORD: ${{ secrets.PASSWORD }}

View File

@@ -109,21 +109,21 @@ func run(cmd *Command) error {
// Read tool file contents
buf, err := os.ReadFile(cmd.tools_file)
if err != nil {
return fmt.Errorf("Unable to read tool file at %q: %w", cmd.tools_file, err)
return fmt.Errorf("unable to read tool file at %q: %w", cmd.tools_file, err)
}
cmd.cfg.SourceConfigs, cmd.cfg.ToolConfigs, cmd.cfg.ToolsetConfigs, err = parseToolsFile(buf)
if err != nil {
return fmt.Errorf("Unable to parse tool file at %q: %w", cmd.tools_file, err)
return fmt.Errorf("unable to parse tool file at %q: %w", cmd.tools_file, err)
}
// run server
s, err := server.NewServer(cmd.cfg)
if err != nil {
return fmt.Errorf("Toolbox failed to start with the following error: %w", err)
return fmt.Errorf("toolbox failed to start with the following error: %w", err)
}
err = s.ListenAndServe(ctx)
if err != nil {
return fmt.Errorf("Toolbox crashed with the following error: %w", err)
return fmt.Errorf("toolbox crashed with the following error: %w", err)
}
return nil

View File

@@ -15,7 +15,9 @@
package server
import (
"encoding/json"
"fmt"
"log"
"net/http"
"github.com/go-chi/chi/v5"
@@ -24,36 +26,87 @@ import (
)
// apiRouter creates a router that represents the routes under /api
func apiRouter(s *Server) chi.Router {
func apiRouter(s *Server) (chi.Router, error) {
r := chi.NewRouter()
r.Get("/toolset/{toolsetName}", toolsetHandler(s))
r.Get("/toolset/", func(w http.ResponseWriter, r *http.Request) { toolsetHandler(s, w, r) })
r.Get("/toolset/{toolsetName}", func(w http.ResponseWriter, r *http.Request) { toolsetHandler(s, w, r) })
r.Route("/tool/{toolName}", func(r chi.Router) {
r.Use(middleware.AllowContentType("application/json"))
r.Post("/invoke", newToolHandler(s))
r.Post("/invoke", func(w http.ResponseWriter, r *http.Request) { toolInvokeHandler(s, w, r) })
})
return r
return r, nil
}
func toolsetHandler(s *Server) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
toolsetName := chi.URLParam(r, "toolsetName")
_, _ = w.Write([]byte(fmt.Sprintf("Stub for toolset %s manifest!", toolsetName)))
// toolInvokeHandler handles the request for information about a Toolset.
func toolsetHandler(s *Server, w http.ResponseWriter, r *http.Request) {
toolsetName := chi.URLParam(r, "toolsetName")
toolset, ok := s.toolsets[toolsetName]
if !ok {
http.Error(w, fmt.Sprintf("Toolset %q does not exist", toolsetName), http.StatusNotFound)
return
}
b, err := json.Marshal(toolset.Manifest)
if err != nil {
log.Printf("unable to JSON the toolset manifest: %s", err)
http.Error(w, "", http.StatusInternalServerError)
return
}
_, _ = w.Write(b)
}
// toolInvokeHandler handles the API request to invoke a specific Tool.
func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
toolName := chi.URLParam(r, "toolName")
tool, ok := s.tools[toolName]
if !ok {
err := fmt.Errorf("Invalid tool name. Tool with name %q does not exist", toolName)
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
return
}
var data map[string]interface{}
if err := render.DecodeJSON(r.Body, &data); err != nil {
render.Status(r, http.StatusBadRequest)
err := fmt.Errorf("Request body was invalid JSON: %w", err)
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
return
}
params, err := tool.ParseParams(data)
if err != nil {
err := fmt.Errorf("Provided parameters were invalid: %w", err)
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
return
}
res, err := tool.Invoke(params)
if err != nil {
err := fmt.Errorf("Error while invoking tool: %w", err)
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
return
}
_ = render.Render(w, r, &resultResponse{Result: res})
}
var _ render.Renderer = &resultResponse{} // Renderer interface for managing response payloads.
// resultResponse is the response sent back when the tool was invocated succesffully.
type resultResponse struct {
Result string `json:"result"` // result of tool invocation
}
// Render renders a single payload and respond to the client request.
func (rr resultResponse) Render(w http.ResponseWriter, r *http.Request) error {
render.Status(r, http.StatusOK)
return nil
}
var _ render.Renderer = &errResponse{} // Renderer interface for managing response payloads.
// newErrResponse is a helper function initalizing an ErrResponse
func newErrResponse(err error, code int) *errResponse {
return &errResponse{
@@ -78,39 +131,3 @@ func (e *errResponse) Render(w http.ResponseWriter, r *http.Request) error {
render.Status(r, e.HTTPStatusCode)
return nil
}
func newToolHandler(s *Server) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
toolName := chi.URLParam(r, "toolName")
tool, ok := s.tools[toolName]
if !ok {
err := fmt.Errorf("Invalid tool name. Tool with name %q does not exist", toolName)
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
return
}
var data map[string]interface{}
if err := render.DecodeJSON(r.Body, &data); err != nil {
render.Status(r, http.StatusBadRequest)
err := fmt.Errorf("Request body was invalid JSON: %w", err)
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
return
}
params, err := tool.ParseParams(data)
if err != nil {
err := fmt.Errorf("Provided parameters were invalid: %w", err)
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
return
}
res, err := tool.Invoke(params)
if err != nil {
err := fmt.Errorf("Error while invoking tool: %w", err)
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
return
}
_ = render.Render(w, r, &resultResponse{Result: res})
}
}

181
internal/server/api_test.go Normal file
View File

@@ -0,0 +1,181 @@
package server
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/googleapis/genai-toolbox/internal/tools"
)
var _ tools.Tool = &MockTool{}
type MockTool struct {
Name string
Description string
Params []tools.Parameter
}
func (t MockTool) Invoke([]any) (string, error) {
return "", nil
}
func (t MockTool) ParseParams(data map[string]any) ([]any, error) {
return tools.ParseParams(t.Params, data)
}
func (t MockTool) Manifest() tools.Manifest {
return tools.Manifest{Description: t.Description, Parameters: t.Params}
}
func TestToolsetEndpoint(t *testing.T) {
// Set up resources to test against
tool1 := MockTool{
Name: "no_params",
Params: []tools.Parameter{},
}
tool2 := MockTool{
Name: "some_params",
Params: []tools.Parameter{
{
Name: "param1",
Type: "int",
Description: "This is the first parameter.",
},
{
Name: "param2",
Type: "string",
Description: "This is the second parameter.",
},
},
}
toolsMap := map[string]tools.Tool{tool1.Name: tool1, tool2.Name: tool2}
toolsets := make(map[string]tools.Toolset)
for name, l := range map[string][]string{
"": {tool1.Name, tool2.Name},
"tool1_only": {tool1.Name},
"tool2_only": {tool2.Name},
} {
tc := tools.ToolsetConfig{Name: name, ToolNames: l}
m, err := tc.Initialize("0.0.0", toolsMap)
if err != nil {
t.Fatalf("unable to initialize toolset %q: %s", name, err)
}
toolsets[name] = m
}
server := Server{tools: toolsMap, toolsets: toolsets}
r, err := apiRouter(&server)
if err != nil {
t.Fatalf("unable to initalize router: %s", err)
}
ts := httptest.NewServer(r)
defer ts.Close()
// wantRepsonse is a struct for checks against test cases
type wantResponse struct {
statusCode int
isErr bool
version string
tools []string
}
testCases := []struct {
name string
toolsetName string
want wantResponse
}{
{
name: "'default' manifest",
toolsetName: "",
want: wantResponse{
statusCode: http.StatusOK,
version: "0.0.0",
tools: []string{tool1.Name, tool2.Name},
},
},
{
name: "invalid toolset name",
toolsetName: "some_imaginary_toolset",
want: wantResponse{
statusCode: http.StatusNotFound,
isErr: true,
},
},
{
name: "single toolset 1",
toolsetName: "tool1_only",
want: wantResponse{
statusCode: http.StatusOK,
version: "0.0.0",
tools: []string{tool1.Name},
},
},
{
name: "single toolset 2",
toolsetName: "tool2_only",
want: wantResponse{
statusCode: http.StatusOK,
version: "0.0.0",
tools: []string{tool2.Name},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
resp, body, err := testRequest(ts, http.MethodGet, fmt.Sprintf("/toolset/%s", tc.toolsetName), nil)
if err != nil {
t.Fatalf("unexpected error during request: %s", err)
}
if resp.StatusCode != tc.want.statusCode {
t.Logf("response body: %s", body)
t.Fatalf("unexpected status code: want %d, got %d", tc.want.statusCode, resp.StatusCode)
}
if tc.want.isErr {
// skip the rest of the checks if this is an error case
return
}
var m tools.ToolsetManifest
err = json.Unmarshal(body, &m)
if err != nil {
t.Fatalf("unable to parse ToolsetManifest: %s", err)
}
// Check the version is correct
if m.ServerVersion != tc.want.version {
t.Fatalf("unexpected ServerVersion: want %q, got %q", tc.want.version, m.ServerVersion)
}
// validate that the tools in the toolset are correct
for _, name := range tc.want.tools {
_, ok := m.ToolsManifest[name]
if !ok {
t.Errorf("%q tool not found in manfiest", name)
}
}
})
}
}
func testRequest(ts *httptest.Server, method, path string, body io.Reader) (*http.Response, []byte, error) {
req, err := http.NewRequest(method, ts.URL+path, body)
if err != nil {
return nil, nil, fmt.Errorf("unable to create request: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("unable to send request: %w", err)
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nil, fmt.Errorf("unable to read request body: %w", err)
}
defer resp.Body.Close()
return resp, respBody, nil
}

View File

@@ -19,6 +19,8 @@ import (
)
type Config struct {
// Server version
Version string
// Address is the address of the interface the server will listen on.
Address string
// Port is the port the server will listen on.

View File

@@ -53,7 +53,7 @@ func NewServer(cfg Config) (*Server, error) {
for name, sc := range cfg.SourceConfigs {
s, err := sc.Initialize()
if err != nil {
return nil, fmt.Errorf("Unable to initialize tool %s: %w", name, err)
return nil, fmt.Errorf("unable to initialize source %q: %w", name, err)
}
sourcesMap[name] = s
}
@@ -64,22 +64,31 @@ func NewServer(cfg Config) (*Server, error) {
for name, tc := range cfg.ToolConfigs {
t, err := tc.Initialize(sourcesMap)
if err != nil {
return nil, fmt.Errorf("Unable to initialize tool %s: %w", name, err)
return nil, fmt.Errorf("unable to initialize tool %q: %w", name, err)
}
toolsMap[name] = t
}
fmt.Printf("Initalized %d tools.\n", len(toolsMap))
// initalize and validate the tools
// create a default toolset that contains all tools
allToolNames := make([]string, 0, len(toolsMap))
for name := range toolsMap {
allToolNames = append(allToolNames, name)
}
if cfg.ToolsetConfigs == nil {
cfg.ToolsetConfigs = make(tools.ToolsetConfigs)
}
cfg.ToolsetConfigs[""] = tools.ToolsetConfig{Name: "", ToolNames: allToolNames}
// initalize and validate the toolsets
toolsetsMap := make(map[string]tools.Toolset)
for name, tc := range cfg.ToolsetConfigs {
t, err := tc.Initialize(toolsMap)
t, err := tc.Initialize(cfg.Version, toolsMap)
if err != nil {
return nil, fmt.Errorf("Unable to initialize toolset %s: %w", name, err)
return nil, fmt.Errorf("unable to initialize toolset %q: %w", name, err)
}
toolsetsMap[name] = t
}
fmt.Printf("Initalized %d tools.\n", len(toolsetsMap))
fmt.Printf("Initalized %d toolsets.\n", len(toolsetsMap))
s := &Server{
conf: cfg,
@@ -88,7 +97,12 @@ func NewServer(cfg Config) (*Server, error) {
tools: toolsMap,
toolsets: toolsetsMap,
}
r.Mount("/api", apiRouter(s))
if router, err := apiRouter(s); err != nil {
return nil, err
} else {
r.Mount("/api", router)
}
return s, nil
}

View File

@@ -49,7 +49,7 @@ func TestServe(t *testing.T) {
}
s, err := server.NewServer(cfg)
if err != nil {
t.Fatalf("Unable initialize server!")
t.Fatalf("Unable to initialize server! %v", err)
}
// start server in background

View File

@@ -36,30 +36,31 @@ type CloudSQLPgGenericConfig struct {
Parameters []Parameter `yaml:"parameters"`
}
func (r CloudSQLPgGenericConfig) toolKind() string {
func (cfg CloudSQLPgGenericConfig) toolKind() string {
return CloudSQLPgSQLGenericKind
}
func (r CloudSQLPgGenericConfig) Initialize(srcs map[string]sources.Source) (Tool, error) {
func (cfg CloudSQLPgGenericConfig) Initialize(srcs map[string]sources.Source) (Tool, error) {
// verify source exists
rawS, ok := srcs[r.Source]
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("No source named %q configured!", r.Source)
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is the right kind
s, ok := rawS.(sources.CloudSQLPgSource)
if !ok {
return nil, fmt.Errorf("Sources for %q tools must be of kind %q!", CloudSQLPgSQLGenericKind, sources.CloudSQLPgKind)
return nil, fmt.Errorf("sources for %q tools must be of kind %q", CloudSQLPgSQLGenericKind, sources.CloudSQLPgKind)
}
// finish tool setup
t := CloudSQLPgGenericTool{
Name: r.Name,
Name: cfg.Name,
Kind: CloudSQLPgSQLGenericKind,
Source: s,
Statement: r.Statement,
Parameters: r.Parameters,
Statement: cfg.Statement,
Parameters: cfg.Parameters,
manifest: Manifest{cfg.Description, cfg.Parameters},
}
return t, nil
}
@@ -73,20 +74,21 @@ type CloudSQLPgGenericTool struct {
Source sources.CloudSQLPgSource
Statement string
Parameters []Parameter `yaml:"parameters"`
manifest Manifest
}
func (t CloudSQLPgGenericTool) Invoke(params []any) (string, error) {
fmt.Printf("Invoked tool %s\n", t.Name)
results, err := t.Source.Pool.Query(context.Background(), t.Statement, params...)
if err != nil {
return "", fmt.Errorf("Unable to execute query: %w", err)
return "", fmt.Errorf("unable to execute query: %w", err)
}
var out strings.Builder
for results.Next() {
v, err := results.Values()
if err != nil {
return "", fmt.Errorf("Unable to parse row: %w", err)
return "", fmt.Errorf("unable to parse row: %w", err)
}
out.WriteString(fmt.Sprintf("%s", v))
}
@@ -95,5 +97,9 @@ func (t CloudSQLPgGenericTool) Invoke(params []any) (string, error) {
}
func (t CloudSQLPgGenericTool) ParseParams(data map[string]any) ([]any, error) {
return parseParams(t.Parameters, data)
return ParseParams(t.Parameters, data)
}
func (t CloudSQLPgGenericTool) Manifest() Manifest {
return t.manifest
}

View File

@@ -18,7 +18,7 @@ import (
"regexp"
)
var validName = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
var validName = regexp.MustCompile(`^[a-zA-Z0-9_-]*$`)
func IsValidName(s string) bool {
return validName.MatchString(s)

View File

@@ -66,6 +66,7 @@ func (c *Configs) UnmarshalYAML(node *yaml.Node) error {
type Tool interface {
Invoke([]any) (string, error)
ParseParams(data map[string]any) ([]any, error)
Manifest() Manifest
}
type Parameter struct {
@@ -85,8 +86,8 @@ func (e ParseTypeError) Error() string {
return fmt.Sprintf("Error parsing parameter %q: %q not type %q", e.Name, e.Value, e.Type)
}
// ParseParams is a helper function for parsing Parameters from an arbitratyJSON object.
func parseParams(ps []Parameter, data map[string]any) ([]any, error) {
// ParseParams is a helper function for parsing Parameters from an arbitraryJSON object.
func ParseParams(ps []Parameter, data map[string]any) ([]any, error) {
params := []any{}
for _, p := range ps {
v, ok := data[p.Name]
@@ -110,3 +111,8 @@ func parseParams(ps []Parameter, data map[string]any) ([]any, error) {
}
return params, nil
}
type Manifest struct {
Description string `json:"description"`
Parameters []Parameter `json:"parameters"`
}

View File

@@ -20,17 +20,24 @@ import (
"gopkg.in/yaml.v3"
)
type Toolset struct {
Name string `yaml:"name"`
Tools []*Tool `yaml:",inline"`
}
type ToolsetConfig struct {
Name string `yaml:"name"`
ToolNames []string `yaml:",inline"`
}
type ToolsetConfigs map[string]ToolsetConfig
type Toolset struct {
Name string `yaml:"name"`
Tools []*Tool `yaml:",inline"`
Manifest ToolsetManifest `yaml:",inline"`
}
type ToolsetManifest struct {
ServerVersion string `json:"serverVersion"`
ToolsManifest map[string]Manifest `json:"tools"`
}
// validate interface
var _ yaml.Unmarshaler = &ToolsetConfigs{}
@@ -48,7 +55,7 @@ func (c *ToolsetConfigs) UnmarshalYAML(node *yaml.Node) error {
return nil
}
func (t ToolsetConfig) Initialize(toolsMap map[string]Tool) (Toolset, error) {
func (t ToolsetConfig) Initialize(serverVersion string, toolsMap map[string]Tool) (Toolset, error) {
// finish toolset setup
// Check each declared tool name exists
var toolset Toolset
@@ -57,13 +64,18 @@ func (t ToolsetConfig) Initialize(toolsMap map[string]Tool) (Toolset, error) {
return toolset, fmt.Errorf("invalid toolset name: %s", t)
}
toolset.Tools = make([]*Tool, len(t.ToolNames))
toolset.Manifest = ToolsetManifest{
ServerVersion: serverVersion,
ToolsManifest: make(map[string]Manifest),
}
for _, toolName := range t.ToolNames {
tool, ok := toolsMap[toolName]
if !ok {
return toolset, fmt.Errorf("tool does not exist: %s", t)
}
toolset.Tools = append(toolset.Tools, &tool)
toolset.Manifest.ToolsManifest[toolName] = tool.Manifest()
}
return toolset, nil
}