test: add prebuilt tools test to all postgres source (#1874)

Move postgres prebuilt integration tests to `common.go` and `tool.go`.
Run those tests from alloydbpg and cloudsqlpg as well.

alloydbpg and cloudsqlpg integration test coverage calculate against the
whole `internal/tools/postgres/` folder. If not added, the coverage will
eventually drop below minimum requirement.
This commit is contained in:
Yuan Teoh
2025-11-06 15:54:38 -08:00
committed by GitHub
parent 47bbbd8c7f
commit 1af43db6f2
5 changed files with 534 additions and 723 deletions

View File

@@ -129,6 +129,9 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
t.Fatalf("unable to create AlloyDB connection pool: %s", err)
}
// cleanup test environment
tests.CleanupPostgresTables(t, ctx, pool)
// create table name with UUID
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
@@ -175,7 +178,14 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
tests.RunMCPToolCallMethod(t, failInvocationWant, mcpSelect1Want)
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
// Run Postgres prebuilt tool tests
tests.RunPostgresListTablesTest(t, tableNameParam, tableNameAuth, AlloyDBPostgresUser)
tests.RunPostgresListViewsTest(t, ctx, pool, tableNameParam)
tests.RunPostgresListSchemasTest(t, ctx, pool)
tests.RunPostgresListActiveQueriesTest(t, ctx, pool)
tests.RunPostgresListAvailableExtensionsTest(t)
tests.RunPostgresListInstalledExtensionsTest(t)
}
// Test connection with different IP type

View File

@@ -114,6 +114,9 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
t.Fatalf("unable to create Cloud SQL connection pool: %s", err)
}
// cleanup test environment
tests.CleanupPostgresTables(t, ctx, pool)
// create table name with UUID
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
@@ -159,7 +162,14 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want)
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
// Run Postgres prebuilt tool tests
tests.RunPostgresListTablesTest(t, tableNameParam, tableNameAuth, CloudSQLPostgresUser)
tests.RunPostgresListViewsTest(t, ctx, pool, tableNameParam)
tests.RunPostgresListSchemasTest(t, ctx, pool)
tests.RunPostgresListActiveQueriesTest(t, ctx, pool)
tests.RunPostgresListAvailableExtensionsTest(t)
tests.RunPostgresListInstalledExtensionsTest(t)
}
// Test connection with different IP type

View File

