mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-02-18 11:02:26 -05:00
## Description Implements the 'postgres-upgrade-precheck' tool to allow users to validate instance readiness for major version upgrades for CloudSQL PostgreSQL. This includes the tool implementation, unit tests for YAML parsing, integration tests for tool invocation, and documentation. The tool is also added to the CloudSQL PostgreSQL prebuilt set. TEST output: <img width="3406" height="1646" alt="image" src="https://github.com/user-attachments/assets/6abaa535-285d-4645-9dd3-7ebcd447d448" /> <img width="3532" height="1490" alt="image" src="https://github.com/user-attachments/assets/4d512af1-51fd-4187-b80f-be13198aba68" /> ## PR Checklist > Thank you for opening a Pull Request! Before submitting your PR, there are a > few things you can do to make sure it goes smoothly: - [x] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) - [x] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [x] Ensure the tests and linter pass - [x] Code coverage does not decrease (if any source code was changed) - [x] Appropriate docs were updated (if necessary) - [x] Make sure to add `!` if this involve a breaking change 🛠️ Fixes #1721 --------- Co-authored-by: Averi Kitsch <akitsch@google.com>
1895 lines
57 KiB
Go
1895 lines
57 KiB
Go
// Copyright 2024 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package cmd
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
_ "embed"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"regexp"
|
|
"runtime"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
|
|
"github.com/googleapis/genai-toolbox/internal/auth/google"
|
|
"github.com/googleapis/genai-toolbox/internal/log"
|
|
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
|
"github.com/googleapis/genai-toolbox/internal/prompts/custom"
|
|
"github.com/googleapis/genai-toolbox/internal/server"
|
|
cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
|
httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http"
|
|
"github.com/googleapis/genai-toolbox/internal/telemetry"
|
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
|
"github.com/googleapis/genai-toolbox/internal/tools/http"
|
|
"github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql"
|
|
"github.com/googleapis/genai-toolbox/internal/util"
|
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
|
"github.com/spf13/cobra"
|
|
)
|
|
|
|
func withDefaults(c server.ServerConfig) server.ServerConfig {
|
|
data, _ := os.ReadFile("version.txt")
|
|
version := strings.TrimSpace(string(data)) // Preserving 'data', new var for clarity
|
|
c.Version = version + "+" + strings.Join([]string{"dev", runtime.GOOS, runtime.GOARCH}, ".")
|
|
|
|
if c.Address == "" {
|
|
c.Address = "127.0.0.1"
|
|
}
|
|
if c.Port == 0 {
|
|
c.Port = 5000
|
|
}
|
|
if c.TelemetryServiceName == "" {
|
|
c.TelemetryServiceName = "toolbox"
|
|
}
|
|
return c
|
|
}
|
|
|
|
func invokeCommand(args []string) (*Command, string, error) {
|
|
c := NewCommand()
|
|
|
|
// Keep the test output quiet
|
|
c.SilenceUsage = true
|
|
c.SilenceErrors = true
|
|
|
|
// Capture output
|
|
buf := new(bytes.Buffer)
|
|
c.SetOut(buf)
|
|
c.SetErr(buf)
|
|
c.SetArgs(args)
|
|
|
|
// Disable execute behavior
|
|
c.RunE = func(*cobra.Command, []string) error {
|
|
return nil
|
|
}
|
|
|
|
err := c.Execute()
|
|
|
|
return c, buf.String(), err
|
|
}
|
|
|
|
func TestVersion(t *testing.T) {
|
|
data, err := os.ReadFile("version.txt")
|
|
if err != nil {
|
|
t.Fatalf("failed to read version.txt: %v", err)
|
|
}
|
|
want := strings.TrimSpace(string(data))
|
|
|
|
_, got, err := invokeCommand([]string{"--version"})
|
|
if err != nil {
|
|
t.Fatalf("error invoking command: %s", err)
|
|
}
|
|
|
|
if !strings.Contains(got, want) {
|
|
t.Errorf("cli did not return correct version: want %q, got %q", want, got)
|
|
}
|
|
}
|
|
|
|
func TestServerConfigFlags(t *testing.T) {
|
|
tcs := []struct {
|
|
desc string
|
|
args []string
|
|
want server.ServerConfig
|
|
}{
|
|
{
|
|
desc: "default values",
|
|
args: []string{},
|
|
want: withDefaults(server.ServerConfig{}),
|
|
},
|
|
{
|
|
desc: "address short",
|
|
args: []string{"-a", "127.0.1.1"},
|
|
want: withDefaults(server.ServerConfig{
|
|
Address: "127.0.1.1",
|
|
}),
|
|
},
|
|
{
|
|
desc: "address long",
|
|
args: []string{"--address", "0.0.0.0"},
|
|
want: withDefaults(server.ServerConfig{
|
|
Address: "0.0.0.0",
|
|
}),
|
|
},
|
|
{
|
|
desc: "port short",
|
|
args: []string{"-p", "5052"},
|
|
want: withDefaults(server.ServerConfig{
|
|
Port: 5052,
|
|
}),
|
|
},
|
|
{
|
|
desc: "port long",
|
|
args: []string{"--port", "5050"},
|
|
want: withDefaults(server.ServerConfig{
|
|
Port: 5050,
|
|
}),
|
|
},
|
|
{
|
|
desc: "logging format",
|
|
args: []string{"--logging-format", "JSON"},
|
|
want: withDefaults(server.ServerConfig{
|
|
LoggingFormat: "JSON",
|
|
}),
|
|
},
|
|
{
|
|
desc: "debug logs",
|
|
args: []string{"--log-level", "WARN"},
|
|
want: withDefaults(server.ServerConfig{
|
|
LogLevel: "WARN",
|
|
}),
|
|
},
|
|
{
|
|
desc: "telemetry gcp",
|
|
args: []string{"--telemetry-gcp"},
|
|
want: withDefaults(server.ServerConfig{
|
|
TelemetryGCP: true,
|
|
}),
|
|
},
|
|
{
|
|
desc: "telemetry otlp",
|
|
args: []string{"--telemetry-otlp", "http://127.0.0.1:4553"},
|
|
want: withDefaults(server.ServerConfig{
|
|
TelemetryOTLP: "http://127.0.0.1:4553",
|
|
}),
|
|
},
|
|
{
|
|
desc: "telemetry service name",
|
|
args: []string{"--telemetry-service-name", "toolbox-custom"},
|
|
want: withDefaults(server.ServerConfig{
|
|
TelemetryServiceName: "toolbox-custom",
|
|
}),
|
|
},
|
|
{
|
|
desc: "stdio",
|
|
args: []string{"--stdio"},
|
|
want: withDefaults(server.ServerConfig{
|
|
Stdio: true,
|
|
}),
|
|
},
|
|
{
|
|
desc: "disable reload",
|
|
args: []string{"--disable-reload"},
|
|
want: withDefaults(server.ServerConfig{
|
|
DisableReload: true,
|
|
}),
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.desc, func(t *testing.T) {
|
|
c, _, err := invokeCommand(tc.args)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error invoking command: %s", err)
|
|
}
|
|
|
|
if !cmp.Equal(c.cfg, tc.want) {
|
|
t.Fatalf("got %v, want %v", c.cfg, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParseEnv(t *testing.T) {
|
|
tcs := []struct {
|
|
desc string
|
|
env map[string]string
|
|
in string
|
|
want string
|
|
err bool
|
|
errString string
|
|
}{
|
|
{
|
|
desc: "without default without env",
|
|
in: "${FOO}",
|
|
want: "",
|
|
err: true,
|
|
errString: `environment variable not found: "FOO"`,
|
|
},
|
|
{
|
|
desc: "without default with env",
|
|
env: map[string]string{
|
|
"FOO": "bar",
|
|
},
|
|
in: "${FOO}",
|
|
want: "bar",
|
|
},
|
|
{
|
|
desc: "with empty default",
|
|
in: "${FOO:}",
|
|
want: "",
|
|
},
|
|
{
|
|
desc: "with default",
|
|
in: "${FOO:bar}",
|
|
want: "bar",
|
|
},
|
|
{
|
|
desc: "with default with env",
|
|
env: map[string]string{
|
|
"FOO": "hello",
|
|
},
|
|
in: "${FOO:bar}",
|
|
want: "hello",
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.desc, func(t *testing.T) {
|
|
if tc.env != nil {
|
|
for k, v := range tc.env {
|
|
t.Setenv(k, v)
|
|
}
|
|
}
|
|
got, err := parseEnv(tc.in)
|
|
if tc.err {
|
|
if err == nil {
|
|
t.Fatalf("expected error not found")
|
|
}
|
|
if tc.errString != err.Error() {
|
|
t.Fatalf("incorrect error string: got %s, want %s", err, tc.errString)
|
|
}
|
|
}
|
|
if tc.want != got {
|
|
t.Fatalf("unexpected want: got %s, want %s", got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestToolFileFlag(t *testing.T) {
|
|
tcs := []struct {
|
|
desc string
|
|
args []string
|
|
want string
|
|
}{
|
|
{
|
|
desc: "default value",
|
|
args: []string{},
|
|
want: "",
|
|
},
|
|
{
|
|
desc: "foo file",
|
|
args: []string{"--tools-file", "foo.yaml"},
|
|
want: "foo.yaml",
|
|
},
|
|
{
|
|
desc: "address long",
|
|
args: []string{"--tools-file", "bar.yaml"},
|
|
want: "bar.yaml",
|
|
},
|
|
{
|
|
desc: "deprecated flag",
|
|
args: []string{"--tools_file", "foo.yaml"},
|
|
want: "foo.yaml",
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.desc, func(t *testing.T) {
|
|
c, _, err := invokeCommand(tc.args)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error invoking command: %s", err)
|
|
}
|
|
if c.tools_file != tc.want {
|
|
t.Fatalf("got %v, want %v", c.cfg, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestToolsFilesFlag(t *testing.T) {
|
|
tcs := []struct {
|
|
desc string
|
|
args []string
|
|
want []string
|
|
}{
|
|
{
|
|
desc: "no value",
|
|
args: []string{},
|
|
want: []string{},
|
|
},
|
|
{
|
|
desc: "single file",
|
|
args: []string{"--tools-files", "foo.yaml"},
|
|
want: []string{"foo.yaml"},
|
|
},
|
|
{
|
|
desc: "multiple files",
|
|
args: []string{"--tools-files", "foo.yaml,bar.yaml"},
|
|
want: []string{"foo.yaml", "bar.yaml"},
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.desc, func(t *testing.T) {
|
|
c, _, err := invokeCommand(tc.args)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error invoking command: %s", err)
|
|
}
|
|
if diff := cmp.Diff(c.tools_files, tc.want); diff != "" {
|
|
t.Fatalf("got %v, want %v", c.tools_files, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestToolsFolderFlag(t *testing.T) {
|
|
tcs := []struct {
|
|
desc string
|
|
args []string
|
|
want string
|
|
}{
|
|
{
|
|
desc: "no value",
|
|
args: []string{},
|
|
want: "",
|
|
},
|
|
{
|
|
desc: "folder set",
|
|
args: []string{"--tools-folder", "test-folder"},
|
|
want: "test-folder",
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.desc, func(t *testing.T) {
|
|
c, _, err := invokeCommand(tc.args)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error invoking command: %s", err)
|
|
}
|
|
if c.tools_folder != tc.want {
|
|
t.Fatalf("got %v, want %v", c.tools_folder, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestPrebuiltFlag(t *testing.T) {
|
|
tcs := []struct {
|
|
desc string
|
|
args []string
|
|
want string
|
|
}{
|
|
{
|
|
desc: "default value",
|
|
args: []string{},
|
|
want: "",
|
|
},
|
|
{
|
|
desc: "custom pre built flag",
|
|
args: []string{"--tools-file", "alloydb"},
|
|
want: "alloydb",
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.desc, func(t *testing.T) {
|
|
c, _, err := invokeCommand(tc.args)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error invoking command: %s", err)
|
|
}
|
|
if c.tools_file != tc.want {
|
|
t.Fatalf("got %v, want %v", c.cfg, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFailServerConfigFlags(t *testing.T) {
|
|
tcs := []struct {
|
|
desc string
|
|
args []string
|
|
}{
|
|
{
|
|
desc: "logging format",
|
|
args: []string{"--logging-format", "fail"},
|
|
},
|
|
{
|
|
desc: "debug logs",
|
|
args: []string{"--log-level", "fail"},
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.desc, func(t *testing.T) {
|
|
_, _, err := invokeCommand(tc.args)
|
|
if err == nil {
|
|
t.Fatalf("expected an error, but got nil")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDefaultLoggingFormat(t *testing.T) {
|
|
c, _, err := invokeCommand([]string{})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error invoking command: %s", err)
|
|
}
|
|
got := c.cfg.LoggingFormat.String()
|
|
want := "standard"
|
|
if got != want {
|
|
t.Fatalf("unexpected default logging format flag: got %v, want %v", got, want)
|
|
}
|
|
}
|
|
|
|
func TestDefaultLogLevel(t *testing.T) {
|
|
c, _, err := invokeCommand([]string{})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error invoking command: %s", err)
|
|
}
|
|
got := c.cfg.LogLevel.String()
|
|
want := "info"
|
|
if got != want {
|
|
t.Fatalf("unexpected default log level flag: got %v, want %v", got, want)
|
|
}
|
|
}
|
|
|
|
func TestParseToolFile(t *testing.T) {
|
|
ctx, err := testutils.ContextWithNewLogger()
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %s", err)
|
|
}
|
|
tcs := []struct {
|
|
description string
|
|
in string
|
|
wantToolsFile ToolsFile
|
|
}{
|
|
{
|
|
description: "basic example",
|
|
in: `
|
|
sources:
|
|
my-pg-instance:
|
|
kind: cloud-sql-postgres
|
|
project: my-project
|
|
region: my-region
|
|
instance: my-instance
|
|
database: my_db
|
|
user: my_user
|
|
password: my_pass
|
|
tools:
|
|
example_tool:
|
|
kind: postgres-sql
|
|
source: my-pg-instance
|
|
description: some description
|
|
statement: |
|
|
SELECT * FROM SQL_STATEMENT;
|
|
parameters:
|
|
- name: country
|
|
type: string
|
|
description: some description
|
|
toolsets:
|
|
example_toolset:
|
|
- example_tool
|
|
`,
|
|
wantToolsFile: ToolsFile{
|
|
Sources: server.SourceConfigs{
|
|
"my-pg-instance": cloudsqlpgsrc.Config{
|
|
Name: "my-pg-instance",
|
|
Kind: cloudsqlpgsrc.SourceKind,
|
|
Project: "my-project",
|
|
Region: "my-region",
|
|
Instance: "my-instance",
|
|
IPType: "public",
|
|
Database: "my_db",
|
|
User: "my_user",
|
|
Password: "my_pass",
|
|
},
|
|
},
|
|
Tools: server.ToolConfigs{
|
|
"example_tool": postgressql.Config{
|
|
Name: "example_tool",
|
|
Kind: "postgres-sql",
|
|
Source: "my-pg-instance",
|
|
Description: "some description",
|
|
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
|
Parameters: []parameters.Parameter{
|
|
parameters.NewStringParameter("country", "some description"),
|
|
},
|
|
AuthRequired: []string{},
|
|
},
|
|
},
|
|
Toolsets: server.ToolsetConfigs{
|
|
"example_toolset": tools.ToolsetConfig{
|
|
Name: "example_toolset",
|
|
ToolNames: []string{"example_tool"},
|
|
},
|
|
},
|
|
Prompts: nil,
|
|
},
|
|
},
|
|
{
|
|
description: "with prompts example",
|
|
in: `
|
|
prompts:
|
|
my-prompt:
|
|
description: A prompt template for data analysis.
|
|
arguments:
|
|
- name: country
|
|
description: The country to analyze.
|
|
messages:
|
|
- content: Analyze the data for {{.country}}.
|
|
`,
|
|
wantToolsFile: ToolsFile{
|
|
Sources: nil,
|
|
AuthServices: nil,
|
|
Tools: nil,
|
|
Toolsets: nil,
|
|
Prompts: server.PromptConfigs{
|
|
"my-prompt": &custom.Config{
|
|
Name: "my-prompt",
|
|
Description: "A prompt template for data analysis.",
|
|
Arguments: prompts.Arguments{
|
|
{Parameter: parameters.NewStringParameter("country", "The country to analyze.")},
|
|
},
|
|
Messages: []prompts.Message{
|
|
{Role: "user", Content: "Analyze the data for {{.country}}."},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.description, func(t *testing.T) {
|
|
toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in))
|
|
if err != nil {
|
|
t.Fatalf("failed to parse input: %v", err)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" {
|
|
t.Fatalf("incorrect sources parse: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" {
|
|
t.Fatalf("incorrect authServices parse: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" {
|
|
t.Fatalf("incorrect tools parse: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.Toolsets, toolsFile.Toolsets); diff != "" {
|
|
t.Fatalf("incorrect toolsets parse: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.Prompts, toolsFile.Prompts); diff != "" {
|
|
t.Fatalf("incorrect prompts parse: diff %v", diff)
|
|
}
|
|
})
|
|
}
|
|
|
|
}
|
|
|
|
func TestParseToolFileWithAuth(t *testing.T) {
|
|
ctx, err := testutils.ContextWithNewLogger()
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %s", err)
|
|
}
|
|
tcs := []struct {
|
|
description string
|
|
in string
|
|
wantToolsFile ToolsFile
|
|
}{
|
|
{
|
|
description: "basic example",
|
|
in: `
|
|
sources:
|
|
my-pg-instance:
|
|
kind: cloud-sql-postgres
|
|
project: my-project
|
|
region: my-region
|
|
instance: my-instance
|
|
database: my_db
|
|
user: my_user
|
|
password: my_pass
|
|
authServices:
|
|
my-google-service:
|
|
kind: google
|
|
clientId: my-client-id
|
|
other-google-service:
|
|
kind: google
|
|
clientId: other-client-id
|
|
|
|
tools:
|
|
example_tool:
|
|
kind: postgres-sql
|
|
source: my-pg-instance
|
|
description: some description
|
|
statement: |
|
|
SELECT * FROM SQL_STATEMENT;
|
|
parameters:
|
|
- name: country
|
|
type: string
|
|
description: some description
|
|
- name: id
|
|
type: integer
|
|
description: user id
|
|
authServices:
|
|
- name: my-google-service
|
|
field: user_id
|
|
- name: email
|
|
type: string
|
|
description: user email
|
|
authServices:
|
|
- name: my-google-service
|
|
field: email
|
|
- name: other-google-service
|
|
field: other_email
|
|
|
|
toolsets:
|
|
example_toolset:
|
|
- example_tool
|
|
`,
|
|
wantToolsFile: ToolsFile{
|
|
Sources: server.SourceConfigs{
|
|
"my-pg-instance": cloudsqlpgsrc.Config{
|
|
Name: "my-pg-instance",
|
|
Kind: cloudsqlpgsrc.SourceKind,
|
|
Project: "my-project",
|
|
Region: "my-region",
|
|
Instance: "my-instance",
|
|
IPType: "public",
|
|
Database: "my_db",
|
|
User: "my_user",
|
|
Password: "my_pass",
|
|
},
|
|
},
|
|
AuthServices: server.AuthServiceConfigs{
|
|
"my-google-service": google.Config{
|
|
Name: "my-google-service",
|
|
Kind: google.AuthServiceKind,
|
|
ClientID: "my-client-id",
|
|
},
|
|
"other-google-service": google.Config{
|
|
Name: "other-google-service",
|
|
Kind: google.AuthServiceKind,
|
|
ClientID: "other-client-id",
|
|
},
|
|
},
|
|
Tools: server.ToolConfigs{
|
|
"example_tool": postgressql.Config{
|
|
Name: "example_tool",
|
|
Kind: "postgres-sql",
|
|
Source: "my-pg-instance",
|
|
Description: "some description",
|
|
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
|
AuthRequired: []string{},
|
|
Parameters: []parameters.Parameter{
|
|
parameters.NewStringParameter("country", "some description"),
|
|
parameters.NewIntParameterWithAuth("id", "user id", []parameters.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}),
|
|
parameters.NewStringParameterWithAuth("email", "user email", []parameters.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}),
|
|
},
|
|
},
|
|
},
|
|
Toolsets: server.ToolsetConfigs{
|
|
"example_toolset": tools.ToolsetConfig{
|
|
Name: "example_toolset",
|
|
ToolNames: []string{"example_tool"},
|
|
},
|
|
},
|
|
Prompts: nil,
|
|
},
|
|
},
|
|
{
|
|
description: "basic example with authSources",
|
|
in: `
|
|
sources:
|
|
my-pg-instance:
|
|
kind: cloud-sql-postgres
|
|
project: my-project
|
|
region: my-region
|
|
instance: my-instance
|
|
database: my_db
|
|
user: my_user
|
|
password: my_pass
|
|
authSources:
|
|
my-google-service:
|
|
kind: google
|
|
clientId: my-client-id
|
|
other-google-service:
|
|
kind: google
|
|
clientId: other-client-id
|
|
|
|
tools:
|
|
example_tool:
|
|
kind: postgres-sql
|
|
source: my-pg-instance
|
|
description: some description
|
|
statement: |
|
|
SELECT * FROM SQL_STATEMENT;
|
|
parameters:
|
|
- name: country
|
|
type: string
|
|
description: some description
|
|
- name: id
|
|
type: integer
|
|
description: user id
|
|
authSources:
|
|
- name: my-google-service
|
|
field: user_id
|
|
- name: email
|
|
type: string
|
|
description: user email
|
|
authSources:
|
|
- name: my-google-service
|
|
field: email
|
|
- name: other-google-service
|
|
field: other_email
|
|
|
|
toolsets:
|
|
example_toolset:
|
|
- example_tool
|
|
`,
|
|
wantToolsFile: ToolsFile{
|
|
Sources: server.SourceConfigs{
|
|
"my-pg-instance": cloudsqlpgsrc.Config{
|
|
Name: "my-pg-instance",
|
|
Kind: cloudsqlpgsrc.SourceKind,
|
|
Project: "my-project",
|
|
Region: "my-region",
|
|
Instance: "my-instance",
|
|
IPType: "public",
|
|
Database: "my_db",
|
|
User: "my_user",
|
|
Password: "my_pass",
|
|
},
|
|
},
|
|
AuthSources: server.AuthServiceConfigs{
|
|
"my-google-service": google.Config{
|
|
Name: "my-google-service",
|
|
Kind: google.AuthServiceKind,
|
|
ClientID: "my-client-id",
|
|
},
|
|
"other-google-service": google.Config{
|
|
Name: "other-google-service",
|
|
Kind: google.AuthServiceKind,
|
|
ClientID: "other-client-id",
|
|
},
|
|
},
|
|
Tools: server.ToolConfigs{
|
|
"example_tool": postgressql.Config{
|
|
Name: "example_tool",
|
|
Kind: "postgres-sql",
|
|
Source: "my-pg-instance",
|
|
Description: "some description",
|
|
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
|
AuthRequired: []string{},
|
|
Parameters: []parameters.Parameter{
|
|
parameters.NewStringParameter("country", "some description"),
|
|
parameters.NewIntParameterWithAuth("id", "user id", []parameters.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}),
|
|
parameters.NewStringParameterWithAuth("email", "user email", []parameters.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}),
|
|
},
|
|
},
|
|
},
|
|
Toolsets: server.ToolsetConfigs{
|
|
"example_toolset": tools.ToolsetConfig{
|
|
Name: "example_toolset",
|
|
ToolNames: []string{"example_tool"},
|
|
},
|
|
},
|
|
Prompts: nil,
|
|
},
|
|
},
|
|
{
|
|
description: "basic example with authRequired",
|
|
in: `
|
|
sources:
|
|
my-pg-instance:
|
|
kind: cloud-sql-postgres
|
|
project: my-project
|
|
region: my-region
|
|
instance: my-instance
|
|
database: my_db
|
|
user: my_user
|
|
password: my_pass
|
|
authServices:
|
|
my-google-service:
|
|
kind: google
|
|
clientId: my-client-id
|
|
other-google-service:
|
|
kind: google
|
|
clientId: other-client-id
|
|
|
|
tools:
|
|
example_tool:
|
|
kind: postgres-sql
|
|
source: my-pg-instance
|
|
description: some description
|
|
statement: |
|
|
SELECT * FROM SQL_STATEMENT;
|
|
authRequired:
|
|
- my-google-service
|
|
parameters:
|
|
- name: country
|
|
type: string
|
|
description: some description
|
|
- name: id
|
|
type: integer
|
|
description: user id
|
|
authServices:
|
|
- name: my-google-service
|
|
field: user_id
|
|
- name: email
|
|
type: string
|
|
description: user email
|
|
authServices:
|
|
- name: my-google-service
|
|
field: email
|
|
- name: other-google-service
|
|
field: other_email
|
|
|
|
toolsets:
|
|
example_toolset:
|
|
- example_tool
|
|
`,
|
|
wantToolsFile: ToolsFile{
|
|
Sources: server.SourceConfigs{
|
|
"my-pg-instance": cloudsqlpgsrc.Config{
|
|
Name: "my-pg-instance",
|
|
Kind: cloudsqlpgsrc.SourceKind,
|
|
Project: "my-project",
|
|
Region: "my-region",
|
|
Instance: "my-instance",
|
|
IPType: "public",
|
|
Database: "my_db",
|
|
User: "my_user",
|
|
Password: "my_pass",
|
|
},
|
|
},
|
|
AuthServices: server.AuthServiceConfigs{
|
|
"my-google-service": google.Config{
|
|
Name: "my-google-service",
|
|
Kind: google.AuthServiceKind,
|
|
ClientID: "my-client-id",
|
|
},
|
|
"other-google-service": google.Config{
|
|
Name: "other-google-service",
|
|
Kind: google.AuthServiceKind,
|
|
ClientID: "other-client-id",
|
|
},
|
|
},
|
|
Tools: server.ToolConfigs{
|
|
"example_tool": postgressql.Config{
|
|
Name: "example_tool",
|
|
Kind: "postgres-sql",
|
|
Source: "my-pg-instance",
|
|
Description: "some description",
|
|
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
|
AuthRequired: []string{"my-google-service"},
|
|
Parameters: []parameters.Parameter{
|
|
parameters.NewStringParameter("country", "some description"),
|
|
parameters.NewIntParameterWithAuth("id", "user id", []parameters.ParamAuthService{{Name: "my-google-service", Field: "user_id"}}),
|
|
parameters.NewStringParameterWithAuth("email", "user email", []parameters.ParamAuthService{{Name: "my-google-service", Field: "email"}, {Name: "other-google-service", Field: "other_email"}}),
|
|
},
|
|
},
|
|
},
|
|
Toolsets: server.ToolsetConfigs{
|
|
"example_toolset": tools.ToolsetConfig{
|
|
Name: "example_toolset",
|
|
ToolNames: []string{"example_tool"},
|
|
},
|
|
},
|
|
Prompts: nil,
|
|
},
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.description, func(t *testing.T) {
|
|
toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in))
|
|
if err != nil {
|
|
t.Fatalf("failed to parse input: %v", err)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" {
|
|
t.Fatalf("incorrect sources parse: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" {
|
|
t.Fatalf("incorrect authServices parse: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" {
|
|
t.Fatalf("incorrect tools parse: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.Toolsets, toolsFile.Toolsets); diff != "" {
|
|
t.Fatalf("incorrect toolsets parse: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.Prompts, toolsFile.Prompts); diff != "" {
|
|
t.Fatalf("incorrect prompts parse: diff %v", diff)
|
|
}
|
|
})
|
|
}
|
|
|
|
}
|
|
|
|
func TestEnvVarReplacement(t *testing.T) {
|
|
ctx, err := testutils.ContextWithNewLogger()
|
|
t.Setenv("TestHeader", "ACTUAL_HEADER")
|
|
t.Setenv("API_KEY", "ACTUAL_API_KEY")
|
|
t.Setenv("clientId", "ACTUAL_CLIENT_ID")
|
|
t.Setenv("clientId2", "ACTUAL_CLIENT_ID_2")
|
|
t.Setenv("toolset_name", "ACTUAL_TOOLSET_NAME")
|
|
t.Setenv("cat_string", "cat")
|
|
t.Setenv("food_string", "food")
|
|
t.Setenv("TestHeader", "ACTUAL_HEADER")
|
|
t.Setenv("prompt_name", "ACTUAL_PROMPT_NAME")
|
|
t.Setenv("prompt_content", "ACTUAL_CONTENT")
|
|
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %s", err)
|
|
}
|
|
tcs := []struct {
|
|
description string
|
|
in string
|
|
wantToolsFile ToolsFile
|
|
}{
|
|
{
|
|
description: "file with env var example",
|
|
in: `
|
|
sources:
|
|
my-http-instance:
|
|
kind: http
|
|
baseUrl: http://test_server/
|
|
timeout: 10s
|
|
headers:
|
|
Authorization: ${TestHeader}
|
|
queryParams:
|
|
api-key: ${API_KEY}
|
|
authServices:
|
|
my-google-service:
|
|
kind: google
|
|
clientId: ${clientId}
|
|
other-google-service:
|
|
kind: google
|
|
clientId: ${clientId2}
|
|
|
|
tools:
|
|
example_tool:
|
|
kind: http
|
|
source: my-instance
|
|
method: GET
|
|
path: "search?name=alice&pet=${cat_string}"
|
|
description: some description
|
|
authRequired:
|
|
- my-google-auth-service
|
|
- other-auth-service
|
|
queryParams:
|
|
- name: country
|
|
type: string
|
|
description: some description
|
|
authServices:
|
|
- name: my-google-auth-service
|
|
field: user_id
|
|
- name: other-auth-service
|
|
field: user_id
|
|
requestBody: |
|
|
{
|
|
"age": {{.age}},
|
|
"city": "{{.city}}",
|
|
"food": "${food_string}",
|
|
"other": "$OTHER"
|
|
}
|
|
bodyParams:
|
|
- name: age
|
|
type: integer
|
|
description: age num
|
|
- name: city
|
|
type: string
|
|
description: city string
|
|
headers:
|
|
Authorization: API_KEY
|
|
Content-Type: application/json
|
|
headerParams:
|
|
- name: Language
|
|
type: string
|
|
description: language string
|
|
|
|
toolsets:
|
|
${toolset_name}:
|
|
- example_tool
|
|
|
|
|
|
prompts:
|
|
${prompt_name}:
|
|
description: A test prompt for {{.name}}.
|
|
messages:
|
|
- role: user
|
|
content: ${prompt_content}
|
|
`,
|
|
wantToolsFile: ToolsFile{
|
|
Sources: server.SourceConfigs{
|
|
"my-http-instance": httpsrc.Config{
|
|
Name: "my-http-instance",
|
|
Kind: httpsrc.SourceKind,
|
|
BaseURL: "http://test_server/",
|
|
Timeout: "10s",
|
|
DefaultHeaders: map[string]string{"Authorization": "ACTUAL_HEADER"},
|
|
QueryParams: map[string]string{"api-key": "ACTUAL_API_KEY"},
|
|
},
|
|
},
|
|
AuthServices: server.AuthServiceConfigs{
|
|
"my-google-service": google.Config{
|
|
Name: "my-google-service",
|
|
Kind: google.AuthServiceKind,
|
|
ClientID: "ACTUAL_CLIENT_ID",
|
|
},
|
|
"other-google-service": google.Config{
|
|
Name: "other-google-service",
|
|
Kind: google.AuthServiceKind,
|
|
ClientID: "ACTUAL_CLIENT_ID_2",
|
|
},
|
|
},
|
|
Tools: server.ToolConfigs{
|
|
"example_tool": http.Config{
|
|
Name: "example_tool",
|
|
Kind: "http",
|
|
Source: "my-instance",
|
|
Method: "GET",
|
|
Path: "search?name=alice&pet=cat",
|
|
Description: "some description",
|
|
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
|
QueryParams: []parameters.Parameter{
|
|
parameters.NewStringParameterWithAuth("country", "some description",
|
|
[]parameters.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"},
|
|
{Name: "other-auth-service", Field: "user_id"}}),
|
|
},
|
|
RequestBody: `{
|
|
"age": {{.age}},
|
|
"city": "{{.city}}",
|
|
"food": "food",
|
|
"other": "$OTHER"
|
|
}
|
|
`,
|
|
BodyParams: []parameters.Parameter{parameters.NewIntParameter("age", "age num"), parameters.NewStringParameter("city", "city string")},
|
|
Headers: map[string]string{"Authorization": "API_KEY", "Content-Type": "application/json"},
|
|
HeaderParams: []parameters.Parameter{parameters.NewStringParameter("Language", "language string")},
|
|
},
|
|
},
|
|
Toolsets: server.ToolsetConfigs{
|
|
"ACTUAL_TOOLSET_NAME": tools.ToolsetConfig{
|
|
Name: "ACTUAL_TOOLSET_NAME",
|
|
ToolNames: []string{"example_tool"},
|
|
},
|
|
},
|
|
Prompts: server.PromptConfigs{
|
|
"ACTUAL_PROMPT_NAME": &custom.Config{
|
|
Name: "ACTUAL_PROMPT_NAME",
|
|
Description: "A test prompt for {{.name}}.",
|
|
Messages: []prompts.Message{
|
|
{
|
|
Role: "user",
|
|
Content: "ACTUAL_CONTENT",
|
|
},
|
|
},
|
|
Arguments: nil,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.description, func(t *testing.T) {
|
|
toolsFile, err := parseToolsFile(ctx, testutils.FormatYaml(tc.in))
|
|
if err != nil {
|
|
t.Fatalf("failed to parse input: %v", err)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.Sources, toolsFile.Sources); diff != "" {
|
|
t.Fatalf("incorrect sources parse: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.AuthServices, toolsFile.AuthServices); diff != "" {
|
|
t.Fatalf("incorrect authServices parse: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.Tools, toolsFile.Tools); diff != "" {
|
|
t.Fatalf("incorrect tools parse: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.Toolsets, toolsFile.Toolsets); diff != "" {
|
|
t.Fatalf("incorrect toolsets parse: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolsFile.Prompts, toolsFile.Prompts); diff != "" {
|
|
t.Fatalf("incorrect prompts parse: diff %v", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// normalizeFilepaths is a helper function to allow same filepath formats for Mac and Windows.
|
|
// this prevents needing multiple "want" cases for TestResolveWatcherInputs
|
|
func normalizeFilepaths(m map[string]bool) map[string]bool {
|
|
newMap := make(map[string]bool)
|
|
for k, v := range m {
|
|
newMap[filepath.ToSlash(k)] = v
|
|
}
|
|
return newMap
|
|
}
|
|
|
|
func TestResolveWatcherInputs(t *testing.T) {
|
|
tcs := []struct {
|
|
description string
|
|
toolsFile string
|
|
toolsFiles []string
|
|
toolsFolder string
|
|
wantWatchDirs map[string]bool
|
|
wantWatchedFiles map[string]bool
|
|
}{
|
|
{
|
|
description: "single tools file",
|
|
toolsFile: "tools_folder/example_tools.yaml",
|
|
toolsFiles: []string{},
|
|
toolsFolder: "",
|
|
wantWatchDirs: map[string]bool{"tools_folder": true},
|
|
wantWatchedFiles: map[string]bool{"tools_folder/example_tools.yaml": true},
|
|
},
|
|
{
|
|
description: "default tools file (root dir)",
|
|
toolsFile: "tools.yaml",
|
|
toolsFiles: []string{},
|
|
toolsFolder: "",
|
|
wantWatchDirs: map[string]bool{".": true},
|
|
wantWatchedFiles: map[string]bool{"tools.yaml": true},
|
|
},
|
|
{
|
|
description: "multiple files in different folders",
|
|
toolsFile: "",
|
|
toolsFiles: []string{"tools_folder/example_tools.yaml", "tools_folder2/example_tools.yaml"},
|
|
toolsFolder: "",
|
|
wantWatchDirs: map[string]bool{"tools_folder": true, "tools_folder2": true},
|
|
wantWatchedFiles: map[string]bool{
|
|
"tools_folder/example_tools.yaml": true,
|
|
"tools_folder2/example_tools.yaml": true,
|
|
},
|
|
},
|
|
{
|
|
description: "multiple files in same folder",
|
|
toolsFile: "",
|
|
toolsFiles: []string{"tools_folder/example_tools.yaml", "tools_folder/example_tools2.yaml"},
|
|
toolsFolder: "",
|
|
wantWatchDirs: map[string]bool{"tools_folder": true},
|
|
wantWatchedFiles: map[string]bool{
|
|
"tools_folder/example_tools.yaml": true,
|
|
"tools_folder/example_tools2.yaml": true,
|
|
},
|
|
},
|
|
{
|
|
description: "multiple files in different levels",
|
|
toolsFile: "",
|
|
toolsFiles: []string{
|
|
"tools_folder/example_tools.yaml",
|
|
"tools_folder/special_tools/example_tools2.yaml"},
|
|
toolsFolder: "",
|
|
wantWatchDirs: map[string]bool{"tools_folder": true, "tools_folder/special_tools": true},
|
|
wantWatchedFiles: map[string]bool{
|
|
"tools_folder/example_tools.yaml": true,
|
|
"tools_folder/special_tools/example_tools2.yaml": true,
|
|
},
|
|
},
|
|
{
|
|
description: "tools folder",
|
|
toolsFile: "",
|
|
toolsFiles: []string{},
|
|
toolsFolder: "tools_folder",
|
|
wantWatchDirs: map[string]bool{"tools_folder": true},
|
|
wantWatchedFiles: map[string]bool{},
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.description, func(t *testing.T) {
|
|
gotWatchDirs, gotWatchedFiles := resolveWatcherInputs(tc.toolsFile, tc.toolsFiles, tc.toolsFolder)
|
|
|
|
normalizedGotWatchDirs := normalizeFilepaths(gotWatchDirs)
|
|
normalizedGotWatchedFiles := normalizeFilepaths(gotWatchedFiles)
|
|
|
|
if diff := cmp.Diff(tc.wantWatchDirs, normalizedGotWatchDirs); diff != "" {
|
|
t.Errorf("incorrect watchDirs: diff %v", diff)
|
|
}
|
|
if diff := cmp.Diff(tc.wantWatchedFiles, normalizedGotWatchedFiles); diff != "" {
|
|
t.Errorf("incorrect watchedFiles: diff %v", diff)
|
|
}
|
|
|
|
})
|
|
}
|
|
}
|
|
|
|
// helper function for testing file detection in dynamic reloading
|
|
func tmpFileWithCleanup(content []byte) (string, func(), error) {
|
|
f, err := os.CreateTemp("", "*")
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
cleanup := func() { os.Remove(f.Name()) }
|
|
|
|
if _, err := f.Write(content); err != nil {
|
|
cleanup()
|
|
return "", nil, err
|
|
}
|
|
if err := f.Close(); err != nil {
|
|
cleanup()
|
|
return "", nil, err
|
|
}
|
|
return f.Name(), cleanup, err
|
|
}
|
|
|
|
func TestSingleEdit(t *testing.T) {
|
|
ctx, cancelCtx := context.WithTimeout(context.Background(), time.Minute)
|
|
defer cancelCtx()
|
|
|
|
pr, pw := io.Pipe()
|
|
defer pw.Close()
|
|
defer pr.Close()
|
|
|
|
fileToWatch, cleanup, err := tmpFileWithCleanup([]byte("initial content"))
|
|
if err != nil {
|
|
t.Fatalf("error editing tools file %s", err)
|
|
}
|
|
defer cleanup()
|
|
|
|
logger, err := log.NewStdLogger(pw, pw, "DEBUG")
|
|
if err != nil {
|
|
t.Fatalf("failed to setup logger %s", err)
|
|
}
|
|
ctx = util.WithLogger(ctx, logger)
|
|
|
|
instrumentation, err := telemetry.CreateTelemetryInstrumentation(versionString)
|
|
if err != nil {
|
|
t.Fatalf("failed to setup instrumentation %s", err)
|
|
}
|
|
ctx = util.WithInstrumentation(ctx, instrumentation)
|
|
|
|
mockServer := &server.Server{}
|
|
|
|
cleanFileToWatch := filepath.Clean(fileToWatch)
|
|
watchDir := filepath.Dir(cleanFileToWatch)
|
|
|
|
watchedFiles := map[string]bool{cleanFileToWatch: true}
|
|
watchDirs := map[string]bool{watchDir: true}
|
|
|
|
go watchChanges(ctx, watchDirs, watchedFiles, mockServer)
|
|
|
|
// escape backslash so regex doesn't fail on windows filepaths
|
|
regexEscapedPathFile := strings.ReplaceAll(cleanFileToWatch, `\`, `\\\\*\\`)
|
|
regexEscapedPathFile = path.Clean(regexEscapedPathFile)
|
|
|
|
regexEscapedPathDir := strings.ReplaceAll(watchDir, `\`, `\\\\*\\`)
|
|
regexEscapedPathDir = path.Clean(regexEscapedPathDir)
|
|
|
|
begunWatchingDir := regexp.MustCompile(fmt.Sprintf(`DEBUG "Added directory %s to watcher."`, regexEscapedPathDir))
|
|
_, err = testutils.WaitForString(ctx, begunWatchingDir, pr)
|
|
if err != nil {
|
|
t.Fatalf("timeout or error waiting for watcher to start: %s", err)
|
|
}
|
|
|
|
err = os.WriteFile(fileToWatch, []byte("modification"), 0777)
|
|
if err != nil {
|
|
t.Fatalf("error writing to file: %v", err)
|
|
}
|
|
|
|
// only check substring of DEBUG message due to some OS/editors firing different operations
|
|
detectedFileChange := regexp.MustCompile(fmt.Sprintf(`event detected in %s"`, regexEscapedPathFile))
|
|
_, err = testutils.WaitForString(ctx, detectedFileChange, pr)
|
|
if err != nil {
|
|
t.Fatalf("timeout or error waiting for file to detect write: %s", err)
|
|
}
|
|
}
|
|
|
|
func TestPrebuiltTools(t *testing.T) {
|
|
// Get prebuilt configs
|
|
alloydb_admin_config, _ := prebuiltconfigs.Get("alloydb-postgres-admin")
|
|
alloydb_config, _ := prebuiltconfigs.Get("alloydb-postgres")
|
|
bigquery_config, _ := prebuiltconfigs.Get("bigquery")
|
|
clickhouse_config, _ := prebuiltconfigs.Get("clickhouse")
|
|
cloudsqlpg_config, _ := prebuiltconfigs.Get("cloud-sql-postgres")
|
|
cloudsqlpg_admin_config, _ := prebuiltconfigs.Get("cloud-sql-postgres-admin")
|
|
cloudsqlmysql_config, _ := prebuiltconfigs.Get("cloud-sql-mysql")
|
|
cloudsqlmysql_admin_config, _ := prebuiltconfigs.Get("cloud-sql-mysql-admin")
|
|
cloudsqlmssql_config, _ := prebuiltconfigs.Get("cloud-sql-mssql")
|
|
cloudsqlmssql_admin_config, _ := prebuiltconfigs.Get("cloud-sql-mssql-admin")
|
|
dataplex_config, _ := prebuiltconfigs.Get("dataplex")
|
|
firestoreconfig, _ := prebuiltconfigs.Get("firestore")
|
|
mysql_config, _ := prebuiltconfigs.Get("mysql")
|
|
mssql_config, _ := prebuiltconfigs.Get("mssql")
|
|
looker_config, _ := prebuiltconfigs.Get("looker")
|
|
lookerca_config, _ := prebuiltconfigs.Get("looker-conversational-analytics")
|
|
postgresconfig, _ := prebuiltconfigs.Get("postgres")
|
|
spanner_config, _ := prebuiltconfigs.Get("spanner")
|
|
spannerpg_config, _ := prebuiltconfigs.Get("spanner-postgres")
|
|
mindsdb_config, _ := prebuiltconfigs.Get("mindsdb")
|
|
sqlite_config, _ := prebuiltconfigs.Get("sqlite")
|
|
neo4jconfig, _ := prebuiltconfigs.Get("neo4j")
|
|
alloydbobsvconfig, _ := prebuiltconfigs.Get("alloydb-postgres-observability")
|
|
cloudsqlpgobsvconfig, _ := prebuiltconfigs.Get("cloud-sql-postgres-observability")
|
|
cloudsqlmysqlobsvconfig, _ := prebuiltconfigs.Get("cloud-sql-mysql-observability")
|
|
cloudsqlmssqlobsvconfig, _ := prebuiltconfigs.Get("cloud-sql-mssql-observability")
|
|
serverless_spark_config, _ := prebuiltconfigs.Get("serverless-spark")
|
|
cloudhealthcare_config, _ := prebuiltconfigs.Get("cloud-healthcare")
|
|
|
|
// Set environment variables
|
|
t.Setenv("API_KEY", "your_api_key")
|
|
|
|
t.Setenv("BIGQUERY_PROJECT", "your_gcp_project_id")
|
|
t.Setenv("DATAPLEX_PROJECT", "your_gcp_project_id")
|
|
t.Setenv("FIRESTORE_PROJECT", "your_gcp_project_id")
|
|
t.Setenv("FIRESTORE_DATABASE", "your_firestore_db_name")
|
|
|
|
t.Setenv("SPANNER_PROJECT", "your_gcp_project_id")
|
|
t.Setenv("SPANNER_INSTANCE", "your_spanner_instance")
|
|
t.Setenv("SPANNER_DATABASE", "your_spanner_db")
|
|
|
|
t.Setenv("ALLOYDB_POSTGRES_PROJECT", "your_gcp_project_id")
|
|
t.Setenv("ALLOYDB_POSTGRES_REGION", "your_gcp_region")
|
|
t.Setenv("ALLOYDB_POSTGRES_CLUSTER", "your_alloydb_cluster")
|
|
t.Setenv("ALLOYDB_POSTGRES_INSTANCE", "your_alloydb_instance")
|
|
t.Setenv("ALLOYDB_POSTGRES_DATABASE", "your_alloydb_db")
|
|
t.Setenv("ALLOYDB_POSTGRES_USER", "your_alloydb_user")
|
|
t.Setenv("ALLOYDB_POSTGRES_PASSWORD", "your_alloydb_password")
|
|
|
|
t.Setenv("CLICKHOUSE_PROTOCOL", "your_clickhouse_protocol")
|
|
t.Setenv("CLICKHOUSE_DATABASE", "your_clickhouse_database")
|
|
t.Setenv("CLICKHOUSE_PASSWORD", "your_clickhouse_password")
|
|
t.Setenv("CLICKHOUSE_USER", "your_clickhouse_user")
|
|
t.Setenv("CLICKHOUSE_HOST", "your_clickhosue_host")
|
|
t.Setenv("CLICKHOUSE_PORT", "8123")
|
|
|
|
t.Setenv("CLOUD_SQL_POSTGRES_PROJECT", "your_pg_project")
|
|
t.Setenv("CLOUD_SQL_POSTGRES_INSTANCE", "your_pg_instance")
|
|
t.Setenv("CLOUD_SQL_POSTGRES_DATABASE", "your_pg_db")
|
|
t.Setenv("CLOUD_SQL_POSTGRES_REGION", "your_pg_region")
|
|
t.Setenv("CLOUD_SQL_POSTGRES_USER", "your_pg_user")
|
|
t.Setenv("CLOUD_SQL_POSTGRES_PASS", "your_pg_pass")
|
|
|
|
t.Setenv("CLOUD_SQL_MYSQL_PROJECT", "your_gcp_project_id")
|
|
t.Setenv("CLOUD_SQL_MYSQL_REGION", "your_gcp_region")
|
|
t.Setenv("CLOUD_SQL_MYSQL_INSTANCE", "your_instance")
|
|
t.Setenv("CLOUD_SQL_MYSQL_DATABASE", "your_cloudsql_mysql_db")
|
|
t.Setenv("CLOUD_SQL_MYSQL_USER", "your_cloudsql_mysql_user")
|
|
t.Setenv("CLOUD_SQL_MYSQL_PASSWORD", "your_cloudsql_mysql_password")
|
|
|
|
t.Setenv("CLOUD_SQL_MSSQL_PROJECT", "your_gcp_project_id")
|
|
t.Setenv("CLOUD_SQL_MSSQL_REGION", "your_gcp_region")
|
|
t.Setenv("CLOUD_SQL_MSSQL_INSTANCE", "your_cloudsql_mssql_instance")
|
|
t.Setenv("CLOUD_SQL_MSSQL_DATABASE", "your_cloudsql_mssql_db")
|
|
t.Setenv("CLOUD_SQL_MSSQL_IP_ADDRESS", "127.0.0.1")
|
|
t.Setenv("CLOUD_SQL_MSSQL_USER", "your_cloudsql_mssql_user")
|
|
t.Setenv("CLOUD_SQL_MSSQL_PASSWORD", "your_cloudsql_mssql_password")
|
|
t.Setenv("CLOUD_SQL_POSTGRES_PASSWORD", "your_cloudsql_pg_password")
|
|
|
|
t.Setenv("SERVERLESS_SPARK_PROJECT", "your_gcp_project_id")
|
|
t.Setenv("SERVERLESS_SPARK_LOCATION", "your_gcp_location")
|
|
|
|
t.Setenv("POSTGRES_HOST", "localhost")
|
|
t.Setenv("POSTGRES_PORT", "5432")
|
|
t.Setenv("POSTGRES_DATABASE", "your_postgres_db")
|
|
t.Setenv("POSTGRES_USER", "your_postgres_user")
|
|
t.Setenv("POSTGRES_PASSWORD", "your_postgres_password")
|
|
|
|
t.Setenv("MYSQL_HOST", "localhost")
|
|
t.Setenv("MYSQL_PORT", "3306")
|
|
t.Setenv("MYSQL_DATABASE", "your_mysql_db")
|
|
t.Setenv("MYSQL_USER", "your_mysql_user")
|
|
t.Setenv("MYSQL_PASSWORD", "your_mysql_password")
|
|
|
|
t.Setenv("MSSQL_HOST", "localhost")
|
|
t.Setenv("MSSQL_PORT", "1433")
|
|
t.Setenv("MSSQL_DATABASE", "your_mssql_db")
|
|
t.Setenv("MSSQL_USER", "your_mssql_user")
|
|
t.Setenv("MSSQL_PASSWORD", "your_mssql_password")
|
|
|
|
t.Setenv("MINDSDB_HOST", "localhost")
|
|
t.Setenv("MINDSDB_PORT", "47334")
|
|
t.Setenv("MINDSDB_DATABASE", "your_mindsdb_db")
|
|
t.Setenv("MINDSDB_USER", "your_mindsdb_user")
|
|
t.Setenv("MINDSDB_PASS", "your_mindsdb_password")
|
|
|
|
t.Setenv("LOOKER_BASE_URL", "https://your_company.looker.com")
|
|
t.Setenv("LOOKER_CLIENT_ID", "your_looker_client_id")
|
|
t.Setenv("LOOKER_CLIENT_SECRET", "your_looker_client_secret")
|
|
t.Setenv("LOOKER_VERIFY_SSL", "true")
|
|
|
|
t.Setenv("LOOKER_PROJECT", "your_project_id")
|
|
t.Setenv("LOOKER_LOCATION", "us")
|
|
|
|
t.Setenv("SQLITE_DATABASE", "test.db")
|
|
|
|
t.Setenv("NEO4J_URI", "bolt://localhost:7687")
|
|
t.Setenv("NEO4J_DATABASE", "neo4j")
|
|
t.Setenv("NEO4J_USERNAME", "your_neo4j_user")
|
|
t.Setenv("NEO4J_PASSWORD", "your_neo4j_password")
|
|
|
|
t.Setenv("CLOUD_HEALTHCARE_PROJECT", "your_gcp_project_id")
|
|
t.Setenv("CLOUD_HEALTHCARE_REGION", "your_gcp_region")
|
|
t.Setenv("CLOUD_HEALTHCARE_DATASET", "your_healthcare_dataset")
|
|
|
|
ctx, err := testutils.ContextWithNewLogger()
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %s", err)
|
|
}
|
|
tcs := []struct {
|
|
name string
|
|
in []byte
|
|
wantToolset server.ToolsetConfigs
|
|
}{
|
|
{
|
|
name: "alloydb postgres admin prebuilt tools",
|
|
in: alloydb_admin_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"alloydb_postgres_admin_tools": tools.ToolsetConfig{
|
|
Name: "alloydb_postgres_admin_tools",
|
|
ToolNames: []string{"create_cluster", "wait_for_operation", "create_instance", "list_clusters", "list_instances", "list_users", "create_user", "get_cluster", "get_instance", "get_user"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "cloudsql pg admin prebuilt tools",
|
|
in: cloudsqlpg_admin_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"cloud_sql_postgres_admin_tools": tools.ToolsetConfig{
|
|
Name: "cloud_sql_postgres_admin_tools",
|
|
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "cloudsql mysql admin prebuilt tools",
|
|
in: cloudsqlmysql_admin_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"cloud_sql_mysql_admin_tools": tools.ToolsetConfig{
|
|
Name: "cloud_sql_mysql_admin_tools",
|
|
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "cloudsql mssql admin prebuilt tools",
|
|
in: cloudsqlmssql_admin_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"cloud_sql_mssql_admin_tools": tools.ToolsetConfig{
|
|
Name: "cloud_sql_mssql_admin_tools",
|
|
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "alloydb prebuilt tools",
|
|
in: alloydb_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"alloydb_postgres_database_tools": tools.ToolsetConfig{
|
|
Name: "alloydb_postgres_database_tools",
|
|
ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "bigquery prebuilt tools",
|
|
in: bigquery_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"bigquery_database_tools": tools.ToolsetConfig{
|
|
Name: "bigquery_database_tools",
|
|
ToolNames: []string{"analyze_contribution", "ask_data_insights", "execute_sql", "forecast", "get_dataset_info", "get_table_info", "list_dataset_ids", "list_table_ids", "search_catalog"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "clickhouse prebuilt tools",
|
|
in: clickhouse_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"clickhouse_database_tools": tools.ToolsetConfig{
|
|
Name: "clickhouse_database_tools",
|
|
ToolNames: []string{"execute_sql", "list_databases", "list_tables"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "cloudsqlpg prebuilt tools",
|
|
in: cloudsqlpg_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"cloud_sql_postgres_database_tools": tools.ToolsetConfig{
|
|
Name: "cloud_sql_postgres_database_tools",
|
|
ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "cloudsqlmysql prebuilt tools",
|
|
in: cloudsqlmysql_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"cloud_sql_mysql_database_tools": tools.ToolsetConfig{
|
|
Name: "cloud_sql_mysql_database_tools",
|
|
ToolNames: []string{"execute_sql", "list_tables", "get_query_plan", "list_active_queries", "list_tables_missing_unique_indexes", "list_table_fragmentation"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "cloudsqlmssql prebuilt tools",
|
|
in: cloudsqlmssql_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"cloud_sql_mssql_database_tools": tools.ToolsetConfig{
|
|
Name: "cloud_sql_mssql_database_tools",
|
|
ToolNames: []string{"execute_sql", "list_tables"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "dataplex prebuilt tools",
|
|
in: dataplex_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"dataplex_tools": tools.ToolsetConfig{
|
|
Name: "dataplex_tools",
|
|
ToolNames: []string{"search_entries", "lookup_entry", "search_aspect_types"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "serverless spark prebuilt tools",
|
|
in: serverless_spark_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"serverless_spark_tools": tools.ToolsetConfig{
|
|
Name: "serverless_spark_tools",
|
|
ToolNames: []string{"list_batches", "get_batch", "cancel_batch"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "firestore prebuilt tools",
|
|
in: firestoreconfig,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"firestore_database_tools": tools.ToolsetConfig{
|
|
Name: "firestore_database_tools",
|
|
ToolNames: []string{"get_documents", "add_documents", "update_document", "list_collections", "delete_documents", "query_collection", "get_rules", "validate_rules"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "mysql prebuilt tools",
|
|
in: mysql_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"mysql_database_tools": tools.ToolsetConfig{
|
|
Name: "mysql_database_tools",
|
|
ToolNames: []string{"execute_sql", "list_tables", "get_query_plan", "list_active_queries", "list_tables_missing_unique_indexes", "list_table_fragmentation"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "mssql prebuilt tools",
|
|
in: mssql_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"mssql_database_tools": tools.ToolsetConfig{
|
|
Name: "mssql_database_tools",
|
|
ToolNames: []string{"execute_sql", "list_tables"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "looker prebuilt tools",
|
|
in: looker_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"looker_tools": tools.ToolsetConfig{
|
|
Name: "looker_tools",
|
|
ToolNames: []string{"get_models", "get_explores", "get_dimensions", "get_measures", "get_filters", "get_parameters", "query", "query_sql", "query_url", "get_looks", "run_look", "make_look", "get_dashboards", "run_dashboard", "make_dashboard", "add_dashboard_element", "health_pulse", "health_analyze", "health_vacuum", "dev_mode", "get_projects", "get_project_files", "get_project_file", "create_project_file", "update_project_file", "delete_project_file", "get_connections", "get_connection_schemas", "get_connection_databases", "get_connection_tables", "get_connection_table_columns"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "looker-conversational-analytics prebuilt tools",
|
|
in: lookerca_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"looker_conversational_analytics_tools": tools.ToolsetConfig{
|
|
Name: "looker_conversational_analytics_tools",
|
|
ToolNames: []string{"ask_data_insights", "get_models", "get_explores"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "postgres prebuilt tools",
|
|
in: postgresconfig,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"postgres_database_tools": tools.ToolsetConfig{
|
|
Name: "postgres_database_tools",
|
|
ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "spanner prebuilt tools",
|
|
in: spanner_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"spanner-database-tools": tools.ToolsetConfig{
|
|
Name: "spanner-database-tools",
|
|
ToolNames: []string{"execute_sql", "execute_sql_dql", "list_tables"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "spanner pg prebuilt tools",
|
|
in: spannerpg_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"spanner_postgres_database_tools": tools.ToolsetConfig{
|
|
Name: "spanner_postgres_database_tools",
|
|
ToolNames: []string{"execute_sql", "execute_sql_dql", "list_tables"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "mindsdb prebuilt tools",
|
|
in: mindsdb_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"mindsdb-tools": tools.ToolsetConfig{
|
|
Name: "mindsdb-tools",
|
|
ToolNames: []string{"mindsdb-execute-sql", "mindsdb-sql"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "sqlite prebuilt tools",
|
|
in: sqlite_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"sqlite_database_tools": tools.ToolsetConfig{
|
|
Name: "sqlite_database_tools",
|
|
ToolNames: []string{"execute_sql", "list_tables"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "neo4j prebuilt tools",
|
|
in: neo4jconfig,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"neo4j_database_tools": tools.ToolsetConfig{
|
|
Name: "neo4j_database_tools",
|
|
ToolNames: []string{"execute_cypher", "get_schema"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "alloydb postgres observability prebuilt tools",
|
|
in: alloydbobsvconfig,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"alloydb_postgres_cloud_monitoring_tools": tools.ToolsetConfig{
|
|
Name: "alloydb_postgres_cloud_monitoring_tools",
|
|
ToolNames: []string{"get_system_metrics", "get_query_metrics"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "cloudsql postgres observability prebuilt tools",
|
|
in: cloudsqlpgobsvconfig,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"cloud_sql_postgres_cloud_monitoring_tools": tools.ToolsetConfig{
|
|
Name: "cloud_sql_postgres_cloud_monitoring_tools",
|
|
ToolNames: []string{"get_system_metrics", "get_query_metrics"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "cloudsql mysql observability prebuilt tools",
|
|
in: cloudsqlmysqlobsvconfig,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"cloud_sql_mysql_cloud_monitoring_tools": tools.ToolsetConfig{
|
|
Name: "cloud_sql_mysql_cloud_monitoring_tools",
|
|
ToolNames: []string{"get_system_metrics", "get_query_metrics"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "cloudsql mssql observability prebuilt tools",
|
|
in: cloudsqlmssqlobsvconfig,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"cloud_sql_mssql_cloud_monitoring_tools": tools.ToolsetConfig{
|
|
Name: "cloud_sql_mssql_cloud_monitoring_tools",
|
|
ToolNames: []string{"get_system_metrics"},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "cloud healthcare prebuilt tools",
|
|
in: cloudhealthcare_config,
|
|
wantToolset: server.ToolsetConfigs{
|
|
"cloud_healthcare_dataset_tools": tools.ToolsetConfig{
|
|
Name: "cloud_healthcare_dataset_tools",
|
|
ToolNames: []string{"get_dataset", "list_dicom_stores", "list_fhir_stores"},
|
|
},
|
|
"cloud_healthcare_fhir_tools": tools.ToolsetConfig{
|
|
Name: "cloud_healthcare_fhir_tools",
|
|
ToolNames: []string{"get_fhir_store", "get_fhir_store_metrics", "get_fhir_resource", "fhir_patient_search", "fhir_patient_everything", "fhir_fetch_page"},
|
|
},
|
|
"cloud_healthcare_dicom_tools": tools.ToolsetConfig{
|
|
Name: "cloud_healthcare_dicom_tools",
|
|
ToolNames: []string{"get_dicom_store", "get_dicom_store_metrics", "search_dicom_studies", "search_dicom_series", "search_dicom_instances", "retrieve_rendered_dicom_instance"},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
toolsFile, err := parseToolsFile(ctx, tc.in)
|
|
if err != nil {
|
|
t.Fatalf("failed to parse input: %v", err)
|
|
}
|
|
if diff := cmp.Diff(tc.wantToolset, toolsFile.Toolsets); diff != "" {
|
|
t.Fatalf("incorrect tools parse: diff %v", diff)
|
|
}
|
|
// Prebuilt configs do not have prompts, so assert empty maps.
|
|
if len(toolsFile.Prompts) != 0 {
|
|
t.Fatalf("expected empty prompts map for prebuilt config, got: %v", toolsFile.Prompts)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMutuallyExclusiveFlags(t *testing.T) {
|
|
testCases := []struct {
|
|
desc string
|
|
args []string
|
|
errString string
|
|
}{
|
|
{
|
|
desc: "--prebuilt and --tools-file",
|
|
args: []string{"--prebuilt", "alloydb", "--tools-file", "my.yaml"},
|
|
errString: "--prebuilt and --tools-file/--tools-files/--tools-folder flags cannot be used simultaneously",
|
|
},
|
|
{
|
|
desc: "--tools-file and --tools-files",
|
|
args: []string{"--tools-file", "my.yaml", "--tools-files", "a.yaml,b.yaml"},
|
|
errString: "--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously",
|
|
},
|
|
{
|
|
desc: "--tools-folder and --tools-files",
|
|
args: []string{"--tools-folder", "./", "--tools-files", "a.yaml,b.yaml"},
|
|
errString: "--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.desc, func(t *testing.T) {
|
|
cmd := NewCommand()
|
|
cmd.SetArgs(tc.args)
|
|
err := cmd.Execute()
|
|
if err == nil {
|
|
t.Fatalf("expected an error but got none")
|
|
}
|
|
if !strings.Contains(err.Error(), tc.errString) {
|
|
t.Errorf("expected error message to contain %q, but got %q", tc.errString, err.Error())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFileLoadingErrors(t *testing.T) {
|
|
t.Run("non-existent tools-file", func(t *testing.T) {
|
|
cmd := NewCommand()
|
|
// Use a file that is guaranteed not to exist
|
|
nonExistentFile := filepath.Join(t.TempDir(), "non-existent-tools.yaml")
|
|
cmd.SetArgs([]string{"--tools-file", nonExistentFile})
|
|
|
|
err := cmd.Execute()
|
|
if err == nil {
|
|
t.Fatal("expected an error for non-existent file but got none")
|
|
}
|
|
if !strings.Contains(err.Error(), "unable to read tool file") {
|
|
t.Errorf("expected error about reading file, but got: %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("non-existent tools-folder", func(t *testing.T) {
|
|
cmd := NewCommand()
|
|
nonExistentFolder := filepath.Join(t.TempDir(), "non-existent-folder")
|
|
cmd.SetArgs([]string{"--tools-folder", nonExistentFolder})
|
|
|
|
err := cmd.Execute()
|
|
if err == nil {
|
|
t.Fatal("expected an error for non-existent folder but got none")
|
|
}
|
|
if !strings.Contains(err.Error(), "unable to access tools folder") {
|
|
t.Errorf("expected error about accessing folder, but got: %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestMergeToolsFiles(t *testing.T) {
|
|
file1 := ToolsFile{
|
|
Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}},
|
|
Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}},
|
|
Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}},
|
|
}
|
|
file2 := ToolsFile{
|
|
AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}},
|
|
Tools: server.ToolConfigs{"tool2": http.Config{Name: "tool2"}},
|
|
Toolsets: server.ToolsetConfigs{"set2": tools.ToolsetConfig{Name: "set2"}},
|
|
}
|
|
fileWithConflicts := ToolsFile{
|
|
Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}},
|
|
Tools: server.ToolConfigs{"tool2": http.Config{Name: "tool2"}},
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
files []ToolsFile
|
|
want ToolsFile
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "merge two distinct files",
|
|
files: []ToolsFile{file1, file2},
|
|
want: ToolsFile{
|
|
Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}},
|
|
AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}},
|
|
Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}, "tool2": http.Config{Name: "tool2"}},
|
|
Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}, "set2": tools.ToolsetConfig{Name: "set2"}},
|
|
Prompts: server.PromptConfigs{},
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "merge with conflicts",
|
|
files: []ToolsFile{file1, file2, fileWithConflicts},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "merge single file",
|
|
files: []ToolsFile{file1},
|
|
want: ToolsFile{
|
|
Sources: file1.Sources,
|
|
AuthServices: make(server.AuthServiceConfigs),
|
|
Tools: file1.Tools,
|
|
Toolsets: file1.Toolsets,
|
|
Prompts: server.PromptConfigs{},
|
|
},
|
|
},
|
|
{
|
|
name: "merge empty list",
|
|
files: []ToolsFile{},
|
|
want: ToolsFile{
|
|
Sources: make(server.SourceConfigs),
|
|
AuthServices: make(server.AuthServiceConfigs),
|
|
Tools: make(server.ToolConfigs),
|
|
Toolsets: make(server.ToolsetConfigs),
|
|
Prompts: server.PromptConfigs{},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
got, err := mergeToolsFiles(tc.files...)
|
|
if (err != nil) != tc.wantErr {
|
|
t.Fatalf("mergeToolsFiles() error = %v, wantErr %v", err, tc.wantErr)
|
|
}
|
|
if !tc.wantErr {
|
|
if diff := cmp.Diff(tc.want, got); diff != "" {
|
|
t.Errorf("mergeToolsFiles() mismatch (-want +got):\n%s", diff)
|
|
}
|
|
} else {
|
|
if err == nil {
|
|
t.Fatal("expected an error for conflicting files but got none")
|
|
}
|
|
if !strings.Contains(err.Error(), "resource conflicts detected") {
|
|
t.Errorf("expected conflict error, but got: %v", err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|