mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 23:48:04 -05:00
## Description The list_database_stats test fails intermittently when run in parallel on shared instances. Specifically, the "filter by tablespace" and "sort by size" test cases fail because they encounter unexpected databases in the pg_default tablespace created by concurrent test runs. This PR narrows the scope of these test cases by filtering for specific database names. This ensures assertions remain isolated to the current test run regardless of other databases present in the shared environment. ``` go test -tags=integration tests/postgres/postgres_integration_test.go ok command-line-arguments 14.455s ``` > Should include a concise description of the changes (bug or feature), it's > impact, along with a summary of the solution ## PR Checklist > Thank you for opening a Pull Request! Before submitting your PR, there are a > few things you can do to make sure it goes smoothly: - [x] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) - [x] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [x] Ensure the tests and linter pass - [x] Code coverage does not decrease (if any source code was changed) - [x] Appropriate docs were updated (if necessary) - [x] Make sure to add `!` if this involve a breaking change 🛠️ Fixes #<1738>
4419 lines
151 KiB
Go
4419 lines
151 KiB
Go
// Copyright 2025 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package tests
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"reflect"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
"github.com/google/go-cmp/cmp/cmpopts"
|
|
"github.com/google/uuid"
|
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
// RunToolGet runs the tool get endpoint
|
|
func RunToolGetTest(t *testing.T) {
|
|
// Test tool get endpoint
|
|
tcs := []struct {
|
|
name string
|
|
api string
|
|
want map[string]any
|
|
}{
|
|
{
|
|
name: "get my-simple-tool",
|
|
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/",
|
|
want: map[string]any{
|
|
"my-simple-tool": map[string]any{
|
|
"description": "Simple tool to test end to end functionality.",
|
|
"parameters": []any{},
|
|
"authRequired": []any{},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
resp, err := http.Get(tc.api)
|
|
if err != nil {
|
|
t.Fatalf("error when sending a request: %s", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != 200 {
|
|
t.Fatalf("response status code is not 200")
|
|
}
|
|
|
|
var body map[string]interface{}
|
|
err = json.NewDecoder(resp.Body).Decode(&body)
|
|
if err != nil {
|
|
t.Fatalf("error parsing response body")
|
|
}
|
|
|
|
got, ok := body["tools"]
|
|
if !ok {
|
|
t.Fatalf("unable to find tools in response body")
|
|
}
|
|
if !reflect.DeepEqual(got, tc.want) {
|
|
t.Fatalf("got %q, want %q", got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func RunToolGetTestByName(t *testing.T, name string, want map[string]any) {
|
|
// Test tool get endpoint
|
|
tcs := []struct {
|
|
name string
|
|
api string
|
|
want map[string]any
|
|
}{
|
|
{
|
|
name: fmt.Sprintf("get %s", name),
|
|
api: fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/", name),
|
|
want: want,
|
|
},
|
|
}
|
|
for _, tc := range tcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
resp, err := http.Get(tc.api)
|
|
if err != nil {
|
|
t.Fatalf("error when sending a request: %s", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != 200 {
|
|
t.Fatalf("response status code is not 200")
|
|
}
|
|
|
|
var body map[string]interface{}
|
|
err = json.NewDecoder(resp.Body).Decode(&body)
|
|
if err != nil {
|
|
t.Fatalf("error parsing response body")
|
|
}
|
|
|
|
got, ok := body["tools"]
|
|
if !ok {
|
|
t.Fatalf("unable to find tools in response body")
|
|
}
|
|
if !reflect.DeepEqual(got, tc.want) {
|
|
t.Fatalf("got %q, want %q", got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// RunToolInvokeSimpleTest runs the tool invoke endpoint with no parameters
|
|
func RunToolInvokeSimpleTest(t *testing.T, name string, simpleWant string) {
|
|
// Test tool invoke endpoint
|
|
invokeTcs := []struct {
|
|
name string
|
|
api string
|
|
requestHeader map[string]string
|
|
requestBody io.Reader
|
|
want string
|
|
isErr bool
|
|
}{
|
|
{
|
|
name: fmt.Sprintf("invoke %s", name),
|
|
api: fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", name),
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
|
want: simpleWant,
|
|
isErr: false,
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Send Tool invocation request
|
|
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader)
|
|
if resp.StatusCode != http.StatusOK {
|
|
if tc.isErr {
|
|
return
|
|
}
|
|
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
// Check response body
|
|
var body map[string]interface{}
|
|
err := json.Unmarshal(respBody, &body)
|
|
if err != nil {
|
|
t.Fatalf("error parsing response body")
|
|
}
|
|
|
|
got, ok := body["result"].(string)
|
|
if !ok {
|
|
t.Fatalf("unable to find result in response body")
|
|
}
|
|
|
|
if !strings.Contains(got, tc.want) {
|
|
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func RunToolInvokeParametersTest(t *testing.T, name string, params []byte, simpleWant string) {
|
|
// Test tool invoke endpoint
|
|
invokeTcs := []struct {
|
|
name string
|
|
api string
|
|
requestHeader map[string]string
|
|
requestBody io.Reader
|
|
want string
|
|
isErr bool
|
|
}{
|
|
{
|
|
name: fmt.Sprintf("invoke %s", name),
|
|
api: fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", name),
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer(params),
|
|
want: simpleWant,
|
|
isErr: false,
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Send Tool invocation request
|
|
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader)
|
|
if resp.StatusCode != http.StatusOK {
|
|
if tc.isErr {
|
|
return
|
|
}
|
|
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
// Check response body
|
|
var body map[string]interface{}
|
|
err := json.Unmarshal(respBody, &body)
|
|
if err != nil {
|
|
t.Fatalf("error parsing response body")
|
|
}
|
|
|
|
got, ok := body["result"].(string)
|
|
if !ok {
|
|
t.Fatalf("unable to find result in response body")
|
|
}
|
|
|
|
if !strings.Contains(got, tc.want) {
|
|
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// RunToolInvoke runs the tool invoke endpoint
|
|
func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOption) {
|
|
// Resolve options
|
|
// Default values for InvokeTestConfig
|
|
configs := &InvokeTestConfig{
|
|
myToolId3NameAliceWant: "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]",
|
|
myToolById4Want: "[{\"id\":4,\"name\":null}]",
|
|
myArrayToolWant: "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]",
|
|
nullWant: "null",
|
|
supportOptionalNullParam: true,
|
|
supportArrayParam: true,
|
|
supportClientAuth: false,
|
|
supportSelect1Want: true,
|
|
supportSelect1Auth: true,
|
|
}
|
|
|
|
// Apply provided options
|
|
for _, option := range options {
|
|
option(configs)
|
|
}
|
|
|
|
// Get ID token
|
|
idToken, err := GetGoogleIdToken(ClientId)
|
|
if err != nil {
|
|
t.Fatalf("error getting Google ID token: %s", err)
|
|
}
|
|
|
|
// Get access token
|
|
accessToken, err := sources.GetIAMAccessToken(t.Context())
|
|
if err != nil {
|
|
t.Fatalf("error getting access token from ADC: %s", err)
|
|
}
|
|
accessToken = "Bearer " + accessToken
|
|
|
|
// Test tool invoke endpoint
|
|
invokeTcs := []struct {
|
|
name string
|
|
api string
|
|
enabled bool
|
|
requestHeader map[string]string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
wantBody string
|
|
}{
|
|
{
|
|
name: "invoke my-simple-tool",
|
|
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
|
enabled: configs.supportSelect1Want,
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
|
wantBody: select1Want,
|
|
wantStatusCode: http.StatusOK,
|
|
},
|
|
{
|
|
name: "invoke my-tool",
|
|
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
|
|
enabled: true,
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(`{"id": 3, "name": "Alice"}`)),
|
|
wantBody: configs.myToolId3NameAliceWant,
|
|
wantStatusCode: http.StatusOK,
|
|
},
|
|
{
|
|
name: "invoke my-tool-by-id with nil response",
|
|
api: "http://127.0.0.1:5000/api/tool/my-tool-by-id/invoke",
|
|
enabled: true,
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(`{"id": 4}`)),
|
|
wantBody: configs.myToolById4Want,
|
|
wantStatusCode: http.StatusOK,
|
|
},
|
|
{
|
|
name: "invoke my-tool-by-name with nil response",
|
|
api: "http://127.0.0.1:5000/api/tool/my-tool-by-name/invoke",
|
|
enabled: configs.supportOptionalNullParam,
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
|
wantBody: configs.nullWant,
|
|
wantStatusCode: http.StatusOK,
|
|
},
|
|
{
|
|
name: "Invoke my-tool without parameters",
|
|
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
|
|
enabled: true,
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
|
wantBody: "",
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "Invoke my-tool with insufficient parameters",
|
|
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
|
|
enabled: true,
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(`{"id": 1}`)),
|
|
wantBody: "",
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "invoke my-array-tool",
|
|
api: "http://127.0.0.1:5000/api/tool/my-array-tool/invoke",
|
|
enabled: configs.supportArrayParam,
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(`{"idArray": [1,2,3], "nameArray": ["Alice", "Sid", "RandomName"], "cmdArray": ["HGETALL", "row3"]}`)),
|
|
wantBody: configs.myArrayToolWant,
|
|
wantStatusCode: http.StatusOK,
|
|
},
|
|
{
|
|
name: "Invoke my-auth-tool with auth token",
|
|
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
|
enabled: configs.supportSelect1Auth,
|
|
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
|
wantBody: configs.myAuthToolWant,
|
|
wantStatusCode: http.StatusOK,
|
|
},
|
|
{
|
|
name: "Invoke my-auth-tool with invalid auth token",
|
|
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
|
enabled: configs.supportSelect1Auth,
|
|
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
|
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
|
wantBody: "",
|
|
wantStatusCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "Invoke my-auth-tool without auth token",
|
|
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
|
enabled: true,
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
|
wantBody: "",
|
|
wantStatusCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "Invoke my-auth-required-tool with auth token",
|
|
api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke",
|
|
enabled: configs.supportSelect1Auth,
|
|
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
|
wantBody: select1Want,
|
|
wantStatusCode: http.StatusOK,
|
|
},
|
|
{
|
|
name: "Invoke my-auth-required-tool with invalid auth token",
|
|
api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke",
|
|
enabled: true,
|
|
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
|
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
|
wantBody: "",
|
|
wantStatusCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "Invoke my-auth-required-tool without auth token",
|
|
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
|
enabled: true,
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
|
wantBody: "",
|
|
wantStatusCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "Invoke my-client-auth-tool with auth token",
|
|
api: "http://127.0.0.1:5000/api/tool/my-client-auth-tool/invoke",
|
|
enabled: configs.supportClientAuth,
|
|
requestHeader: map[string]string{"Authorization": accessToken},
|
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
|
wantBody: select1Want,
|
|
wantStatusCode: http.StatusOK,
|
|
},
|
|
{
|
|
name: "Invoke my-client-auth-tool without auth token",
|
|
api: "http://127.0.0.1:5000/api/tool/my-client-auth-tool/invoke",
|
|
enabled: configs.supportClientAuth,
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
|
wantStatusCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
|
|
name: "Invoke my-client-auth-tool with invalid auth token",
|
|
api: "http://127.0.0.1:5000/api/tool/my-client-auth-tool/invoke",
|
|
enabled: configs.supportClientAuth,
|
|
requestHeader: map[string]string{"Authorization": "Bearer invalid-token"},
|
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
|
wantStatusCode: http.StatusUnauthorized,
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
if !tc.enabled {
|
|
return
|
|
}
|
|
// Send Tool invocation request
|
|
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader)
|
|
|
|
// Check status code
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Errorf("StatusCode mismatch: got %d, want %d. Response body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
|
}
|
|
|
|
// skip response body check
|
|
if tc.wantBody == "" {
|
|
return
|
|
}
|
|
|
|
// Check response body
|
|
var body map[string]interface{}
|
|
err = json.Unmarshal(respBody, &body)
|
|
if err != nil {
|
|
t.Fatalf("error parsing response body: %s", err)
|
|
}
|
|
|
|
got, ok := body["result"].(string)
|
|
if !ok {
|
|
t.Fatalf("unable to find result in response body")
|
|
}
|
|
|
|
if got != tc.wantBody {
|
|
t.Fatalf("unexpected value: got %q, want %q", got, tc.wantBody)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// RunToolInvokeWithTemplateParameters runs tool invoke test cases with template parameters.
|
|
func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options ...TemplateParamOption) {
|
|
// Resolve options
|
|
// Default values for TemplateParameterTestConfig
|
|
configs := &TemplateParameterTestConfig{
|
|
ddlWant: "null",
|
|
selectAllWant: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"},{\"age\":100,\"id\":2,\"name\":\"Alice\"}]",
|
|
selectId1Want: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]",
|
|
selectNameWant: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]",
|
|
selectEmptyWant: "null",
|
|
insert1Want: "null",
|
|
|
|
nameFieldArray: `["name"]`,
|
|
nameColFilter: "name",
|
|
createColArray: `["id INT","name VARCHAR(20)","age INT"]`,
|
|
|
|
supportDdl: true,
|
|
supportInsert: true,
|
|
}
|
|
|
|
// Apply provided options
|
|
for _, option := range options {
|
|
option(configs)
|
|
}
|
|
|
|
selectOnlyNamesWant := "[{\"name\":\"Alex\"},{\"name\":\"Alice\"}]"
|
|
|
|
// Test tool invoke endpoint
|
|
invokeTcs := []struct {
|
|
name string
|
|
enabled bool
|
|
ddl bool
|
|
insert bool
|
|
api string
|
|
requestHeader map[string]string
|
|
requestBody io.Reader
|
|
want string
|
|
isErr bool
|
|
}{
|
|
{
|
|
name: "invoke create-table-templateParams-tool",
|
|
ddl: true,
|
|
api: "http://127.0.0.1:5000/api/tool/create-table-templateParams-tool/invoke",
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "columns":%s}`, tableName, configs.createColArray))),
|
|
want: configs.ddlWant,
|
|
isErr: false,
|
|
},
|
|
{
|
|
name: "invoke insert-table-templateParams-tool",
|
|
insert: true,
|
|
api: "http://127.0.0.1:5000/api/tool/insert-table-templateParams-tool/invoke",
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "columns":["id","name","age"], "values":"1, 'Alex', 21"}`, tableName))),
|
|
want: configs.insert1Want,
|
|
isErr: false,
|
|
},
|
|
{
|
|
name: "invoke insert-table-templateParams-tool",
|
|
insert: true,
|
|
api: "http://127.0.0.1:5000/api/tool/insert-table-templateParams-tool/invoke",
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "columns":["id","name","age"], "values":"2, 'Alice', 100"}`, tableName))),
|
|
want: configs.insert1Want,
|
|
isErr: false,
|
|
},
|
|
{
|
|
name: "invoke select-templateParams-tool",
|
|
api: "http://127.0.0.1:5000/api/tool/select-templateParams-tool/invoke",
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s"}`, tableName))),
|
|
want: configs.selectAllWant,
|
|
isErr: false,
|
|
},
|
|
{
|
|
name: "invoke select-templateParams-combined-tool",
|
|
api: "http://127.0.0.1:5000/api/tool/select-templateParams-combined-tool/invoke",
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"id": 1, "tableName": "%s"}`, tableName))),
|
|
want: configs.selectId1Want,
|
|
isErr: false,
|
|
},
|
|
{
|
|
name: "invoke select-templateParams-combined-tool with no results",
|
|
api: "http://127.0.0.1:5000/api/tool/select-templateParams-combined-tool/invoke",
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"id": 999, "tableName": "%s"}`, tableName))),
|
|
want: configs.selectEmptyWant,
|
|
isErr: false,
|
|
},
|
|
{
|
|
name: "invoke select-fields-templateParams-tool",
|
|
enabled: configs.supportSelectFields,
|
|
api: "http://127.0.0.1:5000/api/tool/select-fields-templateParams-tool/invoke",
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "fields":%s}`, tableName, configs.nameFieldArray))),
|
|
want: selectOnlyNamesWant,
|
|
isErr: false,
|
|
},
|
|
{
|
|
name: "invoke select-filter-templateParams-combined-tool",
|
|
api: "http://127.0.0.1:5000/api/tool/select-filter-templateParams-combined-tool/invoke",
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"name": "Alex", "tableName": "%s", "columnFilter": "%s"}`, tableName, configs.nameColFilter))),
|
|
want: configs.selectNameWant,
|
|
isErr: false,
|
|
},
|
|
{
|
|
name: "invoke drop-table-templateParams-tool",
|
|
ddl: true,
|
|
api: "http://127.0.0.1:5000/api/tool/drop-table-templateParams-tool/invoke",
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s"}`, tableName))),
|
|
want: configs.ddlWant,
|
|
isErr: false,
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
if !tc.enabled {
|
|
return
|
|
}
|
|
// if test case is DDL and source support ddl test cases
|
|
ddlAllow := !tc.ddl || (tc.ddl && configs.supportDdl)
|
|
// if test case is insert statement and source support insert test cases
|
|
insertAllow := !tc.insert || (tc.insert && configs.supportInsert)
|
|
if ddlAllow && insertAllow {
|
|
// Send Tool invocation request
|
|
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader)
|
|
if resp.StatusCode != http.StatusOK {
|
|
if tc.isErr {
|
|
return
|
|
}
|
|
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
// Check response body
|
|
var body map[string]interface{}
|
|
err := json.Unmarshal(respBody, &body)
|
|
if err != nil {
|
|
t.Fatalf("error parsing response body")
|
|
}
|
|
|
|
got, ok := body["result"].(string)
|
|
if !ok {
|
|
t.Fatalf("unable to find result in response body")
|
|
}
|
|
|
|
if got != tc.want {
|
|
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement, select1Want string, options ...ExecuteSqlOption) {
|
|
// Resolve options
|
|
// Default values for ExecuteSqlTestConfig
|
|
configs := &ExecuteSqlTestConfig{
|
|
select1Statement: `"SELECT 1"`,
|
|
}
|
|
|
|
// Apply provided options
|
|
for _, option := range options {
|
|
option(configs)
|
|
}
|
|
|
|
// Get ID token
|
|
idToken, err := GetGoogleIdToken(ClientId)
|
|
if err != nil {
|
|
t.Fatalf("error getting Google ID token: %s", err)
|
|
}
|
|
|
|
// Test tool invoke endpoint
|
|
invokeTcs := []struct {
|
|
name string
|
|
api string
|
|
requestHeader map[string]string
|
|
requestBody io.Reader
|
|
want string
|
|
isErr bool
|
|
}{
|
|
{
|
|
name: "invoke my-exec-sql-tool",
|
|
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql": %s}`, configs.select1Statement))),
|
|
want: select1Want,
|
|
isErr: false,
|
|
},
|
|
{
|
|
name: "invoke my-exec-sql-tool create table",
|
|
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql": %s}`, createTableStatement))),
|
|
want: "null",
|
|
isErr: false,
|
|
},
|
|
{
|
|
name: "invoke my-exec-sql-tool select table",
|
|
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT * FROM t"}`)),
|
|
want: "null",
|
|
isErr: false,
|
|
},
|
|
{
|
|
name: "invoke my-exec-sql-tool drop table",
|
|
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(`{"sql":"DROP TABLE t"}`)),
|
|
want: "null",
|
|
isErr: false,
|
|
},
|
|
{
|
|
name: "invoke my-exec-sql-tool without body",
|
|
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
|
isErr: true,
|
|
},
|
|
{
|
|
name: "Invoke my-auth-exec-sql-tool with auth token",
|
|
api: "http://127.0.0.1:5000/api/tool/my-auth-exec-sql-tool/invoke",
|
|
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql": %s}`, configs.select1Statement))),
|
|
isErr: false,
|
|
want: select1Want,
|
|
},
|
|
{
|
|
name: "Invoke my-auth-exec-sql-tool with invalid auth token",
|
|
api: "http://127.0.0.1:5000/api/tool/my-auth-exec-sql-tool/invoke",
|
|
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql": %s}`, configs.select1Statement))),
|
|
isErr: true,
|
|
},
|
|
{
|
|
name: "Invoke my-auth-exec-sql-tool without auth token",
|
|
api: "http://127.0.0.1:5000/api/tool/my-auth-exec-sql-tool/invoke",
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql": %s}`, configs.select1Statement))),
|
|
isErr: true,
|
|
},
|
|
{
|
|
name: "invoke my-exec-sql-tool with invalid SELECT SQL",
|
|
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT * FROM non_existent_table"}`)),
|
|
isErr: true,
|
|
},
|
|
{
|
|
name: "invoke my-exec-sql-tool with invalid ALTER SQL",
|
|
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
|
|
requestHeader: map[string]string{},
|
|
requestBody: bytes.NewBuffer([]byte(`{"sql":"ALTER TALE t ALTER COLUMN id DROP NOT NULL"}`)),
|
|
isErr: true,
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Send Tool invocation request
|
|
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader)
|
|
if resp.StatusCode != http.StatusOK {
|
|
if tc.isErr {
|
|
return
|
|
}
|
|
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
// Check response body
|
|
var body map[string]interface{}
|
|
err = json.Unmarshal(respBody, &body)
|
|
if err != nil {
|
|
t.Fatalf("error parsing response body")
|
|
}
|
|
|
|
got, ok := body["result"].(string)
|
|
if !ok {
|
|
t.Fatalf("unable to find result in response body")
|
|
}
|
|
|
|
if got != tc.want {
|
|
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// RunInitialize runs the initialize lifecycle for mcp to set up client-server connection
|
|
func RunInitialize(t *testing.T, protocolVersion string) string {
|
|
url := "http://127.0.0.1:5000/mcp"
|
|
|
|
initializeRequestBody := map[string]any{
|
|
"jsonrpc": "2.0",
|
|
"id": "mcp-initialize",
|
|
"method": "initialize",
|
|
"params": map[string]any{
|
|
"protocolVersion": protocolVersion,
|
|
},
|
|
}
|
|
reqMarshal, err := json.Marshal(initializeRequestBody)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error during marshaling of body")
|
|
}
|
|
|
|
resp, _ := RunRequest(t, http.MethodPost, url, bytes.NewBuffer(reqMarshal), nil)
|
|
if resp.StatusCode != 200 {
|
|
t.Fatalf("response status code is not 200")
|
|
}
|
|
|
|
if contentType := resp.Header.Get("Content-type"); contentType != "application/json" {
|
|
t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType)
|
|
}
|
|
|
|
sessionId := resp.Header.Get("Mcp-Session-Id")
|
|
|
|
header := map[string]string{}
|
|
if sessionId != "" {
|
|
header["Mcp-Session-Id"] = sessionId
|
|
}
|
|
|
|
initializeNotificationBody := map[string]any{
|
|
"jsonrpc": "2.0",
|
|
"method": "notifications/initialized",
|
|
}
|
|
notiMarshal, err := json.Marshal(initializeNotificationBody)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error during marshaling of notifications body")
|
|
}
|
|
|
|
_, _ = RunRequest(t, http.MethodPost, url, bytes.NewBuffer(notiMarshal), header)
|
|
return sessionId
|
|
}
|
|
|
|
// RunMCPToolCallMethod runs the tool/call for mcp endpoint
|
|
func RunMCPToolCallMethod(t *testing.T, myFailToolWant, select1Want string, options ...McpTestOption) {
|
|
// Resolve options
|
|
// Default values for MCPTestConfig
|
|
configs := &MCPTestConfig{
|
|
myToolId3NameAliceWant: `{"jsonrpc":"2.0","id":"my-tool","result":{"content":[{"type":"text","text":"{\"id\":1,\"name\":\"Alice\"}"},{"type":"text","text":"{\"id\":3,\"name\":\"Sid\"}"}]}}`,
|
|
mcpSelect1Want: select1Want,
|
|
supportClientAuth: false,
|
|
supportSelect1Auth: true,
|
|
}
|
|
|
|
// Apply provided options
|
|
for _, option := range options {
|
|
option(configs)
|
|
}
|
|
|
|
sessionId := RunInitialize(t, "2024-11-05")
|
|
|
|
// Get access token
|
|
accessToken, err := sources.GetIAMAccessToken(t.Context())
|
|
if err != nil {
|
|
t.Fatalf("error getting access token from ADC: %s", err)
|
|
}
|
|
accessToken = "Bearer " + accessToken
|
|
|
|
idToken, err := GetGoogleIdToken(ClientId)
|
|
if err != nil {
|
|
t.Fatalf("error getting Google ID token: %s", err)
|
|
}
|
|
|
|
// Test tool invoke endpoint
|
|
invokeTcs := []struct {
|
|
name string
|
|
api string
|
|
enabled bool // switch to turn on/off the test case
|
|
requestBody jsonrpc.JSONRPCRequest
|
|
requestHeader map[string]string
|
|
wantStatusCode int
|
|
wantBody string
|
|
}{
|
|
{
|
|
name: "MCP Invoke my-tool",
|
|
api: "http://127.0.0.1:5000/mcp",
|
|
enabled: true,
|
|
requestHeader: map[string]string{},
|
|
requestBody: jsonrpc.JSONRPCRequest{
|
|
Jsonrpc: "2.0",
|
|
Id: "my-tool",
|
|
Request: jsonrpc.Request{
|
|
Method: "tools/call",
|
|
},
|
|
Params: map[string]any{
|
|
"name": "my-tool",
|
|
"arguments": map[string]any{
|
|
"id": int(3),
|
|
"name": "Alice",
|
|
},
|
|
},
|
|
},
|
|
wantStatusCode: http.StatusOK,
|
|
wantBody: configs.myToolId3NameAliceWant,
|
|
},
|
|
{
|
|
name: "MCP Invoke invalid tool",
|
|
api: "http://127.0.0.1:5000/mcp",
|
|
enabled: true,
|
|
requestHeader: map[string]string{},
|
|
requestBody: jsonrpc.JSONRPCRequest{
|
|
Jsonrpc: "2.0",
|
|
Id: "invalid-tool",
|
|
Request: jsonrpc.Request{
|
|
Method: "tools/call",
|
|
},
|
|
Params: map[string]any{
|
|
"name": "foo",
|
|
"arguments": map[string]any{},
|
|
},
|
|
},
|
|
wantStatusCode: http.StatusOK,
|
|
wantBody: `{"jsonrpc":"2.0","id":"invalid-tool","error":{"code":-32602,"message":"invalid tool name: tool with name \"foo\" does not exist"}}`,
|
|
},
|
|
{
|
|
name: "MCP Invoke my-tool without parameters",
|
|
api: "http://127.0.0.1:5000/mcp",
|
|
enabled: true,
|
|
requestHeader: map[string]string{},
|
|
requestBody: jsonrpc.JSONRPCRequest{
|
|
Jsonrpc: "2.0",
|
|
Id: "invoke-without-parameter",
|
|
Request: jsonrpc.Request{
|
|
Method: "tools/call",
|
|
},
|
|
Params: map[string]any{
|
|
"name": "my-tool",
|
|
"arguments": map[string]any{},
|
|
},
|
|
},
|
|
wantStatusCode: http.StatusOK,
|
|
wantBody: `{"jsonrpc":"2.0","id":"invoke-without-parameter","error":{"code":-32602,"message":"provided parameters were invalid: parameter \"id\" is required"}}`,
|
|
},
|
|
{
|
|
name: "MCP Invoke my-tool with insufficient parameters",
|
|
api: "http://127.0.0.1:5000/mcp",
|
|
enabled: true,
|
|
requestHeader: map[string]string{},
|
|
requestBody: jsonrpc.JSONRPCRequest{
|
|
Jsonrpc: "2.0",
|
|
Id: "invoke-insufficient-parameter",
|
|
Request: jsonrpc.Request{
|
|
Method: "tools/call",
|
|
},
|
|
Params: map[string]any{
|
|
"name": "my-tool",
|
|
"arguments": map[string]any{"id": 1},
|
|
},
|
|
},
|
|
wantStatusCode: http.StatusOK,
|
|
wantBody: `{"jsonrpc":"2.0","id":"invoke-insufficient-parameter","error":{"code":-32602,"message":"provided parameters were invalid: parameter \"name\" is required"}}`,
|
|
},
|
|
{
|
|
name: "MCP Invoke my-auth-required-tool",
|
|
api: "http://127.0.0.1:5000/mcp",
|
|
enabled: configs.supportSelect1Auth,
|
|
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
|
requestBody: jsonrpc.JSONRPCRequest{
|
|
Jsonrpc: "2.0",
|
|
Id: "invoke my-auth-required-tool",
|
|
Request: jsonrpc.Request{
|
|
Method: "tools/call",
|
|
},
|
|
Params: map[string]any{
|
|
"name": "my-auth-required-tool",
|
|
"arguments": map[string]any{},
|
|
},
|
|
},
|
|
wantStatusCode: http.StatusOK,
|
|
wantBody: configs.mcpSelect1Want,
|
|
},
|
|
{
|
|
name: "MCP Invoke my-auth-required-tool with invalid auth token",
|
|
api: "http://127.0.0.1:5000/mcp",
|
|
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
|
|
requestBody: jsonrpc.JSONRPCRequest{
|
|
Jsonrpc: "2.0",
|
|
Id: "invoke my-auth-required-tool with invalid token",
|
|
Request: jsonrpc.Request{
|
|
Method: "tools/call",
|
|
},
|
|
Params: map[string]any{
|
|
"name": "my-auth-required-tool",
|
|
"arguments": map[string]any{},
|
|
},
|
|
},
|
|
wantStatusCode: http.StatusUnauthorized,
|
|
wantBody: "{\"jsonrpc\":\"2.0\",\"id\":\"invoke my-auth-required-tool with invalid token\",\"error\":{\"code\":-32600,\"message\":\"unauthorized Tool call: Please make sure your specify correct auth headers: unauthorized\"}}",
|
|
},
|
|
{
|
|
name: "MCP Invoke my-auth-required-tool without auth token",
|
|
api: "http://127.0.0.1:5000/mcp",
|
|
requestHeader: map[string]string{},
|
|
requestBody: jsonrpc.JSONRPCRequest{
|
|
Jsonrpc: "2.0",
|
|
Id: "invoke my-auth-required-tool without token",
|
|
Request: jsonrpc.Request{
|
|
Method: "tools/call",
|
|
},
|
|
Params: map[string]any{
|
|
"name": "my-auth-required-tool",
|
|
"arguments": map[string]any{},
|
|
},
|
|
},
|
|
wantStatusCode: http.StatusUnauthorized,
|
|
wantBody: "{\"jsonrpc\":\"2.0\",\"id\":\"invoke my-auth-required-tool without token\",\"error\":{\"code\":-32600,\"message\":\"unauthorized Tool call: Please make sure your specify correct auth headers: unauthorized\"}}",
|
|
},
|
|
|
|
{
|
|
name: "MCP Invoke my-client-auth-tool",
|
|
enabled: configs.supportClientAuth,
|
|
api: "http://127.0.0.1:5000/mcp",
|
|
requestHeader: map[string]string{"Authorization": accessToken},
|
|
requestBody: jsonrpc.JSONRPCRequest{
|
|
Jsonrpc: "2.0",
|
|
Id: "invoke my-client-auth-tool",
|
|
Request: jsonrpc.Request{
|
|
Method: "tools/call",
|
|
},
|
|
Params: map[string]any{
|
|
"name": "my-client-auth-tool",
|
|
"arguments": map[string]any{},
|
|
},
|
|
},
|
|
wantStatusCode: http.StatusOK,
|
|
wantBody: "{\"jsonrpc\":\"2.0\",\"id\":\"invoke my-client-auth-tool\",\"result\":{\"content\":[{\"type\":\"text\",\"text\":\"{\\\"f0_\\\":1}\"}]}}",
|
|
},
|
|
{
|
|
name: "MCP Invoke my-client-auth-tool without access token",
|
|
enabled: configs.supportClientAuth,
|
|
api: "http://127.0.0.1:5000/mcp",
|
|
requestHeader: map[string]string{},
|
|
requestBody: jsonrpc.JSONRPCRequest{
|
|
Jsonrpc: "2.0",
|
|
Id: "invoke my-client-auth-tool",
|
|
Request: jsonrpc.Request{
|
|
Method: "tools/call",
|
|
},
|
|
Params: map[string]any{
|
|
"name": "my-client-auth-tool",
|
|
"arguments": map[string]any{},
|
|
},
|
|
},
|
|
wantStatusCode: http.StatusUnauthorized,
|
|
wantBody: "{\"jsonrpc\":\"2.0\",\"id\":\"invoke my-client-auth-tool\",\"error\":{\"code\":-32600,\"message\":\"missing access token in the 'Authorization' header\"}",
|
|
},
|
|
{
|
|
name: "MCP Invoke my-client-auth-tool with invalid access token",
|
|
enabled: configs.supportClientAuth,
|
|
api: "http://127.0.0.1:5000/mcp",
|
|
requestHeader: map[string]string{"Authorization": "Bearer invalid-token"},
|
|
requestBody: jsonrpc.JSONRPCRequest{
|
|
Jsonrpc: "2.0",
|
|
Id: "invoke my-client-auth-tool",
|
|
Request: jsonrpc.Request{
|
|
Method: "tools/call",
|
|
},
|
|
Params: map[string]any{
|
|
"name": "my-client-auth-tool",
|
|
"arguments": map[string]any{},
|
|
},
|
|
},
|
|
wantStatusCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "MCP Invoke my-fail-tool",
|
|
api: "http://127.0.0.1:5000/mcp",
|
|
enabled: true,
|
|
requestHeader: map[string]string{},
|
|
requestBody: jsonrpc.JSONRPCRequest{
|
|
Jsonrpc: "2.0",
|
|
Id: "invoke-fail-tool",
|
|
Request: jsonrpc.Request{
|
|
Method: "tools/call",
|
|
},
|
|
Params: map[string]any{
|
|
"name": "my-fail-tool",
|
|
"arguments": map[string]any{"id": 1},
|
|
},
|
|
},
|
|
wantStatusCode: http.StatusOK,
|
|
wantBody: myFailToolWant,
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
if !tc.enabled {
|
|
return
|
|
}
|
|
reqMarshal, err := json.Marshal(tc.requestBody)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error during marshaling of request body")
|
|
}
|
|
|
|
// add headers
|
|
headers := map[string]string{}
|
|
if sessionId != "" {
|
|
headers["Mcp-Session-Id"] = sessionId
|
|
}
|
|
for key, value := range tc.requestHeader {
|
|
headers[key] = value
|
|
}
|
|
|
|
httpResponse, respBody := RunRequest(t, http.MethodPost, tc.api, bytes.NewBuffer(reqMarshal), headers)
|
|
|
|
// Check status code
|
|
if httpResponse.StatusCode != tc.wantStatusCode {
|
|
t.Errorf("StatusCode mismatch: got %d, want %d", httpResponse.StatusCode, tc.wantStatusCode)
|
|
}
|
|
|
|
// Check response body
|
|
got := string(bytes.TrimSpace(respBody))
|
|
if !strings.Contains(got, tc.wantBody) {
|
|
t.Fatalf("Expected substring not found:\ngot: %q\nwant: %q (to be contained within got)", got, tc.wantBody)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func setupPostgresSchemas(t *testing.T, ctx context.Context, pool *pgxpool.Pool, schemaName string) func() {
|
|
createSchemaStmt := fmt.Sprintf("CREATE SCHEMA %s", schemaName)
|
|
_, err := pool.Exec(ctx, createSchemaStmt)
|
|
if err != nil {
|
|
t.Fatalf("failed to create schema: %v", err)
|
|
}
|
|
|
|
return func() {
|
|
dropSchemaStmt := fmt.Sprintf("DROP SCHEMA %s CASCADE", schemaName)
|
|
_, err := pool.Exec(ctx, dropSchemaStmt)
|
|
if err != nil {
|
|
t.Fatalf("failed to drop schema: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func RunPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth, user string) {
|
|
// TableNameParam columns to construct want
|
|
paramTableColumns := fmt.Sprintf(`[
|
|
{"data_type": "integer", "column_name": "id", "column_default": "nextval('%s_id_seq'::regclass)", "is_not_nullable": true, "ordinal_position": 1, "column_comment": null},
|
|
{"data_type": "text", "column_name": "name", "column_default": null, "is_not_nullable": false, "ordinal_position": 2, "column_comment": null}
|
|
]`, tableNameParam)
|
|
|
|
// TableNameAuth columns to construct want
|
|
authTableColumns := fmt.Sprintf(`[
|
|
{"data_type": "integer", "column_name": "id", "column_default": "nextval('%s_id_seq'::regclass)", "is_not_nullable": true, "ordinal_position": 1, "column_comment": null},
|
|
{"data_type": "text", "column_name": "name", "column_default": null, "is_not_nullable": false, "ordinal_position": 2, "column_comment": null},
|
|
{"data_type": "text", "column_name": "email", "column_default": null, "is_not_nullable": false, "ordinal_position": 3, "column_comment": null}
|
|
]`, tableNameAuth)
|
|
|
|
const (
|
|
// Template to construct detailed output want
|
|
detailedObjectTemplate = `{
|
|
"object_name": "%[1]s", "schema_name": "public",
|
|
"object_details": {
|
|
"owner": "%[3]s", "comment": null,
|
|
"indexes": [{"is_primary": true, "is_unique": true, "index_name": "%[1]s_pkey", "index_method": "btree", "index_columns": ["id"], "index_definition": "CREATE UNIQUE INDEX %[1]s_pkey ON public.%[1]s USING btree (id)"}],
|
|
"triggers": [], "columns": %[2]s, "object_name": "%[1]s", "object_type": "TABLE", "schema_name": "public",
|
|
"constraints": [{"constraint_name": "%[1]s_pkey", "constraint_type": "PRIMARY KEY", "constraint_columns": ["id"], "constraint_definition": "PRIMARY KEY (id)", "foreign_key_referenced_table": null, "foreign_key_referenced_columns": null}]
|
|
}
|
|
}`
|
|
|
|
// Template to construct simple output want
|
|
simpleObjectTemplate = `{"object_name":"%s", "schema_name":"public", "object_details":{"name":"%s"}}`
|
|
)
|
|
|
|
// Helper to build json for detailed want
|
|
getDetailedWant := func(tableName, columnJSON string) string {
|
|
return fmt.Sprintf(detailedObjectTemplate, tableName, columnJSON, user)
|
|
}
|
|
|
|
// Helper to build template for simple want
|
|
getSimpleWant := func(tableName string) string {
|
|
return fmt.Sprintf(simpleObjectTemplate, tableName, tableName)
|
|
}
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
api string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
want string
|
|
isAllTables bool
|
|
}{
|
|
{
|
|
name: "invoke list_tables all tables detailed output",
|
|
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
|
requestBody: bytes.NewBuffer([]byte(`{"table_names": ""}`)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)),
|
|
isAllTables: true,
|
|
},
|
|
{
|
|
name: "invoke list_tables all tables simple output",
|
|
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
|
requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "simple"}`)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: fmt.Sprintf("[%s,%s]", getSimpleWant(tableNameAuth), getSimpleWant(tableNameParam)),
|
|
isAllTables: true,
|
|
},
|
|
{
|
|
name: "invoke list_tables detailed output",
|
|
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s"}`, tableNameAuth))),
|
|
wantStatusCode: http.StatusOK,
|
|
want: fmt.Sprintf("[%s]", getDetailedWant(tableNameAuth, authTableColumns)),
|
|
},
|
|
{
|
|
name: "invoke list_tables simple output",
|
|
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s", "output_format": "simple"}`, tableNameAuth))),
|
|
wantStatusCode: http.StatusOK,
|
|
want: fmt.Sprintf("[%s]", getSimpleWant(tableNameAuth)),
|
|
},
|
|
{
|
|
name: "invoke list_tables with invalid output format",
|
|
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
|
requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "abcd"}`)),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "invoke list_tables with malformed table_names parameter",
|
|
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
|
requestBody: bytes.NewBuffer([]byte(`{"table_names": 12345, "output_format": "detailed"}`)),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "invoke list_tables with multiple table names",
|
|
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,%s"}`, tableNameParam, tableNameAuth))),
|
|
wantStatusCode: http.StatusOK,
|
|
want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)),
|
|
},
|
|
{
|
|
name: "invoke list_tables with non-existent table",
|
|
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
|
requestBody: bytes.NewBuffer([]byte(`{"table_names": "non_existent_table"}`)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: `[]`,
|
|
},
|
|
{
|
|
name: "invoke list_tables with one existing and one non-existent table",
|
|
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,non_existent_table"}`, tableNameParam))),
|
|
wantStatusCode: http.StatusOK,
|
|
want: fmt.Sprintf("[%s]", getDetailedWant(tableNameParam, paramTableColumns)),
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
resp, respBytes := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBytes))
|
|
}
|
|
|
|
if tc.wantStatusCode == http.StatusOK {
|
|
var bodyWrapper map[string]json.RawMessage
|
|
|
|
if err := json.Unmarshal(respBytes, &bodyWrapper); err != nil {
|
|
t.Fatalf("error parsing response wrapper: %s, body: %s", err, string(respBytes))
|
|
}
|
|
|
|
resultJSON, ok := bodyWrapper["result"]
|
|
if !ok {
|
|
t.Fatal("unable to find 'result' in response body")
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(resultJSON, &resultString); err != nil {
|
|
t.Fatalf("'result' is not a JSON-encoded string: %s", err)
|
|
}
|
|
|
|
var got, want []any
|
|
|
|
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
|
t.Fatalf("failed to unmarshal actual result string: %v", err)
|
|
}
|
|
if err := json.Unmarshal([]byte(tc.want), &want); err != nil {
|
|
t.Fatalf("failed to unmarshal expected want string: %v", err)
|
|
}
|
|
|
|
// Checking only the default public schema where the test tables are created to avoid brittle tests.
|
|
if tc.isAllTables {
|
|
var filteredGot []any
|
|
for _, item := range got {
|
|
if tableMap, ok := item.(map[string]interface{}); ok {
|
|
if schema, ok := tableMap["schema_name"]; ok && schema == "public" {
|
|
filteredGot = append(filteredGot, item)
|
|
}
|
|
}
|
|
}
|
|
got = filteredGot
|
|
}
|
|
|
|
sort.SliceStable(got, func(i, j int) bool {
|
|
return fmt.Sprintf("%v", got[i]) < fmt.Sprintf("%v", got[j])
|
|
})
|
|
sort.SliceStable(want, func(i, j int) bool {
|
|
return fmt.Sprintf("%v", want[i]) < fmt.Sprintf("%v", want[j])
|
|
})
|
|
|
|
if !reflect.DeepEqual(got, want) {
|
|
t.Errorf("Unexpected result: got %#v, want: %#v", got, want)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func setUpPostgresViews(t *testing.T, ctx context.Context, pool *pgxpool.Pool, viewName string) func() {
|
|
createView := fmt.Sprintf("CREATE VIEW %s AS SELECT 1 AS col", viewName)
|
|
_, err := pool.Exec(ctx, createView)
|
|
if err != nil {
|
|
t.Fatalf("failed to create view: %v", err)
|
|
}
|
|
return func() {
|
|
dropView := fmt.Sprintf("DROP VIEW %s", viewName)
|
|
_, err := pool.Exec(ctx, dropView)
|
|
if err != nil {
|
|
t.Fatalf("failed to drop view: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func RunPostgresListViewsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
|
//adding this line temporarily
|
|
viewName := "test_view_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
|
dropViewfunc1 := setUpPostgresViews(t, ctx, pool, viewName)
|
|
defer dropViewfunc1()
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
want string
|
|
}{
|
|
{
|
|
name: "invoke list_views with newly created view",
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"view_name": "%s"}`, viewName))),
|
|
wantStatusCode: http.StatusOK,
|
|
want: fmt.Sprintf(`[{"schema_name":"public","view_name":"%s","owner_name":"postgres","definition":" SELECT 1 AS col;"}]`, viewName),
|
|
},
|
|
{
|
|
name: "invoke list_views with non-existent_view",
|
|
requestBody: bytes.NewBuffer([]byte(`{"view_name": "non_existent_view"}`)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: `null`,
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
const api = "http://127.0.0.1:5000/api/tool/list_views/invoke"
|
|
resp, body := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
|
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
|
|
}
|
|
if tc.wantStatusCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var bodyWrapper struct {
|
|
Result json.RawMessage `json:"result"`
|
|
}
|
|
if err := json.Unmarshal(body, &bodyWrapper); err != nil {
|
|
t.Fatalf("error decoding response wrapper: %v", err)
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
|
resultString = string(bodyWrapper.Result)
|
|
}
|
|
|
|
var got, want any
|
|
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
|
t.Fatalf("failed to unmarshal nested result string: %v", err)
|
|
}
|
|
if err := json.Unmarshal([]byte(tc.want), &want); err != nil {
|
|
t.Fatalf("failed to unmarshal want string: %v", err)
|
|
}
|
|
|
|
if diff := cmp.Diff(want, got); diff != "" {
|
|
t.Errorf("Unexpected result (-want +got):\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func RunPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
|
schemaName := "test_schema_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
|
cleanup := setupPostgresSchemas(t, ctx, pool, schemaName)
|
|
defer cleanup()
|
|
|
|
wantSchema := map[string]any{"functions": float64(0), "grants": map[string]any{}, "owner": "postgres", "schema_name": schemaName, "tables": float64(0), "views": float64(0)}
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
want []map[string]any
|
|
compareSubset bool
|
|
}{
|
|
{
|
|
name: "invoke list_schemas with schema_name",
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"schema_name": "%s"}`, schemaName))),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []map[string]any{wantSchema},
|
|
},
|
|
{
|
|
name: "invoke list_schemas with owner name",
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"owner": "%s"}`, "postgres"))),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []map[string]any{wantSchema},
|
|
compareSubset: true,
|
|
},
|
|
{
|
|
name: "invoke list_schemas with limit 1",
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"schema_name": "%s","limit": 1}`, schemaName))),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []map[string]any{wantSchema},
|
|
},
|
|
{
|
|
name: "invoke list_schemas with non-existent schema",
|
|
requestBody: bytes.NewBuffer([]byte(`{"schema_name": "non_existent_schema"}`)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: nil,
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
const api = "http://127.0.0.1:5000/api/tool/list_schemas/invoke"
|
|
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
|
}
|
|
if tc.wantStatusCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var bodyWrapper struct {
|
|
Result json.RawMessage `json:"result"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
|
t.Fatalf("error decoding response wrapper: %v", err)
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
|
resultString = string(bodyWrapper.Result)
|
|
}
|
|
|
|
var got []map[string]any
|
|
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
|
t.Fatalf("failed to unmarshal nested result string: %v", err)
|
|
}
|
|
|
|
if tc.compareSubset {
|
|
// Assert that the 'wantTrigger' is present in the 'got' list.
|
|
found := false
|
|
for _, resultSchema := range got {
|
|
if resultSchema["schema_name"] == wantSchema["schema_name"] {
|
|
found = true
|
|
if diff := cmp.Diff(wantSchema, resultSchema); diff != "" {
|
|
t.Errorf("Mismatch in fields for the expected trigger (-want +got):\n%s", diff)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
t.Errorf("Expected schema '%s' not found in the list of all schemas.", wantSchema)
|
|
}
|
|
} else {
|
|
if diff := cmp.Diff(tc.want, got); diff != "" {
|
|
t.Errorf("Unexpected result (-want +got):\n%s", diff)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func RunPostgresDatabaseOverviewTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
|
const api = "http://127.0.0.1:5000/api/tool/database_overview/invoke"
|
|
requestBody := bytes.NewBuffer([]byte(`{}`))
|
|
|
|
resp, respBody := RunRequest(t, http.MethodPost, api, requestBody, nil)
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, http.StatusOK, string(respBody))
|
|
}
|
|
|
|
var bodyWrapper struct {
|
|
Result json.RawMessage `json:"result"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
|
t.Fatalf("error decoding response wrapper: %v, body: %s", err, string(respBody))
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
|
resultString = string(bodyWrapper.Result)
|
|
}
|
|
|
|
var got []map[string]any
|
|
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
|
t.Fatalf("failed to unmarshal nested result string: %v, result string: %s", err, resultString)
|
|
}
|
|
|
|
if len(got) != 1 {
|
|
t.Fatalf("Expected exactly one row in the result, got %d", len(got))
|
|
}
|
|
|
|
resultRow := got[0]
|
|
|
|
// Define expected keys based on the SELECT statement
|
|
expectedKeys := []string{
|
|
"pg_version",
|
|
"is_replica",
|
|
"uptime",
|
|
"max_connections",
|
|
"current_connections",
|
|
"active_connections",
|
|
"pct_connections_used",
|
|
}
|
|
|
|
for _, key := range expectedKeys {
|
|
if _, ok := resultRow[key]; !ok {
|
|
t.Errorf("Missing expected key in result: %s", key)
|
|
}
|
|
}
|
|
|
|
// Check types of the fields. JSON numbers are unmarshalled into float64.
|
|
if _, ok := resultRow["pg_version"].(string); !ok {
|
|
t.Errorf("Expected 'pg_version' to be a string, got %T", resultRow["pg_version"])
|
|
}
|
|
if _, ok := resultRow["is_replica"].(bool); !ok {
|
|
t.Errorf("Expected 'is_replica' to be a bool, got %T", resultRow["is_replica"])
|
|
}
|
|
if _, ok := resultRow["uptime"].(string); !ok {
|
|
t.Errorf("Expected 'uptime' to be a string, got %T", resultRow["uptime"])
|
|
}
|
|
if _, ok := resultRow["max_connections"].(float64); !ok {
|
|
t.Errorf("Expected 'max_connections' to be a number (float64), got %T", resultRow["max_connections"])
|
|
}
|
|
if _, ok := resultRow["current_connections"].(float64); !ok {
|
|
t.Errorf("Expected 'current_connections' to be a number (float64), got %T", resultRow["current_connections"])
|
|
}
|
|
if _, ok := resultRow["active_connections"].(float64); !ok {
|
|
t.Errorf("Expected 'active_connections' to be a number (float64), got %T", resultRow["active_connections"])
|
|
}
|
|
if _, ok := resultRow["pct_connections_used"].(float64); !ok {
|
|
t.Errorf("Expected 'pct_connections_used' to be a number (float64), got %T", resultRow["pct_connections_used"])
|
|
}
|
|
|
|
// Basic sanity checks on values
|
|
if maxConn, ok := resultRow["max_connections"].(float64); ok {
|
|
if maxConn <= 0 {
|
|
t.Errorf("Expected 'max_connections' to be positive, got %f", maxConn)
|
|
}
|
|
}
|
|
|
|
if pctUsed, ok := resultRow["pct_connections_used"].(float64); ok {
|
|
if pctUsed < 0 || pctUsed > 100 {
|
|
t.Errorf("Expected 'pct_connections_used' to be between 0 and 100, got %f", pctUsed)
|
|
}
|
|
}
|
|
}
|
|
|
|
func setupPostgresTrigger(t *testing.T, ctx context.Context, pool *pgxpool.Pool, schemaName, tableName, functionName, triggerName string) func() {
|
|
t.Helper()
|
|
|
|
createSchemaStmt := fmt.Sprintf("CREATE SCHEMA %s", schemaName)
|
|
if _, err := pool.Exec(ctx, createSchemaStmt); err != nil {
|
|
t.Fatalf("failed to create schema %s: %v", schemaName, err)
|
|
}
|
|
|
|
createTableStmt := fmt.Sprintf("CREATE TABLE %s.%s (id SERIAL PRIMARY KEY, name TEXT)", schemaName, tableName)
|
|
if _, err := pool.Exec(ctx, createTableStmt); err != nil {
|
|
t.Fatalf("failed to create table %s.%s: %v", schemaName, tableName, err)
|
|
}
|
|
|
|
createFunctionStmt := fmt.Sprintf(`
|
|
CREATE OR REPLACE FUNCTION %s.%s() RETURNS TRIGGER AS $$
|
|
BEGIN
|
|
RETURN NEW;
|
|
END;
|
|
$$ LANGUAGE plpgsql;
|
|
`, schemaName, functionName)
|
|
if _, err := pool.Exec(ctx, createFunctionStmt); err != nil {
|
|
t.Fatalf("failed to create function %s.%s: %v", schemaName, functionName, err)
|
|
}
|
|
|
|
createTriggerStmt := fmt.Sprintf(`
|
|
CREATE TRIGGER %s
|
|
AFTER INSERT ON %s.%s
|
|
FOR EACH ROW
|
|
EXECUTE FUNCTION %s.%s();
|
|
`, triggerName, schemaName, tableName, schemaName, functionName)
|
|
if _, err := pool.Exec(ctx, createTriggerStmt); err != nil {
|
|
t.Fatalf("failed to create trigger %s: %v", triggerName, err)
|
|
}
|
|
|
|
return func() {
|
|
dropSchemaStmt := fmt.Sprintf("DROP SCHEMA %s CASCADE", schemaName)
|
|
if _, err := pool.Exec(ctx, dropSchemaStmt); err != nil {
|
|
t.Fatalf("failed to drop schema %s: %v", schemaName, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func RunPostgresListTriggersTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
|
uniqueID := strings.ReplaceAll(uuid.New().String(), "-", "")
|
|
schemaName := "test_schema_" + uniqueID
|
|
tableName := "test_table_" + uniqueID
|
|
functionName := "test_func_" + uniqueID
|
|
triggerName := "test_trigger_" + uniqueID
|
|
|
|
cleanup := setupPostgresTrigger(t, ctx, pool, schemaName, tableName, functionName, triggerName)
|
|
defer cleanup()
|
|
|
|
// Definition can vary slightly based on server version/settings, so we fetch it to compare.
|
|
var expectedDef string
|
|
getDefQuery := fmt.Sprintf("SELECT pg_get_triggerdef(oid) FROM pg_trigger WHERE tgname = '%s'", triggerName)
|
|
err := pool.QueryRow(ctx, getDefQuery).Scan(&expectedDef)
|
|
if err != nil {
|
|
t.Fatalf("failed to fetch trigger definition: %v", err)
|
|
}
|
|
|
|
wantTrigger := map[string]any{
|
|
"trigger_name": triggerName,
|
|
"schema_name": schemaName,
|
|
"table_name": tableName,
|
|
"status": "ENABLED",
|
|
"timing": "AFTER",
|
|
"events": "INSERT",
|
|
"activation_level": "ROW",
|
|
"function_name": functionName,
|
|
"definition": expectedDef,
|
|
}
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
want []map[string]any
|
|
compareSubset bool
|
|
}{
|
|
{
|
|
name: "list all triggers (expecting the one we created)",
|
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []map[string]any{wantTrigger},
|
|
compareSubset: true, // avoid test flakiness in race condition
|
|
},
|
|
{
|
|
name: "filter by trigger_name",
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"trigger_name": "%s"}`, triggerName))),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []map[string]any{wantTrigger},
|
|
},
|
|
{
|
|
name: "filter by schema_name",
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"schema_name": "%s"}`, schemaName))),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []map[string]any{wantTrigger},
|
|
},
|
|
{
|
|
name: "filter by table_name",
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_name": "%s"}`, tableName))),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []map[string]any{wantTrigger},
|
|
},
|
|
{
|
|
name: "filter by non-existent trigger_name",
|
|
requestBody: bytes.NewBuffer([]byte(`{"trigger_name": "non_existent_trigger"}`)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: nil,
|
|
},
|
|
{
|
|
name: "filter by non-existent schema_name",
|
|
requestBody: bytes.NewBuffer([]byte(`{"schema_name": "non_existent_schema"}`)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: nil,
|
|
},
|
|
{
|
|
name: "filter by non-existent table_name",
|
|
requestBody: bytes.NewBuffer([]byte(`{"table_name": "non_existent_table"}`)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: nil,
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
const api = "http://127.0.0.1:5000/api/tool/list_triggers/invoke"
|
|
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
|
}
|
|
if tc.wantStatusCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var bodyWrapper struct {
|
|
Result json.RawMessage `json:"result"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
|
t.Fatalf("error decoding response wrapper: %v", err)
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
|
resultString = string(bodyWrapper.Result)
|
|
}
|
|
|
|
var got []map[string]any
|
|
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
|
t.Fatalf("failed to unmarshal nested result string: %v, content: %s", err, resultString)
|
|
}
|
|
|
|
if tc.compareSubset {
|
|
// Assert that the 'wantTrigger' is present in the 'got' list.
|
|
found := false
|
|
for _, resultTrigger := range got {
|
|
if resultTrigger["trigger_name"] == wantTrigger["trigger_name"] {
|
|
found = true
|
|
if diff := cmp.Diff(wantTrigger, resultTrigger); diff != "" {
|
|
t.Errorf("Mismatch in fields for the expected trigger (-want +got):\n%s", diff)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
t.Errorf("Expected trigger '%s' not found in the list of all triggers.", triggerName)
|
|
}
|
|
} else {
|
|
if diff := cmp.Diff(tc.want, got); diff != "" {
|
|
t.Errorf("Unexpected result (-want +got):\n%s", diff)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func setupPostgresPublicationTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool, tableName string, pubName string) func(t *testing.T) {
|
|
t.Helper()
|
|
createTableStmt := fmt.Sprintf("CREATE TABLE %s (id SERIAL PRIMARY KEY, name TEXT);", tableName)
|
|
if _, err := pool.Exec(ctx, createTableStmt); err != nil {
|
|
t.Fatalf("unable to create table %s: %v", tableName, err)
|
|
}
|
|
|
|
createPubStmt := fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s;", pubName, tableName)
|
|
if _, err := pool.Exec(ctx, createPubStmt); err != nil {
|
|
if _, dropErr := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName)); dropErr != nil {
|
|
t.Errorf("unable to drop table after failing to create publication: %v", dropErr)
|
|
}
|
|
t.Fatalf("unable to create publication %s: %v", pubName, err)
|
|
}
|
|
|
|
return func(t *testing.T) {
|
|
t.Helper()
|
|
if _, err := pool.Exec(ctx, fmt.Sprintf("DROP PUBLICATION IF EXISTS %s;", pubName)); err != nil {
|
|
t.Errorf("unable to drop publication %s: %v", pubName, err)
|
|
}
|
|
if _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName)); err != nil {
|
|
t.Errorf("unable to drop table %s: %v", tableName, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func RunPostgresListPublicationTablesTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
|
table1Name := "pub_table_1"
|
|
pub1Name := "pub_1"
|
|
|
|
table2Name := "pub_table_2"
|
|
pub2Name := "pub_2"
|
|
|
|
cleanup := setupPostgresPublicationTable(t, ctx, pool, table1Name, pub1Name)
|
|
defer cleanup(t)
|
|
cleanup2 := setupPostgresPublicationTable(t, ctx, pool, table2Name, pub2Name)
|
|
defer cleanup2(t)
|
|
|
|
// Fetch the current user to match the publication_owner
|
|
var currentUser string
|
|
err := pool.QueryRow(ctx, "SELECT current_user;").Scan(¤tUser)
|
|
if err != nil {
|
|
t.Fatalf("unable to fetch current user: %v", err)
|
|
}
|
|
|
|
wantTable1 := map[string]any{
|
|
"publication_name": pub1Name,
|
|
"schema_name": "public",
|
|
"table_name": table1Name,
|
|
"publishes_all_tables": false,
|
|
"publishes_inserts": true,
|
|
"publishes_updates": true,
|
|
"publishes_deletes": true,
|
|
"publishes_truncates": true,
|
|
"publication_owner": currentUser,
|
|
}
|
|
|
|
wantTable2 := map[string]any{
|
|
"publication_name": pub2Name,
|
|
"schema_name": "public",
|
|
"table_name": table2Name,
|
|
"publishes_all_tables": false,
|
|
"publishes_inserts": true,
|
|
"publishes_updates": true,
|
|
"publishes_deletes": true,
|
|
"publishes_truncates": true,
|
|
"publication_owner": currentUser,
|
|
}
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
want []map[string]any
|
|
}{
|
|
{
|
|
name: "list all publication tables",
|
|
requestBody: bytes.NewBufferString(`{}`),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []map[string]any{wantTable1, wantTable2},
|
|
},
|
|
{
|
|
name: "list all tables for the created publication",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"publication_names": "%s"}`, pub1Name)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []map[string]any{wantTable1},
|
|
},
|
|
{
|
|
name: "filter by table_name",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s, %s"}`, table1Name, table2Name)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []map[string]any{wantTable1, wantTable2},
|
|
},
|
|
{
|
|
name: "filter by schema_name and table_name",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_names": "public", "table_name": "%s , %s"}`, table1Name, table2Name)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []map[string]any{wantTable1, wantTable2},
|
|
},
|
|
{
|
|
name: "invoke list_publication_tables with non-existent table",
|
|
requestBody: bytes.NewBufferString(`{"table_names": "non_existent_table"}`),
|
|
wantStatusCode: http.StatusOK,
|
|
want: nil,
|
|
},
|
|
{
|
|
name: "invoke list_publication_tables with non-existent publication",
|
|
requestBody: bytes.NewBufferString(`{"publication_names": "non_existent_pub"}`),
|
|
wantStatusCode: http.StatusOK,
|
|
want: nil,
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
const api = "http://127.0.0.1:5000/api/tool/list_publication_tables/invoke"
|
|
|
|
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
|
}
|
|
if tc.wantStatusCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var bodyWrapper struct {
|
|
Result json.RawMessage `json:"result"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
|
t.Fatalf("error decoding response wrapper: %v", err)
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
|
resultString = string(bodyWrapper.Result)
|
|
}
|
|
|
|
var got []map[string]any
|
|
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
|
t.Fatalf("failed to unmarshal nested result string: %v, content: %s", err, resultString)
|
|
}
|
|
|
|
if diff := cmp.Diff(tc.want, got); diff != "" {
|
|
t.Errorf("Unexpected result (-want +got):\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func RunPostgresListActiveQueriesTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
|
type queryListDetails struct {
|
|
ProcessId any `json:"pid"`
|
|
User string `json:"user"`
|
|
Datname string `json:"datname"`
|
|
ApplicationName string `json:"application_name"`
|
|
ClientAddress string `json:"client_addr"`
|
|
State string `json:"state"`
|
|
WaitEventType string `json:"wait_event_type"`
|
|
WaitEvent string `json:"wait_event"`
|
|
BackendStart any `json:"backend_start"`
|
|
TransactionStart any `json:"xact_start"`
|
|
QueryStart any `json:"query_start"`
|
|
QueryDuration any `json:"query_duration"`
|
|
Query string `json:"query"`
|
|
}
|
|
|
|
singleQueryWanted := queryListDetails{
|
|
ProcessId: any(nil),
|
|
User: "",
|
|
Datname: "",
|
|
ApplicationName: "",
|
|
ClientAddress: "",
|
|
State: "",
|
|
WaitEventType: "",
|
|
WaitEvent: "",
|
|
BackendStart: any(nil),
|
|
TransactionStart: any(nil),
|
|
QueryStart: any(nil),
|
|
QueryDuration: any(nil),
|
|
Query: "SELECT pg_sleep(10);",
|
|
}
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
requestBody io.Reader
|
|
clientSleepSecs int
|
|
waitSecsBeforeCheck int
|
|
wantStatusCode int
|
|
want any
|
|
}{
|
|
// exclude background monitoring apps such as "wal_uploader"
|
|
{
|
|
name: "invoke list_active_queries when the system is idle",
|
|
requestBody: bytes.NewBufferString(`{"exclude_application_names": "wal_uploader"}`),
|
|
clientSleepSecs: 0,
|
|
waitSecsBeforeCheck: 0,
|
|
wantStatusCode: http.StatusOK,
|
|
want: []queryListDetails(nil),
|
|
},
|
|
{
|
|
name: "invoke list_active_queries when there is 1 ongoing but lower than the threshold",
|
|
requestBody: bytes.NewBufferString(`{"min_duration": "100 seconds", "exclude_application_names": "wal_uploader"}`),
|
|
clientSleepSecs: 1,
|
|
waitSecsBeforeCheck: 1,
|
|
wantStatusCode: http.StatusOK,
|
|
want: []queryListDetails(nil),
|
|
},
|
|
{
|
|
name: "invoke list_active_queries when 1 ongoing query should show up",
|
|
requestBody: bytes.NewBufferString(`{"min_duration": "1 seconds", "exclude_application_names": "wal_uploader"}`),
|
|
clientSleepSecs: 10,
|
|
waitSecsBeforeCheck: 5,
|
|
wantStatusCode: http.StatusOK,
|
|
want: []queryListDetails{singleQueryWanted},
|
|
},
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
if tc.clientSleepSecs > 0 {
|
|
wg.Add(1)
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
err := pool.Ping(ctx)
|
|
if err != nil {
|
|
t.Errorf("unable to connect to test database: %s", err)
|
|
return
|
|
}
|
|
_, err = pool.Exec(ctx, fmt.Sprintf("SELECT pg_sleep(%d);", tc.clientSleepSecs))
|
|
if err != nil {
|
|
t.Errorf("Executing 'SELECT pg_sleep' failed: %s", err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
if tc.waitSecsBeforeCheck > 0 {
|
|
time.Sleep(time.Duration(tc.waitSecsBeforeCheck) * time.Second)
|
|
}
|
|
|
|
const api = "http://127.0.0.1:5000/api/tool/list_active_queries/invoke"
|
|
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
|
}
|
|
if tc.wantStatusCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var bodyWrapper struct {
|
|
Result json.RawMessage `json:"result"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
|
t.Fatalf("error decoding response wrapper: %v", err)
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
|
resultString = string(bodyWrapper.Result)
|
|
}
|
|
|
|
var got any
|
|
var details []queryListDetails
|
|
if err := json.Unmarshal([]byte(resultString), &details); err != nil {
|
|
t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err)
|
|
}
|
|
got = details
|
|
|
|
if diff := cmp.Diff(tc.want, got, cmp.Comparer(func(a, b queryListDetails) bool {
|
|
return a.Query == b.Query
|
|
})); diff != "" {
|
|
t.Errorf("Unexpected result: got %#v, want: %#v", got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func RunPostgresListAvailableExtensionsTest(t *testing.T) {
|
|
invokeTcs := []struct {
|
|
name string
|
|
api string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
}{
|
|
{
|
|
name: "invoke list_available_extensions output",
|
|
api: "http://127.0.0.1:5000/api/tool/list_available_extensions/invoke",
|
|
wantStatusCode: http.StatusOK,
|
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
// Intentionally not adding the output check as output depends on the postgres instance used where the the functional test runs.
|
|
// Adding the check will make the test flaky.
|
|
})
|
|
}
|
|
}
|
|
|
|
func RunPostgresListInstalledExtensionsTest(t *testing.T) {
|
|
invokeTcs := []struct {
|
|
name string
|
|
api string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
}{
|
|
{
|
|
name: "invoke list_installed_extensions output",
|
|
api: "http://127.0.0.1:5000/api/tool/list_installed_extensions/invoke",
|
|
wantStatusCode: http.StatusOK,
|
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
resp, bodyBytes := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
|
}
|
|
|
|
// Intentionally not adding the output check as output depends on the postgres instance used where the the functional test runs.
|
|
// Adding the check will make the test flaky.
|
|
})
|
|
}
|
|
}
|
|
|
|
func setupPostgresIndex(t *testing.T, ctx context.Context, pool *pgxpool.Pool, schemaName string, tableName string) func(t *testing.T) {
|
|
t.Helper()
|
|
createSchemaStmt := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s;", schemaName)
|
|
if _, err := pool.Exec(ctx, createSchemaStmt); err != nil {
|
|
t.Fatalf("unable to create schema %s: %v", schemaName, err)
|
|
}
|
|
|
|
fullTableName := fmt.Sprintf("%s.%s", schemaName, tableName)
|
|
createTableStmt := fmt.Sprintf("CREATE TABLE %s (id SERIAL PRIMARY KEY, name TEXT, email TEXT);", fullTableName)
|
|
if _, err := pool.Exec(ctx, createTableStmt); err != nil {
|
|
t.Fatalf("unable to create table %s: %v", fullTableName, err)
|
|
}
|
|
|
|
// Create a unique index on email
|
|
index1Stmt := fmt.Sprintf("CREATE UNIQUE INDEX %s_email_idx ON %s (email);", tableName, fullTableName)
|
|
if _, err := pool.Exec(ctx, index1Stmt); err != nil {
|
|
t.Fatalf("unable to create index %s_email_idx: %v", tableName, err)
|
|
}
|
|
|
|
// Create a non-unique index on name
|
|
index2Stmt := fmt.Sprintf("CREATE INDEX %s_name_idx ON %s (name);", tableName, fullTableName)
|
|
if _, err := pool.Exec(ctx, index2Stmt); err != nil {
|
|
t.Fatalf("unable to create index %s_name_idx: %v", tableName, err)
|
|
}
|
|
|
|
return func(t *testing.T) {
|
|
t.Helper()
|
|
if _, err := pool.Exec(ctx, fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE;", schemaName)); err != nil {
|
|
t.Errorf("unable to drop schema: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func RunPostgresListIndexesTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
|
schemaName := "testschema_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
|
tableName := "table1_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
|
cleanup := setupPostgresIndex(t, ctx, pool, schemaName, tableName)
|
|
defer cleanup(t)
|
|
|
|
// Primary key index
|
|
wantIndexPK := map[string]any{
|
|
"schema_name": schemaName,
|
|
"table_name": tableName,
|
|
"index_name": tableName + "_pkey",
|
|
"index_type": "btree",
|
|
"is_unique": true,
|
|
"is_primary": true,
|
|
"is_used": false,
|
|
"index_definition": fmt.Sprintf("CREATE UNIQUE INDEX %s_pkey ON %s.%s USING btree (id)", tableName, schemaName, tableName),
|
|
// Size and scan counts can vary, so omitting them from strict checks or using ranges might be better in real tests.
|
|
}
|
|
// Email unique index
|
|
wantIndexEmail := map[string]any{
|
|
"schema_name": schemaName,
|
|
"table_name": tableName,
|
|
"index_name": tableName + "_email_idx",
|
|
"index_type": "btree",
|
|
"is_unique": true,
|
|
"is_primary": false,
|
|
"is_used": false,
|
|
"index_definition": fmt.Sprintf("CREATE UNIQUE INDEX %s_email_idx ON %s.%s USING btree (email)", tableName, schemaName, tableName),
|
|
}
|
|
// Name non-unique index
|
|
wantIndexName := map[string]any{
|
|
"schema_name": schemaName,
|
|
"table_name": tableName,
|
|
"index_name": tableName + "_name_idx",
|
|
"index_type": "btree",
|
|
"is_unique": false,
|
|
"is_primary": false,
|
|
"is_used": false,
|
|
"index_definition": fmt.Sprintf("CREATE INDEX %s_name_idx ON %s.%s USING btree (name)", tableName, schemaName, tableName),
|
|
}
|
|
|
|
allWantIndexes := []map[string]any{wantIndexEmail, wantIndexName, wantIndexPK}
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
want []map[string]any
|
|
}{
|
|
// List all indexes is skipped because the output might include indexes for other database tables
|
|
// defined outside of this test, which could make the test flaky.
|
|
{
|
|
name: "list_indexes for a specific schema and table",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s", "table_name": "%s"}`, schemaName, tableName)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: allWantIndexes,
|
|
},
|
|
{
|
|
name: "list_indexes for a specific schema",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s"}`, schemaName)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: allWantIndexes,
|
|
},
|
|
{
|
|
name: "list_indexes with non-existent schema",
|
|
requestBody: bytes.NewBufferString(`{"schema_name": "non_existent_schema"}`),
|
|
wantStatusCode: http.StatusOK,
|
|
want: nil,
|
|
},
|
|
{
|
|
name: "list_indexes with non-existent table in existing schema",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s", "table_name": "non_existent_table"}`, schemaName)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: nil,
|
|
},
|
|
{
|
|
name: "list_indexes filter by index name",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s", "table_name": "%s", "index_name": "%s"}`, schemaName, tableName, tableName+"_email_idx")),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []map[string]any{wantIndexEmail},
|
|
},
|
|
{
|
|
name: "list_indexes filter by non-existent index name",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s", "table_name": "%s", "index_name": "non_existent_idx"}`, schemaName, tableName)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: nil,
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
const api = "http://127.0.0.1:5000/api/tool/list_indexes/invoke"
|
|
|
|
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
|
}
|
|
if tc.wantStatusCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var bodyWrapper struct {
|
|
Result json.RawMessage `json:"result"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
|
t.Fatalf("error decoding response wrapper: %v", err)
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
|
resultString = string(bodyWrapper.Result)
|
|
}
|
|
|
|
var got []map[string]any
|
|
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
|
t.Fatalf("failed to unmarshal nested result string: %v, resultString: %s", err, resultString)
|
|
}
|
|
// Normalize got by removing fields that are hard to predict (like size)
|
|
for _, index := range got {
|
|
delete(index, "index_size_bytes")
|
|
delete(index, "index_scans")
|
|
delete(index, "tuples_read")
|
|
delete(index, "tuples_fetched")
|
|
}
|
|
|
|
if diff := cmp.Diff(tc.want, got); diff != "" {
|
|
t.Errorf("Unexpected result (-want +got):\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func setupListSequencesTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) (string, func(t *testing.T)) {
|
|
sequenceName := "list_sequences_seq1_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
|
createSequence1Stmt := fmt.Sprintf("CREATE SEQUENCE %s INCREMENT 1 START 1;", sequenceName)
|
|
|
|
_, err := pool.Exec(ctx, createSequence1Stmt)
|
|
if err != nil {
|
|
t.Fatalf("unable to create sequence %s: %s", sequenceName, err)
|
|
}
|
|
return sequenceName, func(t *testing.T) {
|
|
_, err := pool.Exec(ctx, fmt.Sprintf("DROP SEQUENCE IF EXISTS %s;", sequenceName))
|
|
if err != nil {
|
|
t.Errorf("unable to drop sequences: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func RunPostgresListSequencesTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
|
sequenceName, teardown := setupListSequencesTest(t, ctx, pool)
|
|
defer teardown(t)
|
|
|
|
wantSequence := map[string]any{
|
|
"sequence_name": sequenceName,
|
|
"schema_name": "public",
|
|
"sequence_owner": "postgres",
|
|
"data_type": "bigint",
|
|
"start_value": float64(1),
|
|
"min_value": float64(1),
|
|
"max_value": float64(9223372036854775807),
|
|
"increment_by": float64(1),
|
|
"last_value": nil,
|
|
}
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
api string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
want []map[string]any
|
|
}{
|
|
{
|
|
name: "invoke list_sequences",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"sequence_name": "%s"}`, sequenceName)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []map[string]any{wantSequence},
|
|
},
|
|
{
|
|
name: "invoke list_sequences with non-existent sequence",
|
|
requestBody: bytes.NewBufferString(`{"sequence_name": "non_existent_sequence"}`),
|
|
wantStatusCode: http.StatusOK,
|
|
want: nil,
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
const api = "http://127.0.0.1:5000/api/tool/list_sequences/invoke"
|
|
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
|
}
|
|
if tc.wantStatusCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var bodyWrapper struct {
|
|
Result json.RawMessage `json:"result"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
|
t.Fatalf("error decoding response wrapper: %v", err)
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
|
resultString = string(bodyWrapper.Result)
|
|
}
|
|
|
|
var got []map[string]any
|
|
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
|
t.Fatalf("failed to unmarshal nested result string: %v", err)
|
|
}
|
|
|
|
if diff := cmp.Diff(tc.want, got); diff != "" {
|
|
t.Errorf("Unexpected result (-want +got):\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func RunPostgresListTableSpacesTest(t *testing.T) {
|
|
invokeTcs := []struct {
|
|
name string
|
|
api string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
}{
|
|
{
|
|
name: "invoke list_tablespaces output",
|
|
api: "http://127.0.0.1:5000/api/tool/list_tablespaces/invoke",
|
|
wantStatusCode: http.StatusOK,
|
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
// Intentionally not adding the output check as output depends on the postgres instance used where the the functional test runs.
|
|
// Adding the check will make the test flaky.
|
|
})
|
|
}
|
|
}
|
|
|
|
func RunPostgresListPgSettingsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
|
targetSetting := "maintenance_work_mem"
|
|
var name, setting, unit, shortDesc, source, contextVal string
|
|
|
|
// We query the raw pg_settings to get the data needed to reconstruct the logic
|
|
// defined in your listPgSettingQuery.
|
|
err := pool.QueryRow(ctx, `
|
|
SELECT name, setting, unit, short_desc, source, context
|
|
FROM pg_settings
|
|
WHERE name = $1
|
|
`, targetSetting).Scan(&name, &setting, &unit, &shortDesc, &source, &contextVal)
|
|
|
|
if err != nil {
|
|
t.Fatalf("Setup failed: could not fetch postgres setting '%s': %v", targetSetting, err)
|
|
}
|
|
|
|
// Replicate the SQL CASE logic for 'requires_restart' field
|
|
requiresRestart := "No"
|
|
switch contextVal {
|
|
case "postmaster":
|
|
requiresRestart = "Yes"
|
|
case "sighup":
|
|
requiresRestart = "No (Reload sufficient)"
|
|
}
|
|
|
|
expectedObject := map[string]interface{}{
|
|
"name": name,
|
|
"current_value": setting,
|
|
"unit": unit,
|
|
"short_desc": shortDesc,
|
|
"source": source,
|
|
"requires_restart": requiresRestart,
|
|
}
|
|
expectedJSON, _ := json.Marshal([]interface{}{expectedObject})
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
want string
|
|
}{
|
|
{
|
|
name: "invoke list_pg_settings with specific setting",
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"setting_name": "%s"}`, targetSetting))),
|
|
wantStatusCode: http.StatusOK,
|
|
want: string(expectedJSON),
|
|
},
|
|
{
|
|
name: "invoke list_pg_settings with non-existent setting",
|
|
requestBody: bytes.NewBuffer([]byte(`{"setting_name": "non_existent_config_xyz"}`)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: `null`,
|
|
},
|
|
}
|
|
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
const api = "http://127.0.0.1:5000/api/tool/list_pg_settings/invoke"
|
|
resp, body := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
|
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
|
|
}
|
|
if tc.wantStatusCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var bodyWrapper struct {
|
|
Result json.RawMessage `json:"result"`
|
|
}
|
|
if err := json.Unmarshal(body, &bodyWrapper); err != nil {
|
|
t.Fatalf("error decoding response wrapper: %v", err)
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
|
resultString = string(bodyWrapper.Result)
|
|
}
|
|
|
|
var got, want any
|
|
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
|
t.Fatalf("failed to unmarshal nested result string: %v", err)
|
|
}
|
|
if err := json.Unmarshal([]byte(tc.want), &want); err != nil {
|
|
t.Fatalf("failed to unmarshal want string: %v", err)
|
|
}
|
|
|
|
if diff := cmp.Diff(want, got); diff != "" {
|
|
t.Errorf("Unexpected result (-want +got):\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// RunPostgresDatabaseStatsTest tests the database_stats tool by comparing API results
|
|
// against a direct query to the database.
|
|
func RunPostgresListDatabaseStatsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
|
dbName1 := "test_db_stats_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
|
dbOwner1 := "test_user_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
|
dbName2 := "test_db_stats_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
|
dbOwner2 := "test_user_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
|
|
|
cleanup1 := setUpDatabase(t, ctx, pool, dbName1, dbOwner1)
|
|
defer cleanup1()
|
|
cleanup2 := setUpDatabase(t, ctx, pool, dbName2, dbOwner2)
|
|
defer cleanup2()
|
|
|
|
requiredKeys := map[string]bool{
|
|
"database_name": true,
|
|
"database_owner": true,
|
|
"default_tablespace": true,
|
|
"is_connectable": true,
|
|
}
|
|
|
|
db1Want := map[string]interface{}{
|
|
"database_name": dbName1,
|
|
"database_owner": dbOwner1,
|
|
"default_tablespace": "pg_default",
|
|
"is_connectable": true,
|
|
}
|
|
|
|
db2Want := map[string]interface{}{
|
|
"database_name": dbName2,
|
|
"database_owner": dbOwner2,
|
|
"default_tablespace": "pg_default",
|
|
"is_connectable": true,
|
|
}
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
want []map[string]interface{}
|
|
}{
|
|
{
|
|
name: "invoke database_stats filtering by specific database name",
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"database_name": "%s"}`, dbName1))),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []map[string]interface{}{db1Want},
|
|
},
|
|
{
|
|
name: "invoke database_stats filtering by specific owner",
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"database_owner": "%s"}`, dbOwner2))),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []map[string]interface{}{db2Want},
|
|
},
|
|
{
|
|
name: "filter by tablespace",
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"default_tablespace": "pg_default", "database_name": "%s"}`, dbName1))),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []map[string]interface{}{db1Want},
|
|
},
|
|
{
|
|
name: "sort by size",
|
|
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"sort_by": "size", "database_name": "%s"}`, dbName2))),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []map[string]interface{}{db2Want},
|
|
},
|
|
}
|
|
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
const api = "http://127.0.0.1:5000/api/tool/list_database_stats/invoke"
|
|
resp, body := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
|
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
|
|
}
|
|
var bodyWrapper struct {
|
|
Result json.RawMessage `json:"result"`
|
|
}
|
|
if err := json.Unmarshal(body, &bodyWrapper); err != nil {
|
|
t.Fatalf("error decoding response wrapper: %v", err)
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
|
resultString = string(bodyWrapper.Result)
|
|
}
|
|
|
|
var got []map[string]interface{}
|
|
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
|
t.Fatalf("failed to unmarshal nested result string: %v", err)
|
|
}
|
|
|
|
// Configuration for comparison
|
|
opts := []cmp.Option{
|
|
// Ensure consistent order based on name for comparison
|
|
cmpopts.SortSlices(func(a, b map[string]interface{}) bool {
|
|
return a["database_name"].(string) < b["database_name"].(string)
|
|
}),
|
|
|
|
// Ignore Volatile Keys which change in every run and only compare the keys in 'requiredKeys'
|
|
cmpopts.IgnoreMapEntries(func(key string, _ interface{}) bool {
|
|
return !requiredKeys[key]
|
|
}),
|
|
|
|
// Ignore Irrelevant Databases
|
|
cmpopts.IgnoreSliceElements(func(v map[string]interface{}) bool {
|
|
name, ok := v["database_name"].(string)
|
|
if !ok {
|
|
return true
|
|
}
|
|
return name != dbName1 && name != dbName2
|
|
}),
|
|
}
|
|
|
|
if diff := cmp.Diff(tc.want, got, opts...); diff != "" {
|
|
t.Errorf("Unexpected result (-want +got):\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func setUpDatabase(t *testing.T, ctx context.Context, pool *pgxpool.Pool, dbName, dbOwner string) func() {
|
|
_, err := pool.Exec(ctx, fmt.Sprintf("CREATE ROLE %s LOGIN PASSWORD 'password';", dbOwner))
|
|
if err != nil {
|
|
_, _ = pool.Exec(ctx, fmt.Sprintf("DROP ROLE %s;", dbOwner))
|
|
t.Fatalf("failed to create %s: %v", dbOwner, err)
|
|
}
|
|
_, err = pool.Exec(ctx, fmt.Sprintf("GRANT %s TO current_user;", dbOwner))
|
|
if err != nil {
|
|
t.Fatalf("failed to grant %s to current_user: %v", dbOwner, err)
|
|
}
|
|
_, err = pool.Exec(ctx, fmt.Sprintf("CREATE DATABASE %s OWNER %s;", dbName, dbOwner))
|
|
if err != nil {
|
|
t.Fatalf("failed to create %s: %v", dbName, err)
|
|
}
|
|
return func() {
|
|
_, _ = pool.Exec(ctx, fmt.Sprintf("DROP DATABASE IF EXISTS %s;", dbName))
|
|
_, _ = pool.Exec(ctx, fmt.Sprintf("DROP ROLE IF EXISTS %s;", dbOwner))
|
|
}
|
|
}
|
|
|
|
func setupPostgresRoles(t *testing.T, ctx context.Context, pool *pgxpool.Pool) (string, string, string, func(t *testing.T)) {
|
|
t.Helper()
|
|
suffix := strings.ReplaceAll(uuid.New().String(), "-", "")
|
|
|
|
adminUser := "test_role_admin_" + suffix
|
|
superUser := "test_role_super_" + suffix
|
|
normalUser := "test_role_normal_" + suffix
|
|
|
|
createAdminStmt := fmt.Sprintf("CREATE ROLE %s NOLOGIN;", adminUser)
|
|
if _, err := pool.Exec(ctx, createAdminStmt); err != nil {
|
|
t.Fatalf("unable to create role %s: %v", adminUser, err)
|
|
}
|
|
|
|
createSuperUserStmt := fmt.Sprintf("CREATE ROLE %s LOGIN CREATEDB;", superUser)
|
|
if _, err := pool.Exec(ctx, createSuperUserStmt); err != nil {
|
|
t.Fatalf("unable to create role %s: %v", superUser, err)
|
|
}
|
|
|
|
createNormalUserStmt := fmt.Sprintf("CREATE ROLE %s LOGIN;", normalUser)
|
|
if _, err := pool.Exec(ctx, createNormalUserStmt); err != nil {
|
|
t.Fatalf("unable to create role %s: %v", normalUser, err)
|
|
}
|
|
|
|
// Establish Relationships (Admin -> Superuser -> Normal)
|
|
if _, err := pool.Exec(ctx, fmt.Sprintf("GRANT %s TO %s;", adminUser, superUser)); err != nil {
|
|
t.Fatalf("unable to grant %s to %s: %v", adminUser, superUser, err)
|
|
}
|
|
if _, err := pool.Exec(ctx, fmt.Sprintf("GRANT %s TO %s;", superUser, normalUser)); err != nil {
|
|
t.Fatalf("unable to grant %s to %s: %v", superUser, normalUser, err)
|
|
}
|
|
|
|
return adminUser, superUser, normalUser, func(t *testing.T) {
|
|
t.Helper()
|
|
_, _ = pool.Exec(ctx, fmt.Sprintf("DROP ROLE IF EXISTS %s;", normalUser))
|
|
_, _ = pool.Exec(ctx, fmt.Sprintf("DROP ROLE IF EXISTS %s;", superUser))
|
|
_, _ = pool.Exec(ctx, fmt.Sprintf("DROP ROLE IF EXISTS %s;", adminUser))
|
|
}
|
|
}
|
|
|
|
func RunPostgresListRolesTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
|
adminUser, superUser, normalUser, cleanup := setupPostgresRoles(t, ctx, pool)
|
|
defer cleanup(t)
|
|
|
|
wantAdmin := map[string]any{
|
|
"role_name": adminUser,
|
|
"connection_limit": float64(-1),
|
|
"is_superuser": false,
|
|
"inherits_privileges": true,
|
|
"can_create_roles": false,
|
|
"can_create_db": false,
|
|
"can_login": false,
|
|
"is_replication_role": false,
|
|
"bypass_rls": false,
|
|
"direct_members": []any{superUser},
|
|
"member_of": []any{},
|
|
}
|
|
|
|
wantSuperUser := map[string]any{
|
|
"role_name": superUser,
|
|
"connection_limit": float64(-1),
|
|
"is_superuser": false,
|
|
"inherits_privileges": true,
|
|
"can_create_roles": false,
|
|
"can_create_db": true,
|
|
"can_login": true,
|
|
"is_replication_role": false,
|
|
"bypass_rls": false,
|
|
"direct_members": []any{normalUser},
|
|
"member_of": []any{adminUser},
|
|
}
|
|
|
|
wantNormalUser := map[string]any{
|
|
"role_name": normalUser,
|
|
"connection_limit": float64(-1),
|
|
"is_superuser": false,
|
|
"inherits_privileges": true,
|
|
"can_create_roles": false,
|
|
"can_create_db": false,
|
|
"can_login": true,
|
|
"is_replication_role": false,
|
|
"bypass_rls": false,
|
|
"direct_members": []any{},
|
|
"member_of": []any{superUser},
|
|
}
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
want []map[string]any
|
|
}{
|
|
{
|
|
name: "list_roles with filter for created roles",
|
|
requestBody: bytes.NewBufferString(`{"role_name": "test_role_"}`),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []map[string]any{wantAdmin, wantNormalUser, wantSuperUser},
|
|
},
|
|
{
|
|
name: "list_roles filter specific role",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"role_name": "%s"}`, superUser)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []map[string]any{wantSuperUser},
|
|
},
|
|
{
|
|
name: "list_roles non-existent role",
|
|
requestBody: bytes.NewBufferString(`{"role_name": "non_existent_role_xyz"}`),
|
|
wantStatusCode: http.StatusOK,
|
|
want: nil,
|
|
},
|
|
}
|
|
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
const api = "http://127.0.0.1:5000/api/tool/list_roles/invoke"
|
|
|
|
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
|
}
|
|
if tc.wantStatusCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var bodyWrapper struct {
|
|
Result json.RawMessage `json:"result"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
|
t.Fatalf("error decoding response wrapper: %v", err)
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
|
resultString = string(bodyWrapper.Result)
|
|
}
|
|
|
|
var got []map[string]any
|
|
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
|
t.Fatalf("failed to unmarshal nested result string: %v, resultString: %s", err, resultString)
|
|
}
|
|
|
|
gotMap := make(map[string]map[string]any)
|
|
for _, role := range got {
|
|
// Remove fields that change every run
|
|
delete(role, "oid")
|
|
delete(role, "valid_until")
|
|
|
|
if name, ok := role["role_name"].(string); ok {
|
|
gotMap[name] = role
|
|
}
|
|
}
|
|
|
|
// Check that every role in 'want' exists in 'got' and matches
|
|
for _, wantRole := range tc.want {
|
|
roleName, _ := wantRole["role_name"].(string)
|
|
|
|
gotRole, exists := gotMap[roleName]
|
|
if !exists {
|
|
t.Errorf("Expected role %q was not found in the response", roleName)
|
|
continue
|
|
}
|
|
|
|
if diff := cmp.Diff(wantRole, gotRole); diff != "" {
|
|
t.Errorf("Role %q mismatch (-want +got):\n%s", roleName, diff)
|
|
}
|
|
}
|
|
|
|
// Verify that if want is nil/empty, got is also empty
|
|
if len(tc.want) == 0 && len(got) != 0 {
|
|
t.Errorf("Expected empty result, but got %d roles", len(got))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// RunMySQLListTablesTest run tests against the mysql-list-tables tool
|
|
func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNameAuth, expectedOwner string) {
|
|
var ownerWant any
|
|
if expectedOwner == "" {
|
|
ownerWant = nil
|
|
} else {
|
|
ownerWant = expectedOwner
|
|
}
|
|
|
|
type tableInfo struct {
|
|
ObjectName string `json:"object_name"`
|
|
SchemaName string `json:"schema_name"`
|
|
ObjectDetails string `json:"object_details"`
|
|
}
|
|
|
|
type column struct {
|
|
DataType string `json:"data_type"`
|
|
ColumnName string `json:"column_name"`
|
|
ColumnComment string `json:"column_comment"`
|
|
ColumnDefault any `json:"column_default"`
|
|
IsNotNullable int `json:"is_not_nullable"`
|
|
OrdinalPosition int `json:"ordinal_position"`
|
|
}
|
|
|
|
type objectDetails struct {
|
|
Owner any `json:"owner"`
|
|
Columns []column `json:"columns"`
|
|
Comment string `json:"comment"`
|
|
Indexes []any `json:"indexes"`
|
|
Triggers []any `json:"triggers"`
|
|
Constraints []any `json:"constraints"`
|
|
ObjectName string `json:"object_name"`
|
|
ObjectType string `json:"object_type"`
|
|
SchemaName string `json:"schema_name"`
|
|
}
|
|
|
|
paramTableWant := objectDetails{
|
|
ObjectName: tableNameParam,
|
|
SchemaName: databaseName,
|
|
ObjectType: "TABLE",
|
|
Owner: ownerWant,
|
|
Columns: []column{
|
|
{DataType: "int", ColumnName: "id", IsNotNullable: 1, OrdinalPosition: 1},
|
|
{DataType: "varchar(255)", ColumnName: "name", OrdinalPosition: 2},
|
|
},
|
|
Indexes: []any{map[string]any{"index_columns": []any{"id"}, "index_name": "PRIMARY", "is_primary": float64(1), "is_unique": float64(1)}},
|
|
Triggers: []any{},
|
|
Constraints: []any{map[string]any{"constraint_columns": []any{"id"}, "constraint_name": "PRIMARY", "constraint_type": "PRIMARY KEY", "foreign_key_referenced_columns": any(nil), "foreign_key_referenced_table": any(nil), "constraint_definition": ""}},
|
|
}
|
|
|
|
authTableWant := objectDetails{
|
|
ObjectName: tableNameAuth,
|
|
SchemaName: databaseName,
|
|
ObjectType: "TABLE",
|
|
Owner: ownerWant,
|
|
Columns: []column{
|
|
{DataType: "int", ColumnName: "id", IsNotNullable: 1, OrdinalPosition: 1},
|
|
{DataType: "varchar(255)", ColumnName: "name", OrdinalPosition: 2},
|
|
{DataType: "varchar(255)", ColumnName: "email", OrdinalPosition: 3},
|
|
},
|
|
Indexes: []any{map[string]any{"index_columns": []any{"id"}, "index_name": "PRIMARY", "is_primary": float64(1), "is_unique": float64(1)}},
|
|
Triggers: []any{},
|
|
Constraints: []any{map[string]any{"constraint_columns": []any{"id"}, "constraint_name": "PRIMARY", "constraint_type": "PRIMARY KEY", "foreign_key_referenced_columns": any(nil), "foreign_key_referenced_table": any(nil), "constraint_definition": ""}},
|
|
}
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
want any
|
|
isSimple bool
|
|
isAllTables bool
|
|
}{
|
|
{
|
|
name: "invoke list_tables for all tables detailed output",
|
|
requestBody: bytes.NewBufferString(`{"table_names":""}`),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []objectDetails{authTableWant, paramTableWant},
|
|
isAllTables: true,
|
|
},
|
|
{
|
|
name: "invoke list_tables detailed output",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s"}`, tableNameAuth)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []objectDetails{authTableWant},
|
|
},
|
|
{
|
|
name: "invoke list_tables simple output",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s", "output_format": "simple"}`, tableNameAuth)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []map[string]any{{"name": tableNameAuth}},
|
|
isSimple: true,
|
|
},
|
|
{
|
|
name: "invoke list_tables with multiple table names",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s,%s"}`, tableNameParam, tableNameAuth)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []objectDetails{authTableWant, paramTableWant},
|
|
},
|
|
{
|
|
name: "invoke list_tables with one existing and one non-existent table",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s,non_existent_table"}`, tableNameAuth)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []objectDetails{authTableWant},
|
|
},
|
|
{
|
|
name: "invoke list_tables with non-existent table",
|
|
requestBody: bytes.NewBufferString(`{"table_names": "non_existent_table"}`),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []objectDetails{},
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
const api = "http://127.0.0.1:5000/api/tool/list_tables/invoke"
|
|
resp, body := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
|
|
}
|
|
if tc.wantStatusCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var bodyWrapper struct {
|
|
Result json.RawMessage `json:"result"`
|
|
}
|
|
if err := json.Unmarshal(body, &bodyWrapper); err != nil {
|
|
t.Fatalf("error decoding response wrapper: %v", err)
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
|
resultString = string(bodyWrapper.Result)
|
|
}
|
|
|
|
var got any
|
|
if tc.isSimple {
|
|
var tables []tableInfo
|
|
if err := json.Unmarshal([]byte(resultString), &tables); err != nil {
|
|
t.Fatalf("failed to unmarshal outer JSON array into []tableInfo: %v", err)
|
|
}
|
|
details := []map[string]any{}
|
|
for _, table := range tables {
|
|
var d map[string]any
|
|
if err := json.Unmarshal([]byte(table.ObjectDetails), &d); err != nil {
|
|
t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err)
|
|
}
|
|
details = append(details, d)
|
|
}
|
|
got = details
|
|
} else {
|
|
var tables []tableInfo
|
|
if err := json.Unmarshal([]byte(resultString), &tables); err != nil {
|
|
t.Fatalf("failed to unmarshal outer JSON array into []tableInfo: %v", err)
|
|
}
|
|
details := []objectDetails{}
|
|
for _, table := range tables {
|
|
var d objectDetails
|
|
if err := json.Unmarshal([]byte(table.ObjectDetails), &d); err != nil {
|
|
t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err)
|
|
}
|
|
details = append(details, d)
|
|
}
|
|
got = details
|
|
}
|
|
|
|
opts := []cmp.Option{
|
|
cmpopts.SortSlices(func(a, b objectDetails) bool { return a.ObjectName < b.ObjectName }),
|
|
cmpopts.SortSlices(func(a, b column) bool { return a.ColumnName < b.ColumnName }),
|
|
cmpopts.SortSlices(func(a, b map[string]any) bool { return a["name"].(string) < b["name"].(string) }),
|
|
}
|
|
|
|
// Checking only the current database where the test tables are created to avoid brittle tests.
|
|
if tc.isAllTables {
|
|
filteredGot := []objectDetails{}
|
|
if got != nil {
|
|
for _, item := range got.([]objectDetails) {
|
|
if item.SchemaName == databaseName {
|
|
filteredGot = append(filteredGot, item)
|
|
}
|
|
}
|
|
}
|
|
got = filteredGot
|
|
}
|
|
|
|
if diff := cmp.Diff(tc.want, got, opts...); diff != "" {
|
|
t.Errorf("Unexpected result: got %#v, want: %#v", got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// RunMySQLListActiveQueriesTest run tests against the mysql-list-active-queries tests
|
|
func RunMySQLListActiveQueriesTest(t *testing.T, ctx context.Context, pool *sql.DB) {
|
|
type queryListDetails struct {
|
|
ProcessId any `json:"process_id"`
|
|
Query string `json:"query"`
|
|
TrxStarted any `json:"trx_started"`
|
|
TrxDuration any `json:"trx_duration_seconds"`
|
|
TrxWaitDuration any `json:"trx_wait_duration_seconds"`
|
|
QueryTime any `json:"query_time"`
|
|
TrxState string `json:"trx_state"`
|
|
ProcessState string `json:"process_state"`
|
|
User string `json:"user"`
|
|
TrxRowsLocked any `json:"trx_rows_locked"`
|
|
TrxRowsModified any `json:"trx_rows_modified"`
|
|
Db string `json:"db"`
|
|
}
|
|
|
|
singleQueryWanted := queryListDetails{
|
|
ProcessId: any(nil),
|
|
Query: "SELECT sleep(10)",
|
|
TrxStarted: any(nil),
|
|
TrxDuration: any(nil),
|
|
TrxWaitDuration: any(nil),
|
|
QueryTime: any(nil),
|
|
TrxState: "",
|
|
ProcessState: "User sleep",
|
|
User: "",
|
|
TrxRowsLocked: any(nil),
|
|
TrxRowsModified: any(nil),
|
|
Db: "",
|
|
}
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
requestBody io.Reader
|
|
clientSleepSecs int
|
|
waitSecsBeforeCheck int
|
|
wantStatusCode int
|
|
want any
|
|
}{
|
|
{
|
|
name: "invoke list_active_queries when the system is idle",
|
|
requestBody: bytes.NewBufferString(`{}`),
|
|
clientSleepSecs: 0,
|
|
waitSecsBeforeCheck: 0,
|
|
wantStatusCode: http.StatusOK,
|
|
want: []queryListDetails(nil),
|
|
},
|
|
{
|
|
name: "invoke list_active_queries when there is 1 ongoing but lower than the threshold",
|
|
requestBody: bytes.NewBufferString(`{"min_duration_secs": 100}`),
|
|
clientSleepSecs: 10,
|
|
waitSecsBeforeCheck: 1,
|
|
wantStatusCode: http.StatusOK,
|
|
want: []queryListDetails(nil),
|
|
},
|
|
{
|
|
name: "invoke list_active_queries when 1 ongoing query should show up",
|
|
requestBody: bytes.NewBufferString(`{"min_duration_secs": 5}`),
|
|
clientSleepSecs: 0,
|
|
waitSecsBeforeCheck: 5,
|
|
wantStatusCode: http.StatusOK,
|
|
want: []queryListDetails{singleQueryWanted},
|
|
},
|
|
{
|
|
name: "invoke list_active_queries when 2 ongoing query should show up",
|
|
requestBody: bytes.NewBufferString(`{"min_duration_secs": 2}`),
|
|
clientSleepSecs: 10,
|
|
waitSecsBeforeCheck: 3,
|
|
wantStatusCode: http.StatusOK,
|
|
want: []queryListDetails{singleQueryWanted, singleQueryWanted},
|
|
},
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
if tc.clientSleepSecs > 0 {
|
|
wg.Add(1)
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
err := pool.PingContext(ctx)
|
|
if err != nil {
|
|
t.Errorf("unable to connect to test database: %s", err)
|
|
return
|
|
}
|
|
_, err = pool.ExecContext(ctx, fmt.Sprintf("SELECT sleep(%d);", tc.clientSleepSecs))
|
|
if err != nil {
|
|
t.Errorf("Executing 'SELECT sleep' failed: %s", err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
if tc.waitSecsBeforeCheck > 0 {
|
|
time.Sleep(time.Duration(tc.waitSecsBeforeCheck) * time.Second)
|
|
}
|
|
|
|
const api = "http://127.0.0.1:5000/api/tool/list_active_queries/invoke"
|
|
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
|
}
|
|
if tc.wantStatusCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var bodyWrapper struct {
|
|
Result json.RawMessage `json:"result"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
|
t.Fatalf("error decoding response wrapper: %v", err)
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
|
resultString = string(bodyWrapper.Result)
|
|
}
|
|
|
|
var got any
|
|
var details []queryListDetails
|
|
if err := json.Unmarshal([]byte(resultString), &details); err != nil {
|
|
t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err)
|
|
}
|
|
got = details
|
|
|
|
if diff := cmp.Diff(tc.want, got, cmp.Comparer(func(a, b queryListDetails) bool {
|
|
return a.Query == b.Query && a.ProcessState == b.ProcessState
|
|
})); diff != "" {
|
|
t.Errorf("Unexpected result: got %#v, want: %#v", got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func RunMySQLListTablesMissingUniqueIndexes(t *testing.T, ctx context.Context, pool *sql.DB, databaseName string) {
|
|
type listDetails struct {
|
|
TableSchema string `json:"table_schema"`
|
|
TableName string `json:"table_name"`
|
|
}
|
|
|
|
// bunch of wanted
|
|
nonUniqueKeyTableName := "t03_non_unqiue_key_table"
|
|
noKeyTableName := "t04_no_key_table"
|
|
nonUniqueKeyTableWant := listDetails{
|
|
TableSchema: databaseName,
|
|
TableName: nonUniqueKeyTableName,
|
|
}
|
|
noKeyTableWant := listDetails{
|
|
TableSchema: databaseName,
|
|
TableName: noKeyTableName,
|
|
}
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
requestBody io.Reader
|
|
newTableName string
|
|
newTablePrimaryKey bool
|
|
newTableUniqueKey bool
|
|
newTableNonUniqueKey bool
|
|
wantStatusCode int
|
|
want any
|
|
}{
|
|
{
|
|
name: "invoke list_tables_missing_unique_indexes when nothing to be found",
|
|
requestBody: bytes.NewBufferString(`{}`),
|
|
newTableName: "",
|
|
newTablePrimaryKey: false,
|
|
newTableUniqueKey: false,
|
|
newTableNonUniqueKey: false,
|
|
wantStatusCode: http.StatusOK,
|
|
want: []listDetails(nil),
|
|
},
|
|
{
|
|
name: "invoke list_tables_missing_unique_indexes pk table will not show",
|
|
requestBody: bytes.NewBufferString(`{}`),
|
|
newTableName: "t01",
|
|
newTablePrimaryKey: true,
|
|
newTableUniqueKey: false,
|
|
newTableNonUniqueKey: false,
|
|
wantStatusCode: http.StatusOK,
|
|
want: []listDetails(nil),
|
|
},
|
|
{
|
|
name: "invoke list_tables_missing_unique_indexes uk table will not show",
|
|
requestBody: bytes.NewBufferString(`{}`),
|
|
newTableName: "t02",
|
|
newTablePrimaryKey: false,
|
|
newTableUniqueKey: true,
|
|
newTableNonUniqueKey: false,
|
|
wantStatusCode: http.StatusOK,
|
|
want: []listDetails(nil),
|
|
},
|
|
{
|
|
name: "invoke list_tables_missing_unique_indexes non-unique key only table will show",
|
|
requestBody: bytes.NewBufferString(`{}`),
|
|
newTableName: nonUniqueKeyTableName,
|
|
newTablePrimaryKey: false,
|
|
newTableUniqueKey: false,
|
|
newTableNonUniqueKey: true,
|
|
wantStatusCode: http.StatusOK,
|
|
want: []listDetails{nonUniqueKeyTableWant},
|
|
},
|
|
{
|
|
name: "invoke list_tables_missing_unique_indexes table with no key at all will show",
|
|
requestBody: bytes.NewBufferString(`{}`),
|
|
newTableName: noKeyTableName,
|
|
newTablePrimaryKey: false,
|
|
newTableUniqueKey: false,
|
|
newTableNonUniqueKey: false,
|
|
wantStatusCode: http.StatusOK,
|
|
want: []listDetails{nonUniqueKeyTableWant, noKeyTableWant},
|
|
},
|
|
{
|
|
name: "invoke list_tables_missing_unique_indexes table w/ both pk & uk will not show",
|
|
requestBody: bytes.NewBufferString(`{}`),
|
|
newTableName: "t05",
|
|
newTablePrimaryKey: true,
|
|
newTableUniqueKey: true,
|
|
newTableNonUniqueKey: false,
|
|
wantStatusCode: http.StatusOK,
|
|
want: []listDetails{nonUniqueKeyTableWant, noKeyTableWant},
|
|
},
|
|
{
|
|
name: "invoke list_tables_missing_unique_indexes table w/ uk & nk will not show",
|
|
requestBody: bytes.NewBufferString(`{}`),
|
|
newTableName: "t06",
|
|
newTablePrimaryKey: false,
|
|
newTableUniqueKey: true,
|
|
newTableNonUniqueKey: true,
|
|
wantStatusCode: http.StatusOK,
|
|
want: []listDetails{nonUniqueKeyTableWant, noKeyTableWant},
|
|
},
|
|
{
|
|
name: "invoke list_tables_missing_unique_indexes table w/ pk & nk will not show",
|
|
requestBody: bytes.NewBufferString(`{}`),
|
|
newTableName: "t07",
|
|
newTablePrimaryKey: true,
|
|
newTableUniqueKey: false,
|
|
newTableNonUniqueKey: true,
|
|
wantStatusCode: http.StatusOK,
|
|
want: []listDetails{nonUniqueKeyTableWant, noKeyTableWant},
|
|
},
|
|
{
|
|
name: "invoke list_tables_missing_unique_indexes with a non-exist database, nothing to show",
|
|
requestBody: bytes.NewBufferString(`{"table_schema": "non-exist-database"}`),
|
|
newTableName: "",
|
|
newTablePrimaryKey: false,
|
|
newTableUniqueKey: false,
|
|
newTableNonUniqueKey: false,
|
|
wantStatusCode: http.StatusOK,
|
|
want: []listDetails(nil),
|
|
},
|
|
{
|
|
name: "invoke list_tables_missing_unique_indexes with the right database, show everything",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_schema": "%s"}`, databaseName)),
|
|
newTableName: "",
|
|
newTablePrimaryKey: false,
|
|
newTableUniqueKey: false,
|
|
newTableNonUniqueKey: false,
|
|
wantStatusCode: http.StatusOK,
|
|
want: []listDetails{nonUniqueKeyTableWant, noKeyTableWant},
|
|
},
|
|
{
|
|
name: "invoke list_tables_missing_unique_indexes with limited output",
|
|
requestBody: bytes.NewBufferString(`{"limit": 1}`),
|
|
newTableName: "",
|
|
newTablePrimaryKey: false,
|
|
newTableUniqueKey: false,
|
|
newTableNonUniqueKey: false,
|
|
wantStatusCode: http.StatusOK,
|
|
want: []listDetails{nonUniqueKeyTableWant},
|
|
},
|
|
}
|
|
|
|
createTableHelper := func(t *testing.T, tableName, databaseName string, primaryKey, uniqueKey, nonUniqueKey bool, ctx context.Context, pool *sql.DB) func() {
|
|
var stmt strings.Builder
|
|
stmt.WriteString(fmt.Sprintf("CREATE TABLE %s (", tableName))
|
|
stmt.WriteString("c1 INT")
|
|
if primaryKey {
|
|
stmt.WriteString(" PRIMARY KEY")
|
|
}
|
|
stmt.WriteString(", c2 INT, c3 CHAR(8)")
|
|
if uniqueKey {
|
|
stmt.WriteString(", UNIQUE(c2)")
|
|
}
|
|
if nonUniqueKey {
|
|
stmt.WriteString(", INDEX(c3)")
|
|
}
|
|
stmt.WriteString(")")
|
|
|
|
t.Logf("Creating table: %s", stmt.String())
|
|
if _, err := pool.ExecContext(ctx, stmt.String()); err != nil {
|
|
t.Fatalf("failed executing %s: %v", stmt.String(), err)
|
|
}
|
|
|
|
return func() {
|
|
t.Logf("Dropping table: %s", tableName)
|
|
if _, err := pool.ExecContext(ctx, fmt.Sprintf("DROP TABLE %s", tableName)); err != nil {
|
|
t.Errorf("failed to drop table %s: %v", tableName, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
var cleanups []func()
|
|
defer func() {
|
|
for i := len(cleanups) - 1; i >= 0; i-- {
|
|
cleanups[i]()
|
|
}
|
|
}()
|
|
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
if tc.newTableName != "" {
|
|
cleanup := createTableHelper(t, tc.newTableName, databaseName, tc.newTablePrimaryKey, tc.newTableUniqueKey, tc.newTableNonUniqueKey, ctx, pool)
|
|
cleanups = append(cleanups, cleanup)
|
|
}
|
|
|
|
const api = "http://127.0.0.1:5000/api/tool/list_tables_missing_unique_indexes/invoke"
|
|
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
|
}
|
|
if tc.wantStatusCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var bodyWrapper struct {
|
|
Result json.RawMessage `json:"result"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
|
t.Fatalf("error decoding response wrapper: %v", err)
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
|
resultString = string(bodyWrapper.Result)
|
|
}
|
|
|
|
var got any
|
|
var details []listDetails
|
|
if err := json.Unmarshal([]byte(resultString), &details); err != nil {
|
|
t.Fatalf("failed to unmarshal nested listDetails string: %v", err)
|
|
}
|
|
got = details
|
|
|
|
if diff := cmp.Diff(tc.want, got, cmp.Comparer(func(a, b listDetails) bool {
|
|
return a.TableSchema == b.TableSchema && a.TableName == b.TableName
|
|
})); diff != "" {
|
|
t.Errorf("Unexpected result: got %#v, want: %#v", got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func RunMySQLListTableFragmentationTest(t *testing.T, databaseName, tableNameParam, tableNameAuth string) {
|
|
type tableFragmentationDetails struct {
|
|
TableSchema string `json:"table_schema"`
|
|
TableName string `json:"table_name"`
|
|
DataSize any `json:"data_size"`
|
|
IndexSize any `json:"index_size"`
|
|
DataFree any `json:"data_free"`
|
|
FragmentationPercentage any `json:"fragmentation_percentage"`
|
|
}
|
|
|
|
paramTableEntryWanted := tableFragmentationDetails{
|
|
TableSchema: databaseName,
|
|
TableName: tableNameParam,
|
|
DataSize: any(nil),
|
|
IndexSize: any(nil),
|
|
DataFree: any(nil),
|
|
FragmentationPercentage: any(nil),
|
|
}
|
|
authTableEntryWanted := tableFragmentationDetails{
|
|
TableSchema: databaseName,
|
|
TableName: tableNameAuth,
|
|
DataSize: any(nil),
|
|
IndexSize: any(nil),
|
|
DataFree: any(nil),
|
|
FragmentationPercentage: any(nil),
|
|
}
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
want any
|
|
}{
|
|
{
|
|
name: "invoke list_table_fragmentation on all, no data_free threshold, expected to have 2 results",
|
|
requestBody: bytes.NewBufferString(`{"data_free_threshold_bytes": 0}`),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []tableFragmentationDetails{authTableEntryWanted, paramTableEntryWanted},
|
|
},
|
|
{
|
|
name: "invoke list_table_fragmentation on all, no data_free threshold, limit to 1, expected to have 1 results",
|
|
requestBody: bytes.NewBufferString(`{"data_free_threshold_bytes": 0, "limit": 1}`),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []tableFragmentationDetails{authTableEntryWanted},
|
|
},
|
|
{
|
|
name: "invoke list_table_fragmentation on all databases and 1 specific table name, no data_free threshold, expected to have 1 result",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_name": "%s","data_free_threshold_bytes": 0}`, tableNameAuth)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []tableFragmentationDetails{authTableEntryWanted},
|
|
},
|
|
{
|
|
name: "invoke list_table_fragmentation on 1 database and 1 specific table name, no data_free threshold, expected to have 1 result",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_schema": "%s", "table_name": "%s", "data_free_threshold_bytes": 0}`, databaseName, tableNameParam)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []tableFragmentationDetails{paramTableEntryWanted},
|
|
},
|
|
{
|
|
name: "invoke list_table_fragmentation on 1 database and 1 specific table name, high data_free threshold, expected to have 0 result",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_schema": "%s", "table_name": "%s", "data_free_threshold_bytes": 1000000000}`, databaseName, tableNameParam)),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []tableFragmentationDetails(nil),
|
|
},
|
|
{
|
|
name: "invoke list_table_fragmentation on 1 non-exist database, no data_free threshold, expected to have 0 result",
|
|
requestBody: bytes.NewBufferString(`{"table_schema": "non_existent_database", "data_free_threshold_bytes": 0}`),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []tableFragmentationDetails(nil),
|
|
},
|
|
{
|
|
name: "invoke list_table_fragmentation on 1 non-exist table, no data_free threshold, expected to have 0 result",
|
|
requestBody: bytes.NewBufferString(`{"table_name": "non_existent_table", "data_free_threshold_bytes": 0}`),
|
|
wantStatusCode: http.StatusOK,
|
|
want: []tableFragmentationDetails(nil),
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
const api = "http://127.0.0.1:5000/api/tool/list_table_fragmentation/invoke"
|
|
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
|
}
|
|
if tc.wantStatusCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var bodyWrapper struct {
|
|
Result json.RawMessage `json:"result"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
|
t.Fatalf("error decoding response wrapper: %v", err)
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
|
resultString = string(bodyWrapper.Result)
|
|
}
|
|
|
|
var got any
|
|
var details []tableFragmentationDetails
|
|
if err := json.Unmarshal([]byte(resultString), &details); err != nil {
|
|
t.Fatalf("failed to unmarshal outer JSON array into []tableInfo: %v", err)
|
|
}
|
|
got = details
|
|
|
|
if diff := cmp.Diff(tc.want, got, cmp.Comparer(func(a, b tableFragmentationDetails) bool {
|
|
return a.TableSchema == b.TableSchema && a.TableName == b.TableName
|
|
})); diff != "" {
|
|
t.Errorf("Unexpected result: got %#v, want: %#v", got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func RunMySQLGetQueryPlanTest(t *testing.T, ctx context.Context, pool *sql.DB, databaseName, tableNameParam string) {
|
|
// Create a simple query to explain
|
|
query := fmt.Sprintf("SELECT * FROM %s", tableNameParam)
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
checkResult func(t *testing.T, result any)
|
|
}{
|
|
{
|
|
name: "invoke get_query_plan with valid query",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"sql_statement": "%s"}`, query)),
|
|
wantStatusCode: http.StatusOK,
|
|
checkResult: func(t *testing.T, result any) {
|
|
resultMap, ok := result.(map[string]any)
|
|
if !ok {
|
|
t.Fatalf("result should be a map, got %T", result)
|
|
}
|
|
if _, ok := resultMap["query_block"]; !ok {
|
|
t.Errorf("result should contain 'query_block', got %v", resultMap)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "invoke get_query_plan with invalid query",
|
|
requestBody: bytes.NewBufferString(`{"sql_statement": "SELECT * FROM non_existent_table"}`),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
checkResult: nil,
|
|
},
|
|
}
|
|
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
const api = "http://127.0.0.1:5000/api/tool/get_query_plan/invoke"
|
|
resp, respBytes := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBytes))
|
|
}
|
|
if tc.wantStatusCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var bodyWrapper map[string]json.RawMessage
|
|
|
|
if err := json.Unmarshal(respBytes, &bodyWrapper); err != nil {
|
|
t.Fatalf("error parsing response wrapper: %s, body: %s", err, string(respBytes))
|
|
}
|
|
|
|
resultJSON, ok := bodyWrapper["result"]
|
|
if !ok {
|
|
t.Fatal("unable to find 'result' in response body")
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(resultJSON, &resultString); err != nil {
|
|
if string(resultJSON) == "null" {
|
|
resultString = "null"
|
|
} else {
|
|
t.Fatalf("'result' is not a JSON-encoded string: %s", err)
|
|
}
|
|
}
|
|
|
|
var got map[string]any
|
|
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
|
t.Fatalf("failed to unmarshal actual result string: %v", err)
|
|
}
|
|
|
|
if tc.checkResult != nil {
|
|
tc.checkResult(t, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// RunMSSQLListTablesTest run tests againsts the mssql-list-tables tools.
|
|
func RunMSSQLListTablesTest(t *testing.T, tableNameParam, tableNameAuth string) {
|
|
// TableNameParam columns to construct want.
|
|
const paramTableColumns = `[
|
|
{"column_name": "id", "data_type": "INT", "column_ordinal_position": 1, "is_not_nullable": true},
|
|
{"column_name": "name", "data_type": "VARCHAR(255)", "column_ordinal_position": 2, "is_not_nullable": false}
|
|
]`
|
|
|
|
// TableNameAuth columns to construct want
|
|
const authTableColumns = `[
|
|
{"column_name": "id", "data_type": "INT", "column_ordinal_position": 1, "is_not_nullable": true},
|
|
{"column_name": "name", "data_type": "VARCHAR(255)", "column_ordinal_position": 2, "is_not_nullable": false},
|
|
{"column_name": "email", "data_type": "VARCHAR(255)", "column_ordinal_position": 3, "is_not_nullable": false}
|
|
]`
|
|
|
|
const (
|
|
// Template to construct detailed output want.
|
|
detailedObjectTemplate = `{
|
|
"schema_name": "dbo",
|
|
"object_name": "%[1]s",
|
|
"object_details": {
|
|
"owner": "dbo",
|
|
"triggers": [],
|
|
"columns": %[2]s,
|
|
"object_name": "%[1]s",
|
|
"object_type": "TABLE",
|
|
"schema_name": "dbo"
|
|
}
|
|
}`
|
|
|
|
// Template to construct simple output want
|
|
simpleObjectTemplate = `{"object_name":"%s", "schema_name":"dbo", "object_details":{"name":"%s"}}`
|
|
)
|
|
|
|
// Helper to build json for detailed want
|
|
getDetailedWant := func(tableName, columnJSON string) string {
|
|
return fmt.Sprintf(detailedObjectTemplate, tableName, columnJSON)
|
|
}
|
|
|
|
// Helper to build template for simple want
|
|
getSimpleWant := func(tableName string) string {
|
|
return fmt.Sprintf(simpleObjectTemplate, tableName, tableName)
|
|
}
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
api string
|
|
requestBody string
|
|
wantStatusCode int
|
|
want string
|
|
isAllTables bool
|
|
}{
|
|
{
|
|
name: "invoke list_tables for all tables detailed output",
|
|
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
|
requestBody: `{"table_names": ""}`,
|
|
wantStatusCode: http.StatusOK,
|
|
want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)),
|
|
isAllTables: true,
|
|
},
|
|
{
|
|
name: "invoke list_tables for all tables simple output",
|
|
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
|
requestBody: `{"table_names": "", "output_format": "simple"}`,
|
|
wantStatusCode: http.StatusOK,
|
|
want: fmt.Sprintf("[%s,%s]", getSimpleWant(tableNameAuth), getSimpleWant(tableNameParam)),
|
|
isAllTables: true,
|
|
},
|
|
{
|
|
name: "invoke list_tables detailed output",
|
|
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
|
requestBody: fmt.Sprintf(`{"table_names": "%s"}`, tableNameAuth),
|
|
wantStatusCode: http.StatusOK,
|
|
want: fmt.Sprintf("[%s]", getDetailedWant(tableNameAuth, authTableColumns)),
|
|
},
|
|
{
|
|
name: "invoke list_tables simple output",
|
|
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
|
requestBody: fmt.Sprintf(`{"table_names": "%s", "output_format": "simple"}`, tableNameAuth),
|
|
wantStatusCode: http.StatusOK,
|
|
want: fmt.Sprintf("[%s]", getSimpleWant(tableNameAuth)),
|
|
},
|
|
{
|
|
name: "invoke list_tables with invalid output format",
|
|
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
|
requestBody: `{"table_names": "", "output_format": "abcd"}`,
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "invoke list_tables with malformed table_names parameter",
|
|
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
|
requestBody: `{"table_names": 12345, "output_format": "detailed"}`,
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "invoke list_tables with multiple table names",
|
|
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
|
requestBody: fmt.Sprintf(`{"table_names": "%s,%s"}`, tableNameParam, tableNameAuth),
|
|
wantStatusCode: http.StatusOK,
|
|
want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)),
|
|
},
|
|
{
|
|
name: "invoke list_tables with non-existent table",
|
|
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
|
requestBody: `{"table_names": "non_existent_table"}`,
|
|
wantStatusCode: http.StatusOK,
|
|
want: `[]`,
|
|
},
|
|
{
|
|
name: "invoke list_tables with one existing and one non-existent table",
|
|
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
|
requestBody: fmt.Sprintf(`{"table_names": "%s,non_existent_table"}`, tableNameParam),
|
|
wantStatusCode: http.StatusOK,
|
|
want: fmt.Sprintf("[%s]", getDetailedWant(tableNameParam, paramTableColumns)),
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
resp, respBytes := RunRequest(t, http.MethodPost, tc.api, bytes.NewBuffer([]byte(tc.requestBody)), nil)
|
|
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("response status code is not %d, got %d: %s", tc.wantStatusCode, resp.StatusCode, string(respBytes))
|
|
}
|
|
|
|
if tc.wantStatusCode == http.StatusOK {
|
|
var bodyWrapper map[string]json.RawMessage
|
|
|
|
if err := json.Unmarshal(respBytes, &bodyWrapper); err != nil {
|
|
t.Fatalf("error parsing response wrapper: %s, body: %s", err, string(respBytes))
|
|
}
|
|
|
|
resultJSON, ok := bodyWrapper["result"]
|
|
if !ok {
|
|
t.Fatal("unable to find 'result' in response body")
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(resultJSON, &resultString); err != nil {
|
|
if string(resultJSON) == "null" {
|
|
resultString = "null"
|
|
} else {
|
|
t.Fatalf("'result' is not a JSON-encoded string: %s", err)
|
|
}
|
|
}
|
|
|
|
var got, want []any
|
|
|
|
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
|
t.Fatalf("failed to unmarshal actual result string: %v", err)
|
|
}
|
|
if err := json.Unmarshal([]byte(tc.want), &want); err != nil {
|
|
t.Fatalf("failed to unmarshal expected want string: %v", err)
|
|
}
|
|
|
|
for _, item := range got {
|
|
itemMap, ok := item.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
detailsStr, ok := itemMap["object_details"].(string)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
var detailsMap map[string]any
|
|
if err := json.Unmarshal([]byte(detailsStr), &detailsMap); err != nil {
|
|
t.Fatalf("failed to unmarshal nested object_details string: %v", err)
|
|
}
|
|
|
|
// clean unpredictable fields
|
|
delete(detailsMap, "constraints")
|
|
delete(detailsMap, "indexes")
|
|
|
|
itemMap["object_details"] = detailsMap
|
|
}
|
|
|
|
// Checking only the default dbo schema where the test tables are created to avoid brittle tests.
|
|
if tc.isAllTables {
|
|
var filteredGot []any
|
|
for _, item := range got {
|
|
if tableMap, ok := item.(map[string]interface{}); ok {
|
|
if schema, ok := tableMap["schema_name"]; ok && schema == "dbo" {
|
|
filteredGot = append(filteredGot, item)
|
|
}
|
|
}
|
|
}
|
|
got = filteredGot
|
|
}
|
|
|
|
sort.SliceStable(got, func(i, j int) bool {
|
|
return fmt.Sprintf("%v", got[i]) < fmt.Sprintf("%v", got[j])
|
|
})
|
|
sort.SliceStable(want, func(i, j int) bool {
|
|
return fmt.Sprintf("%v", want[i]) < fmt.Sprintf("%v", want[j])
|
|
})
|
|
|
|
if !reflect.DeepEqual(got, want) {
|
|
gotJSON, _ := json.MarshalIndent(got, "", " ")
|
|
wantJSON, _ := json.MarshalIndent(want, "", " ")
|
|
t.Errorf("Unexpected result:\ngot:\n%s\n\nwant:\n%s", string(gotJSON), string(wantJSON))
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// RunPostgresListLocksTest runs tests for the postgres list-locks tool
|
|
func RunPostgresListLocksTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
|
type lockDetails struct {
|
|
Pid any `json:"pid"`
|
|
Usename string `json:"usename"`
|
|
Database string `json:"database"`
|
|
RelName string `json:"relname"`
|
|
LockType string `json:"locktype"`
|
|
Mode string `json:"mode"`
|
|
Granted bool `json:"granted"`
|
|
FastPath bool `json:"fastpath"`
|
|
VirtualXid any `json:"virtualxid"`
|
|
TransactionId any `json:"transactionid"`
|
|
ClassId any `json:"classid"`
|
|
ObjId any `json:"objid"`
|
|
ObjSubId any `json:"objsubid"`
|
|
PageNumber any `json:"page"`
|
|
TupleNumber any `json:"tuple"`
|
|
VirtualBlock any `json:"virtualblock"`
|
|
BlockNumber any `json:"blockno"`
|
|
}
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
expectResults bool
|
|
}{
|
|
{
|
|
name: "invoke list_locks with no arguments",
|
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
|
wantStatusCode: http.StatusOK,
|
|
expectResults: false, // locks may or may not exist
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
const api = "http://127.0.0.1:5000/api/tool/list_locks/invoke"
|
|
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
|
}
|
|
if tc.wantStatusCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var bodyWrapper struct {
|
|
Result json.RawMessage `json:"result"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
|
t.Fatalf("error decoding response wrapper: %v", err)
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
|
resultString = string(bodyWrapper.Result)
|
|
}
|
|
|
|
var got []lockDetails
|
|
if resultString != "null" {
|
|
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
|
t.Fatalf("failed to unmarshal result: %v, result string: %s", err, resultString)
|
|
}
|
|
}
|
|
|
|
// Verify that if results exist, they have the expected structure
|
|
for _, lock := range got {
|
|
if lock.LockType == "" {
|
|
t.Errorf("lock type should not be empty")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// RunPostgresLongRunningTransactionsTest runs tests for the postgres long-running-transactions tool
|
|
func RunPostgresLongRunningTransactionsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
|
type transactionDetails struct {
|
|
Pid any `json:"pid"`
|
|
Usename string `json:"usename"`
|
|
Database string `json:"database"`
|
|
ApplicationName string `json:"application_name"`
|
|
XactStart any `json:"xact_start"`
|
|
XactDurationSecs any `json:"xact_duration_secs"`
|
|
IdleInTransaction string `json:"idle_in_transaction"`
|
|
Query string `json:"query"`
|
|
}
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
}{
|
|
{
|
|
name: "invoke long_running_transactions with default threshold",
|
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
|
wantStatusCode: http.StatusOK,
|
|
},
|
|
{
|
|
name: "invoke long_running_transactions with custom threshold",
|
|
requestBody: bytes.NewBuffer([]byte(`{"min_transaction_duration_secs": 3600}`)),
|
|
wantStatusCode: http.StatusOK,
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
const api = "http://127.0.0.1:5000/api/tool/long_running_transactions/invoke"
|
|
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
|
}
|
|
if tc.wantStatusCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var bodyWrapper struct {
|
|
Result json.RawMessage `json:"result"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
|
t.Fatalf("error decoding response wrapper: %v", err)
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
|
resultString = string(bodyWrapper.Result)
|
|
}
|
|
|
|
var got []transactionDetails
|
|
if resultString != "null" {
|
|
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
|
t.Fatalf("failed to unmarshal result: %v, result string: %s", err, resultString)
|
|
}
|
|
}
|
|
|
|
// Verify that if results exist, they have the expected structure
|
|
for _, tx := range got {
|
|
if tx.XactDurationSecs == nil {
|
|
t.Errorf("transaction duration should not be null for long-running transactions")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// RunPostgresReplicationStatsTest runs tests for the postgres replication-stats tool
|
|
func RunPostgresReplicationStatsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
|
type replicationStats struct {
|
|
ClientAddr string `json:"client_addr"`
|
|
Username string `json:"usename"`
|
|
ApplicationName string `json:"application_name"`
|
|
ClientHostname string `json:"client_hostname"`
|
|
BackendStart any `json:"backend_start"`
|
|
State string `json:"state"`
|
|
SyncState string `json:"sync_state"`
|
|
ReplyTime any `json:"reply_time"`
|
|
FlushLsn string `json:"flush_lsn"`
|
|
ReplayLsn string `json:"replay_lsn"`
|
|
WriteLag any `json:"write_lag"`
|
|
FlushLag any `json:"flush_lag"`
|
|
ReplayLag any `json:"replay_lag"`
|
|
SyncPriority any `json:"sync_priority"`
|
|
ReplicationSlotName any `json:"slot_name"`
|
|
IsStreaming bool `json:"is_streaming"`
|
|
}
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
}{
|
|
{
|
|
name: "invoke replication_stats with no arguments",
|
|
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
|
wantStatusCode: http.StatusOK,
|
|
},
|
|
}
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
const api = "http://127.0.0.1:5000/api/tool/replication_stats/invoke"
|
|
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
|
}
|
|
if tc.wantStatusCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var bodyWrapper struct {
|
|
Result json.RawMessage `json:"result"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
|
t.Fatalf("error decoding response wrapper: %v", err)
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
|
resultString = string(bodyWrapper.Result)
|
|
}
|
|
|
|
var got []replicationStats
|
|
if resultString != "null" {
|
|
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
|
t.Fatalf("failed to unmarshal result: %v, result string: %s", err, resultString)
|
|
}
|
|
}
|
|
|
|
// Verify that if results exist, they have the expected structure
|
|
for _, stat := range got {
|
|
if stat.State == "" {
|
|
t.Errorf("replication state should not be empty")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func RunPostgresGetColumnCardinalityTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
|
schemaName := "testschema_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
|
tableName := "table1_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
|
cleanup := setupPostgresSchemas(t, ctx, pool, schemaName)
|
|
defer cleanup()
|
|
|
|
// Create table with multiple columns
|
|
createTableStmt := fmt.Sprintf(`
|
|
CREATE TABLE %s.%s (
|
|
id SERIAL PRIMARY KEY,
|
|
email VARCHAR(100) UNIQUE,
|
|
name VARCHAR(50),
|
|
status VARCHAR(20),
|
|
created_at TIMESTAMP
|
|
)
|
|
`, schemaName, tableName)
|
|
|
|
if _, err := pool.Exec(ctx, createTableStmt); err != nil {
|
|
t.Fatalf("unable to create table: %s", err)
|
|
}
|
|
|
|
// Insert larger sample data to ensure statistics are collected
|
|
insertStmt := fmt.Sprintf(`
|
|
INSERT INTO %s.%s (email, name, status, created_at) VALUES
|
|
('user1@example.com', 'Alice', 'active', NOW()),
|
|
('user2@example.com', 'Bob', 'inactive', NOW()),
|
|
('user3@example.com', 'Charlie', 'active', NOW()),
|
|
('user4@example.com', 'David', 'active', NOW()),
|
|
('user5@example.com', 'Eve', 'inactive', NOW()),
|
|
('user6@example.com', 'Frank', 'active', NOW()),
|
|
('user7@example.com', 'Grace', 'inactive', NOW()),
|
|
('user8@example.com', 'Henry', 'active', NOW()),
|
|
('user9@example.com', 'Ivy', 'active', NOW()),
|
|
('user10@example.com', 'Jack', 'inactive', NOW())
|
|
`, schemaName, tableName)
|
|
|
|
if _, err := pool.Exec(ctx, insertStmt); err != nil {
|
|
t.Fatalf("unable to insert data: %s", err)
|
|
}
|
|
|
|
// Run ANALYZE to update statistics
|
|
analyzeStmt := fmt.Sprintf(`ANALYZE %s.%s`, schemaName, tableName)
|
|
if _, err := pool.Exec(ctx, analyzeStmt); err != nil {
|
|
t.Fatalf("unable to run ANALYZE: %s", err)
|
|
}
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
shouldHaveData bool // Whether we expect data in the response
|
|
}{
|
|
{
|
|
name: "get cardinality for a specific column",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s", "table_name": "%s", "column_name": "email"}`, schemaName, tableName)),
|
|
wantStatusCode: http.StatusOK,
|
|
shouldHaveData: true,
|
|
},
|
|
{
|
|
name: "get cardinality for all columns",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s", "table_name": "%s"}`, schemaName, tableName)),
|
|
wantStatusCode: http.StatusOK,
|
|
shouldHaveData: true,
|
|
},
|
|
{
|
|
name: "get cardinality with non-existent column",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s", "table_name": "%s", "column_name": "non_existent"}`, schemaName, tableName)),
|
|
wantStatusCode: http.StatusOK,
|
|
shouldHaveData: false,
|
|
},
|
|
{
|
|
name: "get cardinality with non-existent schema",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "non_existent_schema", "table_name": "%s"}`, tableName)),
|
|
wantStatusCode: http.StatusOK,
|
|
shouldHaveData: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
const api = "http://127.0.0.1:5000/api/tool/get_column_cardinality/invoke"
|
|
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
|
}
|
|
if tc.wantStatusCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var bodyWrapper struct {
|
|
Result json.RawMessage `json:"result"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
|
t.Fatalf("error decoding response wrapper: %v", err)
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
|
resultString = string(bodyWrapper.Result)
|
|
}
|
|
|
|
var got []map[string]any
|
|
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
|
t.Fatalf("failed to unmarshal nested result string: %v", err)
|
|
}
|
|
|
|
// Verify that we got the expected data presence
|
|
if tc.shouldHaveData {
|
|
if len(got) == 0 {
|
|
t.Logf("warning: expected data but got empty result. This can happen if pg_stats is not populated yet.")
|
|
return
|
|
}
|
|
|
|
// Verify column names and cardinality values
|
|
for _, row := range got {
|
|
columnName, ok := row["column_name"].(string)
|
|
if !ok {
|
|
t.Fatalf("column_name is not a string: %v", row["column_name"])
|
|
}
|
|
|
|
// Check that estimated_cardinality is present and is a number
|
|
cardinality, ok := row["estimated_cardinality"]
|
|
if !ok {
|
|
t.Fatalf("estimated_cardinality is missing for column %s", columnName)
|
|
}
|
|
|
|
// Convert to float64 for numeric checks
|
|
cardinalityFloat, ok := cardinality.(float64)
|
|
if !ok {
|
|
t.Fatalf("estimated_cardinality is not a number: %v", cardinality)
|
|
}
|
|
|
|
// Cardinality should be >= 0
|
|
if cardinalityFloat < 0 {
|
|
t.Errorf("cardinality for column %s is negative: %v", columnName, cardinalityFloat)
|
|
}
|
|
}
|
|
} else {
|
|
if len(got) != 0 {
|
|
t.Errorf("expected no data but got: %v", got)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func createPostgresExtension(t *testing.T, ctx context.Context, pool *pgxpool.Pool, extensionName string) func() {
|
|
createExtensionCmd := fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s", extensionName)
|
|
_, err := pool.Exec(ctx, createExtensionCmd)
|
|
if err != nil {
|
|
t.Fatalf("failed to create extension: %v", err)
|
|
}
|
|
return func() {
|
|
dropExtensionCmd := fmt.Sprintf("DROP EXTENSION IF EXISTS %s", extensionName)
|
|
_, err := pool.Exec(ctx, dropExtensionCmd)
|
|
if err != nil {
|
|
t.Fatalf("failed to drop extension: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func RunPostgresListQueryStatsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
|
// Insert a simple query by running a SELECT statement
|
|
// This will record statistics in pg_stat_statements
|
|
selectStmt := "SELECT 1 as test_query"
|
|
if _, err := pool.Exec(ctx, selectStmt); err != nil {
|
|
t.Logf("warning: unable to execute test query: %s", err)
|
|
}
|
|
|
|
dropExtensionFunc := createPostgresExtension(t, ctx, pool, "pg_stat_statements")
|
|
defer dropExtensionFunc()
|
|
|
|
type queryStatDetails struct {
|
|
Datname string `json:"datname"`
|
|
Query string `json:"query"`
|
|
Calls any `json:"calls"`
|
|
TotalExecTime any `json:"total_exec_time"`
|
|
MinExecTime any `json:"min_exec_time"`
|
|
MaxExecTime any `json:"max_exec_time"`
|
|
MeanExecTime any `json:"mean_exec_time"`
|
|
Rows any `json:"rows"`
|
|
SharedBlksHit any `json:"shared_blks_hit"`
|
|
SharedBlksRead any `json:"shared_blks_read"`
|
|
}
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
}{
|
|
{
|
|
name: "list query stats with default limit",
|
|
requestBody: bytes.NewBufferString(`{}`),
|
|
wantStatusCode: http.StatusOK,
|
|
},
|
|
{
|
|
name: "list query stats with custom limit",
|
|
requestBody: bytes.NewBufferString(`{"limit": 10}`),
|
|
wantStatusCode: http.StatusOK,
|
|
},
|
|
{
|
|
name: "list query stats for specific database",
|
|
requestBody: bytes.NewBufferString(`{"database_name": "postgres"}`),
|
|
wantStatusCode: http.StatusOK,
|
|
},
|
|
{
|
|
name: "list query stats with non-existent database name",
|
|
requestBody: bytes.NewBufferString(`{"database_name": "non_existent_db_xyz"}`),
|
|
wantStatusCode: http.StatusOK,
|
|
},
|
|
}
|
|
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
const api = "http://127.0.0.1:5000/api/tool/list_query_stats/invoke"
|
|
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
|
}
|
|
if tc.wantStatusCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var bodyWrapper struct {
|
|
Result json.RawMessage `json:"result"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
|
t.Fatalf("error decoding response wrapper: %v", err)
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
|
resultString = string(bodyWrapper.Result)
|
|
}
|
|
|
|
var got []map[string]any
|
|
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
|
t.Fatalf("failed to unmarshal nested result string: %v, resultString: %s", err, resultString)
|
|
}
|
|
|
|
// For databases with pg_stat_statements available, verify response structure
|
|
if len(got) > 0 {
|
|
// Verify the response has the expected fields
|
|
requiredFields := []string{"datname", "query", "calls", "total_exec_time", "min_exec_time", "max_exec_time", "mean_exec_time", "rows", "shared_blks_hit", "shared_blks_read"}
|
|
for _, field := range requiredFields {
|
|
if _, ok := got[0][field]; !ok {
|
|
t.Errorf("missing expected field: %s in result: %v", field, got[0])
|
|
}
|
|
}
|
|
|
|
// Verify data types
|
|
var stat queryStatDetails
|
|
statData, _ := json.Marshal(got[0])
|
|
if err := json.Unmarshal(statData, &stat); err != nil {
|
|
t.Logf("warning: unable to unmarshal query stat: %v", err)
|
|
}
|
|
|
|
// Verify that results are ordered by total_exec_time (descending)
|
|
if len(got) > 1 {
|
|
for i := 0; i < len(got)-1; i++ {
|
|
currentTime, ok1 := got[i]["total_exec_time"].(float64)
|
|
nextTime, ok2 := got[i+1]["total_exec_time"].(float64)
|
|
if ok1 && ok2 && currentTime < nextTime {
|
|
t.Logf("warning: results may not be ordered by total_exec_time descending: %f vs %f", currentTime, nextTime)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// RunPostgresListTableStatsTest runs tests for the postgres list-table-stats tool
|
|
func RunPostgresListTableStatsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
|
type tableStatsDetails struct {
|
|
SchemaName string `json:"schema_name"`
|
|
TableName string `json:"table_name"`
|
|
Owner string `json:"owner"`
|
|
TotalSizeBytes any `json:"total_size_bytes"`
|
|
SeqScan any `json:"seq_scan"`
|
|
IdxScan any `json:"idx_scan"`
|
|
IdxScanRatioPercent float64 `json:"idx_scan_ratio_percent"`
|
|
LiveRows any `json:"live_rows"`
|
|
DeadRows any `json:"dead_rows"`
|
|
DeadRowRatioPercent float64 `json:"dead_row_ratio_percent"`
|
|
NTupIns any `json:"n_tup_ins"`
|
|
NTupUpd any `json:"n_tup_upd"`
|
|
NTupDel any `json:"n_tup_del"`
|
|
LastVacuum any `json:"last_vacuum"`
|
|
LastAutovacuum any `json:"last_autovacuum"`
|
|
LastAutoanalyze any `json:"last_autoanalyze"`
|
|
}
|
|
|
|
// Create a test table to generate statistics
|
|
testTableName := "test_list_table_stats_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
|
createTableStmt := fmt.Sprintf(`
|
|
CREATE TABLE %s (
|
|
id SERIAL PRIMARY KEY,
|
|
name VARCHAR(100),
|
|
email VARCHAR(100)
|
|
)
|
|
`, testTableName)
|
|
|
|
if _, err := pool.Exec(ctx, createTableStmt); err != nil {
|
|
t.Fatalf("unable to create test table: %s", err)
|
|
}
|
|
defer func() {
|
|
dropTableStmt := fmt.Sprintf("DROP TABLE IF EXISTS %s", testTableName)
|
|
if _, err := pool.Exec(ctx, dropTableStmt); err != nil {
|
|
t.Logf("warning: unable to drop test table: %v", err)
|
|
}
|
|
}()
|
|
|
|
// Insert some data to generate statistics
|
|
insertStmt := fmt.Sprintf(`
|
|
INSERT INTO %s (name, email) VALUES
|
|
('Alice', 'alice@example.com'),
|
|
('Bob', 'bob@example.com'),
|
|
('Charlie', 'charlie@example.com'),
|
|
('David', 'david@example.com'),
|
|
('Eve', 'eve@example.com')
|
|
`, testTableName)
|
|
|
|
if _, err := pool.Exec(ctx, insertStmt); err != nil {
|
|
t.Fatalf("unable to insert test data: %s", err)
|
|
}
|
|
|
|
// Run some sequential scans to generate statistics
|
|
for i := 0; i < 3; i++ {
|
|
selectStmt := fmt.Sprintf("SELECT * FROM %s WHERE name = 'Alice'", testTableName)
|
|
if _, err := pool.Exec(ctx, selectStmt); err != nil {
|
|
t.Logf("warning: unable to execute select: %v", err)
|
|
}
|
|
}
|
|
|
|
// Run ANALYZE to update statistics
|
|
analyzeStmt := fmt.Sprintf("ANALYZE %s", testTableName)
|
|
if _, err := pool.Exec(ctx, analyzeStmt); err != nil {
|
|
t.Logf("warning: unable to run ANALYZE: %v", err)
|
|
}
|
|
|
|
invokeTcs := []struct {
|
|
name string
|
|
requestBody io.Reader
|
|
wantStatusCode int
|
|
shouldHaveData bool
|
|
filterTable bool
|
|
}{
|
|
{
|
|
name: "list table stats with no arguments (default limit)",
|
|
requestBody: bytes.NewBufferString(`{}`),
|
|
wantStatusCode: http.StatusOK,
|
|
shouldHaveData: false, // may or may not have data depending on what's in the database
|
|
},
|
|
{
|
|
name: "list table stats with default limit",
|
|
requestBody: bytes.NewBufferString(`{"schema_name": "public"}`),
|
|
wantStatusCode: http.StatusOK,
|
|
shouldHaveData: false,
|
|
},
|
|
{
|
|
name: "list table stats filtering by specific table",
|
|
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_name": "%s"}`, testTableName)),
|
|
wantStatusCode: http.StatusOK,
|
|
shouldHaveData: true,
|
|
filterTable: true,
|
|
},
|
|
{
|
|
name: "list table stats with custom limit",
|
|
requestBody: bytes.NewBufferString(`{"limit": 10}`),
|
|
wantStatusCode: http.StatusOK,
|
|
shouldHaveData: false,
|
|
},
|
|
{
|
|
name: "list table stats sorted by size",
|
|
requestBody: bytes.NewBufferString(`{"sort_by": "size", "limit": 5}`),
|
|
wantStatusCode: http.StatusOK,
|
|
shouldHaveData: false,
|
|
},
|
|
{
|
|
name: "list table stats sorted by seq_scan",
|
|
requestBody: bytes.NewBufferString(`{"sort_by": "seq_scan", "limit": 5}`),
|
|
wantStatusCode: http.StatusOK,
|
|
shouldHaveData: false,
|
|
},
|
|
{
|
|
name: "list table stats sorted by idx_scan",
|
|
requestBody: bytes.NewBufferString(`{"sort_by": "idx_scan", "limit": 5}`),
|
|
wantStatusCode: http.StatusOK,
|
|
shouldHaveData: false,
|
|
},
|
|
{
|
|
name: "list table stats sorted by dead_rows",
|
|
requestBody: bytes.NewBufferString(`{"sort_by": "dead_rows", "limit": 5}`),
|
|
wantStatusCode: http.StatusOK,
|
|
shouldHaveData: false,
|
|
},
|
|
{
|
|
name: "list table stats with non-existent table filter",
|
|
requestBody: bytes.NewBufferString(`{"table_name": "non_existent_table_xyz"}`),
|
|
wantStatusCode: http.StatusOK,
|
|
shouldHaveData: false,
|
|
},
|
|
{
|
|
name: "list table stats with non-existent schema filter",
|
|
requestBody: bytes.NewBufferString(`{"schema_name": "non_existent_schema_xyz"}`),
|
|
wantStatusCode: http.StatusOK,
|
|
shouldHaveData: false,
|
|
},
|
|
{
|
|
name: "list table stats with owner filter",
|
|
requestBody: bytes.NewBufferString(`{"owner": "postgres"}`),
|
|
wantStatusCode: http.StatusOK,
|
|
shouldHaveData: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range invokeTcs {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
const api = "http://127.0.0.1:5000/api/tool/list_table_stats/invoke"
|
|
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
|
if resp.StatusCode != tc.wantStatusCode {
|
|
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
|
}
|
|
if tc.wantStatusCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var bodyWrapper struct {
|
|
Result json.RawMessage `json:"result"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
|
t.Fatalf("error decoding response wrapper: %v", err)
|
|
}
|
|
|
|
var resultString string
|
|
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
|
resultString = string(bodyWrapper.Result)
|
|
}
|
|
|
|
var got []tableStatsDetails
|
|
if resultString != "null" {
|
|
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
|
t.Fatalf("failed to unmarshal result: %v, result string: %s", err, resultString)
|
|
}
|
|
}
|
|
|
|
// Verify expected data presence
|
|
if tc.shouldHaveData {
|
|
if len(got) == 0 {
|
|
t.Fatalf("expected data but got empty result")
|
|
}
|
|
|
|
// Verify the test table is in results
|
|
found := false
|
|
for _, row := range got {
|
|
if row.TableName == testTableName {
|
|
found = true
|
|
// Verify expected fields are present
|
|
if row.SchemaName == "" {
|
|
t.Errorf("schema_name should not be empty")
|
|
}
|
|
if row.Owner == "" {
|
|
t.Errorf("owner should not be empty")
|
|
}
|
|
if row.TotalSizeBytes == nil {
|
|
t.Errorf("total_size_bytes should not be null")
|
|
}
|
|
if row.LiveRows == nil {
|
|
t.Errorf("live_rows should not be null")
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
t.Errorf("test table %s not found in results", testTableName)
|
|
}
|
|
} else if tc.filterTable {
|
|
// For filtered queries that shouldn't find anything
|
|
if len(got) != 0 {
|
|
t.Logf("warning: expected no data but got: %v", len(got))
|
|
}
|
|
}
|
|
|
|
// Verify result structure and data types
|
|
for _, stat := range got {
|
|
// Verify schema_name and table_name are strings
|
|
if stat.SchemaName == "" && stat.TableName != "" {
|
|
t.Errorf("schema_name is empty for table %s", stat.TableName)
|
|
}
|
|
|
|
// Verify numeric fields are valid
|
|
if stat.IdxScanRatioPercent < 0 || stat.IdxScanRatioPercent > 100 {
|
|
t.Errorf("idx_scan_ratio_percent should be between 0 and 100, got %f", stat.IdxScanRatioPercent)
|
|
}
|
|
|
|
if stat.DeadRowRatioPercent < 0 || stat.DeadRowRatioPercent > 100 {
|
|
t.Errorf("dead_row_ratio_percent should be between 0 and 100, got %f", stat.DeadRowRatioPercent)
|
|
}
|
|
}
|
|
|
|
// Verify sorting for specific sort_by options
|
|
if tc.name == "list table stats sorted by size" && len(got) > 1 {
|
|
for i := 0; i < len(got)-1; i++ {
|
|
current, ok1 := got[i].TotalSizeBytes.(float64)
|
|
next, ok2 := got[i+1].TotalSizeBytes.(float64)
|
|
if ok1 && ok2 && current < next {
|
|
t.Logf("warning: results may not be sorted by total_size_bytes descending")
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// RunRequest is a helper function to send HTTP requests and return the response
|
|
func RunRequest(t *testing.T, method, url string, body io.Reader, headers map[string]string) (*http.Response, []byte) {
|
|
// Send request
|
|
req, err := http.NewRequest(method, url, body)
|
|
if err != nil {
|
|
t.Fatalf("unable to create request: %s", err)
|
|
}
|
|
|
|
req.Header.Set("Content-type", "application/json")
|
|
|
|
for k, v := range headers {
|
|
req.Header.Set(k, v)
|
|
}
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("unable to send request: %s", err)
|
|
}
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatalf("unable to read request body: %s", err)
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
return resp, respBody
|
|
}
|