mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 15:38:08 -05:00
ci: Increase AlloyDB integration test coverage (#187)
add tests for tool invocation with params, connection over public IP, and connection over private IP.
This commit is contained in:
2
go.mod
2
go.mod
@@ -14,6 +14,7 @@ require (
|
||||
github.com/go-chi/httplog/v2 v2.1.1
|
||||
github.com/go-chi/render v1.0.3
|
||||
github.com/google/go-cmp v0.6.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.7.1
|
||||
github.com/spf13/cobra v1.8.1
|
||||
go.opentelemetry.io/contrib/propagators/autoprop v0.58.0
|
||||
@@ -53,7 +54,6 @@ require (
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
|
||||
github.com/google/s2a-go v0.1.8 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.14.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 // indirect
|
||||
|
||||
@@ -20,13 +20,20 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cloud.google.com/go/alloydbconn"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -39,7 +46,7 @@ var (
|
||||
ALLOYDB_POSTGRES_PASS = os.Getenv("ALLOYDB_POSTGRES_PASS")
|
||||
)
|
||||
|
||||
func requireAlloyDBPgVars(t *testing.T) {
|
||||
func requireAlloyDBPgVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case ALLOYDB_POSTGRES_PROJECT:
|
||||
t.Fatal("'ALLOYDB_POSTGRES_PROJECT' not set")
|
||||
@@ -56,9 +63,64 @@ func requireAlloyDBPgVars(t *testing.T) {
|
||||
case ALLOYDB_POSTGRES_PASS:
|
||||
t.Fatal("'ALLOYDB_POSTGRES_PASS' not set")
|
||||
}
|
||||
return map[string]any{
|
||||
"kind": "alloydb-postgres",
|
||||
"project": ALLOYDB_POSTGRES_PROJECT,
|
||||
"cluster": ALLOYDB_POSTGRES_CLUSTER,
|
||||
"instance": ALLOYDB_POSTGRES_INSTANCE,
|
||||
"region": ALLOYDB_POSTGRES_REGION,
|
||||
"database": ALLOYDB_POSTGRES_DATABASE,
|
||||
"user": ALLOYDB_POSTGRES_USER,
|
||||
"password": ALLOYDB_POSTGRES_PASS,
|
||||
}
|
||||
}
|
||||
|
||||
func TestAlloyDBPostgres(t *testing.T) {
|
||||
// Copied over from alloydb_pg.go
|
||||
func getDialOpts(ip_type string) ([]alloydbconn.DialOption, error) {
|
||||
switch strings.ToLower(ip_type) {
|
||||
case "private":
|
||||
return []alloydbconn.DialOption{alloydbconn.WithPrivateIP()}, nil
|
||||
case "public":
|
||||
return []alloydbconn.DialOption{alloydbconn.WithPublicIP()}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid ip_type %s", ip_type)
|
||||
}
|
||||
}
|
||||
|
||||
// Copied over from alloydb_pg.go
|
||||
func initAlloyDBPgConnectionPool(project, region, cluster, instance, ip_type, user, pass, dbname string) (*pgxpool.Pool, error) {
|
||||
// Configure the driver to connect to the database
|
||||
dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", user, pass, dbname)
|
||||
config, err := pgxpool.ParseConfig(dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
|
||||
}
|
||||
|
||||
// Create a new dialer with options
|
||||
dialOpts, err := getDialOpts(ip_type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d, err := alloydbconn.NewDialer(context.Background(), alloydbconn.WithDefaultDialOptions(dialOpts...))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
|
||||
}
|
||||
|
||||
// Tell the driver to use the AlloyDB Go Connector to create connections
|
||||
i := fmt.Sprintf("projects/%s/locations/%s/clusters/%s/instances/%s", project, region, cluster, instance)
|
||||
config.ConnConfig.DialFunc = func(ctx context.Context, _ string, instance string) (net.Conn, error) {
|
||||
return d.Dial(ctx, i)
|
||||
}
|
||||
|
||||
// Interact with the driver directly as you normally would
|
||||
pool, err := pgxpool.NewWithConfig(context.Background(), config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
func TestAlloyDBSimpleToolEndpoints(t *testing.T) {
|
||||
requireAlloyDBPgVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
@@ -187,3 +249,79 @@ func TestAlloyDBPostgres(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicIpConnection(t *testing.T) {
|
||||
// Test connecting to an AlloyDB source's public IP
|
||||
sourceConfig := requireAlloyDBPgVars(t)
|
||||
sourceConfig["ipType"] = "public"
|
||||
RunSourceConnectionTest(t, sourceConfig, "postgres-sql")
|
||||
}
|
||||
|
||||
func TestPrivateIpConnection(t *testing.T) {
|
||||
// Test connecting to an AlloyDB source's private IP
|
||||
sourceConfig := requireAlloyDBPgVars(t)
|
||||
sourceConfig["ipType"] = "private"
|
||||
RunSourceConnectionTest(t, sourceConfig, "postgres-sql")
|
||||
}
|
||||
|
||||
func setupParamTest(t *testing.T, ctx context.Context, tableName string) func(*testing.T) {
|
||||
// Set up Tool invocation with parameters test
|
||||
pool, err := initAlloyDBPgConnectionPool(ALLOYDB_POSTGRES_PROJECT, ALLOYDB_POSTGRES_REGION, ALLOYDB_POSTGRES_CLUSTER, ALLOYDB_POSTGRES_INSTANCE, "public", ALLOYDB_POSTGRES_USER, ALLOYDB_POSTGRES_PASS, ALLOYDB_POSTGRES_DATABASE)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create AlloyDB connection pool: %s", err)
|
||||
}
|
||||
|
||||
err = pool.Ping(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to connect to test database: %s", err)
|
||||
}
|
||||
|
||||
_, err = pool.Query(ctx, fmt.Sprintf(`
|
||||
CREATE TABLE %s (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name TEXT
|
||||
);
|
||||
`, tableName))
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create test table %s: %s", tableName, err)
|
||||
}
|
||||
|
||||
// Insert test data
|
||||
statement := fmt.Sprintf(`
|
||||
INSERT INTO %s (name)
|
||||
VALUES ($1), ($2), ($3);
|
||||
`, tableName)
|
||||
|
||||
params := []any{"Alice", "Jane", "Sid"}
|
||||
_, err = pool.Query(ctx, statement, params...)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to insert test data: %s", err)
|
||||
}
|
||||
|
||||
return func(t *testing.T) {
|
||||
// tear down test
|
||||
_, err = pool.Exec(ctx, fmt.Sprintf("DROP TABLE %s;", tableName))
|
||||
if err != nil {
|
||||
t.Errorf("Teardown failed: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolInvocationWithParams(t *testing.T) {
|
||||
// create test configs
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// create source config
|
||||
sourceConfig := requireAlloyDBPgVars(t)
|
||||
|
||||
// create table name with UUID
|
||||
tableName := "param_test_table_" + strings.Replace(uuid.New().String(), "-", "", -1)
|
||||
|
||||
// test setup function reterns teardown function
|
||||
teardownTest := setupParamTest(t, ctx, tableName)
|
||||
defer teardownTest(t)
|
||||
|
||||
// call generic invocation test helper
|
||||
RunToolInvocationWithParamsTest(t, sourceConfig, "postgres-sql", tableName)
|
||||
}
|
||||
|
||||
@@ -22,12 +22,17 @@ package tests
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
@@ -203,3 +208,163 @@ func (c *CmdExec) WaitForString(ctx context.Context, re *regexp.Regexp) (string,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func RunToolInvocationWithParamsTest(t *testing.T, sourceConfig map[string]any, toolKind string, tableName string) {
|
||||
// Write config into a file and pass it to command
|
||||
var statement string
|
||||
switch toolKind {
|
||||
case "postgres-sql":
|
||||
statement = fmt.Sprintf("SELECT * FROM %s WHERE id = $1 OR name = $2;", tableName)
|
||||
default:
|
||||
t.Fatalf("invalid tool kind: %s", toolKind)
|
||||
}
|
||||
|
||||
toolsFile := map[string]any{
|
||||
"sources": map[string]any{
|
||||
"my-instance": sourceConfig,
|
||||
},
|
||||
"tools": map[string]any{
|
||||
"my-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test invocation with params.",
|
||||
"statement": statement,
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"name": "id",
|
||||
"type": "integer",
|
||||
"description": "user ID",
|
||||
},
|
||||
map[string]any{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"description": "user name",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Initialize a test command
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
|
||||
cmd, cleanup, err := StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
// Test Tool invocation with parameters
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
|
||||
requestBody io.Reader
|
||||
want string
|
||||
isErr bool
|
||||
}{
|
||||
{
|
||||
name: "Invoke my-tool with parameters",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"id": 3, "name": "Alice"}`)),
|
||||
isErr: false,
|
||||
want: "Stub tool call for \"my-tool\"! Parameters parsed: [{\"id\" '\\x03'} {\"name\" \"Alice\"}] \n Output: [%!s(int32=1) Alice][%!s(int32=3) Sid]",
|
||||
},
|
||||
{
|
||||
name: "Invoke my-tool without parameters",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-tool without insufficient parameters",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"id": 1}`)),
|
||||
isErr: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Send Tool invocation request with parameters
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr == true {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func RunSourceConnectionTest(t *testing.T, sourceConfig map[string]any, toolKind string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := map[string]any{
|
||||
"sources": map[string]any{
|
||||
"my-instance": sourceConfig,
|
||||
},
|
||||
"tools": map[string]any{
|
||||
"my-simple-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"statement": "SELECT 1;",
|
||||
},
|
||||
},
|
||||
}
|
||||
cmd, cleanup, err := StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user