mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-11 00:18:17 -05:00
feat: Add Toolset manifest endpoint (#11)
1. Calculate tool manifests when server starts. 2. Add toolset manifest endpoints. --------- Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>
This commit is contained in:
8
.github/workflows/tests.yaml
vendored
8
.github/workflows/tests.yaml
vendored
@@ -77,3 +77,11 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: go test -race -v ./...
|
||||
env:
|
||||
PROJECT: ${{ secrets.PROJECT }}
|
||||
CLOUD_SQL_PG_INSTANCE: ${{ secrets.CLOUD_SQL_PG_INSTANCE }}
|
||||
REGION: ${{ secrets.REGION }}
|
||||
DATABASE: ${{ secrets.DATABASE }}
|
||||
USER: ${{ secrets.USER }}
|
||||
PASSWORD: ${{ secrets.PASSWORD }}
|
||||
|
||||
|
||||
@@ -109,21 +109,21 @@ func run(cmd *Command) error {
|
||||
// Read tool file contents
|
||||
buf, err := os.ReadFile(cmd.tools_file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to read tool file at %q: %w", cmd.tools_file, err)
|
||||
return fmt.Errorf("unable to read tool file at %q: %w", cmd.tools_file, err)
|
||||
}
|
||||
cmd.cfg.SourceConfigs, cmd.cfg.ToolConfigs, cmd.cfg.ToolsetConfigs, err = parseToolsFile(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to parse tool file at %q: %w", cmd.tools_file, err)
|
||||
return fmt.Errorf("unable to parse tool file at %q: %w", cmd.tools_file, err)
|
||||
}
|
||||
|
||||
// run server
|
||||
s, err := server.NewServer(cmd.cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Toolbox failed to start with the following error: %w", err)
|
||||
return fmt.Errorf("toolbox failed to start with the following error: %w", err)
|
||||
}
|
||||
err = s.ListenAndServe(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Toolbox crashed with the following error: %w", err)
|
||||
return fmt.Errorf("toolbox crashed with the following error: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -15,7 +15,9 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
@@ -24,36 +26,87 @@ import (
|
||||
)
|
||||
|
||||
// apiRouter creates a router that represents the routes under /api
|
||||
func apiRouter(s *Server) chi.Router {
|
||||
func apiRouter(s *Server) (chi.Router, error) {
|
||||
r := chi.NewRouter()
|
||||
|
||||
r.Get("/toolset/{toolsetName}", toolsetHandler(s))
|
||||
r.Get("/toolset/", func(w http.ResponseWriter, r *http.Request) { toolsetHandler(s, w, r) })
|
||||
r.Get("/toolset/{toolsetName}", func(w http.ResponseWriter, r *http.Request) { toolsetHandler(s, w, r) })
|
||||
|
||||
r.Route("/tool/{toolName}", func(r chi.Router) {
|
||||
r.Use(middleware.AllowContentType("application/json"))
|
||||
r.Post("/invoke", newToolHandler(s))
|
||||
r.Post("/invoke", func(w http.ResponseWriter, r *http.Request) { toolInvokeHandler(s, w, r) })
|
||||
})
|
||||
|
||||
return r
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func toolsetHandler(s *Server) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
toolsetName := chi.URLParam(r, "toolsetName")
|
||||
_, _ = w.Write([]byte(fmt.Sprintf("Stub for toolset %s manifest!", toolsetName)))
|
||||
// toolInvokeHandler handles the request for information about a Toolset.
|
||||
func toolsetHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
toolsetName := chi.URLParam(r, "toolsetName")
|
||||
toolset, ok := s.toolsets[toolsetName]
|
||||
if !ok {
|
||||
http.Error(w, fmt.Sprintf("Toolset %q does not exist", toolsetName), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
b, err := json.Marshal(toolset.Manifest)
|
||||
if err != nil {
|
||||
log.Printf("unable to JSON the toolset manifest: %s", err)
|
||||
http.Error(w, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(b)
|
||||
}
|
||||
|
||||
// toolInvokeHandler handles the API request to invoke a specific Tool.
|
||||
func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
toolName := chi.URLParam(r, "toolName")
|
||||
tool, ok := s.tools[toolName]
|
||||
if !ok {
|
||||
err := fmt.Errorf("Invalid tool name. Tool with name %q does not exist", toolName)
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
|
||||
return
|
||||
}
|
||||
|
||||
var data map[string]interface{}
|
||||
if err := render.DecodeJSON(r.Body, &data); err != nil {
|
||||
render.Status(r, http.StatusBadRequest)
|
||||
err := fmt.Errorf("Request body was invalid JSON: %w", err)
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
|
||||
return
|
||||
}
|
||||
|
||||
params, err := tool.ParseParams(data)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("Provided parameters were invalid: %w", err)
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
|
||||
return
|
||||
}
|
||||
|
||||
res, err := tool.Invoke(params)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("Error while invoking tool: %w", err)
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
|
||||
return
|
||||
}
|
||||
|
||||
_ = render.Render(w, r, &resultResponse{Result: res})
|
||||
}
|
||||
|
||||
var _ render.Renderer = &resultResponse{} // Renderer interface for managing response payloads.
|
||||
|
||||
// resultResponse is the response sent back when the tool was invocated succesffully.
|
||||
type resultResponse struct {
|
||||
Result string `json:"result"` // result of tool invocation
|
||||
}
|
||||
|
||||
// Render renders a single payload and respond to the client request.
|
||||
func (rr resultResponse) Render(w http.ResponseWriter, r *http.Request) error {
|
||||
render.Status(r, http.StatusOK)
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ render.Renderer = &errResponse{} // Renderer interface for managing response payloads.
|
||||
|
||||
// newErrResponse is a helper function initalizing an ErrResponse
|
||||
func newErrResponse(err error, code int) *errResponse {
|
||||
return &errResponse{
|
||||
@@ -78,39 +131,3 @@ func (e *errResponse) Render(w http.ResponseWriter, r *http.Request) error {
|
||||
render.Status(r, e.HTTPStatusCode)
|
||||
return nil
|
||||
}
|
||||
|
||||
func newToolHandler(s *Server) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
toolName := chi.URLParam(r, "toolName")
|
||||
tool, ok := s.tools[toolName]
|
||||
if !ok {
|
||||
err := fmt.Errorf("Invalid tool name. Tool with name %q does not exist", toolName)
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusNotFound))
|
||||
return
|
||||
}
|
||||
|
||||
var data map[string]interface{}
|
||||
if err := render.DecodeJSON(r.Body, &data); err != nil {
|
||||
render.Status(r, http.StatusBadRequest)
|
||||
err := fmt.Errorf("Request body was invalid JSON: %w", err)
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
|
||||
return
|
||||
}
|
||||
|
||||
params, err := tool.ParseParams(data)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("Provided parameters were invalid: %w", err)
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
|
||||
return
|
||||
}
|
||||
|
||||
res, err := tool.Invoke(params)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("Error while invoking tool: %w", err)
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
|
||||
return
|
||||
}
|
||||
|
||||
_ = render.Render(w, r, &resultResponse{Result: res})
|
||||
}
|
||||
}
|
||||
|
||||
181
internal/server/api_test.go
Normal file
181
internal/server/api_test.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
var _ tools.Tool = &MockTool{}
|
||||
|
||||
type MockTool struct {
|
||||
Name string
|
||||
Description string
|
||||
Params []tools.Parameter
|
||||
}
|
||||
|
||||
func (t MockTool) Invoke([]any) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (t MockTool) ParseParams(data map[string]any) ([]any, error) {
|
||||
return tools.ParseParams(t.Params, data)
|
||||
}
|
||||
|
||||
func (t MockTool) Manifest() tools.Manifest {
|
||||
return tools.Manifest{Description: t.Description, Parameters: t.Params}
|
||||
}
|
||||
|
||||
func TestToolsetEndpoint(t *testing.T) {
|
||||
// Set up resources to test against
|
||||
tool1 := MockTool{
|
||||
Name: "no_params",
|
||||
Params: []tools.Parameter{},
|
||||
}
|
||||
tool2 := MockTool{
|
||||
Name: "some_params",
|
||||
Params: []tools.Parameter{
|
||||
{
|
||||
Name: "param1",
|
||||
Type: "int",
|
||||
Description: "This is the first parameter.",
|
||||
},
|
||||
{
|
||||
Name: "param2",
|
||||
Type: "string",
|
||||
Description: "This is the second parameter.",
|
||||
},
|
||||
},
|
||||
}
|
||||
toolsMap := map[string]tools.Tool{tool1.Name: tool1, tool2.Name: tool2}
|
||||
|
||||
toolsets := make(map[string]tools.Toolset)
|
||||
for name, l := range map[string][]string{
|
||||
"": {tool1.Name, tool2.Name},
|
||||
"tool1_only": {tool1.Name},
|
||||
"tool2_only": {tool2.Name},
|
||||
} {
|
||||
tc := tools.ToolsetConfig{Name: name, ToolNames: l}
|
||||
m, err := tc.Initialize("0.0.0", toolsMap)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to initialize toolset %q: %s", name, err)
|
||||
}
|
||||
toolsets[name] = m
|
||||
}
|
||||
|
||||
server := Server{tools: toolsMap, toolsets: toolsets}
|
||||
r, err := apiRouter(&server)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to initalize router: %s", err)
|
||||
}
|
||||
ts := httptest.NewServer(r)
|
||||
defer ts.Close()
|
||||
|
||||
// wantRepsonse is a struct for checks against test cases
|
||||
type wantResponse struct {
|
||||
statusCode int
|
||||
isErr bool
|
||||
version string
|
||||
tools []string
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
toolsetName string
|
||||
want wantResponse
|
||||
}{
|
||||
{
|
||||
name: "'default' manifest",
|
||||
toolsetName: "",
|
||||
want: wantResponse{
|
||||
statusCode: http.StatusOK,
|
||||
version: "0.0.0",
|
||||
tools: []string{tool1.Name, tool2.Name},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid toolset name",
|
||||
toolsetName: "some_imaginary_toolset",
|
||||
want: wantResponse{
|
||||
statusCode: http.StatusNotFound,
|
||||
isErr: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single toolset 1",
|
||||
toolsetName: "tool1_only",
|
||||
want: wantResponse{
|
||||
statusCode: http.StatusOK,
|
||||
version: "0.0.0",
|
||||
tools: []string{tool1.Name},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single toolset 2",
|
||||
toolsetName: "tool2_only",
|
||||
want: wantResponse{
|
||||
statusCode: http.StatusOK,
|
||||
version: "0.0.0",
|
||||
tools: []string{tool2.Name},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, body, err := testRequest(ts, http.MethodGet, fmt.Sprintf("/toolset/%s", tc.toolsetName), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during request: %s", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != tc.want.statusCode {
|
||||
t.Logf("response body: %s", body)
|
||||
t.Fatalf("unexpected status code: want %d, got %d", tc.want.statusCode, resp.StatusCode)
|
||||
}
|
||||
if tc.want.isErr {
|
||||
// skip the rest of the checks if this is an error case
|
||||
return
|
||||
}
|
||||
var m tools.ToolsetManifest
|
||||
err = json.Unmarshal(body, &m)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to parse ToolsetManifest: %s", err)
|
||||
}
|
||||
// Check the version is correct
|
||||
if m.ServerVersion != tc.want.version {
|
||||
t.Fatalf("unexpected ServerVersion: want %q, got %q", tc.want.version, m.ServerVersion)
|
||||
}
|
||||
// validate that the tools in the toolset are correct
|
||||
for _, name := range tc.want.tools {
|
||||
_, ok := m.ToolsManifest[name]
|
||||
if !ok {
|
||||
t.Errorf("%q tool not found in manfiest", name)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testRequest(ts *httptest.Server, method, path string, body io.Reader) (*http.Response, []byte, error) {
|
||||
req, err := http.NewRequest(method, ts.URL+path, body)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unable to create request: %w", err)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unable to send request: %w", err)
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unable to read request body: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return resp, respBody, nil
|
||||
}
|
||||
@@ -19,6 +19,8 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
// Server version
|
||||
Version string
|
||||
// Address is the address of the interface the server will listen on.
|
||||
Address string
|
||||
// Port is the port the server will listen on.
|
||||
|
||||
@@ -53,7 +53,7 @@ func NewServer(cfg Config) (*Server, error) {
|
||||
for name, sc := range cfg.SourceConfigs {
|
||||
s, err := sc.Initialize()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to initialize tool %s: %w", name, err)
|
||||
return nil, fmt.Errorf("unable to initialize source %q: %w", name, err)
|
||||
}
|
||||
sourcesMap[name] = s
|
||||
}
|
||||
@@ -64,22 +64,31 @@ func NewServer(cfg Config) (*Server, error) {
|
||||
for name, tc := range cfg.ToolConfigs {
|
||||
t, err := tc.Initialize(sourcesMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to initialize tool %s: %w", name, err)
|
||||
return nil, fmt.Errorf("unable to initialize tool %q: %w", name, err)
|
||||
}
|
||||
toolsMap[name] = t
|
||||
}
|
||||
fmt.Printf("Initalized %d tools.\n", len(toolsMap))
|
||||
|
||||
// initalize and validate the tools
|
||||
// create a default toolset that contains all tools
|
||||
allToolNames := make([]string, 0, len(toolsMap))
|
||||
for name := range toolsMap {
|
||||
allToolNames = append(allToolNames, name)
|
||||
}
|
||||
if cfg.ToolsetConfigs == nil {
|
||||
cfg.ToolsetConfigs = make(tools.ToolsetConfigs)
|
||||
}
|
||||
cfg.ToolsetConfigs[""] = tools.ToolsetConfig{Name: "", ToolNames: allToolNames}
|
||||
// initalize and validate the toolsets
|
||||
toolsetsMap := make(map[string]tools.Toolset)
|
||||
for name, tc := range cfg.ToolsetConfigs {
|
||||
t, err := tc.Initialize(toolsMap)
|
||||
t, err := tc.Initialize(cfg.Version, toolsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to initialize toolset %s: %w", name, err)
|
||||
return nil, fmt.Errorf("unable to initialize toolset %q: %w", name, err)
|
||||
}
|
||||
toolsetsMap[name] = t
|
||||
}
|
||||
fmt.Printf("Initalized %d tools.\n", len(toolsetsMap))
|
||||
fmt.Printf("Initalized %d toolsets.\n", len(toolsetsMap))
|
||||
|
||||
s := &Server{
|
||||
conf: cfg,
|
||||
@@ -88,7 +97,12 @@ func NewServer(cfg Config) (*Server, error) {
|
||||
tools: toolsMap,
|
||||
toolsets: toolsetsMap,
|
||||
}
|
||||
r.Mount("/api", apiRouter(s))
|
||||
|
||||
if router, err := apiRouter(s); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
r.Mount("/api", router)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ func TestServe(t *testing.T) {
|
||||
}
|
||||
s, err := server.NewServer(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable initialize server!")
|
||||
t.Fatalf("Unable to initialize server! %v", err)
|
||||
}
|
||||
|
||||
// start server in background
|
||||
|
||||
@@ -36,30 +36,31 @@ type CloudSQLPgGenericConfig struct {
|
||||
Parameters []Parameter `yaml:"parameters"`
|
||||
}
|
||||
|
||||
func (r CloudSQLPgGenericConfig) toolKind() string {
|
||||
func (cfg CloudSQLPgGenericConfig) toolKind() string {
|
||||
return CloudSQLPgSQLGenericKind
|
||||
}
|
||||
|
||||
func (r CloudSQLPgGenericConfig) Initialize(srcs map[string]sources.Source) (Tool, error) {
|
||||
func (cfg CloudSQLPgGenericConfig) Initialize(srcs map[string]sources.Source) (Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[r.Source]
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("No source named %q configured!", r.Source)
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is the right kind
|
||||
s, ok := rawS.(sources.CloudSQLPgSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Sources for %q tools must be of kind %q!", CloudSQLPgSQLGenericKind, sources.CloudSQLPgKind)
|
||||
return nil, fmt.Errorf("sources for %q tools must be of kind %q", CloudSQLPgSQLGenericKind, sources.CloudSQLPgKind)
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := CloudSQLPgGenericTool{
|
||||
Name: r.Name,
|
||||
Name: cfg.Name,
|
||||
Kind: CloudSQLPgSQLGenericKind,
|
||||
Source: s,
|
||||
Statement: r.Statement,
|
||||
Parameters: r.Parameters,
|
||||
Statement: cfg.Statement,
|
||||
Parameters: cfg.Parameters,
|
||||
manifest: Manifest{cfg.Description, cfg.Parameters},
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -73,20 +74,21 @@ type CloudSQLPgGenericTool struct {
|
||||
Source sources.CloudSQLPgSource
|
||||
Statement string
|
||||
Parameters []Parameter `yaml:"parameters"`
|
||||
manifest Manifest
|
||||
}
|
||||
|
||||
func (t CloudSQLPgGenericTool) Invoke(params []any) (string, error) {
|
||||
fmt.Printf("Invoked tool %s\n", t.Name)
|
||||
results, err := t.Source.Pool.Query(context.Background(), t.Statement, params...)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Unable to execute query: %w", err)
|
||||
return "", fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
var out strings.Builder
|
||||
for results.Next() {
|
||||
v, err := results.Values()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Unable to parse row: %w", err)
|
||||
return "", fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
out.WriteString(fmt.Sprintf("%s", v))
|
||||
}
|
||||
@@ -95,5 +97,9 @@ func (t CloudSQLPgGenericTool) Invoke(params []any) (string, error) {
|
||||
}
|
||||
|
||||
func (t CloudSQLPgGenericTool) ParseParams(data map[string]any) ([]any, error) {
|
||||
return parseParams(t.Parameters, data)
|
||||
return ParseParams(t.Parameters, data)
|
||||
}
|
||||
|
||||
func (t CloudSQLPgGenericTool) Manifest() Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var validName = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
var validName = regexp.MustCompile(`^[a-zA-Z0-9_-]*$`)
|
||||
|
||||
func IsValidName(s string) bool {
|
||||
return validName.MatchString(s)
|
||||
|
||||
@@ -66,6 +66,7 @@ func (c *Configs) UnmarshalYAML(node *yaml.Node) error {
|
||||
type Tool interface {
|
||||
Invoke([]any) (string, error)
|
||||
ParseParams(data map[string]any) ([]any, error)
|
||||
Manifest() Manifest
|
||||
}
|
||||
|
||||
type Parameter struct {
|
||||
@@ -85,8 +86,8 @@ func (e ParseTypeError) Error() string {
|
||||
return fmt.Sprintf("Error parsing parameter %q: %q not type %q", e.Name, e.Value, e.Type)
|
||||
}
|
||||
|
||||
// ParseParams is a helper function for parsing Parameters from an arbitratyJSON object.
|
||||
func parseParams(ps []Parameter, data map[string]any) ([]any, error) {
|
||||
// ParseParams is a helper function for parsing Parameters from an arbitraryJSON object.
|
||||
func ParseParams(ps []Parameter, data map[string]any) ([]any, error) {
|
||||
params := []any{}
|
||||
for _, p := range ps {
|
||||
v, ok := data[p.Name]
|
||||
@@ -110,3 +111,8 @@ func parseParams(ps []Parameter, data map[string]any) ([]any, error) {
|
||||
}
|
||||
return params, nil
|
||||
}
|
||||
|
||||
type Manifest struct {
|
||||
Description string `json:"description"`
|
||||
Parameters []Parameter `json:"parameters"`
|
||||
}
|
||||
|
||||
@@ -20,17 +20,24 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Toolset struct {
|
||||
Name string `yaml:"name"`
|
||||
Tools []*Tool `yaml:",inline"`
|
||||
}
|
||||
|
||||
type ToolsetConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
ToolNames []string `yaml:",inline"`
|
||||
}
|
||||
|
||||
type ToolsetConfigs map[string]ToolsetConfig
|
||||
|
||||
type Toolset struct {
|
||||
Name string `yaml:"name"`
|
||||
Tools []*Tool `yaml:",inline"`
|
||||
Manifest ToolsetManifest `yaml:",inline"`
|
||||
}
|
||||
|
||||
type ToolsetManifest struct {
|
||||
ServerVersion string `json:"serverVersion"`
|
||||
ToolsManifest map[string]Manifest `json:"tools"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ yaml.Unmarshaler = &ToolsetConfigs{}
|
||||
|
||||
@@ -48,7 +55,7 @@ func (c *ToolsetConfigs) UnmarshalYAML(node *yaml.Node) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t ToolsetConfig) Initialize(toolsMap map[string]Tool) (Toolset, error) {
|
||||
func (t ToolsetConfig) Initialize(serverVersion string, toolsMap map[string]Tool) (Toolset, error) {
|
||||
// finish toolset setup
|
||||
// Check each declared tool name exists
|
||||
var toolset Toolset
|
||||
@@ -57,13 +64,18 @@ func (t ToolsetConfig) Initialize(toolsMap map[string]Tool) (Toolset, error) {
|
||||
return toolset, fmt.Errorf("invalid toolset name: %s", t)
|
||||
}
|
||||
toolset.Tools = make([]*Tool, len(t.ToolNames))
|
||||
toolset.Manifest = ToolsetManifest{
|
||||
ServerVersion: serverVersion,
|
||||
ToolsManifest: make(map[string]Manifest),
|
||||
}
|
||||
for _, toolName := range t.ToolNames {
|
||||
tool, ok := toolsMap[toolName]
|
||||
if !ok {
|
||||
return toolset, fmt.Errorf("tool does not exist: %s", t)
|
||||
}
|
||||
toolset.Tools = append(toolset.Tools, &tool)
|
||||
|
||||
toolset.Manifest.ToolsManifest[toolName] = tool.Manifest()
|
||||
}
|
||||
|
||||
return toolset, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user