mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 15:38:08 -05:00
feat: add basic parsing from tools file (#8)
This commit is contained in:
46
cmd/root.go
46
cmd/root.go
@@ -22,7 +22,10 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/spf13/cobra"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -59,35 +62,62 @@ func Execute() {
|
||||
type Command struct {
|
||||
*cobra.Command
|
||||
|
||||
cfg server.Config
|
||||
cfg server.Config
|
||||
tools_file string
|
||||
}
|
||||
|
||||
// NewCommand returns a Command object representing an invocation of the CLI.
|
||||
func NewCommand() *Command {
|
||||
c := &Command{
|
||||
cmd := &Command{
|
||||
Command: &cobra.Command{
|
||||
Use: "toolbox",
|
||||
Version: versionString,
|
||||
},
|
||||
}
|
||||
|
||||
flags := c.Flags()
|
||||
flags.StringVarP(&c.cfg.Address, "address", "a", "127.0.0.1", "Address of the interface the server will listen on.")
|
||||
flags.IntVarP(&c.cfg.Port, "port", "p", 5000, "Port the server will listen on.")
|
||||
flags := cmd.Flags()
|
||||
flags.StringVarP(&cmd.cfg.Address, "address", "a", "127.0.0.1", "Address of the interface the server will listen on.")
|
||||
flags.IntVarP(&cmd.cfg.Port, "port", "p", 5000, "Port the server will listen on.")
|
||||
|
||||
flags.StringVar(&cmd.tools_file, "tools_file", "tools.yaml", "File path specifying the tool configuration")
|
||||
|
||||
// wrap RunE command so that we have access to original Command object
|
||||
c.RunE = func(*cobra.Command, []string) error { return run(c) }
|
||||
cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) }
|
||||
|
||||
return c
|
||||
return cmd
|
||||
}
|
||||
|
||||
// parseToolsFile parses the provided yaml into appropriate configs.
|
||||
func parseToolsFile(raw []byte) (sources.Configs, tools.Configs, error) {
|
||||
tools_file := &struct {
|
||||
Sources sources.Configs `yaml:"sources"`
|
||||
Tools tools.Configs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(raw, tools_file)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return tools_file.Sources, tools_file.Tools, nil
|
||||
}
|
||||
|
||||
func run(cmd *Command) error {
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
// Read tool file contents
|
||||
buf, err := os.ReadFile(cmd.tools_file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to read tool file at %q: %w", cmd.tools_file, err)
|
||||
}
|
||||
cmd.cfg.SourceConfigs, cmd.cfg.ToolConfigs, err = parseToolsFile(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to parse tool file at %q: %w", cmd.tools_file, err)
|
||||
}
|
||||
|
||||
// run server
|
||||
s := server.NewServer(cmd.cfg)
|
||||
err := s.ListenAndServe(ctx)
|
||||
err = s.ListenAndServe(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Toolbox crashed with the following error: %w", err)
|
||||
}
|
||||
|
||||
114
cmd/root_test.go
114
cmd/root_test.go
@@ -21,7 +21,11 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -65,7 +69,7 @@ func TestVersion(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlags(t *testing.T) {
|
||||
func TestAddrPort(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
args []string
|
||||
@@ -119,9 +123,115 @@ func TestFlags(t *testing.T) {
|
||||
t.Fatalf("unexpected error invoking command: %s", err)
|
||||
}
|
||||
|
||||
if c.cfg != tc.want {
|
||||
if !cmp.Equal(c.cfg, tc.want) {
|
||||
t.Fatalf("got %v, want %v", c.cfg, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolFileFlag(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
args []string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
desc: "default value",
|
||||
args: []string{},
|
||||
want: "tools.yaml",
|
||||
},
|
||||
{
|
||||
desc: "foo file",
|
||||
args: []string{"--tools_file", "foo.yaml"},
|
||||
want: "foo.yaml",
|
||||
},
|
||||
{
|
||||
desc: "address long",
|
||||
args: []string{"--tools_file", "bar.yaml"},
|
||||
want: "bar.yaml",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
c, _, err := invokeCommand(tc.args)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error invoking command: %s", err)
|
||||
}
|
||||
if c.tools_file != tc.want {
|
||||
t.Fatalf("got %v, want %v", c.cfg, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolFile(t *testing.T) {
|
||||
tcs := []struct {
|
||||
description string
|
||||
in string
|
||||
wantSources sources.Configs
|
||||
wantTools tools.Configs
|
||||
}{
|
||||
{
|
||||
description: "basic example",
|
||||
in: `
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
tools:
|
||||
example_tool:
|
||||
kind: cloud-sql-postgres-generic
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
country:
|
||||
type: string
|
||||
description: some description
|
||||
`,
|
||||
wantSources: sources.Configs{
|
||||
"my-pg-instance": sources.CloudSQLPgConfig{
|
||||
Kind: sources.CloudSQLPgKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
Database: "my_db",
|
||||
},
|
||||
},
|
||||
wantTools: tools.Configs{
|
||||
"example_tool": tools.CloudSQLPgGenericConfig{
|
||||
Kind: tools.CloudSQLPgSQLGenericKind,
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
Parameters: map[string]tools.Parameter{
|
||||
"country": {
|
||||
Type: "string",
|
||||
Description: "some description",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
gotSources, gotTools, err := parseToolsFile(testutils.FormatYaml(tc.in))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse input: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.wantSources, gotSources); diff != "" {
|
||||
t.Fatalf("incorrect sources parse: diff %v", diff)
|
||||
}
|
||||
if diff := cmp.Diff(tc.wantTools, gotTools); diff != "" {
|
||||
t.Fatalf("incorrect tools parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
2
go.mod
2
go.mod
@@ -4,7 +4,9 @@ go 1.22.2
|
||||
|
||||
require (
|
||||
github.com/go-chi/chi/v5 v5.1.0
|
||||
github.com/google/go-cmp v0.6.0
|
||||
github.com/spf13/cobra v1.8.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
|
||||
4
go.sum
4
go.sum
@@ -1,6 +1,8 @@
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw=
|
||||
github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
@@ -8,5 +10,7 @@ github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
|
||||
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -13,9 +13,18 @@
|
||||
// limitations under the License.
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
// address is the address of the interface the server will listen on.
|
||||
// Address is the address of the interface the server will listen on.
|
||||
Address string
|
||||
// port is the port the server will listen on.
|
||||
// Port is the port the server will listen on.
|
||||
Port int
|
||||
// SourceConfigs defines what sources of data are available for tools.
|
||||
SourceConfigs sources.Configs
|
||||
// ToolConfigs defines what tools are available.
|
||||
ToolConfigs tools.Configs
|
||||
}
|
||||
|
||||
32
internal/sources/cloud_sql_pg.go
Normal file
32
internal/sources/cloud_sql_pg.go
Normal file
@@ -0,0 +1,32 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sources
|
||||
|
||||
const CloudSQLPgKind string = "cloud-sql-postgres"
|
||||
|
||||
// validate interface
|
||||
var _ Config = CloudSQLPgConfig{}
|
||||
|
||||
type CloudSQLPgConfig struct {
|
||||
Kind string `yaml:"kind"`
|
||||
Project string `yaml:"project"`
|
||||
Region string `yaml:"region"`
|
||||
Instance string `yaml:"instance"`
|
||||
Database string `yaml:"database"`
|
||||
}
|
||||
|
||||
func (r CloudSQLPgConfig) sourceKind() string {
|
||||
return CloudSQLPgKind
|
||||
}
|
||||
70
internal/sources/cloud_sql_pg_test.go
Normal file
70
internal/sources/cloud_sql_pg_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sources_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestParseFromYaml(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want sources.Configs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
`,
|
||||
want: sources.Configs{
|
||||
"my-pg-instance": sources.CloudSQLPgConfig{
|
||||
Kind: sources.CloudSQLPgKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
Database: "my_db",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources sources.Configs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
62
internal/sources/sources.go
Normal file
62
internal/sources/sources.go
Normal file
@@ -0,0 +1,62 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sources
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config interface {
|
||||
sourceKind() string
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ yaml.Unmarshaler = &Configs{}
|
||||
|
||||
// Configs is a type used to allow unmarshal of the data source config map
|
||||
type Configs map[string]Config
|
||||
|
||||
func (c *Configs) UnmarshalYAML(node *yaml.Node) error {
|
||||
*c = make(Configs)
|
||||
// Parse the 'kind' fields for each source
|
||||
var raw map[string]yaml.Node
|
||||
if err := node.Decode(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, n := range raw {
|
||||
var k struct {
|
||||
Kind string `yaml:"kind"`
|
||||
}
|
||||
err := n.Decode(&k)
|
||||
if err != nil {
|
||||
return fmt.Errorf("missing 'kind' field for %q", k)
|
||||
}
|
||||
switch k.Kind {
|
||||
case CloudSQLPgKind:
|
||||
actual := CloudSQLPgConfig{}
|
||||
if err := n.Decode(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of data source", k.Kind)
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
28
internal/testutils/testutils.go
Normal file
28
internal/testutils/testutils.go
Normal file
@@ -0,0 +1,28 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// formatYaml is a utility function for stripping out tabs in multiline strings
|
||||
func FormatYaml(in string) []byte {
|
||||
// removes any leading indentation(tabs)
|
||||
in = strings.ReplaceAll(in, "\n\t", "\n ")
|
||||
// converts remaining indentation
|
||||
in = strings.ReplaceAll(in, "\t", " ")
|
||||
return []byte(in)
|
||||
}
|
||||
32
internal/tools/cloud_sql_pg.go
Normal file
32
internal/tools/cloud_sql_pg.go
Normal file
@@ -0,0 +1,32 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tools
|
||||
|
||||
const CloudSQLPgSQLGenericKind string = "cloud-sql-postgres-generic"
|
||||
|
||||
// validate interface
|
||||
var _ Config = CloudSQLPgGenericConfig{}
|
||||
|
||||
type CloudSQLPgGenericConfig struct {
|
||||
Kind string `yaml:"kind"`
|
||||
Source string `yaml:"source"`
|
||||
Description string `yaml:"description"`
|
||||
Statement string `yaml:"statement"`
|
||||
Parameters map[string]Parameter `yaml:"parameters"`
|
||||
}
|
||||
|
||||
func (r CloudSQLPgGenericConfig) toolKind() string {
|
||||
return CloudSQLPgSQLGenericKind
|
||||
}
|
||||
79
internal/tools/cloud_sql_pg_test.go
Normal file
79
internal/tools/cloud_sql_pg_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tools_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestParseFromYaml(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want tools.Configs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: cloud-sql-postgres-generic
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
country:
|
||||
type: string
|
||||
description: some description
|
||||
`,
|
||||
want: tools.Configs{
|
||||
"example_tool": tools.CloudSQLPgGenericConfig{
|
||||
Kind: tools.CloudSQLPgSQLGenericKind,
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
Parameters: map[string]tools.Parameter{
|
||||
"country": {
|
||||
Type: "string",
|
||||
Description: "some description",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools tools.Configs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
69
internal/tools/tools.go
Normal file
69
internal/tools/tools.go
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// SourceConfigs is a type used to allow unmarshal of the data source config map
|
||||
type Configs map[string]Config
|
||||
|
||||
type Config interface {
|
||||
toolKind() string
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ yaml.Unmarshaler = &Configs{}
|
||||
|
||||
func (c *Configs) UnmarshalYAML(node *yaml.Node) error {
|
||||
*c = make(Configs)
|
||||
// Parse the 'kind' fields for each source
|
||||
var raw map[string]yaml.Node
|
||||
if err := node.Decode(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, n := range raw {
|
||||
var k struct {
|
||||
Kind string `yaml:"kind"`
|
||||
}
|
||||
err := n.Decode(&k)
|
||||
if err != nil {
|
||||
return fmt.Errorf("missing 'kind' field for %q", k)
|
||||
}
|
||||
switch k.Kind {
|
||||
case CloudSQLPgSQLGenericKind:
|
||||
actual := CloudSQLPgGenericConfig{}
|
||||
if err := n.Decode(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of tool", k.Kind)
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Parameter struct {
|
||||
Name string `yaml:"name"`
|
||||
Type string `yaml:"type"`
|
||||
Description string `yaml:"description"`
|
||||
Required bool `yaml:"required"`
|
||||
}
|
||||
Reference in New Issue
Block a user