mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 07:28:05 -05:00
test: add prebuilt tools test to all postgres source (#1874)
Move postgres prebuilt integration tests to `common.go` and `tool.go`. Run those tests from alloydbpg and cloudsqlpg as well. alloydbpg and cloudsqlpg integration test coverage calculate against the whole `internal/tools/postgres/` folder. If not added, the coverage will eventually drop below minimum requirement.
This commit is contained in:
@@ -129,6 +129,9 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
||||
t.Fatalf("unable to create AlloyDB connection pool: %s", err)
|
||||
}
|
||||
|
||||
// cleanup test environment
|
||||
tests.CleanupPostgresTables(t, ctx, pool)
|
||||
|
||||
// create table name with UUID
|
||||
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
@@ -175,7 +178,14 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
||||
tests.RunMCPToolCallMethod(t, failInvocationWant, mcpSelect1Want)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
|
||||
|
||||
// Run Postgres prebuilt tool tests
|
||||
tests.RunPostgresListTablesTest(t, tableNameParam, tableNameAuth, AlloyDBPostgresUser)
|
||||
tests.RunPostgresListViewsTest(t, ctx, pool, tableNameParam)
|
||||
tests.RunPostgresListSchemasTest(t, ctx, pool)
|
||||
tests.RunPostgresListActiveQueriesTest(t, ctx, pool)
|
||||
tests.RunPostgresListAvailableExtensionsTest(t)
|
||||
tests.RunPostgresListInstalledExtensionsTest(t)
|
||||
}
|
||||
|
||||
// Test connection with different IP type
|
||||
|
||||
@@ -114,6 +114,9 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
||||
t.Fatalf("unable to create Cloud SQL connection pool: %s", err)
|
||||
}
|
||||
|
||||
// cleanup test environment
|
||||
tests.CleanupPostgresTables(t, ctx, pool)
|
||||
|
||||
// create table name with UUID
|
||||
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
@@ -159,7 +162,14 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
||||
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
|
||||
|
||||
// Run Postgres prebuilt tool tests
|
||||
tests.RunPostgresListTablesTest(t, tableNameParam, tableNameAuth, CloudSQLPostgresUser)
|
||||
tests.RunPostgresListViewsTest(t, ctx, pool, tableNameParam)
|
||||
tests.RunPostgresListSchemasTest(t, ctx, pool)
|
||||
tests.RunPostgresListActiveQueriesTest(t, ctx, pool)
|
||||
tests.RunPostgresListAvailableExtensionsTest(t)
|
||||
tests.RunPostgresListInstalledExtensionsTest(t)
|
||||
}
|
||||
|
||||
// Test connection with different IP type
|
||||
|
||||
@@ -33,10 +33,6 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
var (
|
||||
PostgresListSchemasToolKind = "postgres-list-schemas"
|
||||
)
|
||||
|
||||
// GetToolsConfig returns a mock tools config file
|
||||
func GetToolsConfig(sourceConfig map[string]any, toolKind, paramToolStatement, idParamToolStmt, nameParamToolStmt, arrayToolStatement, authToolStatement string) map[string]any {
|
||||
// Write config into a file and pass it to command
|
||||
@@ -195,10 +191,46 @@ func AddExecuteSqlConfig(t *testing.T, config map[string]any, toolKind string) m
|
||||
}
|
||||
|
||||
func AddPostgresPrebuiltConfig(t *testing.T, config map[string]any) map[string]any {
|
||||
var (
|
||||
PostgresListSchemasToolKind = "postgres-list-schemas"
|
||||
PostgresListTablesToolKind = "postgres-list-tables"
|
||||
PostgresListActiveQueriesToolKind = "postgres-list-active-queries"
|
||||
PostgresListInstalledExtensionsToolKind = "postgres-list-installed-extensions"
|
||||
PostgresListAvailableExtensionsToolKind = "postgres-list-available-extensions"
|
||||
PostgresListViewsToolKind = "postgres-list-views"
|
||||
)
|
||||
|
||||
tools, ok := config["tools"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("unable to get tools from config")
|
||||
}
|
||||
tools["list_tables"] = map[string]any{
|
||||
"kind": PostgresListTablesToolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Lists tables in the database.",
|
||||
}
|
||||
tools["list_active_queries"] = map[string]any{
|
||||
"kind": PostgresListActiveQueriesToolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Lists active queries in the database.",
|
||||
}
|
||||
|
||||
tools["list_installed_extensions"] = map[string]any{
|
||||
"kind": PostgresListInstalledExtensionsToolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Lists installed extensions in the database.",
|
||||
}
|
||||
|
||||
tools["list_available_extensions"] = map[string]any{
|
||||
"kind": PostgresListAvailableExtensionsToolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Lists available extensions in the database.",
|
||||
}
|
||||
|
||||
tools["list_views"] = map[string]any{
|
||||
"kind": PostgresListViewsToolKind,
|
||||
"source": "my-instance",
|
||||
}
|
||||
tools["list_schemas"] = map[string]any{
|
||||
"kind": PostgresListSchemasToolKind,
|
||||
"source": "my-instance",
|
||||
|
||||
@@ -15,23 +15,15 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
@@ -39,18 +31,13 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
PostgresSourceKind = "postgres"
|
||||
PostgresToolKind = "postgres-sql"
|
||||
PostgresListTablesToolKind = "postgres-list-tables"
|
||||
PostgresListActiveQueriesToolKind = "postgres-list-active-queries"
|
||||
PostgresListInstalledExtensionsToolKind = "postgres-list-installed-extensions"
|
||||
PostgresListAvailableExtensionsToolKind = "postgres-list-available-extensions"
|
||||
PostgresListViewsToolKind = "postgres-list-views"
|
||||
PostgresDatabase = os.Getenv("POSTGRES_DATABASE")
|
||||
PostgresHost = os.Getenv("POSTGRES_HOST")
|
||||
PostgresPort = os.Getenv("POSTGRES_PORT")
|
||||
PostgresUser = os.Getenv("POSTGRES_USER")
|
||||
PostgresPass = os.Getenv("POSTGRES_PASS")
|
||||
PostgresSourceKind = "postgres"
|
||||
PostgresToolKind = "postgres-sql"
|
||||
PostgresDatabase = os.Getenv("POSTGRES_DATABASE")
|
||||
PostgresHost = os.Getenv("POSTGRES_HOST")
|
||||
PostgresPort = os.Getenv("POSTGRES_PORT")
|
||||
PostgresUser = os.Getenv("POSTGRES_USER")
|
||||
PostgresPass = os.Getenv("POSTGRES_PASS")
|
||||
)
|
||||
|
||||
func getPostgresVars(t *testing.T) map[string]any {
|
||||
@@ -77,43 +64,6 @@ func getPostgresVars(t *testing.T) map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
func addPrebuiltToolConfig(t *testing.T, config map[string]any) map[string]any {
|
||||
tools, ok := config["tools"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("unable to get tools from config")
|
||||
}
|
||||
tools["list_tables"] = map[string]any{
|
||||
"kind": PostgresListTablesToolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Lists tables in the database.",
|
||||
}
|
||||
tools["list_active_queries"] = map[string]any{
|
||||
"kind": PostgresListActiveQueriesToolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Lists active queries in the database.",
|
||||
}
|
||||
|
||||
tools["list_installed_extensions"] = map[string]any{
|
||||
"kind": PostgresListInstalledExtensionsToolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Lists installed extensions in the database.",
|
||||
}
|
||||
|
||||
tools["list_available_extensions"] = map[string]any{
|
||||
"kind": PostgresListAvailableExtensionsToolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Lists available extensions in the database.",
|
||||
}
|
||||
|
||||
tools["list_views"] = map[string]any{
|
||||
"kind": PostgresListViewsToolKind,
|
||||
"source": "my-instance",
|
||||
}
|
||||
|
||||
config["tools"] = tools
|
||||
return config
|
||||
}
|
||||
|
||||
// Copied over from postgres.go
|
||||
func initPostgresConnectionPool(host, port, user, pass, dbname string) (*pgxpool.Pool, error) {
|
||||
// urlExample := "postgres:dd//username:password@localhost:5432/database_name"
|
||||
@@ -166,8 +116,6 @@ func TestPostgres(t *testing.T) {
|
||||
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, PostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
toolsFile = addPrebuiltToolConfig(t, toolsFile)
|
||||
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
@@ -194,506 +142,11 @@ func TestPostgres(t *testing.T) {
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
|
||||
|
||||
// Run specific Postgres tool tests
|
||||
runPostgresListTablesTest(t, tableNameParam, tableNameAuth)
|
||||
runPostgresListViewsTest(t, ctx, pool, tableNameParam)
|
||||
// Run Postgres prebuilt tool tests
|
||||
tests.RunPostgresListTablesTest(t, tableNameParam, tableNameAuth, PostgresUser)
|
||||
tests.RunPostgresListViewsTest(t, ctx, pool, tableNameParam)
|
||||
tests.RunPostgresListSchemasTest(t, ctx, pool)
|
||||
runPostgresListActiveQueriesTest(t, ctx, pool)
|
||||
runPostgresListAvailableExtensionsTest(t)
|
||||
runPostgresListInstalledExtensionsTest(t)
|
||||
}
|
||||
|
||||
func runPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth string) {
|
||||
// TableNameParam columns to construct want
|
||||
paramTableColumns := fmt.Sprintf(`[
|
||||
{"data_type": "integer", "column_name": "id", "column_default": "nextval('%s_id_seq'::regclass)", "is_not_nullable": true, "ordinal_position": 1, "column_comment": null},
|
||||
{"data_type": "text", "column_name": "name", "column_default": null, "is_not_nullable": false, "ordinal_position": 2, "column_comment": null}
|
||||
]`, tableNameParam)
|
||||
|
||||
// TableNameAuth columns to construct want
|
||||
authTableColumns := fmt.Sprintf(`[
|
||||
{"data_type": "integer", "column_name": "id", "column_default": "nextval('%s_id_seq'::regclass)", "is_not_nullable": true, "ordinal_position": 1, "column_comment": null},
|
||||
{"data_type": "text", "column_name": "name", "column_default": null, "is_not_nullable": false, "ordinal_position": 2, "column_comment": null},
|
||||
{"data_type": "text", "column_name": "email", "column_default": null, "is_not_nullable": false, "ordinal_position": 3, "column_comment": null}
|
||||
]`, tableNameAuth)
|
||||
|
||||
const (
|
||||
// Template to construct detailed output want
|
||||
detailedObjectTemplate = `{
|
||||
"object_name": "%[1]s", "schema_name": "public",
|
||||
"object_details": {
|
||||
"owner": "%[3]s", "comment": null,
|
||||
"indexes": [{"is_primary": true, "is_unique": true, "index_name": "%[1]s_pkey", "index_method": "btree", "index_columns": ["id"], "index_definition": "CREATE UNIQUE INDEX %[1]s_pkey ON public.%[1]s USING btree (id)"}],
|
||||
"triggers": [], "columns": %[2]s, "object_name": "%[1]s", "object_type": "TABLE", "schema_name": "public",
|
||||
"constraints": [{"constraint_name": "%[1]s_pkey", "constraint_type": "PRIMARY KEY", "constraint_columns": ["id"], "constraint_definition": "PRIMARY KEY (id)", "foreign_key_referenced_table": null, "foreign_key_referenced_columns": null}]
|
||||
}
|
||||
}`
|
||||
|
||||
// Template to construct simple output want
|
||||
simpleObjectTemplate = `{"object_name":"%s", "schema_name":"public", "object_details":{"name":"%s"}}`
|
||||
)
|
||||
|
||||
// Helper to build json for detailed want
|
||||
getDetailedWant := func(tableName, columnJSON string) string {
|
||||
return fmt.Sprintf(detailedObjectTemplate, tableName, columnJSON, PostgresUser)
|
||||
}
|
||||
|
||||
// Helper to build template for simple want
|
||||
getSimpleWant := func(tableName string) string {
|
||||
return fmt.Sprintf(simpleObjectTemplate, tableName, tableName)
|
||||
}
|
||||
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestBody io.Reader
|
||||
wantStatusCode int
|
||||
want string
|
||||
isAllTables bool
|
||||
}{
|
||||
{
|
||||
name: "invoke list_tables all tables detailed output",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"table_names": ""}`)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)),
|
||||
isAllTables: true,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables all tables simple output",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "simple"}`)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf("[%s,%s]", getSimpleWant(tableNameAuth), getSimpleWant(tableNameParam)),
|
||||
isAllTables: true,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables detailed output",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s"}`, tableNameAuth))),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf("[%s]", getDetailedWant(tableNameAuth, authTableColumns)),
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables simple output",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s", "output_format": "simple"}`, tableNameAuth))),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf("[%s]", getSimpleWant(tableNameAuth)),
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with invalid output format",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "abcd"}`)),
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with malformed table_names parameter",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"table_names": 12345, "output_format": "detailed"}`)),
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with multiple table names",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,%s"}`, tableNameParam, tableNameAuth))),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)),
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with non-existent table",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"table_names": "non_existent_table"}`)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: `null`,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with one existing and one non-existent table",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,non_existent_table"}`, tableNameParam))),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf("[%s]", getDetailedWant(tableNameParam, paramTableColumns)),
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
if tc.wantStatusCode == http.StatusOK {
|
||||
var bodyWrapper map[string]json.RawMessage
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("error reading response body: %s", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(respBytes, &bodyWrapper); err != nil {
|
||||
t.Fatalf("error parsing response wrapper: %s, body: %s", err, string(respBytes))
|
||||
}
|
||||
|
||||
resultJSON, ok := bodyWrapper["result"]
|
||||
if !ok {
|
||||
t.Fatal("unable to find 'result' in response body")
|
||||
}
|
||||
|
||||
var resultString string
|
||||
if err := json.Unmarshal(resultJSON, &resultString); err != nil {
|
||||
t.Fatalf("'result' is not a JSON-encoded string: %s", err)
|
||||
}
|
||||
|
||||
var got, want []any
|
||||
|
||||
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
||||
t.Fatalf("failed to unmarshal actual result string: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal([]byte(tc.want), &want); err != nil {
|
||||
t.Fatalf("failed to unmarshal expected want string: %v", err)
|
||||
}
|
||||
|
||||
// Checking only the default public schema where the test tables are created to avoid brittle tests.
|
||||
if tc.isAllTables {
|
||||
var filteredGot []any
|
||||
for _, item := range got {
|
||||
if tableMap, ok := item.(map[string]interface{}); ok {
|
||||
if schema, ok := tableMap["schema_name"]; ok && schema == "public" {
|
||||
filteredGot = append(filteredGot, item)
|
||||
}
|
||||
}
|
||||
}
|
||||
got = filteredGot
|
||||
}
|
||||
|
||||
sort.SliceStable(got, func(i, j int) bool {
|
||||
return fmt.Sprintf("%v", got[i]) < fmt.Sprintf("%v", got[j])
|
||||
})
|
||||
sort.SliceStable(want, func(i, j int) bool {
|
||||
return fmt.Sprintf("%v", want[i]) < fmt.Sprintf("%v", want[j])
|
||||
})
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("Unexpected result: got %#v, want: %#v", got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runPostgresListActiveQueriesTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
||||
type queryListDetails struct {
|
||||
ProcessId any `json:"pid"`
|
||||
User string `json:"user"`
|
||||
Datname string `json:"datname"`
|
||||
ApplicationName string `json:"application_name"`
|
||||
ClientAddress string `json:"client_addr"`
|
||||
State string `json:"state"`
|
||||
WaitEventType string `json:"wait_event_type"`
|
||||
WaitEvent string `json:"wait_event"`
|
||||
BackendStart any `json:"backend_start"`
|
||||
TransactionStart any `json:"xact_start"`
|
||||
QueryStart any `json:"query_start"`
|
||||
QueryDuration any `json:"query_duration"`
|
||||
Query string `json:"query"`
|
||||
}
|
||||
|
||||
singleQueryWanted := queryListDetails{
|
||||
ProcessId: any(nil),
|
||||
User: "",
|
||||
Datname: "",
|
||||
ApplicationName: "",
|
||||
ClientAddress: "",
|
||||
State: "",
|
||||
WaitEventType: "",
|
||||
WaitEvent: "",
|
||||
BackendStart: any(nil),
|
||||
TransactionStart: any(nil),
|
||||
QueryStart: any(nil),
|
||||
QueryDuration: any(nil),
|
||||
Query: "SELECT pg_sleep(10);",
|
||||
}
|
||||
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
requestBody io.Reader
|
||||
clientSleepSecs int
|
||||
waitSecsBeforeCheck int
|
||||
wantStatusCode int
|
||||
want any
|
||||
}{
|
||||
// exclude background monitoring apps such as "wal_uploader"
|
||||
{
|
||||
name: "invoke list_active_queries when the system is idle",
|
||||
requestBody: bytes.NewBufferString(`{"exclude_application_names": "wal_uploader"}`),
|
||||
clientSleepSecs: 0,
|
||||
waitSecsBeforeCheck: 0,
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: []queryListDetails(nil),
|
||||
},
|
||||
{
|
||||
name: "invoke list_active_queries when there is 1 ongoing but lower than the threshold",
|
||||
requestBody: bytes.NewBufferString(`{"min_duration": "100 seconds", "exclude_application_names": "wal_uploader"}`),
|
||||
clientSleepSecs: 1,
|
||||
waitSecsBeforeCheck: 1,
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: []queryListDetails(nil),
|
||||
},
|
||||
{
|
||||
name: "invoke list_active_queries when 1 ongoing query should show up",
|
||||
requestBody: bytes.NewBufferString(`{"min_duration": "1 seconds", "exclude_application_names": "wal_uploader"}`),
|
||||
clientSleepSecs: 10,
|
||||
waitSecsBeforeCheck: 5,
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: []queryListDetails{singleQueryWanted},
|
||||
},
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.clientSleepSecs > 0 {
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
err := pool.Ping(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("unable to connect to test database: %s", err)
|
||||
return
|
||||
}
|
||||
_, err = pool.Exec(ctx, fmt.Sprintf("SELECT pg_sleep(%d);", tc.clientSleepSecs))
|
||||
if err != nil {
|
||||
t.Errorf("Executing 'SELECT pg_sleep' failed: %s", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if tc.waitSecsBeforeCheck > 0 {
|
||||
time.Sleep(time.Duration(tc.waitSecsBeforeCheck) * time.Second)
|
||||
}
|
||||
|
||||
const api = "http://127.0.0.1:5000/api/tool/list_active_queries/invoke"
|
||||
req, err := http.NewRequest(http.MethodPost, api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %v", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
|
||||
}
|
||||
if tc.wantStatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
var bodyWrapper struct {
|
||||
Result json.RawMessage `json:"result"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil {
|
||||
t.Fatalf("error decoding response wrapper: %v", err)
|
||||
}
|
||||
|
||||
var resultString string
|
||||
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
||||
resultString = string(bodyWrapper.Result)
|
||||
}
|
||||
|
||||
var got any
|
||||
var details []queryListDetails
|
||||
if err := json.Unmarshal([]byte(resultString), &details); err != nil {
|
||||
t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err)
|
||||
}
|
||||
got = details
|
||||
|
||||
if diff := cmp.Diff(tc.want, got, cmp.Comparer(func(a, b queryListDetails) bool {
|
||||
return a.Query == b.Query
|
||||
})); diff != "" {
|
||||
t.Errorf("Unexpected result: got %#v, want: %#v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func setUpPostgresViews(t *testing.T, ctx context.Context, pool *pgxpool.Pool, viewName, tableName string) func() {
|
||||
createView := fmt.Sprintf("CREATE VIEW %s AS SELECT name FROM %s", viewName, tableName)
|
||||
_, err := pool.Exec(ctx, createView)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create view: %v", err)
|
||||
}
|
||||
return func() {
|
||||
dropView := fmt.Sprintf("DROP VIEW %s", viewName)
|
||||
_, err := pool.Exec(ctx, dropView)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to drop view: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func runPostgresListViewsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool, tableName string) {
|
||||
viewName1 := "test_view_1" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
dropViewfunc1 := setUpPostgresViews(t, ctx, pool, viewName1, tableName)
|
||||
defer dropViewfunc1()
|
||||
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
requestBody io.Reader
|
||||
wantStatusCode int
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "invoke list_views with newly created view",
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"viewname": "%s"}`, viewName1))),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf(`[{"schemaname":"public","viewname":"%s","viewowner":"postgres"}]`, viewName1),
|
||||
},
|
||||
{
|
||||
name: "invoke list_views with non-existent_view",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"viewname": "non_existent_view"}`)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: `null`,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
const api = "http://127.0.0.1:5000/api/tool/list_views/invoke"
|
||||
req, err := http.NewRequest(http.MethodPost, api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %v", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
|
||||
}
|
||||
if tc.wantStatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
var bodyWrapper struct {
|
||||
Result json.RawMessage `json:"result"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil {
|
||||
t.Fatalf("error decoding response wrapper: %v", err)
|
||||
}
|
||||
|
||||
var resultString string
|
||||
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
||||
resultString = string(bodyWrapper.Result)
|
||||
}
|
||||
|
||||
var got, want any
|
||||
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
||||
t.Fatalf("failed to unmarshal nested result string: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal([]byte(tc.want), &want); err != nil {
|
||||
t.Fatalf("failed to unmarshal want string: %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("Unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runPostgresListAvailableExtensionsTest(t *testing.T) {
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestBody io.Reader
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "invoke list_available_extensions output",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_available_extensions/invoke",
|
||||
wantStatusCode: http.StatusOK,
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Intentionally not adding the output check as output depends on the postgres instance used where the the functional test runs.
|
||||
// Adding the check will make the test flaky.
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runPostgresListInstalledExtensionsTest(t *testing.T) {
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestBody io.Reader
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "invoke list_installed_extensions output",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_installed_extensions/invoke",
|
||||
wantStatusCode: http.StatusOK,
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Intentionally not adding the output check as output depends on the postgres instance used where the the functional test runs.
|
||||
// Adding the check will make the test flaky.
|
||||
})
|
||||
}
|
||||
tests.RunPostgresListActiveQueriesTest(t, ctx, pool)
|
||||
tests.RunPostgresListAvailableExtensionsTest(t)
|
||||
tests.RunPostgresListInstalledExtensionsTest(t)
|
||||
}
|
||||
|
||||
624
tests/tool.go
624
tests/tool.go
@@ -149,31 +149,17 @@ func RunToolInvokeSimpleTest(t *testing.T, name string, simpleWant string) {
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
err := json.Unmarshal(respBody, &body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
@@ -212,31 +198,17 @@ func RunToolInvokeParametersTest(t *testing.T, name string, params []byte, simpl
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
err := json.Unmarshal(respBody, &body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
@@ -447,25 +419,11 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
|
||||
return
|
||||
}
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
// Add headers
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader)
|
||||
|
||||
// Check status code
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Errorf("StatusCode mismatch: got %d, want %d. Response body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
|
||||
t.Errorf("StatusCode mismatch: got %d, want %d. Response body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// skip response body check
|
||||
@@ -475,7 +433,7 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
err = json.Unmarshal(respBody, &body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body: %s", err)
|
||||
}
|
||||
@@ -620,32 +578,17 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options
|
||||
insertAllow := !tc.insert || (tc.insert && configs.supportInsert)
|
||||
if ddlAllow && insertAllow {
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
err := json.Unmarshal(respBody, &body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
@@ -769,31 +712,17 @@ func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement, select1Want
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
err = json.Unmarshal(respBody, &body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
@@ -1157,6 +1086,257 @@ func setupPostgresSchemas(t *testing.T, ctx context.Context, pool *pgxpool.Pool,
|
||||
}
|
||||
}
|
||||
|
||||
func RunPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth, user string) {
|
||||
// TableNameParam columns to construct want
|
||||
paramTableColumns := fmt.Sprintf(`[
|
||||
{"data_type": "integer", "column_name": "id", "column_default": "nextval('%s_id_seq'::regclass)", "is_not_nullable": true, "ordinal_position": 1, "column_comment": null},
|
||||
{"data_type": "text", "column_name": "name", "column_default": null, "is_not_nullable": false, "ordinal_position": 2, "column_comment": null}
|
||||
]`, tableNameParam)
|
||||
|
||||
// TableNameAuth columns to construct want
|
||||
authTableColumns := fmt.Sprintf(`[
|
||||
{"data_type": "integer", "column_name": "id", "column_default": "nextval('%s_id_seq'::regclass)", "is_not_nullable": true, "ordinal_position": 1, "column_comment": null},
|
||||
{"data_type": "text", "column_name": "name", "column_default": null, "is_not_nullable": false, "ordinal_position": 2, "column_comment": null},
|
||||
{"data_type": "text", "column_name": "email", "column_default": null, "is_not_nullable": false, "ordinal_position": 3, "column_comment": null}
|
||||
]`, tableNameAuth)
|
||||
|
||||
const (
|
||||
// Template to construct detailed output want
|
||||
detailedObjectTemplate = `{
|
||||
"object_name": "%[1]s", "schema_name": "public",
|
||||
"object_details": {
|
||||
"owner": "%[3]s", "comment": null,
|
||||
"indexes": [{"is_primary": true, "is_unique": true, "index_name": "%[1]s_pkey", "index_method": "btree", "index_columns": ["id"], "index_definition": "CREATE UNIQUE INDEX %[1]s_pkey ON public.%[1]s USING btree (id)"}],
|
||||
"triggers": [], "columns": %[2]s, "object_name": "%[1]s", "object_type": "TABLE", "schema_name": "public",
|
||||
"constraints": [{"constraint_name": "%[1]s_pkey", "constraint_type": "PRIMARY KEY", "constraint_columns": ["id"], "constraint_definition": "PRIMARY KEY (id)", "foreign_key_referenced_table": null, "foreign_key_referenced_columns": null}]
|
||||
}
|
||||
}`
|
||||
|
||||
// Template to construct simple output want
|
||||
simpleObjectTemplate = `{"object_name":"%s", "schema_name":"public", "object_details":{"name":"%s"}}`
|
||||
)
|
||||
|
||||
// Helper to build json for detailed want
|
||||
getDetailedWant := func(tableName, columnJSON string) string {
|
||||
return fmt.Sprintf(detailedObjectTemplate, tableName, columnJSON, user)
|
||||
}
|
||||
|
||||
// Helper to build template for simple want
|
||||
getSimpleWant := func(tableName string) string {
|
||||
return fmt.Sprintf(simpleObjectTemplate, tableName, tableName)
|
||||
}
|
||||
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestBody io.Reader
|
||||
wantStatusCode int
|
||||
want string
|
||||
isAllTables bool
|
||||
}{
|
||||
{
|
||||
name: "invoke list_tables all tables detailed output",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"table_names": ""}`)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)),
|
||||
isAllTables: true,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables all tables simple output",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "simple"}`)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf("[%s,%s]", getSimpleWant(tableNameAuth), getSimpleWant(tableNameParam)),
|
||||
isAllTables: true,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables detailed output",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s"}`, tableNameAuth))),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf("[%s]", getDetailedWant(tableNameAuth, authTableColumns)),
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables simple output",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s", "output_format": "simple"}`, tableNameAuth))),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf("[%s]", getSimpleWant(tableNameAuth)),
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with invalid output format",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "abcd"}`)),
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with malformed table_names parameter",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"table_names": 12345, "output_format": "detailed"}`)),
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with multiple table names",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,%s"}`, tableNameParam, tableNameAuth))),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)),
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with non-existent table",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"table_names": "non_existent_table"}`)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: `null`,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with one existing and one non-existent table",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,non_existent_table"}`, tableNameParam))),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf("[%s]", getDetailedWant(tableNameParam, paramTableColumns)),
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, respBytes := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, nil)
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBytes))
|
||||
}
|
||||
|
||||
if tc.wantStatusCode == http.StatusOK {
|
||||
var bodyWrapper map[string]json.RawMessage
|
||||
|
||||
if err := json.Unmarshal(respBytes, &bodyWrapper); err != nil {
|
||||
t.Fatalf("error parsing response wrapper: %s, body: %s", err, string(respBytes))
|
||||
}
|
||||
|
||||
resultJSON, ok := bodyWrapper["result"]
|
||||
if !ok {
|
||||
t.Fatal("unable to find 'result' in response body")
|
||||
}
|
||||
|
||||
var resultString string
|
||||
if err := json.Unmarshal(resultJSON, &resultString); err != nil {
|
||||
t.Fatalf("'result' is not a JSON-encoded string: %s", err)
|
||||
}
|
||||
|
||||
var got, want []any
|
||||
|
||||
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
||||
t.Fatalf("failed to unmarshal actual result string: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal([]byte(tc.want), &want); err != nil {
|
||||
t.Fatalf("failed to unmarshal expected want string: %v", err)
|
||||
}
|
||||
|
||||
// Checking only the default public schema where the test tables are created to avoid brittle tests.
|
||||
if tc.isAllTables {
|
||||
var filteredGot []any
|
||||
for _, item := range got {
|
||||
if tableMap, ok := item.(map[string]interface{}); ok {
|
||||
if schema, ok := tableMap["schema_name"]; ok && schema == "public" {
|
||||
filteredGot = append(filteredGot, item)
|
||||
}
|
||||
}
|
||||
}
|
||||
got = filteredGot
|
||||
}
|
||||
|
||||
sort.SliceStable(got, func(i, j int) bool {
|
||||
return fmt.Sprintf("%v", got[i]) < fmt.Sprintf("%v", got[j])
|
||||
})
|
||||
sort.SliceStable(want, func(i, j int) bool {
|
||||
return fmt.Sprintf("%v", want[i]) < fmt.Sprintf("%v", want[j])
|
||||
})
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("Unexpected result: got %#v, want: %#v", got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setUpPostgresViews(t *testing.T, ctx context.Context, pool *pgxpool.Pool, viewName, tableName string) func() {
|
||||
createView := fmt.Sprintf("CREATE VIEW %s AS SELECT name FROM %s", viewName, tableName)
|
||||
_, err := pool.Exec(ctx, createView)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create view: %v", err)
|
||||
}
|
||||
return func() {
|
||||
dropView := fmt.Sprintf("DROP VIEW %s", viewName)
|
||||
_, err := pool.Exec(ctx, dropView)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to drop view: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func RunPostgresListViewsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool, tableName string) {
|
||||
viewName1 := "test_view_1" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
dropViewfunc1 := setUpPostgresViews(t, ctx, pool, viewName1, tableName)
|
||||
defer dropViewfunc1()
|
||||
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
requestBody io.Reader
|
||||
wantStatusCode int
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "invoke list_views with newly created view",
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"viewname": "%s"}`, viewName1))),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf(`[{"schemaname":"public","viewname":"%s","viewowner":"postgres"}]`, viewName1),
|
||||
},
|
||||
{
|
||||
name: "invoke list_views with non-existent_view",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"viewname": "non_existent_view"}`)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: `null`,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
const api = "http://127.0.0.1:5000/api/tool/list_views/invoke"
|
||||
resp, body := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
||||
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
|
||||
}
|
||||
if tc.wantStatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
var bodyWrapper struct {
|
||||
Result json.RawMessage `json:"result"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &bodyWrapper); err != nil {
|
||||
t.Fatalf("error decoding response wrapper: %v", err)
|
||||
}
|
||||
|
||||
var resultString string
|
||||
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
||||
resultString = string(bodyWrapper.Result)
|
||||
}
|
||||
|
||||
var got, want any
|
||||
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
||||
t.Fatalf("failed to unmarshal nested result string: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal([]byte(tc.want), &want); err != nil {
|
||||
t.Fatalf("failed to unmarshal want string: %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("Unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func RunPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
||||
schemaName := "test_schema_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
cleanup := setupPostgresSchemas(t, ctx, pool, schemaName)
|
||||
@@ -1186,20 +1366,9 @@ func RunPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxpool
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
const api = "http://127.0.0.1:5000/api/tool/list_schemas/invoke"
|
||||
req, err := http.NewRequest(http.MethodPost, api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %v", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
|
||||
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
||||
}
|
||||
if tc.wantStatusCode != http.StatusOK {
|
||||
return
|
||||
@@ -1208,7 +1377,7 @@ func RunPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxpool
|
||||
var bodyWrapper struct {
|
||||
Result json.RawMessage `json:"result"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil {
|
||||
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
||||
t.Fatalf("error decoding response wrapper: %v", err)
|
||||
}
|
||||
|
||||
@@ -1229,6 +1398,191 @@ func RunPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxpool
|
||||
}
|
||||
}
|
||||
|
||||
func RunPostgresListActiveQueriesTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
||||
type queryListDetails struct {
|
||||
ProcessId any `json:"pid"`
|
||||
User string `json:"user"`
|
||||
Datname string `json:"datname"`
|
||||
ApplicationName string `json:"application_name"`
|
||||
ClientAddress string `json:"client_addr"`
|
||||
State string `json:"state"`
|
||||
WaitEventType string `json:"wait_event_type"`
|
||||
WaitEvent string `json:"wait_event"`
|
||||
BackendStart any `json:"backend_start"`
|
||||
TransactionStart any `json:"xact_start"`
|
||||
QueryStart any `json:"query_start"`
|
||||
QueryDuration any `json:"query_duration"`
|
||||
Query string `json:"query"`
|
||||
}
|
||||
|
||||
singleQueryWanted := queryListDetails{
|
||||
ProcessId: any(nil),
|
||||
User: "",
|
||||
Datname: "",
|
||||
ApplicationName: "",
|
||||
ClientAddress: "",
|
||||
State: "",
|
||||
WaitEventType: "",
|
||||
WaitEvent: "",
|
||||
BackendStart: any(nil),
|
||||
TransactionStart: any(nil),
|
||||
QueryStart: any(nil),
|
||||
QueryDuration: any(nil),
|
||||
Query: "SELECT pg_sleep(10);",
|
||||
}
|
||||
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
requestBody io.Reader
|
||||
clientSleepSecs int
|
||||
waitSecsBeforeCheck int
|
||||
wantStatusCode int
|
||||
want any
|
||||
}{
|
||||
// exclude background monitoring apps such as "wal_uploader"
|
||||
{
|
||||
name: "invoke list_active_queries when the system is idle",
|
||||
requestBody: bytes.NewBufferString(`{"exclude_application_names": "wal_uploader"}`),
|
||||
clientSleepSecs: 0,
|
||||
waitSecsBeforeCheck: 0,
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: []queryListDetails(nil),
|
||||
},
|
||||
{
|
||||
name: "invoke list_active_queries when there is 1 ongoing but lower than the threshold",
|
||||
requestBody: bytes.NewBufferString(`{"min_duration": "100 seconds", "exclude_application_names": "wal_uploader"}`),
|
||||
clientSleepSecs: 1,
|
||||
waitSecsBeforeCheck: 1,
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: []queryListDetails(nil),
|
||||
},
|
||||
{
|
||||
name: "invoke list_active_queries when 1 ongoing query should show up",
|
||||
requestBody: bytes.NewBufferString(`{"min_duration": "1 seconds", "exclude_application_names": "wal_uploader"}`),
|
||||
clientSleepSecs: 10,
|
||||
waitSecsBeforeCheck: 5,
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: []queryListDetails{singleQueryWanted},
|
||||
},
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.clientSleepSecs > 0 {
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
err := pool.Ping(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("unable to connect to test database: %s", err)
|
||||
return
|
||||
}
|
||||
_, err = pool.Exec(ctx, fmt.Sprintf("SELECT pg_sleep(%d);", tc.clientSleepSecs))
|
||||
if err != nil {
|
||||
t.Errorf("Executing 'SELECT pg_sleep' failed: %s", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if tc.waitSecsBeforeCheck > 0 {
|
||||
time.Sleep(time.Duration(tc.waitSecsBeforeCheck) * time.Second)
|
||||
}
|
||||
|
||||
const api = "http://127.0.0.1:5000/api/tool/list_active_queries/invoke"
|
||||
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
||||
}
|
||||
if tc.wantStatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
var bodyWrapper struct {
|
||||
Result json.RawMessage `json:"result"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
||||
t.Fatalf("error decoding response wrapper: %v", err)
|
||||
}
|
||||
|
||||
var resultString string
|
||||
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
||||
resultString = string(bodyWrapper.Result)
|
||||
}
|
||||
|
||||
var got any
|
||||
var details []queryListDetails
|
||||
if err := json.Unmarshal([]byte(resultString), &details); err != nil {
|
||||
t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err)
|
||||
}
|
||||
got = details
|
||||
|
||||
if diff := cmp.Diff(tc.want, got, cmp.Comparer(func(a, b queryListDetails) bool {
|
||||
return a.Query == b.Query
|
||||
})); diff != "" {
|
||||
t.Errorf("Unexpected result: got %#v, want: %#v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func RunPostgresListAvailableExtensionsTest(t *testing.T) {
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestBody io.Reader
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "invoke list_available_extensions output",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_available_extensions/invoke",
|
||||
wantStatusCode: http.StatusOK,
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, nil)
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// Intentionally not adding the output check as output depends on the postgres instance used where the the functional test runs.
|
||||
// Adding the check will make the test flaky.
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func RunPostgresListInstalledExtensionsTest(t *testing.T) {
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestBody io.Reader
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "invoke list_installed_extensions output",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_installed_extensions/invoke",
|
||||
wantStatusCode: http.StatusOK,
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, bodyBytes := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, nil)
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Intentionally not adding the output check as output depends on the postgres instance used where the the functional test runs.
|
||||
// Adding the check will make the test flaky.
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RunMySQLListTablesTest run tests against the mysql-list-tables tool
|
||||
func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNameAuth string) {
|
||||
type tableInfo struct {
|
||||
@@ -1335,20 +1689,8 @@ func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNam
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
const api = "http://127.0.0.1:5000/api/tool/list_tables/invoke"
|
||||
req, err := http.NewRequest(http.MethodPost, api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %v", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
resp, body := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
|
||||
}
|
||||
if tc.wantStatusCode != http.StatusOK {
|
||||
@@ -1358,7 +1700,7 @@ func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNam
|
||||
var bodyWrapper struct {
|
||||
Result json.RawMessage `json:"result"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil {
|
||||
if err := json.Unmarshal(body, &bodyWrapper); err != nil {
|
||||
t.Fatalf("error decoding response wrapper: %v", err)
|
||||
}
|
||||
|
||||
@@ -1532,21 +1874,9 @@ func RunMySQLListActiveQueriesTest(t *testing.T, ctx context.Context, pool *sql.
|
||||
}
|
||||
|
||||
const api = "http://127.0.0.1:5000/api/tool/list_active_queries/invoke"
|
||||
req, err := http.NewRequest(http.MethodPost, api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %v", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
|
||||
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
||||
}
|
||||
if tc.wantStatusCode != http.StatusOK {
|
||||
return
|
||||
@@ -1555,7 +1885,7 @@ func RunMySQLListActiveQueriesTest(t *testing.T, ctx context.Context, pool *sql.
|
||||
var bodyWrapper struct {
|
||||
Result json.RawMessage `json:"result"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil {
|
||||
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
||||
t.Fatalf("error decoding response wrapper: %v", err)
|
||||
}
|
||||
|
||||
@@ -1765,21 +2095,9 @@ func RunMySQLListTablesMissingUniqueIndexes(t *testing.T, ctx context.Context, p
|
||||
}
|
||||
|
||||
const api = "http://127.0.0.1:5000/api/tool/list_tables_missing_unique_indexes/invoke"
|
||||
req, err := http.NewRequest(http.MethodPost, api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %v", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
|
||||
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
||||
}
|
||||
if tc.wantStatusCode != http.StatusOK {
|
||||
return
|
||||
@@ -1788,7 +2106,7 @@ func RunMySQLListTablesMissingUniqueIndexes(t *testing.T, ctx context.Context, p
|
||||
var bodyWrapper struct {
|
||||
Result json.RawMessage `json:"result"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil {
|
||||
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
||||
t.Fatalf("error decoding response wrapper: %v", err)
|
||||
}
|
||||
|
||||
@@ -1892,21 +2210,9 @@ func RunMySQLListTableFragmentationTest(t *testing.T, databaseName, tableNamePar
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
const api = "http://127.0.0.1:5000/api/tool/list_table_fragmentation/invoke"
|
||||
req, err := http.NewRequest(http.MethodPost, api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %v", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
|
||||
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
||||
}
|
||||
if tc.wantStatusCode != http.StatusOK {
|
||||
return
|
||||
@@ -1915,7 +2221,7 @@ func RunMySQLListTableFragmentationTest(t *testing.T, databaseName, tableNamePar
|
||||
var bodyWrapper struct {
|
||||
Result json.RawMessage `json:"result"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil {
|
||||
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
||||
t.Fatalf("error decoding response wrapper: %v", err)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user