ci: Increase AlloyDB integration test coverage (#187)

add tests for tool invocation with params, connection over public IP,
and connection over private IP.
This commit is contained in:
Wenxin Du
2025-01-14 12:07:34 +08:00
committed by GitHub
parent 0c86e89506
commit 6ffcca0573
3 changed files with 306 additions and 3 deletions

2
go.mod
View File

@@ -14,6 +14,7 @@ require (
github.com/go-chi/httplog/v2 v2.1.1
github.com/go-chi/render v1.0.3
github.com/google/go-cmp v0.6.0
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.7.1
github.com/spf13/cobra v1.8.1
go.opentelemetry.io/contrib/propagators/autoprop v0.58.0
@@ -53,7 +54,6 @@ require (
github.com/go-logr/stdr v1.2.2 // indirect
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
github.com/google/s2a-go v0.1.8 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
github.com/googleapis/gax-go/v2 v2.14.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 // indirect

View File

@@ -20,13 +20,20 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"reflect"
"regexp"
"strings"
"testing"
"time"
"cloud.google.com/go/alloydbconn"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
)
var (
@@ -39,7 +46,7 @@ var (
ALLOYDB_POSTGRES_PASS = os.Getenv("ALLOYDB_POSTGRES_PASS")
)
func requireAlloyDBPgVars(t *testing.T) {
func requireAlloyDBPgVars(t *testing.T) map[string]any {
switch "" {
case ALLOYDB_POSTGRES_PROJECT:
t.Fatal("'ALLOYDB_POSTGRES_PROJECT' not set")
@@ -56,9 +63,64 @@ func requireAlloyDBPgVars(t *testing.T) {
case ALLOYDB_POSTGRES_PASS:
t.Fatal("'ALLOYDB_POSTGRES_PASS' not set")
}
return map[string]any{
"kind": "alloydb-postgres",
"project": ALLOYDB_POSTGRES_PROJECT,
"cluster": ALLOYDB_POSTGRES_CLUSTER,
"instance": ALLOYDB_POSTGRES_INSTANCE,
"region": ALLOYDB_POSTGRES_REGION,
"database": ALLOYDB_POSTGRES_DATABASE,
"user": ALLOYDB_POSTGRES_USER,
"password": ALLOYDB_POSTGRES_PASS,
}
}
func TestAlloyDBPostgres(t *testing.T) {
// Copied over from alloydb_pg.go
func getDialOpts(ip_type string) ([]alloydbconn.DialOption, error) {
switch strings.ToLower(ip_type) {
case "private":
return []alloydbconn.DialOption{alloydbconn.WithPrivateIP()}, nil
case "public":
return []alloydbconn.DialOption{alloydbconn.WithPublicIP()}, nil
default:
return nil, fmt.Errorf("invalid ip_type %s", ip_type)
}
}
// Copied over from alloydb_pg.go
func initAlloyDBPgConnectionPool(project, region, cluster, instance, ip_type, user, pass, dbname string) (*pgxpool.Pool, error) {
// Configure the driver to connect to the database
dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", user, pass, dbname)
config, err := pgxpool.ParseConfig(dsn)
if err != nil {
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
}
// Create a new dialer with options
dialOpts, err := getDialOpts(ip_type)
if err != nil {
return nil, err
}
d, err := alloydbconn.NewDialer(context.Background(), alloydbconn.WithDefaultDialOptions(dialOpts...))
if err != nil {
return nil, fmt.Errorf("unable to parse connection uri: %w", err)
}
// Tell the driver to use the AlloyDB Go Connector to create connections
i := fmt.Sprintf("projects/%s/locations/%s/clusters/%s/instances/%s", project, region, cluster, instance)
config.ConnConfig.DialFunc = func(ctx context.Context, _ string, instance string) (net.Conn, error) {
return d.Dial(ctx, i)
}
// Interact with the driver directly as you normally would
pool, err := pgxpool.NewWithConfig(context.Background(), config)
if err != nil {
return nil, err
}
return pool, nil
}
func TestAlloyDBSimpleToolEndpoints(t *testing.T) {
requireAlloyDBPgVars(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
@@ -187,3 +249,79 @@ func TestAlloyDBPostgres(t *testing.T) {
})
}
}
func TestPublicIpConnection(t *testing.T) {
// Test connecting to an AlloyDB source's public IP
sourceConfig := requireAlloyDBPgVars(t)
sourceConfig["ipType"] = "public"
RunSourceConnectionTest(t, sourceConfig, "postgres-sql")
}
func TestPrivateIpConnection(t *testing.T) {
// Test connecting to an AlloyDB source's private IP
sourceConfig := requireAlloyDBPgVars(t)
sourceConfig["ipType"] = "private"
RunSourceConnectionTest(t, sourceConfig, "postgres-sql")
}
func setupParamTest(t *testing.T, ctx context.Context, tableName string) func(*testing.T) {
// Set up Tool invocation with parameters test
pool, err := initAlloyDBPgConnectionPool(ALLOYDB_POSTGRES_PROJECT, ALLOYDB_POSTGRES_REGION, ALLOYDB_POSTGRES_CLUSTER, ALLOYDB_POSTGRES_INSTANCE, "public", ALLOYDB_POSTGRES_USER, ALLOYDB_POSTGRES_PASS, ALLOYDB_POSTGRES_DATABASE)
if err != nil {
t.Fatalf("unable to create AlloyDB connection pool: %s", err)
}
err = pool.Ping(ctx)
if err != nil {
t.Fatalf("unable to connect to test database: %s", err)
}
_, err = pool.Query(ctx, fmt.Sprintf(`
CREATE TABLE %s (
id SERIAL PRIMARY KEY,
name TEXT
);
`, tableName))
if err != nil {
t.Fatalf("unable to create test table %s: %s", tableName, err)
}
// Insert test data
statement := fmt.Sprintf(`
INSERT INTO %s (name)
VALUES ($1), ($2), ($3);
`, tableName)
params := []any{"Alice", "Jane", "Sid"}
_, err = pool.Query(ctx, statement, params...)
if err != nil {
t.Fatalf("unable to insert test data: %s", err)
}
return func(t *testing.T) {
// tear down test
_, err = pool.Exec(ctx, fmt.Sprintf("DROP TABLE %s;", tableName))
if err != nil {
t.Errorf("Teardown failed: %s", err)
}
}
}
func TestToolInvocationWithParams(t *testing.T) {
// create test configs
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
// create source config
sourceConfig := requireAlloyDBPgVars(t)
// create table name with UUID
tableName := "param_test_table_" + strings.Replace(uuid.New().String(), "-", "", -1)
// test setup function reterns teardown function
teardownTest := setupParamTest(t, ctx, tableName)
defer teardownTest(t)
// call generic invocation test helper
RunToolInvocationWithParamsTest(t, sourceConfig, "postgres-sql", tableName)
}

View File

@@ -22,12 +22,17 @@ package tests
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"regexp"
"strings"
"testing"
"time"
"gopkg.in/yaml.v3"
@@ -203,3 +208,163 @@ func (c *CmdExec) WaitForString(ctx context.Context, re *regexp.Regexp) (string,
}
}
}
func RunToolInvocationWithParamsTest(t *testing.T, sourceConfig map[string]any, toolKind string, tableName string) {
// Write config into a file and pass it to command
var statement string
switch toolKind {
case "postgres-sql":
statement = fmt.Sprintf("SELECT * FROM %s WHERE id = $1 OR name = $2;", tableName)
default:
t.Fatalf("invalid tool kind: %s", toolKind)
}
toolsFile := map[string]any{
"sources": map[string]any{
"my-instance": sourceConfig,
},
"tools": map[string]any{
"my-tool": map[string]any{
"kind": toolKind,
"source": "my-instance",
"description": "Tool to test invocation with params.",
"statement": statement,
"parameters": []any{
map[string]any{
"name": "id",
"type": "integer",
"description": "user ID",
},
map[string]any{
"name": "name",
"type": "string",
"description": "user name",
},
},
},
},
}
// Initialize a test command
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var args []string
cmd, cleanup, err := StartCmd(ctx, toolsFile, args...)
if err != nil {
t.Fatalf("command initialization returned an error: %s", err)
}
defer cleanup()
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Minute)
defer cancel()
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
if err != nil {
t.Logf("toolbox command logs: \n%s", out)
t.Fatalf("toolbox didn't start successfully: %s", err)
}
// Test Tool invocation with parameters
invokeTcs := []struct {
name string
api string
requestBody io.Reader
want string
isErr bool
}{
{
name: "Invoke my-tool with parameters",
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{"id": 3, "name": "Alice"}`)),
isErr: false,
want: "Stub tool call for \"my-tool\"! Parameters parsed: [{\"id\" '\\x03'} {\"name\" \"Alice\"}] \n Output: [%!s(int32=1) Alice][%!s(int32=3) Sid]",
},
{
name: "Invoke my-tool without parameters",
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{}`)),
isErr: true,
},
{
name: "Invoke my-tool without insufficient parameters",
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{"id": 1}`)),
isErr: true,
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
// Send Tool invocation request with parameters
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %s", err)
}
if resp.StatusCode != http.StatusOK {
if tc.isErr == true {
return
}
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
}
// Check response body
var body map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&body)
if err != nil {
t.Fatalf("error parsing response body")
}
got, ok := body["result"].(string)
if !ok {
t.Fatalf("unable to find result in response body")
}
if got != tc.want {
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
}
})
}
}
func RunSourceConnectionTest(t *testing.T, sourceConfig map[string]any, toolKind string) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var args []string
// Write config into a file and pass it to command
toolsFile := map[string]any{
"sources": map[string]any{
"my-instance": sourceConfig,
},
"tools": map[string]any{
"my-simple-tool": map[string]any{
"kind": toolKind,
"source": "my-instance",
"description": "Simple tool to test end to end functionality.",
"statement": "SELECT 1;",
},
},
}
cmd, cleanup, err := StartCmd(ctx, toolsFile, args...)
if err != nil {
t.Fatalf("command initialization returned an error: %s", err)
}
defer cleanup()
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
if err != nil {
t.Logf("toolbox command logs: \n%s", out)
t.Fatalf("toolbox didn't start successfully: %s", err)
}
}