feat: add basic parsing from tools file (#8)

This commit is contained in:
Kurtis Van Gent
2024-08-05 10:33:07 -05:00
committed by GitHub
parent df9ad9e33f
commit b9ba364fb6
12 changed files with 539 additions and 12 deletions

View File

@@ -22,7 +22,10 @@ import (
"strings"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/spf13/cobra"
"gopkg.in/yaml.v3"
)
var (
@@ -59,35 +62,62 @@ func Execute() {
type Command struct {
*cobra.Command
cfg server.Config
cfg server.Config
tools_file string
}
// NewCommand returns a Command object representing an invocation of the CLI.
func NewCommand() *Command {
c := &Command{
cmd := &Command{
Command: &cobra.Command{
Use: "toolbox",
Version: versionString,
},
}
flags := c.Flags()
flags.StringVarP(&c.cfg.Address, "address", "a", "127.0.0.1", "Address of the interface the server will listen on.")
flags.IntVarP(&c.cfg.Port, "port", "p", 5000, "Port the server will listen on.")
flags := cmd.Flags()
flags.StringVarP(&cmd.cfg.Address, "address", "a", "127.0.0.1", "Address of the interface the server will listen on.")
flags.IntVarP(&cmd.cfg.Port, "port", "p", 5000, "Port the server will listen on.")
flags.StringVar(&cmd.tools_file, "tools_file", "tools.yaml", "File path specifying the tool configuration")
// wrap RunE command so that we have access to original Command object
c.RunE = func(*cobra.Command, []string) error { return run(c) }
cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) }
return c
return cmd
}
// parseToolsFile parses the provided yaml into appropriate configs.
func parseToolsFile(raw []byte) (sources.Configs, tools.Configs, error) {
tools_file := &struct {
Sources sources.Configs `yaml:"sources"`
Tools tools.Configs `yaml:"tools"`
}{}
// Parse contents
err := yaml.Unmarshal(raw, tools_file)
if err != nil {
return nil, nil, err
}
return tools_file.Sources, tools_file.Tools, nil
}
func run(cmd *Command) error {
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
// Read tool file contents
buf, err := os.ReadFile(cmd.tools_file)
if err != nil {
return fmt.Errorf("Unable to read tool file at %q: %w", cmd.tools_file, err)
}
cmd.cfg.SourceConfigs, cmd.cfg.ToolConfigs, err = parseToolsFile(buf)
if err != nil {
return fmt.Errorf("Unable to parse tool file at %q: %w", cmd.tools_file, err)
}
// run server
s := server.NewServer(cmd.cfg)
err := s.ListenAndServe(ctx)
err = s.ListenAndServe(ctx)
if err != nil {
return fmt.Errorf("Toolbox crashed with the following error: %w", err)
}

View File

@@ -21,7 +21,11 @@ import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/spf13/cobra"
)
@@ -65,7 +69,7 @@ func TestVersion(t *testing.T) {
}
}
func TestFlags(t *testing.T) {
func TestAddrPort(t *testing.T) {
tcs := []struct {
desc string
args []string
@@ -119,9 +123,115 @@ func TestFlags(t *testing.T) {
t.Fatalf("unexpected error invoking command: %s", err)
}
if c.cfg != tc.want {
if !cmp.Equal(c.cfg, tc.want) {
t.Fatalf("got %v, want %v", c.cfg, tc.want)
}
})
}
}
func TestToolFileFlag(t *testing.T) {
tcs := []struct {
desc string
args []string
want string
}{
{
desc: "default value",
args: []string{},
want: "tools.yaml",
},
{
desc: "foo file",
args: []string{"--tools_file", "foo.yaml"},
want: "foo.yaml",
},
{
desc: "address long",
args: []string{"--tools_file", "bar.yaml"},
want: "bar.yaml",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
c, _, err := invokeCommand(tc.args)
if err != nil {
t.Fatalf("unexpected error invoking command: %s", err)
}
if c.tools_file != tc.want {
t.Fatalf("got %v, want %v", c.cfg, tc.want)
}
})
}
}
func TestParseToolFile(t *testing.T) {
tcs := []struct {
description string
in string
wantSources sources.Configs
wantTools tools.Configs
}{
{
description: "basic example",
in: `
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
tools:
example_tool:
kind: cloud-sql-postgres-generic
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
country:
type: string
description: some description
`,
wantSources: sources.Configs{
"my-pg-instance": sources.CloudSQLPgConfig{
Kind: sources.CloudSQLPgKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
Database: "my_db",
},
},
wantTools: tools.Configs{
"example_tool": tools.CloudSQLPgGenericConfig{
Kind: tools.CloudSQLPgSQLGenericKind,
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
Parameters: map[string]tools.Parameter{
"country": {
Type: "string",
Description: "some description",
},
},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.description, func(t *testing.T) {
gotSources, gotTools, err := parseToolsFile(testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("failed to parse input: %v", err)
}
if diff := cmp.Diff(tc.wantSources, gotSources); diff != "" {
t.Fatalf("incorrect sources parse: diff %v", diff)
}
if diff := cmp.Diff(tc.wantTools, gotTools); diff != "" {
t.Fatalf("incorrect tools parse: diff %v", diff)
}
})
}
}

2
go.mod
View File

@@ -4,7 +4,9 @@ go 1.22.2
require (
github.com/go-chi/chi/v5 v5.1.0
github.com/google/go-cmp v0.6.0
github.com/spf13/cobra v1.8.1
gopkg.in/yaml.v3 v3.0.1
)
require (

4
go.sum
View File

@@ -1,6 +1,8 @@
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw=
github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
@@ -8,5 +10,7 @@ github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -13,9 +13,18 @@
// limitations under the License.
package server
import (
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
)
type Config struct {
// address is the address of the interface the server will listen on.
// Address is the address of the interface the server will listen on.
Address string
// port is the port the server will listen on.
// Port is the port the server will listen on.
Port int
// SourceConfigs defines what sources of data are available for tools.
SourceConfigs sources.Configs
// ToolConfigs defines what tools are available.
ToolConfigs tools.Configs
}

View File

@@ -0,0 +1,32 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sources
const CloudSQLPgKind string = "cloud-sql-postgres"
// validate interface
var _ Config = CloudSQLPgConfig{}
type CloudSQLPgConfig struct {
Kind string `yaml:"kind"`
Project string `yaml:"project"`
Region string `yaml:"region"`
Instance string `yaml:"instance"`
Database string `yaml:"database"`
}
func (r CloudSQLPgConfig) sourceKind() string {
return CloudSQLPgKind
}

View File

@@ -0,0 +1,70 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sources_test
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/testutils"
"gopkg.in/yaml.v3"
)
func TestParseFromYaml(t *testing.T) {
tcs := []struct {
desc string
in string
want sources.Configs
}{
{
desc: "basic example",
in: `
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
`,
want: sources.Configs{
"my-pg-instance": sources.CloudSQLPgConfig{
Kind: sources.CloudSQLPgKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
Database: "my_db",
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources sources.Configs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
}

View File

@@ -0,0 +1,62 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sources
import (
"fmt"
"gopkg.in/yaml.v3"
)
type Config interface {
sourceKind() string
}
// validate interface
var _ yaml.Unmarshaler = &Configs{}
// Configs is a type used to allow unmarshal of the data source config map
type Configs map[string]Config
func (c *Configs) UnmarshalYAML(node *yaml.Node) error {
*c = make(Configs)
// Parse the 'kind' fields for each source
var raw map[string]yaml.Node
if err := node.Decode(&raw); err != nil {
return err
}
for name, n := range raw {
var k struct {
Kind string `yaml:"kind"`
}
err := n.Decode(&k)
if err != nil {
return fmt.Errorf("missing 'kind' field for %q", k)
}
switch k.Kind {
case CloudSQLPgKind:
actual := CloudSQLPgConfig{}
if err := n.Decode(&actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
default:
return fmt.Errorf("%q is not a valid kind of data source", k.Kind)
}
}
return nil
}

View File

@@ -0,0 +1,28 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package testutils
import (
"strings"
)
// formatYaml is a utility function for stripping out tabs in multiline strings
func FormatYaml(in string) []byte {
// removes any leading indentation(tabs)
in = strings.ReplaceAll(in, "\n\t", "\n ")
// converts remaining indentation
in = strings.ReplaceAll(in, "\t", " ")
return []byte(in)
}

View File

@@ -0,0 +1,32 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tools
const CloudSQLPgSQLGenericKind string = "cloud-sql-postgres-generic"
// validate interface
var _ Config = CloudSQLPgGenericConfig{}
type CloudSQLPgGenericConfig struct {
Kind string `yaml:"kind"`
Source string `yaml:"source"`
Description string `yaml:"description"`
Statement string `yaml:"statement"`
Parameters map[string]Parameter `yaml:"parameters"`
}
func (r CloudSQLPgGenericConfig) toolKind() string {
return CloudSQLPgSQLGenericKind
}

View File

@@ -0,0 +1,79 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tools_test
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools"
"gopkg.in/yaml.v3"
)
func TestParseFromYaml(t *testing.T) {
tcs := []struct {
desc string
in string
want tools.Configs
}{
{
desc: "basic example",
in: `
tools:
example_tool:
kind: cloud-sql-postgres-generic
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
country:
type: string
description: some description
`,
want: tools.Configs{
"example_tool": tools.CloudSQLPgGenericConfig{
Kind: tools.CloudSQLPgSQLGenericKind,
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
Parameters: map[string]tools.Parameter{
"country": {
Type: "string",
Description: "some description",
},
},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Tools tools.Configs `yaml:"tools"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})
}
}

69
internal/tools/tools.go Normal file
View File

@@ -0,0 +1,69 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tools
import (
"fmt"
"gopkg.in/yaml.v3"
)
// SourceConfigs is a type used to allow unmarshal of the data source config map
type Configs map[string]Config
type Config interface {
toolKind() string
}
// validate interface
var _ yaml.Unmarshaler = &Configs{}
func (c *Configs) UnmarshalYAML(node *yaml.Node) error {
*c = make(Configs)
// Parse the 'kind' fields for each source
var raw map[string]yaml.Node
if err := node.Decode(&raw); err != nil {
return err
}
for name, n := range raw {
var k struct {
Kind string `yaml:"kind"`
}
err := n.Decode(&k)
if err != nil {
return fmt.Errorf("missing 'kind' field for %q", k)
}
switch k.Kind {
case CloudSQLPgSQLGenericKind:
actual := CloudSQLPgGenericConfig{}
if err := n.Decode(&actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
default:
return fmt.Errorf("%q is not a valid kind of tool", k.Kind)
}
}
return nil
}
type Parameter struct {
Name string `yaml:"name"`
Type string `yaml:"type"`
Description string `yaml:"description"`
Required bool `yaml:"required"`
}