@@ -33,10 +33,6 @@ import (
"github.com/jackc/pgx/v5/pgxpool"
)
var (
PostgresListSchemasToolKind = "postgres-list-schemas"
)
// GetToolsConfig returns a mock tools config file
func GetToolsConfig(sourceConfig map[string]any, toolKind, paramToolStatement, idParamToolStmt, nameParamToolStmt, arrayToolStatement, authToolStatement string) map[string]any {
// Write config into a file and pass it to command
@@ -195,10 +191,46 @@ func AddExecuteSqlConfig(t *testing.T, config map[string]any, toolKind string) m
}
func AddPostgresPrebuiltConfig(t *testing.T, config map[string]any) map[string]any {
var (
PostgresListSchemasToolKind = "postgres-list-schemas"
PostgresListTablesToolKind = "postgres-list-tables"
PostgresListActiveQueriesToolKind = "postgres-list-active-queries"
PostgresListInstalledExtensionsToolKind = "postgres-list-installed-extensions"
PostgresListAvailableExtensionsToolKind = "postgres-list-available-extensions"
PostgresListViewsToolKind = "postgres-list-views"
)
tools, ok := config["tools"].(map[string]any)
if !ok {
t.Fatalf("unable to get tools from config")
}
tools["list_tables"] = map[string]any{
"kind": PostgresListTablesToolKind,
"source": "my-instance",
"description": "Lists tables in the database.",
}
tools["list_active_queries"] = map[string]any{
"kind": PostgresListActiveQueriesToolKind,
"source": "my-instance",
"description": "Lists active queries in the database.",
}
tools["list_installed_extensions"] = map[string]any{
"kind": PostgresListInstalledExtensionsToolKind,
"source": "my-instance",
"description": "Lists installed extensions in the database.",
}
tools["list_available_extensions"] = map[string]any{
"kind": PostgresListAvailableExtensionsToolKind,
"source": "my-instance",
"description": "Lists available extensions in the database.",
}
tools["list_views"] = map[string]any{
"kind": PostgresListViewsToolKind,
"source": "my-instance",
}
tools["list_schemas"] = map[string]any{
"kind": PostgresListSchemasToolKind,
"source": "my-instance",

View File

@@ -15,23 +15,15 @@
package postgres
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"reflect"
"regexp"
"sort"
"strings"
"sync"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/tests"
@@ -39,18 +31,13 @@ import (
)
var (
PostgresSourceKind = "postgres"
PostgresToolKind = "postgres-sql"
PostgresListTablesToolKind = "postgres-list-tables"
PostgresListActiveQueriesToolKind = "postgres-list-active-queries"
PostgresListInstalledExtensionsToolKind = "postgres-list-installed-extensions"
PostgresListAvailableExtensionsToolKind = "postgres-list-available-extensions"
PostgresListViewsToolKind = "postgres-list-views"
PostgresDatabase = os.Getenv("POSTGRES_DATABASE")
PostgresHost = os.Getenv("POSTGRES_HOST")
PostgresPort = os.Getenv("POSTGRES_PORT")
PostgresUser = os.Getenv("POSTGRES_USER")
PostgresPass = os.Getenv("POSTGRES_PASS")
PostgresSourceKind = "postgres"
PostgresToolKind = "postgres-sql"
PostgresDatabase = os.Getenv("POSTGRES_DATABASE")
PostgresHost = os.Getenv("POSTGRES_HOST")
PostgresPort = os.Getenv("POSTGRES_PORT")
PostgresUser = os.Getenv("POSTGRES_USER")
PostgresPass = os.Getenv("POSTGRES_PASS")
)
func getPostgresVars(t *testing.T) map[string]any {
@@ -77,43 +64,6 @@ func getPostgresVars(t *testing.T) map[string]any {
}
}
func addPrebuiltToolConfig(t *testing.T, config map[string]any) map[string]any {
tools, ok := config["tools"].(map[string]any)
if !ok {
t.Fatalf("unable to get tools from config")
}
tools["list_tables"] = map[string]any{
"kind": PostgresListTablesToolKind,
"source": "my-instance",
"description": "Lists tables in the database.",
}
tools["list_active_queries"] = map[string]any{
"kind": PostgresListActiveQueriesToolKind,
"source": "my-instance",
"description": "Lists active queries in the database.",
}
tools["list_installed_extensions"] = map[string]any{
"kind": PostgresListInstalledExtensionsToolKind,
"source": "my-instance",
"description": "Lists installed extensions in the database.",
}
tools["list_available_extensions"] = map[string]any{
"kind": PostgresListAvailableExtensionsToolKind,
"source": "my-instance",
"description": "Lists available extensions in the database.",
}
tools["list_views"] = map[string]any{
"kind": PostgresListViewsToolKind,
"source": "my-instance",
}
config["tools"] = tools
return config
}
// Copied over from postgres.go
func initPostgresConnectionPool(host, port, user, pass, dbname string) (*pgxpool.Pool, error) {
// urlExample := "postgres:dd//username:password@localhost:5432/database_name"
@@ -166,8 +116,6 @@ func TestPostgres(t *testing.T) {
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, PostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
toolsFile = addPrebuiltToolConfig(t, toolsFile)
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
@@ -194,506 +142,11 @@ func TestPostgres(t *testing.T) {
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
// Run specific Postgres tool tests
runPostgresListTablesTest(t, tableNameParam, tableNameAuth)
runPostgresListViewsTest(t, ctx, pool, tableNameParam)
// Run Postgres prebuilt tool tests
tests.RunPostgresListTablesTest(t, tableNameParam, tableNameAuth, PostgresUser)
tests.RunPostgresListViewsTest(t, ctx, pool, tableNameParam)
tests.RunPostgresListSchemasTest(t, ctx, pool)
runPostgresListActiveQueriesTest(t, ctx, pool)
runPostgresListAvailableExtensionsTest(t)
runPostgresListInstalledExtensionsTest(t)
}
func runPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth string) {
// TableNameParam columns to construct want
paramTableColumns := fmt.Sprintf(`[
{"data_type": "integer", "column_name": "id", "column_default": "nextval('%s_id_seq'::regclass)", "is_not_nullable": true, "ordinal_position": 1, "column_comment": null},
{"data_type": "text", "column_name": "name", "column_default": null, "is_not_nullable": false, "ordinal_position": 2, "column_comment": null}
]`, tableNameParam)
// TableNameAuth columns to construct want
authTableColumns := fmt.Sprintf(`[
{"data_type": "integer", "column_name": "id", "column_default": "nextval('%s_id_seq'::regclass)", "is_not_nullable": true, "ordinal_position": 1, "column_comment": null},
{"data_type": "text", "column_name": "name", "column_default": null, "is_not_nullable": false, "ordinal_position": 2, "column_comment": null},
{"data_type": "text", "column_name": "email", "column_default": null, "is_not_nullable": false, "ordinal_position": 3, "column_comment": null}
]`, tableNameAuth)
const (
// Template to construct detailed output want
detailedObjectTemplate = `{
"object_name": "%[1]s", "schema_name": "public",
"object_details": {
"owner": "%[3]s", "comment": null,
"indexes": [{"is_primary": true, "is_unique": true, "index_name": "%[1]s_pkey", "index_method": "btree", "index_columns": ["id"], "index_definition": "CREATE UNIQUE INDEX %[1]s_pkey ON public.%[1]s USING btree (id)"}],
"triggers": [], "columns": %[2]s, "object_name": "%[1]s", "object_type": "TABLE", "schema_name": "public",
"constraints": [{"constraint_name": "%[1]s_pkey", "constraint_type": "PRIMARY KEY", "constraint_columns": ["id"], "constraint_definition": "PRIMARY KEY (id)", "foreign_key_referenced_table": null, "foreign_key_referenced_columns": null}]
}
}`
// Template to construct simple output want
simpleObjectTemplate = `{"object_name":"%s", "schema_name":"public", "object_details":{"name":"%s"}}`
)
// Helper to build json for detailed want
getDetailedWant := func(tableName, columnJSON string) string {
return fmt.Sprintf(detailedObjectTemplate, tableName, columnJSON, PostgresUser)
}
// Helper to build template for simple want
getSimpleWant := func(tableName string) string {
return fmt.Sprintf(simpleObjectTemplate, tableName, tableName)
}
invokeTcs := []struct {
name string
api string
requestBody io.Reader
wantStatusCode int
want string
isAllTables bool
}{
{
name: "invoke list_tables all tables detailed output",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: bytes.NewBuffer([]byte(`{"table_names": ""}`)),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)),
isAllTables: true,
},
{
name: "invoke list_tables all tables simple output",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "simple"}`)),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s,%s]", getSimpleWant(tableNameAuth), getSimpleWant(tableNameParam)),
isAllTables: true,
},
{
name: "invoke list_tables detailed output",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s"}`, tableNameAuth))),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s]", getDetailedWant(tableNameAuth, authTableColumns)),
},
{
name: "invoke list_tables simple output",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s", "output_format": "simple"}`, tableNameAuth))),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s]", getSimpleWant(tableNameAuth)),
},
{
name: "invoke list_tables with invalid output format",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "abcd"}`)),
wantStatusCode: http.StatusBadRequest,
},
{
name: "invoke list_tables with malformed table_names parameter",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: bytes.NewBuffer([]byte(`{"table_names": 12345, "output_format": "detailed"}`)),
wantStatusCode: http.StatusBadRequest,
},
{
name: "invoke list_tables with multiple table names",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,%s"}`, tableNameParam, tableNameAuth))),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)),
},
{
name: "invoke list_tables with non-existent table",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: bytes.NewBuffer([]byte(`{"table_names": "non_existent_table"}`)),
wantStatusCode: http.StatusOK,
want: `null`,
},
{
name: "invoke list_tables with one existing and one non-existent table",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,non_existent_table"}`, tableNameParam))),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s]", getDetailedWant(tableNameParam, paramTableColumns)),
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != tc.wantStatusCode {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
}
if tc.wantStatusCode == http.StatusOK {
var bodyWrapper map[string]json.RawMessage
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("error reading response body: %s", err)
}
if err := json.Unmarshal(respBytes, &bodyWrapper); err != nil {
t.Fatalf("error parsing response wrapper: %s, body: %s", err, string(respBytes))
}
resultJSON, ok := bodyWrapper["result"]
if !ok {
t.Fatal("unable to find 'result' in response body")
}
var resultString string
if err := json.Unmarshal(resultJSON, &resultString); err != nil {
t.Fatalf("'result' is not a JSON-encoded string: %s", err)
}
var got, want []any
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
t.Fatalf("failed to unmarshal actual result string: %v", err)
}
if err := json.Unmarshal([]byte(tc.want), &want); err != nil {
t.Fatalf("failed to unmarshal expected want string: %v", err)
}
// Checking only the default public schema where the test tables are created to avoid brittle tests.
if tc.isAllTables {
var filteredGot []any
for _, item := range got {
if tableMap, ok := item.(map[string]interface{}); ok {
if schema, ok := tableMap["schema_name"]; ok && schema == "public" {
filteredGot = append(filteredGot, item)
}
}
}
got = filteredGot
}
sort.SliceStable(got, func(i, j int) bool {
return fmt.Sprintf("%v", got[i]) < fmt.Sprintf("%v", got[j])
})
sort.SliceStable(want, func(i, j int) bool {
return fmt.Sprintf("%v", want[i]) < fmt.Sprintf("%v", want[j])
})
if !reflect.DeepEqual(got, want) {
t.Errorf("Unexpected result: got %#v, want: %#v", got, want)
}
}
})
}
}
func runPostgresListActiveQueriesTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
type queryListDetails struct {
ProcessId any `json:"pid"`
User string `json:"user"`
Datname string `json:"datname"`
ApplicationName string `json:"application_name"`
ClientAddress string `json:"client_addr"`
State string `json:"state"`
WaitEventType string `json:"wait_event_type"`
WaitEvent string `json:"wait_event"`
BackendStart any `json:"backend_start"`
TransactionStart any `json:"xact_start"`
QueryStart any `json:"query_start"`
QueryDuration any `json:"query_duration"`
Query string `json:"query"`
}
singleQueryWanted := queryListDetails{
ProcessId: any(nil),
User: "",
Datname: "",
ApplicationName: "",
ClientAddress: "",
State: "",
WaitEventType: "",
WaitEvent: "",
BackendStart: any(nil),
TransactionStart: any(nil),
QueryStart: any(nil),
QueryDuration: any(nil),
Query: "SELECT pg_sleep(10);",
}
invokeTcs := []struct {
name string
requestBody io.Reader
clientSleepSecs int
waitSecsBeforeCheck int
wantStatusCode int
want any
}{
// exclude background monitoring apps such as "wal_uploader"
{
name: "invoke list_active_queries when the system is idle",
requestBody: bytes.NewBufferString(`{"exclude_application_names": "wal_uploader"}`),
clientSleepSecs: 0,
waitSecsBeforeCheck: 0,
wantStatusCode: http.StatusOK,
want: []queryListDetails(nil),
},
{
name: "invoke list_active_queries when there is 1 ongoing but lower than the threshold",
requestBody: bytes.NewBufferString(`{"min_duration": "100 seconds", "exclude_application_names": "wal_uploader"}`),
clientSleepSecs: 1,
waitSecsBeforeCheck: 1,
wantStatusCode: http.StatusOK,
want: []queryListDetails(nil),
},
{
name: "invoke list_active_queries when 1 ongoing query should show up",
requestBody: bytes.NewBufferString(`{"min_duration": "1 seconds", "exclude_application_names": "wal_uploader"}`),
clientSleepSecs: 10,
waitSecsBeforeCheck: 5,
wantStatusCode: http.StatusOK,
want: []queryListDetails{singleQueryWanted},
},
}
var wg sync.WaitGroup
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
if tc.clientSleepSecs > 0 {
wg.Add(1)
go func() {
defer wg.Done()
err := pool.Ping(ctx)
if err != nil {
t.Errorf("unable to connect to test database: %s", err)
return
}
_, err = pool.Exec(ctx, fmt.Sprintf("SELECT pg_sleep(%d);", tc.clientSleepSecs))
if err != nil {
t.Errorf("Executing 'SELECT pg_sleep' failed: %s", err)
}
}()
}
if tc.waitSecsBeforeCheck > 0 {
time.Sleep(time.Duration(tc.waitSecsBeforeCheck) * time.Second)
}
const api = "http://127.0.0.1:5000/api/tool/list_active_queries/invoke"
req, err := http.NewRequest(http.MethodPost, api, tc.requestBody)
if err != nil {
t.Fatalf("unable to create request: %v", err)
}
req.Header.Add("Content-type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != tc.wantStatusCode {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
}
if tc.wantStatusCode != http.StatusOK {
return
}
var bodyWrapper struct {
Result json.RawMessage `json:"result"`
}
if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil {
t.Fatalf("error decoding response wrapper: %v", err)
}
var resultString string
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
resultString = string(bodyWrapper.Result)
}
var got any
var details []queryListDetails
if err := json.Unmarshal([]byte(resultString), &details); err != nil {
t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err)
}
got = details
if diff := cmp.Diff(tc.want, got, cmp.Comparer(func(a, b queryListDetails) bool {
return a.Query == b.Query
})); diff != "" {
t.Errorf("Unexpected result: got %#v, want: %#v", got, tc.want)
}
})
}
wg.Wait()
}
func setUpPostgresViews(t *testing.T, ctx context.Context, pool *pgxpool.Pool, viewName, tableName string) func() {
createView := fmt.Sprintf("CREATE VIEW %s AS SELECT name FROM %s", viewName, tableName)
_, err := pool.Exec(ctx, createView)
if err != nil {
t.Fatalf("failed to create view: %v", err)
}
return func() {
dropView := fmt.Sprintf("DROP VIEW %s", viewName)
_, err := pool.Exec(ctx, dropView)
if err != nil {
t.Fatalf("failed to drop view: %v", err)
}
}
}
func runPostgresListViewsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool, tableName string) {
viewName1 := "test_view_1" + strings.ReplaceAll(uuid.New().String(), "-", "")
dropViewfunc1 := setUpPostgresViews(t, ctx, pool, viewName1, tableName)
defer dropViewfunc1()
invokeTcs := []struct {
name string
requestBody io.Reader
wantStatusCode int
want string
}{
{
name: "invoke list_views with newly created view",
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"viewname": "%s"}`, viewName1))),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf(`[{"schemaname":"public","viewname":"%s","viewowner":"postgres"}]`, viewName1),
},
{
name: "invoke list_views with non-existent_view",
requestBody: bytes.NewBuffer([]byte(`{"viewname": "non_existent_view"}`)),
wantStatusCode: http.StatusOK,
want: `null`,
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
const api = "http://127.0.0.1:5000/api/tool/list_views/invoke"
req, err := http.NewRequest(http.MethodPost, api, tc.requestBody)
if err != nil {
t.Fatalf("unable to create request: %v", err)
}
req.Header.Add("Content-type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != tc.wantStatusCode {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
}
if tc.wantStatusCode != http.StatusOK {
return
}
var bodyWrapper struct {
Result json.RawMessage `json:"result"`
}
if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil {
t.Fatalf("error decoding response wrapper: %v", err)
}
var resultString string
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
resultString = string(bodyWrapper.Result)
}
var got, want any
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
t.Fatalf("failed to unmarshal nested result string: %v", err)
}
if err := json.Unmarshal([]byte(tc.want), &want); err != nil {
t.Fatalf("failed to unmarshal want string: %v", err)
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("Unexpected result (-want +got):\n%s", diff)
}
})
}
}
func runPostgresListAvailableExtensionsTest(t *testing.T) {
invokeTcs := []struct {
name string
api string
requestBody io.Reader
wantStatusCode int
}{
{
name: "invoke list_available_extensions output",
api: "http://127.0.0.1:5000/api/tool/list_available_extensions/invoke",
wantStatusCode: http.StatusOK,
requestBody: bytes.NewBuffer([]byte(`{}`)),
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != tc.wantStatusCode {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
}
// Intentionally not adding the output check as output depends on the postgres instance used where the the functional test runs.
// Adding the check will make the test flaky.
})
}
}
func runPostgresListInstalledExtensionsTest(t *testing.T) {
invokeTcs := []struct {
name string
api string
requestBody io.Reader
wantStatusCode int
}{
{
name: "invoke list_installed_extensions output",
api: "http://127.0.0.1:5000/api/tool/list_installed_extensions/invoke",
wantStatusCode: http.StatusOK,
requestBody: bytes.NewBuffer([]byte(`{}`)),
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != tc.wantStatusCode {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
}
// Intentionally not adding the output check as output depends on the postgres instance used where the the functional test runs.
// Adding the check will make the test flaky.
})
}
tests.RunPostgresListActiveQueriesTest(t, ctx, pool)
tests.RunPostgresListAvailableExtensionsTest(t)
tests.RunPostgresListInstalledExtensionsTest(t)
}

View File

@@ -149,31 +149,17 @@ func RunToolInvokeSimpleTest(t *testing.T, name string, simpleWant string) {
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
// Send Tool invocation request
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
for k, v := range tc.requestHeader {
req.Header.Add(k, v)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %s", err)
}
defer resp.Body.Close()
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader)
if resp.StatusCode != http.StatusOK {
if tc.isErr {
return
}
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
}
// Check response body
var body map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&body)
err := json.Unmarshal(respBody, &body)
if err != nil {
t.Fatalf("error parsing response body")
}
@@ -212,31 +198,17 @@ func RunToolInvokeParametersTest(t *testing.T, name string, params []byte, simpl
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
// Send Tool invocation request
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
for k, v := range tc.requestHeader {
req.Header.Add(k, v)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %s", err)
}
defer resp.Body.Close()
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader)
if resp.StatusCode != http.StatusOK {
if tc.isErr {
return
}
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
}
// Check response body
var body map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&body)
err := json.Unmarshal(respBody, &body)
if err != nil {
t.Fatalf("error parsing response body")
}
@@ -447,25 +419,11 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
return
}
// Send Tool invocation request
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
// Add headers
for k, v := range tc.requestHeader {
req.Header.Add(k, v)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %s", err)
}
defer resp.Body.Close()
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader)
// Check status code
if resp.StatusCode != tc.wantStatusCode {
body, _ := io.ReadAll(resp.Body)
t.Errorf("StatusCode mismatch: got %d, want %d. Response body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
t.Errorf("StatusCode mismatch: got %d, want %d. Response body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
}
// skip response body check
@@ -475,7 +433,7 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
// Check response body
var body map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&body)
err = json.Unmarshal(respBody, &body)
if err != nil {
t.Fatalf("error parsing response body: %s", err)
}
@@ -620,32 +578,17 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options
insertAllow := !tc.insert || (tc.insert && configs.supportInsert)
if ddlAllow && insertAllow {
// Send Tool invocation request
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
for k, v := range tc.requestHeader {
req.Header.Add(k, v)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %s", err)
}
defer resp.Body.Close()
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader)
if resp.StatusCode != http.StatusOK {
if tc.isErr {
return
}
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
}
// Check response body
var body map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&body)
err := json.Unmarshal(respBody, &body)
if err != nil {
t.Fatalf("error parsing response body")
}
@@ -769,31 +712,17 @@ func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement, select1Want
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
// Send Tool invocation request
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
for k, v := range tc.requestHeader {
req.Header.Add(k, v)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %s", err)
}
defer resp.Body.Close()
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader)
if resp.StatusCode != http.StatusOK {
if tc.isErr {
return
}
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
}
// Check response body
var body map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&body)
err = json.Unmarshal(respBody, &body)
if err != nil {
t.Fatalf("error parsing response body")
}
@@ -1157,6 +1086,257 @@ func setupPostgresSchemas(t *testing.T, ctx context.Context, pool *pgxpool.Pool,
}
}
func RunPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth, user string) {
// TableNameParam columns to construct want
paramTableColumns := fmt.Sprintf(`[
{"data_type": "integer", "column_name": "id", "column_default": "nextval('%s_id_seq'::regclass)", "is_not_nullable": true, "ordinal_position": 1, "column_comment": null},
{"data_type": "text", "column_name": "name", "column_default": null, "is_not_nullable": false, "ordinal_position": 2, "column_comment": null}
]`, tableNameParam)
// TableNameAuth columns to construct want
authTableColumns := fmt.Sprintf(`[
{"data_type": "integer", "column_name": "id", "column_default": "nextval('%s_id_seq'::regclass)", "is_not_nullable": true, "ordinal_position": 1, "column_comment": null},
{"data_type": "text", "column_name": "name", "column_default": null, "is_not_nullable": false, "ordinal_position": 2, "column_comment": null},
{"data_type": "text", "column_name": "email", "column_default": null, "is_not_nullable": false, "ordinal_position": 3, "column_comment": null}
]`, tableNameAuth)
const (
// Template to construct detailed output want
detailedObjectTemplate = `{
"object_name": "%[1]s", "schema_name": "public",
"object_details": {
"owner": "%[3]s", "comment": null,
"indexes": [{"is_primary": true, "is_unique": true, "index_name": "%[1]s_pkey", "index_method": "btree", "index_columns": ["id"], "index_definition": "CREATE UNIQUE INDEX %[1]s_pkey ON public.%[1]s USING btree (id)"}],
"triggers": [], "columns": %[2]s, "object_name": "%[1]s", "object_type": "TABLE", "schema_name": "public",
"constraints": [{"constraint_name": "%[1]s_pkey", "constraint_type": "PRIMARY KEY", "constraint_columns": ["id"], "constraint_definition": "PRIMARY KEY (id)", "foreign_key_referenced_table": null, "foreign_key_referenced_columns": null}]
}
}`
// Template to construct simple output want
simpleObjectTemplate = `{"object_name":"%s", "schema_name":"public", "object_details":{"name":"%s"}}`
)
// Helper to build json for detailed want
getDetailedWant := func(tableName, columnJSON string) string {
return fmt.Sprintf(detailedObjectTemplate, tableName, columnJSON, user)
}
// Helper to build template for simple want
getSimpleWant := func(tableName string) string {
return fmt.Sprintf(simpleObjectTemplate, tableName, tableName)
}
invokeTcs := []struct {
name string
api string
requestBody io.Reader
wantStatusCode int
want string
isAllTables bool
}{
{
name: "invoke list_tables all tables detailed output",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: bytes.NewBuffer([]byte(`{"table_names": ""}`)),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)),
isAllTables: true,
},
{
name: "invoke list_tables all tables simple output",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "simple"}`)),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s,%s]", getSimpleWant(tableNameAuth), getSimpleWant(tableNameParam)),
isAllTables: true,
},
{
name: "invoke list_tables detailed output",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s"}`, tableNameAuth))),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s]", getDetailedWant(tableNameAuth, authTableColumns)),
},
{
name: "invoke list_tables simple output",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s", "output_format": "simple"}`, tableNameAuth))),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s]", getSimpleWant(tableNameAuth)),
},
{
name: "invoke list_tables with invalid output format",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "abcd"}`)),
wantStatusCode: http.StatusBadRequest,
},
{
name: "invoke list_tables with malformed table_names parameter",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: bytes.NewBuffer([]byte(`{"table_names": 12345, "output_format": "detailed"}`)),
wantStatusCode: http.StatusBadRequest,
},
{
name: "invoke list_tables with multiple table names",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,%s"}`, tableNameParam, tableNameAuth))),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)),
},
{
name: "invoke list_tables with non-existent table",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: bytes.NewBuffer([]byte(`{"table_names": "non_existent_table"}`)),
wantStatusCode: http.StatusOK,
want: `null`,
},
{
name: "invoke list_tables with one existing and one non-existent table",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,non_existent_table"}`, tableNameParam))),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s]", getDetailedWant(tableNameParam, paramTableColumns)),
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
resp, respBytes := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, nil)
if resp.StatusCode != tc.wantStatusCode {
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBytes))
}
if tc.wantStatusCode == http.StatusOK {
var bodyWrapper map[string]json.RawMessage
if err := json.Unmarshal(respBytes, &bodyWrapper); err != nil {
t.Fatalf("error parsing response wrapper: %s, body: %s", err, string(respBytes))
}
resultJSON, ok := bodyWrapper["result"]
if !ok {
t.Fatal("unable to find 'result' in response body")
}
var resultString string
if err := json.Unmarshal(resultJSON, &resultString); err != nil {
t.Fatalf("'result' is not a JSON-encoded string: %s", err)
}
var got, want []any
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
t.Fatalf("failed to unmarshal actual result string: %v", err)
}
if err := json.Unmarshal([]byte(tc.want), &want); err != nil {
t.Fatalf("failed to unmarshal expected want string: %v", err)
}
// Checking only the default public schema where the test tables are created to avoid brittle tests.
if tc.isAllTables {
var filteredGot []any
for _, item := range got {
if tableMap, ok := item.(map[string]interface{}); ok {
if schema, ok := tableMap["schema_name"]; ok && schema == "public" {
filteredGot = append(filteredGot, item)
}
}
}
got = filteredGot
}
sort.SliceStable(got, func(i, j int) bool {
return fmt.Sprintf("%v", got[i]) < fmt.Sprintf("%v", got[j])
})
sort.SliceStable(want, func(i, j int) bool {
return fmt.Sprintf("%v", want[i]) < fmt.Sprintf("%v", want[j])
})
if !reflect.DeepEqual(got, want) {
t.Errorf("Unexpected result: got %#v, want: %#v", got, want)
}
}
})
}
}
func setUpPostgresViews(t *testing.T, ctx context.Context, pool *pgxpool.Pool, viewName, tableName string) func() {
createView := fmt.Sprintf("CREATE VIEW %s AS SELECT name FROM %s", viewName, tableName)
_, err := pool.Exec(ctx, createView)
if err != nil {
t.Fatalf("failed to create view: %v", err)
}
return func() {
dropView := fmt.Sprintf("DROP VIEW %s", viewName)
_, err := pool.Exec(ctx, dropView)
if err != nil {
t.Fatalf("failed to drop view: %v", err)
}
}
}
func RunPostgresListViewsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool, tableName string) {
viewName1 := "test_view_1" + strings.ReplaceAll(uuid.New().String(), "-", "")
dropViewfunc1 := setUpPostgresViews(t, ctx, pool, viewName1, tableName)
defer dropViewfunc1()
invokeTcs := []struct {
name string
requestBody io.Reader
wantStatusCode int
want string
}{
{
name: "invoke list_views with newly created view",
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"viewname": "%s"}`, viewName1))),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf(`[{"schemaname":"public","viewname":"%s","viewowner":"postgres"}]`, viewName1),
},
{
name: "invoke list_views with non-existent_view",
requestBody: bytes.NewBuffer([]byte(`{"viewname": "non_existent_view"}`)),
wantStatusCode: http.StatusOK,
want: `null`,
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
const api = "http://127.0.0.1:5000/api/tool/list_views/invoke"
resp, body := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
if resp.StatusCode != tc.wantStatusCode {
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
}
if tc.wantStatusCode != http.StatusOK {
return
}
var bodyWrapper struct {
Result json.RawMessage `json:"result"`
}
if err := json.Unmarshal(body, &bodyWrapper); err != nil {
t.Fatalf("error decoding response wrapper: %v", err)
}
var resultString string
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
resultString = string(bodyWrapper.Result)
}
var got, want any
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
t.Fatalf("failed to unmarshal nested result string: %v", err)
}
if err := json.Unmarshal([]byte(tc.want), &want); err != nil {
t.Fatalf("failed to unmarshal want string: %v", err)
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("Unexpected result (-want +got):\n%s", diff)
}
})
}
}
func RunPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
schemaName := "test_schema_" + strings.ReplaceAll(uuid.New().String(), "-", "")
cleanup := setupPostgresSchemas(t, ctx, pool, schemaName)
@@ -1186,20 +1366,9 @@ func RunPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxpool
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
const api = "http://127.0.0.1:5000/api/tool/list_schemas/invoke"
req, err := http.NewRequest(http.MethodPost, api, tc.requestBody)
if err != nil {
t.Fatalf("unable to create request: %v", err)
}
req.Header.Add("Content-type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %v", err)
}
defer resp.Body.Close()
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
if resp.StatusCode != tc.wantStatusCode {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
}
if tc.wantStatusCode != http.StatusOK {
return
@@ -1208,7 +1377,7 @@ func RunPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxpool
var bodyWrapper struct {
Result json.RawMessage `json:"result"`
}
if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil {
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
t.Fatalf("error decoding response wrapper: %v", err)
}
@@ -1229,6 +1398,191 @@ func RunPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxpool
}
}
func RunPostgresListActiveQueriesTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
type queryListDetails struct {
ProcessId any `json:"pid"`
User string `json:"user"`
Datname string `json:"datname"`
ApplicationName string `json:"application_name"`
ClientAddress string `json:"client_addr"`
State string `json:"state"`
WaitEventType string `json:"wait_event_type"`
WaitEvent string `json:"wait_event"`
BackendStart any `json:"backend_start"`
TransactionStart any `json:"xact_start"`
QueryStart any `json:"query_start"`
QueryDuration any `json:"query_duration"`
Query string `json:"query"`
}
singleQueryWanted := queryListDetails{
ProcessId: any(nil),
User: "",
Datname: "",
ApplicationName: "",
ClientAddress: "",
State: "",
WaitEventType: "",
WaitEvent: "",
BackendStart: any(nil),
TransactionStart: any(nil),
QueryStart: any(nil),
QueryDuration: any(nil),
Query: "SELECT pg_sleep(10);",
}
invokeTcs := []struct {
name string
requestBody io.Reader
clientSleepSecs int
waitSecsBeforeCheck int
wantStatusCode int
want any
}{
// exclude background monitoring apps such as "wal_uploader"
{
name: "invoke list_active_queries when the system is idle",
requestBody: bytes.NewBufferString(`{"exclude_application_names": "wal_uploader"}`),
clientSleepSecs: 0,
waitSecsBeforeCheck: 0,
wantStatusCode: http.StatusOK,
want: []queryListDetails(nil),
},
{
name: "invoke list_active_queries when there is 1 ongoing but lower than the threshold",
requestBody: bytes.NewBufferString(`{"min_duration": "100 seconds", "exclude_application_names": "wal_uploader"}`),
clientSleepSecs: 1,
waitSecsBeforeCheck: 1,
wantStatusCode: http.StatusOK,
want: []queryListDetails(nil),
},
{
name: "invoke list_active_queries when 1 ongoing query should show up",
requestBody: bytes.NewBufferString(`{"min_duration": "1 seconds", "exclude_application_names": "wal_uploader"}`),
clientSleepSecs: 10,
waitSecsBeforeCheck: 5,
wantStatusCode: http.StatusOK,
want: []queryListDetails{singleQueryWanted},
},
}
var wg sync.WaitGroup
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
if tc.clientSleepSecs > 0 {
wg.Add(1)
go func() {
defer wg.Done()
err := pool.Ping(ctx)
if err != nil {
t.Errorf("unable to connect to test database: %s", err)
return
}
_, err = pool.Exec(ctx, fmt.Sprintf("SELECT pg_sleep(%d);", tc.clientSleepSecs))
if err != nil {
t.Errorf("Executing 'SELECT pg_sleep' failed: %s", err)
}
}()
}
if tc.waitSecsBeforeCheck > 0 {
time.Sleep(time.Duration(tc.waitSecsBeforeCheck) * time.Second)
}
const api = "http://127.0.0.1:5000/api/tool/list_active_queries/invoke"
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
if resp.StatusCode != tc.wantStatusCode {
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
}
if tc.wantStatusCode != http.StatusOK {
return
}
var bodyWrapper struct {
Result json.RawMessage `json:"result"`
}
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
t.Fatalf("error decoding response wrapper: %v", err)
}
var resultString string
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
resultString = string(bodyWrapper.Result)
}
var got any
var details []queryListDetails
if err := json.Unmarshal([]byte(resultString), &details); err != nil {
t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err)
}
got = details
if diff := cmp.Diff(tc.want, got, cmp.Comparer(func(a, b queryListDetails) bool {
return a.Query == b.Query
})); diff != "" {
t.Errorf("Unexpected result: got %#v, want: %#v", got, tc.want)
}
})
}
wg.Wait()
}
func RunPostgresListAvailableExtensionsTest(t *testing.T) {
invokeTcs := []struct {
name string
api string
requestBody io.Reader
wantStatusCode int
}{
{
name: "invoke list_available_extensions output",
api: "http://127.0.0.1:5000/api/tool/list_available_extensions/invoke",
wantStatusCode: http.StatusOK,
requestBody: bytes.NewBuffer([]byte(`{}`)),
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, nil)
if resp.StatusCode != tc.wantStatusCode {
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
}
// Intentionally not adding the output check as output depends on the postgres instance used where the the functional test runs.
// Adding the check will make the test flaky.
})
}
}
func RunPostgresListInstalledExtensionsTest(t *testing.T) {
invokeTcs := []struct {
name string
api string
requestBody io.Reader
wantStatusCode int
}{
{
name: "invoke list_installed_extensions output",
api: "http://127.0.0.1:5000/api/tool/list_installed_extensions/invoke",
wantStatusCode: http.StatusOK,
requestBody: bytes.NewBuffer([]byte(`{}`)),
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
resp, bodyBytes := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, nil)
if resp.StatusCode != tc.wantStatusCode {
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
}
// Intentionally not adding the output check as output depends on the postgres instance used where the the functional test runs.
// Adding the check will make the test flaky.
})
}
}
// RunMySQLListTablesTest run tests against the mysql-list-tables tool
func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNameAuth string) {
type tableInfo struct {
@@ -1335,20 +1689,8 @@ func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNam
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
const api = "http://127.0.0.1:5000/api/tool/list_tables/invoke"
req, err := http.NewRequest(http.MethodPost, api, tc.requestBody)
if err != nil {
t.Fatalf("unable to create request: %v", err)
}
req.Header.Add("Content-type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %v", err)
}
defer resp.Body.Close()
resp, body := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
if resp.StatusCode != tc.wantStatusCode {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
}
if tc.wantStatusCode != http.StatusOK {
@@ -1358,7 +1700,7 @@ func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNam
var bodyWrapper struct {
Result json.RawMessage `json:"result"`
}
if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil {
if err := json.Unmarshal(body, &bodyWrapper); err != nil {
t.Fatalf("error decoding response wrapper: %v", err)
}
@@ -1532,21 +1874,9 @@ func RunMySQLListActiveQueriesTest(t *testing.T, ctx context.Context, pool *sql.
}
const api = "http://127.0.0.1:5000/api/tool/list_active_queries/invoke"
req, err := http.NewRequest(http.MethodPost, api, tc.requestBody)
if err != nil {
t.Fatalf("unable to create request: %v", err)
}
req.Header.Add("Content-type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %v", err)
}
defer resp.Body.Close()
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
if resp.StatusCode != tc.wantStatusCode {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
}
if tc.wantStatusCode != http.StatusOK {
return
@@ -1555,7 +1885,7 @@ func RunMySQLListActiveQueriesTest(t *testing.T, ctx context.Context, pool *sql.
var bodyWrapper struct {
Result json.RawMessage `json:"result"`
}
if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil {
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
t.Fatalf("error decoding response wrapper: %v", err)
}
@@ -1765,21 +2095,9 @@ func RunMySQLListTablesMissingUniqueIndexes(t *testing.T, ctx context.Context, p
}
const api = "http://127.0.0.1:5000/api/tool/list_tables_missing_unique_indexes/invoke"
req, err := http.NewRequest(http.MethodPost, api, tc.requestBody)
if err != nil {
t.Fatalf("unable to create request: %v", err)
}
req.Header.Add("Content-type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %v", err)
}
defer resp.Body.Close()
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
if resp.StatusCode != tc.wantStatusCode {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
}
if tc.wantStatusCode != http.StatusOK {
return
@@ -1788,7 +2106,7 @@ func RunMySQLListTablesMissingUniqueIndexes(t *testing.T, ctx context.Context, p
var bodyWrapper struct {
Result json.RawMessage `json:"result"`
}
if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil {
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
t.Fatalf("error decoding response wrapper: %v", err)
}
@@ -1892,21 +2210,9 @@ func RunMySQLListTableFragmentationTest(t *testing.T, databaseName, tableNamePar
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
const api = "http://127.0.0.1:5000/api/tool/list_table_fragmentation/invoke"
req, err := http.NewRequest(http.MethodPost, api, tc.requestBody)
if err != nil {
t.Fatalf("unable to create request: %v", err)
}
req.Header.Add("Content-type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %v", err)
}
defer resp.Body.Close()
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
if resp.StatusCode != tc.wantStatusCode {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
}
if tc.wantStatusCode != http.StatusOK {
return
@@ -1915,7 +2221,7 @@ func RunMySQLListTableFragmentationTest(t *testing.T, databaseName, tableNamePar
var bodyWrapper struct {
Result json.RawMessage `json:"result"`
}
if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil {
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
t.Fatalf("error decoding response wrapper: %v", err)
}