mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 16:08:16 -05:00
feat: add toolset configuration (#12)
Add `Toolset` implementation to the `tools` package: - struct and configs. - Custom `UnmarshalYAML` function. - Initialization function that validates if tools specified for the toolset exist.
This commit is contained in:
13
cmd/root.go
13
cmd/root.go
@@ -88,17 +88,18 @@ func NewCommand() *Command {
|
||||
}
|
||||
|
||||
// parseToolsFile parses the provided yaml into appropriate configs.
|
||||
func parseToolsFile(raw []byte) (sources.Configs, tools.Configs, error) {
|
||||
func parseToolsFile(raw []byte) (sources.Configs, tools.Configs, tools.ToolsetConfigs, error) {
|
||||
tools_file := &struct {
|
||||
Sources sources.Configs `yaml:"sources"`
|
||||
Tools tools.Configs `yaml:"tools"`
|
||||
Sources sources.Configs `yaml:"sources"`
|
||||
Tools tools.Configs `yaml:"tools"`
|
||||
Toolsets tools.ToolsetConfigs `yaml:"toolsets"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(raw, tools_file)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
return tools_file.Sources, tools_file.Tools, nil
|
||||
return tools_file.Sources, tools_file.Tools, tools_file.Toolsets, nil
|
||||
}
|
||||
|
||||
func run(cmd *Command) error {
|
||||
@@ -110,7 +111,7 @@ func run(cmd *Command) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to read tool file at %q: %w", cmd.tools_file, err)
|
||||
}
|
||||
cmd.cfg.SourceConfigs, cmd.cfg.ToolConfigs, err = parseToolsFile(buf)
|
||||
cmd.cfg.SourceConfigs, cmd.cfg.ToolConfigs, cmd.cfg.ToolsetConfigs, err = parseToolsFile(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to parse tool file at %q: %w", cmd.tools_file, err)
|
||||
}
|
||||
|
||||
@@ -167,10 +167,11 @@ func TestToolFileFlag(t *testing.T) {
|
||||
|
||||
func TestParseToolFile(t *testing.T) {
|
||||
tcs := []struct {
|
||||
description string
|
||||
in string
|
||||
wantSources sources.Configs
|
||||
wantTools tools.Configs
|
||||
description string
|
||||
in string
|
||||
wantSources sources.Configs
|
||||
wantTools tools.Configs
|
||||
wantToolsets tools.ToolsetConfigs
|
||||
}{
|
||||
{
|
||||
description: "basic example",
|
||||
@@ -193,6 +194,9 @@ func TestParseToolFile(t *testing.T) {
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
toolsets:
|
||||
example_toolset:
|
||||
- example_tool
|
||||
`,
|
||||
wantSources: sources.Configs{
|
||||
"my-pg-instance": sources.CloudSQLPgConfig{
|
||||
@@ -220,11 +224,17 @@ func TestParseToolFile(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
wantToolsets: tools.ToolsetConfigs{
|
||||
"example_toolset": tools.ToolsetConfig{
|
||||
Name: "example_toolset",
|
||||
ToolNames: []string{"example_tool"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
gotSources, gotTools, err := parseToolsFile(testutils.FormatYaml(tc.in))
|
||||
gotSources, gotTools, gotToolsets, err := parseToolsFile(testutils.FormatYaml(tc.in))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse input: %v", err)
|
||||
}
|
||||
@@ -234,6 +244,9 @@ func TestParseToolFile(t *testing.T) {
|
||||
if diff := cmp.Diff(tc.wantTools, gotTools); diff != "" {
|
||||
t.Fatalf("incorrect tools parse: diff %v", diff)
|
||||
}
|
||||
if diff := cmp.Diff(tc.wantToolsets, gotToolsets); diff != "" {
|
||||
t.Fatalf("incorrect tools parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -27,4 +27,6 @@ type Config struct {
|
||||
SourceConfigs sources.Configs
|
||||
// ToolConfigs defines what tools are available.
|
||||
ToolConfigs tools.Configs
|
||||
// ToolsetConfigs defines what tools are available.
|
||||
ToolsetConfigs tools.ToolsetConfigs
|
||||
}
|
||||
|
||||
@@ -33,8 +33,9 @@ type Server struct {
|
||||
conf Config
|
||||
root chi.Router
|
||||
|
||||
sources map[string]sources.Source
|
||||
tools map[string]tools.Tool
|
||||
sources map[string]sources.Source
|
||||
tools map[string]tools.Tool
|
||||
toolsets map[string]tools.Toolset
|
||||
}
|
||||
|
||||
// NewServer returns a Server object based on provided Config.
|
||||
@@ -48,32 +49,44 @@ func NewServer(cfg Config) (*Server, error) {
|
||||
})
|
||||
|
||||
// initalize and validate the sources
|
||||
sources := make(map[string]sources.Source)
|
||||
sourcesMap := make(map[string]sources.Source)
|
||||
for name, sc := range cfg.SourceConfigs {
|
||||
s, err := sc.Initialize()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to initialize tool %s: %w", name, err)
|
||||
}
|
||||
sources[name] = s
|
||||
sourcesMap[name] = s
|
||||
}
|
||||
fmt.Printf("Initalized %d sources.\n", len(sources))
|
||||
fmt.Printf("Initalized %d sources.\n", len(sourcesMap))
|
||||
|
||||
// initalize and validate the tools
|
||||
tools := make(map[string]tools.Tool)
|
||||
toolsMap := make(map[string]tools.Tool)
|
||||
for name, tc := range cfg.ToolConfigs {
|
||||
t, err := tc.Initialize(sources)
|
||||
t, err := tc.Initialize(sourcesMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to initialize tool %s: %w", name, err)
|
||||
}
|
||||
tools[name] = t
|
||||
toolsMap[name] = t
|
||||
}
|
||||
fmt.Printf("Initalized %d tools.\n", len(tools))
|
||||
fmt.Printf("Initalized %d tools.\n", len(toolsMap))
|
||||
|
||||
// initalize and validate the tools
|
||||
toolsetsMap := make(map[string]tools.Toolset)
|
||||
for name, tc := range cfg.ToolsetConfigs {
|
||||
t, err := tc.Initialize(toolsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to initialize toolset %s: %w", name, err)
|
||||
}
|
||||
toolsetsMap[name] = t
|
||||
}
|
||||
fmt.Printf("Initalized %d tools.\n", len(toolsetsMap))
|
||||
|
||||
s := &Server{
|
||||
conf: cfg,
|
||||
root: r,
|
||||
sources: sources,
|
||||
tools: tools,
|
||||
conf: cfg,
|
||||
root: r,
|
||||
sources: sourcesMap,
|
||||
tools: toolsMap,
|
||||
toolsets: toolsetsMap,
|
||||
}
|
||||
r.Mount("/api", apiRouter(s))
|
||||
|
||||
|
||||
25
internal/tools/common.go
Normal file
25
internal/tools/common.go
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tools
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var validName = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
|
||||
func IsValidName(s string) bool {
|
||||
return validName.MatchString(s)
|
||||
}
|
||||
69
internal/tools/toolsets.go
Normal file
69
internal/tools/toolsets.go
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Toolset struct {
|
||||
Name string `yaml:"name"`
|
||||
Tools []*Tool `yaml:",inline"`
|
||||
}
|
||||
|
||||
type ToolsetConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
ToolNames []string `yaml:",inline"`
|
||||
}
|
||||
type ToolsetConfigs map[string]ToolsetConfig
|
||||
|
||||
// validate interface
|
||||
var _ yaml.Unmarshaler = &ToolsetConfigs{}
|
||||
|
||||
func (c *ToolsetConfigs) UnmarshalYAML(node *yaml.Node) error {
|
||||
*c = make(ToolsetConfigs)
|
||||
|
||||
var raw map[string][]string
|
||||
if err := node.Decode(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, toolList := range raw {
|
||||
(*c)[name] = ToolsetConfig{Name: name, ToolNames: toolList}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t ToolsetConfig) Initialize(toolsMap map[string]Tool) (Toolset, error) {
|
||||
// finish toolset setup
|
||||
// Check each declared tool name exists
|
||||
var toolset Toolset
|
||||
toolset.Name = t.Name
|
||||
if !IsValidName(toolset.Name) {
|
||||
return toolset, fmt.Errorf("invalid toolset name: %s", t)
|
||||
}
|
||||
toolset.Tools = make([]*Tool, len(t.ToolNames))
|
||||
for _, toolName := range t.ToolNames {
|
||||
tool, ok := toolsMap[toolName]
|
||||
if !ok {
|
||||
return toolset, fmt.Errorf("tool does not exist: %s", t)
|
||||
}
|
||||
toolset.Tools = append(toolset.Tools, &tool)
|
||||
|
||||
}
|
||||
return toolset, nil
|
||||
}
|
||||
Reference in New Issue
Block a user