mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-02-11 23:55:07 -05:00
# Add Multiple YAML Configuration File Support ## 🎯 Overview This PR introduces support for loading and merging multiple YAML configuration files in Toolbox, addressing the need for modular configuration management in complex deployments. ## ✨ New Features ### 1. Multiple Files Support (`--tools-files`) - **Usage**: `--tools-files=file1.yaml,file2.yaml,file3.yaml` - Load and intelligently merge multiple YAML configuration files - Comma-separated file paths for maximum flexibility ### 2. Directory Support (`--tools-folder`) - **Usage**: `--tools-folder=config-directory` - Automatically discover and load all `.yaml` and `.yml` files from a directory - Simplifies configuration management for organized deployments ### 3. Smart Merging Logic - **Sources/AuthServices/Tools**: Later files override earlier files with same names - **Toolsets**: Tools from same-named toolsets are combined without duplicates - Preserves all existing functionality while enabling composition ## 🔒 Safety & Validation - **Mutual Exclusivity**: Prevents simultaneous use of `--tools-file`, `--tools-files`, `--tools-folder`, and `--prebuilt` - **Clear Error Messages**: Descriptive validation errors guide users to correct usage - **Comprehensive Error Handling**: Proper handling of missing files, directories, and parsing errors - **Full Backward Compatibility**: Existing configurations continue to work unchanged ## 🏗️ Implementation Details ### Core Functions Added - `mergeToolsFiles()` - Smart merging with configurable override rules - `loadAndMergeToolsFiles()` - Multi-file loading and processing - `loadAndMergeToolsFolder()` - Directory scanning and batch loading ### Command Structure Updates - New `tools_files []string` field for multiple file paths - New `tools_folder string` field for directory path - Enhanced validation logic in `run()` function - Updated flag definitions with proper descriptions ## 📋 Use Cases ### Organizational Benefits - **Modular Configuration**: Separate database, API, and auth configurations - **Team Collaboration**: Multiple developers can work on different config files - **Environment Management**: Easy configuration swapping for different environments - **Scalability**: Large configurations can be broken into manageable chunks ### Example Usage Patterns ```bash # Multiple specific files ./toolbox --tools-files=database.yaml,apis.yaml,auth.yaml # Directory-based loading ./toolbox --tools-folder=./production-configs # Error case (properly handled) ./toolbox --tools-file=single.yaml --tools-folder=configs # ERROR: --tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously --------- Co-authored-by: Wenxin Du <117315983+duwenxin99@users.noreply.github.com>
1128 lines
29 KiB
Go
1128 lines
29 KiB
Go
// Copyright 2024 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package cmd
|
|
|
|
import (
|
|
"bytes"
|
|
_ "embed"
|
|
"os"
|
|
"runtime"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
|
|
"github.com/googleapis/genai-toolbox/internal/auth/google"
|
|
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
|
"github.com/googleapis/genai-toolbox/internal/server"
|
|
cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
|
httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http"
|
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
|
"github.com/googleapis/genai-toolbox/internal/tools/http"
|
|
"github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql"
|
|
"github.com/spf13/cobra"
|
|
)
|
|
|
|
func withDefaults(c server.ServerConfig) server.ServerConfig {
|
|
data, _ := os.ReadFile("version.txt")
|
|
version := strings.TrimSpace(string(data)) // Preserving 'data', new var for clarity
|
|
c.Version = version + "+" + strings.Join([]string{"dev", runtime.GOOS, runtime.GOARCH}, ".")
|
|
|
|
if c.Address == "" {
|
|
c.Address = "127.0.0.1"
|
|
}
|
|
if c.Port == 0 {
|
|
c.Port = 5000
|
|
}
|
|
if c.TelemetryServiceName == "" {
|
|
c.TelemetryServiceName = "toolbox"
|
|
}
|
|
return c
|
|
}
|
|
|
|
func invokeCommand(args []string) (*Command, string, error) {
|
|
c := NewCommand()
|
|
|
|
// Keep the test output quiet
|
|
c.SilenceUsage = true
|
|
c.SilenceErrors = true
|
|
|
|
// Capture output
|
|
buf := new(bytes.Buffer)
|
|
c.SetOut(buf)
|
|
c.SetErr(buf)
|
|
c.SetArgs(args)
|
|
|
|
// Disable execute behavior
|
|
c.RunE = func(*cobra.Command, []string) error {
|
|
return nil
|
|
}
|
|
|
|
err := c.Execute()
|
|
|
|
return c, buf.String(), err
|
|
}
|
|
|
|
func TestVersion(t *testing.T) {
|
|
data, err := os.ReadFile("version.txt")
|
|
if err != nil {
|
|
t.Fatalf("failed to read version.txt: %v", err)
|
|
}
|
|
want := strings.TrimSpace(string(data))
|
|
|
|
_, got, err := invokeCommand([]string{"--version"})
|
|
if err != nil {
|
|
t.Fatalf("error invoking command: %s", err)
|
|
}
|
|
|
|
if !strings.Contains(got, want) {
|
|
t.Errorf("cli did not return correct version: want %q, got %q", want, got)
|
|
}
|
|
}
|
|
|
|
func TestServerConfigFlags(t *testing.T) {
|
|
tcs := []struct {
|
|
desc string
|
|
args []string
|
|
want server.ServerConfig
|
|
}{
|
|
{
|
|
desc: "default values",
|
|
args: []string{},
|
|
want: withDefaults(server.ServerConfig{}),
|
|
},
|
|
{
|
|
desc: "address short",
|
|
args: []string{"-a", "127.0.1.1"},
|
|
want: withDefaults(server.ServerConfig{
|
|
Address: "127.0.1.1",
|
|
}),
|
|
},
|
|
{
|
|
desc: "address long",
|
|
args: []string{"--address", "0.0.0.0"},
|
|
want: withDefaults(server.ServerConfig{
|
|
Address: "0.0.0.0",
|
|
}),
|
|
},
|
|
{
|
|
desc: "port short",
|
|
args: []string{"-p", "5052"},
|
|
want: withDefaults(server.ServerConfig{
|
|
Port: 5052,
|
|
}),
|
|
},
|
|
{
|
|
desc: "port long",
|
|
args: []string{"--port", "5050"},
|
|
want: withDefaults(server.ServerConfig{
|
|
Port: 5050,
|
|
}),
|
|
},
|
|
{
|
|
desc: "logging format",
|
|
args: []string{"--logging-format", "JSON"},
|
|
want: withDefaults(server.ServerConfig{
|
|
LoggingFormat: "JSON",
|
|
}),
|
|
},
|
|
{
|
|
desc: "debug logs",
|
|
args: []string{"--log-level", "WARN"},
|
|
want: withDefaults(server.ServerConfig{
|
|
LogLevel: "WARN",
|
|
}),
|
|
},
|
|
{
|
|
desc: "telemetry gcp",
|
|
args: []string{"--telemetry-gcp"},
|
|
want: withDefaults(server.ServerConfig{
|
|
TelemetryGCP: true,
|
|
}),
|
|
},
|
|
{
|
|
desc: "telemetry otlp",
|
|
args: []string{"--telemetry-otlp", "http://127.0.0.1:4553"},
|
|
want: withDefaults(server.ServerConfig{
|
|
TelemetryOTLP: "http://127.0.0.1:4553",
|
|
}),
|
|
},
|
|
{
|
|
desc: "telemetry service name",
|
|
args: []string{"--telemetry-service-name", "toolbox-custom"},
|
|
want: withDefaults(server.ServerConfig{
|
|
TelemetryServiceName: "toolbox-custom",
|
|
}),
|
|
},
|
|
{
|
|
desc: "stdio",
|
|
args: []string{"--stdio"},
|
|
want: withDefaults(server.ServerConfig{
|
|
Stdio: true,
|
|
}),
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.desc, func(t *testing.T) {
|
|
c, _, err := invokeCommand(tc.args)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error invoking command: %s", err)
|
|
}
|
|
|
|
if !cmp.Equal(c.cfg, tc.want) {
|
|
t.Fatalf("got %v, want %v", c.cfg, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestToolFileFlag(t *testing.T) {
|
|
tcs := []struct {
|
|
desc string
|
|
args []string
|
|
want string
|
|
}{
|
|
{
|
|
desc: "default value",
|
|
args: []string{},
|
|
want: "",
|
|
},
|
|
{
|
|
desc: "foo file",
|
|
args: []string{"--tools-file", "foo.yaml"},
|
|
want: "foo.yaml",
|
|
},
|
|
{
|
|
desc: "address long",
|
|
args: []string{"--tools-file", "bar.yaml"},
|
|
want: "bar.yaml",
|
|
},
|
|
{
|
|
desc: "deprecated flag",
|
|
args: []string{"--tools_file", "foo.yaml"},
|
|
want: "foo.yaml",
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.desc, func(t *testing.T) {
|
|
c, _, err := invokeCommand(tc.args)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error invoking command: %s", err)
|
|
}
|
|
if c.tools_file != tc.want {
|
|
t.Fatalf("got %v, want %v", c.cfg, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestToolsFilesFlag(t *testing.T) {
|
|
tcs := []struct {
|
|
desc string
|
|
args []string
|
|
want []string
|
|
}{
|
|
{
|
|
desc: "no value",
|
|
args: []string{},
|
|
want: []string{},
|
|
},
|
|
{
|
|
desc: "single file",
|
|
args: []string{"--tools-files", "foo.yaml"},
|
|
want: []string{"foo.yaml"},
|
|
},
|
|
{
|
|
desc: "multiple files",
|
|
args: []string{"--tools-files", "foo.yaml,bar.yaml"},
|
|
want: []string{"foo.yaml", "bar.yaml"},
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.desc, func(t *testing.T) {
|
|
c, _, err := invokeCommand(tc.args)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error invoking command: %s", err)
|
|
}
|
|
if diff := cmp.Diff(c.tools_files, tc.want); diff != "" {
|
|
t.Fatalf("got %v, want %v", c.tools_files, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestToolsFolderFlag(t *testing.T) {
|
|
tcs := []struct {
|
|
desc string
|
|
args []string
|
|
want string
|
|
}{
|
|
{
|
|
desc: "no value",
|
|
args: []string{},
|
|
want: "",
|
|
},
|
|
{
|
|
desc: "folder set",
|
|
args: []string{"--tools-folder", "test-folder"},
|
|
want: "test-folder",
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.desc, func(t *testing.T) {
|
|
c, _, err := invokeCommand(tc.args)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error invoking command: %s", err)
|
|
}
|
|
if c.tools_folder != tc.want {
|
|
t.Fatalf("got %v, want %v", c.tools_folder, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestPrebuiltFlag(t *testing.T) {
|
|
tcs := []struct {
|
|
desc string
|
|
args []string
|
|
want string
|
|
}{
|
|
{
|
|
desc: "default value",
|
|
args: []string{},
|
|
want: "",
|
|
},
|
|
{
|
|
desc: "custom pre built flag",
|
|
args: []string{"--tools-file", "alloydb"},
|
|
want: "alloydb",
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.desc, func(t *testing.T) {
|
|
c, _, err := invokeCommand(tc.args)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error invoking command: %s", err)
|
|
}
|
|
if c.tools_file != tc.want {
|
|
t.Fatalf("got %v, want %v", c.cfg, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFailServerConfigFlags(t *testing.T) {
|
|
tcs := []struct {
|
|
desc string
|
|
args []string
|
|
}{
|
|
{
|
|
desc: "logging format",
|
|
args: []string{"--logging-format", "fail"},
|
|
},
|
|
{
|
|
desc: "debug logs",
|
|
args: []string{"--log-level", "fail"},
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.desc, func(t *testing.T) {
|
|
_, _, err := invokeCommand(tc.args)
|
|
if err == nil {
|
|
t.Fatalf("expected an error, but got nil")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDefaultLoggingFormat(t *testing.T) {
|
|
c, _, err := invokeCommand([]string{})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error invoking command: %s", err)
|
|
}
|
|
got := c.cfg.LoggingFormat.String()
|
|
want := "standard"
|
|
if got != want {
|
|
t.Fatalf("unexpected default logging format flag: got %v, want %v", got, want)
|
|
}
|
|
}
|
|
|
|
func TestDefaultLogLevel(t *testing.T) {
|
|
c, _, err := invokeCommand([]string{})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error invoking command: %s", err)
|
|
}
|
|
got := c.cfg.LogLevel.String()
|
|
want := "info"
|
|
if got != want {
|
|
t.Fatalf("unexpected default log level flag: got %v, want %v", got, want)
|
|
}
|
|
}
|
|
|
|
func TestParseToolFile(t *testing.T) {
|
|
ctx, err := testutils.ContextWithNewLogger()
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %s", err)
|
|
}
|
|
tcs := []struct {
|
|
description string
|
|
in string
|
|
wantToolsFile ToolsFile
|
|
}{
|
|
{
|
|
description: "basic example",
|
|
in: `
|
|
sources:
|
|
my-pg-instance:
|
|
kind: cloud-sql-postgres
|
|
project: my-project
|
|
region: my-region
|
|
instance: my-instance
|
|
database: my_db
|
|
user: my_user
|
|
password: my_pass
|
|
tools:
|
|
example_tool:
|
|
kind: postgres-sql
|
|
source: my-pg-instance
|
|
description: some description
|
|
statement: |
|
|
SELECT * FROM SQL_STATEMENT;
|
|
parameters:
|
|
- name: country
|
|
type: string
|
|
description: some description
|
|
toolsets:
|
|
example_toolset:
|
|
- example_tool
|
|
`,
|
|
wantToolsFile: ToolsFile{
|
|
Sources: server.SourceConfigs{
|
|
"my-pg-instance": cloudsqlpgsrc.Config{
|
|
Name: "my-pg-instance",
|
|
Kind: cloudsqlpgsrc.SourceKind,
|
|
Project: "my-project",
|
|
Region: "my-region",
|
|
Instance: "my-instance",
|
|
IPType: "public",
|
|
Database: "my_db",
|
|
User: "my_user",
|
|
Password: "my_pass",
|
|
},
|
|
},
|
|
Tools: server.ToolConfigs{
|
|
"example_tool": postgressql.Config{
|
|
Name: "example_tool",
|
|
Kind: "postgres-sql",
|
|
Source: "my-pg-instance",
|
|
Description: "some description",
|
|
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
|
Parameters: []tools.Parameter{
|
|
tools.NewStringParameter("country", "some description"),
|
|
},
|
|
AuthRequired: []string{},
|
|
},
|
|
},
|
|
Toolsets: server.ToolsetConfigs{
|
|
"example_toolset": tools.ToolsetConfig{
|
|
Name: "example_toolset",
|
|
ToolNames: []string{"example_tool"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.description, func(t *testing.T) {
|
|
toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in))
|
|
if err != nil {
|
|
t.Fatalf("failed to parse input: %v", err)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" {
|
|
t.Fatalf("incorrect sources parse: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" {
|
|
t.Fatalf("incorrect authServices parse: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" {
|
|
t.Fatalf("incorrect tools parse: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.Toolsets, toolsFile.Toolsets); diff != "" {
|
|
t.Fatalf("incorrect tools parse: diff %v", diff)
|
|
}
|
|
})
|
|
}
|
|
|
|
}
|
|
|
|
func TestParseToolFileWithAuth(t *testing.T) {
|
|
ctx, err := testutils.ContextWithNewLogger()
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %s", err)
|
|
}
|
|
tcs := []struct {
|
|
description string
|
|
in string
|
|
wantToolsFile ToolsFile
|
|
}{
|
|
{
|
|
description: "basic example",
|
|
in: `
|
|
sources:
|
|
my-pg-instance:
|
|
kind: cloud-sql-postgres
|
|
project: my-project
|
|
region: my-region
|
|
instance: my-instance
|
|
database: my_db
|
|
user: my_user
|
|
password: my_pass
|
|
authServices:
|
|
my-google-service:
|
|
kind: google
|
|
clientId: my-client-id
|
|
other-google-service:
|
|
kind: google
|
|
clientId: other-client-id
|
|
|
|
tools:
|
|
example_tool:
|
|
kind: postgres-sql
|
|
source: my-pg-instance
|
|
description: some description
|
|
statement: |
|
|
SELECT * FROM SQL_STATEMENT;
|
|
parameters:
|
|
- name: country
|
|
type: string
|
|
description: some description
|
|
- name: id
|
|
type: integer
|
|
description: user id
|
|
authServices:
|
|
- name: my-google-service
|
|
field: user_id
|
|
- name: email
|
|
type: string
|
|
description: user email
|
|
authServices:
|
|
- name: my-google-service
|
|
field: email
|
|
- name: other-google-service
|
|
field: other_email
|
|
|
|
toolsets:
|
|
example_toolset:
|
|
- example_tool
|
|
`,
|
|
wantToolsFile: ToolsFile{
|
|
Sources: server.SourceConfigs{
|
|
"my-pg-instance": cloudsqlpgsrc.Config{
|
|
Name: "my-pg-instance",
|
|
Kind: cloudsqlpgsrc.SourceKind,
|
|
Project: "my-project",
|
|
Region: "my-region",
|
|
Instance: "my-instance",
|
|
IPType: "public",
|
|
Database: "my_db",
|
|
User: "my_user",
|
|
Password: "my_pass",
|
|
},
|
|
},
|
|
AuthServices: server.AuthServiceConfigs{
|
|
"my-google-service": google.Config{
|
|
Name: "my-google-service",
|
|
Kind: google.AuthServiceKind,
|
|
ClientID: "my-client-id",
|
|
},
|
|
"other-google-service": google.Config{
|
|
Name: "other-google-service",
|
|
Kind: google.AuthServiceKind,
|
|
ClientID: "other-client-id",
|
|
},
|
|
},
|
|
Tools: server.ToolConfigs{
|
|
"example_tool": postgressql.Config{
|
|
Name: "example_tool",
|
|
Kind: "postgres-sql",
|
|
Source: "my-pg-instance",
|
|
Description: "some description",
|
|
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
|
AuthRequired: []string{},
|
|
Parameters: []tools.Parameter{
|
|
tools.NewStringParameter("country", "some description"),
|
|
tools.NewIntParameterWithAuth("id", "user id", []tools.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}),
|
|
tools.NewStringParameterWithAuth("email", "user email", []tools.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}),
|
|
},
|
|
},
|
|
},
|
|
Toolsets: server.ToolsetConfigs{
|
|
"example_toolset": tools.ToolsetConfig{
|
|
Name: "example_toolset",
|
|
ToolNames: []string{"example_tool"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
description: "basic example with authSources",
|
|
in: `
|
|
sources:
|
|
my-pg-instance:
|
|
kind: cloud-sql-postgres
|
|
project: my-project
|
|
region: my-region
|
|
instance: my-instance
|
|
database: my_db
|
|
user: my_user
|
|
password: my_pass
|
|
authSources:
|
|
my-google-service:
|
|
kind: google
|
|
clientId: my-client-id
|
|
other-google-service:
|
|
kind: google
|
|
clientId: other-client-id
|
|
|
|
tools:
|
|
example_tool:
|
|
kind: postgres-sql
|
|
source: my-pg-instance
|
|
description: some description
|
|
statement: |
|
|
SELECT * FROM SQL_STATEMENT;
|
|
parameters:
|
|
- name: country
|
|
type: string
|
|
description: some description
|
|
- name: id
|
|
type: integer
|
|
description: user id
|
|
authSources:
|
|
- name: my-google-service
|
|
field: user_id
|
|
- name: email
|
|
type: string
|
|
description: user email
|
|
authSources:
|
|
- name: my-google-service
|
|
field: email
|
|
- name: other-google-service
|
|
field: other_email
|
|
|
|
toolsets:
|
|
example_toolset:
|
|
- example_tool
|
|
`,
|
|
wantToolsFile: ToolsFile{
|
|
Sources: server.SourceConfigs{
|
|
"my-pg-instance": cloudsqlpgsrc.Config{
|
|
Name: "my-pg-instance",
|
|
Kind: cloudsqlpgsrc.SourceKind,
|
|
Project: "my-project",
|
|
Region: "my-region",
|
|
Instance: "my-instance",
|
|
IPType: "public",
|
|
Database: "my_db",
|
|
User: "my_user",
|
|
Password: "my_pass",
|
|
},
|
|
},
|
|
AuthSources: server.AuthServiceConfigs{
|
|
"my-google-service": google.Config{
|
|
Name: "my-google-service",
|
|
Kind: google.AuthServiceKind,
|
|
ClientID: "my-client-id",
|
|
},
|
|
"other-google-service": google.Config{
|
|
Name: "other-google-service",
|
|
Kind: google.AuthServiceKind,
|
|
ClientID: "other-client-id",
|
|
},
|
|
},
|
|
Tools: server.ToolConfigs{
|
|
"example_tool": postgressql.Config{
|
|
Name: "example_tool",
|
|
Kind: "postgres-sql",
|
|
Source: "my-pg-instance",
|
|
Description: "some description",
|
|
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
|
AuthRequired: []string{},
|
|
Parameters: []tools.Parameter{
|
|
tools.NewStringParameter("country", "some description"),
|
|
tools.NewIntParameterWithAuth("id", "user id", []tools.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}),
|
|
tools.NewStringParameterWithAuth("email", "user email", []tools.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}),
|
|
},
|
|
},
|
|
},
|
|
Toolsets: server.ToolsetConfigs{
|
|
"example_toolset": tools.ToolsetConfig{
|
|
Name: "example_toolset",
|
|
ToolNames: []string{"example_tool"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
description: "basic example with authRequired",
|
|
in: `
|
|
sources:
|
|
my-pg-instance:
|
|
kind: cloud-sql-postgres
|
|
project: my-project
|
|
region: my-region
|
|
instance: my-instance
|
|
database: my_db
|
|
user: my_user
|
|
password: my_pass
|
|
authServices:
|
|
my-google-service:
|
|
kind: google
|
|
clientId: my-client-id
|
|
other-google-service:
|
|
kind: google
|
|
clientId: other-client-id
|
|
|
|
tools:
|
|
example_tool:
|
|
kind: postgres-sql
|
|
source: my-pg-instance
|
|
description: some description
|
|
statement: |
|
|
SELECT * FROM SQL_STATEMENT;
|
|
authRequired:
|
|
- my-google-service
|
|
parameters:
|
|
- name: country
|
|
type: string
|
|
description: some description
|
|
- name: id
|
|
type: integer
|
|
description: user id
|
|
authServices:
|
|
- name: my-google-service
|
|
field: user_id
|
|
- name: email
|
|
type: string
|
|
description: user email
|
|
authServices:
|
|
- name: my-google-service
|
|
field: email
|
|
- name: other-google-service
|
|
field: other_email
|
|
|
|
toolsets:
|
|
example_toolset:
|
|
- example_tool
|
|
`,
|
|
wantToolsFile: ToolsFile{
|
|
Sources: server.SourceConfigs{
|
|
"my-pg-instance": cloudsqlpgsrc.Config{
|
|
Name: "my-pg-instance",
|
|
Kind: cloudsqlpgsrc.SourceKind,
|
|
Project: "my-project",
|
|
Region: "my-region",
|
|
Instance: "my-instance",
|
|
IPType: "public",
|
|
Database: "my_db",
|
|
User: "my_user",
|
|
Password: "my_pass",
|
|
},
|
|
},
|
|
AuthServices: server.AuthServiceConfigs{
|
|
"my-google-service": google.Config{
|
|
Name: "my-google-service",
|
|
Kind: google.AuthServiceKind,
|
|
ClientID: "my-client-id",
|
|
},
|
|
"other-google-service": google.Config{
|
|
Name: "other-google-service",
|
|
Kind: google.AuthServiceKind,
|
|
ClientID: "other-client-id",
|
|
},
|
|
},
|
|
Tools: server.ToolConfigs{
|
|
"example_tool": postgressql.Config{
|
|
Name: "example_tool",
|
|
Kind: "postgres-sql",
|
|
Source: "my-pg-instance",
|
|
Description: "some description",
|
|
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
|
AuthRequired: []string{"my-google-service"},
|
|
Parameters: []tools.Parameter{
|
|
tools.NewStringParameter("country", "some description"),
|
|
tools.NewIntParameterWithAuth("id", "user id", []tools.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}),
|
|
tools.NewStringParameterWithAuth("email", "user email", []tools.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}),
|
|
},
|
|
},
|
|
},
|
|
Toolsets: server.ToolsetConfigs{
|
|
"example_toolset": tools.ToolsetConfig{
|
|
Name: "example_toolset",
|
|
ToolNames: []string{"example_tool"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.description, func(t *testing.T) {
|
|
toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in))
|
|
if err != nil {
|
|
t.Fatalf("failed to parse input: %v", err)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" {
|
|
t.Fatalf("incorrect sources parse: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" {
|
|
t.Fatalf("incorrect authServices parse: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" {
|
|
t.Fatalf("incorrect tools parse: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.Toolsets, toolsFile.Toolsets); diff != "" {
|
|
t.Fatalf("incorrect tools parse: diff %v", diff)
|
|
}
|
|
})
|
|
}
|
|
|
|
}
|
|
|
|
func TestEnvVarReplacement(t *testing.T) {
|
|
ctx, err := testutils.ContextWithNewLogger()
|
|
os.Setenv("TestHeader", "ACTUAL_HEADER")
|
|
os.Setenv("API_KEY", "ACTUAL_API_KEY")
|
|
os.Setenv("clientId", "ACTUAL_CLIENT_ID")
|
|
os.Setenv("clientId2", "ACTUAL_CLIENT_ID_2")
|
|
os.Setenv("toolset_name", "ACTUAL_TOOLSET_NAME")
|
|
os.Setenv("cat_string", "cat")
|
|
os.Setenv("food_string", "food")
|
|
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %s", err)
|
|
}
|
|
tcs := []struct {
|
|
description string
|
|
in string
|
|
wantToolsFile ToolsFile
|
|
}{
|
|
{
|
|
description: "file with env var example",
|
|
in: `
|
|
sources:
|
|
my-http-instance:
|
|
kind: http
|
|
baseUrl: http://test_server/
|
|
timeout: 10s
|
|
headers:
|
|
Authorization: ${TestHeader}
|
|
queryParams:
|
|
api-key: ${API_KEY}
|
|
authServices:
|
|
my-google-service:
|
|
kind: google
|
|
clientId: ${clientId}
|
|
other-google-service:
|
|
kind: google
|
|
clientId: ${clientId2}
|
|
|
|
tools:
|
|
example_tool:
|
|
kind: http
|
|
source: my-instance
|
|
method: GET
|
|
path: "search?name=alice&pet=${cat_string}"
|
|
description: some description
|
|
authRequired:
|
|
- my-google-auth-service
|
|
- other-auth-service
|
|
queryParams:
|
|
- name: country
|
|
type: string
|
|
description: some description
|
|
authServices:
|
|
- name: my-google-auth-service
|
|
field: user_id
|
|
- name: other-auth-service
|
|
field: user_id
|
|
requestBody: |
|
|
{
|
|
"age": {{.age}},
|
|
"city": "{{.city}}",
|
|
"food": "${food_string}",
|
|
"other": "$OTHER"
|
|
}
|
|
bodyParams:
|
|
- name: age
|
|
type: integer
|
|
description: age num
|
|
- name: city
|
|
type: string
|
|
description: city string
|
|
headers:
|
|
Authorization: API_KEY
|
|
Content-Type: application/json
|
|
headerParams:
|
|
- name: Language
|
|
type: string
|
|
description: language string
|
|
|
|
toolsets:
|
|
${toolset_name}:
|
|
- example_tool
|
|
`,
|
|
wantToolsFile: ToolsFile{
|
|
Sources: server.SourceConfigs{
|
|
"my-http-instance": httpsrc.Config{
|
|
Name: "my-http-instance",
|
|
Kind: httpsrc.SourceKind,
|
|
BaseURL: "http://test_server/",
|
|
Timeout: "10s",
|
|
DefaultHeaders: map[string]string{"Authorization": "ACTUAL_HEADER"},
|
|
QueryParams: map[string]string{"api-key": "ACTUAL_API_KEY"},
|
|
},
|
|
},
|
|
AuthServices: server.AuthServiceConfigs{
|
|
"my-google-service": google.Config{
|
|
Name: "my-google-service",
|
|
Kind: google.AuthServiceKind,
|
|
ClientID: "ACTUAL_CLIENT_ID",
|
|
},
|
|
"other-google-service": google.Config{
|
|
Name: "other-google-service",
|
|
Kind: google.AuthServiceKind,
|
|
ClientID: "ACTUAL_CLIENT_ID_2",
|
|
},
|
|
},
|
|
Tools: server.ToolConfigs{
|
|
"example_tool": http.Config{
|
|
Name: "example_tool",
|
|
Kind: "http",
|
|
Source: "my-instance",
|
|
Method: "GET",
|
|
Path: "search?name=alice&pet=cat",
|
|
Description: "some description",
|
|
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
|
QueryParams: []tools.Parameter{
|
|
tools.NewStringParameterWithAuth("country", "some description",
|
|
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"},
|
|
{Name: "other-auth-service", Field: "user_id"}}),
|
|
},
|
|
RequestBody: `{
|
|
"age": {{.age}},
|
|
"city": "{{.city}}",
|
|
"food": "food",
|
|
"other": "$OTHER"
|
|
}
|
|
`,
|
|
BodyParams: []tools.Parameter{tools.NewIntParameter("age", "age num"), tools.NewStringParameter("city", "city string")},
|
|
Headers: map[string]string{"Authorization": "API_KEY", "Content-Type": "application/json"},
|
|
HeaderParams: []tools.Parameter{tools.NewStringParameter("Language", "language string")},
|
|
},
|
|
},
|
|
Toolsets: server.ToolsetConfigs{
|
|
"ACTUAL_TOOLSET_NAME": tools.ToolsetConfig{
|
|
Name: "ACTUAL_TOOLSET_NAME",
|
|
ToolNames: []string{"example_tool"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.description, func(t *testing.T) {
|
|
toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in))
|
|
if err != nil {
|
|
t.Fatalf("failed to parse input: %v", err)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" {
|
|
t.Fatalf("incorrect sources parse: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" {
|
|
t.Fatalf("incorrect authServices parse: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" {
|
|
t.Fatalf("incorrect tools parse: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.Toolsets, toolsFile.Toolsets); diff != "" {
|
|
t.Fatalf("incorrect tools parse: diff %v", diff)
|
|
}
|
|
})
|
|
}
|
|
|
|
}
|
|
|
|
func TestPrebuiltTools(t *testing.T) {
|
|
alloydb_config, _ := prebuiltconfigs.Get("alloydb-postgres")
|
|
bigquery_config, _ := prebuiltconfigs.Get("bigquery")
|
|
cloudsqlpg_config, _ := prebuiltconfigs.Get("cloud-sql-postgres")
|
|
cloudsqlmysql_config, _ := prebuiltconfigs.Get("cloud-sql-mysql")
|
|
cloudsqlmssql_config, _ := prebuiltconfigs.Get("cloud-sql-mssql")
|
|
postgresconfig, _ := prebuiltconfigs.Get("postgres")
|
|
spanner_config, _ := prebuiltconfigs.Get("spanner")
|
|
spannerpg_config, _ := prebuiltconfigs.Get("spanner-postgres")
|
|
ctx, err := testutils.ContextWithNewLogger()
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %s", err)
|
|
}
|
|
tcs := []struct {
|
|
name string
|
|
in []byte
|
|
wantToolset server.ToolsetConfigs
|
|
}{
|
|
{
|
|
name: "alloydb prebuilt tools",
|
|
in: alloydb_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"alloydb-postgres-database-tools": tools.ToolsetConfig{
|
|
Name: "alloydb-postgres-database-tools",
|
|
ToolNames: []string{"execute_sql", "list_tables"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "bigquery prebuilt tools",
|
|
in: bigquery_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"bigquery-database-tools": tools.ToolsetConfig{
|
|
Name: "bigquery-database-tools",
|
|
ToolNames: []string{"execute_sql", "get_dataset_info", "get_table_info", "list_dataset_ids", "list_table_ids"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "cloudsqlpg prebuilt tools",
|
|
in: cloudsqlpg_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"cloud-sql-postgres-database-tools": tools.ToolsetConfig{
|
|
Name: "cloud-sql-postgres-database-tools",
|
|
ToolNames: []string{"execute_sql", "list_tables"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "cloudsqlmysql prebuilt tools",
|
|
in: cloudsqlmysql_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"cloud-sql-mysql-database-tools": tools.ToolsetConfig{
|
|
Name: "cloud-sql-mysql-database-tools",
|
|
ToolNames: []string{"execute_sql", "list_tables"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "cloudsqlmssql prebuilt tools",
|
|
in: cloudsqlmssql_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"cloud-sql-mssql-database-tools": tools.ToolsetConfig{
|
|
Name: "cloud-sql-mssql-database-tools",
|
|
ToolNames: []string{"execute_sql", "list_tables"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "postgres prebuilt tools",
|
|
in: postgresconfig,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"postgres-database-tools": tools.ToolsetConfig{
|
|
Name: "postgres-database-tools",
|
|
ToolNames: []string{"execute_sql", "list_tables"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "spanner prebuilt tools",
|
|
in: spanner_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"spanner-database-tools": tools.ToolsetConfig{
|
|
Name: "spanner-database-tools",
|
|
ToolNames: []string{"execute_sql", "execute_sql_dql", "list_tables"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "spanner pg prebuilt tools",
|
|
in: spannerpg_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"spanner-postgres-database-tools": tools.ToolsetConfig{
|
|
Name: "spanner-postgres-database-tools",
|
|
ToolNames: []string{"execute_sql", "execute_sql_dql", "list_tables"},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
toolsFile, err := parseToolsFile(ctx, tc.in)
|
|
if err != nil {
|
|
t.Fatalf("failed to parse input: %v", err)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolset, toolsFile.Toolsets); diff != "" {
|
|
t.Fatalf("incorrect tools parse: diff %v", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUpdateLogLevel(t *testing.T) {
|
|
tcs := []struct {
|
|
desc string
|
|
stdio bool
|
|
logLevel string
|
|
want bool
|
|
}{
|
|
{
|
|
desc: "no stdio",
|
|
stdio: false,
|
|
logLevel: "info",
|
|
want: false,
|
|
},
|
|
{
|
|
desc: "stdio with info log",
|
|
stdio: true,
|
|
logLevel: "info",
|
|
want: true,
|
|
},
|
|
{
|
|
desc: "stdio with debug log",
|
|
stdio: true,
|
|
logLevel: "debug",
|
|
want: true,
|
|
},
|
|
{
|
|
desc: "stdio with warn log",
|
|
stdio: true,
|
|
logLevel: "warn",
|
|
want: false,
|
|
},
|
|
{
|
|
desc: "stdio with error log",
|
|
stdio: true,
|
|
logLevel: "error",
|
|
want: false,
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.desc, func(t *testing.T) {
|
|
got := updateLogLevel(tc.stdio, tc.logLevel)
|
|
if got != tc.want {
|
|
t.Fatalf("incorrect indication to update log level: got %t, want %t", got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|