mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-13 09:28:12 -05:00
Compare commits
1 Commits
config-sou
...
config-pre
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fea5fed265 |
71
cmd/root.go
71
cmd/root.go
@@ -15,6 +15,7 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
@@ -293,6 +294,76 @@ func parseEnv(input string) (string, error) {
|
||||
return output, err
|
||||
}
|
||||
|
||||
func convertToolsFile(ctx context.Context, raw []byte) ([]byte, error) {
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
keysToCheck := []string{"sources", "authServices", "authSources", "tools", "toolsets"}
|
||||
var input map[string]any
|
||||
if err := yaml.Unmarshal(raw, &input); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling tools file: %s", err)
|
||||
}
|
||||
|
||||
// convert to tools file v2
|
||||
var toolsFileV1 bool
|
||||
var buf bytes.Buffer
|
||||
for _, kind := range keysToCheck {
|
||||
resource, ok := input[kind]
|
||||
if !ok {
|
||||
// if this is skipped for all keys, the tools file is in v2
|
||||
continue
|
||||
}
|
||||
toolsFileV1 = true
|
||||
// convert `authSources` to `authServices`
|
||||
if kind == "authSources" {
|
||||
logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` instead")
|
||||
kind = "authServices"
|
||||
}
|
||||
r, ok := resource.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("'%s' is not a map", kind)
|
||||
}
|
||||
|
||||
for name, d := range r {
|
||||
buf.WriteString(fmt.Sprintf("kind: %s\n", kind))
|
||||
buf.WriteString(fmt.Sprintf("name: %s\n", name))
|
||||
|
||||
if kind == "toolsets" {
|
||||
b, err := yaml.Marshal(d)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling: %s", err)
|
||||
}
|
||||
buf.WriteString(fmt.Sprintf("tools:\n%s", b))
|
||||
} else {
|
||||
fields, ok := d.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("fields for %s is not a map", name)
|
||||
}
|
||||
// Copy all existing fields, renaming 'kind' to 'type'.
|
||||
buf.WriteString(fmt.Sprintf("type: %s\n", fields["kind"]))
|
||||
for key, value := range fields {
|
||||
if key == "kind" {
|
||||
continue
|
||||
}
|
||||
o := map[string]any{key: value}
|
||||
b, err := yaml.Marshal(o)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling: %s", err)
|
||||
}
|
||||
buf.Write(b)
|
||||
}
|
||||
}
|
||||
buf.WriteString("---\n")
|
||||
}
|
||||
}
|
||||
if !toolsFileV1 {
|
||||
return raw, nil
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// parseToolsFile parses the provided yaml into appropriate configs.
|
||||
func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
|
||||
var toolsFile ToolsFile
|
||||
|
||||
196
cmd/root_test.go
196
cmd/root_test.go
@@ -23,6 +23,7 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
@@ -31,6 +32,7 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/auth/google"
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
||||
@@ -455,6 +457,200 @@ func TestDefaultLogLevel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertToolsFile(t *testing.T) {
|
||||
ctx, cancelCtx := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancelCtx()
|
||||
pr, pw := io.Pipe()
|
||||
defer pw.Close()
|
||||
defer pr.Close()
|
||||
|
||||
logger, err := log.NewStdLogger(pw, pw, "DEBUG")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup logger %s", err)
|
||||
}
|
||||
ctx = util.WithLogger(ctx, logger)
|
||||
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want string
|
||||
isErr bool
|
||||
errStr string
|
||||
}{
|
||||
{
|
||||
desc: "basic convert",
|
||||
in: `
|
||||
sources:
|
||||
my-pg-instance:
|
||||
kind: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
tools:
|
||||
example_tool:
|
||||
kind: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
toolsets:
|
||||
example_toolset:
|
||||
- example_tool`,
|
||||
want: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
---
|
||||
kind: tools
|
||||
name: example_tool
|
||||
type: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
---
|
||||
kind: toolsets
|
||||
name: example_toolset
|
||||
tools:
|
||||
- example_tool`,
|
||||
},
|
||||
{
|
||||
desc: "no convertion needed",
|
||||
in: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
---
|
||||
kind: tools
|
||||
name: example_tool
|
||||
type: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
---
|
||||
kind: toolsets
|
||||
name: example_toolset
|
||||
tools:
|
||||
- example_tool`,
|
||||
want: `
|
||||
kind: sources
|
||||
name: my-pg-instance
|
||||
type: cloud-sql-postgres
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
---
|
||||
kind: tools
|
||||
name: example_tool
|
||||
type: postgres-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
---
|
||||
kind: toolsets
|
||||
name: example_toolset
|
||||
tools:
|
||||
- example_tool`,
|
||||
},
|
||||
{
|
||||
desc: "invalid source",
|
||||
in: `sources: invalid`,
|
||||
isErr: true,
|
||||
errStr: "'sources' is not a map",
|
||||
},
|
||||
{
|
||||
desc: "invalid toolset",
|
||||
in: `toolsets: invalid`,
|
||||
isErr: true,
|
||||
errStr: "'toolsets' is not a map",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
output, err := convertToolsFile(ctx, []byte(tc.in))
|
||||
if tc.isErr {
|
||||
if err == nil {
|
||||
t.Fatalf("missing error: %s", tc.errStr)
|
||||
}
|
||||
if err.Error() != tc.errStr {
|
||||
t.Fatalf("invalid error string: got %s, want %s", err, tc.errStr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
outputObj, err := unmarshalYAMLUtil(output)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
wantObj, err := unmarshalYAMLUtil([]byte(tc.want))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
if !reflect.DeepEqual(outputObj, wantObj) {
|
||||
t.Fatalf("incorrect output: got %s, want %s", outputObj, wantObj)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// unmarshalYAMLUtil decodes a multi-document YAML string into a slice of generic maps.
|
||||
func unmarshalYAMLUtil(yamlStr []byte) ([]map[string]any, error) {
|
||||
decoder := yaml.NewDecoder(bytes.NewReader(yamlStr))
|
||||
var docs []map[string]any
|
||||
for {
|
||||
var doc map[string]any
|
||||
err := decoder.Decode(&doc)
|
||||
if err != nil {
|
||||
if err.Error() == "EOF" {
|
||||
break
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
func TestParseToolFile(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user