Compare commits

...

1 Commits

Author SHA1 Message Date
Yuan Teoh
fea5fed265 chore: add preprocessing layer to convert tools file 2025-09-05 17:40:07 -07:00
2 changed files with 267 additions and 0 deletions

View File

@@ -15,6 +15,7 @@
package cmd
import (
"bytes"
"context"
_ "embed"
"fmt"
@@ -293,6 +294,76 @@ func parseEnv(input string) (string, error) {
return output, err
}
func convertToolsFile(ctx context.Context, raw []byte) ([]byte, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
panic(err)
}
keysToCheck := []string{"sources", "authServices", "authSources", "tools", "toolsets"}
var input map[string]any
if err := yaml.Unmarshal(raw, &input); err != nil {
return nil, fmt.Errorf("error unmarshaling tools file: %s", err)
}
// convert to tools file v2
var toolsFileV1 bool
var buf bytes.Buffer
for _, kind := range keysToCheck {
resource, ok := input[kind]
if !ok {
// if this is skipped for all keys, the tools file is in v2
continue
}
toolsFileV1 = true
// convert `authSources` to `authServices`
if kind == "authSources" {
logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` instead")
kind = "authServices"
}
r, ok := resource.(map[string]any)
if !ok {
return nil, fmt.Errorf("'%s' is not a map", kind)
}
for name, d := range r {
buf.WriteString(fmt.Sprintf("kind: %s\n", kind))
buf.WriteString(fmt.Sprintf("name: %s\n", name))
if kind == "toolsets" {
b, err := yaml.Marshal(d)
if err != nil {
return nil, fmt.Errorf("error marshaling: %s", err)
}
buf.WriteString(fmt.Sprintf("tools:\n%s", b))
} else {
fields, ok := d.(map[string]any)
if !ok {
return nil, fmt.Errorf("fields for %s is not a map", name)
}
// Copy all existing fields, renaming 'kind' to 'type'.
buf.WriteString(fmt.Sprintf("type: %s\n", fields["kind"]))
for key, value := range fields {
if key == "kind" {
continue
}
o := map[string]any{key: value}
b, err := yaml.Marshal(o)
if err != nil {
return nil, fmt.Errorf("error marshaling: %s", err)
}
buf.Write(b)
}
}
buf.WriteString("---\n")
}
}
if !toolsFileV1 {
return raw, nil
}
return buf.Bytes(), nil
}
// parseToolsFile parses the provided yaml into appropriate configs.
func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
var toolsFile ToolsFile

View File

@@ -23,6 +23,7 @@ import (
"os"
"path"
"path/filepath"
"reflect"
"regexp"
"runtime"
"strings"
@@ -31,6 +32,7 @@ import (
"github.com/google/go-cmp/cmp"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/auth/google"
"github.com/googleapis/genai-toolbox/internal/log"
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
@@ -455,6 +457,200 @@ func TestDefaultLogLevel(t *testing.T) {
}
}
func TestConvertToolsFile(t *testing.T) {
ctx, cancelCtx := context.WithTimeout(context.Background(), time.Minute)
defer cancelCtx()
pr, pw := io.Pipe()
defer pw.Close()
defer pr.Close()
logger, err := log.NewStdLogger(pw, pw, "DEBUG")
if err != nil {
t.Fatalf("failed to setup logger %s", err)
}
ctx = util.WithLogger(ctx, logger)
tcs := []struct {
desc string
in string
want string
isErr bool
errStr string
}{
{
desc: "basic convert",
in: `
sources:
my-pg-instance:
kind: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
tools:
example_tool:
kind: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
toolsets:
example_toolset:
- example_tool`,
want: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
---
kind: tools
name: example_tool
type: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
---
kind: toolsets
name: example_toolset
tools:
- example_tool`,
},
{
desc: "no convertion needed",
in: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
---
kind: tools
name: example_tool
type: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
---
kind: toolsets
name: example_toolset
tools:
- example_tool`,
want: `
kind: sources
name: my-pg-instance
type: cloud-sql-postgres
project: my-project
region: my-region
instance: my-instance
database: my_db
user: my_user
password: my_pass
---
kind: tools
name: example_tool
type: postgres-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
---
kind: toolsets
name: example_toolset
tools:
- example_tool`,
},
{
desc: "invalid source",
in: `sources: invalid`,
isErr: true,
errStr: "'sources' is not a map",
},
{
desc: "invalid toolset",
in: `toolsets: invalid`,
isErr: true,
errStr: "'toolsets' is not a map",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
output, err := convertToolsFile(ctx, []byte(tc.in))
if tc.isErr {
if err == nil {
t.Fatalf("missing error: %s", tc.errStr)
}
if err.Error() != tc.errStr {
t.Fatalf("invalid error string: got %s, want %s", err, tc.errStr)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
outputObj, err := unmarshalYAMLUtil(output)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
wantObj, err := unmarshalYAMLUtil([]byte(tc.want))
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
if !reflect.DeepEqual(outputObj, wantObj) {
t.Fatalf("incorrect output: got %s, want %s", outputObj, wantObj)
}
})
}
}
// unmarshalYAMLUtil decodes a multi-document YAML string into a slice of generic maps.
func unmarshalYAMLUtil(yamlStr []byte) ([]map[string]any, error) {
decoder := yaml.NewDecoder(bytes.NewReader(yamlStr))
var docs []map[string]any
for {
var doc map[string]any
err := decoder.Decode(&doc)
if err != nil {
if err.Error() == "EOF" {
break
}
return nil, err
}
docs = append(docs, doc)
}
return docs, nil
}
func TestParseToolFile(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {